Skip to content

Commit

Permalink
Add separate test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Dec 9, 2018
1 parent 645347b commit d9ec247
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 62 deletions.
61 changes: 10 additions & 51 deletions src/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import glob

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import models
from preprocessing.load_data import load_train_data, load_test_data, save_images
from pipeline.load_data import load_train_data
from models.losses import generator_loss, discriminator_loss, cycle_consistency_loss, identity_loss
from models.networks import Generator, Discriminator
from utils.image_history_buffer import ImageHistoryBuffer
Expand Down Expand Up @@ -72,67 +70,28 @@ def restore_from_checkpoint(checkpoint, checkpoint_dir):
else:
print("No checkpoint found, initializing model.")

def define_model(initial_learning_rate, training=True):
def define_model(initial_learning_rate=0.0002, training=True):
if not training:
genA2B = Generator(img_size=img_size)
genB2A = Generator(img_size=img_size)
genA2B = Generator(num_gen_filters, img_size=img_size)
genB2A = Generator(num_gen_filters, img_size=img_size)
return {'genA2B':genA2B, 'genB2A':genB2A}
else:
discA = Discriminator()
discB = Discriminator()
genA2B = Generator(img_size=img_size)
genB2A = Generator(img_size=img_size)
discA = Discriminator(num_disc_filters)
discB = Discriminator(num_disc_filters)
genA2B = Generator(num_gen_filters, img_size=img_size)
genB2A = Generator(num_gen_filters, img_size=img_size)
learning_rate = tf.contrib.eager.Variable(initial_learning_rate, dtype=tf.float32, name='learning_rate')
discA_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
discB_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
genA2B_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
genB2A_opt = tf.train.AdamOptimizer(learning_rate, beta1=0.5)

nets = {'discA':discA, 'discB':discB, 'genA2B':genA2B, 'genB2A':genB2A}
optimizers = {'discA_opt':discA_opt, 'discB_opt':discB_opt, 'genA2B_opt':genA2B_opt,
'genB2A_opt':genB2A_opt, 'learning_rate':learning_rate}
return nets, optimizers

def test(data, model, checkpoint_info, dataset_id):
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
generatedA = os.path.join(path_to_dataset, 'generatedA' + os.sep)
generatedB = os.path.join(path_to_dataset, 'generatedB' + os.sep)
genA2B = model['genA2B']
genB2A = model['genB2A']

checkpoint, checkpoint_dir = checkpoint_info
restore_from_checkpoint(checkpoint, checkpoint_dir)
test_datasetA, test_datasetB, testA_size, testB_size = data
test_datasetA = iter(test_datasetA)
test_datasetB = iter(test_datasetB)

for imageB in range(testB_size):
start = time.time()
try:
# Get next testing image:
testB = test_datasetB.get_next()
except tf.errors.OutOfRangeError:
print("Error, run out of data")
break
genB2A_output = genB2A(testB, training=False)
with tf.device("/cpu:0"):
save_images(genB2A_output, save_dir=generatedA, image_index=imageB)
print("Generating {} test A images finished in {} sec\n".format(testA_size, time.time()-start))

for imageA in range(testA_size):
start = time.time()
try:
# Get next testing image:
testA = test_datasetA.get_next()
except tf.errors.OutOfRangeError:
print("Error, run out of data")
break
genA2B_output = genA2B(testA, training=False)
with tf.device("/cpu:0"):
save_images(genA2B_output, save_dir=generatedB, image_index=imageA)
print("Generating {} test B images finished in {} sec\n".format(testB_size, time.time()-start))

def train(data, model, checkpoint_info, epochs, initial_learning_rate=initial_learning_rate):
def train(data, model, checkpoint_info, epochs):
nets, optimizers = model
discA = nets['discA']
discB = nets['discB']
Expand Down
70 changes: 59 additions & 11 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,63 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import time

import tensorflow as tf

from cyclegan import define_checkpoint, define_model, restore_from_checkpoint
from pipeline.load_data import load_test_data, save_images

"""Pytorch testing ground"""
dataset_id = 'horse2zebra'
project_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
testA_path = os.path.join(path_to_dataset, 'testA')
img_path = testA_path + os.sep + 'n02381460_20.jpg'
checkpoint_dir = os.path.join(project_dir, 'saved_models', 'checkpoints')
dataset_id = 'horse2zebra'
initial_learning_rate = 0.0002

def test(data, model, checkpoint_info, dataset_id):
path_to_dataset = os.path.join(project_dir, 'data', 'raw', dataset_id + os.sep)
generatedA = os.path.join(path_to_dataset, 'generatedA' + os.sep)
generatedB = os.path.join(path_to_dataset, 'generatedB' + os.sep)
genA2B = model['genA2B']
genB2A = model['genB2A']

checkpoint, checkpoint_dir = checkpoint_info
restore_from_checkpoint(checkpoint, checkpoint_dir)
test_datasetA, test_datasetB, testA_size, testB_size = data
test_datasetA = iter(test_datasetA)
test_datasetB = iter(test_datasetB)

for imageB in range(testB_size):
start = time.time()
try:
# Get next testing image:
testB = test_datasetB.get_next()
except tf.errors.OutOfRangeError:
print("Error, run out of data")
break
genB2A_output = genB2A(testB)
with tf.device("/cpu:0"):
save_images(genB2A_output, save_dir=generatedA, image_index=imageB)
print("Generating {} test A images finished in {} sec\n".format(testA_size, time.time()-start))

for imageA in range(testA_size):
start = time.time()
try:
# Get next testing image:
testA = test_datasetA.get_next()
except tf.errors.OutOfRangeError:
print("Error, run out of data")
break
genA2B_output = genA2B(testA)
with tf.device("/cpu:0"):
save_images(genA2B_output, save_dir=generatedB, image_index=imageA)
print("Generating {} test B images finished in {} sec\n".format(testB_size, time.time()-start))

A = Image.open(img_path).convert('RGB')
A = transforms.ToTensor()(A)
if __name__ == "__main__":
with tf.device("/cpu:0"): # Preprocess data on CPU for significant performance gains.
data = load_test_data(dataset_id, project_dir)
with tf.device("/gpu:0"):
model = define_model(initial_learning_rate, training=False)
checkpoint_info = define_checkpoint(checkpoint_dir, model, training=False)
test(data, model, checkpoint_info, dataset_id)

0 comments on commit d9ec247

Please sign in to comment.