Skip to content

Commit

Permalink
Add Pojo support to Scala API
Browse files Browse the repository at this point in the history
  • Loading branch information
rmetzger committed Oct 8, 2014
1 parent aca6fbc commit 6b493fb
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public FlatFieldDescriptor(int keyPosition, TypeInformation<?> type) {
public int getPosition() {
return keyPosition;
}

public TypeInformation<?> getType() {
return type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,12 @@ public TupleSerializer<T> createSerializer() {
/**
* Comparator creation
*/
private TypeSerializer<?>[] fieldSerializers;
private TypeComparator<?>[] fieldComparators;
private int[] logicalKeyFields;
private int comparatorHelperIndex = 0;

@Override
protected void initializeNewComparator(int localKeyCount) {
fieldSerializers = new TypeSerializer[localKeyCount];
fieldComparators = new TypeComparator<?>[localKeyCount];
logicalKeyFields = new int[localKeyCount];
comparatorHelperIndex = 0;
Expand All @@ -78,7 +76,6 @@ protected void initializeNewComparator(int localKeyCount) {
@Override
protected void addCompareField(int fieldId, TypeComparator<?> comparator) {
fieldComparators[comparatorHelperIndex] = comparator;
fieldSerializers[comparatorHelperIndex] = types[fieldId].createSerializer();
logicalKeyFields[comparatorHelperIndex] = fieldId;
comparatorHelperIndex++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ else if (t instanceof GenericArrayType) {
return ObjectArrayTypeInfo.getInfoFor(t, componentInfo);
}
// objects with generics are treated as raw type
else if (t instanceof ParameterizedType) {
else if (t instanceof ParameterizedType) { //TODO
return privateGetForClass((Class<OUT>) ((ParameterizedType) t).getRawType(), typeHierarchy);
}
// no tuple, no TypeVariable, no generic type
Expand Down Expand Up @@ -936,14 +936,13 @@ private <X> TypeInformation<X> privateGetForClass(Class<X> clazz, ArrayList<Type
return pojoType;
}


// return a generic type
return new GenericTypeInfo<X>(clazz);
}

/**
* Checks if the given field is a valid pojo field:
* - it is public
* - it is public
* OR
* - there are getter and setter methods for the field.
*
Expand All @@ -968,8 +967,8 @@ private boolean isValidPojoField(Field f, Class<?> clazz, ArrayList<Type> typeHi
for(Method m : clazz.getMethods()) {
// check for getter

if( // The name should be "get<FieldName>".
m.getName().toLowerCase().contains("get"+fieldNameLow) &&
if( // The name should be "get<FieldName>" or "<fieldName>" (for scala).
(m.getName().toLowerCase().contains("get"+fieldNameLow) || m.getName().toLowerCase().contains(fieldNameLow)) &&
// no arguments for the getter
m.getParameterTypes().length == 0 &&
// return type is same as field type (or the generic variant of it)
Expand All @@ -980,20 +979,20 @@ private boolean isValidPojoField(Field f, Class<?> clazz, ArrayList<Type> typeHi
}
hasGetter = true;
}
// check for setters
if( m.getName().toLowerCase().contains("set"+fieldNameLow) &&
m.getParameterTypes().length == 1 && // one parameter of the field's type
( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&&
// return type is void.
m.getReturnType().equals(Void.TYPE)
// check for setters (<FieldName>_$eq for scala)
if((m.getName().toLowerCase().contains("set"+fieldNameLow) || m.getName().toLowerCase().contains(fieldNameLow+"_$eq")) &&
m.getParameterTypes().length == 1 && // one parameter of the field's type
( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&&
// return type is void.
m.getReturnType().equals(Void.TYPE)
) {
if(hasSetter) {
throw new IllegalStateException("Detected more than one getters");
}
hasSetter = true;
}
}
if( hasGetter && hasSetter) {
if(hasGetter && hasSetter) {
return true;
} else {
if(!hasGetter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ public PojoComparator<T> duplicate() {
public int extractKeys(Object record, Object[] target, int index) {
int localIndex = index;
for (int i = 0; i < comparators.length; i++) {
if(comparators[i] instanceof PojoComparator || comparators[i] instanceof TupleComparator) {
if(comparators[i] instanceof CompositeTypeComparator) {
localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex) -1;
} else {
// non-composite case (= atomic). We can assume this to have only one key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ public static class MyObject<T> {
public static class InType extends MyObject<String> {}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
@Ignore
// @Ignore
public void testParamertizedCustomObject() {
RichMapFunction<?, ?> function = new RichMapFunction<InType, MyObject<String>>() {
private static final long serialVersionUID = 1L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.junit.Ignore;


@Ignore // TODO
//@Ignore // TODO
public class PojoComparatorTest extends ComparatorTestBase<PojoContainingTuple> {
TypeInformation<PojoContainingTuple> type = TypeExtractor.getForClass(PojoContainingTuple.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* This only works on CaseClass DataSets.
*/
def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)
// val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)

new GroupedDataSet[T](
this,
new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType,false))
new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType))
}

// public UnsortedGrouping<T> groupBy(String... fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,68 +32,42 @@ abstract class CaseClassTypeInfo[T <: Product](
val fieldNames: Seq[String])
extends TupleTypeInfoBase[T](clazz, fieldTypes: _*) {

override def createComparator(logicalKeyFields: Array[Int],
orders: Array[Boolean], offset: Int): TypeComparator[T] = {
// sanity checks
if (logicalKeyFields == null || orders == null
|| logicalKeyFields.length != orders.length || logicalKeyFields.length > types.length) {
throw new IllegalArgumentException
}

// No special handling of leading Key field as in JavaTupleComparator for now

// --- general case ---
var maxKey: Int = -1
def getFieldIndices(fields: Array[String]): Array[Int] = {
fields map { x => fieldNames.indexOf(x) }
}

for (key <- logicalKeyFields) {
maxKey = Math.max(key, maxKey)
}
/*
* Comparator construction
*/
var fieldComparators: Array[TypeComparator[_]] = null
var logicalKeyFields : Array[Int] = null
var comparatorHelperIndex = 0

if (maxKey >= types.length) {
throw new IllegalArgumentException("The key position " + maxKey + " is out of range for " +
"Tuple" + types.length)
}
override protected def initializeNewComparator(localKeyCount: Int): Unit = {
fieldComparators = new Array(localKeyCount)
logicalKeyFields = new Array(localKeyCount)
comparatorHelperIndex = 0
}

// create the comparators for the individual fields
val fieldComparators: Array[TypeComparator[_]] = new Array(logicalKeyFields.length)
override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = {
fieldComparators(comparatorHelperIndex) = comparator
logicalKeyFields(comparatorHelperIndex) = fieldId
comparatorHelperIndex += 1
}

for (i <- 0 until logicalKeyFields.length) {
val keyPos = logicalKeyFields(i)
if (types(keyPos).isKeyType && types(keyPos).isInstanceOf[AtomicType[_]]) {
fieldComparators(i) = types(keyPos).asInstanceOf[AtomicType[_]].createComparator(orders(i))
} else {
throw new IllegalArgumentException(
"The field at position " + i + " (" + types(keyPos) + ") is no atomic key type.")
}
override protected def getNewComparator: TypeComparator[T] = {
val finalLogicalKeyFields = logicalKeyFields.take(comparatorHelperIndex)
val finalComparators = fieldComparators.take(comparatorHelperIndex)
var maxKey: Int = 0
for (key <- finalLogicalKeyFields) {
maxKey = Math.max(maxKey, key)
}

// create the serializers for the prefix up to highest key position
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](maxKey + 1)

for (i <- 0 to maxKey) {
fieldSerializers(i) = types(i).createSerializer
}

new CaseClassComparator[T](logicalKeyFields, fieldComparators, fieldSerializers)
}

def getFieldIndices(fields: Array[String]): Array[Int] = {
fields map { x => fieldNames.indexOf(x) }
}

override protected def initializeNewComparator(localKeyCount: Int): Unit = {
throw new UnsupportedOperationException("The Scala API is not using the composite " +
"type comparator creation")
}

override protected def getNewComparator: TypeComparator[T] = {
throw new UnsupportedOperationException("The Scala API is not using the composite " +
"type comparator creation")
}

override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = {
throw new UnsupportedOperationException("The Scala API is not using the composite " +
"type comparator creation")
new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers)
}

override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.operators

import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.configuration.Configuration
import org.apache.flink.test.util.JavaProgramTestBase
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.apache.flink.api.scala._
import org.junit.runners.Parameterized.Parameters
import scala.collection.JavaConverters._

import scala.collection.mutable

// TODO case class Tuple2[T1, T2](_1: T1, _2: T2)
// TODO case class Foo(a: Int, b: String)

class Nested(var myLong: Long) {
def this() = {
this(0);
}
}
class Pojo(var myString: String, var myInt: Int, myLong: Long) {
var nested = new Nested(myLong)

def this() = {
this("", 0, 0)
}

override def toString() = "myString="+myString+" myInt="+myInt+" nested.myLong="+nested.myLong
}

object ExampleProgs {
var NUM_PROGRAMS: Int = 3

def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = {
progId match {
case 1 =>
/*
Test nested tuples with int offset
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) )

val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)})
grouped.writeAsText(resultPath)
env.execute()
"((this,hello),3)\n((this,is),3)\n"
case 2 =>
/*
Test nested tuples with int offset
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) )

val grouped = ds.groupBy("f0.f0").reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)})
grouped.writeAsText(resultPath)
env.execute()
"((this,is),6)\n"
case 3 =>
/*
Test nested pojos
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( new Pojo("one", 1, 1L),new Pojo("one", 1, 1L),new Pojo("two", 666, 2L) )

val grouped = ds.groupBy("nested.myLong").reduce {
(p1, p2) =>
p1.myInt += p2.myInt
p1
}
grouped.writeAsText(resultPath)
env.execute()
"myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n"
}
}
}

@RunWith(classOf[Parameterized])
class ExamplesITCase(config: Configuration) extends JavaProgramTestBase(config) {

private var curProgId: Int = config.getInteger("ProgramId", -1)
private var resultPath: String = null
private var expectedResult: String = null

protected override def preSubmit(): Unit = {
resultPath = getTempDirPath("result")
}

protected def testProgram(): Unit = {
expectedResult = ExampleProgs.runProgram(curProgId, resultPath, isCollectionExecution)
}

protected override def postSubmit(): Unit = {
compareResultsByLinesInMemory(expectedResult, resultPath)
}
}

object ExamplesITCase {
@Parameters
def getConfigurations: java.util.Collection[Array[AnyRef]] = {
val configs = mutable.MutableList[Array[AnyRef]]()
for (i <- 1 to ExampleProgs.NUM_PROGRAMS) {
val config = new Configuration()
config.setInteger("ProgramId", i)
configs += Array(config)
}

configs.asJavaCollection
}
}

0 comments on commit 6b493fb

Please sign in to comment.