Skip to content

This repo contains the implementation of VQGAN, Taming Transformers for High-Resolution Image Synthesis in PyTorch from scratch. I have added support for custom datasets, testings, experiment tracking etc.

License

Notifications You must be signed in to change notification settings

Shubhamai/pytorch-vqgan

Repository files navigation

PyTorch VQGAN

License: MIT Run Python Tests codecov


Figure 1. VQGAN Architecture

Note: This is a work in progress.

This repo purpose is to serve as a cleaner and feature-rich implementation of the VQGAN - Taming Transformers for High-Resolution Image Synthesis from the initial work of dome272's repo in PyTorch from scratch. There's also a great video on the explanation of VQGAN by dome272.

I created this repo to better understand VQGAN myself, and to provide scripts for faster training and experimentation with a toy dataset like MNIST etc. I also tried to make it as clean as possible, with comments, logging, testing & coverage, custom datasets & visualizations, etc.

What is VQGAN?

VQGAN stands for Vector Quantised Generative Adversarial Networks. The main idea behind this paper is to use CNN to learn the visual part of the image and generate a codebook of context-rich visual parts and then use Transformers to learn the long-range/global interactions between the visual parts of the image embedded in the codebook. Combining these two, we can generate very high-resolution images.

Learning both of these short and long-term interactions to generate high-resolution images is done in two different stages.

  1. The first stage uses VQGAN to learn the codebook of context-rich visual representation of the images. In terms of architecture, it is very similar to VQVAE in that it consists of an encoder, decoder and the codebook. We will learn more about this in the next section.


Figure 2. VQVAE Architecture

  1. Using a transformer to learn the global interactions between the vectors in the codebook by predicting the next sequence from the previous sequences, to generate high-resolution images.

Stage 1


Stage 1 : VQGAN Architecture

The architecture of VQGAN consists of majorly three parts, the encoder, decoder and the Codebook, similar to the VQVAE paper.
  1. The encoder encoder.py part in the VQGAN learns to represent the images into a much lower dimension called embeddings or latent and consists of Convolution, Downsample, Residual blocks and special attention blocks ( Non-Local blocks ), around 30 million parameters in default settings.
  2. The embeddings are then quantized using CodeBook and the quantized embeddings are used as input to the decoder decoder.py part.
  3. The decode takes the "quantized" embeddings and reconstructs the image. The architecture is similar to the encoder but reversed. Around 40 million parameters in default settings, slightly more compared to encoder due to more number of residual blocks.

The main idea behind codebook and quantization is to convert the continuous latent representation into a discrete representation. The codebook is simply a list of n latent vectors ( which are learned while training ) which are then used to replace the latents generated from the encoder output with the closest vector ( in terms of distance ) from the codebook. The VQ part comes from here.

Training

The training involves, sending the batch of images through the encoder, quantizing the embeddings and then sending the quantized embeddings through the decoder to reconstruct the image. The loss function is computed as follows:

$\begin{aligned} \mathcal{L}_{\mathrm{VQ}}(E, G, \mathcal{Z})=\|x-\hat{x}\|^{2} &+\left\|\text{sg}[E(x)]-z_{\mathbf{q}}\right\|_{2}^{2}+\left\|\text{sg}\left[z_{\mathbf{q}}\right]-E(x)\right\|_{2}^{2} . \end{aligned}$

The above equation represents the sum of reconstruction loss, alignment and commitment loss

  1. Reconstruction loss

    Appartely there is some confusion about is this reconstruction loss was replaced with perceptual loss or it was a combination of them, we will go with what was implemented in the official code CompVis/taming-transformers#40, which is l1 + perceptual loss

    The reconstruction loss is a sum of the l1 loss and perceptual loss.
    $\text { L1 Loss }=\sum_{i=1}^{n}\left|y_{\text {true }}-y_{\text {predicted }}\right|$

    The perceptual is calculated the l2 distance between the last layer output of the generated vs original image from pre-trained model like VGG, etc.

  2. The alignment and commitment loss is from the quantization which compares the distance between the latent vectors from encoder output and the closest vector from the codebook. sg here means stop gradient function.


$\mathcal{L}_{\mathrm{GAN}}(\{E, G, \mathcal{Z}\}, D)=[\log D(x)+\log (1-D(\hat{x}))]$

The above loss is for the discriminator which takes in real and generated images and learns to classify which one's real or face. the GAN in VQGAN comes from here :)

The discrimination here is a bit different than conventional discriminators in that, instead of taking whole images as an input, they instead convert the images into patches using convolution and then predict which patch is real or fake.


$\lambda=\frac{\nabla_{G_{L}}\left[\mathcal{L}_{\mathrm{rec}}\right]}{\nabla_{G_{L}}\left[\mathcal{L}_{\mathrm{GAN}}\right]+\delta}$

We calculate lambda as the ratio between the reconstruction loss and the GAN loss, both with respect to the gradient of the last layer of the decoder. calculate_lambda in vqgan.py

The final loss then becomes -

$\begin{aligned} \mathcal{Q}^{*}=\underset{E, G, \mathcal{Z}}{\arg \min } \max _{D} \mathbb{E}_{x \sim p(x)}\left[\mathcal{L}_{\mathrm{VQ}}(E, G, \mathcal{Z})+\lambda \mathcal{L}_{\mathrm{GAN}}(\{E, G, \mathcal{Z}\}, D)\right] \end{aligned}$

which is the combination of the reconstruction loss, alignment loss and commitment loss and discriminator loss multiplied with lambda.

Generation

To generate the images from VQGAN, we generate the quantized vectors from Stage 2 and pass them through the decoder to reconstruct the image.


Stage 2


Stage 2: Transformers

This stage contains Transformers 🤖 which are trained to predict the next latent vector from the sequence of previous latent vectors in the quantized encoder output. The paper uses mingpt.py from Andrej Karpathy's karpathy/minGPT repo.

Due to computation constraints of generating high-resolution images, they also use a sliding attention window to predict the next latent vector from its neighbor vectors in the quantized encoder output.

Setup

  1. Clone the repo - https://github.com/Shubhamai/pytorch-vqgan
  2. Create a new conda environment using conda env create --prefix env python=3.7.13 --file=environment.yml
  3. Activate the conda environment using conda activate ./env

Usage

Training

  • You can start the training by running python train.py. It reads the default config file from configs/default.yml . To change the config path, run - python train.py --config_path configs/default.yaml.

    Here's what mostly the script does -

    • Downloads the MNIST dataset automatically and saved in the data directory ( specified in config ).
    • Training the VQGAN and transformer model on the MNIST train set with parameters passed from the config file.
    • The training metrics, visualizations and model are saved in the experiments/ directory with the corresponding path specified in the config file.
  • Run aim up to open the experiment tracker to see the metrics and reconstructed & generated images.

Generation

To generate the images, simply run python generate.py, the models will be loaded from the experiments/checkpoints and the output will be saved in experiments.

Tests

I have also just started getting my feet wet with testing and automated testing with GitHub CI/CD, so the tests here might not be the best practices.

To run tests, run pytest --cov-config=.coveragerc --cov=. test

Hardware requirements

The hardware which I tried the model on default settings is -

  • Ryzen 5 4600H
  • NVIDIA GeForce GTX 1660Ti - 6 GB VRAM
  • 12 GB ram

It took around 2-3 min to get good reconstruction results. Since, google colab has similar hardware in terms compute power from what I understand, it should run just fine on colab :)

Shoutouts

The list here contains some helpful blogs or videos that helped me a bunch in understanding the VQGAN.

  1. The Illustrated VQGAN by Lj Miranda
  2. VQGAN: Taming Transformers for High-Resolution Image Synthesis [Paper Explained] by Gradient Dude
  3. VQ-GAN: Taming Transformers for High-Resolution Image Synthesis | Paper Explained by The AI Epiphany
  4. VQ-GAN | Paper Explanation and VQ-GAN | PyTorch Implementation by Outlier
  5. TL#006 Robin Rombach Taming Transformers for High Resolution Image Synthesis by one of the paper's author - Robin Rombach. Thanks for the talk :)

BibTeX

@misc{esser2020taming,
      title={Taming Transformers for High-Resolution Image Synthesis}, 
      author={Patrick Esser and Robin Rombach and Björn Ommer},
      year={2020},
      eprint={2012.09841},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

About

This repo contains the implementation of VQGAN, Taming Transformers for High-Resolution Image Synthesis in PyTorch from scratch. I have added support for custom datasets, testings, experiment tracking etc.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages