Skip to content

Commit

Permalink
Minor bugfixes, refactor options class
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Mar 16, 2019
1 parent 9826590 commit bd5a280
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import tensorflow as tf

def Dataset(object):
class Dataset(object):

def __init__(self, opt):
self.opt = opt
Expand Down
2 changes: 1 addition & 1 deletion src/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, opt):
self.norm = False
self.dropout = tf.keras.layers.Dropout(opt.dropout_prob)
if self.resize_conv:
self.upsample = tf.keras.layers.Upsampling2D(size=(2, 2), interpolation='nearest')
self.upsample = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')
self.conv1 = tf.keras.layers.Conv2D(opt.ngf * 2, kernel_size=3, strides=1,
kernel_initializer=tf.truncated_normal_initializer(stddev=opt.init_scale))
self.conv2 = tf.keras.layers.Conv2D(opt.ngf, kernel_size=3, strides=1,
Expand Down
2 changes: 1 addition & 1 deletion src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
tf.enable_eager_execution()

if __name__ == "__main__":
opt = Options(training=False)
opt = Options().parse(training=False)
# TODO: Test if this is always on CPU:
dataset = Dataset(opt)
model = CycleGANModel(opt)
Expand Down
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
tf.enable_eager_execution()

if __name__ == "__main__":
opt = Options(training=True)
opt = Options().parse(training=True)
# TODO: Test if this is always on CPU:
dataset = Dataset(opt)
model = CycleGANModel(opt)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/image_history_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ImageHistoryBuffer(object):
image_history_buffer: Numpy array of image batches used to calculate average loss.
"""
def __init__(self, opt):
self.max_buffer_size = opt.max_buffer_size
self.max_buffer_size = opt.buffer_size
self.batch_size = opt.batch_size
self.image_history_buffer = np.zeros((0, opt.img_size, opt.img_size, 3))
assert(self.batch_size >= 1)
Expand Down
20 changes: 11 additions & 9 deletions src/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

class Options(object):

def __init__(self, training=True):
def __init__(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# basic options
parser.add_argument('--data_dir', required=True, help='path to directory where the dataset is stored, should have subfolders trainA, trainB, testA, testB')
parser.add_argument('--save_dir', required=True, help='checkpoints and Tensorboard summaries are saved here')
parser.add_argument('--gpu_id', type=str, default='0', help='gpu id to run model on, use -1 for CPU, multigpu not supported')
parser.add_argument('--training', action='store_false', help='boolean for training/testing')
# model options
parser.add_argument('--ngf', type=int, default=64, help='number of gen filters in the first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='number of disc filters in the first conv layer')
Expand All @@ -22,21 +23,23 @@ def __init__(self, training=True):
cpu_count = multiprocessing.cpu_count()
parser.add_argument('--num_threads', type=int, default=cpu_count, help='number of CPU threads to use for loading data')
parser.add_argument('--img_size', type=int, default=256, help='input image size')
self.parser = parser

def parse(self, training):
opt, _ = self.parser.parse_known_args()
# get training/testing options
if training:
parser = self.get_train_options(parser)
self.parser.set_defaults(training=True)
self.parser = self.get_train_options(self.parser)
else:
parser = self.get_test_options(parser)

self.parser = parser
opt = parser.parse_args()
self.parser = self.get_test_options(self.parser)

opt = self.parser.parse_args()
self.print_options(opt)
return opt

def get_train_options(self, parser):
# training specific options
parser.add_argument('--training', action='store_true', help='boolean for training/testing')
parser.add_argument('--load_checkpoint', action='store_true', help='if true, loads latest checkpoint')
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--summary_freq', type=int, default=100, help='frequency of saving saving tensorboard summaries in training steps')
Expand All @@ -54,7 +57,6 @@ def get_train_options(self, parser):

def get_test_options(self, parser):
# test specific options
parser.add_argument('--training', action='store_false', help='boolean for training/testing')
parser.add_argument('--results_dir', required=True, help='directory to save result images')
parser.add_argument('--num_test', type=int, default=50, help='number of test images to generate')
return parser
Expand All @@ -65,7 +67,7 @@ def print_options(self, opt):
for option, value in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(option)
if v != default:
if value != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(option), str(value), comment)
message += '----------------- End -------------------'
Expand Down

0 comments on commit bd5a280

Please sign in to comment.