diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index dc9d9b0f8cb8c..efd6557e36237 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -115,8 +115,13 @@ public FieldType getBaseType() { } /** Create a {@link Value} specifying which field to set and the value to set. */ - public Value createValue(String caseType, T value) { - return createValue(getCaseEnumType().valueOf(caseType), value); + public Value createValue(String caseValue, T value) { + return createValue(getCaseEnumType().valueOf(caseValue), value); + } + + /** Create a {@link Value} specifying which field to set and the value to set. */ + public Value createValue(int caseValue, T value) { + return createValue(getCaseEnumType().valueOf(caseValue), value); } /** Create a {@link Value} specifying which field to set and the value to set. */ diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java index e9a5d48ed35b6..2b47496f449be 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java @@ -163,15 +163,15 @@ private void indexDescriptorByName() { .values() .forEach( fileDescriptor -> { - fileDescriptor - .getMessageTypes() - .forEach( - descriptor -> { - descriptorMap.put(descriptor.getFullName(), descriptor); - }); + fileDescriptor.getMessageTypes().forEach(descriptor -> indexDescriptor(descriptor)); }); } + private void indexDescriptor(Descriptors.Descriptor descriptor) { + descriptorMap.put(descriptor.getFullName(), descriptor); + descriptor.getNestedTypes().forEach(nested -> indexDescriptor(nested)); + } + private void indexOptionsByNumber(Collection fileDescriptors) { fieldOptionMap = new HashMap<>(); fileOptionMap = new HashMap<>(); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java new file mode 100644 index 0000000000000..6d39b68e1aa59 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java @@ -0,0 +1,797 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapKeyMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapValueMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Duration; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.Timestamp; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; + +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoDynamicMessageSchema implements Serializable { + static class Context implements Serializable { + private final Schema schema; + + Context(Schema schema) { + this.schema = schema; + } + + public Schema getSchema() { + return schema; + } + + public DynamicMessage.Builder invokeNewBuilder() { + throw new IllegalStateException("Should not be calling invokeNewBuilder"); + } + + public Context getSubContext(Schema.Field field) { + return new Context(field.getType().getRowSchema()); + } + } + + static class DescriptorContext extends Context { + private final String messageName; + private final ProtoDomain domain; + private transient Descriptors.Descriptor descriptor; + + DescriptorContext(String messageName, ProtoDomain domain) { + super(ProtoSchemaTranslator.getSchema(domain.getDescriptor(messageName))); + this.messageName = messageName; + this.domain = domain; + } + + @Override + public DynamicMessage.Builder invokeNewBuilder() { + if (descriptor == null) { + descriptor = domain.getDescriptor(messageName); + } + return DynamicMessage.newBuilder(descriptor); + } + + @Override + public Context getSubContext(Schema.Field field) { + String messageName = getMessageName(field.getType()); + return new DescriptorContext(messageName, domain); + } + } + + public static final long serialVersionUID = 1L; + private final Context context; + private transient SchemaCoder schemaCoder; + private transient List getters; + + private ProtoDynamicMessageSchema(String messageName, ProtoDomain domain) { + this.context = new DescriptorContext(messageName, domain); + init(); + } + + private ProtoDynamicMessageSchema(Context context) { + this.context = context; + init(); + } + + private Object readResolve() { + init(); + return this; + } + + private void init() { + getters = createFieldLayer(context.getSchema()); + schemaCoder = + SchemaCoder.of( + context.getSchema(), + TypeDescriptor.of(Message.class), + new MessageToRowFunction(), + new RowToMessageFunction()); + } + + ProtoFieldConvert createFieldLayer(Schema.Field field) { + Schema.FieldType fieldType = field.getType(); + String messageName = getMessageName(fieldType); + if (messageName != null && messageName.length() > 0) { + Schema.Field valueField = + Schema.Field.of("value", withFieldNumber(Schema.FieldType.BOOLEAN, 1)); + switch (messageName) { + case "google.protobuf.StringValue": + case "google.protobuf.DoubleValue": + case "google.protobuf.FloatValue": + case "google.protobuf.BoolValue": + case "google.protobuf.Int64Value": + case "google.protobuf.Int32Value": + case "google.protobuf.UInt64Value": + case "google.protobuf.UInt32Value": + return new WrapperConvert(field, new PrimitiveConvert(valueField)); + case "google.protobuf.BytesValue": + return new WrapperConvert(field, new BytesConvert(valueField)); + case "google.protobuf.Timestamp": + case "google.protobuf.Duration": + // handled by logical type case + break; + } + } + switch (fieldType.getTypeName()) { + case BYTE: + case INT16: + case INT32: + case INT64: + case FLOAT: + case DOUBLE: + case STRING: + case BOOLEAN: + return new PrimitiveConvert(field); + case BYTES: + return new BytesConvert(field); + case ARRAY: + case ITERABLE: + return new ArrayConvert(this, field); + case MAP: + return new MapConvert(this, field); + case LOGICAL_TYPE: + String identifier = field.getType().getLogicalType().getIdentifier(); + switch (identifier) { + case ProtoSchemaLogicalTypes.Fixed32.IDENTIFIER: + case ProtoSchemaLogicalTypes.Fixed64.IDENTIFIER: + case ProtoSchemaLogicalTypes.SFixed32.IDENTIFIER: + case ProtoSchemaLogicalTypes.SFixed64.IDENTIFIER: + case ProtoSchemaLogicalTypes.SInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.SInt64.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt64.IDENTIFIER: + return new LogicalTypeConvert(field, fieldType.getLogicalType()); + case ProtoSchemaLogicalTypes.TimestampNanos.IDENTIFIER: + return new TimestampConvert(field, fieldType.getLogicalType()); + case ProtoSchemaLogicalTypes.DurationNanos.IDENTIFIER: + return new DurationConvert(field, fieldType.getLogicalType()); + case EnumerationType.IDENTIFIER: + return new EnumConvert(field, fieldType.getLogicalType()); + case OneOfType.IDENTIFIER: + return new OneOfConvert(this, field, fieldType.getLogicalType()); + default: + throw new IllegalStateException("Unexpected logical type : " + identifier); + } + case ROW: + return new MessageConvert(this, field); + default: + throw new IllegalStateException("Unexpected value: " + fieldType); + } + } + + private List createFieldLayer(Schema schema) { + + List fieldOverlays = new ArrayList<>(); + for (Schema.Field field : schema.getFields()) { + fieldOverlays.add(createFieldLayer(field)); + } + return fieldOverlays; + } + + public Schema getSchema() { + return this.schemaCoder.getSchema(); + } + + public SchemaCoder getSchemaCoder() { + return schemaCoder; + } + + public SerializableFunction getToRowFunction() { + return schemaCoder.getToRowFunction(); + } + + public SerializableFunction getFromRowFunction() { + return schemaCoder.getFromRowFunction(); + } + + public static ProtoDynamicMessageSchema forDescriptor(ProtoDomain domain, String messageName) { + return new ProtoDynamicMessageSchema(messageName, domain); + } + + public static ProtoDynamicMessageSchema forDescriptor( + ProtoDomain domain, Descriptors.Descriptor descriptor) { + return new ProtoDynamicMessageSchema(descriptor.getFullName(), domain); + } + + static ProtoDynamicMessageSchema forContext(Context context, Schema.Field field) { + return new ProtoDynamicMessageSchema(context.getSubContext(field)); + } + + static ProtoDynamicMessageSchema forSchema(Schema schema) { + return new ProtoDynamicMessageSchema(new Context(schema)); + } + + private class MessageToRowFunction implements SerializableFunction { + + private MessageToRowFunction() {} + + @Override + public Row apply(Message input) { + Schema schema = schemaCoder.getSchema(); + Row.Builder builder = Row.withSchema(schema); + for (ProtoFieldConvert getter : getters) { + builder.addValue(getter.getFromProtoMessage(input)); + } + return builder.build(); + } + } + + private class RowToMessageFunction implements SerializableFunction { + + private RowToMessageFunction() {} + + @Override + public Message apply(Row input) { + DynamicMessage.Builder builder = context.invokeNewBuilder(); + Iterator values = input.getValues().iterator(); + Iterator overlayIterator = getters.iterator(); + + for (int i = 0; i < input.getValues().size(); i++) { + ProtoFieldConvert getter = overlayIterator.next(); + Object value = values.next(); + getter.setOnProtoMessage(builder, value); + } + return builder.build(); + } + } + + abstract static class ProtoFieldConvert { + + private int number; + + FieldDescriptor getFieldDescriptor(Message message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + FieldDescriptor getFieldDescriptor(Message.Builder message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + ProtoFieldConvert(Schema.Field field) { + try { + this.number = getFieldNumber(field.getType()); + } catch (NumberFormatException e) { + this.number = -1; + } + } + + abstract Object getFromProtoMessage(Message message); + + abstract ValueT convertFromProtoValue(Object object); + + abstract void setOnProtoMessage(Message.Builder object, InT value); + + abstract Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value); + } + + static class PrimitiveConvert extends ProtoFieldConvert { + PrimitiveConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + Object convertFromProtoValue(Object object) { + return object; + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Converter for Bytes. Protobuf Bytes are natively represented as ByteStrings that requires + * special handling for byte[] of size 0. + */ + static class BytesConvert extends PrimitiveConvert { + BytesConvert(Schema.Field field) { + super(field); + } + + @Override + Object convertFromProtoValue(Object object) { + // return object; + return ((ByteString) object).toByteArray(); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null && ((byte[]) value).length > 0) { + // Protobuf messages BYTES doesn't like empty bytes?! + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + if (value != null) { + return ByteString.copyFrom((byte[]) value); + } + return null; + } + } + + static class WrapperConvert extends ProtoFieldConvert { + private ProtoFieldConvert valueConvert; + + WrapperConvert(Schema.Field field, ProtoFieldConvert valueConvert) { + super(field); + this.valueConvert = valueConvert; + } + + @Override + Object getFromProtoMessage(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + Message wrapper = (Message) message.getField(getFieldDescriptor(message)); + return valueConvert.getFromProtoMessage(wrapper); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + return object; + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(getFieldDescriptor(message).getMessageType()); + valueConvert.setOnProtoMessage(builder, value); + message.setField(getFieldDescriptor(message), builder.build()); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class TimestampConvert extends ProtoFieldConvert { + ProtoSchemaLogicalTypes.TimestampNanos logicalType; + + TimestampConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (ProtoSchemaLogicalTypes.TimestampNanos) logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertFromProtoValue(wrapper); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + Message timestamp = (Message) object; + try { + return Timestamp.parseFrom(timestamp.toByteArray()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Unable to parse timestamp"); + } + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + message.setField(getFieldDescriptor(message), logicalType.toInputType((Row) value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class DurationConvert extends ProtoFieldConvert { + ProtoSchemaLogicalTypes.DurationNanos logicalType; + + DurationConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (ProtoSchemaLogicalTypes.DurationNanos) logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertFromProtoValue(wrapper); + } + return null; + } + + @Override + Duration convertFromProtoValue(Object object) { + Message timestamp = (Message) object; + try { + return Duration.parseFrom(timestamp.toByteArray()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Unable to parse timestamp"); + } + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + message.setField(getFieldDescriptor(message), logicalType.toInputType((Row) value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class MessageConvert extends ProtoFieldConvert { + private final SerializableFunction fromRowFunction; + private final SerializableFunction toRowFunction; + + MessageConvert(ProtoDynamicMessageSchema rootProtoSchema, Schema.Field field) { + super(field); + ProtoDynamicMessageSchema protoSchema = + ProtoDynamicMessageSchema.forContext(rootProtoSchema.context, field); + SchemaCoder schemaCoder = protoSchema.getSchemaCoder(); + toRowFunction = schemaCoder.getToRowFunction(); + fromRowFunction = schemaCoder.getFromRowFunction(); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + return toRowFunction.apply(object); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return fromRowFunction.apply(value); + } + } + + /** + * Proto has a well defined way of storing maps, by having a Message with two fields, named "key" + * and "value" in a repeatable field. This overlay translates between Row.map and the Protobuf + * map. + */ + static class MapConvert extends ProtoFieldConvert { + private ProtoFieldConvert key; + private ProtoFieldConvert value; + + MapConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { + super(field); + Schema.FieldType fieldType = field.getType(); + key = + protoSchema.createFieldLayer( + Schema.Field.of( + "KEY", + withMessageName(fieldType.getMapKeyType(), getMapKeyMessageName(fieldType)))); + value = + protoSchema.createFieldLayer( + Schema.Field.of( + "VALUE", + withMessageName(fieldType.getMapValueType(), getMapValueMessageName(fieldType)))); + } + + @Override + Map getFromProtoMessage(Message message) { + List list = (List) message.getField(getFieldDescriptor(message)); + if (list.size() == 0) { + return null; + } + Map rowMap = new HashMap<>(); + list.forEach( + entryMessage -> { + Descriptors.Descriptor entryDescriptor = entryMessage.getDescriptorForType(); + FieldDescriptor keyFieldDescriptor = entryDescriptor.findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = entryDescriptor.findFieldByName("value"); + rowMap.put( + key.convertFromProtoValue(entryMessage.getField(keyFieldDescriptor)), + this.value.convertFromProtoValue(entryMessage.getField(valueFieldDescriptor))); + }); + return rowMap; + } + + @Override + Map convertFromProtoValue(Object object) { + throw new RuntimeException("?"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Map map) { + if (map != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List messageMap = new ArrayList<>(); + map.forEach( + (k, v) -> { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(fieldDescriptor.getMessageType()); + FieldDescriptor keyFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("key"); + builder.setField( + keyFieldDescriptor, this.key.convertToProtoValue(keyFieldDescriptor, k)); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + builder.setField( + valueFieldDescriptor, value.convertToProtoValue(valueFieldDescriptor, v)); + messageMap.add(builder.build()); + }); + message.setField(fieldDescriptor, messageMap); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class ArrayConvert extends ProtoFieldConvert { + private ProtoFieldConvert element; + + ArrayConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { + super(field); + Schema.FieldType collectionElementType = field.getType().getCollectionElementType(); + this.element = + protoSchema.createFieldLayer( + Schema.Field.of( + "ELEMENT", + withMessageName(collectionElementType, getMessageName(field.getType())))); + } + + @Override + List getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + List convertFromProtoValue(Object value) { + List list = (List) value; + List arrayList = new ArrayList<>(); + list.forEach( + entry -> { + arrayList.add(element.convertFromProtoValue(entry)); + }); + return arrayList; + } + + @Override + void setOnProtoMessage(Message.Builder message, List list) { + if (list != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List targetList = new ArrayList<>(); + list.forEach( + (e) -> { + targetList.add(element.convertToProtoValue(fieldDescriptor, e)); + }); + message.setField(fieldDescriptor, targetList); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** Enum overlay handles the conversion between a string and a ProtoBuf Enum. */ + static class EnumConvert extends ProtoFieldConvert { + EnumerationType logicalType; + + EnumConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (EnumerationType) logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + EnumerationType.Value convertFromProtoValue(Object in) { + return logicalType.valueOf(((Descriptors.EnumValueDescriptor) in).getNumber()); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Descriptors.EnumDescriptor enumType = fieldDescriptor.getEnumType(); + return enumType.findValueByNumber((Integer) value); + } + } + + static class OneOfConvert extends ProtoFieldConvert { + OneOfType logicalType; + Map oneOfConvert = new HashMap<>(); + + OneOfConvert( + ProtoDynamicMessageSchema protoSchema, Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (OneOfType) logicalType; + for (Schema.Field oneOfField : this.logicalType.getOneOfSchema().getFields()) { + int fieldNumber = getFieldNumber(oneOfField.getType()); + oneOfConvert.put( + fieldNumber, new NullableConvert(oneOfField, protoSchema.createFieldLayer(oneOfField))); + } + } + + @Override + Object getFromProtoMessage(Message message) { + for (Map.Entry entry : this.oneOfConvert.entrySet()) { + Object value = entry.getValue().getFromProtoMessage(message); + if (value != null) { + return logicalType.createValue(entry.getKey(), value); + } + } + return null; + } + + @Override + OneOfType.Value convertFromProtoValue(Object in) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Row value) { + OneOfType.Value oneOf = logicalType.toInputType(value); + int caseIndex = oneOf.getCaseType().getValue(); + oneOfConvert.get(caseIndex).setOnProtoMessage(message, oneOf.getValue()); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + } + + /** + * This overlay handles nullable fields. If a primitive field needs to be nullable this overlay is + * wrapped around the original overlay. + */ + static class NullableConvert extends ProtoFieldConvert { + + private ProtoFieldConvert fieldOverlay; + + NullableConvert(Schema.Field field, ProtoFieldConvert fieldOverlay) { + super(field); + this.fieldOverlay = fieldOverlay; + } + + @Override + Object getFromProtoMessage(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + return fieldOverlay.getFromProtoMessage(message); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + fieldOverlay.setOnProtoMessage(message, value); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + } + + static class LogicalTypeConvert extends ProtoFieldConvert { + + private Schema.LogicalType logicalType; + + LogicalTypeConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + Object convertFromProtoValue(Object object) { + return logicalType.toBaseType(object); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaLogicalTypes.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaLogicalTypes.java index 0d4a5a6560ac1..7a3b1f06b5067 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaLogicalTypes.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaLogicalTypes.java @@ -85,7 +85,7 @@ public static Timestamp toTimestamp(Row row) { /** A duration represented in nanoseconds. */ public static class DurationNanos extends NanosType { - public static final String IDENTIFIER = "ProtoTimestamp"; + public static final String IDENTIFIER = "ProtoDuration"; public DurationNanos() { super(IDENTIFIER); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java index d27f48020efc3..373f28fc7af3a 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java @@ -112,6 +112,13 @@ public class ProtoSchemaTranslator { /** This METADATA tag is used to store the field number of a proto tag. */ public static final String PROTO_NUMBER_METADATA_TAG = "PROTO_NUMBER"; + public static final String PROTO_MESSAGE_NAME_METADATA_TAG = "PROTO_MESSAGE_NAME"; + + public static final String PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG = "PROTO_MAP_KEY_MESSAGE_NAME"; + + public static final String PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG = + "PROTO_MAP_VALUE_MESSAGE_NAME"; + /** Attach a proto field number to a type. */ public static FieldType withFieldNumber(FieldType fieldType, int index) { return fieldType.withMetadata(PROTO_NUMBER_METADATA_TAG, Long.toString(index)); @@ -122,12 +129,42 @@ public static int getFieldNumber(FieldType fieldType) { return Integer.parseInt(fieldType.getMetadataString(PROTO_NUMBER_METADATA_TAG)); } + /** Attach the name of the message to a type. */ + public static FieldType withMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a type. */ + public static String getMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MESSAGE_NAME_METADATA_TAG); + } + + /** Attach the name of the message to a map key. */ + public static FieldType withMapKeyMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a map key. */ + public static String getMapKeyMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG); + } + + /** Attach the name of the message to a map value. */ + public static FieldType withMapValueMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a map value. */ + public static String getMapValueMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG); + } + /** Return a Beam scheam representing a proto class. */ public static Schema getSchema(Class clazz) { return getSchema(ProtobufUtil.getDescriptorForClass(clazz)); } - private static Schema getSchema(Descriptors.Descriptor descriptor) { + static Schema getSchema(Descriptors.Descriptor descriptor) { Set oneOfFields = Sets.newHashSet(); List fields = Lists.newArrayListWithCapacity(descriptor.getFields().size()); for (OneofDescriptor oneofDescriptor : descriptor.getOneofs()) { @@ -137,8 +174,7 @@ private static Schema getSchema(Descriptors.Descriptor descriptor) { oneOfFields.add(fieldDescriptor.getNumber()); // Store proto field number in metadata. FieldType fieldType = - withFieldNumber( - beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor.getNumber()); + withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); subFields.add(Field.nullable(fieldDescriptor.getName(), fieldType)); checkArgument( enumIds.putIfAbsent(fieldDescriptor.getName(), fieldDescriptor.getNumber()) == null); @@ -151,14 +187,34 @@ private static Schema getSchema(Descriptors.Descriptor descriptor) { if (!oneOfFields.contains(fieldDescriptor.getNumber())) { // Store proto field number in metadata. FieldType fieldType = - withFieldNumber( - beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor.getNumber()); + withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); fields.add(Field.of(fieldDescriptor.getName(), fieldType)); } } return Schema.builder().addFields(fields).build(); } + private static FieldType withMetaData( + FieldType inType, Descriptors.FieldDescriptor fieldDescriptor) { + FieldType fieldType = withFieldNumber(inType, fieldDescriptor.getNumber()); + if (fieldDescriptor.isMapField()) { + FieldDescriptor keyFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + if ((keyFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { + fieldType = + withMapKeyMessageName(fieldType, keyFieldDescriptor.getMessageType().getFullName()); + } + if ((valueFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { + fieldType = + withMapValueMessageName(fieldType, valueFieldDescriptor.getMessageType().getFullName()); + } + } else if (fieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE) { + return withMessageName(fieldType, fieldDescriptor.getMessageType().getFullName()); + } + return fieldType; + } + private static FieldType beamFieldTypeFromProtoField( Descriptors.FieldDescriptor protoFieldDescriptor) { FieldType fieldType = null; diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java new file mode 100644 index 0000000000000..66b964d7fad87 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA; +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.MapPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OuterOneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.RepeatPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.WktMessage; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Collection of tests for values on Protobuf Messages and Rows. */ +@RunWith(JUnit4.class) +public class ProtoDynamicMessageSchemaTest { + + private ProtoDynamicMessageSchema schemaFromDescriptor(Descriptors.Descriptor descriptor) { + ProtoDomain domain = ProtoDomain.buildFrom(descriptor); + return ProtoDynamicMessageSchema.forDescriptor(domain, descriptor); + } + + private DynamicMessage toDynamic(Message message) throws InvalidProtocolBufferException { + return DynamicMessage.parseFrom(message.getDescriptorForType(), message.toByteArray()); + } + + @Test + public void testPrimitiveSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testPrimitiveProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(PRIMITIVE_ROW, toRow.apply(toDynamic(PRIMITIVE_PROTO))); + } + + @Test + public void testPrimitiveRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(PRIMITIVE_PROTO.toString(), fromRow.apply(PRIMITIVE_ROW).toString()); + } + + @Test + public void testRepeatedSchema() { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(REPEATED_SCHEMA, schema); + } + + @Test + public void testRepeatedProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(REPEATED_ROW, toRow.apply(toDynamic(REPEATED_PROTO))); + } + + @Test + public void testRepeatedRowToProto() { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(REPEATED_PROTO.toString(), fromRow.apply(REPEATED_ROW).toString()); + } + + // Test map type + @Test + public void testMapSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(MAP_PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testMapProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(MAP_PRIMITIVE_ROW, toRow.apply(toDynamic(MAP_PRIMITIVE_PROTO))); + } + + @Test + public void testMapRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(MAP_PRIMITIVE_PROTO.toString(), fromRow.apply(MAP_PRIMITIVE_ROW).toString()); + } + + @Test + public void testNestedSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(NESTED_SCHEMA, schema); + } + + @Test + public void testNestedProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(NESTED_ROW, toRow.apply(toDynamic(NESTED_PROTO))); + } + + @Test + public void testNestedRowToProto() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(NESTED_PROTO.toString(), fromRow.apply(NESTED_ROW).toString()); + } + + @Test + public void testOneOfSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(ONEOF_SCHEMA, schema); + } + + @Test + public void testOneOfProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(ONEOF_ROW_INT32.toString(), toRow.apply(toDynamic(ONEOF_PROTO_INT32)).toString()); + assertEquals(ONEOF_ROW_BOOL.toString(), toRow.apply(toDynamic(ONEOF_PROTO_BOOL)).toString()); + assertEquals( + ONEOF_ROW_STRING.toString(), toRow.apply(toDynamic(ONEOF_PROTO_STRING)).toString()); + assertEquals( + ONEOF_ROW_PRIMITIVE.toString(), toRow.apply(toDynamic(ONEOF_PROTO_PRIMITIVE)).toString()); + } + + @Test + public void testOneOfRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(ONEOF_PROTO_INT32.toString(), fromRow.apply(ONEOF_ROW_INT32).toString()); + assertEquals(ONEOF_PROTO_BOOL.toString(), fromRow.apply(ONEOF_ROW_BOOL).toString()); + assertEquals(ONEOF_PROTO_STRING.toString(), fromRow.apply(ONEOF_ROW_STRING).toString()); + assertEquals(ONEOF_PROTO_PRIMITIVE.toString(), fromRow.apply(ONEOF_ROW_PRIMITIVE).toString()); + } + + @Test + public void testOuterOneOfSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(OUTER_ONEOF_SCHEMA, schema); + } + + @Test + public void testOuterOneOfProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(OUTER_ONEOF_ROW.toString(), toRow.apply(toDynamic(OUTER_ONEOF_PROTO)).toString()); + } + + @Test + public void testOuterOneOfRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(OUTER_ONEOF_PROTO.toString(), fromRow.apply(OUTER_ONEOF_ROW).toString()); + } + + private static final EnumerationType ENUM_TYPE = + EnumerationType.create(ImmutableMap.of("ZERO", 0, "TWO", 2, "THREE", 3)); + private static final Schema ENUM_SCHEMA = + Schema.builder() + .addField( + "enum", + withFieldNumber(Schema.FieldType.logicalType(ENUM_TYPE).withNullable(true), 1)) + .build(); + private static final Row ENUM_ROW = + Row.withSchema(ENUM_SCHEMA).addValues(ENUM_TYPE.valueOf("TWO")).build(); + private static final EnumMessage ENUM_PROTO = + EnumMessage.newBuilder().setEnum(EnumMessage.Enum.TWO).build(); + + @Test + public void testEnumSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(ENUM_SCHEMA, schema); + } + + @Test + public void testEnumProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(ENUM_ROW, toRow.apply(toDynamic(ENUM_PROTO))); + } + + @Test + public void testEnumRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(ENUM_PROTO.toString(), fromRow.apply(ENUM_ROW).toString()); + } + + @Test + public void testWktMessageSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(WKT_MESSAGE_SCHEMA, schema); + } + + @Test + public void testWktProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(WKT_MESSAGE_ROW, toRow.apply(toDynamic(WKT_MESSAGE_PROTO))); + } + + @Test + public void testWktRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(WKT_MESSAGE_PROTO.toString(), fromRow.apply(WKT_MESSAGE_ROW).toString()); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java index 88892d8432025..9ad59791f6d1a 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java @@ -19,6 +19,8 @@ import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMapValueMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; @@ -248,16 +250,24 @@ class TestProtoSchemas { static final Schema NESTED_SCHEMA = Schema.builder() .addField( - "nested", withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1)) + "nested", + withMessageName( + withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1), + "proto3_schema_messages.Primitive")) .addField( - "nested_list", withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2)) + "nested_list", + withMessageName( + withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2), + "proto3_schema_messages.Primitive")) .addField( "nested_map", - withFieldNumber( - FieldType.map( - FieldType.STRING.withNullable(true), - FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)), - 3)) + withMapValueMessageName( + withFieldNumber( + FieldType.map( + FieldType.STRING.withNullable(true), + FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)), + 3), + "proto3_schema_messages.Primitive")) .build(); // A sample instance of the row. @@ -282,7 +292,11 @@ class TestProtoSchemas { Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2)), Field.of("oneof_bool", withFieldNumber(FieldType.BOOLEAN, 3)), Field.of("oneof_string", withFieldNumber(FieldType.STRING, 4)), - Field.of("oneof_primitive", withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5))); + Field.of( + "oneof_primitive", + withMessageName( + withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5), + "proto3_schema_messages.Primitive"))); private static final Map ONE_OF_ENUM_MAP = ONEOF_FIELDS.stream() .collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f.getType()))); @@ -325,7 +339,10 @@ class TestProtoSchemas { // The schema for the OuterOneOf proto. private static final List OUTER_ONEOF_FIELDS = ImmutableList.of( - Field.of("oneof_oneof", withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1)), + Field.of( + "oneof_oneof", + withMessageName( + withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1), "proto3_schema_messages.OneOf")), Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2))); private static final Map OUTER_ONE_OF_ENUM_MAP = OUTER_ONEOF_FIELDS.stream() @@ -347,19 +364,47 @@ class TestProtoSchemas { static final Schema WKT_MESSAGE_SCHEMA = Schema.builder() - .addNullableField("double", withFieldNumber(FieldType.DOUBLE, 1)) - .addNullableField("float", withFieldNumber(FieldType.FLOAT, 2)) - .addNullableField("int32", withFieldNumber(FieldType.INT32, 3)) - .addNullableField("int64", withFieldNumber(FieldType.INT64, 4)) - .addNullableField("uint32", withFieldNumber(FieldType.logicalType(new UInt32()), 5)) - .addNullableField("uint64", withFieldNumber(FieldType.logicalType(new UInt64()), 6)) - .addNullableField("bool", withFieldNumber(FieldType.BOOLEAN, 13)) - .addNullableField("string", withFieldNumber(FieldType.STRING, 14)) - .addNullableField("bytes", withFieldNumber(FieldType.BYTES, 15)) .addNullableField( - "timestamp", withFieldNumber(FieldType.logicalType(new TimestampNanos()), 16)) + "double", + withMessageName(withFieldNumber(FieldType.DOUBLE, 1), "google.protobuf.DoubleValue")) + .addNullableField( + "float", + withMessageName(withFieldNumber(FieldType.FLOAT, 2), "google.protobuf.FloatValue")) + .addNullableField( + "int32", + withMessageName(withFieldNumber(FieldType.INT32, 3), "google.protobuf.Int32Value")) + .addNullableField( + "int64", + withMessageName(withFieldNumber(FieldType.INT64, 4), "google.protobuf.Int64Value")) + .addNullableField( + "uint32", + withMessageName( + withFieldNumber(FieldType.logicalType(new UInt32()), 5), + "google.protobuf.UInt32Value")) + .addNullableField( + "uint64", + withMessageName( + withFieldNumber(FieldType.logicalType(new UInt64()), 6), + "google.protobuf.UInt64Value")) + .addNullableField( + "bool", + withMessageName(withFieldNumber(FieldType.BOOLEAN, 13), "google.protobuf.BoolValue")) + .addNullableField( + "string", + withMessageName(withFieldNumber(FieldType.STRING, 14), "google.protobuf.StringValue")) + .addNullableField( + "bytes", + withMessageName(withFieldNumber(FieldType.BYTES, 15), "google.protobuf.BytesValue")) + .addNullableField( + "timestamp", + withMessageName( + withFieldNumber(FieldType.logicalType(new TimestampNanos()), 16), + "google.protobuf.Timestamp")) .addNullableField( - "duration", withFieldNumber(FieldType.logicalType(new DurationNanos()), 17)) + "duration", + withMessageName( + withFieldNumber(FieldType.logicalType(new DurationNanos()), 17), + "google.protobuf.Duration")) .build(); // A sample instance of the row. static final Instant JAVA_NOW = Instant.now();