From 8f858f398de3c47fe5045fbe0f1497022bfb1c15 Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Thu, 4 Jun 2020 20:01:47 +0800 Subject: [PATCH] [FLINK-18073][avro] Fix AvroRowDataSerializationSchema is not serializable This closes #12471 --- .../avro/AvroFileSystemFormatFactory.java | 3 +- .../avro/AvroRowDataSerializationSchema.java | 92 +++++++++++++++---- .../avro/typeutils/AvroSchemaConverter.java | 23 +++-- .../typeutils/AvroSchemaConverterTest.java | 42 +++++++++ 4 files changed, 127 insertions(+), 33 deletions(-) diff --git a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java index a033739f9ad2e..c60c42c0f41c6 100644 --- a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java +++ b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroFileSystemFormatFactory.java @@ -243,11 +243,12 @@ public BulkWriter create(FSDataOutputStream out) throws IOException { BulkWriter writer = factory.create(out); AvroRowDataSerializationSchema.SerializationRuntimeConverter converter = AvroRowDataSerializationSchema.createRowConverter(rowType); + Schema schema = AvroSchemaConverter.convertToSchema(rowType); return new BulkWriter() { @Override public void addElement(RowData element) throws IOException { - GenericRecord record = (GenericRecord) converter.convert(element); + GenericRecord record = (GenericRecord) converter.convert(schema, element); writer.addElement(record); } diff --git a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java index 00b7ac5d18b79..5b1fbbecc886f 100644 --- a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java +++ b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDataSerializationSchema.java @@ -74,6 +74,11 @@ public class AvroRowDataSerializationSchema implements SerializationSchema(schema); arrayOutputStream = new ByteArrayOutputStream(); encoder = EncoderFactory.get().binaryEncoder(arrayOutputStream, null); @@ -109,7 +114,7 @@ public void open(InitializationContext context) throws Exception { public byte[] serialize(RowData row) { try { // convert to record - final GenericRecord record = (GenericRecord) runtimeConverter.convert(row); + final GenericRecord record = (GenericRecord) runtimeConverter.convert(schema, row); arrayOutputStream.reset(); datumWriter.write(record, encoder); encoder.flush(); @@ -145,33 +150,43 @@ public int hashCode() { * to corresponding Avro data structures. */ interface SerializationRuntimeConverter extends Serializable { - Object convert(Object object); + Object convert(Schema schema, Object object); } static SerializationRuntimeConverter createRowConverter(RowType rowType) { final SerializationRuntimeConverter[] fieldConverters = rowType.getChildren().stream() .map(AvroRowDataSerializationSchema::createConverter) .toArray(SerializationRuntimeConverter[]::new); - final Schema schema = AvroSchemaConverter.convertToSchema(rowType); final LogicalType[] fieldTypes = rowType.getFields().stream() .map(RowType.RowField::getType) .toArray(LogicalType[]::new); + final RowData.FieldGetter[] fieldGetters = new RowData.FieldGetter[fieldTypes.length]; + for (int i = 0; i < fieldTypes.length; i++) { + fieldGetters[i] = RowData.createFieldGetter(fieldTypes[i], i); + } final int length = rowType.getFieldCount(); - return object -> { + return (schema, object) -> { final RowData row = (RowData) object; + final List fields = schema.getFields(); final GenericRecord record = new GenericData.Record(schema); for (int i = 0; i < length; ++i) { - record.put(i, fieldConverters[i].convert(RowData.get(row, i, fieldTypes[i]))); + final Schema.Field schemaField = fields.get(i); + Object avroObject = fieldConverters[i].convert( + schemaField.schema(), + fieldGetters[i].getFieldOrNull(row)); + record.put(i, avroObject); } return record; }; } private static SerializationRuntimeConverter createConverter(LogicalType type) { + final SerializationRuntimeConverter converter; switch (type.getTypeRoot()) { case NULL: - return object -> null; + converter = (schema, object) -> null; + break; case BOOLEAN: // boolean case INTEGER: // int case INTERVAL_YEAR_MONTH: // long @@ -181,39 +196,74 @@ private static SerializationRuntimeConverter createConverter(LogicalType type) { case DOUBLE: // double case TIME_WITHOUT_TIME_ZONE: // int case DATE: // int - return avroObject -> avroObject; + converter = (schema, object) -> object; + break; case CHAR: case VARCHAR: - return object -> new Utf8(object.toString()); + converter = (schema, object) -> new Utf8(object.toString()); + break; case BINARY: case VARBINARY: - return object -> ByteBuffer.wrap((byte[]) object); + converter = (schema, object) -> ByteBuffer.wrap((byte[]) object); + break; case TIMESTAMP_WITHOUT_TIME_ZONE: - return object -> ((TimestampData) object).toTimestamp().getTime(); + converter = (schema, object) -> ((TimestampData) object).toTimestamp().getTime(); + break; case DECIMAL: - return object -> ByteBuffer.wrap(((DecimalData) object).toUnscaledBytes()); + converter = (schema, object) -> ByteBuffer.wrap(((DecimalData) object).toUnscaledBytes()); + break; case ARRAY: - return createArrayConverter((ArrayType) type); + converter = createArrayConverter((ArrayType) type); + break; case ROW: - return createRowConverter((RowType) type); + converter = createRowConverter((RowType) type); + break; case MAP: case MULTISET: - return createMapConverter(type); + converter = createMapConverter(type); + break; case RAW: default: throw new UnsupportedOperationException("Unsupported type: " + type); } + + // wrap into nullable converter + return (schema, object) -> { + if (object == null) { + return null; + } + + // get actual schema if it is a nullable schema + Schema actualSchema; + if (schema.getType() == Schema.Type.UNION) { + List types = schema.getTypes(); + int size = types.size(); + if (size == 2 && types.get(1).getType() == Schema.Type.NULL) { + actualSchema = types.get(0); + } else if (size == 2 && types.get(0).getType() == Schema.Type.NULL) { + actualSchema = types.get(1); + } else { + throw new IllegalArgumentException( + "The Avro schema is not a nullable type: " + schema.toString()); + } + } else { + actualSchema = schema; + } + return converter.convert(actualSchema, object); + }; } private static SerializationRuntimeConverter createArrayConverter(ArrayType arrayType) { + LogicalType elementType = arrayType.getElementType(); + final ArrayData.ElementGetter elementGetter = ArrayData.createElementGetter(elementType); final SerializationRuntimeConverter elementConverter = createConverter(arrayType.getElementType()); - final LogicalType elementType = arrayType.getElementType(); - return object -> { + return (schema, object) -> { + final Schema elementSchema = schema.getElementType(); ArrayData arrayData = (ArrayData) object; List list = new ArrayList<>(); for (int i = 0; i < arrayData.size(); ++i) { - list.add(elementConverter.convert(ArrayData.get(arrayData, i, elementType))); + list.add(elementConverter.convert(elementSchema, elementGetter.getElementOrNull(arrayData, i))); } return list; }; @@ -221,16 +271,18 @@ private static SerializationRuntimeConverter createArrayConverter(ArrayType arra private static SerializationRuntimeConverter createMapConverter(LogicalType type) { LogicalType valueType = extractValueTypeToAvroMap(type); + final ArrayData.ElementGetter valueGetter = ArrayData.createElementGetter(valueType); final SerializationRuntimeConverter valueConverter = createConverter(valueType); - return object -> { + return (schema, object) -> { + final Schema valueSchema = schema.getValueType(); final MapData mapData = (MapData) object; final ArrayData keyArray = mapData.keyArray(); final ArrayData valueArray = mapData.valueArray(); final Map map = new HashMap<>(mapData.size()); for (int i = 0; i < mapData.size(); ++i) { final String key = keyArray.getString(i).toString(); - final Object value = valueConverter.convert(ArrayData.get(valueArray, i, valueType)); + final Object value = valueConverter.convert(valueSchema, valueGetter.getElementOrNull(valueArray, i)); map.put(key, value); } return map; diff --git a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java index 774fadfcfd03e..37745e555a161 100644 --- a/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java +++ b/flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverter.java @@ -31,6 +31,7 @@ import org.apache.flink.table.types.logical.MapType; import org.apache.flink.table.types.logical.MultisetType; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.TimeType; import org.apache.flink.table.types.logical.TimestampType; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import org.apache.flink.types.Row; @@ -179,6 +180,7 @@ public static Schema convertToSchema(LogicalType logicalType) { } public static Schema convertToSchema(LogicalType logicalType, int rowTypeCounter) { + int precision; switch (logicalType.getTypeRoot()) { case NULL: return SchemaBuilder.builder().nullType(); @@ -201,20 +203,25 @@ public static Schema convertToSchema(LogicalType logicalType, int rowTypeCounter case TIMESTAMP_WITHOUT_TIME_ZONE: // use long to represents Timestamp final TimestampType timestampType = (TimestampType) logicalType; - int precision = timestampType.getPrecision(); + precision = timestampType.getPrecision(); org.apache.avro.LogicalType avroLogicalType; if (precision <= 3) { avroLogicalType = LogicalTypes.timestampMillis(); } else { - throw new IllegalArgumentException("Avro Timestamp does not support Timestamp with precision: " + - precision + - ", it only supports precision of 3 or 9."); + throw new IllegalArgumentException("Avro does not support TIMESTAMP type " + + "with precision: " + precision + ", it only supports precision less than 3."); } return avroLogicalType.addToSchema(SchemaBuilder.builder().longType()); case DATE: // use int to represents Date return LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()); case TIME_WITHOUT_TIME_ZONE: + precision = ((TimeType) logicalType).getPrecision(); + if (precision > 3) { + throw new IllegalArgumentException( + "Avro does not support TIME type with precision: " + precision + + ", it only supports precision less than 3."); + } // use int to represents Time, we only support millisecond when deserialization return LogicalTypes.timeMillis().addToSchema(SchemaBuilder.builder().intType()); case DECIMAL: @@ -254,14 +261,6 @@ public static Schema convertToSchema(LogicalType logicalType, int rowTypeCounter .array() .items(convertToSchema(arrayType.getElementType(), rowTypeCounter)); case RAW: - // if the union type has more than 2 types, it will be recognized a generic type - // see AvroRowDeserializationSchema#convertAvroType and AvroRowSerializationSchema#convertFlinkType - return SchemaBuilder.builder().unionOf() - .nullType().and() - .booleanType().and() - .longType().and() - .doubleType() - .endUnion(); case TIMESTAMP_WITH_LOCAL_TIME_ZONE: default: throw new UnsupportedOperationException("Unsupported to derive Schema for type: " + logicalType); diff --git a/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java b/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java index be0ddc48386a8..fa499b79c4a69 100644 --- a/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java +++ b/flink-formats/flink-avro/src/test/java/org/apache/flink/formats/avro/typeutils/AvroSchemaConverterTest.java @@ -22,9 +22,14 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.formats.avro.generated.User; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.Row; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -34,6 +39,9 @@ */ public class AvroSchemaConverterTest { + @Rule + public ExpectedException thrown = ExpectedException.none(); + @Test public void testAvroClassConversion() { validateUserSchema(AvroSchemaConverter.convertToTypeInfo(User.class)); @@ -45,6 +53,40 @@ public void testAvroSchemaConversion() { validateUserSchema(AvroSchemaConverter.convertToTypeInfo(schema)); } + @Test + public void testInvalidRawTypeAvroSchemaConversion() { + RowType rowType = (RowType) TableSchema.builder() + .field("a", DataTypes.STRING()) + .field("b", DataTypes.RAW(Types.GENERIC(AvroSchemaConverterTest.class))) + .build().toRowDataType().getLogicalType(); + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("Unsupported to derive Schema for type: RAW"); + AvroSchemaConverter.convertToSchema(rowType); + } + + @Test + public void testInvalidTimestampTypeAvroSchemaConversion() { + RowType rowType = (RowType) TableSchema.builder() + .field("a", DataTypes.STRING()) + .field("b", DataTypes.TIMESTAMP(9)) + .build().toRowDataType().getLogicalType(); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Avro does not support TIMESTAMP type with precision: 9, " + + "it only supports precision less than 3."); + AvroSchemaConverter.convertToSchema(rowType); + } + + @Test + public void testInvalidTimeTypeAvroSchemaConversion() { + RowType rowType = (RowType) TableSchema.builder() + .field("a", DataTypes.STRING()) + .field("b", DataTypes.TIME(6)) + .build().toRowDataType().getLogicalType(); + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Avro does not support TIME type with precision: 6, it only supports precision less than 3."); + AvroSchemaConverter.convertToSchema(rowType); + } + private void validateUserSchema(TypeInformation actual) { final TypeInformation address = Types.ROW_NAMED( new String[]{