From 6b493fb080d4a589396c1754f5d89ba802c828d1 Mon Sep 17 00:00:00 2001 From: Robert Metzger Date: Mon, 6 Oct 2014 11:45:24 +0200 Subject: [PATCH] Add Pojo support to Scala API --- .../api/common/typeutils/CompositeType.java | 1 + .../api/java/typeutils/TupleTypeInfo.java | 3 - .../api/java/typeutils/TypeExtractor.java | 23 ++-- .../typeutils/runtime/PojoComparator.java | 2 +- .../type/extractor/TypeExtractorTest.java | 2 +- .../typeutils/runtime/PojoComparatorTest.java | 2 +- .../org/apache/flink/api/scala/DataSet.scala | 4 +- .../scala/typeutils/CaseClassTypeInfo.scala | 78 ++++------- .../api/scala/operators/ExamplesITCase.scala | 127 ++++++++++++++++++ 9 files changed, 170 insertions(+), 72 deletions(-) create mode 100644 flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java index 60e9eab5807e0..1522ed10c1fae 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java @@ -117,6 +117,7 @@ public FlatFieldDescriptor(int keyPosition, TypeInformation type) { public int getPosition() { return keyPosition; } + public TypeInformation getType() { return type; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java index 82f9c50817619..177f03390f7bf 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java @@ -62,14 +62,12 @@ public TupleSerializer createSerializer() { /** * Comparator creation */ - private TypeSerializer[] fieldSerializers; private TypeComparator[] fieldComparators; private int[] logicalKeyFields; private int comparatorHelperIndex = 0; @Override protected void initializeNewComparator(int localKeyCount) { - fieldSerializers = new TypeSerializer[localKeyCount]; fieldComparators = new TypeComparator[localKeyCount]; logicalKeyFields = new int[localKeyCount]; comparatorHelperIndex = 0; @@ -78,7 +76,6 @@ protected void initializeNewComparator(int localKeyCount) { @Override protected void addCompareField(int fieldId, TypeComparator comparator) { fieldComparators[comparatorHelperIndex] = comparator; - fieldSerializers[comparatorHelperIndex] = types[fieldId].createSerializer(); logicalKeyFields[comparatorHelperIndex] = fieldId; comparatorHelperIndex++; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java index 6231a74fa11d2..5d216e9def006 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java @@ -398,7 +398,7 @@ else if (t instanceof GenericArrayType) { return ObjectArrayTypeInfo.getInfoFor(t, componentInfo); } // objects with generics are treated as raw type - else if (t instanceof ParameterizedType) { + else if (t instanceof ParameterizedType) { //TODO return privateGetForClass((Class) ((ParameterizedType) t).getRawType(), typeHierarchy); } // no tuple, no TypeVariable, no generic type @@ -936,14 +936,13 @@ private TypeInformation privateGetForClass(Class clazz, ArrayList(clazz); } /** * Checks if the given field is a valid pojo field: - * - it is public + * - it is public * OR * - there are getter and setter methods for the field. * @@ -968,8 +967,8 @@ private boolean isValidPojoField(Field f, Class clazz, ArrayList typeHi for(Method m : clazz.getMethods()) { // check for getter - if( // The name should be "get". - m.getName().toLowerCase().contains("get"+fieldNameLow) && + if( // The name should be "get" or "" (for scala). + (m.getName().toLowerCase().contains("get"+fieldNameLow) || m.getName().toLowerCase().contains(fieldNameLow)) && // no arguments for the getter m.getParameterTypes().length == 0 && // return type is same as field type (or the generic variant of it) @@ -980,12 +979,12 @@ private boolean isValidPojoField(Field f, Class clazz, ArrayList typeHi } hasGetter = true; } - // check for setters - if( m.getName().toLowerCase().contains("set"+fieldNameLow) && - m.getParameterTypes().length == 1 && // one parameter of the field's type - ( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&& - // return type is void. - m.getReturnType().equals(Void.TYPE) + // check for setters (_$eq for scala) + if((m.getName().toLowerCase().contains("set"+fieldNameLow) || m.getName().toLowerCase().contains(fieldNameLow+"_$eq")) && + m.getParameterTypes().length == 1 && // one parameter of the field's type + ( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&& + // return type is void. + m.getReturnType().equals(Void.TYPE) ) { if(hasSetter) { throw new IllegalStateException("Detected more than one getters"); @@ -993,7 +992,7 @@ private boolean isValidPojoField(Field f, Class clazz, ArrayList typeHi hasSetter = true; } } - if( hasGetter && hasSetter) { + if(hasGetter && hasSetter) { return true; } else { if(!hasGetter) { diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java index 9d7eed44b161a..51d8090ad2ef9 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java @@ -340,7 +340,7 @@ public PojoComparator duplicate() { public int extractKeys(Object record, Object[] target, int index) { int localIndex = index; for (int i = 0; i < comparators.length; i++) { - if(comparators[i] instanceof PojoComparator || comparators[i] instanceof TupleComparator) { + if(comparators[i] instanceof CompositeTypeComparator) { localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex) -1; } else { // non-composite case (= atomic). We can assume this to have only one key. diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java index a092a532752c3..b4b1c1917c8ad 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java @@ -1247,7 +1247,7 @@ public static class MyObject { public static class InType extends MyObject {} @SuppressWarnings({ "rawtypes", "unchecked" }) @Test - @Ignore +// @Ignore public void testParamertizedCustomObject() { RichMapFunction function = new RichMapFunction>() { private static final long serialVersionUID = 1L; diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoComparatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoComparatorTest.java index 5cb31ca3cfcda..61f6167a79b1b 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoComparatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoComparatorTest.java @@ -31,7 +31,7 @@ import org.junit.Ignore; -@Ignore // TODO +//@Ignore // TODO public class PojoComparatorTest extends ComparatorTestBase { TypeInformation type = TypeExtractor.getForClass(PojoContainingTuple.class); diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 28624bccbfc5a..895b96407895a 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -615,11 +615,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { * This only works on CaseClass DataSets. */ def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = { - val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray) + // val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray) new GroupedDataSet[T]( this, - new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType,false)) + new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType)) } // public UnsortedGrouping groupBy(String... fields) { diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala index c9a3bbff038ce..3e9d4c63008f6 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala @@ -32,68 +32,42 @@ abstract class CaseClassTypeInfo[T <: Product]( val fieldNames: Seq[String]) extends TupleTypeInfoBase[T](clazz, fieldTypes: _*) { - override def createComparator(logicalKeyFields: Array[Int], - orders: Array[Boolean], offset: Int): TypeComparator[T] = { - // sanity checks - if (logicalKeyFields == null || orders == null - || logicalKeyFields.length != orders.length || logicalKeyFields.length > types.length) { - throw new IllegalArgumentException - } - - // No special handling of leading Key field as in JavaTupleComparator for now - - // --- general case --- - var maxKey: Int = -1 + def getFieldIndices(fields: Array[String]): Array[Int] = { + fields map { x => fieldNames.indexOf(x) } + } - for (key <- logicalKeyFields) { - maxKey = Math.max(key, maxKey) - } + /* + * Comparator construction + */ + var fieldComparators: Array[TypeComparator[_]] = null + var logicalKeyFields : Array[Int] = null + var comparatorHelperIndex = 0 - if (maxKey >= types.length) { - throw new IllegalArgumentException("The key position " + maxKey + " is out of range for " + - "Tuple" + types.length) - } + override protected def initializeNewComparator(localKeyCount: Int): Unit = { + fieldComparators = new Array(localKeyCount) + logicalKeyFields = new Array(localKeyCount) + comparatorHelperIndex = 0 + } - // create the comparators for the individual fields - val fieldComparators: Array[TypeComparator[_]] = new Array(logicalKeyFields.length) + override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = { + fieldComparators(comparatorHelperIndex) = comparator + logicalKeyFields(comparatorHelperIndex) = fieldId + comparatorHelperIndex += 1 + } - for (i <- 0 until logicalKeyFields.length) { - val keyPos = logicalKeyFields(i) - if (types(keyPos).isKeyType && types(keyPos).isInstanceOf[AtomicType[_]]) { - fieldComparators(i) = types(keyPos).asInstanceOf[AtomicType[_]].createComparator(orders(i)) - } else { - throw new IllegalArgumentException( - "The field at position " + i + " (" + types(keyPos) + ") is no atomic key type.") - } + override protected def getNewComparator: TypeComparator[T] = { + val finalLogicalKeyFields = logicalKeyFields.take(comparatorHelperIndex) + val finalComparators = fieldComparators.take(comparatorHelperIndex) + var maxKey: Int = 0 + for (key <- finalLogicalKeyFields) { + maxKey = Math.max(maxKey, key) } - - // create the serializers for the prefix up to highest key position val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](maxKey + 1) for (i <- 0 to maxKey) { fieldSerializers(i) = types(i).createSerializer } - - new CaseClassComparator[T](logicalKeyFields, fieldComparators, fieldSerializers) - } - - def getFieldIndices(fields: Array[String]): Array[Int] = { - fields map { x => fieldNames.indexOf(x) } - } - - override protected def initializeNewComparator(localKeyCount: Int): Unit = { - throw new UnsupportedOperationException("The Scala API is not using the composite " + - "type comparator creation") - } - - override protected def getNewComparator: TypeComparator[T] = { - throw new UnsupportedOperationException("The Scala API is not using the composite " + - "type comparator creation") - } - - override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = { - throw new UnsupportedOperationException("The Scala API is not using the composite " + - "type comparator creation") + new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers) } override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala new file mode 100644 index 0000000000000..d5ae6b6ba9852 --- /dev/null +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/ExamplesITCase.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.apache.flink.api.scala._ +import org.junit.runners.Parameterized.Parameters +import scala.collection.JavaConverters._ + +import scala.collection.mutable + +// TODO case class Tuple2[T1, T2](_1: T1, _2: T2) +// TODO case class Foo(a: Int, b: String) + +class Nested(var myLong: Long) { + def this() = { + this(0); + } +} +class Pojo(var myString: String, var myInt: Int, myLong: Long) { + var nested = new Nested(myLong) + + def this() = { + this("", 0, 0) + } + + override def toString() = "myString="+myString+" myInt="+myInt+" nested.myLong="+nested.myLong +} + +object ExampleProgs { + var NUM_PROGRAMS: Int = 3 + + def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = { + progId match { + case 1 => + /* + Test nested tuples with int offset + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) + + val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)}) + grouped.writeAsText(resultPath) + env.execute() + "((this,hello),3)\n((this,is),3)\n" + case 2 => + /* + Test nested tuples with int offset + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) ) + + val grouped = ds.groupBy("f0.f0").reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)}) + grouped.writeAsText(resultPath) + env.execute() + "((this,is),6)\n" + case 3 => + /* + Test nested pojos + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromElements( new Pojo("one", 1, 1L),new Pojo("one", 1, 1L),new Pojo("two", 666, 2L) ) + + val grouped = ds.groupBy("nested.myLong").reduce { + (p1, p2) => + p1.myInt += p2.myInt + p1 + } + grouped.writeAsText(resultPath) + env.execute() + "myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n" + } + } +} + +@RunWith(classOf[Parameterized]) +class ExamplesITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = ExampleProgs.runProgram(curProgId, resultPath, isCollectionExecution) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object ExamplesITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to ExampleProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} \ No newline at end of file