Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Herbie Bradley committed Dec 11, 2018
1 parent 5ebacc7 commit 1477bfc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def test(data, model, checkpoint_info, dataset_id):
generatedB = os.path.join(path_to_dataset, 'generatedB' + os.sep)
genA2B = model['genA2B']
genB2A = model['genB2A']
<<<<<<< Updated upstream

=======

>>>>>>> Stashed changes
checkpoint, checkpoint_dir = checkpoint_info
restore_from_checkpoint(checkpoint, checkpoint_dir)
test_datasetA, test_datasetB, testA_size, testB_size = data
Expand Down Expand Up @@ -57,6 +61,10 @@ def test(data, model, checkpoint_info, dataset_id):
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_test_data(dataset_id, project_dir)
with tf.device("/gpu:0"):
<<<<<<< Updated upstream
model = define_model(training=False)
=======
model = define_model(initial_learning_rate, training=False)
>>>>>>> Stashed changes
checkpoint_info = initialize_checkpoint(checkpoint_dir, model, training=False)
test(data, model, checkpoint_info, dataset_id)
25 changes: 20 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
img_size = 256
cyc_lambda = 10
<<<<<<< Updated upstream
identity_lambda = 0.5
=======
identity_lambda = 0
if dataset_id == 'facades':
identity_lambda = 0.5
>>>>>>> Stashed changes
epochs = 15
save_epoch_freq = 5
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)
Expand Down Expand Up @@ -137,14 +143,23 @@ def train(data, model, checkpoint_info, epochs):
discA_real = discA(trainA)
discB_real = discB(trainB)

discA_fake = discA(genB2A_output)
discB_fake = discB(genA2B_output)
discA_fake_refined = discA(genB2A_output)
discB_fake_refined = discB(genA2B_output)
# Sample from history buffer of 50 images:
discA_fake_refined = discA_buffer.query(discA_fake)
discB_fake_refined = discB_buffer.query(discB_fake)
#discA_fake_refined = discA_buffer.query(discA_fake)
#discB_fake_refined = discB_buffer.query(discB_fake)

<<<<<<< Updated upstream
identityA = genB2A(trainA)
identityB = genA2B(trainB)
=======
reconstructedA = genB2A(genA2B_output)
reconstructedB = genA2B(genB2A_output)
identityA = genB2A(trainA)
identityB = genA2B(trainB)

cyc_loss = cyc_lambda * cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB)
>>>>>>> Stashed changes
id_loss = identity_lambda * cyc_lambda * identity_loss(trainA, trainB, identityA, identityB)

genA2B_loss_basic = generator_loss(discB_fake_refined)
Expand Down Expand Up @@ -193,7 +208,7 @@ def train(data, model, checkpoint_info, epochs):
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_train_data(dataset_id, project_dir)
#with tf.device("/gpu:0"):
with tf.device("/gpu:0"):
model = define_model(initial_learning_rate, training=True)
checkpoint_info = initialize_checkpoint(checkpoint_dir, model, training=True)
train(data, model, checkpoint_info, epochs=epochs)

0 comments on commit 1477bfc

Please sign in to comment.