Skip to content

Commit

Permalink
Merge pull request sjvasquez#2 from cclauss/modernize-python2-code
Browse files Browse the repository at this point in the history
Modernize Python 2 code to get ready for Python 3
  • Loading branch information
Sean Vasquez committed Feb 27, 2018
2 parents 0a769ee + 66c86ab commit de4cfde
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
3 changes: 2 additions & 1 deletion drawing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from collections import defaultdict

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -209,7 +210,7 @@ def draw(

if save_file is not None:
plt.savefig(save_file)
print 'saved to {}'.format(save_file)
print('saved to {}'.format(save_file))
else:
plt.show()
plt.close('all')
9 changes: 5 additions & 4 deletions prepare_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
from xml.etree import ElementTree

Expand Down Expand Up @@ -54,7 +55,7 @@ def collect_data():

stroke_fnames, transcriptions, writer_ids = [], [], []
for i, fname in enumerate(fnames):
print i, fname
print(i, fname)
if fname == 'data/raw/ascii/z01/z01-000/z01-000z.txt':
continue

Expand Down Expand Up @@ -98,10 +99,10 @@ def collect_data():


if __name__ == '__main__':
print 'traversing data directory...'
print('traversing data directory...')
stroke_fnames, transcriptions, writer_ids = collect_data()

print 'dumping to numpy arrays...'
print('dumping to numpy arrays...')
x = np.zeros([len(stroke_fnames), drawing.MAX_STROKE_LEN, 3], dtype=np.float32)
x_len = np.zeros([len(stroke_fnames)], dtype=np.int16)
c = np.zeros([len(stroke_fnames), drawing.MAX_CHAR_LEN], dtype=np.int8)
Expand All @@ -111,7 +112,7 @@ def collect_data():

for i, (stroke_fname, c_i, w_id_i) in enumerate(zip(stroke_fnames, transcriptions, writer_ids)):
if i % 200 == 0:
print i, '\t', '/', len(stroke_fnames)
print(i, '\t', '/', len(stroke_fnames))
x_i = get_stroke_sequence(stroke_fname)
valid_mask[i] = ~np.any(np.linalg.norm(x_i[:, :2], axis=1) > 60)

Expand Down
7 changes: 4 additions & 3 deletions rnn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os

import numpy as np
Expand All @@ -20,9 +21,9 @@ def __init__(self, data_dir):
self.test_df = DataFrame(columns=data_cols, data=data)
self.train_df, self.val_df = self.test_df.train_test_split(train_size=0.95, random_state=2018)

print 'train size', len(self.train_df)
print 'val size', len(self.val_df)
print 'test size', len(self.test_df)
print('train size', len(self.train_df))
print('val size', len(self.val_df))
print('test size', len(self.test_df))

def train_batch_generator(self, batch_size):
return self.batch_generator(
Expand Down
30 changes: 17 additions & 13 deletions tf_base_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from collections import deque
from datetime import datetime
import logging
Expand Down Expand Up @@ -146,7 +147,7 @@ def fit(self):

# validation evaluation
val_start = time.time()
val_batch_df = val_generator.next()
val_batch_df = next(val_generator)
val_feed_dict = {
getattr(self, placeholder_name, None): data
for placeholder_name, data in val_batch_df.items() if hasattr(self, placeholder_name)
Expand All @@ -173,19 +174,19 @@ def fit(self):
if hasattr(self, 'monitor_tensors'):
for name, tensor in self.monitor_tensors.items():
[np_val] = self.session.run([tensor], feed_dict=val_feed_dict)
print name
print 'min', np_val.min()
print 'max', np_val.max()
print 'mean', np_val.mean()
print 'std', np_val.std()
print 'nans', np.isnan(np_val).sum()
print
print
print
print(name)
print('min', np_val.min())
print('max', np_val.max())
print('mean', np_val.mean())
print('std', np_val.std())
print('nans', np.isnan(np_val).sum())
print()
print()
print()

# train step
train_start = time.time()
train_batch_df = train_generator.next()
train_batch_df = next(train_generator)
train_feed_dict = {
getattr(self, placeholder_name, None): data
for placeholder_name, data in train_batch_df.items() if hasattr(self, placeholder_name)
Expand Down Expand Up @@ -272,7 +273,7 @@ def predict(self, chunk_size=256):
test_generator = self.reader.test_batch_generator(chunk_size)
for i, test_batch_df in enumerate(test_generator):
if i % 10 == 0:
print i*len(test_batch_df)
print(i*len(test_batch_df))

test_feed_dict = {
getattr(self, placeholder_name, None): data
Expand Down Expand Up @@ -337,7 +338,10 @@ def init_logging(self, log_dir):
date_str = datetime.now().strftime('%Y-%m-%d_%H-%M')
log_file = 'log_{}.txt'.format(date_str)

reload(logging) # bad
try: # Python 2
reload(logging) # bad
except NameError: # Python 3
import logging
logging.basicConfig(
filename=os.path.join(log_dir, log_file),
level=self.logging_level,
Expand Down

0 comments on commit de4cfde

Please sign in to comment.