From e85d3df92deb9d717befdf173966a2913ac2aea0 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Thu, 29 Jun 2017 11:44:13 -0700 Subject: [PATCH] Prepare to remove a bunch of proto.h includes from tensorflow/core headers The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. This CL does not remove any actual headers, but changes a bunch of files so that header removal is possible in a followup CL. It also marks the headers that will be removed with // TODO(b/62899350): Remove RELNOTES: n/a PiperOrigin-RevId: 160552878 --- tensorflow/c/BUILD | 1 + tensorflow/c/c_api.cc | 11 +++ tensorflow/c/c_api_internal.h | 9 +-- tensorflow/c/c_api_test.cc | 1 + tensorflow/cc/BUILD | 2 + tensorflow/cc/framework/cc_op_gen.cc | 4 + tensorflow/cc/framework/gradients_test.cc | 1 + tensorflow/cc/framework/scope.cc | 2 +- tensorflow/compiler/aot/compile.cc | 1 + tensorflow/compiler/jit/BUILD | 1 + tensorflow/compiler/jit/xla_device.cc | 2 + .../compiler/tf2xla/kernels/const_op.cc | 1 + tensorflow/compiler/tf2xla/test_util.h | 1 + tensorflow/compiler/tf2xla/xla_op_registry.cc | 1 + ...single_image_random_dot_stereograms_ops.cc | 14 ++-- .../kernels/sparse_feature_cross_kernel.cc | 1 + tensorflow/contrib/tpu/ops/infeed_ops.cc | 8 +- .../convert_graphdef_memmapped_format_lib.cc | 2 + tensorflow/core/BUILD | 8 +- .../core/common_runtime/direct_session.cc | 1 + tensorflow/core/common_runtime/function.cc | 1 + .../core/common_runtime/function_test.cc | 1 + .../core/common_runtime/gpu/gpu_util.cc | 1 + .../core/common_runtime/graph_runner.cc | 1 + .../kernel_benchmark_testlib.cc | 1 + .../core/common_runtime/shape_refiner.cc | 5 ++ .../core/common_runtime/shape_refiner.h | 4 + .../simple_graph_execution_state.cc | 1 + .../common_runtime/step_stats_collector.cc | 2 + tensorflow/core/debug/debug_io_utils.cc | 1 + tensorflow/core/distributed_runtime/BUILD | 2 + .../core/distributed_runtime/executor_test.cc | 1 + .../core/distributed_runtime/graph_mgr.cc | 1 + .../distributed_runtime/master_session.cc | 2 + tensorflow/core/distributed_runtime/rpc/BUILD | 1 + .../rpc/grpc_tensor_coding.cc | 2 + .../core/distributed_runtime/tensor_coding.cc | 2 + .../distributed_runtime/tensor_coding_test.cc | 1 + .../example/example_parser_configuration.cc | 2 + tensorflow/core/framework/allocator.cc | 8 ++ tensorflow/core/framework/attr_value_util.cc | 3 + tensorflow/core/framework/attr_value_util.h | 10 ++- .../core/framework/attr_value_util_test.cc | 1 + .../core/framework/common_shape_fns_test.cc | 14 +--- tensorflow/core/framework/device_base.cc | 4 + tensorflow/core/framework/device_base.h | 13 ++-- tensorflow/core/framework/fake_input.cc | 1 + tensorflow/core/framework/function.cc | 1 + tensorflow/core/framework/function.h | 4 +- tensorflow/core/framework/graph_def_util.cc | 2 + tensorflow/core/framework/graph_def_util.h | 6 +- .../core/framework/kernel_def_builder.cc | 13 +++- .../core/framework/kernel_def_builder.h | 16 ++-- tensorflow/core/framework/memory_types.h | 4 +- tensorflow/core/framework/node_def_builder.cc | 71 ++++++++++++++++-- tensorflow/core/framework/node_def_builder.h | 75 +++++++++---------- tensorflow/core/framework/node_def_util.cc | 1 + tensorflow/core/framework/node_def_util.h | 5 +- tensorflow/core/framework/op.h | 2 +- tensorflow/core/framework/op_def_builder.cc | 1 + tensorflow/core/framework/op_def_util.cc | 1 + tensorflow/core/framework/op_gen_lib.cc | 9 ++- tensorflow/core/framework/op_gen_lib.h | 13 +++- tensorflow/core/framework/op_kernel.cc | 1 + tensorflow/core/framework/op_kernel.h | 16 ++-- tensorflow/core/framework/op_kernel_test.cc | 2 + .../framework/partial_tensor_shape_test.cc | 1 + tensorflow/core/framework/reader_base.cc | 1 + tensorflow/core/framework/reader_base.h | 4 +- tensorflow/core/framework/resource_mgr.cc | 1 + tensorflow/core/framework/resource_mgr.h | 3 +- .../core/framework/resource_mgr_test.cc | 1 + tensorflow/core/framework/shape_inference.cc | 53 +++++++++++++ tensorflow/core/framework/shape_inference.h | 22 +++++- .../core/framework/shape_inference_test.cc | 19 ++--- .../core/framework/shape_inference_testutil.h | 2 - tensorflow/core/framework/tensor.cc | 2 + tensorflow/core/framework/tensor.h | 13 +++- tensorflow/core/framework/tensor_reference.h | 1 - tensorflow/core/framework/tensor_shape.cc | 1 + tensorflow/core/framework/tensor_shape.h | 3 +- .../core/framework/tensor_shape_test.cc | 1 + tensorflow/core/framework/tensor_test.cc | 1 + tensorflow/core/framework/versions.cc | 1 + tensorflow/core/framework/versions.h | 4 +- tensorflow/core/graph/costmodel.cc | 2 + tensorflow/core/graph/graph.cc | 13 +++- tensorflow/core/graph/graph.h | 12 +-- tensorflow/core/graph/graph_constructor.cc | 2 + .../core/graph/graph_constructor_test.cc | 1 + .../core/graph/graph_def_builder_test.cc | 1 + tensorflow/core/graph/graph_partition.cc | 2 + tensorflow/core/graph/graph_partition_test.cc | 1 + tensorflow/core/graph/node_builder.cc | 1 + .../grappler/clusters/single_machine_test.cc | 1 + .../core/grappler/clusters/virtual_cluster.cc | 1 + .../grappler/clusters/virtual_cluster_test.cc | 1 + tensorflow/core/grappler/costs/BUILD | 4 + .../costs/analytical_cost_estimator.cc | 1 + .../core/grappler/costs/graph_properties.cc | 2 +- .../grappler/costs/graph_properties_test.cc | 1 + .../grappler/costs/op_level_cost_estimator.cc | 2 + tensorflow/core/grappler/costs/utils.cc | 1 + .../core/grappler/costs/virtual_scheduler.cc | 3 + .../grappler/costs/virtual_scheduler_test.cc | 2 + .../core/grappler/grappler_item_builder.cc | 2 + .../core/grappler/optimizers/auto_parallel.cc | 1 + .../grappler/optimizers/constant_folding.cc | 1 + .../grappler/optimizers/layout_optimizer.cc | 2 + .../grappler/optimizers/memory_optimizer.cc | 1 + tensorflow/core/kernels/BUILD | 2 +- tensorflow/core/kernels/decode_image_op.cc | 1 + tensorflow/core/kernels/gather_nd_op.cc | 1 + .../kernels/hexagon/graph_transfer_utils.cc | 1 + .../core/kernels/hexagon/graph_transferer.cc | 2 +- .../hexagon/hexagon_graph_execution_test.cc | 1 + tensorflow/core/kernels/identity_reader_op.cc | 2 + .../core/kernels/reduction_ops_common.cc | 2 + .../remote_fused_graph_execute_utils.cc | 13 ++-- .../remote_fused_graph_execute_utils_test.cc | 4 +- tensorflow/core/kernels/scatter_nd_op.cc | 1 + .../core/kernels/serialize_sparse_op.cc | 1 + tensorflow/core/kernels/sparse_cross_op.cc | 1 + tensorflow/core/kernels/unique_op_test.cc | 1 + .../core/kernels/whole_file_read_ops.cc | 2 + tensorflow/core/ops/array_ops.cc | 46 +++++------- tensorflow/core/ops/array_ops_test.cc | 2 + tensorflow/core/ops/math_ops_test.cc | 1 + tensorflow/core/ops/parsing_ops.cc | 14 ++-- tensorflow/core/ops/parsing_ops_test.cc | 17 ++--- tensorflow/core/ops/resource_variable_ops.cc | 4 +- tensorflow/core/ops/state_ops.cc | 16 ++-- tensorflow/core/ops/state_ops_test.cc | 21 ++---- tensorflow/core/util/equal_graph_def.cc | 2 + tensorflow/core/util/equal_graph_def.h | 5 +- .../core/util/memmapped_file_system_test.cc | 1 + tensorflow/core/util/padding.h | 4 +- tensorflow/core/util/stat_summarizer.cc | 10 +++ tensorflow/core/util/stat_summarizer.h | 8 +- .../core/util/tensor_bundle/tensor_bundle.cc | 1 + .../core/util/tensor_slice_writer_test.cc | 1 + .../adding_an_op/zero_out_op_kernel_2.cc | 1 + tensorflow/python/BUILD | 1 + .../python/framework/cpp_shape_inference.cc | 1 + tensorflow/python/grappler/model_analyzer.cc | 1 + tensorflow/python/lib/core/py_func.cc | 1 + .../fold_old_batch_norms_test.cc | 1 + .../strip_unused_nodes_test.cc | 1 + .../graph_transforms/summarize_graph_main.cc | 2 + .../tools/graph_transforms/transform_utils.h | 2 + 150 files changed, 555 insertions(+), 240 deletions(-) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3ab4e8efcdb5b0..9267ef77efb9ba 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -62,6 +62,7 @@ tf_cuda_library( "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", ], }), diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index c320fe94a1f72e..fb06389af2de4d 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -28,12 +28,15 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -1587,6 +1590,14 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, // TF_Graph functions --------------------------------------------------------- +TF_Graph::TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + num_sessions(0), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) {} + TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index f2773ae20f8bcf..7e987a65f7b24e 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -56,13 +56,8 @@ struct TF_Library { }; struct TF_Graph { - TF_Graph() - : graph(tensorflow::OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) {} + TF_Graph(); + tensorflow::mutex mu; tensorflow::Graph graph GUARDED_BY(mu); diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 04540bd793dab3..35d6e295c2b8a8 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/tensor_id.h" diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 9801add1dac9d9..b461a475c13df2 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -45,6 +45,7 @@ tf_cc_test( "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -432,6 +433,7 @@ cc_library_with_android_deps( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", + "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 71aa986f918de6..80dd272f6f9dd5 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -18,8 +18,12 @@ limitations under the License. #include #include "tensorflow/cc/framework/cc_op_gen.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 6a249825812b4d..2aad9784808ea1 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 32c0822de69da7..1948dd4e46b932 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -136,7 +136,7 @@ Scope::Impl::Impl(const std::shared_ptr& graph, Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = - new ShapeRefiner(graph->versions().producer(), graph->op_registry()); + new ShapeRefiner(graph->versions(), graph->op_registry()); return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); } diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 317baa89715c93..59ff14600bc70a 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7ebd8422181981..8b2d0b7659a5e4 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -139,6 +139,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:tensorflow_opensource", "//tensorflow/core/kernels:constant_op", diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 5e336c5287bd9e..615e2230f42f63 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -31,9 +31,11 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index ad676e7a2bb3d3..9833323d851e00 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 362558bcfc0e85..e6e4ae92ed23f3 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 20adf300ec6a0a..d059c7a23ef295 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc index 6efcc29654fe94..9f0bf37aed3fc9 100755 --- a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc @@ -54,8 +54,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel { float normalize_min; float border_level; int number_colors; - ::tensorflow::TensorShapeProto output_image_shape; - ::tensorflow::TensorShapeProto output_data_window; + ::tensorflow::PartialTensorShape output_image_shape; + ::tensorflow::PartialTensorShape output_data_window; uint8 Cblack = 0; uint8 Cwhite = 255; @@ -109,15 +109,15 @@ class SingleImageRandomDotStereogramsOp : public OpKernel { input_Yvalue = input_tensor.shape().dim_size(0); // Y value is the number of rows - output_Ximage = output_image_shape.dim(0).size(); - output_Yimage = output_image_shape.dim(1).size(); - output_Cimage = output_image_shape.dim(2).size(); + output_Ximage = output_image_shape.dim_size(0); + output_Yimage = output_image_shape.dim_size(1); + output_Cimage = output_image_shape.dim_size(2); if (number_colors > 256) // Go to full color image output_Cimage = 3; - int data_Xwindow = output_data_window.dim(0).size(); - int data_Ywindow = output_data_window.dim(1).size(); + int data_Xwindow = output_data_window.dim_size(0); + int data_Ywindow = output_data_window.dim_size(1); int deltaX_border_image = output_Ximage - data_Xwindow; int deltaY_border_image = output_Yimage - data_Ywindow; diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 72df272af89543..badf9d486a2dcc 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc index be4d4f964936f7..c12e83137aa8f4 100644 --- a/tensorflow/contrib/tpu/ops/infeed_ops.cc +++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc @@ -29,10 +29,8 @@ REGISTER_OP("InfeedDequeue") .SetShapeFn([](InferenceContext* c) { PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - TensorShapeProto shape_proto; - shape.AsProto(&shape_proto); ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); c->set_output(0, out); return Status::OK(); }) @@ -87,10 +85,8 @@ REGISTER_OP("InfeedDequeueTuple") std::vector shapes; TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); for (int i = 0; i < shapes.size(); ++i) { - TensorShapeProto shape_proto; - shapes[i].AsProto(&shape_proto); ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out)); c->set_output(i, out); } return Status::OK(); diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc index 3e387129eb04b8..2992a61ea8186c 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/immutable_constant_op.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 24b0f702b398f4..d4097f40320636 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -500,7 +500,6 @@ cc_library( # Generates library per group of ops. tf_gen_op_libs( op_lib_names = [ - "array_ops", "bitwise_ops", "candidate_sampling_ops", "control_flow_ops", @@ -534,6 +533,13 @@ tf_gen_op_libs( ], ) +tf_gen_op_libs( + op_lib_names = [ + "array_ops", + ], + deps = [":protos_all_cc"], +) + tf_gen_op_libs( op_lib_names = [ "audio_ops", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 4b951691fbdd8d..7bc48aba278cb2 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 99389968ee6950..e3cc97c9461483 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/gradients.h" #include "tensorflow/core/graph/graph_constructor.h" diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index dec6ca996aa460..b00bb453b1c85f 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index ae9e5aeaa3dda6..b69c1ae8fec21f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 74b2252c7c6a45..2ce1e8b4830281 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 4a5b88d5fda595..420dfe338efb47 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/notification.h" diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index a6204b9d0dbf1f..e61ea9d84cd717 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -39,6 +40,10 @@ ShapeRefiner::ShapeRefiner(int graph_def_version, ops_registry_(ops), graph_runner_(Env::Default()) {} +ShapeRefiner::ShapeRefiner(const VersionDef& versions, + const OpRegistryInterface* ops) + : ShapeRefiner(versions.producer(), ops) {} + ShapeRefiner::~ShapeRefiner() { // The lifetime of the tensors are bound to the GraphRunner, so the tensors // should be deleted before it. diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 1af7835392fcf2..21e58381a5f28e 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -36,6 +36,10 @@ class GraphProperties; class ShapeRefiner { public: ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); + + // Same as ShapeRefiner(versions.producer(), ops) + ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops); + ~ShapeRefiner(); // Performs validation of 'node' and runs 'node's shape function, diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index c00eb3a2fc6970..358550f6415ce1 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/subgraph.h" diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 9b43385d6f754c..d410a164eac00a 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/common_runtime/costmodel_manager.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/scanner.h" diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 69fc36778928fc..875a4763f80cda 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -27,6 +27,7 @@ limitations under the License. #endif #include "tensorflow/core/debug/debugger_event_metadata.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index f59e5f4dc218e2..3bea72c4a3a13c 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -153,6 +153,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", ], ) @@ -205,6 +206,7 @@ cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc index 17843ff6b060b6..1a4980a61b208a 100644 --- a/tensorflow/core/distributed_runtime/executor_test.cc +++ b/tensorflow/core/distributed_runtime/executor_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 69f5b7d944440f..205843d3429867 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_partition.h" diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index e98d9cecfce3bf..2b6e0c52680390 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -26,11 +26,13 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/scheduler.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/blocking_counter.h" diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index bd381dd10f1fbe..c7349f0dd7e82d 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -108,6 +108,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", "@grpc//:grpc++_unsecure", ], diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc index 90e311a4930795..0dd6b5c89eaa30 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc @@ -18,7 +18,9 @@ limitations under the License. #include "grpc++/support/slice.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_reference.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/io/proto_encode_helper.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index f98bd17ab9307a..94d54a2b16bb38 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -17,6 +17,8 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/tensor_coding_test.cc b/tensorflow/core/distributed_runtime/tensor_coding_test.cc index 540b76ada68eb8..52a057bdb2f95f 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/tensor_coding.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" diff --git a/tensorflow/core/example/example_parser_configuration.cc b/tensorflow/core/example/example_parser_configuration.cc index 485cf6da4b0ad1..5660465c51adbd 100644 --- a/tensorflow/core/example/example_parser_configuration.cc +++ b/tensorflow/core/example/example_parser_configuration.cc @@ -17,9 +17,11 @@ limitations under the License. #include #include "tensorflow/core/example/feature.pb_text.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 943dcab36269db..e7092f549b21e9 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -48,6 +48,14 @@ constexpr size_t Allocator::kAllocatorAlignment; Allocator::~Allocator() {} +void RunResourceCtor(ResourceHandle* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle(); +} + +void RunResourceDtor(ResourceHandle* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); +} + // If true, cpu allocator collects more stats. static bool cpu_allocator_collect_stats = false; // If true, cpu allocator collects full stats. diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index b18ce3decc0268..9fdb3da6a0d33b 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb_text.h" #include "tensorflow/core/framework/tensor.pb_text.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/core/errors.h" @@ -287,6 +288,8 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { return ProtoParseFromString(to_parse, out); } +void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } + #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); } diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h index 0e25cec4abc078..08cc3b7158ef47 100644 --- a/tensorflow/core/framework/attr_value_util.h +++ b/tensorflow/core/framework/attr_value_util.h @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" // TODO(62899350): Remove #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -29,6 +29,10 @@ limitations under the License. namespace tensorflow { +// Forward declare protos so their symbols can be removed from .so exports +class AttrValue; +class NameAttrList; + // A human-readable rendering of attr_value, that is more concise than a // text-format proto. string SummarizeAttrValue(const AttrValue& attr_value); @@ -80,9 +84,7 @@ void SetAttrValue(gtl::ArraySlice value, AttrValue* out); void SetAttrValue(gtl::ArraySlice value, AttrValue* out); void SetAttrValue(gtl::ArraySlice value, AttrValue* out); -inline void SetAttrValue(const AttrValue& value, AttrValue* out) { - *out = value; -} +void SetAttrValue(const AttrValue& value, AttrValue* out); // Returns true if a and b have the same value. // NOTE: May return false negatives for tensor values. diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index c14ea9b322a244..5d30d327ae111a 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index d14e1dfee09cff..99d3ef3af1d6a8 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -26,19 +26,11 @@ namespace shape_inference { namespace { -TensorShapeProto S(std::initializer_list dims) { - PartialTensorShape shape(dims); - TensorShapeProto ret; - shape.AsProto(&ret); - return ret; +PartialTensorShape S(std::initializer_list dims) { + return PartialTensorShape(dims); } -TensorShapeProto Unknown() { - PartialTensorShape shape; - TensorShapeProto ret; - shape.AsProto(&ret); - return ret; -} +PartialTensorShape Unknown() { return PartialTensorShape(); } OpDef MakeOpDef(int num_inputs, int num_outputs) { OpRegistrationData op_reg_data; diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc index ea0ed3ccbbd977..f5bc24aafe3cdf 100644 --- a/tensorflow/core/framework/device_base.cc +++ b/tensorflow/core/framework/device_base.cc @@ -19,4 +19,8 @@ namespace tensorflow { DeviceBase::~DeviceBase() {} +const DeviceAttributes& DeviceBase::attributes() const { + LOG(FATAL) << "Device does not implement attributes()"; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 27fe28fe60a9bd..e1eb387d88b4c3 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -19,9 +19,9 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_attributes.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" @@ -44,10 +44,12 @@ class Stream; namespace tensorflow { class Device; +class DeviceAttributes; class Env; class EventMgr; class OpKernelContext; class ResourceMgr; +class TensorProto; namespace thread { class ThreadPool; @@ -194,11 +196,8 @@ class DeviceBase { DeviceContext* /*dc*/, Allocator* /*allocator*/) {} - virtual const DeviceAttributes& attributes() const { - LOG(FATAL) << "Device does not implement attributes()"; - static DeviceAttributes dummy; - return dummy; - } + // Unimplemented by default + virtual const DeviceAttributes& attributes() const; // Materializes the given TensorProto into 'tensor' stored in Device // memory. Most devices will want to override this. diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc index 7a21dd5066c341..ad301a8aa4ba4b 100644 --- a/tensorflow/core/framework/fake_input.cc +++ b/tensorflow/core/framework/fake_input.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/fake_input.h" #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_util.h" diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index fe6e9a6cd60776..9d43aab5a5dc9a 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/function.pb_text.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 6c2da84790c021..b8d5b8797af31c 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_FUNCTION_H_ #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/selective_registration.h" @@ -33,6 +34,7 @@ limitations under the License. namespace tensorflow { class CancellationManager; +class GraphDef; class OpKernel; class ResourceMgr; class ScopedStepContainer; diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index aeedf4b0efce10..bd018b7243897a 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index 950737c39aae00..838c9fd4ce3190 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -17,13 +17,15 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ #include - -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +// Forward declare proto so that it's symbols can be removed from .so exports +class GraphDef; + // Produce a human-readable version of a GraphDef that is more concise // than a text-format proto. string SummarizeGraphDef(const GraphDef& graph_def); diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc index 6366ac5bebb304..eb86f18ff06c38 100644 --- a/tensorflow/core/framework/kernel_def_builder.cc +++ b/tensorflow/core/framework/kernel_def_builder.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb_text.h" -#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/kernel_def.pb.h" namespace tensorflow { @@ -24,6 +25,10 @@ KernelDefBuilder::KernelDefBuilder(const char* op_name) { kernel_def_->set_op(op_name); } +KernelDefBuilder::~KernelDefBuilder() { + DCHECK(kernel_def_ == nullptr) << "Did not call Build()"; +} + KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) { kernel_def_->set_device_type(device_type); return *this; @@ -61,4 +66,10 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) { return *this; } +const KernelDef* KernelDefBuilder::Build() { + KernelDef* r = kernel_def_; + kernel_def_ = nullptr; + return r; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h index 84657f8dbf339b..27f768c72fd1e5 100644 --- a/tensorflow/core/framework/kernel_def_builder.h +++ b/tensorflow/core/framework/kernel_def_builder.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ #define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ -#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" @@ -24,16 +24,16 @@ limitations under the License. namespace tensorflow { +// Forward declare proto so that kernels don't need to depend on it +class KernelDef; + // Builder class passed to the REGISTER_KERNEL_BUILDER() macro. class KernelDefBuilder { public: // Starts with just the name field set. // Caller MUST call Build() and take ownership of the result. explicit KernelDefBuilder(const char* op_name); - - ~KernelDefBuilder() { - DCHECK(kernel_def_ == nullptr) << "Did not call Build()"; - } + ~KernelDefBuilder(); // Required: specify the type of device this kernel supports. // Returns *this. @@ -68,11 +68,7 @@ class KernelDefBuilder { // Returns a pointer to a KernelDef with fields set based on the // above calls to this instance. // Caller takes ownership of the result. - const KernelDef* Build() { - KernelDef* r = kernel_def_; - kernel_def_ = nullptr; - return r; - } + const KernelDef* Build(); private: KernelDef* kernel_def_; diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h index e35e22f5907b09..a82aea9f0763bb 100644 --- a/tensorflow/core/framework/memory_types.h +++ b/tensorflow/core/framework/memory_types.h @@ -16,12 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ #define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_ -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" namespace tensorflow { +class NodeDef; + // Returns into *{input,output}_memory_types the memory type of each // {input,output} tensor. // diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 9385d1266a90e5..f9cf6ce87359d6 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -83,6 +84,25 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { return *this; } +NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index, + DataType dt) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); + return *this; +} + +NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) { + Input(src.node, src.index, src.data_type); + return *this; +} + +// For inputs that take a list of tensors. +NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice src_list) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) ListInput(arg, src_list); + return *this; +} + void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node, int src_index, DataType dt) { @@ -228,14 +248,51 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const { } } -void NodeDefBuilder::CheckInconsistency(StringPiece attr_name, - const AttrValue& found, - const AttrValue& attr_value) { - if (!AreAttrValuesEqual(found, attr_value)) { - errors_.push_back(strings::StrCat( - "Inconsistent values for attr '", attr_name, "' ", - SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value))); +NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { + if (const AttrValue* found = AttrSlice(node_def_).Find(name)) { + if (!AreAttrValuesEqual(*found, value)) { + errors_.push_back(strings::StrCat("Inconsistent values for attr '", name, + "' ", SummarizeAttrValue(*found), + " vs. ", SummarizeAttrValue(value))); + } + } else { + AddNodeAttr(name, value, &node_def_); } + return *this; } +#define ATTR(T) \ + NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \ + AttrValue attr_value; \ + SetAttrValue(value, &attr_value); \ + return Attr(name, attr_value); \ + } +ATTR(StringPiece) +ATTR(const char*) +ATTR(int32) +ATTR(int64) +ATTR(float) +ATTR(double) +ATTR(bool) +ATTR(DataType) +ATTR(const PartialTensorShape&) +ATTR(const Tensor&) +ATTR(const TensorProto&) +ATTR(const NameAttrList&) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(const std::vector&) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +ATTR(gtl::ArraySlice) +#undef ATTR + } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index fd26d0ae6494a5..d7f1d36540ad87 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -19,7 +19,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -72,22 +72,11 @@ class NodeDefBuilder { // *and in the same order as the input_args appear in the OpDef.* // For inputs that take a single tensor. - NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) { - const OpDef::ArgDef* arg = NextArgDef(); - if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); - return *this; - } - NodeDefBuilder& Input(const NodeOut& src) { - Input(src.node, src.index, src.data_type); - return *this; - } + NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt); + NodeDefBuilder& Input(const NodeOut& src); // For inputs that take a list of tensors. - NodeDefBuilder& Input(gtl::ArraySlice src_list) { - const OpDef::ArgDef* arg = NextArgDef(); - if (arg != nullptr) ListInput(arg, src_list); - return *this; - } + NodeDefBuilder& Input(gtl::ArraySlice src_list); // To create inputs in tests, see fake_input.h. NodeDefBuilder& Input(FakeInputFunctor fake_input); @@ -100,13 +89,39 @@ class NodeDefBuilder { // Sets the attr, if not already set. If already set with a different // value, an error will be returned from Finalize(). + NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); + NodeDefBuilder& Attr(StringPiece name, StringPiece value); + NodeDefBuilder& Attr(StringPiece name, const char* value); + NodeDefBuilder& Attr(StringPiece name, int32 value); + NodeDefBuilder& Attr(StringPiece name, int64 value); + NodeDefBuilder& Attr(StringPiece name, float value); + NodeDefBuilder& Attr(StringPiece name, double value); + NodeDefBuilder& Attr(StringPiece name, bool value); + NodeDefBuilder& Attr(StringPiece name, DataType value); + NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value); + NodeDefBuilder& Attr(StringPiece name, const Tensor& value); + NodeDefBuilder& Attr(StringPiece name, const TensorProto& value); + NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, const std::vector& value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, + gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, + gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice value); + template - NodeDefBuilder& Attr(StringPiece attr_name, T&& value); - // Note: overload needed to allow {...} expressions for value. - template - NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list value) { - Attr>(attr_name, std::move(value)); - return *this; + NodeDefBuilder& Attr(StringPiece name, std::initializer_list value) { + return Attr(name, gtl::ArraySlice(value)); } // Finish building the NodeDef, returning any errors or setting @@ -152,9 +167,6 @@ class NodeDefBuilder { return input_arg->is_ref() ? MakeRefType(dt) : dt; } - void CheckInconsistency(StringPiece attr_name, const AttrValue& found, - const AttrValue& attr_value); - const OpDef* op_def_; NodeDef node_def_; int inputs_specified_; @@ -162,21 +174,6 @@ class NodeDefBuilder { std::vector errors_; }; -// IMPLEMENTATION ------------------------------------------------------------- - -template -NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) { - const AttrValue* found = AttrSlice(node_def_).Find(attr_name); - if (found == nullptr) { - AddNodeAttr(attr_name, std::forward(value), &node_def_); - } else { - AttrValue attr_value; - SetAttrValue(std::forward(value), &attr_value); - CheckInconsistency(attr_name, *found, attr_value); - } - return *this; -} - } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 79feb20d53a7c8..b98a6033d0a28b 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/tensor.pb_text.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 5d4864db665541..a829243a75a024 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" @@ -30,8 +30,9 @@ namespace tensorflow { class Node; -// We forward declare NodeDef so that kernels don't need to depend on protos +// We forward declare protos so that kernels don't need to depend on them class NodeDef; +class OpDef; // Name of the attribute used to encode node colocation constraints. // diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index a4dd06de4538f5..1c63a6f4c0e1e4 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/selective_registration.h" diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 58a30a87a8f569..62b504691b2b54 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index c36e6dd653b481..2f25b6e18fc05d 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index da623ae5b25b3e..517120ecab37ef 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -71,6 +73,9 @@ bool ConsumeEquals(StringPiece* description) { return false; } +OpGenOverrideMap::OpGenOverrideMap() {} +OpGenOverrideMap::~OpGenOverrideMap() {} + Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) { std::vector v = str_util::Split(filenames, ","); for (const string& f : v) { @@ -86,7 +91,7 @@ Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) { OpGenOverrides all; protobuf::TextFormat::ParseFromString(contents, &all); for (const auto& one : all.op()) { - map_[one.name()] = one; + map_[one.name()].reset(new OpGenOverride(one)); } return Status::OK(); } @@ -142,7 +147,7 @@ const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const { // Look up const auto iter = map_.find(op_def->name()); if (iter == map_.end()) return nullptr; - const OpGenOverride& proto = iter->second; + const OpGenOverride& proto = *iter->second; // Apply overrides from `proto`. if (!proto.rename_to().empty()) { diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h index e92dc8d92417e9..a74b651bb67b3a 100644 --- a/tensorflow/core/framework/op_gen_lib.h +++ b/tensorflow/core/framework/op_gen_lib.h @@ -18,14 +18,18 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/op_gen_overrides.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { +// Forward declare protos so their symbols can be removed from .so exports +class OpDef; +class OpGenOverride; + inline string Spaces(int n) { return string(n, ' '); } // Wrap prefix + str to be at most width characters, indenting every line @@ -43,6 +47,9 @@ bool ConsumeEquals(StringPiece* description); // look up the specific override for any given op. class OpGenOverrideMap { public: + OpGenOverrideMap(); + ~OpGenOverrideMap(); + // `filenames` is a comma-separated list of file names. If an op // is mentioned in more than one file, the last one takes priority. Status LoadFileList(Env* env, const string& filenames); @@ -61,7 +68,7 @@ class OpGenOverrideMap { const OpGenOverride* ApplyOverride(OpDef* op_def) const; private: - std::unordered_map map_; + std::unordered_map> map_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 3892320b7d07e9..067d30d21a8a84 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/kernel_def.pb_text.h" #include "tensorflow/core/framework/log_memory.h" diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index fd5c8fdbb8c562..1b716c5a5a3743 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -24,19 +24,19 @@ limitations under the License. #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/function.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/kernel_def.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/session_state.h" -#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/tracking_allocator.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -65,9 +65,13 @@ class TensorSliceReaderCacheWrapper; } // namespace checkpoint class AsyncOpKernel; +class FunctionCallFrame; +class FunctionLibraryRuntime; class OpKernelConstruction; // declared below class OpKernelContext; // declared below +class OpRegistryInterface; class ResourceMgr; +class ScopedStepContainer; class OpKernel { public: diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index f87b7178449c65..47523358bed408 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/framework/partial_tensor_shape_test.cc b/tensorflow/core/framework/partial_tensor_shape_test.cc index f8ebd99bf88a49..54ae019f9b4812 100644 --- a/tensorflow/core/framework/partial_tensor_shape_test.cc +++ b/tensorflow/core/framework/partial_tensor_shape_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc index ebed957d99df36..b8c771a0a1955b 100644 --- a/tensorflow/core/framework/reader_base.cc +++ b/tensorflow/core/framework/reader_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/reader_base.h" +#include "tensorflow/core/framework/reader_base.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/framework/reader_base.h b/tensorflow/core/framework/reader_base.h index 0528841814b822..2b5052e959b799 100644 --- a/tensorflow/core/framework/reader_base.h +++ b/tensorflow/core/framework/reader_base.h @@ -19,12 +19,14 @@ limitations under the License. #include #include #include "tensorflow/core/framework/queue_interface.h" -#include "tensorflow/core/framework/reader_base.pb.h" +#include "tensorflow/core/framework/reader_base.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/reader_interface.h" #include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { +class ReaderBaseState; + // Default implementation of ReaderInterface. class ReaderBase : public ReaderInterface { public: diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 3018e4f655181d..bc3ba914e095be 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 0e1a5a82d3fa4b..c92c0a36bad0e5 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -22,8 +22,9 @@ limitations under the License. #include #include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index df4d8c35915178..07272e2374cbf4 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 62f85d2dac29c2..6947f680021a75 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -79,6 +80,58 @@ InferenceContext::InferenceContext( PostInputInit(std::move(handle_data)); } +// Same as above, but with PartialTensorShape instead of TensorShapeProto +InferenceContext::InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector< + std::unique_ptr>>>& + input_handle_shapes_and_types) + : graph_def_version_(graph_def_version), + node_def_(*CHECK_NOTNULL(node_def)) { + std::vector input_tensors_as_shape_handles; + for (const PartialTensorShape& p : input_tensors_as_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape)); + if (!construction_status_.ok()) { + return; + } + input_tensors_as_shape_handles.push_back(shape); + } + PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles); + if (!construction_status_.ok()) return; + for (const PartialTensorShape& p : input_shapes) { + ShapeHandle shape; + construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape)); + if (!construction_status_.ok()) { + return; + } + inputs_.push_back(shape); + } + std::vector>> handle_data( + input_shapes.size()); + for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) { + const auto& v = input_handle_shapes_and_types[i]; + if (v == nullptr) { + continue; + } + handle_data[i].reset(new std::vector(v->size())); + auto& new_v = *handle_data[i]; + for (int j = 0; j < v->size(); ++j) { + const auto& p = (*v)[j]; + construction_status_.Update( + MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape)); + if (!construction_status_.ok()) { + return; + } + new_v[j].dtype = p.second; + } + } + PostInputInit(std::move(handle_data)); +} + InferenceContext::InferenceContext( int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector& input_shapes, diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 460aefe29e35e6..716cec5c4a52c6 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" @@ -189,6 +189,26 @@ class InferenceContext { std::unique_ptr>>>& input_handle_shapes_and_types); + // is NULL-padded to be the same size as . + // + // Elements of are used for when a shape + // function makes a call to MakeShapeFromShapeTensor; in particular, when + // the input_tensors[i] is nullptr but the shape represented by it is + // partially known from analysis of the graph. + // can have fewer elements than . Values of + // do not need to outlive the context. + // + // REQUIRES: is not NULL, and must outlive the + // InferenceContext. + InferenceContext( + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector>>>& + input_handle_shapes_and_types); + ~InferenceContext(); // Runs the shape inference function 'fn' with 'this' as the diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 66cfbf874751d9..57d8dc9353cd27 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,19 +37,11 @@ OpDef MakeOpDefWithLists() { return op_reg_data.op_def; } -TensorShapeProto S(std::initializer_list dims) { - PartialTensorShape shape(dims); - TensorShapeProto ret; - shape.AsProto(&ret); - return ret; +PartialTensorShape S(std::initializer_list dims) { + return PartialTensorShape(dims); } -TensorShapeProto Unknown() { - PartialTensorShape shape; - TensorShapeProto ret; - shape.AsProto(&ret); - return ret; -} +PartialTensorShape Unknown() { return PartialTensorShape(); } } // namespace @@ -1537,7 +1530,7 @@ void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { {}); auto make_shape = [&c](std::initializer_list dim_sizes) { ShapeHandle s; - TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s)); + TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); return s; }; auto get_shapes_and_types_from_context = [&](int idx) { @@ -1648,7 +1641,7 @@ void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) { {}); auto make_shape = [&c](std::initializer_list dim_sizes) { ShapeHandle s; - TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s)); + TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s)); return s; }; auto get_shapes_and_types_from_context = [&](int idx) { diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 03c39e6dc1c7ff..fbfd24538bc7a5 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -16,7 +16,6 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ #include -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/status.h" @@ -28,7 +27,6 @@ limitations under the License. namespace tensorflow { -class NodeDef; class Tensor; struct ShapeInferenceTestOp { diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index cc9f11ef6afe38..980703f190b159 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/resource_handle.pb.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/coding.h" diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index a164fe61b5f001..a8f9d215114da9 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/allocation_description.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/tensor_description.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" @@ -35,8 +35,13 @@ limitations under the License. namespace tensorflow { -class TensorBuffer; // Forward declaration. +// Forward declarations. In particular, we forward declare protos so that their +// symbols can be removed from .so exports. +class AllocationDescription; +class TensorBuffer; class TensorCApi; +class TensorDescription; +class TensorProto; /// @ingroup core /// Represents an n-dimensional array of values. diff --git a/tensorflow/core/framework/tensor_reference.h b/tensorflow/core/framework/tensor_reference.h index 186820785dd164..37e588d4f10898 100644 --- a/tensorflow/core/framework/tensor_reference.h +++ b/tensorflow/core/framework/tensor_reference.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_ #define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_ -#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 1284214952cd80..14d9cea20ea955 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index b2016074614297..e56c3d7b930681 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -35,6 +35,7 @@ namespace tensorflow { template class TensorShapeIter; class TensorShape; +class TensorShapeProto; class PartialTensorShape; // END_SKIP_DOXYGEN diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc index d6fe9a1511b3de..51a7b14fed2367 100644 --- a/tensorflow/core/framework/tensor_shape_test.cc +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 369f64e9e2d467..6c9c803af6c734 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/core/framework/versions.cc b/tensorflow/core/framework/versions.cc index 58937556d9ba56..3ff0723ceec257 100644 --- a/tensorflow/core/framework/versions.cc +++ b/tensorflow/core/framework/versions.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/core/framework/versions.h b/tensorflow/core/framework/versions.h index 01429b26a633d6..e8f07f9016a03e 100644 --- a/tensorflow/core/framework/versions.h +++ b/tensorflow/core/framework/versions.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_ #define TENSORFLOW_FRAMEWORK_VERSIONS_H_ -#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +class VersionDef; + // Check whether data with the given versions is compatible with the given // consumer and min producer. upper_name and lower_name are used to form // error messages upon failure. Example usage: diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index f798af85e15e36..3ed32068ae19b7 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/core/graph/costmodel.h" #include +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index e06e47926462e7..9469e3f98faf44 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -255,9 +257,11 @@ Status Node::input_node(int idx, const Node** const_n) const { // Graph Graph::Graph(const OpRegistryInterface* ops) - : ops_(ops, FunctionDefLibrary()), arena_(8 << 10 /* 8kB */) { - versions_.set_producer(TF_GRAPH_DEF_VERSION); - versions_.set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); + : ops_(ops, FunctionDefLibrary()), + versions_(new VersionDef), + arena_(8 << 10 /* 8kB */) { + versions_->set_producer(TF_GRAPH_DEF_VERSION); + versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); // Initialize the name interning table for assigned_device_name. device_names_.push_back(""); @@ -301,6 +305,9 @@ Graph::~Graph() { // destroy them. } +const VersionDef& Graph::versions() const { return *versions_; } +void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; } + Node* Graph::AddNode(const NodeDef& node_def, Status* status) { const OpDef* op_def; status->Update(ops_.LookUpOpDef(node_def.op(), &op_def)); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index e19e0b727d6df8..565f430455c4f2 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -41,10 +41,10 @@ limitations under the License. #include #include #include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/versions.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/graph/edgeset.h" #include "tensorflow/core/lib/core/arena.h" #include "tensorflow/core/lib/core/refcount.h" @@ -59,7 +59,9 @@ namespace tensorflow { class Edge; class EdgeSetTest; class Graph; +class GraphDef; class Node; +class VersionDef; class NeighborIter; // Declared below class NodeIter; // Declared below @@ -370,8 +372,8 @@ class Graph { static const int kControlSlot; // The GraphDef version range of this graph (see graph.proto). - const VersionDef& versions() const { return versions_; } - void set_versions(const VersionDef& versions) { versions_ = versions; } + const VersionDef& versions() const; + void set_versions(const VersionDef& versions); // Adds a new node to this graph, and returns it. Infers the Op and // input/output types for the node. *this owns the returned instance. @@ -514,7 +516,7 @@ class Graph { FunctionLibraryDefinition ops_; // GraphDef versions - VersionDef versions_; + const std::unique_ptr versions_; // Allocator which will give us good locality. core::Arena arena_; diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index a2929d0210b4b3..582c8727c6dc69 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -27,8 +27,10 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index b8d1879fa0cc62..f222b9b5f1dd96 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/kernels/ops_util.h" diff --git a/tensorflow/core/graph/graph_def_builder_test.cc b/tensorflow/core/graph/graph_def_builder_test.cc index 867eca0c41f46b..e85de71ef79988 100644 --- a/tensorflow/core/graph/graph_def_builder_test.cc +++ b/tensorflow/core/graph/graph_def_builder_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index f8c6895dfa164f..f452299a8347c6 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/costmodel.h" diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index ca49ea0ac49a80..cb9e4b7973a6cc 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 500ac129e8b5d0..138952dcb33e7b 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc index 84e796c96016ef..b73b084793e9d9 100644 --- a/tensorflow/core/grappler/clusters/single_machine_test.cc +++ b/tensorflow/core/grappler/clusters/single_machine_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/utils.h" diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 95329f3f14759c..e717f6e761f22a 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/grappler/costs/virtual_scheduler.h" diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc index 6f25e7b0d4d7f5..ec21f5f4260d86 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 96d37c9a97e262..37623f8997201f 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -60,6 +60,7 @@ cc_test( "//tensorflow/cc:scope", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -196,6 +197,7 @@ cc_test( ":virtual_placer", ":virtual_scheduler", "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -232,6 +234,7 @@ cc_library( ":cost_estimator", ":op_performance_data_cc", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/clusters:utils", "//third_party/eigen3", ], @@ -265,6 +268,7 @@ cc_library( ":virtual_scheduler", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", ], ) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index e530f66415be2f..569efaf96d68ab 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/types.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 7361968cea880d..cd673a64a052e0 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -179,7 +179,7 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( Status GraphProperties::InferStatically() { Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); shape_refiner.set_require_shape_inference_fns(false); ImportGraphDefOptions options; Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 29b6adef5e75cb..109f973956e8e8 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/grappler_item.h" diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index ba6686e7df9f18..7d3298ded4f7cd 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/utils.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 2fbd54d7591879..4135d9b3313d31 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 3be4f917bf6b66..11650cb6a2e74d 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -19,6 +19,9 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/op_types.h" diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index dc54b8d0d1671e..9743db33db2e97 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/grappler/costs/virtual_scheduler.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 607f2c286b3044..86775306ff4933 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -28,8 +28,10 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variable.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/inputs/utils.h" #include "tensorflow/core/grappler/op_types.h" diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc index d46b849ad416c2..3f58a2abeac9f7 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/devices.h" diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 0ed73bfae933ac..63159dc3aac53c 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.pb.h" diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 28d663e2f7fdb3..ddf8a5acd28e33 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/devices.h" diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 16a638a7d31241..462cfb928f6f38 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0b1c1085f28fca..f0dc3312c27297 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3337,7 +3337,7 @@ tf_kernel_library( tf_kernel_library( name = "serialize_sparse_op", prefix = "serialize_sparse_op", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + ["//tensorflow/core:protos_all_cc"], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc index 76f8c225432dd7..f5a74048af4e8b 100644 --- a/tensorflow/core/kernels/decode_image_op.cc +++ b/tensorflow/core/kernels/decode_image_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/gif/gif_io.h" #include "tensorflow/core/lib/jpeg/jpeg_mem.h" #include "tensorflow/core/lib/png/png_io.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index 73f30cdae37ffb..9526f1119b73e0 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index 04697c3b15f477..8538ebc5c66aa7 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index a82ae61ad9dde9..65677cc3de3571 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -87,7 +87,7 @@ Status GraphTransferer::LoadGraphFromProto( const std::vector& output_node_names, const bool shape_inference_for_unknown_shape) { Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); Status status = ImportGraphDef({}, graph_def, &graph, &shape_refiner); if (!status.ok()) { return status; diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index cb9091e29f8ebe..130109b813fce5 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -26,6 +26,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp #include +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h" diff --git a/tensorflow/core/kernels/identity_reader_op.cc b/tensorflow/core/kernels/identity_reader_op.cc index ddd012b910810a..6e5714b313887e 100644 --- a/tensorflow/core/kernels/identity_reader_op.cc +++ b/tensorflow/core/kernels/identity_reader_op.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include "tensorflow/core/framework/reader_base.h" +#include "tensorflow/core/framework/reader_base.pb.h" #include "tensorflow/core/framework/reader_op_kernel.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc index 0cb63abf2a276a..5eba4288acccfb 100644 --- a/tensorflow/core/kernels/reduction_ops_common.cc +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/kernels/reduction_ops_common.h" +#include "tensorflow/core/lib/strings/str_util.h" + namespace tensorflow { TensorShape ReductionHelper::out_reshape() const { diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 890ff15c4875dc..ed088bfba7eddc 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/public/session.h" @@ -580,8 +582,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( } else { ImportGraphDefOptions opts; Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), - graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); TF_RETURN_IF_ERROR( ImportGraphDef(opts, *graph_def, &graph, &shape_refiner)); TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors, @@ -724,7 +725,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( const std::unordered_set& node_names, const GraphDef& graph_def, std::vector* cluster_infos) { Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); std::unordered_set remaining_nodes = node_names; @@ -829,7 +830,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster)); Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); for (Node* node : graph.nodes()) { @@ -883,7 +884,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( const std::vector& border_outputs, const GraphDef& graph_def, ClusterInfo* cluster) { Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); std::unordered_set visited; @@ -955,7 +956,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def)); Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); TF_RETURN_IF_ERROR( ImportGraphDef({}, input_graph_def, &graph, &shape_refiner)); diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc index 059070305f44f5..e7a2da4c532fc1 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc @@ -308,7 +308,7 @@ TEST(RemoteFusedGraphExecuteUtils, PropagateAndBuildTensorShapeMap) { NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); ImportGraphDefOptions opts; Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); Status status = ImportGraphDef(opts, def, &graph, &shape_refiner); ASSERT_TRUE(RemoteFusedGraphExecuteUtils::PropagateShapeInference( def, inputs, &graph, &shape_refiner) @@ -427,7 +427,7 @@ TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) { NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); TF_ASSERT_OK(ImportGraphDef({}, def, &graph, &shape_refiner)); Node* node; diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 49fcf878d5de2d..1428546d52ad96 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/dense_update_ops.h" #include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 67234e2a401cfe..2c7ad5bab08c40 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index c7bf250fad79de..07d935d55fe061 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/util/work_sharder.h" diff --git a/tensorflow/core/kernels/unique_op_test.cc b/tensorflow/core/kernels/unique_op_test.cc index 0dc9066273f282..176280b7a10119 100644 --- a/tensorflow/core/kernels/unique_op_test.cc +++ b/tensorflow/core/kernels/unique_op_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/node_builder.h" diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index 8f42bb28324ecc..d7ef4bc4c49b72 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -18,11 +18,13 @@ limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/reader_base.h" +#include "tensorflow/core/framework/reader_base.pb.h" #include "tensorflow/core/framework/reader_op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 557443f2299002..1b4d5436c063ab 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/util/mirror_pad_mode.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/strided_slice_op.h" @@ -172,11 +173,11 @@ REGISTER_OP("ParallelConcat") .Attr("shape: shape") .SetShapeFn([](InferenceContext* c) { // Validate that the shape attr is correct. - TensorShapeProto passed_shape_proto; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &passed_shape_proto)); + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); ShapeHandle passed_shape; TF_RETURN_IF_ERROR( - c->MakeShapeFromShapeProto(passed_shape_proto, &passed_shape)); + c->MakeShapeFromPartialTensorShape(shape, &passed_shape)); if (!c->FullyDefined(passed_shape)) { return errors::InvalidArgument("shape attr must be fully defined."); } @@ -637,11 +638,9 @@ REGISTER_OP("ImmutableConst") .SetShapeFn([](InferenceContext* c) { TensorShape shape_from_attr; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_from_attr)); - TensorShapeProto shape_proto; - shape_from_attr.AsProto(&shape_proto); ShapeHandle output_shape; TF_RETURN_IF_ERROR( - c->MakeShapeFromShapeProto(shape_proto, &output_shape)); + c->MakeShapeFromPartialTensorShape(shape_from_attr, &output_shape)); c->set_output(0, output_shape); return Status::OK(); }) @@ -1306,11 +1305,11 @@ REGISTER_OP("_ParallelConcatStart") .Attr("dtype: type") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { - TensorShapeProto shape_proto; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_proto)); + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); ShapeHandle output_shape; TF_RETURN_IF_ERROR( - c->MakeShapeFromShapeProto(shape_proto, &output_shape)); + c->MakeShapeFromPartialTensorShape(shape, &output_shape)); c->set_output(0, output_shape); return Status::OK(); }) @@ -2266,11 +2265,10 @@ REGISTER_OP("StridedSlice") return Status::OK(); } - TensorShapeProto input_shape_proto; + PartialTensorShape input_shape({}); for (int i = 0; i < c->Rank(input); ++i) { auto dim = c->Dim(input, i); - input_shape_proto.add_dim()->set_size(c->ValueKnown(dim) ? c->Value(dim) - : -1); + input_shape.AddDim(c->ValueKnown(dim) ? c->Value(dim) : -1); } int32 begin_mask, end_mask, ellipsis_mask, new_axis_mask, @@ -2288,7 +2286,7 @@ REGISTER_OP("StridedSlice") bool is_identity, is_simple_slice, slice_dim0; gtl::InlinedVector begin, end, strides; TF_RETURN_IF_ERROR(ValidateStridedSliceOp( - begin_value, end_value, *strides_value, input_shape_proto, begin_mask, + begin_value, end_value, *strides_value, input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape, &final_shape, &is_identity, &is_simple_slice, &slice_dim0, &begin, &end, &strides)); @@ -2866,10 +2864,8 @@ REGISTER_OP("Placeholder") return shape_inference::UnknownShape(c); } - TensorShapeProto shape_proto; - shape.AsProto(&shape_proto); ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); c->set_output(0, out); return Status::OK(); }) @@ -2894,10 +2890,10 @@ REGISTER_OP("PlaceholderV2") .Attr("dtype: type") .Attr("shape: shape") .SetShapeFn([](InferenceContext* c) { - TensorShapeProto shape; + PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); ShapeHandle output; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape, &output)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output)); c->set_output(0, output); return Status::OK(); }) @@ -2925,10 +2921,8 @@ REGISTER_OP("PlaceholderWithDefault") ShapeHandle input = c->input(0); PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - TensorShapeProto shape_proto; - shape.AsProto(&shape_proto); ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); // We merge for compatibility checking, but return the output, // since output_shape may be less precise than input_shape. @@ -3899,9 +3893,8 @@ REGISTER_OP("SpaceToDepth") TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size, &output_depth)); - c->set_output(0, - c->MakeShape({c->Dim(input, 0), output_height, output_width, - output_depth})); + c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height, + output_width, output_depth})); return Status::OK(); }) .Doc(R"doc( @@ -4005,9 +3998,8 @@ REGISTER_OP("DepthToSpace") TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size, true /* evenly_divisible */, &output_depth)); - c->set_output(0, - c->MakeShape({c->Dim(input, 0), output_height, output_width, - output_depth})); + c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height, + output_width, output_depth})); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index b1d334e4545f7f..dc5d46e6fae3f0 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index c10e667f564ee6..28f9969de56c93 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference_testutil.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 22f87f5fdfee8f..2e605fdffcfbb2 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -85,11 +85,10 @@ REGISTER_OP("ParseExample") } // Output dense_shapes. - TensorShapeProto shape_proto; for (int i = 0; i < attrs.num_dense; ++i) { - attrs.dense_shapes[i].AsProto(&shape_proto); ShapeHandle dense; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &dense)); + TF_RETURN_IF_ERROR( + c->MakeShapeFromPartialTensorShape(attrs.dense_shapes[i], &dense)); TF_RETURN_IF_ERROR(c->Concatenate(input, dense, &dense)); c->set_output(output_idx++, dense); } @@ -196,11 +195,10 @@ REGISTER_OP("ParseSingleSequenceExample") } // Output context_dense_shapes. - TensorShapeProto shape_proto; for (int i = 0; i < attrs.num_context_dense; ++i) { - attrs.context_dense_shapes[i].AsProto(&shape_proto); ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + attrs.context_dense_shapes[i], &s)); c->set_output(output_idx++, s); } @@ -218,9 +216,9 @@ REGISTER_OP("ParseSingleSequenceExample") // Output feature_list_dense_shapes. for (int i = 0; i < attrs.num_feature_list_dense; ++i) { - attrs.feature_list_dense_shapes[i].AsProto(&shape_proto); ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + attrs.feature_list_dense_shapes[i], &s)); TF_RETURN_IF_ERROR( c->Concatenate(c->Vector(InferenceContext::kUnknownDim), s, &s)); c->set_output(output_idx++, s); diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc index 5c29e21d00b059..c6e521e33e9801 100644 --- a/tensorflow/core/ops/parsing_ops_test.cc +++ b/tensorflow/core/ops/parsing_ops_test.cc @@ -59,25 +59,24 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) { INFER_ERROR("Shape of a default must be", op, "?;[2];?"); } -static std::vector MakeDenseShapes(int size, - bool add_extra_shape, - int unknown_outer_dims) { - std::vector shapes(size); +static std::vector MakeDenseShapes(int size, + bool add_extra_shape, + int unknown_outer_dims) { + std::vector shapes(size); for (int i = 0; i < size; ++i) { // Make shapes be the sequence [?,1]; [?,1,2], [?,1,2,3]... // where the number of prefixed ? depends on unknown_outer_dims. if (i == 0) { + shapes[i].Clear(); for (int d = 0; d < unknown_outer_dims; ++d) { - shapes[i].add_dim()->set_size(-1); + shapes[i].AddDim(-1); } } else { shapes[i] = shapes[i - 1]; } - shapes[i].add_dim()->set_size(i + 1); - } - if (add_extra_shape) { - shapes.resize(shapes.size() + 1); + shapes[i].AddDim(i + 1); } + if (add_extra_shape) shapes.push_back(PartialTensorShape({})); return shapes; } diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 3b48559b1fc3f9..034946f17adb2c 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -68,10 +68,10 @@ REGISTER_OP("VarHandleOp") c->set_output(0, c->Scalar()); DataType t; TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); - TensorShapeProto p; + PartialTensorShape p; TF_RETURN_IF_ERROR(c->GetAttr("shape", &p)); ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); c->set_output_handle_shapes_and_types(0, std::vector{{s, t}}); diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 35f965b6a90f3e..1c9ae90a26fcf3 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -30,11 +30,11 @@ REGISTER_OP("VariableV2") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { - TensorShapeProto shape_proto; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_proto)); + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); ShapeHandle output_shape; TF_RETURN_IF_ERROR( - c->MakeShapeFromShapeProto(shape_proto, &output_shape)); + c->MakeShapeFromPartialTensorShape(shape, &output_shape)); c->set_output(0, output_shape); return Status::OK(); }) @@ -72,10 +72,8 @@ REGISTER_OP("Variable") return shape_inference::UnknownShape(c); } - TensorShapeProto shape_proto; - shape.AsProto(&shape_proto); ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); c->set_output(0, out); return Status::OK(); }) @@ -103,10 +101,10 @@ REGISTER_OP("TemporaryVariable") .Attr("var_name: string = ''") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { - TensorShapeProto shape_proto; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_proto)); + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); ShapeHandle output; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &output)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output)); c->set_output(0, output); return Status::OK(); }) diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc index bcc1c924937ec0..6d05dd0b96c3c5 100644 --- a/tensorflow/core/ops/state_ops_test.cc +++ b/tensorflow/core/ops/state_ops_test.cc @@ -63,49 +63,41 @@ TEST(StateOpsTest, ScatterUpdate_ShapeFn) { TEST(StateOpsTest, TemporaryVariable_ShapeFn) { ShapeInferenceTestOp op("TemporaryVariable"); TensorShape shape({1, 2, 3}); - TensorShapeProto shape_proto; - shape.AsProto(&shape_proto); TF_ASSERT_OK(NodeDefBuilder("test", "TemporaryVariable") - .Attr("shape", shape_proto) + .Attr("shape", shape) .Finalize(&op.node_def)); INFER_OK(op, "", "[1,2,3]"); } TEST(StateOpsTest, Variable_ShapeFn) { ShapeInferenceTestOp op("Variable"); - TensorShapeProto shape_proto; // Unknown rank. - PartialTensorShape().AsProto(&shape_proto); TF_ASSERT_OK(NodeDefBuilder("test", "Variable") - .Attr("shape", shape_proto) + .Attr("shape", PartialTensorShape()) .Finalize(&op.node_def)); INFER_OK(op, "", "?"); // For historical reasons an empty TensorShapeProto can be either an unknown // rank or a scalar, so the shape function conservatively says "unknown" - shape_proto.Clear(); TF_ASSERT_OK(NodeDefBuilder("test", "Variable") - .Attr("shape", shape_proto) + .Attr("shape", TensorShape({})) .Finalize(&op.node_def)); INFER_OK(op, "", "?"); // Specified shape. - TensorShape({1, 2, 3}).AsProto(&shape_proto); TF_ASSERT_OK(NodeDefBuilder("test", "Variable") - .Attr("shape", shape_proto) + .Attr("shape", TensorShape({1, 2, 3})) .Finalize(&op.node_def)); INFER_OK(op, "", "[1,2,3]"); } TEST(StateOpsTest, VariableV2_ShapeFn) { ShapeInferenceTestOp op("VariableV2"); - TensorShapeProto shape_proto; // Unknown rank. - shape_proto.set_unknown_rank(true); TF_ASSERT_OK(NodeDefBuilder("test", "VariableV2") - .Attr("shape", shape_proto) + .Attr("shape", PartialTensorShape()) .Finalize(&op.node_def)); INFER_OK(op, "", "?"); @@ -116,9 +108,8 @@ TEST(StateOpsTest, VariableV2_ShapeFn) { INFER_OK(op, "", "[]"); // Specified shape. - TensorShape({1, 2, 3}).AsProto(&shape_proto); TF_ASSERT_OK(NodeDefBuilder("test", "VariableV2") - .Attr("shape", shape_proto) + .Attr("shape", TensorShape({1, 2, 3})) .Finalize(&op.node_def)); INFER_OK(op, "", "[1,2,3]"); } diff --git a/tensorflow/core/util/equal_graph_def.cc b/tensorflow/core/util/equal_graph_def.cc index 45d6a6662a538c..919a46bfb85044 100644 --- a/tensorflow/core/util/equal_graph_def.cc +++ b/tensorflow/core/util/equal_graph_def.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/core/util/equal_graph_def.h b/tensorflow/core/util/equal_graph_def.h index 1ce6181c2e7e41..14f5bdfda4d0b1 100644 --- a/tensorflow/core/util/equal_graph_def.h +++ b/tensorflow/core/util/equal_graph_def.h @@ -16,13 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ #define TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { +class GraphDef; +class NodeDef; + struct EqualGraphDefOptions { // Should internal attributes (attribute names that start with '_') be // ignored? diff --git a/tensorflow/core/util/memmapped_file_system_test.cc b/tensorflow/core/util/memmapped_file_system_test.cc index 1d01c6b0839bf1..24ce5ebafce36a 100644 --- a/tensorflow/core/util/memmapped_file_system_test.cc +++ b/tensorflow/core/util/memmapped_file_system_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/util/memmapped_file_system.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h index 6f56d9c25a0c5f..2e6003226c6508 100644 --- a/tensorflow/core/util/padding.h +++ b/tensorflow/core/util/padding.h @@ -21,11 +21,13 @@ limitations under the License. #include -#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +class NodeDef; + // Padding: the padding we apply to the input tensor along the rows and columns // dimensions. This is usually used to make sure that the spatial dimensions do // not shrink when we progress with convolutions. Two types of padding are diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc index fa59f735818c24..8447028e382438 100644 --- a/tensorflow/core/util/stat_summarizer.cc +++ b/tensorflow/core/util/stat_summarizer.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -35,6 +37,14 @@ StatSummarizer::StatSummarizer(const StatSummarizerOptions& options) StatSummarizer::StatSummarizer(const tensorflow::GraphDef& tensorflow_graph) : StatSummarizer(StatSummarizerOptions()) {} +StatSummarizer::~StatSummarizer() {} + +void StatSummarizer::Reset() { + run_total_us_.Reset(); + memory_.Reset(); + details_.clear(); +} + void StatSummarizer::Validate(const Detail* detail, const NodeExecStats& ns) const { if (detail->outputs.size() != ns.output_size()) { diff --git a/tensorflow/core/util/stat_summarizer.h b/tensorflow/core/util/stat_summarizer.h index 6111e276ea69b9..f7b63e86869c27 100644 --- a/tensorflow/core/util/stat_summarizer.h +++ b/tensorflow/core/util/stat_summarizer.h @@ -154,6 +154,8 @@ class StatSummarizer { // GraphDef is not needed by the StatSummarizer. explicit StatSummarizer(const tensorflow::GraphDef& tensorflow_graph); + ~StatSummarizer(); + // Adds another run's StepStats output to the aggregate counts. void ProcessStepStats(const StepStats& step_stats); @@ -181,11 +183,7 @@ class StatSummarizer { SortingMetric sorting_metric, int num_stats) const; - void Reset() { - run_total_us_.Reset(); - memory_.Reset(); - details_.clear(); - } + void Reset(); // Returns number of runs. int num_runs() const { return run_total_us_.count(); } diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 301eae2c4de11a..41dc9f8a78210c 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc index be636c04c47bbf..d935eba2e5d05d 100644 --- a/tensorflow/core/util/tensor_slice_writer_test.cc +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc index 04c34c5968530a..4a04e5c3c94662 100644 --- a/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc +++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0c04d36aff3204..b761391f91d3ba 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -233,6 +233,7 @@ cc_library( ":numpy_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:script_ops_op_lib", "//third_party/py/numpy:headers", "//util/python:python_headers", diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index 8ebdbafb85c5a5..34f68b4fae2a61 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/python/framework/cpp_shape_inference.pb.h" diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index 1374967ca71dc6..4ec7620bce9462 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/python/grappler/model_analyzer.h" #include +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index c48296eccb0fca..a1618d5349cd1b 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "numpy/arrayobject.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 1c4958d83c935e..8123b0a0c6206f 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc index 4eb074998f71e8..c0107014e2cf11 100644 --- a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index 91670f54d49d05..e79e7ba121c93d 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -23,8 +23,10 @@ limitations under the License. // bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ // --in_graph=my_graph.pb +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 6ed549a9589af2..2db0a24267bba4 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -20,10 +20,12 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow {