Skip to content

Latest commit

 

History

History
64 lines (37 loc) · 2.41 KB

readme.md

File metadata and controls

64 lines (37 loc) · 2.41 KB

Monte Carlo Variational Auto-Encoders

The repository complements the code of the 'Monte Carlo Variational Auto-Encoders' paper.

In this paper, we introduce new objectives for training a VAE. These objectives are inspired by the exactness of MCMC methods and their specific application in annealing importance sampling.

The posteriors we received using our approach are very flexible, allowing us to learn complicated shapes, not feasible to parametric approaches. The results of the toy example are presented below:

Resulting posteriors

In image datasets, in particular, MNIST, our new objectives, and induced posterior approximations outperform other approaches by a large margin:

Likelihood Comparison

Results on other image datasets are presented below:

Results on different datasets

How to run the code?

All the experiments which appear in the paper can be run via the exps.sh script. For example:

python main.py --model L-MCVAE --dataset mnist --act_func gelu --binarize True --hidden_dim 64 --batch_size 100 --net_type conv --num_samples 1 --max_epochs 50 --step_size 0.01 --K 1 --use_transforms True --learnable_transitions False --use_cloned_decoder True

For each experiment, we can set

  • model -- which model we want to train: 'VAE', 'IWAE', 'L-MCVAE' or 'A-MCVAE'

  • dataset -- which dataset to use (now available MNIST, CIFAR-10, OMNIGLOT, Fashion-MNIST, CelebA[should be downloaded separately])

  • act_func -- activation function

    More arguments with the description are available in main.py

Citation

The original paper can be found here. If you use MCVAE, we kindly ask you to cite:

@inproceedings{thin2021monte,
  title={Monte Carlo variational auto-encoders},
  author={Thin, Achille and Kotelevskii, Nikita and Doucet, Arnaud and Durmus, Alain and Moulines, Eric and Panov, Maxim},
  booktitle={International Conference on Machine Learning},
  pages={10247--10257},
  year={2021},
  organization={PMLR}
}