Skip to content

Commit

Permalink
Made sure that the nodes listed as feed, fetch and init_op exist in t…
Browse files Browse the repository at this point in the history
…he graph.

PiperOrigin-RevId: 159034290
  • Loading branch information
benoitsteiner authored and tensorflower-gardener committed Jun 14, 2017
1 parent 69bc160 commit bea7255
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/grappler/grappler_item.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
std::vector<const NodeDef*> queue;
for (const string& root : terminal_nodes) {
const NodeDef* node = name_to_node[NodeName(root)];
CHECK(node);
CHECK(node) << "Unknown root " << root;
queue.push_back(node);
}

Expand Down
23 changes: 23 additions & 0 deletions tensorflow/core/grappler/grappler_item_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,29 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
return nullptr;
}

// Validate feed, fetch and init nodes
std::unordered_set<string> nodes;
for (const auto& node : new_item->graph.node()) {
nodes.insert(node.name());
}
for (const auto& feed : new_item->feed) {
if (nodes.find(feed.first) == nodes.end()) {
LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
return nullptr;
}
}
for (const auto& fetch : new_item->fetch) {
if (nodes.find(fetch) == nodes.end()) {
LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
return nullptr;
}
}
for (const auto& init : new_item->init_ops) {
if (nodes.find(init) == nodes.end()) {
LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
return nullptr;
}
}
return new_item;
}

Expand Down
11 changes: 8 additions & 3 deletions tensorflow/core/grappler/grappler_item_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ void SampleSumSymbolicGradientGraphdef(
auto g0 = SymbolicGradient(scope, std::initializer_list<Input>{x, y, z},
{DT_FLOAT, DT_INT32}, fn);

fetches->mutable_node_list()->add_value(g0[0].name());
// TODO(bsteiner): we should rewrite the feed/fetch nodes to reflect the
// inlining that's done in the item builder
// fetches->mutable_node_list()->add_value(g0[0].name());
fetches->mutable_node_list()->add_value("SymbolicGradient/dx");
fetches->mutable_node_list()->add_value("SymbolicGradient/dy_reshaped");

TF_CHECK_OK(scope.ToGraphDef(def));

Expand Down Expand Up @@ -109,11 +113,12 @@ TEST_F(GrapplerItemBuilderTest, SymbolicGradientInlining) {
std::unique_ptr<GrapplerItem> with_inline = CreateGrapplerItem(def, fetches);

// For the inlined graph, there should be 0 symbolic gradient ops.
CHECK_EQ(0, CountSymbolicGradientOps(with_inline));
EXPECT_EQ(0, CountSymbolicGradientOps(with_inline));

// For the inlined graph, make sure all the required expanded op’s are in the
// graph.
CHECK_EQ(ops_of_inline.size(), CountOpsWithNames(with_inline, ops_of_inline));
EXPECT_EQ(ops_of_inline.size(),
CountOpsWithNames(with_inline, ops_of_inline));
}

} // namespace
Expand Down

0 comments on commit bea7255

Please sign in to comment.