Skip to content

Tensorflow-based framework to ease training of generative models

License

Notifications You must be signed in to change notification settings

MohitBurkule/simplegan

 
 

Repository files navigation

SimpleGAN

License Documentation Status Downloads Downloads Code style: black

Framework to ease training of generative models

SimpleGAN is a framework based on TensorFlow to make training of generative models easier. SimpleGAN provides high level APIs with customizability options to user which allows them to train a generative models with few lines of code or the user can reuse modules from the exisiting architectures to run custom training loops and experiments.

Requirements

Make sure you have the following packages installed

Installation

Latest Development release:

  $ pip install git+https://github.com/mohitburkule/simplegan.git
  
  or 
  
  pip install --user --force-reinstall git+https://github.com/mohitburkule/simplegan.git

Getting Started

DCGAN
from simplegan.gan import DCGAN

## initialize model
gan = DCGAN()

## load train data
train_ds = gan.load_data(use_mnist = True)

## get samples from the data object
samples = gan.get_sample(train_ds, n_samples = 5)

## train the model
gan.fit(train_ds = train_ds)

## get generated samples from model
generated_samples = gan.generate_samples(n_samples = 5)
Custom training loops for GANs
from simplegan.gan import Pix2Pix

## initialize model
gan = Pix2Pix()

## get generator module of Pix2Pix
generator = gan.generator() ## A tf.keras model

## get discriminator module of Pix2Pix
discriminator = gan.discriminator() ## A tf.keras model

## training loop
with tf.GradientTape() as tape:
""" Custom training loops """
Convolutional Autoencoder
from simplegan.autoencoder import ConvolutionalAutoencoder

## initialize autoencoder
autoenc = ConvolutionalAutoencoder()

## load train and test data
train_ds, test_ds = autoenc.load_data(use_cifar10 = True)

## get sample from data object
train_sample = autoenc.get_sample(data = train_ds, n_samples = 5)
test_sample = autoenc.get_sample(data = test_ds, n_samples = 1)

## train the autoencoder
autoenc.fit(train_ds = train_ds, epochs = 5, optimizer = 'RMSprop', learning_rate = 0.002)

## get generated test samples from model
generated_samples = autoenc.generate_samples(test_ds = test_ds.take(1))

To have a look at more examples in detail, check here

Documentation

Check out the docs page

Provided models

Model Generated Images
Vanilla Autoencoder None
Convolutional Autoencoder
Variational Autoencoder [Paper]
Vector Quantized - Variational Autoencoder [Paper]
Vanilla GAN [Paper]
DCGAN [Paper]
WGAN [Paper]
CGAN [Paper]
InfoGAN [Paper]
Pix2Pix [Paper]
CycleGAN [Paper]
3DGAN(VoxelGAN) [Paper]
Self-Attention GAN(SAGAN) [Paper]

Contributing

We appreciate all contributions. If you are planning to perform bug-fixes, add new features or models, please file an issue and discuss before making a pull request.

Citation

@software{simplegan,
    author = {{Rohith Gandhi et al.}},
    title = {simplegan},
    url = {https://simplegan.readthedocs.io},
    version = {0.2.9},
}

Contributors

About

Tensorflow-based framework to ease training of generative models

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%