Skip to content

A PyTorch implementation of the paper 'Adversarial PoseNet: A Structure-aware Convolutional Network for Human Pose Estimation' (https://arxiv.org/pdf/1705.00389v2.pdf)

Notifications You must be signed in to change notification settings

rohitrango/Adversarial-Pose-Estimation

Repository files navigation

Adversarial Pose Estimation

Abstract

This repository aims to replicate the results of this paper. The idea is to augment the human pose estimation by using a GAN-based framework, where the (conditional) generator learns the distribution P(y|x), where x is the image and y is the heatmap for the person. Typical keypoint detectors simply employ a similarity based loss (MSE or cross-entropy) on the predicted heatmaps with the ground-truth heatmaps. However, these losses can predicted smooth outputs as they are averaged over the entire spatial domain. The idea here is to make the predictions ''crisper and sharper'' by employing discriminators that differentiate between ground-truth and predicted heatmaps in 2 different ways.

Framework

An overview of the architecture is given here: arch
This framework consists of a generator network and two discriminator networks. Why two? One of the discriminator captures the similarity of (x, y) pairs, which is what a traditional discriminator would do, and the second discriminator compares the "quality" of heatmaps generated by the network versus the ground truth heatmaps provided. The former discriminator is called the Pose Discriminator which takes the image and heatmaps as inputs and the latter is called the Confidence Discriminator which takes only the heatmaps as input. This makes the heatmaps sharper than what is traditionally achieved by using a per-pixel loss.
disc
The architectures used for the discriminators are same as the one described in the paper. The generator is a typical stacked hourglass architecture with intermediate supervision modules. The paper uses MSE loss function for learning the heatmaps. However, we noticed that training with MSE is very slow because the maximum per-pixel difference can be 1 and the ground truth heatmaps are sparse. Hence, we use a weighted binary cross entropy loss, to balance the ratio of positives v/s negatives. This results in much faster training and convergence. The discriminator loss is a simple binary cross entropy loss (since real and fake pairs are given in equal ratios).

Dependencies

The list of dependencies can be found in the the requirements.txt file. Simply use pip install -r requirements.txt to install them.

Running the code

WARNING: GAN training can be unstable, and may also depend on your pytorch/CUDA versions. If the default code doesn't work, try tuning with other parameters.

Running the code for training is fairly easy. Follow these steps.

  • Go to config/default_config.py to edit the hyperparameters as per your choice. For your convenience, the default parameters are already set.
  • Download the extended LSP dataset here. Download it in your favorite directory. Your dataset directory should look like this (if the root dataset dir is lspet_dataset/)
  lspet_dataset/
    images/
      im00001.jpg
      im00002.jpg
      ...
    joints.mat
    README.txt
  • Add this path to the --path parameter in train.sh script. This contains all the other parameters required to train the model.
  • Run the script.
  • The pretrained file can be found in the Downloads sections of the README.

Results

We got a [email protected] value of 0.606893 over the validation dataset. We trained on binary cross entropy loss with a batch size of 1. This score is low, however, we trained it for about a day only (since we had bugs in our previous codes). Here are some qualitative results:

mpii_1

mpii_2

References

If you liked this repository, and would like to use it in your work, consider citing the original paper.

@article{DBLP:journals/corr/ChenSWLY17,
  author    = {Yu Chen and
               Chunhua Shen and
               Xiu{-}Shen Wei and
               Lingqiao Liu and
               Jian Yang},
  title     = {Adversarial PoseNet: {A} Structure-aware Convolutional Network for
               Human Pose Estimation},
  journal   = {CoRR},
  volume    = {abs/1705.00389},
  year      = {2017},
  url       = {http:https://arxiv.org/abs/1705.00389},
  archivePrefix = {arXiv},
  eprint    = {1705.00389},
  timestamp = {Mon, 13 Aug 2018 16:47:51 +0200},
  biburl    = {https://dblp.org/rec/bib/journals/corr/ChenSWLY17},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

We also thank Naman's repository for providing the code for PCK and PCKh metrics.

Downloads

TODO: Put a drive link to pretrained model.

About

A PyTorch implementation of the paper 'Adversarial PoseNet: A Structure-aware Convolutional Network for Human Pose Estimation' (https://arxiv.org/pdf/1705.00389v2.pdf)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors 4

  •  
  •  
  •  
  •