This is a minimal PyTorch implementation of the VQ-VAE model described in "Neural Discrete Representation Learning". I tried to stay as close to the official DeepMind implementation as possible while still being PyTorch-y, and I tried to add comments in the code referring back to the relevant sections/equations in the paper.
To train the model on the CIFAR-10 dataset using the same hyperparameters described in the paper, run:
python3 train_vqvae.py
It should only take a few minutes on a modern GPU (a Colab notebook can be found here). After training, the script saves the following two images:
Validation Set Samples
Reconstructions