diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index e037a00b9844ed..15c8633ed4417b 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -268,10 +268,14 @@ bool ParseExample(protobuf::io::CodedInputStream* stream, parsed::Example* example) { DCHECK(stream != nullptr); DCHECK(example != nullptr); - if (stream->ExpectTag(kDelimitedTag(1))) { - if (!ParseFeatures(stream, example)) return false; + // Loop over the input stream which may contain multiple serialized Example + // protos merged together as strings. This behavior is consistent with Proto's + // ParseFromString when string representations are concatenated. + while (!stream->ExpectAtEnd()) { + if (stream->ExpectTag(kDelimitedTag(1))) { + if (!ParseFeatures(stream, example)) return false; + } } - if (!stream->ExpectAtEnd()) return false; return true; } @@ -439,6 +443,8 @@ Status FastParseSerializedExample( size, " but output shape: ", shape.DebugString()); }; + // TODO(b/31499934): Make sure concatented serialized tf.Example protos + // get parsed correctly when they contain dense features and add tests. switch (config.dense[d].dtype) { case DT_INT64: { SmallVector list; diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 9da3f6ad2a2cb8..0590d801c3badc 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -63,6 +63,23 @@ void TestCorrectness(const string& serialized) { // TestCorrectness(example); // } +// Test that concatenating two Example protos in their serialized string +// representations gets parsed identically by TestFastParse(..) and the regular +// Example.ParseFromString(..). +TEST(FastParse, SingleInt64WithContext) { + Example example; + (*example.mutable_features()->mutable_feature())["age"] + .mutable_int64_list() + ->add_value(13); + + Example context; + (*context.mutable_features()->mutable_feature())["zipcode"] + .mutable_int64_list() + ->add_value(94043); + + TestCorrectness(strings::StrCat(Serialize(example), Serialize(context))); +} + TEST(FastParse, NonPacked) { TestCorrectness( "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d");