Skip to content

Commit

Permalink
Add training methods in cyclegan model class. Freeze D when training …
Browse files Browse the repository at this point in the history
…G, stop gradients propagating from G when training D, and fixes image history buffer.
  • Loading branch information
herbiebradley committed Feb 4, 2019
1 parent bb68513 commit 35188e5
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 61 deletions.
122 changes: 86 additions & 36 deletions src/models/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,36 @@ def __init__(self, initial_learning_rate, num_gen_filters, num_disc_filters,
img_size, training):
self.isTrain = training
self.checkpoint_dir = checkpoint_dir
self.initial_learning_rate = initial_learning_rate
self.cyc_lambda = cyc_lambda
self.identity_lambda = identity_lambda

self.genA2B = Generator(num_gen_filters, img_size=img_size)
self.genB2A = Generator(num_gen_filters, img_size=img_size)

if self.isTrain:
self.discA = Discriminator(num_disc_filters)
self.discB = Discriminator(num_disc_filters)
self.learning_rate = tf.contrib.eager.Variable(initial_learning_rate, dtype=tf.float32, name='learning_rate')
self.discA_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
self.discB_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
self.genA2B_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
self.genB2A_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
self.learning_rate = tf.contrib.eager.Variable(initial_learning_rate,
dtype=tf.float32, name='learning_rate')
self.disc_opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
self.gen_opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
self.global_step = tf.train.get_or_create_global_step()
# Initialize history buffers:
self.discA_buffer = ImageHistoryBuffer(50, batch_size, img_size // 8) # / 8 for PatchGAN
self.discB_buffer = ImageHistoryBuffer(50, batch_size, img_size // 8)
self.discA_buffer = ImageHistoryBuffer(50, batch_size, img_size)
self.discB_buffer = ImageHistoryBuffer(50, batch_size, img_size)
# Restore latest checkpoint:
self.initialize_checkpoint()
self.restore_checkpoint()


def forward(self):
raise NotImplementedError

def optimize_parameters(self):
raise NotImplementedError

def initialize_checkpoint(self):
if self.isTrain:
self.checkpoint = tf.train.Checkpoint(discA=self.discA,
discB=self.discB,
genA2B=self.genA2B,
genB2A=self.genB2A,
discA_opt=self.discA_opt,
discB_opt=self.discB_opt,
genA2B_opt=self.genA2B_opt,
genB2A_opt=self.genB2A_opt,
disc_opt=self.disc_opt,
gen_opt=self.gen_opt,
learning_rate=self.learning_rate,
global_step=self.global_step)
else:
Expand All @@ -76,34 +69,91 @@ def restore_checkpoint(self):
print("No checkpoint found, initializing model.")

def load_batch(self, input_batch):
self.realA = input_batch[0].get_next()
self.realB = input_batch[1].get_next()
self.dataA = input_batch[0].get_next()
self.dataB = input_batch[1].get_next()

def forward(self):
self.fakeB = self.genA2B(self.realA)
# Gen output shape: (batch_size, img_size, img_size, 3)
self.fakeB = self.genA2B(self.dataA)
self.reconstructedA = self.genB2A(self.fakeB)

self.fakeA = self.genB2A(self.realB)
self.fakeA = self.genB2A(self.dataB)
self.reconstructedB = self.genA2B(self.fakeA)

def backward_D_basic(self, disc, real, fake):
# Real
pred_real = disc(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = disc(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D(self, netD, real, fake):
# Disc output shape: (batch_size, img_size/8, img_size/8, 1)
pred_real = netD(real)
pred_fake = netD(tf.stop_gradient(fake)) # Detaches generator from D
disc_loss = discriminator_loss(pred_real, pred_fake)
return disc_loss

def backward_discA(self):
fake_A = self.discA_buffer.query(self.fakeA)
discA_loss = self.backward_D(self.discA, self.dataA, fake_A)
return discA_loss

def backward_discB(self):
fake_B = self.discB_buffer.query(self.fakeB)
discB_loss = self.backward_D(self.discB, self.dataB, fake_B)
return discB_loss

def backward_G(self):
if self.identity_lambda > 0:
identityA = self.genB2A(self.dataA)
id_lossA = identity_loss(self.dataA, identityA) * self.cyc_lambda * self.identity_lambda

identityB = self.genA2B(self.dataB)
id_lossB = identity_loss(self.dataB, identityB) * self.cyc_lambda * self.identity_lambda
else:
id_lossA, id_lossB = 0, 0

genA2B_loss = generator_loss(self.discB(self.fakeB))
genB2A_loss = generator_loss(self.discA(self.fakeA))

cyc_lossA = cycle_consistency_loss(self.dataA, self.reconstructedA) * self.cyc_lambda
cyc_lossB = cycle_consistency_loss(self.dataB, self.reconstructedB) * self.cyc_lambda

gen_loss = genA2B_loss + genB2A_loss + cyc_lossA + cyc_lossB + id_lossA + id_lossB
return gen_loss

def optimize_parameters(self):
raise NotImplementedError
for net in (self.discA, self.discB):
for layer in net.layers:
layer.trainable = False

with tf.GradientTape() as genTape: # Upgrade to 1.12 for watching?
genTape.watch([self.genA2B.variables, self.genB2A.variables])

self.forward()
gen_loss = self.backward_G()

gen_variables = [self.genA2B.variables, self.genB2A.variables]
gen_gradients = genTape.gradient(gen_loss, gen_variables)
self.gen_opt.apply_gradients(list(zip(gen_gradients[0], gen_variables[0])) \
+ list(zip(gen_gradients[1], gen_variables[1])),
global_step=self.global_step)

for net in (self.discA, self.discB):
for layer in net.layers:
layer.trainable = True

with tf.GradientTape(persistent=True) as discTape: # Try 2 disc tapes?
discTape.watch([self.discA.variables, self.discB.variables])

discA_loss = self.backward_discA()
discB_loss = self.backward_discB()

discA_gradients = discTape.gradient(discA_loss, self.discA.variables)
discB_gradients = discTape.gradient(discB_loss, self.discB.variables)
self.disc_opt.apply_gradients(zip(discA_gradients, self.discA.variables),
global_step=self.global_step)
self.disc_opt.apply_gradients(zip(discB_gradients, self.discB.variables),
global_step=self.global_step)

def save_model(self):
raise NotImplementedError
checkpoint_prefix = os.path.join(self.checkpoint_dir, 'ckpt')
checkpoint_path = self.checkpoint.save(file_prefix=checkpoint_prefix)
print("Checkpoint saved at ", checkpoint_path)

def update_learning_rate(self):
raise NotImplementedError
16 changes: 7 additions & 9 deletions src/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import tensorflow as tf

def discriminator_loss(disc_of_real_output, disc_of_gen_output, use_lsgan=True):
label_value = 1 # TODO: Implement proper label for smoothing
def discriminator_loss(disc_of_real_output, disc_of_gen_output, label_value=1, use_lsgan=True):
if use_lsgan: # Use least squares loss
real_loss = tf.reduce_mean(tf.squared_difference(disc_of_real_output, label_value))
generated_loss = tf.reduce_mean(tf.square(disc_of_gen_output))
generated_loss = tf.reduce_mean(tf.squared_difference(disc_of_gen_output, 1-label_value))

total_disc_loss = (real_loss + generated_loss) * 0.5 # * 0.5 slows down rate D learns compared to G

Expand All @@ -20,8 +19,7 @@ def discriminator_loss(disc_of_real_output, disc_of_gen_output, use_lsgan=True):

return total_disc_loss

def generator_loss(disc_of_gen_output, use_lsgan=True):
label_value = 1
def generator_loss(disc_of_gen_output, label_value=1, use_lsgan=True):
if use_lsgan: # Use least squares loss
gen_loss = tf.reduce_mean(tf.squared_difference(disc_of_gen_output, label_value))

Expand All @@ -31,16 +29,16 @@ def generator_loss(disc_of_gen_output, use_lsgan=True):

return gen_loss

def cycle_consistency_loss(dataA, reconstructedA, norm='l1'):
def cycle_consistency_loss(data, reconstructed, norm='l1'):
if norm == 'l1':
loss = tf.reduce_mean(tf.abs(reconstructedA - dataA))
loss = tf.reduce_mean(tf.abs(reconstructed - data))
return loss
else:
raise NotImplementedError #TODO: l2 norm option

def identity_loss(trainA, trainB, identityA, identityB, norm='l1'):
def identity_loss(data, identity, norm='l1'):
if norm == 'l1':
loss = tf.reduce_mean(tf.abs(identityA - trainA)) + tf.reduce_mean(tf.abs(identityB - trainB))
loss = tf.reduce_mean(tf.abs(identity - data))
return loss
else:
raise NotImplementedError #TODO: l2 norm option
4 changes: 2 additions & 2 deletions src/pipeline/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_train_data(dataset_id, project_dir, batch_size=1):
# Queue up batches asynchronously onto the GPU.
# As long as there is a pool of batches CPU side a GPU prefetch of 1 is fine.
# TODO: If GPU exists:
train_datasetA = train_datasetA.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))
#train_datasetA = train_datasetA.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))

train_datasetB = tf.data.Dataset.list_files(trainB_path + os.sep + '*.jpg', shuffle=False)
train_datasetB = train_datasetB.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=trainB_size))
Expand All @@ -72,7 +72,7 @@ def load_train_data(dataset_id, project_dir, batch_size=1):
num_parallel_calls=threads,
drop_remainder=True))
train_datasetB = train_datasetB.prefetch(buffer_size=threads)
train_datasetB = train_datasetB.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))
#train_datasetB = train_datasetB.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))

return train_datasetA, train_datasetB

Expand Down
22 changes: 13 additions & 9 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pipeline.data import load_train_data
from models.losses import generator_loss, discriminator_loss, cycle_consistency_loss, identity_loss
from models.networks import Generator, Discriminator
from models.cyclegan import CycleGANModel
from utils.image_history_buffer import ImageHistoryBuffer

tf.enable_eager_execution()
Expand All @@ -29,10 +30,9 @@
save_epoch_freq = 5
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)

def train(data, model, checkpoint_info, epochs):
def train(data, model, options):
# Create a tf.data.Iterator from the Datasets:
data = iter(data[0]), iter(data[1])
model.load_batch(data)
# Initialize Tensorboard summary writer:
log_dir = os.path.join(project_dir, 'saved_models', 'tensorboard')
summary_writer = tf.contrib.summary.create_file_writer(log_dir, flush_millis=10000)
Expand Down Expand Up @@ -69,8 +69,8 @@ def train(data, model, checkpoint_info, epochs):
identityB = genA2B(trainB)
id_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)

genA2B_loss_basic = generator_loss(discB_fake_refined)
genB2A_loss_basic = generator_loss(discA_fake_refined)
genA2B_loss_basic = generator_loss(discB_fake)
genB2A_loss_basic = generator_loss(discA_fake)
cyc_lossA = cyc_lambda * cycle_consistency_loss(trainA, reconstructedA)
cyc_lossB = cyc_lambda * cycle_consistency_loss(trainB, reconstructedB)

Expand Down Expand Up @@ -110,12 +110,16 @@ def train(data, model, checkpoint_info, epochs):
print("Global Training Step: ", global_step.numpy() // 4)
print ("Time taken for total epoch {} is {} sec\n".format(global_step.numpy() // (4 * batches_per_epoch),
time.time()-start))

if __name__ == "__main__":
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_train_data(dataset_id, project_dir)
with tf.device("/gpu:0"):
model = define_model(initial_learning_rate, training=True)
checkpoint_info = initialize_checkpoint(checkpoint_dir, model, training=True)
train(data, model, checkpoint_info, epochs=epochs)
model = CycleGANModel(initial_learning_rate, num_gen_filters,
num_disc_filters, batch_size, cyc_lambda,
identity_lambda, checkpoint_dir, img_size,
training=True)
#with tf.device("/gpu:0"):
data = iter(data[0]), iter(data[1])
model.load_batch(data)
model.optimize_parameters()
6 changes: 1 addition & 5 deletions src/utils/image_history_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ class ImageHistoryBuffer(object):
Attributes:
image_history_buffer: Numpy array of image batches used to calculate average loss.
"""
def __init__(self, max_buffer_size, batch_size, img_size):
self.max_buffer_size = max_buffer_size
self.batch_size = batch_size
self.image_history_buffer = np.zeros((0, img_size, img_size,1))
self.image_history_buffer = np.zeros((0, img_size, img_size, 3))
assert(self.batch_size >= 1)

def query(self, image_batch):
Expand All @@ -41,9 +40,7 @@ def query(self, image_batch):
Returns:
Tensor: Processed batch.
"""

image_batch = image_batch.numpy()
self._add_to_image_history_buffer(image_batch)
if self.batch_size > 1:
Expand All @@ -60,7 +57,6 @@ def _add_to_image_history_buffer(self, image_batch):
Args:
image_batch (ndarray): Incoming image batch.
"""
images_to_add = max(1, self.batch_size // 2)

Expand Down

0 comments on commit 35188e5

Please sign in to comment.