Skip to content

Commit

Permalink
Added proto function for LOCAL_VARIABLES and MODEL_VARIABLES.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 158739297
  • Loading branch information
sherrym authored and tensorflower-gardener committed Jun 12, 2017
1 parent b4aa475 commit 2d1823f
Show file tree
Hide file tree
Showing 5 changed files with 1,646 additions and 2 deletions.
9 changes: 9 additions & 0 deletions tensorflow/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -803,10 +803,18 @@ py_test(
],
)

filegroup(
name = "meta_graph_testdata",
srcs = [
"framework/testdata/metrics_export_meta_graph.pb",
],
)

py_test(
name = "framework_meta_graph_test",
size = "small",
srcs = ["framework/meta_graph_test.py"],
data = ["//tensorflow/python:meta_graph_testdata"],
main = "framework/meta_graph_test.py",
srcs_version = "PY2AND3",
deps = [
Expand All @@ -817,6 +825,7 @@ py_test(
":framework",
":framework_for_generated_wrappers",
":math_ops",
":metrics",
":nn_ops",
":platform",
":random_ops",
Expand Down
13 changes: 11 additions & 2 deletions tensorflow/python/framework/meta_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
# Prefix to be added to unbound input names so they are easily identifiable.
_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"

# List of collections that didn't register proto functions, as a result in
# a previously exported meta_graph the items are of a different data type.
_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
ops.GraphKeys.MODEL_VARIABLES]


def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
"""Create a `NodeDef` proto with export_scope stripped.
Expand Down Expand Up @@ -667,8 +672,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
key)
continue
from_proto = ops.get_from_proto_function(key)
if from_proto:
assert kind == "bytes_list"
if from_proto and kind == "bytes_list":
proto_type = ops.get_collection_proto_type(key)
for value in col_def.bytes_list.value:
proto = proto_type()
Expand All @@ -677,6 +681,11 @@ def import_scoped_meta_graph(meta_graph_or_file,
key, from_proto(proto, import_scope=scope_to_prepend_to_names))
else:
field = getattr(col_def, kind)
if key in _COMPAT_COLLECTION_LIST:
logging.warning(
"The saved meta_graph is possibly from an older release:\n"
"'%s' collection should be of type 'byte_list', but instead "
"is of type '%s'.", key, kind)
if kind == "node_list":
for value in field.value:
col_op = graph.as_graph_element(
Expand Down
58 changes: 58 additions & 0 deletions tensorflow/python/framework/meta_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
Expand Down Expand Up @@ -599,5 +600,62 @@ def testClearDevices(self):
self.assertEqual("", str(graph2.as_graph_element("matmul").device))


class MetaGraphWithVariableScopeTest(test.TestCase):

def testMetricsCollection(self):

def _enqueue_vector(sess, queue, values, shape=None):
if not shape:
shape = (1, len(values))
dtype = queue.dtypes[0]
sess.run(
queue.enqueue(constant_op.constant(
values, dtype=dtype, shape=shape)))

meta_graph_filename = os.path.join(
_TestDir("metrics_export"), "meta_graph.pb")

graph = ops.Graph()
with self.test_session(graph=graph) as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()

_, update_op = metrics.mean(values)

initializer = variables.local_variables_initializer()
sess.run(initializer)
sess.run(update_op)

meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)

# Verifies that importing a meta_graph with LOCAL_VARIABLES collection
# works correctly.
graph = ops.Graph()
with self.test_session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(meta_graph_filename)
initializer = variables.local_variables_initializer()
sess.run(initializer)

# Verifies that importing an old meta_graph where "local_variables"
# collection is of node_list type works, but cannot build initializer
# with the collection.
graph = ops.Graph()
with self.test_session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(
test.test_src_dir_path(
"python/framework/testdata/metrics_export_meta_graph.pb"))
self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
2)
with self.assertRaisesRegexp(
AttributeError, "'Tensor' object has no attribute 'initializer'"):
initializer = variables.local_variables_initializer()


if __name__ == "__main__":
test.main()
Loading

0 comments on commit 2d1823f

Please sign in to comment.