Skip to content

Commit

Permalink
Add ngf/ndf, fix instance norm
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Dec 9, 2018
1 parent 198f2e2 commit 645347b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 113 deletions.
80 changes: 36 additions & 44 deletions src/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
dataset_id = 'horse2zebra'
initial_learning_rate = 0.0002
num_gen_filters = 32
num_disc_filters = 64
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
img_size = 256
cyc_lambda = 10
identity_lambda = 0
if dataset_id == 'monet2photo':
identity_lambda = 0.5
else:
identity_lambda = 0
epochs = 2
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)

Expand Down Expand Up @@ -171,77 +172,68 @@ def train(data, model, checkpoint_info, epochs, initial_learning_rate=initial_le
break
with tf.GradientTape(persistent=True) as tape:
# Gen output shape: (batch_size, img_size, img_size, 3)
genA2B_output = genA2B(trainA, training=True)
genB2A_output = genB2A(trainB, training=True)
genA2B_output = genA2B(trainA)
genB2A_output = genB2A(trainB)
# Disc output shape: (batch_size, img_size/8, img_size/8, 1)
discA_real = discA(trainA, training=True)
discB_real = discB(trainB, training=True)
discA_real = discA(trainA)
discB_real = discB(trainB)

discA_fake = discA(genB2A_output, training=True)
discB_fake = discB(genA2B_output, training=True)
discA_fake = discA(genB2A_output)
discB_fake = discB(genA2B_output)
# Sample from history buffer of 50 images:
discA_fake_refined = discA_buffer.query(discA_fake)
discB_fake_refined = discB_buffer.query(discB_fake)

reconstructedA = genB2A(genA2B_output, training=True)
reconstructedB = genA2B(genB2A_output, training=True)
reconstructedA = genB2A(genA2B_output)
reconstructedB = genA2B(genB2A_output)
identityA, identityB = 0, 0
if dataset_id == 'monet2photo':
identityA = genB2A(trainA, training=True)
identityB = genA2B(trainB, training= True)
else:
identityA, identityB = 0, 0
identity_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)
identityA = genB2A(trainA)
identityB = genA2B(trainB)

cyc_loss = cyc_lambda * cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB)
genA2B_loss = generator_loss(discB_fake_refined) + cyc_loss + identity_loss
genB2A_loss = generator_loss(discA_fake_refined) + cyc_loss + identity_loss
id_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)
genA2B_loss = generator_loss(discB_fake_refined) + cyc_loss + id_loss
genB2A_loss = generator_loss(discA_fake_refined) + cyc_loss + id_loss
discA_loss = discriminator_loss(discA_real, discA_fake_refined)
discB_loss = discriminator_loss(discB_real, discB_fake_refined)
# Summaries for Tensorboard:
tf.contrib.summary.scalar('loss/genA2B', genA2B_loss)
tf.contrib.summary.scalar('loss/genB2A', genB2A_loss)
tf.contrib.summary.scalar('loss/discA', discA_loss)
tf.contrib.summary.scalar('loss/discB', discB_loss)
tf.contrib.summary.scalar('loss/cyc', cyc_loss)
tf.contrib.summary.scalar('loss/identity', id_loss)
tf.contrib.summary.scalar('learning_rate', learning_rate)
tf.contrib.summary.image('A/generated', genB2A_output)
tf.contrib.summary.image('A/reconstructed', reconstructedA)
tf.contrib.summary.image('B/generated', genA2B_output)
tf.contrib.summary.image('B/reconstructed', reconstructedB)

discA_gradients = tape.gradient(discA_loss, discA.variables)
discB_gradients = tape.gradient(discB_loss, discB.variables)
genA2B_gradients = tape.gradient(genA2B_loss, genA2B.variables)
genB2A_gradients = tape.gradient(genB2A_loss, genB2A.variables)

# Try chaining disc and gen parameters into 2 optimizers?
discA_opt.apply_gradients(zip(discA_gradients, discA.variables), global_step=global_step)
discB_opt.apply_gradients(zip(discB_gradients, discB.variables), global_step=global_step)
genA2B_opt.apply_gradients(zip(genA2B_gradients, genA2B.variables), global_step=global_step)
genB2A_opt.apply_gradients(zip(genB2A_gradients, genB2A.variables), global_step=global_step)
# Summaries
tf.contrib.summary.scalar('loss/genA2B', genA2B_loss)
tf.contrib.summary.scalar('loss/genB2A', genB2A_loss)
tf.contrib.summary.scalar('loss/discA', discA_loss)
tf.contrib.summary.scalar('loss/discB', discB_loss)
tf.contrib.summary.scalar('loss/cyc', cyc_loss)
tf.contrib.summary.scalar('learning_rate', learning_rate)
tf.contrib.summary.histogram('discA/real', discA_real)
tf.contrib.summary.histogram('discA/fake', discA_fake)
tf.contrib.summary.histogram('discB/real', discB_real)
tf.contrib.summary.histogram('discB/fake', discA_fake)
# Transform images from [-1, 1] to [0, 1) for Tensorboard.
tf.contrib.summary.image('A/generated', (genB2A_output * 0.5) + 0.5)
tf.contrib.summary.image('A/reconstructed', (reconstructedA * 0.5) + 0.5)
tf.contrib.summary.image('B/generated', (genA2B_output * 0.5) + 0.5)
tf.contrib.summary.image('B/reconstructed', (reconstructedB * 0.5) + 0.5)

if train_step % 100 == 0:
# Here we do global step / 4 because there are 4 gradient updates per batch.
print("Global Training Step: ", global_step.numpy() // 4)
print("Epoch Training Step: ", train_step + 1)
# Assign decayed learning rate:
learning_rate.assign(models.get_learning_rate(initial_learning_rate, global_step, batches_per_epoch))
print("Learning rate in total epoch {} is: {}".format(global_step.numpy() // (4 * batches_per_epoch),
learning_rate.numpy()))
# Checkpoint the model:
if (epoch + 1) % 5 == 0:
checkpoint_path = checkpoint.save(file_prefix=checkpoint_prefix)
print("Checkpoint saved at ", checkpoint_path)
print ("Time taken for local epoch {} is {} sec\n".format(epoch + 1, time.time()-start))
print("Global Training Step: ", global_step.numpy() // 4)
print ("Time taken for local epoch {} is {} sec\n".format(global_step.numpy() // (4 * batches_per_epoch),
time.time()-start))

if __name__ == "__main__":
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_train_data(dataset_id, project_dir)
with tf.device("/gpu:0"):
model = define_model(initial_learning_rate=initial_learning_rate, training=True)
#with tf.device("/gpu:0"):
model = define_model(initial_learning_rate, training=True)
checkpoint_info = define_checkpoint(checkpoint_dir, model, training=True)
train(data, model, checkpoint_info, epochs=epochs)
139 changes: 70 additions & 69 deletions src/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,154 +6,155 @@

class Encoder(tf.keras.Model):

def __init__(self):
def __init__(self, ngf):
super(Encoder, self).__init__()
# Small variance in initialization helps with preventing colour inversion.
self.conv1 = tf.keras.layers.Conv2D(32, kernel_size=7, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2D(64, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv3 = tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv1 = tf.keras.layers.Conv2D(ngf, kernel_size=7, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2D(ngf * 2, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv3 = tf.keras.layers.Conv2D(ngf * 4, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

def call(self, inputs, training=True):
def call(self, inputs):
"""Reflection padding is used to reduce artifacts."""
x = tf.pad(inputs, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')
x = self.conv1(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
# Implement instance norm to more closely match orig. paper (momentum=0.1)?
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.relu(x)

x = self.conv2(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.relu(x)

x = self.conv3(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.relu(x)
return x


class Residual(tf.keras.Model):

def __init__(self):
def __init__(self, ngf):
super(Residual, self).__init__()

self.conv1 = tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv1 = tf.keras.layers.Conv2D(ngf * 4, kernel_size=3, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2D(ngf * 4, kernel_size=3, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

def call(self, inputs, training=True):
def call(self, inputs):
x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')
x = self.conv1(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.relu(x)

x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')
x = self.conv2(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)

x = tf.add(x, inputs) # Add better than concatenation.
x = tf.add(x, inputs) # Add is better than concatenation.
return x


class Decoder(tf.keras.Model):

def __init__(self):
def __init__(self, ngf):
super(Decoder, self).__init__()

self.conv1 = tf.keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2DTranspose(32, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv1 = tf.keras.layers.Conv2DTranspose(ngf * 2, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2DTranspose(ngf, kernel_size=3, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv3 = tf.keras.layers.Conv2D(3, kernel_size=7, strides=1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

def call(self, inputs, training=True):
def call(self, inputs):

x = self.conv1(inputs)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.relu(x)

x = self.conv2(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.relu(x)

x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT')
x = self.conv3(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = tf.nn.tanh(x)
return x


class Generator(tf.keras.Model):

def __init__(self, img_size=256, skip=False):
def __init__(self, ngf=32, img_size=256, skip=False):
super(Generator, self).__init__()

self.img_size = img_size
self.skip = skip #TODO: Add skip
self.encoder = Encoder()
self.encoder = Encoder(ngf)
if(self.img_size == 128):
self.res1 = Residual()
self.res2 = Residual()
self.res3 = Residual()
self.res4 = Residual()
self.res5 = Residual()
self.res6 = Residual()
self.res1 = Residual(ngf)
self.res2 = Residual(ngf)
self.res3 = Residual(ngf)
self.res4 = Residual(ngf)
self.res5 = Residual(ngf)
self.res6 = Residual(ngf)
else:
self.res1 = Residual()
self.res2 = Residual()
self.res3 = Residual()
self.res4 = Residual()
self.res5 = Residual()
self.res6 = Residual()
self.res7 = Residual()
self.res8 = Residual()
self.res9 = Residual()
self.decoder = Decoder()
self.res1 = Residual(ngf)
self.res2 = Residual(ngf)
self.res3 = Residual(ngf)
self.res4 = Residual(ngf)
self.res5 = Residual(ngf)
self.res6 = Residual(ngf)
self.res7 = Residual(ngf)
self.res8 = Residual(ngf)
self.res9 = Residual(ngf)
self.decoder = Decoder(ngf)

@tf.contrib.eager.defun
def call(self, inputs, training=True):
x = self.encoder(inputs, training)
def call(self, inputs):
x = self.encoder(inputs)
if(self.img_size == 128):
x = self.res1(x, training)
x = self.res2(x, training)
x = self.res3(x, training)
x = self.res4(x, training)
x = self.res5(x, training)
x = self.res6(x, training)
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
x = self.res6(x)
else:
x = self.res1(x, training)
x = self.res2(x, training)
x = self.res3(x, training)
x = self.res4(x, training)
x = self.res5(x, training)
x = self.res6(x, training)
x = self.res7(x, training)
x = self.res8(x, training)
x = self.res9(x, training)
x = self.decoder(x, training)
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
x = self.res6(x)
x = self.res7(x)
x = self.res8(x)
x = self.res9(x)
x = self.decoder(x)
return x

class Discriminator(tf.keras.Model):

def __init__(self):
def __init__(self, ndf=64):
super(Discriminator, self).__init__()
# TODO: check padding here, should it be same?
self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2D(128, kernel_size=4, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv3 = tf.keras.layers.Conv2D(256, kernel_size=4, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv4 = tf.keras.layers.Conv2D(512, kernel_size=4, strides=1, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))

self.conv1 = tf.keras.layers.Conv2D(ndf, kernel_size=4, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv2 = tf.keras.layers.Conv2D(ndf * 2, kernel_size=4, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv3 = tf.keras.layers.Conv2D(ndf * 4, kernel_size=4, strides=2, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv4 = tf.keras.layers.Conv2D(ndf * 8, kernel_size=4, strides=1, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.conv5 = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding='same', kernel_initializer=tf.truncated_normal_initializer(stddev=0.02))
self.leaky = tf.keras.layers.LeakyReLU(0.2)

@tf.contrib.eager.defun
def call(self, inputs, training=True):
def call(self, inputs):
x = self.conv1(inputs)
x = self.leaky(x)

x = self.conv2(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = self.leaky(x)

x = self.conv3(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = self.leaky(x)

x = self.conv4(x)
x = tf.contrib.layers.instance_norm(x, epsilon=1e-05, trainable=training)
x = tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-05, trainable=False)
x = self.leaky(x)

x = self.conv5(x)
Expand Down

0 comments on commit 645347b

Please sign in to comment.