Skip to content
This repository has been archived by the owner on May 4, 2024. It is now read-only.
/ mnist-gan Public archive

A template repository for GANs

License

Notifications You must be signed in to change notification settings

rharish101/mnist-gan

Repository files navigation

MNIST GAN

This is a repository for training a conditional GAN for the MNIST dataset. The GAN is optimized using the Wasserstein loss and the Wasserstein gradient penalty. A DCGAN-like architecture is used along with spectral normalization for the critic.

Instructions

All Python scripts use argparse to parse commandline arguments. For viewing the list of all positional and optional arguments for any script, type:

./script.py --help

Setup

Poetry is used for conveniently installing and managing dependencies.

  1. [Optional] Create and activate a virtual environment with Python >= 3.8.

  2. Install Poetry globally (recommended), or in a virtual environment. Please refer to Poetry's installation guide for recommended installation options.

    You can use pip to install it:

    pip install poetry
  3. Install all dependencies with Poetry:

    poetry install --no-dev

    If you didn't create and activate a virtual environment in step 1, Poetry creates one for you and installs all dependencies there. To use this virtual environment, run:

    poetry shell
  4. Download the MNIST dataset using the provided script (requires cURL >= 7.19.0):

    ./download_mnist.sh [/path/where/dataset/should/be/saved/]

    By default, this dataset is saved to the directory datasets/MNIST.

For Contributing

pre-commit is used for managing hooks that run before each commit, to ensure code quality and run some basic tests. Thus, this needs to be set up only when one intends to commit changes to git.

  1. Activate the virtual environment where you installed the dependencies.

  2. Install all dependencies, including extra dependencies for development:

    poetry install
  3. Install pre-commit hooks:

    pre-commit install

NOTE: You need to be inside the virtual environment where you installed the above dependencies every time you commit. However, this is not required if you have installed pre-commit globally.

Hyper-Parameter Configuration

Hyper-parameters can be specified through TOML configs. For example, to specify a batch size of 32 for the GAN and a learning rate of 0.001 for the generator, use the following config:

gan_batch_size = 32
gen_lr = 0.001

You can store configs in a directory named configs located in the root of this repository. It has an entry in the .gitignore file so that custom configs aren't picked up by git.

The available hyper-parameters, their documentation and default values are specified in the Config class in the file gan/utils.py.

Training

The GAN uses Frechet Inception Distance for evaluating its performance during training time. For this, we need to train a classifier before training the GAN.

  • Classifier: Run classifier.py:

    ./classifier.py
  • GAN: Run train.py after training a classifier:

    ./train.py

The weights of trained models are saved in TensorFlow's ckpt format to the directory given by the --save-dir argument. By default, this directory is checkpoints for both the classifier and the GAN.

Training logs are by default stored inside an ISO 8601 timestamp named subdirectory, which is stored in a parent directory (as given by the --log-dir argument). By default, this directory is logs/classifier for classifier, and logs/gan for the GAN.

The hyper-parameter config, along with the current date and time, is saved as a TOML file in both the model checkpoint directory and the timestamped log directory. For the classifier, it is named config-cls.toml, and for the GAN, it is named config-gan.toml.

Multi-GPU Training

This implementation supports multi-GPU training on a single machine for both the classifier and the GAN using TensorFlow's tf.distribute.MirroredStrategy.

For choosing which GPUs to train on, set the CUDA_VISIBLE_DEVICES environment variable when running a script as follows:

CUDA_VISIBLE_DEVICES=0,1,3 ./script.py

This selects the GPUs 0, 1 and 3 for training. By default, all available GPUs are chosen.

On-demand GPU Memory

TensorFlow allocates all the available GPU memory on each GPU. To instruct TensorFlow to allocate GPU memory only on demand, set the TF_FORCE_GPU_ALLOW_GROWTH environment variable when running a script as follows:

TF_FORCE_GPU_ALLOW_GROWTH=true ./script.py

Mixed Precision Training

This implementation supports mixed-precision training. This can be enabled by setting the mixed_precision hyper-parameter in a config, as follows:

mixed_precision = true

Note that this will only provide significant speed-ups if your GPU(s) have special support for mixed-precision compute.

Generation

A generation script is provided to generate images using a trained GAN. This will generate an equal number of images for each class in the dataset.

Run generate.py:

./generate.py

The generated images are saved in the directory given by the --output-dir argument. By default, this directory is outputs. The images will be saved as JPEG images with the file name formatted as {class_num}-{instance_num}.jpg. Here, {class_num} is the index of the image's class, and {instance_num} signifies whether this is the 1st, 2nd, or nth image generated from that class.

Samples

sample 0 sample 1 sample 2 sample 3 sample 4 sample 5 sample 6 sample 7 sample 8 sample 9
sample 0 sample 1 sample 2 sample 3 sample 4 sample 5 sample 6 sample 7 sample 8 sample 9
sample 0 sample 1 sample 2 sample 3 sample 4 sample 5 sample 6 sample 7 sample 8 sample 9

Licenses

This repository uses REUSE to document licenses. Each file either has a header containing copyright and license information, or has an entry in the DEP5 file at .reuse/dep5. The license files that are used in this project can be found in the LICENSES directory.

The MIT license is placed in LICENSE, to signify that it constitutes the majority of the codebase, and for compatibility with GitHub.