Skip to content

Commit

Permalink
Small change
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Dec 10, 2018
1 parent d9ec247 commit 289a8ae
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
12 changes: 6 additions & 6 deletions src/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import os
import time
import multiprocessing
import glob

import tensorflow as tf

import models
from pipeline.load_data import load_train_data
from pipeline.data import load_train_data
from models.losses import generator_loss, discriminator_loss, cycle_consistency_loss, identity_loss
from models.networks import Generator, Discriminator
from utils.image_history_buffer import ImageHistoryBuffer
Expand All @@ -19,17 +18,18 @@

"""Hyperparameters (TODO: Move to argparse)"""
project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
dataset_id = 'horse2zebra'
dataset_id = 'facades'
initial_learning_rate = 0.0002
num_gen_filters = 32
num_gen_filters = 64
num_disc_filters = 64
batch_size = 1 # Set batch size to 4 or 16 if training multigpu
img_size = 256
cyc_lambda = 10
identity_lambda = 0
if dataset_id == 'monet2photo':
identity_lambda = 0.5
epochs = 2
epochs = 50
save_epoch_freq = 5
batches_per_epoch = models.get_batches_per_epoch(dataset_id, project_dir)

def define_checkpoint(checkpoint_dir, model, training=True):
Expand Down Expand Up @@ -181,7 +181,7 @@ def train(data, model, checkpoint_info, epochs):
# Assign decayed learning rate:
learning_rate.assign(models.get_learning_rate(initial_learning_rate, global_step, batches_per_epoch))
# Checkpoint the model:
if (epoch + 1) % 5 == 0:
if (epoch + 1) % save_epoch_freq == 0:
checkpoint_path = checkpoint.save(file_prefix=checkpoint_prefix)
print("Checkpoint saved at ", checkpoint_path)
print("Global Training Step: ", global_step.numpy() // 4)
Expand Down
4 changes: 3 additions & 1 deletion src/pytorchtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy as np

"""Pytorch testing ground"""
dataset_id = 'horse2zebra'
dataset_id = 'facades'
project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
testA_path = os.path.join(path_to_dataset, 'testA')
img_path = testA_path + os.sep + 'n02381460_20.jpg'

A = Image.open(img_path).convert('RGB')
A = transforms.ToTensor()(A)

# Things to try: ngf = 64, cyc_lambda = 1, 5, 20, identity loss, don't divide D loss by 2
8 changes: 4 additions & 4 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import tensorflow as tf

from cyclegan import define_checkpoint, define_model, restore_from_checkpoint
from pipeline.load_data import load_test_data, save_images
from pipeline.data import load_test_data, save_images

project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
dataset_id = 'horse2zebra'
dataset_id = 'facades'
initial_learning_rate = 0.0002

def test(data, model, checkpoint_info, dataset_id):
Expand All @@ -21,7 +21,7 @@ def test(data, model, checkpoint_info, dataset_id):
generatedB = os.path.join(path_to_dataset, 'generatedB' + os.sep)
genA2B = model['genA2B']
genB2A = model['genB2A']

return None
checkpoint, checkpoint_dir = checkpoint_info
restore_from_checkpoint(checkpoint, checkpoint_dir)
test_datasetA, test_datasetB, testA_size, testB_size = data
Expand Down Expand Up @@ -57,7 +57,7 @@ def test(data, model, checkpoint_info, dataset_id):
if __name__ == "__main__":
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_test_data(dataset_id, project_dir)
with tf.device("/gpu:0"):
#with tf.device("/gpu:0"):
model = define_model(initial_learning_rate, training=False)
checkpoint_info = define_checkpoint(checkpoint_dir, model, training=False)
test(data, model, checkpoint_info, dataset_id)

0 comments on commit 289a8ae

Please sign in to comment.