Skip to content

Commit

Permalink
Fix stash
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Dec 11, 2018
1 parent 1477bfc commit 8a3d615
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 31 deletions.
8 changes: 0 additions & 8 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ def test(data, model, checkpoint_info, dataset_id):
generatedB = os.path.join(path_to_dataset, 'generatedB' + os.sep)
genA2B = model['genA2B']
genB2A = model['genB2A']
<<<<<<< Updated upstream

=======

>>>>>>> Stashed changes
checkpoint, checkpoint_dir = checkpoint_info
restore_from_checkpoint(checkpoint, checkpoint_dir)
test_datasetA, test_datasetB, testA_size, testB_size = data
Expand Down Expand Up @@ -61,10 +57,6 @@ def test(data, model, checkpoint_info, dataset_id):
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_test_data(dataset_id, project_dir)
with tf.device("/gpu:0"):
<<<<<<< Updated upstream
model = define_model(training=False)
=======
model = define_model(initial_learning_rate, training=False)
>>>>>>> Stashed changes
checkpoint_info = initialize_checkpoint(checkpoint_dir, model, training=False)
test(data, model, checkpoint_info, dataset_id)
31 changes: 8 additions & 23 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
img_size = 256
cyc_lambda = 10
<<<<<<< Updated upstream
identity_lambda = 0.5
=======
identity_lambda = 0
if dataset_id == 'facades':
identity_lambda = 0.5
>>>>>>> Stashed changes
epochs = 15
save_epoch_freq = 5
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)
Expand Down Expand Up @@ -143,32 +137,23 @@ def train(data, model, checkpoint_info, epochs):
discA_real = discA(trainA)
discB_real = discB(trainB)

discA_fake_refined = discA(genB2A_output)
discB_fake_refined = discB(genA2B_output)
discA_fake = discA(genB2A_output)
discB_fake = discB(genA2B_output)
# Sample from history buffer of 50 images:
#discA_fake_refined = discA_buffer.query(discA_fake)
#discB_fake_refined = discB_buffer.query(discB_fake)
discA_fake_refined = discA_buffer.query(discA_fake)
discB_fake_refined = discB_buffer.query(discB_fake)

<<<<<<< Updated upstream
identityA = genB2A(trainA)
identityB = genA2B(trainB)
=======
reconstructedA = genB2A(genA2B_output)
reconstructedB = genA2B(genB2A_output)
identityA = genB2A(trainA)
identityB = genA2B(trainB)

cyc_loss = cyc_lambda * cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB)
>>>>>>> Stashed changes
id_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)

genA2B_loss_basic = generator_loss(discB_fake_refined)
genB2A_loss_basic = generator_loss(discA_fake_refined)
cyc_loss_A = cyc_lambda * cycle_consistency_loss(trainA, reconstructedA)
cyc_loss_B = cyc_lambda * cycle_consistency_loss(trainB, reconstructedB)
cyc_lossA = cyc_lambda * cycle_consistency_loss(trainA, reconstructedA)
cyc_lossB = cyc_lambda * cycle_consistency_loss(trainB, reconstructedB)

genA2B_loss = genA2B_loss_basic + cyc_loss_A + cyc_loss_B + id_loss
genB2A_loss = genB2A_loss_basic + cyc_loss_B + cyc_loss_B + id_loss
genA2B_loss = genA2B_loss_basic + cyc_lossA + cyc_lossB + id_loss
genB2A_loss = genB2A_loss_basic + cyc_lossB + cyc_lossB + id_loss

discA_loss = discriminator_loss(discA_real, discA_fake_refined)
discB_loss = discriminator_loss(discB_real, discB_fake_refined)
Expand Down

0 comments on commit 8a3d615

Please sign in to comment.