Skip to content

Commit

Permalink
Add lr decay to model class, finish training loop in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Feb 5, 2019
1 parent 35188e5 commit 84d855c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 104 deletions.
6 changes: 4 additions & 2 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@

import tensorflow as tf

# TODO: Merge into dataset class.
def get_batches_per_epoch(dataset_id, project_dir, batch_size=1):
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
trainA_path = os.path.join(path_to_dataset, 'trainA')
trainB_path = os.path.join(path_to_dataset, 'trainB')
trainA_size = len(os.listdir(trainA_path))
trainB_size = len(os.listdir(trainB_path))
batches_per_epoch = (trainA_size + trainB_size) // (2 * batch_size) # floor(Average dataset size / batch_size)
batches_per_epoch = (trainA_size + trainB_size) // (2 * batch_size) # floor(Avg dataset size / batch_size)
return batches_per_epoch

# TODO: Merge into basemodel class?
def get_learning_rate(initial_learning_rate, global_step, batches_per_epoch, const_iterations=100, decay_iterations=100):
global_step = global_step.numpy() / 4 # /4 because there are 4 gradient updates per batch.
global_step = global_step.numpy() / 3 # /3 because there are 3 gradient updates per batch.
total_epochs = global_step // batches_per_epoch
learning_rate_lambda = 1.0 - max(0, total_epochs - const_iterations) / float(decay_iterations + 1)
return initial_learning_rate * max(0, learning_rate_lambda)
14 changes: 9 additions & 5 deletions src/models/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def restore_checkpoint(self):
else:
print("No checkpoint found, initializing model.")

def load_batch(self, input_batch):
self.dataA = input_batch[0].get_next()
self.dataB = input_batch[1].get_next()
def set_input(self, input):
self.dataA = input["A"]
self.dataB = input["B"]

def forward(self):
# Gen output shape: (batch_size, img_size, img_size, 3)
Expand All @@ -88,11 +88,13 @@ def backward_D(self, netD, real, fake):
return disc_loss

def backward_discA(self):
# Sample from history buffer of 50 images:
fake_A = self.discA_buffer.query(self.fakeA)
discA_loss = self.backward_D(self.discA, self.dataA, fake_A)
return discA_loss

def backward_discB(self):
# Sample from history buffer of 50 images:
fake_B = self.discB_buffer.query(self.fakeB)
discB_loss = self.backward_D(self.discB, self.dataB, fake_B)
return discB_loss
Expand Down Expand Up @@ -155,5 +157,7 @@ def save_model(self):
checkpoint_path = self.checkpoint.save(file_prefix=checkpoint_prefix)
print("Checkpoint saved at ", checkpoint_path)

def update_learning_rate(self):
raise NotImplementedError
def update_learning_rate(self, batches_per_epoch):
new_learning_rate = models.get_learning_rate(self.initial_learning_rate,
self.global_step, batches_per_epoch)
self.learning_rate.assign(new_learning_rate)
8 changes: 4 additions & 4 deletions src/pipeline/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_train_data(dataset_id, project_dir, batch_size=1):
# Queue up batches asynchronously onto the GPU.
# As long as there is a pool of batches CPU side a GPU prefetch of 1 is fine.
# TODO: If GPU exists:
#train_datasetA = train_datasetA.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))
train_datasetA = train_datasetA.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))

train_datasetB = tf.data.Dataset.list_files(trainB_path + os.sep + '*.jpg', shuffle=False)
train_datasetB = train_datasetB.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=trainB_size))
Expand All @@ -72,9 +72,9 @@ def load_train_data(dataset_id, project_dir, batch_size=1):
num_parallel_calls=threads,
drop_remainder=True))
train_datasetB = train_datasetB.prefetch(buffer_size=threads)
#train_datasetB = train_datasetB.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))

return train_datasetA, train_datasetB
train_datasetB = train_datasetB.apply(tf.contrib.data.prefetch_to_device("/gpu:0", buffer_size=1))
# Create a tf.data.Iterator from the Datasets:
return iter(train_datasetA), iter(train_datasetB)

def load_test_data(dataset_id, project_dir):
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
Expand Down
137 changes: 44 additions & 93 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,118 +8,69 @@
import tensorflow as tf

import models
from pipeline.data import load_train_data
from models.losses import generator_loss, discriminator_loss, cycle_consistency_loss, identity_loss
from models.networks import Generator, Discriminator
from models.cyclegan import CycleGANModel
from utils.image_history_buffer import ImageHistoryBuffer
from pipeline.data import load_train_data

tf.enable_eager_execution()

"""Hyperparameters (TODO: Move to argparse)"""
project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
dataset_id = 'facades'
epochs = 15
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
initial_learning_rate = 0.0002
num_gen_filters = 32
num_disc_filters = 64
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
img_size = 256
cyc_lambda = 10
identity_lambda = 0.5
epochs = 15
save_epoch_freq = 5
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)

def train(data, model, options):
# Create a tf.data.Iterator from the Datasets:
data = iter(data[0]), iter(data[1])
# Initialize Tensorboard summary writer:
log_dir = os.path.join(project_dir, 'saved_models', 'tensorboard')
summary_writer = tf.contrib.summary.create_file_writer(log_dir, flush_millis=10000)
for epoch in range(epochs):
with summary_writer.as_default():
start = time.time()
for train_step in range(batches_per_epoch):
# Record summaries every 100 train_steps; there are 4 gradient updates per step.
with tf.contrib.summary.record_summaries_every_n_global_steps(400, global_step=global_step):
try:
# Get next training batches:
trainA = train_datasetA.get_next()
trainB = train_datasetB.get_next()
except tf.errors.OutOfRangeError:
print("Error, run out of data")
break
with tf.GradientTape(persistent=True) as tape:
# Gen output shape: (batch_size, img_size, img_size, 3)
genA2B_output = genA2B(trainA)
reconstructedA = genB2A(genA2B_output)
genB2A_output = genB2A(trainB)
reconstructedB = genA2B(genB2A_output)
# Disc output shape: (batch_size, img_size/8, img_size/8, 1)
discA_real = discA(trainA)
discB_real = discB(trainB)

discA_fake = discA(genB2A_output)
discB_fake = discB(genA2B_output)
# Sample from history buffer of 50 images:
discA_fake_refined = discA_buffer.query(discA_fake)
discB_fake_refined = discB_buffer.query(discB_fake)

identityA = genB2A(trainA)
identityB = genA2B(trainB)
id_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)

genA2B_loss_basic = generator_loss(discB_fake)
genB2A_loss_basic = generator_loss(discA_fake)
cyc_lossA = cyc_lambda * cycle_consistency_loss(trainA, reconstructedA)
cyc_lossB = cyc_lambda * cycle_consistency_loss(trainB, reconstructedB)

genA2B_loss = genA2B_loss_basic + cyc_lossA + cyc_lossB + id_loss
genB2A_loss = genB2A_loss_basic + cyc_lossB + cyc_lossB + id_loss
#tf.contrib.summary.scalar('loss/genA2B', genA2B_loss_basic)
#tf.contrib.summary.scalar('loss/genB2A', genB2A_loss_basic)
#tf.contrib.summary.scalar('loss/discA', discA_loss)
#tf.contrib.summary.scalar('loss/discB', discB_loss)
#tf.contrib.summary.scalar('loss/cyc', cyc_loss)
#tf.contrib.summary.scalar('loss/identity', id_loss)
#tf.contrib.summary.scalar('learning_rate', learning_rate)

discA_loss = discriminator_loss(discA_real, discA_fake_refined)
discB_loss = discriminator_loss(discB_real, discB_fake_refined)
# Summaries for Tensorboard:
tf.contrib.summary.scalar('loss/genA2B', genA2B_loss_basic)
tf.contrib.summary.scalar('loss/genB2A', genB2A_loss_basic)
tf.contrib.summary.scalar('loss/discA', discA_loss)
tf.contrib.summary.scalar('loss/discB', discB_loss)
tf.contrib.summary.scalar('loss/cyc', cyc_loss)
tf.contrib.summary.scalar('loss/identity', id_loss)
tf.contrib.summary.scalar('learning_rate', learning_rate)
tf.contrib.summary.image('A/generated', genB2A_output)
tf.contrib.summary.image('A/reconstructed', reconstructedA)
tf.contrib.summary.image('B/generated', genA2B_output)
tf.contrib.summary.image('B/reconstructed', reconstructedB)

# Try chaining disc and gen parameters into 2 optimizers?
genA2B_gradients = tape.gradient(genA2B_loss, genA2B.variables)
genB2A_gradients = tape.gradient(genB2A_loss, genB2A.variables)
genA2B_opt.apply_gradients(zip(genA2B_gradients, genA2B.variables), global_step=global_step)
genB2A_opt.apply_gradients(zip(genB2A_gradients, genB2A.variables), global_step=global_step)
discA_gradients = tape.gradient(discA_loss, discA.variables)
discB_gradients = tape.gradient(discB_loss, discB.variables)
discA_opt.apply_gradients(zip(discA_gradients, discA.variables), global_step=global_step)
discB_opt.apply_gradients(zip(discB_gradients, discB.variables), global_step=global_step)
# Assign decayed learning rate:
learning_rate.assign(models.get_learning_rate(initial_learning_rate, global_step, batches_per_epoch))
# Checkpoint the model:
if (epoch + 1) % save_epoch_freq == 0:
checkpoint_path = checkpoint.save(file_prefix=checkpoint_prefix)
print("Checkpoint saved at ", checkpoint_path)
print("Global Training Step: ", global_step.numpy() // 4)
print ("Time taken for total epoch {} is {} sec\n".format(global_step.numpy() // (4 * batches_per_epoch),
time.time()-start))
def train_one_epoch():
raise NotImplementedError

if __name__ == "__main__":
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_train_data(dataset_id, project_dir)
with tf.device("/cpu:0"):
# Preprocess data on CPU for significant performance gains:
dataA, dataB = load_train_data(dataset_id, project_dir)
model = CycleGANModel(initial_learning_rate, num_gen_filters,
num_disc_filters, batch_size, cyc_lambda,
identity_lambda, checkpoint_dir, img_size,
training=True)
#with tf.device("/gpu:0"):
data = iter(data[0]), iter(data[1])
model.load_batch(data)
model.optimize_parameters()
with tf.device("/gpu:0"):
# Initialize Tensorboard summary writer:
log_dir = os.path.join(project_dir, 'saved_models', 'tensorboard')
summary_writer = tf.contrib.summary.create_file_writer(log_dir, flush_millis=10000)
for epoch in range(epochs):
start = time.time()
with summary_writer.as_default():
for train_step in range(batches_per_epoch):
# Record summaries every 100 train_steps; there are 3 gradient updates per step.
with tf.contrib.summary.record_summaries_every_n_global_steps(300, global_step=model.global_step):
# Get next training batches:
batch = {"A": dataA.get_next(), "B": dataB.get_next()}
model.set_input(batch)
model.optimize_parameters()
print("Iteration ", train_step)
# Summaries for Tensorboard:
tf.contrib.summary.image('A/generated', model.fakeA)
tf.contrib.summary.image('A/reconstructed', model.reconstructedA)
tf.contrib.summary.image('B/generated', model.fakeB)
tf.contrib.summary.image('B/reconstructed', model.reconstructedB)
# Assign decayed learning rate:
model.update_learning_rate(batches_per_epoch)
# Checkpoint the model:
if (epoch + 1) % save_epoch_freq == 0:
model.save_model()
print("Global Training Step: ", global_step.numpy() // 3)
print ("Time taken for total epoch {} is {} sec\n".format(global_step.numpy() \
// (3 * batches_per_epoch), time.time()-start))

0 comments on commit 84d855c

Please sign in to comment.