This is a repository for training a conditional GAN for the MNIST dataset. The GAN is optimized using the Wasserstein loss and the Wasserstein gradient penalty. A DCGAN-like architecture is used along with spectral normalization for the critic.
All Python scripts use argparse to parse commandline arguments. For viewing the list of all positional and optional arguments for any script, type:
./script.py --help
Poetry is used for conveniently installing and managing dependencies.
-
[Optional] Create and activate a virtual environment with Python >= 3.8.
-
Install Poetry globally (recommended), or in a virtual environment. Please refer to Poetry's installation guide for recommended installation options.
You can use pip to install it:
pip install poetry
-
Install all dependencies with Poetry:
poetry install --no-dev
If you didn't create and activate a virtual environment in step 1, Poetry creates one for you and installs all dependencies there. To use this virtual environment, run:
poetry shell
-
Download the MNIST dataset using the provided script (requires cURL >= 7.19.0):
./download_mnist.sh [/path/where/dataset/should/be/saved/]
By default, this dataset is saved to the directory
datasets/MNIST
.
pre-commit is used for managing hooks that run before each commit, to ensure code quality and run some basic tests. Thus, this needs to be set up only when one intends to commit changes to git.
-
Activate the virtual environment where you installed the dependencies.
-
Install all dependencies, including extra dependencies for development:
poetry install
-
Install pre-commit hooks:
pre-commit install
NOTE: You need to be inside the virtual environment where you installed the above dependencies every time you commit. However, this is not required if you have installed pre-commit globally.
Hyper-parameters can be specified through TOML configs. For example, to specify a batch size of 32 for the GAN and a learning rate of 0.001 for the generator, use the following config:
gan_batch_size = 32
gen_lr = 0.001
You can store configs in a directory named configs
located in the root of this repository.
It has an entry in the .gitignore
file so that custom configs aren't picked up by git.
The available hyper-parameters, their documentation and default values are specified in the Config
class in the file gan/utils.py
.
The GAN uses Frechet Inception Distance for evaluating its performance during training time. For this, we need to train a classifier before training the GAN.
-
Classifier: Run
classifier.py
:./classifier.py
-
GAN: Run
train.py
after training a classifier:./train.py
The weights of trained models are saved in TensorFlow's ckpt format to the directory given by the --save-dir
argument.
By default, this directory is checkpoints
for both the classifier and the GAN.
Training logs are by default stored inside an ISO 8601 timestamp named subdirectory, which is stored in a parent directory (as given by the --log-dir
argument).
By default, this directory is logs/classifier
for classifier, and logs/gan
for the GAN.
The hyper-parameter config, along with the current date and time, is saved as a TOML file in both the model checkpoint directory and the timestamped log directory.
For the classifier, it is named config-cls.toml
, and for the GAN, it is named config-gan.toml
.
This implementation supports multi-GPU training on a single machine for both the classifier and the GAN using TensorFlow's tf.distribute.MirroredStrategy
.
For choosing which GPUs to train on, set the CUDA_VISIBLE_DEVICES
environment variable when running a script as follows:
CUDA_VISIBLE_DEVICES=0,1,3 ./script.py
This selects the GPUs 0, 1 and 3 for training. By default, all available GPUs are chosen.
TensorFlow allocates all the available GPU memory on each GPU.
To instruct TensorFlow to allocate GPU memory only on demand, set the TF_FORCE_GPU_ALLOW_GROWTH
environment variable when running a script as follows:
TF_FORCE_GPU_ALLOW_GROWTH=true ./script.py
This implementation supports mixed-precision training.
This can be enabled by setting the mixed_precision
hyper-parameter in a config, as follows:
mixed_precision = true
Note that this will only provide significant speed-ups if your GPU(s) have special support for mixed-precision compute.
A generation script is provided to generate images using a trained GAN. This will generate an equal number of images for each class in the dataset.
Run generate.py
:
./generate.py
The generated images are saved in the directory given by the --output-dir
argument.
By default, this directory is outputs
.
The images will be saved as JPEG images with the file name formatted as {class_num}-{instance_num}.jpg
.
Here, {class_num}
is the index of the image's class, and {instance_num}
signifies whether this is the 1st, 2nd, or nth image generated from that class.
This repository uses REUSE to document licenses. Each file either has a header containing copyright and license information, or has an entry in the DEP5 file at .reuse/dep5. The license files that are used in this project can be found in the LICENSES directory.
The MIT license is placed in LICENSE, to signify that it constitutes the majority of the codebase, and for compatibility with GitHub.