Skip to content

Commit

Permalink
[FLINK-18073][avro] Fix AvroRowDataSerializationSchema is not seriali…
Browse files Browse the repository at this point in the history
…zable

This closes apache#12471
  • Loading branch information
wuchong committed Jun 8, 2020
1 parent 674817c commit 8f858f3
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,12 @@ public BulkWriter<RowData> create(FSDataOutputStream out) throws IOException {
BulkWriter<GenericRecord> writer = factory.create(out);
AvroRowDataSerializationSchema.SerializationRuntimeConverter converter =
AvroRowDataSerializationSchema.createRowConverter(rowType);
Schema schema = AvroSchemaConverter.convertToSchema(rowType);
return new BulkWriter<RowData>() {

@Override
public void addElement(RowData element) throws IOException {
GenericRecord record = (GenericRecord) converter.convert(element);
GenericRecord record = (GenericRecord) converter.convert(schema, element);
writer.addElement(record);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ public class AvroRowDataSerializationSchema implements SerializationSchema<RowDa
*/
private final SerializationRuntimeConverter runtimeConverter;

/**
* Avro serialization schema.
*/
private transient Schema schema;

/**
* Writer to serialize Avro record into a Avro bytes.
*/
Expand All @@ -99,7 +104,7 @@ public AvroRowDataSerializationSchema(RowType rowType) {

@Override
public void open(InitializationContext context) throws Exception {
final Schema schema = AvroSchemaConverter.convertToSchema(rowType);
this.schema = AvroSchemaConverter.convertToSchema(rowType);
datumWriter = new SpecificDatumWriter<>(schema);
arrayOutputStream = new ByteArrayOutputStream();
encoder = EncoderFactory.get().binaryEncoder(arrayOutputStream, null);
Expand All @@ -109,7 +114,7 @@ public void open(InitializationContext context) throws Exception {
public byte[] serialize(RowData row) {
try {
// convert to record
final GenericRecord record = (GenericRecord) runtimeConverter.convert(row);
final GenericRecord record = (GenericRecord) runtimeConverter.convert(schema, row);
arrayOutputStream.reset();
datumWriter.write(record, encoder);
encoder.flush();
Expand Down Expand Up @@ -145,33 +150,43 @@ public int hashCode() {
* to corresponding Avro data structures.
*/
interface SerializationRuntimeConverter extends Serializable {
Object convert(Object object);
Object convert(Schema schema, Object object);
}

static SerializationRuntimeConverter createRowConverter(RowType rowType) {
final SerializationRuntimeConverter[] fieldConverters = rowType.getChildren().stream()
.map(AvroRowDataSerializationSchema::createConverter)
.toArray(SerializationRuntimeConverter[]::new);
final Schema schema = AvroSchemaConverter.convertToSchema(rowType);
final LogicalType[] fieldTypes = rowType.getFields().stream()
.map(RowType.RowField::getType)
.toArray(LogicalType[]::new);
final RowData.FieldGetter[] fieldGetters = new RowData.FieldGetter[fieldTypes.length];
for (int i = 0; i < fieldTypes.length; i++) {
fieldGetters[i] = RowData.createFieldGetter(fieldTypes[i], i);
}
final int length = rowType.getFieldCount();

return object -> {
return (schema, object) -> {
final RowData row = (RowData) object;
final List<Schema.Field> fields = schema.getFields();
final GenericRecord record = new GenericData.Record(schema);
for (int i = 0; i < length; ++i) {
record.put(i, fieldConverters[i].convert(RowData.get(row, i, fieldTypes[i])));
final Schema.Field schemaField = fields.get(i);
Object avroObject = fieldConverters[i].convert(
schemaField.schema(),
fieldGetters[i].getFieldOrNull(row));
record.put(i, avroObject);
}
return record;
};
}

private static SerializationRuntimeConverter createConverter(LogicalType type) {
final SerializationRuntimeConverter converter;
switch (type.getTypeRoot()) {
case NULL:
return object -> null;
converter = (schema, object) -> null;
break;
case BOOLEAN: // boolean
case INTEGER: // int
case INTERVAL_YEAR_MONTH: // long
Expand All @@ -181,56 +196,93 @@ private static SerializationRuntimeConverter createConverter(LogicalType type) {
case DOUBLE: // double
case TIME_WITHOUT_TIME_ZONE: // int
case DATE: // int
return avroObject -> avroObject;
converter = (schema, object) -> object;
break;
case CHAR:
case VARCHAR:
return object -> new Utf8(object.toString());
converter = (schema, object) -> new Utf8(object.toString());
break;
case BINARY:
case VARBINARY:
return object -> ByteBuffer.wrap((byte[]) object);
converter = (schema, object) -> ByteBuffer.wrap((byte[]) object);
break;
case TIMESTAMP_WITHOUT_TIME_ZONE:
return object -> ((TimestampData) object).toTimestamp().getTime();
converter = (schema, object) -> ((TimestampData) object).toTimestamp().getTime();
break;
case DECIMAL:
return object -> ByteBuffer.wrap(((DecimalData) object).toUnscaledBytes());
converter = (schema, object) -> ByteBuffer.wrap(((DecimalData) object).toUnscaledBytes());
break;
case ARRAY:
return createArrayConverter((ArrayType) type);
converter = createArrayConverter((ArrayType) type);
break;
case ROW:
return createRowConverter((RowType) type);
converter = createRowConverter((RowType) type);
break;
case MAP:
case MULTISET:
return createMapConverter(type);
converter = createMapConverter(type);
break;
case RAW:
default:
throw new UnsupportedOperationException("Unsupported type: " + type);
}

// wrap into nullable converter
return (schema, object) -> {
if (object == null) {
return null;
}

// get actual schema if it is a nullable schema
Schema actualSchema;
if (schema.getType() == Schema.Type.UNION) {
List<Schema> types = schema.getTypes();
int size = types.size();
if (size == 2 && types.get(1).getType() == Schema.Type.NULL) {
actualSchema = types.get(0);
} else if (size == 2 && types.get(0).getType() == Schema.Type.NULL) {
actualSchema = types.get(1);
} else {
throw new IllegalArgumentException(
"The Avro schema is not a nullable type: " + schema.toString());
}
} else {
actualSchema = schema;
}
return converter.convert(actualSchema, object);
};
}

private static SerializationRuntimeConverter createArrayConverter(ArrayType arrayType) {
LogicalType elementType = arrayType.getElementType();
final ArrayData.ElementGetter elementGetter = ArrayData.createElementGetter(elementType);
final SerializationRuntimeConverter elementConverter = createConverter(arrayType.getElementType());
final LogicalType elementType = arrayType.getElementType();

return object -> {
return (schema, object) -> {
final Schema elementSchema = schema.getElementType();
ArrayData arrayData = (ArrayData) object;
List<Object> list = new ArrayList<>();
for (int i = 0; i < arrayData.size(); ++i) {
list.add(elementConverter.convert(ArrayData.get(arrayData, i, elementType)));
list.add(elementConverter.convert(elementSchema, elementGetter.getElementOrNull(arrayData, i)));
}
return list;
};
}

private static SerializationRuntimeConverter createMapConverter(LogicalType type) {
LogicalType valueType = extractValueTypeToAvroMap(type);
final ArrayData.ElementGetter valueGetter = ArrayData.createElementGetter(valueType);
final SerializationRuntimeConverter valueConverter = createConverter(valueType);

return object -> {
return (schema, object) -> {
final Schema valueSchema = schema.getValueType();
final MapData mapData = (MapData) object;
final ArrayData keyArray = mapData.keyArray();
final ArrayData valueArray = mapData.valueArray();
final Map<Object, Object> map = new HashMap<>(mapData.size());
for (int i = 0; i < mapData.size(); ++i) {
final String key = keyArray.getString(i).toString();
final Object value = valueConverter.convert(ArrayData.get(valueArray, i, valueType));
final Object value = valueConverter.convert(valueSchema, valueGetter.getElementOrNull(valueArray, i));
map.put(key, value);
}
return map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.MultisetType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
import org.apache.flink.types.Row;
Expand Down Expand Up @@ -179,6 +180,7 @@ public static Schema convertToSchema(LogicalType logicalType) {
}

public static Schema convertToSchema(LogicalType logicalType, int rowTypeCounter) {
int precision;
switch (logicalType.getTypeRoot()) {
case NULL:
return SchemaBuilder.builder().nullType();
Expand All @@ -201,20 +203,25 @@ public static Schema convertToSchema(LogicalType logicalType, int rowTypeCounter
case TIMESTAMP_WITHOUT_TIME_ZONE:
// use long to represents Timestamp
final TimestampType timestampType = (TimestampType) logicalType;
int precision = timestampType.getPrecision();
precision = timestampType.getPrecision();
org.apache.avro.LogicalType avroLogicalType;
if (precision <= 3) {
avroLogicalType = LogicalTypes.timestampMillis();
} else {
throw new IllegalArgumentException("Avro Timestamp does not support Timestamp with precision: " +
precision +
", it only supports precision of 3 or 9.");
throw new IllegalArgumentException("Avro does not support TIMESTAMP type " +
"with precision: " + precision + ", it only supports precision less than 3.");
}
return avroLogicalType.addToSchema(SchemaBuilder.builder().longType());
case DATE:
// use int to represents Date
return LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType());
case TIME_WITHOUT_TIME_ZONE:
precision = ((TimeType) logicalType).getPrecision();
if (precision > 3) {
throw new IllegalArgumentException(
"Avro does not support TIME type with precision: " + precision +
", it only supports precision less than 3.");
}
// use int to represents Time, we only support millisecond when deserialization
return LogicalTypes.timeMillis().addToSchema(SchemaBuilder.builder().intType());
case DECIMAL:
Expand Down Expand Up @@ -254,14 +261,6 @@ public static Schema convertToSchema(LogicalType logicalType, int rowTypeCounter
.array()
.items(convertToSchema(arrayType.getElementType(), rowTypeCounter));
case RAW:
// if the union type has more than 2 types, it will be recognized a generic type
// see AvroRowDeserializationSchema#convertAvroType and AvroRowSerializationSchema#convertFlinkType
return SchemaBuilder.builder().unionOf()
.nullType().and()
.booleanType().and()
.longType().and()
.doubleType()
.endUnion();
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
default:
throw new UnsupportedOperationException("Unsupported to derive Schema for type: " + logicalType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.formats.avro.generated.User;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.types.Row;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
Expand All @@ -34,6 +39,9 @@
*/
public class AvroSchemaConverterTest {

@Rule
public ExpectedException thrown = ExpectedException.none();

@Test
public void testAvroClassConversion() {
validateUserSchema(AvroSchemaConverter.convertToTypeInfo(User.class));
Expand All @@ -45,6 +53,40 @@ public void testAvroSchemaConversion() {
validateUserSchema(AvroSchemaConverter.convertToTypeInfo(schema));
}

@Test
public void testInvalidRawTypeAvroSchemaConversion() {
RowType rowType = (RowType) TableSchema.builder()
.field("a", DataTypes.STRING())
.field("b", DataTypes.RAW(Types.GENERIC(AvroSchemaConverterTest.class)))
.build().toRowDataType().getLogicalType();
thrown.expect(UnsupportedOperationException.class);
thrown.expectMessage("Unsupported to derive Schema for type: RAW");
AvroSchemaConverter.convertToSchema(rowType);
}

@Test
public void testInvalidTimestampTypeAvroSchemaConversion() {
RowType rowType = (RowType) TableSchema.builder()
.field("a", DataTypes.STRING())
.field("b", DataTypes.TIMESTAMP(9))
.build().toRowDataType().getLogicalType();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Avro does not support TIMESTAMP type with precision: 9, " +
"it only supports precision less than 3.");
AvroSchemaConverter.convertToSchema(rowType);
}

@Test
public void testInvalidTimeTypeAvroSchemaConversion() {
RowType rowType = (RowType) TableSchema.builder()
.field("a", DataTypes.STRING())
.field("b", DataTypes.TIME(6))
.build().toRowDataType().getLogicalType();
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Avro does not support TIME type with precision: 6, it only supports precision less than 3.");
AvroSchemaConverter.convertToSchema(rowType);
}

private void validateUserSchema(TypeInformation<?> actual) {
final TypeInformation<Row> address = Types.ROW_NAMED(
new String[]{
Expand Down

0 comments on commit 8f858f3

Please sign in to comment.