Skip to content

Commit

Permalink
Add comments to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Feb 15, 2018
1 parent 7117de2 commit 45f58b3
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions codebase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,60 @@
import numpy as np

def u2t(x):
"""Convert uint8 to [-1, 1] float
"""
return x.astype('float32') / 255 * 2 - 1

def s2t(x):
"""Convert [0, 1] float to [-1, 1] float
"""
return x * 2 - 1

def delete_existing(path):
"""Delete directory if it exists
Used for automatically rewrites existing log directories
"""
if args.run < 999:
assert not os.path.exists(path), "Cannot overwrite {:s}".format(path)

else:
if os.path.exists(path):
shutil.rmtree(path)

def save_accuracy(M, fn_acc_key, tag, dataloader,
def save_accuracy(M, fn_acc_key, tag, data,
train_writer=None, step=None, print_list=None,
full=True):
"""Log the accuracy of model on data to tf.summary.FileWriter
M - (TensorDict) the model
fn_acc_key - (str) key to query the relevant function in M
tag - (str) summary tag for FileWriter
data - (Data) data object with images/labels attributes
train_writer - (FileWriter)
step - (int) global step in file writer
print_list - (list) list of vals to print to stdout
full - (bool) use full dataset v. first 1000 samples
"""
fn_acc = getattr(M, fn_acc_key, None)
if fn_acc:
acc, summary = exact_accuracy(fn_acc, tag, dataloader, full)
acc, summary = compute_accuracy(fn_acc, tag, data, full)
train_writer.add_summary(summary, step + 1)
print_list += [os.path.basename(tag), acc]

def exact_accuracy(fn_acc, tag, dataloader, full=True):
# Fixed shuffling scheme
state = np.random.get_state()
np.random.seed(0)
shuffle = np.random.permutation(len(dataloader.images))
np.random.set_state(state)
def compute_accuracy(fn_acc, tag, data, full=True):
"""Compute accuracy w.r.t. data
fn_acc - (fn) Takes (x, y) as input and returns accuracy
tag - (str) summary tag for FileWriter
data - (Data) data object with images/labels attributes
full - (bool) use full dataset v. first 1000 samples
"""
with tb.nputils.FixedSeed(0):
shuffle = np.random.permutation(len(data.images))

xs = dataloader.images[shuffle]
ys = dataloader.labels[shuffle] if dataloader.labels is not None else None
xs = data.images[shuffle]
ys = data.labels[shuffle] if data.labels is not None else None

if not full:
xs = xs[:1000]
Expand All @@ -47,9 +70,9 @@ def exact_accuracy(fn_acc, tag, dataloader, full=True):
bs = 200

for i in xrange(0, n, bs):
x = u2t(xs[i:i+bs]) if dataloader.cast else xs[i:i+bs]
y = ys[i:i+bs] if ys is not None else dataloader.labeler(x)
acc += fn_acc(x, y) * len(x) / n
x = data.preprocess(xs[i:i+bs])
y = ys[i:i+bs] if ys is not None else data.labeler(x)
acc += fn_acc(x, y) / n * len(x)

summary = tf.Summary.Value(tag=tag, simple_value=acc)
summary = tf.Summary(value=[summary])
Expand Down

0 comments on commit 45f58b3

Please sign in to comment.