Skip to content

Commit

Permalink
[FLINK-9384] [table] Fix KafkaAvroTableSource type mismatch
Browse files Browse the repository at this point in the history
This closes apache#6026.
  • Loading branch information
jerryjzhang authored and twalthr committed May 25, 2018
1 parent cb48019 commit bc8d1b1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.formats.avro.utils.AvroTestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Types;

import org.apache.avro.Schema;
Expand Down Expand Up @@ -72,6 +73,10 @@ public void testSameFieldsAvroClass() {

// check field mapping
assertNull(source.getFieldMapping());

// check if DataStream type matches with TableSource.getReturnType()
assertEquals(source.getReturnType(),
source.getDataStream(StreamExecutionEnvironment.getExecutionEnvironment()).getType());
}

@Test
Expand Down Expand Up @@ -117,6 +122,10 @@ public void testDifferentFieldsAvroClass() {
assertEquals("otherField1", fieldMapping.get("field1"));
assertEquals("otherField2", fieldMapping.get("field2"));
assertEquals("otherField3", fieldMapping.get("field3"));

// check if DataStream type matches with TableSource.getReturnType()
assertEquals(source.getReturnType(),
source.getDataStream(StreamExecutionEnvironment.getExecutionEnvironment()).getType());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.flink.formats.avro;

import org.apache.flink.api.common.serialization.AbstractDeserializationSchema;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.formats.avro.typeutils.AvroRecordClassConverter;
import org.apache.flink.formats.avro.utils.MutableByteArrayInputStream;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
Expand All @@ -30,6 +32,7 @@
import org.apache.avro.specific.SpecificData;
import org.apache.avro.specific.SpecificDatumReader;
import org.apache.avro.specific.SpecificRecord;
import org.apache.avro.specific.SpecificRecordBase;
import org.apache.avro.util.Utf8;

import java.io.IOException;
Expand Down Expand Up @@ -76,19 +79,25 @@ public class AvroRowDeserializationSchema extends AbstractDeserializationSchema<
*/
private SpecificRecord record;

/**
* Type information describing the result type.
*/
private transient TypeInformation<Row> typeInfo;

/**
* Creates a Avro deserialization schema for the given record.
*
* @param recordClazz Avro record class used to deserialize Avro's record to Flink's row
*/
public AvroRowDeserializationSchema(Class<? extends SpecificRecord> recordClazz) {
public AvroRowDeserializationSchema(Class<? extends SpecificRecordBase> recordClazz) {
Preconditions.checkNotNull(recordClazz, "Avro record class must not be null.");
this.recordClazz = recordClazz;
this.schema = SpecificData.get().getSchema(recordClazz);
this.datumReader = new SpecificDatumReader<>(schema);
this.record = (SpecificRecord) SpecificData.newInstance(recordClazz, schema);
this.inputStream = new MutableByteArrayInputStream();
this.decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
this.typeInfo = AvroRecordClassConverter.convert(recordClazz);
}

@Override
Expand Down Expand Up @@ -120,6 +129,11 @@ private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IO
this.decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
}

@Override
public TypeInformation<Row> getProducedType() {
return typeInfo;
}

/**
* Converts a (nested) Avro {@link SpecificRecord} into Flink's Row type.
* Avro's {@link Utf8} fields are converted into regular Java strings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.util.InstantiationUtil;

import org.apache.avro.specific.SpecificRecord;
import org.apache.avro.specific.SpecificRecordBase;
import org.junit.Test;

import java.io.IOException;
Expand All @@ -37,7 +38,7 @@ public class AvroRowDeSerializationSchemaTest {

@Test
public void testSerializeDeserializeSimpleRow() throws IOException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getSimpleTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getSimpleTestData();

final AvroRowSerializationSchema serializationSchema = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserializationSchema = new AvroRowDeserializationSchema(testData.f0);
Expand All @@ -50,7 +51,7 @@ public void testSerializeDeserializeSimpleRow() throws IOException {

@Test
public void testSerializeSimpleRowSeveralTimes() throws IOException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getSimpleTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getSimpleTestData();

final AvroRowSerializationSchema serializationSchema = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserializationSchema = new AvroRowDeserializationSchema(testData.f0);
Expand All @@ -65,7 +66,7 @@ public void testSerializeSimpleRowSeveralTimes() throws IOException {

@Test
public void testDeserializeRowSeveralTimes() throws IOException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getSimpleTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getSimpleTestData();

final AvroRowSerializationSchema serializationSchema = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserializationSchema = new AvroRowDeserializationSchema(testData.f0);
Expand All @@ -80,7 +81,7 @@ public void testDeserializeRowSeveralTimes() throws IOException {

@Test
public void testSerializeDeserializeComplexRow() throws IOException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();

final AvroRowSerializationSchema serializationSchema = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserializationSchema = new AvroRowDeserializationSchema(testData.f0);
Expand All @@ -93,7 +94,7 @@ public void testSerializeDeserializeComplexRow() throws IOException {

@Test
public void testSerializeComplexRowSeveralTimes() throws IOException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();

final AvroRowSerializationSchema serializationSchema = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserializationSchema = new AvroRowDeserializationSchema(testData.f0);
Expand All @@ -108,7 +109,7 @@ public void testSerializeComplexRowSeveralTimes() throws IOException {

@Test
public void testDeserializeComplexRowSeveralTimes() throws IOException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();

final AvroRowSerializationSchema serializationSchema = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserializationSchema = new AvroRowDeserializationSchema(testData.f0);
Expand All @@ -123,7 +124,7 @@ public void testDeserializeComplexRowSeveralTimes() throws IOException {

@Test
public void testSerializability() throws IOException, ClassNotFoundException {
final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> testData = AvroTestUtils.getComplexTestData();

final AvroRowSerializationSchema serOrig = new AvroRowSerializationSchema(testData.f0);
final AvroRowDeserializationSchema deserOrig = new AvroRowDeserializationSchema(testData.f0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.avro.io.EncoderFactory;
import org.apache.avro.reflect.ReflectData;
import org.apache.avro.specific.SpecificRecord;
import org.apache.avro.specific.SpecificRecordBase;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand Down Expand Up @@ -70,7 +71,7 @@ public static Schema createFlatAvroSchema(String[] fieldNames, TypeInformation[]
/**
* Tests a simple Avro data types without nesting.
*/
public static Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> getSimpleTestData() {
public static Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> getSimpleTestData() {
final Address addr = Address.newBuilder()
.setNum(42)
.setStreet("Main Street 42")
Expand All @@ -86,7 +87,7 @@ public static Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> getSi
rowAddr.setField(3, "Test State");
rowAddr.setField(4, "12345");

final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> t = new Tuple3<>();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> t = new Tuple3<>();
t.f0 = Address.class;
t.f1 = addr;
t.f2 = rowAddr;
Expand All @@ -97,7 +98,7 @@ public static Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> getSi
/**
* Tests all Avro data types as well as nested types.
*/
public static Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> getComplexTestData() {
public static Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> getComplexTestData() {
final Address addr = Address.newBuilder()
.setNum(42)
.setStreet("Main Street 42")
Expand Down Expand Up @@ -148,7 +149,7 @@ public static Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> getCo
rowUser.setField(13, null);
rowUser.setField(14, rowAddr);

final Tuple3<Class<? extends SpecificRecord>, SpecificRecord, Row> t = new Tuple3<>();
final Tuple3<Class<? extends SpecificRecordBase>, SpecificRecord, Row> t = new Tuple3<>();
t.f0 = User.class;
t.f1 = user;
t.f2 = rowUser;
Expand Down

0 comments on commit bc8d1b1

Please sign in to comment.