Skip to content

Commit

Permalink
Training fix
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Dec 11, 2018
1 parent 8332027 commit 5ebacc7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def generator_loss(disc_of_gen_output, use_lsgan=True):

return gen_loss

def cycle_consistency_loss(dataA, dataB, reconstructedA, reconstructedB, norm='l1'):
def cycle_consistency_loss(dataA, reconstructedA, norm='l1'):
if norm == 'l1':
loss = tf.reduce_mean(tf.abs(reconstructedA - dataA)) + tf.reduce_mean(tf.abs(reconstructedB - dataB))
loss = tf.reduce_mean(tf.abs(reconstructedA - dataA))
return loss
else:
raise NotImplementedError #TODO: l2 norm option
Expand Down
9 changes: 4 additions & 5 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@

import tensorflow as tf

from train import define_checkpoint, define_model, restore_from_checkpoint
from train import initialize_checkpoint, define_model, restore_from_checkpoint
from pipeline.data import load_test_data, save_images

project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
dataset_id = 'facades'
initial_learning_rate = 0.0002

def test(data, model, checkpoint_info, dataset_id):
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
generatedA = os.path.join(path_to_dataset, 'generatedA' + os.sep)
generatedB = os.path.join(path_to_dataset, 'generatedB' + os.sep)
genA2B = model['genA2B']
genB2A = model['genB2A']
return None

checkpoint, checkpoint_dir = checkpoint_info
restore_from_checkpoint(checkpoint, checkpoint_dir)
test_datasetA, test_datasetB, testA_size, testB_size = data
Expand Down Expand Up @@ -57,7 +56,7 @@ def test(data, model, checkpoint_info, dataset_id):
if __name__ == "__main__":
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_test_data(dataset_id, project_dir)
#with tf.device("/gpu:0"):
model = define_model(initial_learning_rate, training=False)
with tf.device("/gpu:0"):
model = define_model(training=False)
checkpoint_info = initialize_checkpoint(checkpoint_dir, model, training=False)
test(data, model, checkpoint_info, dataset_id)
47 changes: 24 additions & 23 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
dataset_id = 'facades'
initial_learning_rate = 0.0002
num_gen_filters = 64
num_gen_filters = 32
num_disc_filters = 64
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
img_size = 256
cyc_lambda = 10
identity_lambda = 0
if dataset_id == 'monet2photo':
identity_lambda = 0.5
epochs = 50
identity_lambda = 0.5
epochs = 15
save_epoch_freq = 5
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)

Expand Down Expand Up @@ -132,7 +130,9 @@ def train(data, model, checkpoint_info, epochs):
with tf.GradientTape(persistent=True) as tape:
# Gen output shape: (batch_size, img_size, img_size, 3)
genA2B_output = genA2B(trainA)
reconstructedA = genB2A(genA2B_output)
genB2A_output = genB2A(trainB)
reconstructedB = genA2B(genB2A_output)
# Disc output shape: (batch_size, img_size/8, img_size/8, 1)
discA_real = discA(trainA)
discB_real = discB(trainB)
Expand All @@ -143,22 +143,23 @@ def train(data, model, checkpoint_info, epochs):
discA_fake_refined = discA_buffer.query(discA_fake)
discB_fake_refined = discB_buffer.query(discB_fake)

reconstructedA = genB2A(genA2B_output)
reconstructedB = genA2B(genB2A_output)
identityA, identityB = 0, 0
if dataset_id == 'monet2photo':
identityA = genB2A(trainA)
identityB = genA2B(trainB)

cyc_loss = cyc_lambda * cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB)
identityA = genB2A(trainA)
identityB = genA2B(trainB)
id_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)
genA2B_loss = generator_loss(discB_fake_refined) + cyc_loss + id_loss
genB2A_loss = generator_loss(discA_fake_refined) + cyc_loss + id_loss

genA2B_loss_basic = generator_loss(discB_fake_refined)
genB2A_loss_basic = generator_loss(discA_fake_refined)
cyc_loss_A = cyc_lambda * cycle_consistency_loss(trainA, reconstructedA)
cyc_loss_B = cyc_lambda * cycle_consistency_loss(trainB, reconstructedB)

genA2B_loss = genA2B_loss_basic + cyc_loss_A + cyc_loss_B + id_loss
genB2A_loss = genB2A_loss_basic + cyc_loss_B + cyc_loss_B + id_loss

discA_loss = discriminator_loss(discA_real, discA_fake_refined)
discB_loss = discriminator_loss(discB_real, discB_fake_refined)
# Summaries for Tensorboard:
tf.contrib.summary.scalar('loss/genA2B', genA2B_loss)
tf.contrib.summary.scalar('loss/genB2A', genB2A_loss)
tf.contrib.summary.scalar('loss/genA2B', genA2B_loss_basic)
tf.contrib.summary.scalar('loss/genB2A', genB2A_loss_basic)
tf.contrib.summary.scalar('loss/discA', discA_loss)
tf.contrib.summary.scalar('loss/discB', discB_loss)
tf.contrib.summary.scalar('loss/cyc', cyc_loss)
Expand All @@ -169,23 +170,23 @@ def train(data, model, checkpoint_info, epochs):
tf.contrib.summary.image('B/generated', genA2B_output)
tf.contrib.summary.image('B/reconstructed', reconstructedB)

discA_gradients = tape.gradient(discA_loss, discA.variables)
discB_gradients = tape.gradient(discB_loss, discB.variables)
# Try chaining disc and gen parameters into 2 optimizers?
genA2B_gradients = tape.gradient(genA2B_loss, genA2B.variables)
genB2A_gradients = tape.gradient(genB2A_loss, genB2A.variables)
# Try chaining disc and gen parameters into 2 optimizers?
discA_opt.apply_gradients(zip(discA_gradients, discA.variables), global_step=global_step)
discB_opt.apply_gradients(zip(discB_gradients, discB.variables), global_step=global_step)
genA2B_opt.apply_gradients(zip(genA2B_gradients, genA2B.variables), global_step=global_step)
genB2A_opt.apply_gradients(zip(genB2A_gradients, genB2A.variables), global_step=global_step)
discA_gradients = tape.gradient(discA_loss, discA.variables)
discB_gradients = tape.gradient(discB_loss, discB.variables)
discA_opt.apply_gradients(zip(discA_gradients, discA.variables), global_step=global_step)
discB_opt.apply_gradients(zip(discB_gradients, discB.variables), global_step=global_step)
# Assign decayed learning rate:
learning_rate.assign(models.get_learning_rate(initial_learning_rate, global_step, batches_per_epoch))
# Checkpoint the model:
if (epoch + 1) % save_epoch_freq == 0:
checkpoint_path = checkpoint.save(file_prefix=checkpoint_prefix)
print("Checkpoint saved at ", checkpoint_path)
print("Global Training Step: ", global_step.numpy() // 4)
print ("Time taken for local epoch {} is {} sec\n".format(global_step.numpy() // (4 * batches_per_epoch),
print ("Time taken for total epoch {} is {} sec\n".format(global_step.numpy() // (4 * batches_per_epoch),
time.time()-start))

if __name__ == "__main__":
Expand Down

0 comments on commit 5ebacc7

Please sign in to comment.