Skip to content

Commit

Permalink
Python 3 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
sergomezcol committed Feb 9, 2017
1 parent 2cff59a commit caa1448
Show file tree
Hide file tree
Showing 14 changed files with 24 additions and 11 deletions.
1 change: 1 addition & 0 deletions convergence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

from six.moves import xrange
import tensorflow as tf

import meta
Expand Down
1 change: 1 addition & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

from six.moves import xrange
import tensorflow as tf

from tensorflow.contrib.learn.python.learn import monitored_session as ms
Expand Down
7 changes: 4 additions & 3 deletions meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def _make_nets(variables, config, net_assignments):
"a single net config.")

with tf.variable_scope("vars_optimizer"):
key, kwargs = config.items()[0]
key = next(iter(config))
kwargs = config[key]
net = networks.factory(**kwargs)

nets = {key: net}
Expand Down Expand Up @@ -246,7 +247,7 @@ def __init__(self, **kwargs):
def save(self, sess, path=None):
"""Save meta-optimizer."""
result = {}
for k, net in self._nets.iteritems():
for k, net in self._nets.items():
if path is None:
filename = None
key = k
Expand Down Expand Up @@ -376,7 +377,7 @@ def time_step(t, fx_array, x, state):
nest.flatten(_nested_assign(state, s_final)))

# Log internal variables.
for k, net in nets.iteritems():
for k, net in nets.items():
print("Optimizer '{}' variables".format(k))
print([op.name for op in nn.get_variables_in_module(net)])

Expand Down
6 changes: 4 additions & 2 deletions meta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nose_parameterized import parameterized
import numpy as np
from six.moves import xrange
import tensorflow as tf

import meta
Expand Down Expand Up @@ -62,7 +63,7 @@ def testResults(self):

# Torch results
torch_cost = 0.7325327
torch_final_x = np.array([0.8559])
torch_final_x = 0.8559

self.assertAlmostEqual(cost, torch_cost, places=4)
self.assertAlmostEqual(final_x[0], torch_final_x, places=4)
Expand Down Expand Up @@ -208,7 +209,8 @@ def testSaveAndLoad(self):

# Save optimizer.
tmp_dir = tempfile.mkdtemp()
net_path = optimizer.save(sess, path=tmp_dir).keys()[0]
save_result = optimizer.save(sess, path=tmp_dir)
net_path = next(iter(save_result))

# Retrain original optimizer.
cost, x = train(sess, minimize_ops, num_unrolls, num_epochs)
Expand Down
4 changes: 2 additions & 2 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def factory(net, net_options=(), net_path=None):
net_options = dict(net_options)

if net_path:
with open(net_path, "r") as f:
with open(net_path, "rb") as f:
net_options["initializer"] = pickle.load(f)

return net_class(**net_options)
Expand All @@ -56,7 +56,7 @@ def save(network, sess, filename=None):
to_save[module_name][variable_name] = v.eval(sess)

if filename:
with open(filename, "w") as f:
with open(filename, "wb") as f:
pickle.dump(to_save, f)

return to_save
Expand Down
5 changes: 3 additions & 2 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from __future__ import print_function

import abc
import types
import six
from six import string_types
from six.moves import xrange
import tensorflow as tf


Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(self, name):
Raises:
ValueError: If name is not specified.
"""
if not isinstance(name, types.StringTypes):
if not isinstance(name, string_types):
raise ValueError("Name must be a string.")
self._is_connected = False
self._template = tf.make_template(name, self._build,
Expand Down
1 change: 1 addition & 0 deletions nn/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import division
from __future__ import print_function

from six.moves import xrange
import tensorflow as tf

from tensorflow.contrib.layers.python.layers import utils
Expand Down
1 change: 1 addition & 0 deletions nn/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import collections

from six.moves import xrange
import tensorflow as tf

from nn import base
Expand Down
1 change: 1 addition & 0 deletions nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import collections

from six.moves import xrange
import tensorflow as tf

from nn import base
Expand Down
1 change: 1 addition & 0 deletions nn/rnn_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


import six
from six.moves import xrange
import tensorflow as tf

from tensorflow.python.framework import tensor_shape
Expand Down
4 changes: 2 additions & 2 deletions nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def check_initializers(initializers, keys):
", ".join("'{}'".format(key) for key in keys)))

def check_nested_callables(dictionary):
for key, entry in dictionary.iteritems():
for key, entry in dictionary.items():
if isinstance(entry, dict):
check_nested_callables(entry)
elif not callable(entry):
Expand Down Expand Up @@ -156,7 +156,7 @@ def check_partitioners(partitioners, keys):
", ".join("'{}'".format(key) for key in keys)))

def check_nested_callables(dictionary):
for key, entry in dictionary.iteritems():
for key, entry in dictionary.items():
if isinstance(entry, dict):
check_nested_callables(entry)
elif not callable(entry):
Expand Down
1 change: 1 addition & 0 deletions problems_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

from six.moves import xrange
import tensorflow as tf

from nose_parameterized import parameterized
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import os

from six.moves import xrange
import tensorflow as tf

from tensorflow.contrib.learn.python.learn import monitored_session as ms
Expand Down
1 change: 1 addition & 0 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from timeit import default_timer as timer

import numpy as np
from six.moves import xrange

import problems

Expand Down

0 comments on commit caa1448

Please sign in to comment.