Skip to content

Commit

Permalink
Add WGAN-GP and LSGAN losses, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Jun 29, 2019
1 parent d4f349f commit 11dad4c
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 57 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ Requirements:
- Tensorflow 1.11
- Python 3.6

Note that the eval_cityscapes folder allows you to get FCN scores if you have the
full cityscapes dataset and caffe installed; folder copied from:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/tree/master/scripts/eval_cityscapes

Project Organization
------------

Expand Down
23 changes: 14 additions & 9 deletions src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import tensorflow as tf

class Dataset(object):

"""
Fully optimised tf.data loader.
For more info about the optimal way to use tf.data, see
https://www.reddit.com/r/MachineLearning/comments/a08fx6/d_why_tfdata_is_so_much_better_than_feed_dict_and/
"""
def __init__(self, opt):
self.opt = opt
self.gpu_id = "/gpu:" + str(self.opt.gpu_id)
Expand Down Expand Up @@ -67,8 +71,8 @@ def load_test_data(self):
test_datasetA = test_datasetA.prefetch(buffer_size=self.opt.num_threads)
test_datasetB = test_datasetB.prefetch(buffer_size=self.opt.num_threads)
if self.opt.gpu_id != -1:
train_datasetA = train_datasetA.apply(tf.contrib.data.prefetch_to_device(self.gpu_id, buffer_size=1))
train_datasetB = train_datasetB.apply(tf.contrib.data.prefetch_to_device(self.gpu_id, buffer_size=1))
test_datasetA = test_datasetA.apply(tf.contrib.data.prefetch_to_device(self.gpu_id, buffer_size=1))
test_datasetB = test_datasetB.apply(tf.contrib.data.prefetch_to_device(self.gpu_id, buffer_size=1))
return iter(test_datasetA), iter(test_datasetB)

def load_image(self, image_file):
Expand All @@ -87,10 +91,10 @@ def load_image(self, image_file):
return image

def save_images(self, test_images, image_index):
image_paths = [(os.path.join(opt.results_dir, 'generatedA', 'test' + str(image_index) + '_real.jpg'),
os.path.join(opt.results_dir, 'generatedA', 'test' + str(image_index) + '_fake.jpg'),
os.path.join(opt.results_dir, 'generatedB', 'test' + str(image_index) + '_real.jpg'),
os.path.join(opt.results_dir, 'generatedB', 'test' + str(image_index) + '_fake.jpg')]
image_paths = [os.path.join(self.opt.results_dir, 'generatedA', 'test' + str(image_index) + '_real.jpg'),
os.path.join(self.opt.results_dir, 'generatedA', 'test' + str(image_index) + '_fake.jpg'),
os.path.join(self.opt.results_dir, 'generatedB', 'test' + str(image_index) + '_real.jpg'),
os.path.join(self.opt.results_dir, 'generatedB', 'test' + str(image_index) + '_fake.jpg')]
for i in range(len(test_images)):
# Reshape to get rid of batch size dimension in the tensor.
image = tf.reshape(test_images[i], shape=[self.opt.img_size, self.opt.img_size, 3])
Expand All @@ -99,8 +103,9 @@ def save_images(self, test_images, image_index):
# Convert to uint8 (range [0, 255]), saturate to avoid possible under/overflow.
image = tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
# JPEG encode image into string Tensor.
image_string = tf.image.encode_jpeg(image, format='rgb', quality=95)
tf.write_file(filename=image_paths[i], contents=image_string)
with tf.device("/cpu:0"):
image_string = tf.image.encode_jpeg(image, format='rgb', quality=95)
tf.write_file(filename=image_paths[i], contents=image_string)

def get_batches_per_epoch(self, opt):
# floor(Avg dataset size / batch_size)
Expand Down
4 changes: 3 additions & 1 deletion src/data/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import tensorflow as tf

"""
This script downloads the cyclegan datasets, but will overwrite if dataset is already there.
"""
def download_data(dataset_id, download_location):
path_to_zip = tf.keras.utils.get_file(dataset_id + '.zip', cache_subdir=download_location,
origin='https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + dataset_id + '.zip',
Expand All @@ -15,5 +18,4 @@ def download_data(dataset_id, download_location):
parser.add_argument('--data_dir', required=True, help='download data to this directory')
opt = parser.parse_args()

# TODO: Add code to check if dataset is already there.
download_data(opt.dataset_id, opt.data_dir)
64 changes: 39 additions & 25 deletions src/models/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from utils.image_history_buffer import ImageHistoryBuffer

class CycleGANModel(object):

"""
CycleGAN model class, responsible for checkpointing and the forward and backward pass.
Inspired by:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py
"""
def __init__(self, opt):
self.opt = opt

Expand Down Expand Up @@ -41,13 +45,12 @@ def initialize_checkpoint(self):
global_step=self.global_step)
else:
self.checkpoint = tf.train.Checkpoint(genA2B=self.genA2B,
genB2A=self.genB2A,
global_step=self.global_step)
genB2A=self.genB2A)

def restore_checkpoint(self):
checkpoint_dir = os.path.join(self.opt.save_dir, 'checkpoints')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if self.opt.load_checkpoint and latest_checkpoint is not None:
if (not self.opt.training or self.opt.load_checkpoint) and latest_checkpoint is not None:
# Use assert_existing_objects_matched() instead of asset_consumed() here because
# optimizers aren't initialized fully until first gradient update.
# This will throw an exception if the checkpoint does not restore the model weights.
Expand All @@ -69,50 +72,61 @@ def forward(self):
self.fakeA = self.genB2A(self.dataB)
self.reconstructedB = self.genA2B(self.fakeA)

def backward_D(self, netD, real, fake):
def backward_D(self, netD, real, fake, tape):
# Disc output shape: (batch_size, img_size/8, img_size/8, 1)
pred_real = netD(real)
pred_fake = netD(tf.stop_gradient(fake)) # Detaches generator from D
disc_loss = discriminator_loss(pred_real, pred_fake)
disc_loss = discriminator_loss(pred_real, pred_fake, self.opt.gan_mode)
if self.opt.gan_mode == 'wgangp': # GRADIENT PENALTY
with tape.stop_recording():
epsilon = tf.random_uniform(shape=[BATCH_SIZE, 1, 1, 1], minval=0., maxval=1.)
X_hat = real + epsilon * (fake - real)
def gp_func(X_hat):
return netD(X_hat)
gp_grad_func = tf.contrib.eager.gradients_function(gp_func)
grad_critic_X_hat = gp_grad_func(X_hat)[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_critic_X_hat), axis=[1, 2, 3]))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
disc_loss += 10 * gradient_penalty # Lambda = 10 in gradient penalty
return disc_loss

def backward_discA(self):
def backward_discA(self, tape):
# Sample from history buffer of 50 images:
fake_A = self.discA_buffer.query(self.fakeA)
discA_loss = self.backward_D(self.discA, self.dataA, fake_A)
return discA_loss
self.discA_loss = self.backward_D(self.discA, self.dataA, fake_A, tape)
return self.discA_loss

def backward_discB(self):
def backward_discB(self, tape):
# Sample from history buffer of 50 images:
fake_B = self.discB_buffer.query(self.fakeB)
discB_loss = self.backward_D(self.discB, self.dataB, fake_B)
return discB_loss
self.discB_loss = self.backward_D(self.discB, self.dataB, fake_B, tape)
return self.discB_loss

def backward_G(self):
if self.opt.identity_lambda > 0:
identityA = self.genB2A(self.dataA)
id_lossA = identity_loss(self.dataA, identityA) * self.opt.cyc_lambda * self.opt.identity_lambda
self.id_lossA = identity_loss(self.dataA, identityA) * self.opt.cyc_lambda * self.opt.identity_lambda

identityB = self.genA2B(self.dataB)
id_lossB = identity_loss(self.dataB, identityB) * self.opt.cyc_lambda * self.opt.identity_lambda
self.id_lossB = identity_loss(self.dataB, identityB) * self.opt.cyc_lambda * self.opt.identity_lambda
else:
id_lossA, id_lossB = 0, 0

genA2B_loss = generator_loss(self.discB(self.fakeB))
genB2A_loss = generator_loss(self.discA(self.fakeA))
self.genA2B_loss = generator_loss(self.discB(self.dataB), self.discB(self.fakeB), self.opt.gan_mode)
self.genB2A_loss = generator_loss(self.discA(self.dataA), self.discA(self.fakeA), self.opt.gan_mode)

cyc_lossA = cycle_loss(self.dataA, self.reconstructedA) * self.opt.cyc_lambda
cyc_lossB = cycle_loss(self.dataB, self.reconstructedB) * self.opt.cyc_lambda
self.cyc_lossA = cycle_loss(self.dataA, self.reconstructedA) * self.opt.cyc_lambda
self.cyc_lossB = cycle_loss(self.dataB, self.reconstructedB) * self.opt.cyc_lambda

gen_loss = genA2B_loss + genB2A_loss + cyc_lossA + cyc_lossB + id_lossA + id_lossB
gen_loss = self.genA2B_loss + self.genB2A_loss + self.cyc_lossA + self.cyc_lossB + self.id_lossA + self.id_lossB
return gen_loss

def optimize_parameters(self):
for net in (self.discA, self.discB):
for layer in net.layers:
layer.trainable = False

with tf.GradientTape() as genTape: # Upgrade to 1.12 for watching?
with tf.GradientTape() as genTape:
genTape.watch([self.genA2B.variables, self.genB2A.variables])

self.forward()
Expand All @@ -128,11 +142,11 @@ def optimize_parameters(self):
for layer in net.layers:
layer.trainable = True

with tf.GradientTape(persistent=True) as discTape: # Try 2 disc tapes?
with tf.GradientTape(persistent=True) as discTape:
discTape.watch([self.discA.variables, self.discB.variables])

discA_loss = self.backward_discA()
discB_loss = self.backward_discB()
self.forward()
discA_loss = self.backward_discA(discTape)
discB_loss = self.backward_discB(discTape)

discA_gradients = discTape.gradient(discA_loss, self.discA.variables)
discB_gradients = discTape.gradient(discB_loss, self.discB.variables)
Expand Down
38 changes: 32 additions & 6 deletions src/models/losses.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import tensorflow as tf
"""
This module defines all CycleGAN losses. Options are included for
LSGAN, WGAN, and RGAN.
"""

def discriminator_loss(disc_of_real_output, disc_of_gen_output, label_value=1, use_lsgan=True):
if use_lsgan: # Use least squares loss
def discriminator_loss(disc_of_real_output, disc_of_gen_output, gan_mode='lsgan', label_value=1):
if gan_mode == 'lsgan': # Use least squares loss
real_loss = tf.reduce_mean(tf.squared_difference(disc_of_real_output, label_value))
generated_loss = tf.reduce_mean(tf.squared_difference(disc_of_gen_output, 1-label_value))

total_disc_loss = (real_loss + generated_loss) * 0.5 # * 0.5 slows down rate D learns compared to G

elif gan_mode == 'wgangp': # WGAN-GP loss
total_disc_loss = tf.reduce_mean(disc_of_gen_output) - tf.reduce_mean(disc_of_real_output)

elif gan_mode == 'rgan': # RGAN with vanilla GAN loss
real = disc_of_real_output - disc_of_gen_output
fake = disc_of_gen_output - disc_of_real_output

real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real), logits=real)
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(fake), logits=fake)

total_disc_loss = real_loss + generated_loss

else: # Use vanilla GAN loss
real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_of_real_output), logits=disc_of_real_output)
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(disc_of_gen_output), logits=disc_of_gen_output)
Expand All @@ -15,13 +31,23 @@ def discriminator_loss(disc_of_real_output, disc_of_gen_output, label_value=1, u

return total_disc_loss

def generator_loss(disc_of_gen_output, label_value=1, use_lsgan=True):
if use_lsgan: # Use least squares loss
def generator_loss(disc_of_real_output, disc_of_gen_output, gan_mode='lsgan', label_value=1):
if gan_mode == 'lsgan': # Use least squares loss
gen_loss = tf.reduce_mean(tf.squared_difference(disc_of_gen_output, label_value))

elif gan_mode == 'wgangp': # WGAN-GP loss
gen_loss = -tf.reduce_mean(disc_of_gen_output)

elif gan_mode == 'rgan': # RGAN with vanilla GAN loss
real = disc_of_real_output - disc_of_gen_output
fake = disc_of_gen_output - disc_of_real_output

real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real), logits=real)
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(fake), logits=fake)
gen_loss = real_loss + generated_loss

else: # Use vanilla GAN loss
gen_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_generated_output), logits=disc_generated_output)
#l1_loss = tf.reduce_mean(tf.abs(target - gen_output)) # Look up pix2pix loss
gen_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(disc_of_gen_output), logits=disc_of_gen_output)

return gen_loss

Expand Down
4 changes: 4 additions & 0 deletions src/models/networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import tensorflow as tf

"""
This file defines the CycleGAN generator and discriminator.
Options are included for extra skips, instance norm, dropout, and resize conv instead of deconv
"""
class Encoder(tf.keras.Model):

def __init__(self, opt):
Expand Down
6 changes: 4 additions & 2 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from models.cyclegan import CycleGANModel

tf.enable_eager_execution()

"""
Run this module for testing.
Required args: --data_dir, --save_dir
"""
if __name__ == "__main__":
opt = Options().parse(training=False)
# TODO: Test if this is always on CPU:
dataset = Dataset(opt)
model = CycleGANModel(opt)

Expand Down
19 changes: 10 additions & 9 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from models.cyclegan import CycleGANModel

tf.enable_eager_execution()

"""
Run this module for training.
Required args: --data_dir, --save_dir, --results_dir
"""
if __name__ == "__main__":
opt = Options().parse(training=True)
# TODO: Test if this is always on CPU:
dataset = Dataset(opt)
model = CycleGANModel(opt)

Expand All @@ -31,12 +33,12 @@
model.set_input(dataset.data)
model.optimize_parameters()
# Summaries for Tensorboard:
#tf.contrib.summary.scalar('loss/genA2B', genA2B_loss_basic)
#tf.contrib.summary.scalar('loss/genB2A', genB2A_loss_basic)
#tf.contrib.summary.scalar('loss/discA', discA_loss)
#tf.contrib.summary.scalar('loss/discB', discB_loss)
#tf.contrib.summary.scalar('loss/cyc', cyc_loss)
#tf.contrib.summary.scalar('loss/identity', id_loss)
tf.contrib.summary.scalar('loss/genA2B', model.genA2B_loss)
tf.contrib.summary.scalar('loss/genB2A', model.genB2A_loss)
tf.contrib.summary.scalar('loss/discA', model.discA_loss)
tf.contrib.summary.scalar('loss/discB', model.discB_loss)
tf.contrib.summary.scalar('loss/cyc', model.cyc_lossA + model.cyc_lossB)
tf.contrib.summary.scalar('loss/identity', model.id_lossA + model.id_lossB)
tf.contrib.summary.scalar('learning_rate', model.learning_rate)
tf.contrib.summary.image('A/generated', model.fakeA)
tf.contrib.summary.image('A/reconstructed', model.reconstructedA)
Expand All @@ -48,6 +50,5 @@
if epoch % opt.save_epoch_freq == 0:
model.save_model()
print("Global Training Step: ", global_step.numpy() // 3)
# TODO: Better progress prints (epoch bar filling up?)
print("Time taken for total epoch {} is {} sec\n".format(global_step.numpy() \
// (3 * batches_per_epoch), time.time()-start))
6 changes: 5 additions & 1 deletion src/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import multiprocessing

class Options(object):

"""
Options class - defines all train/test options and prints them out as a summary.
Inspired by
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/options/base_options.py
"""
def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# basic options
Expand Down

0 comments on commit 11dad4c

Please sign in to comment.