Skip to content

RenshengJi/cvae

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Conditional Variational Auto-encoder

PWC

Introduction

This tutorial implements Learning Structured Output Representation using Deep Conditional Generative Models paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL.

Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. They called the model Conditional Variational Auto-encoder (CVAE).

The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. It is trained to maximize the conditional marginal log-likelihood. The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.

The problem

Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. The image below shows the case where one quadrant is the input:

image1

Our objective is to learn a model that can perform probabilistic inference and make diverse predictions from a single input. This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. In the example above, the input shows a small part of a digit that might be a three or a five.

Evaluating the results

For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.

image1

We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table.

1 quadrant 2 quadrants 3 quadrants
NN (baseline) 100.4 61.9 25.4
CVAE (Monte Carlo) 71.8 51.0 24.2
Performance gap 28.6 10.9 1.2

We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: the estimated CLLs of the CVAE significantly outperforms the baseline NN.

See the full code on Github.

IMPORTANT

There are some issue reports when trying to run the code with Pyro versions different than the one in requirements.txt. So, to make sure the code works, the recommended way is to create a clean virtual environment (conda or virtualenv), and running pip install -r requirements.txt in this new environment.

References

[1] Learning Structured Output Representation using Deep Conditional Generative Models,
     Kihyuk Sohn, Xinchen Yan, Honglak Lee

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 58.6%
  • Jupyter Notebook 41.4%