diff --git a/flink-addons/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroInputFormatTypeExtractionTest.java b/flink-addons/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroInputFormatTypeExtractionTest.java index fe3b6c86f2551..23fbab3a671f7 100644 --- a/flink-addons/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroInputFormatTypeExtractionTest.java +++ b/flink-addons/flink-avro/src/test/java/org/apache/flink/api/java/io/AvroInputFormatTypeExtractionTest.java @@ -23,7 +23,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.core.fs.Path; import org.junit.Assert; @@ -43,8 +43,8 @@ public void testTypeExtraction() { TypeInformation typeInfoDataSet = input.getType(); - Assert.assertTrue(typeInfoDirect instanceof GenericTypeInfo); - Assert.assertTrue(typeInfoDataSet instanceof GenericTypeInfo); + Assert.assertTrue(typeInfoDirect instanceof PojoTypeInfo); + Assert.assertTrue(typeInfoDataSet instanceof PojoTypeInfo); Assert.assertEquals(MyAvroType.class, typeInfoDirect.getTypeClass()); Assert.assertEquals(MyAvroType.class, typeInfoDataSet.getTypeClass()); diff --git a/flink-addons/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java b/flink-addons/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java index 5d24fad2a38a0..7d004f9737e5c 100644 --- a/flink-addons/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java +++ b/flink-addons/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java @@ -26,7 +26,6 @@ import org.junit.Assert; -import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple5; import org.junit.After; import org.junit.AfterClass; diff --git a/flink-compiler/src/main/java/org/apache/flink/compiler/postpass/JavaApiPostPass.java b/flink-compiler/src/main/java/org/apache/flink/compiler/postpass/JavaApiPostPass.java index 4756bc9f9c088..acce51fdf704d 100644 --- a/flink-compiler/src/main/java/org/apache/flink/compiler/postpass/JavaApiPostPass.java +++ b/flink-compiler/src/main/java/org/apache/flink/compiler/postpass/JavaApiPostPass.java @@ -31,8 +31,8 @@ import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; import org.apache.flink.api.common.operators.util.FieldList; import org.apache.flink.api.common.typeinfo.AtomicType; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeComparatorFactory; import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; @@ -40,6 +40,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializerFactory; import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator; import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimeStatefulSerializerFactory; @@ -292,7 +293,7 @@ private static TypeComparatorFactory createComparator(TypeInformation TypeComparator comparator; if (typeInfo instanceof CompositeType) { - comparator = ((CompositeType) typeInfo).createComparator(keys.toArray(), sortOrder); + comparator = ((CompositeType) typeInfo).createComparator(keys.toArray(), sortOrder, 0); } else if (typeInfo instanceof AtomicType) { // handle grouping of atomic types @@ -306,8 +307,8 @@ else if (typeInfo instanceof AtomicType) { } private static TypePairComparatorFactory createPairComparator(TypeInformation typeInfo1, TypeInformation typeInfo2) { - if (!(typeInfo1.isTupleType() && typeInfo2.isTupleType())) { - throw new RuntimeException("The runtime currently supports only keyed binary operations on tuples."); + if (!(typeInfo1.isTupleType() || typeInfo1 instanceof PojoTypeInfo) && (typeInfo2.isTupleType() || typeInfo2 instanceof PojoTypeInfo)) { + throw new RuntimeException("The runtime currently supports only keyed binary operations (such as joins) on tuples and POJO types."); } // @SuppressWarnings("unchecked") diff --git a/flink-compiler/src/test/java/org/apache/flink/compiler/WorksetIterationsJavaApiCompilerTest.java b/flink-compiler/src/test/java/org/apache/flink/compiler/WorksetIterationsJavaApiCompilerTest.java index d07743e993998..ef756d01f05d2 100644 --- a/flink-compiler/src/test/java/org/apache/flink/compiler/WorksetIterationsJavaApiCompilerTest.java +++ b/flink-compiler/src/test/java/org/apache/flink/compiler/WorksetIterationsJavaApiCompilerTest.java @@ -78,7 +78,7 @@ public void testJavaApiWithDeferredSoltionSetUpdateWithMapper() { // verify joinWithSolutionSet assertEquals(ShipStrategyType.PARTITION_HASH, joinWithSolutionSetNode.getInput1().getShipStrategy()); assertEquals(ShipStrategyType.FORWARD, joinWithSolutionSetNode.getInput2().getShipStrategy()); - assertEquals(new FieldList(0, 1), joinWithSolutionSetNode.getKeysForInput1()); + assertEquals(new FieldList(1, 0), joinWithSolutionSetNode.getKeysForInput1()); // verify reducer @@ -125,7 +125,7 @@ public void testRecordApiWithDeferredSoltionSetUpdateWithNonPreservingJoin() { // verify joinWithSolutionSet assertEquals(ShipStrategyType.PARTITION_HASH, joinWithSolutionSetNode.getInput1().getShipStrategy()); assertEquals(ShipStrategyType.FORWARD, joinWithSolutionSetNode.getInput2().getShipStrategy()); - assertEquals(new FieldList(0, 1), joinWithSolutionSetNode.getKeysForInput1()); + assertEquals(new FieldList(1, 0), joinWithSolutionSetNode.getKeysForInput1()); // verify reducer assertEquals(ShipStrategyType.PARTITION_HASH, worksetReducer.getInput().getShipStrategy()); @@ -170,7 +170,7 @@ public void testRecordApiWithDirectSoltionSetUpdate() { // verify joinWithSolutionSet assertEquals(ShipStrategyType.PARTITION_HASH, joinWithSolutionSetNode.getInput1().getShipStrategy()); assertEquals(ShipStrategyType.FORWARD, joinWithSolutionSetNode.getInput2().getShipStrategy()); - assertEquals(new FieldList(0, 1), joinWithSolutionSetNode.getKeysForInput1()); + assertEquals(new FieldList(1, 0), joinWithSolutionSetNode.getKeysForInput1()); // verify reducer assertEquals(ShipStrategyType.FORWARD, worksetReducer.getInput().getShipStrategy()); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/InvalidProgramException.java b/flink-core/src/main/java/org/apache/flink/api/common/InvalidProgramException.java index 73162ee0d301f..675398259de36 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/InvalidProgramException.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/InvalidProgramException.java @@ -43,4 +43,8 @@ public InvalidProgramException() { public InvalidProgramException(String message) { super(message); } + + public InvalidProgramException(String message, Throwable e) { + super(message, e); + } } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/InvalidTypesException.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/InvalidTypesException.java index 05015054ae310..611765e542411 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/functions/InvalidTypesException.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/InvalidTypesException.java @@ -43,4 +43,8 @@ public InvalidTypesException() { public InvalidTypesException(String message) { super(message); } + + public InvalidTypesException(String message, Throwable e) { + super(message, e); + } } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java index ac8e1f3eb223a..ea68554d47b74 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java @@ -45,8 +45,8 @@ import org.apache.flink.api.common.operators.base.DeltaIterationBase.SolutionSetPlaceHolder; import org.apache.flink.api.common.operators.base.DeltaIterationBase.WorksetPlaceHolder; import org.apache.flink.api.common.operators.util.TypeComparable; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.types.Value; import org.apache.flink.util.Visitor; @@ -349,7 +349,7 @@ private List executeDeltaIteration(DeltaIterationBase iteration) th int[] keyColumns = iteration.getSolutionSetKeyFields(); boolean[] inputOrderings = new boolean[keyColumns.length]; - TypeComparator inputComparator = ((CompositeType) solutionType).createComparator(keyColumns, inputOrderings); + TypeComparator inputComparator = ((CompositeType) solutionType).createComparator(keyColumns, inputOrderings, 0); Map, T> solutionMap = new HashMap, T>(solutionInputData.size()); // fill the solution from the initial input diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java index 6aa3da0f02b0b..bca909f1e6f30 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java @@ -31,8 +31,8 @@ import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; @@ -227,13 +227,12 @@ protected List executeOnCollections(List input1, List input2, Run return result; } - @SuppressWarnings("unchecked") private TypeComparator getTypeComparator(TypeInformation inputType, int[] inputKeys, boolean[] inputSortDirections) { if (!(inputType instanceof CompositeType)) { throw new InvalidProgramException("Input types of coGroup must be composite types."); } - return ((CompositeType) inputType).createComparator(inputKeys, inputSortDirections); + return ((CompositeType) inputType).createComparator(inputKeys, inputSortDirections, 0); } private static class CoGroupSortListIterator { diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java index 8a3bf6597b54c..f500717e68586 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java @@ -33,11 +33,13 @@ import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; +import com.google.common.base.Preconditions; + import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -154,15 +156,18 @@ protected List executeOnCollections(List inputData, RuntimeContext ctx, } if (inputType instanceof CompositeType) { - @SuppressWarnings("unchecked") - final TypeComparator sortComparator = ((CompositeType) inputType).createComparator(sortColumns, sortOrderings); - - Collections.sort(inputData, new Comparator() { - @Override - public int compare(IN o1, IN o2) { - return sortComparator.compare(o1, o2); - } - }); + if(sortColumns.length == 0) { // => all reduce. No comparator + Preconditions.checkArgument(sortOrderings.length == 0); + } else { + final TypeComparator sortComparator = ((CompositeType) inputType).createComparator(sortColumns, sortOrderings, 0); + + Collections.sort(inputData, new Comparator() { + @Override + public int compare(IN o1, IN o2) { + return sortComparator.compare(o1, o2); + } + }); + } } FunctionUtils.setFunctionRuntimeContext(function, ctx); @@ -188,7 +193,7 @@ public int compare(IN o1, IN o2) { } else { final TypeSerializer inputSerializer = inputType.createSerializer(); boolean[] keyOrderings = new boolean[keyColumns.length]; - final TypeComparator comparator = ((CompositeType) inputType).createComparator(keyColumns, keyOrderings); + final TypeComparator comparator = ((CompositeType) inputType).createComparator(keyColumns, keyOrderings, 0); ListKeyGroupedIterator keyedIterator = new ListKeyGroupedIterator(inputData, inputSerializer, comparator, mutableObjectSafeMode); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java index 3d5cf725075e8..7bfe39fe812d6 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java @@ -29,8 +29,8 @@ import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; import org.apache.flink.api.common.typeinfo.AtomicType; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; @@ -86,7 +86,7 @@ else if(leftInformation instanceof CompositeType){ boolean[] orders = new boolean[keyPositions.length]; Arrays.fill(orders, true); - leftComparator = ((CompositeType) leftInformation).createComparator(keyPositions, orders); + leftComparator = ((CompositeType) leftInformation).createComparator(keyPositions, orders, 0); }else{ throw new RuntimeException("Type information for left input of type " + leftInformation.getClass() .getCanonicalName() + " is not supported. Could not generate a comparator."); @@ -99,7 +99,7 @@ else if(leftInformation instanceof CompositeType){ boolean[] orders = new boolean[keyPositions.length]; Arrays.fill(orders, true); - rightComparator = ((CompositeType) rightInformation).createComparator(keyPositions, orders); + rightComparator = ((CompositeType) rightInformation).createComparator(keyPositions, orders, 0); }else{ throw new RuntimeException("Type information for right input of type " + rightInformation.getClass() .getCanonicalName() + " is not supported. Could not generate a comparator."); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java index dc0163777ad8b..30ff1768a1f0b 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/ReduceOperatorBase.java @@ -28,8 +28,8 @@ import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -125,7 +125,6 @@ public ReduceOperatorBase(Class udf, UnaryOperatorInformation executeOnCollections(List inputData, RuntimeContext ctx, boolean mutableObjectSafeMode) throws Exception { // make sure we can handle empty inputs @@ -151,7 +150,7 @@ protected List executeOnCollections(List inputData, RuntimeContext ctx, bo if (inputColumns.length > 0) { boolean[] inputOrderings = new boolean[inputColumns.length]; - TypeComparator inputComparator = ((CompositeType) inputType).createComparator(inputColumns, inputOrderings); + TypeComparator inputComparator = ((CompositeType) inputType).createComparator(inputColumns, inputOrderings, 0); Map, T> aggregateMap = new HashMap, T>(inputData.size() / 10); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicArrayTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicArrayTypeInfo.java index bbf4c2a4129af..646a549eb4ee7 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicArrayTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicArrayTypeInfo.java @@ -68,6 +68,11 @@ public boolean isTupleType() { public int getArity() { return 1; } + + @Override + public int getTotalFields() { + return 1; + } @Override public Class getTypeClass() { diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicTypeInfo.java index cb2247d77a7e5..a152b4a45c133 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/BasicTypeInfo.java @@ -91,6 +91,11 @@ public int getArity() { return 1; } + @Override + public int getTotalFields() { + return 1; + } + @Override public Class getTypeClass() { return this.clazz; diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/NothingTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/NothingTypeInfo.java index 5cd5db39be01c..dba0e6f8c3184 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/NothingTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/NothingTypeInfo.java @@ -38,6 +38,11 @@ public int getArity() { return 0; } + @Override + public int getTotalFields() { + return 0; + } + @Override public Class getTypeClass() { return Nothing.class; diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/PrimitiveArrayTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/PrimitiveArrayTypeInfo.java index 19ef6636a2220..5163801f80f0f 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/PrimitiveArrayTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/PrimitiveArrayTypeInfo.java @@ -69,6 +69,11 @@ public boolean isTupleType() { public int getArity() { return 1; } + + @Override + public int getTotalFields() { + return 1; + } @Override public Class getTypeClass() { diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/TypeInformation.java b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/TypeInformation.java index 6eb08b8d78aa2..0f86486882974 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/TypeInformation.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/TypeInformation.java @@ -33,4 +33,10 @@ public abstract class TypeInformation { public abstract boolean isKeyType(); public abstract TypeSerializer createSerializer(); + + /** + * @return The number of fields in this type, including its sub-fields (for compsite types) + */ + public abstract int getTotalFields(); + } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java new file mode 100644 index 0000000000000..60e9eab5807e0 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeType.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.typeutils; + +import java.util.List; + +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.TypeInformation; + + +/** + * Type Information for Tuple and Pojo types + * + * The class is taking care of serialization and comparators for Tuples as well. + * See @see {@link Keys} class for fields setup. + */ +public abstract class CompositeType extends TypeInformation { + + protected final Class typeClass; + + public CompositeType(Class typeClass) { + this.typeClass = typeClass; + } + + /** + * Returns the keyPosition for the given fieldPosition, offsetted by the given offset + */ + public abstract void getKey(String fieldExpression, int offset, List result); + + public abstract TypeInformation getTypeAt(int pos); + + /** + * Initializes the internal state inside a Composite type to create a new comparator + * (such as the lists / arrays for the fields and field comparators) + * @param localKeyCount + */ + protected abstract void initializeNewComparator(int localKeyCount); + + /** + * Add a field for comparison in this type. + */ + protected abstract void addCompareField(int fieldId, TypeComparator comparator); + + /** + * Get the actual comparator we've initialized. + */ + protected abstract TypeComparator getNewComparator(); + + + /** + * Generic implementation of the comparator creation. Composite types are supplying the infrastructure + * to create the actual comparators + * @return + */ + public TypeComparator createComparator(int[] logicalKeyFields, boolean[] orders, int logicalFieldOffset) { + initializeNewComparator(logicalKeyFields.length); + + for(int logicalKeyFieldIndex = 0; logicalKeyFieldIndex < logicalKeyFields.length; logicalKeyFieldIndex++) { + int logicalKeyField = logicalKeyFields[logicalKeyFieldIndex]; + int logicalField = logicalFieldOffset; // this is the global/logical field number + for(int localFieldId = 0; localFieldId < this.getArity(); localFieldId++) { + TypeInformation localFieldType = this.getTypeAt(localFieldId); + + if(localFieldType instanceof AtomicType && logicalField == logicalKeyField) { + // we found an atomic key --> create comparator + addCompareField(localFieldId, ((AtomicType) localFieldType).createComparator(orders[logicalKeyFieldIndex]) ); + } else if(localFieldType instanceof CompositeType && // must be a composite type + ( logicalField <= logicalKeyField //check if keyField can be at or behind the current logicalField + && logicalKeyField <= logicalField + (localFieldType.getTotalFields() - 1) ) // check if logical field + lookahead could contain our key + ) { + // we found a compositeType that is containing the logicalKeyField we are looking for --> create comparator + addCompareField(localFieldId, ((CompositeType) localFieldType).createComparator(new int[] {logicalKeyField}, orders, logicalField)); + } + + // maintain logicalField + if(localFieldType instanceof CompositeType) { + // we need to subtract 1 because we are not accounting for the local field (not accessible for the user) + logicalField += localFieldType.getTotalFields() - 1; + } + logicalField++; + } + } + return getNewComparator(); + } + + + + public static class FlatFieldDescriptor { + private int keyPosition; + private TypeInformation type; + + public FlatFieldDescriptor(int keyPosition, TypeInformation type) { + if( !(type instanceof AtomicType)) { + throw new IllegalArgumentException("A flattened field can only be an atomic type"); + } + this.keyPosition = keyPosition; + this.type = type; + } + + + public int getPosition() { + return keyPosition; + } + public TypeInformation getType() { + return type; + } + + @Override + public String toString() { + return "FlatFieldDescriptor [position="+keyPosition+" typeInfo="+type+"]"; + } + } +} diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/CompositeType.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeTypeComparator.java similarity index 58% rename from flink-core/src/main/java/org/apache/flink/api/common/typeinfo/CompositeType.java rename to flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeTypeComparator.java index 075b528a8ce9e..8323dc3d4e67f 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeinfo/CompositeType.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/CompositeTypeComparator.java @@ -16,15 +16,21 @@ * limitations under the License. */ -package org.apache.flink.api.common.typeinfo; +package org.apache.flink.api.common.typeutils; -import org.apache.flink.api.common.typeutils.TypeComparator; +import java.util.LinkedList; +import java.util.List; +public abstract class CompositeTypeComparator extends TypeComparator { + + private static final long serialVersionUID = 1L; -/** - * - */ -public interface CompositeType { - - TypeComparator createComparator(int[] logicalKeyFields, boolean[] orders); + @Override + public TypeComparator[] getFlatComparators() { + List flatComparators = new LinkedList(); + this.getFlatComparator(flatComparators); + return flatComparators.toArray(new TypeComparator[flatComparators.size()]); + } + + public abstract void getFlatComparator(List flatComparators); } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java index b64a3a3bc4590..09d4aeef499e7 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java @@ -36,13 +36,15 @@ public class GenericPairComparator extends TypePairComparator private final TypeComparator[] comparators2; private final Object[] referenceKeyFields; + + private final Object[] candidateKeyFields; @SuppressWarnings("unchecked") public GenericPairComparator(TypeComparator comparator1, TypeComparator comparator2) { this.comparator1 = comparator1; this.comparator2 = comparator2; - this.comparators1 = comparator1.getComparators(); - this.comparators2 = comparator2.getComparators(); + this.comparators1 = comparator1.getFlatComparators(); + this.comparators2 = comparator2.getFlatComparators(); if(comparators1.length != comparators2.length) { throw new IllegalArgumentException("Number of key fields and comparators differ."); @@ -56,19 +58,19 @@ public GenericPairComparator(TypeComparator comparator1, TypeComparator } this.referenceKeyFields = new Object[numKeys]; + this.candidateKeyFields = new Object[numKeys]; } @Override public void setReference(T1 reference) { - Object[] keys = comparator1.extractKeys(reference); - System.arraycopy(keys, 0, referenceKeyFields, 0, keys.length); + comparator1.extractKeys(reference, referenceKeyFields, 0); } @Override public boolean equalToReference(T2 candidate) { - Object[] keys = comparator2.extractKeys(candidate); + comparator2.extractKeys(candidate, candidateKeyFields, 0); for (int i = 0; i < this.comparators1.length; i++) { - if (this.comparators1[i].compare(referenceKeyFields[i], keys[i]) != 0) { + if (this.comparators1[i].compare(referenceKeyFields[i], candidateKeyFields[i]) != 0) { return false; } } @@ -77,11 +79,11 @@ public boolean equalToReference(T2 candidate) { @Override public int compareToReference(T2 candidate) { - Object[] keys = comparator2.extractKeys(candidate); + comparator2.extractKeys(candidate, candidateKeyFields, 0); for (int i = 0; i < this.comparators1.length; i++) { // We reverse ordering here because our "compareToReference" does work in a mirrored // way compared to Comparable.compareTo - int res = this.comparators1[i].compare(keys[i], referenceKeyFields[i]); + int res = this.comparators1[i].compare(candidateKeyFields[i], referenceKeyFields[i]); if(res != 0) { return res; } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java index 27487ed16ed37..f98f6e0b09288 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeComparator.java @@ -289,15 +289,16 @@ public boolean supportsCompareAgainstReference() { /** * Extracts the key fields from a record. This is for use by the PairComparator to provide * interoperability between different record types. + * @return the number of keys added to target. */ - public abstract Object[] extractKeys(T record); + public abstract int extractKeys(Object record, Object[] target, int index); /** * Get the field comparators. This is used together with {@link #extractKeys(Object)} to provide * interoperability between different record types. */ @SuppressWarnings("rawtypes") - public abstract TypeComparator[] getComparators(); + public abstract TypeComparator[] getFlatComparators(); // -------------------------------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypePairComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypePairComparator.java index a11da737ea9e7..fe278fd577508 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypePairComparator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypePairComparator.java @@ -15,7 +15,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.api.common.typeutils; /** diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/BasicTypeComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/BasicTypeComparator.java index 4e44326902588..d1d8e0f4e93fd 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/BasicTypeComparator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/BasicTypeComparator.java @@ -33,10 +33,6 @@ public abstract class BasicTypeComparator> extends TypeC protected final boolean ascendingComparison; - // This is used in extractKeys, so that we don't create a new array for every call. - @SuppressWarnings("rawtypes") - private final Comparable[] extractedKey = new Comparable[1]; - // For use by getComparators @SuppressWarnings("rawtypes") private final TypeComparator[] comparators = new TypeComparator[] {this}; @@ -89,14 +85,14 @@ public void writeWithKeyNormalization(T record, DataOutputView target) throws IO } @Override - public Object[] extractKeys(T record) { - extractedKey[0] = record; - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; } - @Override @SuppressWarnings("rawtypes") - public TypeComparator[] getComparators() { + @Override + public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/record/RecordComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/record/RecordComparator.java index 4e1f64ae8c448..605d6a1356570 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/record/RecordComparator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/record/RecordComparator.java @@ -391,13 +391,14 @@ public final Key[] getKeysAsCopy(Record record) { } @Override - public Object[] extractKeys(Record record) { + public int extractKeys(Object record, Object[] target, int index) { throw new UnsupportedOperationException("Record does not support extactKeys and " + "getComparators. This cannot be used with the GenericPairComparator."); } + @Override - public TypeComparator[] getComparators() { + public TypeComparator[] getFlatComparators() { throw new UnsupportedOperationException("Record does not support extactKeys and " + "getComparators. This cannot be used with the GenericPairComparator."); } diff --git a/flink-core/src/main/java/org/apache/flink/types/KeyFieldOutOfBoundsException.java b/flink-core/src/main/java/org/apache/flink/types/KeyFieldOutOfBoundsException.java index 1b9f1089ce1c8..3d04efafe980d 100644 --- a/flink-core/src/main/java/org/apache/flink/types/KeyFieldOutOfBoundsException.java +++ b/flink-core/src/main/java/org/apache/flink/types/KeyFieldOutOfBoundsException.java @@ -62,6 +62,11 @@ public KeyFieldOutOfBoundsException(int fieldNumber) { this.fieldNumber = fieldNumber; } + public KeyFieldOutOfBoundsException(int fieldNumber, Throwable parent) { + super("Field " + fieldNumber + " is accessed for a key, but out of bounds in the record.", parent); + this.fieldNumber = fieldNumber; + } + /** * Gets the field number that was attempted to access. If the number is not set, this method returns * {@code -1}. diff --git a/flink-examples/flink-java-examples/pom.xml b/flink-examples/flink-java-examples/pom.xml index 836b0bc313d79..3ebd85e7ba26a 100644 --- a/flink-examples/flink-java-examples/pom.xml +++ b/flink-examples/flink-java-examples/pom.xml @@ -296,6 +296,30 @@ under the License. + + + WordCountPOJO + package + + jar + + + WordCountPOJO + + + + org.apache.flink.examples.java.wordcount.PojoExample + + + + + **/java/wordcount/PojoExample.class + **/java/wordcount/PojoExample$*.class + **/java/wordcount/util/WordCountData.class + + + + diff --git a/flink-examples/flink-java-examples/src/main/java/org/apache/flink/example/java/environments/CollectionExecutionExample.java b/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/environments/CollectionExecutionExample.java similarity index 96% rename from flink-examples/flink-java-examples/src/main/java/org/apache/flink/example/java/environments/CollectionExecutionExample.java rename to flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/environments/CollectionExecutionExample.java index bc260d59f7a17..1ce3e7a6f5e8f 100644 --- a/flink-examples/flink-java-examples/src/main/java/org/apache/flink/example/java/environments/CollectionExecutionExample.java +++ b/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/environments/CollectionExecutionExample.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.example.java.environments; +package org.apache.flink.examples.java.environments; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.CollectionEnvironment; diff --git a/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/example/java/remotecollectoroutputformat/RemoteCollectorOutputFormatExample.java b/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/remotecollectoroutputformat/RemoteCollectorOutputFormatExample.java similarity index 98% rename from flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/example/java/remotecollectoroutputformat/RemoteCollectorOutputFormatExample.java rename to flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/remotecollectoroutputformat/RemoteCollectorOutputFormatExample.java index 30c8fc465611c..36b5c8257f49d 100644 --- a/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/example/java/remotecollectoroutputformat/RemoteCollectorOutputFormatExample.java +++ b/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/remotecollectoroutputformat/RemoteCollectorOutputFormatExample.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.example.java.remotecollectoroutputformat; +package org.apache.flink.examples.java.remotecollectoroutputformat; import java.util.HashSet; import java.util.Set; @@ -111,4 +111,4 @@ public void flatMap(String value, Collector> out) { } } } -} +} \ No newline at end of file diff --git a/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/wordcount/PojoExample.java b/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/wordcount/PojoExample.java new file mode 100644 index 0000000000000..a79462e5f7f10 --- /dev/null +++ b/flink-examples/flink-java-examples/src/main/java/org/apache/flink/examples/java/wordcount/PojoExample.java @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.examples.java.wordcount; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.core.fs.FileSystem.WriteMode; +import org.apache.flink.examples.java.wordcount.util.WordCountData; +import org.apache.flink.util.Collector; + + +/** + * This example shows an implementation of Wordcount without using the + * Tuple2 type, but a custom class. + * + */ +@SuppressWarnings("serial") +public class PojoExample { + + /** + * This is the POJO (Plain Old Java Object) that is bein used + * for all the operations. + * As long as all fields are public or have a getter/setter, the system can handle them + */ + public static class Word { + // fields + private String word; + private Integer frequency; + + // constructors + public Word() { + } + public Word(String word, int i) { + this.word = word; + this.frequency = i; + } + // getters setters + public String getWord() { + return word; + } + public void setWord(String word) { + this.word = word; + } + public Integer getFrequency() { + return frequency; + } + public void setFrequency(Integer frequency) { + this.frequency = frequency; + } + // to String + @Override + public String toString() { + return "Word="+word+" freq="+frequency; + } + } + + public static void main(String[] args) throws Exception { + + if(!parseParameters(args)) { + return; + } + + // set up the execution environment + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // get input data + DataSet text = getTextDataSet(env); + + DataSet counts = + // split up the lines in pairs (2-tuples) containing: (word,1) + text.flatMap(new Tokenizer()) + // group by the tuple field "0" and sum up tuple field "1" + .groupBy("word") + .reduce(new ReduceFunction() { + @Override + public Word reduce(Word value1, Word value2) throws Exception { + return new Word(value1.word,value1.frequency + value2.frequency); + } + }); + + if(fileOutput) { + counts.writeAsText(outputPath, WriteMode.OVERWRITE); + } else { + counts.print(); + } + + // execute program + env.execute("WordCount-Pojo Example"); + } + + // ************************************************************************* + // USER FUNCTIONS + // ************************************************************************* + + /** + * Implements the string tokenizer that splits sentences into words as a user-defined + * FlatMapFunction. The function takes a line (String) and splits it into + * multiple pairs in the form of "(word,1)" (Tuple2). + */ + public static final class Tokenizer implements FlatMapFunction { + private static final long serialVersionUID = 1L; + + @Override + public void flatMap(String value, Collector out) { + // normalize and split the line + String[] tokens = value.toLowerCase().split("\\W+"); + + // emit the pairs + for (String token : tokens) { + if (token.length() > 0) { + out.collect(new Word(token, 1)); + } + } + } + } + + // ************************************************************************* + // UTIL METHODS + // ************************************************************************* + + private static boolean fileOutput = false; + private static String textPath; + private static String outputPath; + + private static boolean parseParameters(String[] args) { + + if(args.length > 0) { + // parse input arguments + fileOutput = true; + if(args.length == 2) { + textPath = args[0]; + outputPath = args[1]; + } else { + System.err.println("Usage: WordCount "); + return false; + } + } else { + System.out.println("Executing WordCount example with built-in default data."); + System.out.println(" Provide parameters to read input data from a file."); + System.out.println(" Usage: WordCount "); + } + return true; + } + + private static DataSet getTextDataSet(ExecutionEnvironment env) { + if(fileOutput) { + // read the text file from given input path + return env.readTextFile(textPath); + } else { + // get default test text data + return WordCountData.getDefaultTextLineDataSet(env); + } + } +} diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java index d2939d8cd66e7..424d30b71dfa1 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java @@ -430,7 +430,7 @@ public DistinctOperator distinct(KeySelector keyExtractor) { * @return A DistinctOperator that represents the distinct DataSet. */ public DistinctOperator distinct(int... fields) { - return new DistinctOperator(this, new Keys.FieldPositionKeys(fields, getType(), true)); + return new DistinctOperator(this, new Keys.ExpressionKeys(fields, getType(), true)); } /** @@ -500,7 +500,7 @@ public UnsortedGrouping groupBy(KeySelector keyExtractor) { * @see DataSet */ public UnsortedGrouping groupBy(int... fields) { - return new UnsortedGrouping(this, new Keys.FieldPositionKeys(fields, getType(), false)); + return new UnsortedGrouping(this, new Keys.ExpressionKeys(fields, getType(), false)); } /** @@ -526,9 +526,9 @@ public UnsortedGrouping groupBy(int... fields) { * @see org.apache.flink.api.java.operators.GroupReduceOperator * @see DataSet */ -// public UnsortedGrouping groupBy(String... fields) { -// return new UnsortedGrouping(this, new Keys.ExpressionKeys(fields, getType())); -// } + public UnsortedGrouping groupBy(String... fields) { + return new UnsortedGrouping(this, new Keys.ExpressionKeys(fields, getType())); + } // -------------------------------------------------------------------------------------------- // Joining @@ -541,7 +541,7 @@ public UnsortedGrouping groupBy(int... fields) { * joining elements into one DataSet.
* * This method returns a {@link JoinOperatorSets} on which - * {@link JoinOperatorSets#where(int...)} needs to be called to define the join key of the first + * {@link JoinOperatorSets#where()} needs to be called to define the join key of the first * joining (i.e., this) DataSet. * * @param other The other DataSet with which this DataSet is joined. @@ -562,7 +562,7 @@ public JoinOperatorSets join(DataSet other) { * This method also gives the hint to the optimizer that the second DataSet to join is much * smaller than the first one.
* This method returns a {@link JoinOperatorSets} on which - * {@link JoinOperatorSets#where(int...)} needs to be called to define the join key of the first + * {@link JoinOperatorSets#where()} needs to be called to define the join key of the first * joining (i.e., this) DataSet. * * @param other The other DataSet with which this DataSet is joined. @@ -583,7 +583,7 @@ public JoinOperatorSets joinWithTiny(DataSet other) { * This method also gives the hint to the optimizer that the second DataSet to join is much * larger than the first one.
* This method returns a {@link JoinOperatorSets JoinOperatorSet} on which - * {@link JoinOperatorSets#where(int...)} needs to be called to define the join key of the first + * {@link JoinOperatorSets#where()} needs to be called to define the join key of the first * joining (i.e., this) DataSet. * * @param other The other DataSet with which this DataSet is joined. @@ -610,7 +610,7 @@ public JoinOperatorSets joinWithHuge(DataSet other) { * The CoGroupFunction can iterate over the elements of both groups and return any number * of elements including none.
* This method returns a {@link CoGroupOperatorSets} on which - * {@link CoGroupOperatorSets#where(int...)} needs to be called to define the grouping key of the first + * {@link CoGroupOperatorSets#where()} needs to be called to define the grouping key of the first * (i.e., this) DataSet. * * @param other The other DataSet of the CoGroup transformation. @@ -814,7 +814,7 @@ public IterativeDataSet iterate(int maxIterations) { * @see org.apache.flink.api.java.operators.DeltaIteration */ public DeltaIteration iterateDelta(DataSet workset, int maxIterations, int... keyPositions) { - Keys.FieldPositionKeys keys = new Keys.FieldPositionKeys(keyPositions, getType(), false); + Keys.ExpressionKeys keys = new Keys.ExpressionKeys(keyPositions, getType(), false); return new DeltaIteration(getExecutionEnvironment(), getType(), this, workset, keys, maxIterations); } @@ -863,7 +863,7 @@ public UnionOperator union(DataSet other){ * @return The partitioned DataSet. */ public PartitionOperator partitionByHash(int... fields) { - return new PartitionOperator(this, PartitionMethod.HASH, new Keys.FieldPositionKeys(fields, getType(), false)); + return new PartitionOperator(this, PartitionMethod.HASH, new Keys.ExpressionKeys(fields, getType(), false)); } /** diff --git a/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java b/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java index c1d072872c6cb..54e36c03e97c0 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java @@ -791,4 +791,5 @@ protected static void disableLocalExecution() { public static boolean localExecutionIsAllowed() { return allowLocalExecution; } + } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/SerializationSpeedTest.java b/flink-java/src/main/java/org/apache/flink/api/java/SerializationSpeedTest.java new file mode 100644 index 0000000000000..a81f55db60850 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/SerializationSpeedTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.java; + +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; +import java.lang.reflect.Field; + +public class SerializationSpeedTest { + + static Field wordDescField; + static Field wordField; + static { + try { + wordDescField = WC.class.getField("wordDesc"); + wordField = ComplexWordDescriptor.class.getField("word"); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public static class ComplexWordDescriptor { + public String word; + + public String getWord() { + return word; + } + } + + public static class WC { + public int count; + public ComplexWordDescriptor wordDesc; + + public WC(int c, String s) throws NoSuchFieldException, + SecurityException { + this.count = c; + this.wordDesc = new ComplexWordDescriptor(); + this.wordDesc.word = s; + } + + public ComplexWordDescriptor getWordDesc() { + return wordDesc; + } + + } + + public static int compareCodeGenPublicFields(WC w1, WC w2) { + return w1.wordDesc.word.compareTo(w2.wordDesc.word); + } + + public static int compareCodeGenMethods(WC w1, WC w2) { + return w1.getWordDesc().getWord().compareTo(w2.getWordDesc().getWord()); + } + + public static int compareReflective(WC w1, WC w2) + throws IllegalArgumentException, IllegalAccessException { + // get String of w1 + Object wordDesc1 = wordDescField.get(w1); + String word2cmp1 = (String) wordField.get(wordDesc1); + + // get String of w2 + Object wordDesc2 = wordDescField.get(w2); + String word2cmp2 = (String) wordField.get(wordDesc2); + + return word2cmp1.compareTo(word2cmp2); + } + + /** + * results on Core i7 2600k + * + * + * warming up Code gen 5019 Reflection 20364 Factor = 4.057382 + */ + public static void main(String[] args) throws NoSuchFieldException, + SecurityException, IllegalArgumentException, IllegalAccessException { + final long RUNS = 1000000000L; + + final RuntimeMXBean bean = ManagementFactory.getRuntimeMXBean(); + String jvm = bean.getVmName() + " - " + bean.getVmVendor() + " - " + + bean.getSpecVersion() + '/' + bean.getVmVersion(); + System.err.println("Jvm info : " + jvm); + + WC word0 = new WC(14, "Hallo"); + WC word1 = new WC(3, "Hola"); + + System.err.println("warming up"); + for (long i = 0; i < 100000000; i++) { + compareCodeGenPublicFields(word0, word1); + compareCodeGenMethods(word0, word1); + compareReflective(word0, word1); + } + + System.err.println("Code gen public fields"); + long startTime = System.currentTimeMillis(); + for (long i = 0; i < RUNS; i++) { + int a = compareCodeGenPublicFields(word0, word1); + if (a == 0) { + System.err.println("hah"); + } + } + long stopTime = System.currentTimeMillis(); + long elapsedTimeGen = stopTime - startTime; + System.err.println(elapsedTimeGen); + + System.err.println("Code gen methods"); + startTime = System.currentTimeMillis(); + for (long i = 0; i < RUNS; i++) { + int a = compareCodeGenPublicFields(word0, word1); + if (a == 0) { + System.err.println("hah"); + } + } + stopTime = System.currentTimeMillis(); + long elapsedTimeGenMethods = stopTime - startTime; + System.err.println(elapsedTimeGenMethods); + + System.err.println("Reflection"); + + startTime = System.currentTimeMillis(); + for (long i = 0; i < RUNS; i++) { + int a = compareReflective(word0, word1); + if (a == 0) { + System.err.println("hah"); + } + } + stopTime = System.currentTimeMillis(); + long elapsedTimeRef = stopTime - startTime; + System.err.println(elapsedTimeRef); + + System.err.println("Factor vs public = " + + (elapsedTimeRef / (float) elapsedTimeGen)); + System.err.println("Factor vs methods = " + + (elapsedTimeRef / (float) elapsedTimeGenMethods)); + } +} diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java index ee9b65f0155a8..041dc75ca9942 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/AggregateOperator.java @@ -195,7 +195,7 @@ protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase operatorInfo = new UnaryOperatorInformation(getInputType(), getResultType()); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java index 9ecb9a2e2584a..56f90f4072e21 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/CoGroupOperator.java @@ -32,7 +32,8 @@ import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.operators.DeltaIteration.SolutionSetPlaceHolder; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.operators.Keys.FieldPositionKeys; +import org.apache.flink.api.java.operators.Keys.ExpressionKeys; +import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException; import org.apache.flink.api.java.operators.translation.KeyExtractingMapper; import org.apache.flink.api.java.operators.translation.PlanBothUnwrappingCoGroupOperator; import org.apache.flink.api.java.operators.translation.PlanLeftUnwrappingCoGroupOperator; @@ -42,6 +43,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; + /** * A {@link DataSet} that is the result of a CoGroup transformation. * @@ -74,16 +76,16 @@ public CoGroupOperator(DataSet input1, DataSet input2, // sanity check solution set key mismatches if (input1 instanceof SolutionSetPlaceHolder) { - if (keys1 instanceof FieldPositionKeys) { - int[] positions = ((FieldPositionKeys) keys1).computeLogicalKeyPositions(); + if (keys1 instanceof ExpressionKeys) { + int[] positions = ((ExpressionKeys) keys1).computeLogicalKeyPositions(); ((SolutionSetPlaceHolder) input1).checkJoinKeyFields(positions); } else { throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions."); } } if (input2 instanceof SolutionSetPlaceHolder) { - if (keys2 instanceof FieldPositionKeys) { - int[] positions = ((FieldPositionKeys) keys2).computeLogicalKeyPositions(); + if (keys2 instanceof ExpressionKeys) { + int[] positions = ((ExpressionKeys) keys2).computeLogicalKeyPositions(); ((SolutionSetPlaceHolder) input2).checkJoinKeyFields(positions); } else { throw new InvalidProgramException("Currently, the solution set may only be CoGrouped with using tuple field positions."); @@ -108,10 +110,10 @@ protected Keys getKeys2() { protected org.apache.flink.api.common.operators.base.CoGroupOperatorBase translateToDataFlow(Operator input1, Operator input2) { String name = getName() != null ? getName() : function.getClass().getName(); - - if (!keys1.areCompatibale(keys2)) { - throw new InvalidProgramException("The types of the key fields do not match. Left:" + - " " + keys1 + " Right: " + keys2); + try { + keys1.areCompatible(keys2); + } catch (IncompatibleKeysException e) { + throw new InvalidProgramException("The types of the key fields do not match.", e); } if (keys1 instanceof Keys.SelectorFunctionKeys @@ -166,15 +168,13 @@ else if (keys1 instanceof Keys.SelectorFunctionKeys) { return po; } - else if ((keys1 instanceof Keys.FieldPositionKeys - && keys2 instanceof Keys.FieldPositionKeys) || - ((keys1 instanceof Keys.ExpressionKeys - && keys2 instanceof Keys.ExpressionKeys))) + else if ( keys1 instanceof Keys.ExpressionKeys && keys2 instanceof Keys.ExpressionKeys) { - - if (!keys1.areCompatibale(keys2)) { - throw new InvalidProgramException("The types of the key fields do not match."); - } + try { + keys1.areCompatible(keys2); + } catch (IncompatibleKeysException e) { + throw new InvalidProgramException("The types of the key fields do not match.", e); + } int[] logicalKeyPositions1 = keys1.computeLogicalKeyPositions(); int[] logicalKeyPositions2 = keys2.computeLogicalKeyPositions(); @@ -364,7 +364,7 @@ public CoGroupOperatorSets(DataSet input1, DataSet input2) { * @see DataSet */ public CoGroupOperatorSetsPredicate where(int... fields) { - return new CoGroupOperatorSetsPredicate(new Keys.FieldPositionKeys(fields, input1.getType())); + return new CoGroupOperatorSetsPredicate(new Keys.ExpressionKeys(fields, input1.getType())); } /** @@ -380,9 +380,9 @@ public CoGroupOperatorSetsPredicate where(int... fields) { * @see Tuple * @see DataSet */ -// public CoGroupOperatorSetsPredicate where(String... fields) { -// return new CoGroupOperatorSetsPredicate(new Keys.ExpressionKeys(fields, input1.getType())); -// } + public CoGroupOperatorSetsPredicate where(String... fields) { + return new CoGroupOperatorSetsPredicate(new Keys.ExpressionKeys(fields, input1.getType())); + } /** * Continues a CoGroup transformation and defines a {@link KeySelector} function for the first co-grouped {@link DataSet}.
@@ -436,7 +436,7 @@ private CoGroupOperatorSetsPredicate(Keys keys1) { * Call {@link org.apache.flink.api.java.operators.CoGroupOperator.CoGroupOperatorSets.CoGroupOperatorSetsPredicate.CoGroupOperatorWithoutFunction#with(org.apache.flink.api.common.functions.CoGroupFunction)} to finalize the CoGroup transformation. */ public CoGroupOperatorWithoutFunction equalTo(int... fields) { - return createCoGroupOperator(new Keys.FieldPositionKeys(fields, input2.getType())); + return createCoGroupOperator(new Keys.ExpressionKeys(fields, input2.getType())); } /** @@ -448,9 +448,9 @@ public CoGroupOperatorWithoutFunction equalTo(int... fields) { * @return An incomplete CoGroup transformation. * Call {@link org.apache.flink.api.java.operators.CoGroupOperator.CoGroupOperatorSets.CoGroupOperatorSetsPredicate.CoGroupOperatorWithoutFunction#with(org.apache.flink.api.common.functions.CoGroupFunction)} to finalize the CoGroup transformation. */ -// public CoGroupOperatorWithoutFunction equalTo(String... fields) { -// return createCoGroupOperator(new Keys.ExpressionKeys(fields, input2.getType())); -// } + public CoGroupOperatorWithoutFunction equalTo(String... fields) { + return createCoGroupOperator(new Keys.ExpressionKeys(fields, input2.getType())); + } /** * Continues a CoGroup transformation and defines a {@link KeySelector} function for the second co-grouped {@link DataSet}.
@@ -480,9 +480,10 @@ private CoGroupOperatorWithoutFunction createCoGroupOperator(Keys keys2) { if (keys2.isEmpty()) { throw new InvalidProgramException("The co-group keys must not be empty."); } - - if (!keys1.areCompatibale(keys2)) { - throw new InvalidProgramException("The pair of co-group keys are not compatible with each other."); + try { + keys1.areCompatible(keys2); + } catch(IncompatibleKeysException ike) { + throw new InvalidProgramException("The pair of co-group keys are not compatible with each other.", ike); } return new CoGroupOperatorWithoutFunction(keys2); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java index 18cb8f69404cc..54a65a9cf1bed 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/DistinctOperator.java @@ -58,7 +58,7 @@ public DistinctOperator(DataSet input, Keys keys) { for(int i = 0; i < tupleType.getArity(); i++) { allFields[i] = i; } - keys = new Keys.FieldPositionKeys(allFields, input.getType(), true); + keys = new Keys.ExpressionKeys(allFields, input.getType(), true); } else { throw new InvalidProgramException("Distinction on all fields is only possible on tuple data types."); @@ -67,7 +67,7 @@ public DistinctOperator(DataSet input, Keys keys) { // FieldPositionKeys can only be applied on Tuples - if (keys instanceof Keys.FieldPositionKeys && !input.getType().isTupleType()) { + if (keys instanceof Keys.ExpressionKeys && !input.getType().isTupleType()) { throw new InvalidProgramException("Distinction on field positions is only possible on tuple data types."); } @@ -81,7 +81,7 @@ public DistinctOperator(DataSet input, Keys keys) { String name = function.getClass().getName(); - if (keys instanceof Keys.FieldPositionKeys) { + if (keys instanceof Keys.ExpressionKeys) { int[] logicalKeyPositions = keys.computeLogicalKeyPositions(); UnaryOperatorInformation operatorInfo = new UnaryOperatorInformation(getInputType(), getResultType()); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java index edae3c8b0a111..cbcc367e70fc8 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/GroupReduceOperator.java @@ -139,7 +139,8 @@ public void setCombinable(boolean combinable) { return po; } - else if (grouper.getKeys() instanceof Keys.FieldPositionKeys) { + else if (grouper.getKeys() instanceof Keys.ExpressionKeys) { //was field position key + //TODO ask stephan int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions(); UnaryOperatorInformation operatorInfo = new UnaryOperatorInformation(getInputType(), getResultType()); @@ -166,19 +167,19 @@ else if (grouper.getKeys() instanceof Keys.FieldPositionKeys) { return po; } - else if (grouper.getKeys() instanceof Keys.ExpressionKeys) { - - int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions(); - UnaryOperatorInformation operatorInfo = new UnaryOperatorInformation(getInputType(), getResultType()); - GroupReduceOperatorBase> po = - new GroupReduceOperatorBase>(function, operatorInfo, logicalKeyPositions, name); - - po.setCombinable(combinable); - po.setInput(input); - po.setDegreeOfParallelism(this.getParallelism()); - - return po; - } +// else if (grouper.getKeys() instanceof Keys.ExpressionKeys) { +// +// int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions(); +// UnaryOperatorInformation operatorInfo = new UnaryOperatorInformation(getInputType(), getResultType()); +// GroupReduceOperatorBase> po = +// new GroupReduceOperatorBase>(function, operatorInfo, logicalKeyPositions, name); +// +// po.setCombinable(combinable); +// po.setInput(input); +// po.setDegreeOfParallelism(this.getParallelism()); +// +// return po; +// } else { throw new UnsupportedOperationException("Unrecognized key type."); } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java index bc35c1409a147..caa27dc4d2321 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/JoinOperator.java @@ -25,6 +25,7 @@ import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.DualInputSemanticProperties; import org.apache.flink.api.common.operators.Operator; @@ -33,23 +34,22 @@ import org.apache.flink.api.common.operators.base.MapOperatorBase; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.operators.DeltaIteration.SolutionSetPlaceHolder; -import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.functions.SemanticPropUtil; -import org.apache.flink.api.java.operators.Keys.FieldPositionKeys; +import org.apache.flink.api.java.operators.DeltaIteration.SolutionSetPlaceHolder; +import org.apache.flink.api.java.operators.Keys.ExpressionKeys; +import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException; import org.apache.flink.api.java.operators.translation.KeyExtractingMapper; import org.apache.flink.api.java.operators.translation.PlanBothUnwrappingJoinOperator; import org.apache.flink.api.java.operators.translation.PlanLeftUnwrappingJoinOperator; import org.apache.flink.api.java.operators.translation.PlanRightUnwrappingJoinOperator; import org.apache.flink.api.java.operators.translation.WrappingFunction; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; -import org.apache.flink.api.java.typeutils.TypeExtractor; - //CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator import org.apache.flink.api.java.tuple.*; -import org.apache.flink.util.Collector; //CHECKSTYLE.ON: AvoidStarImport +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.util.Collector; /** * A {@link DataSet} that is the result of a Join transformation. @@ -124,16 +124,16 @@ protected JoinOperator(DataSet input1, DataSet input2, // sanity check solution set key mismatches if (input1 instanceof SolutionSetPlaceHolder) { - if (keys1 instanceof FieldPositionKeys) { - int[] positions = ((FieldPositionKeys) keys1).computeLogicalKeyPositions(); + if (keys1 instanceof ExpressionKeys) { + int[] positions = ((ExpressionKeys) keys1).computeLogicalKeyPositions(); ((SolutionSetPlaceHolder) input1).checkJoinKeyFields(positions); } else { throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions."); } } if (input2 instanceof SolutionSetPlaceHolder) { - if (keys2 instanceof FieldPositionKeys) { - int[] positions = ((FieldPositionKeys) keys2).computeLogicalKeyPositions(); + if (keys2 instanceof ExpressionKeys) { + int[] positions = ((ExpressionKeys) keys2).computeLogicalKeyPositions(); ((SolutionSetPlaceHolder) input2).checkJoinKeyFields(positions); } else { throw new InvalidProgramException("Currently, the solution set may only be joined with using tuple field positions."); @@ -247,12 +247,12 @@ public void generateProjectionProperties(ProjectFlatJoinFunction pjf) { protected JoinOperatorBase translateToDataFlow( Operator input1, Operator input2) { - - String name = getName() != null ? getName() : function.getClass().getName(); - if (!keys1.areCompatibale(keys2)) { - throw new InvalidProgramException("The types of the key fields do not match. Left:" + - " " + keys1 + " Right: " + keys2); + String name = getName() != null ? getName() : function.getClass().getName(); + try { + keys1.areCompatible(super.keys2); + } catch(IncompatibleKeysException ike) { + throw new InvalidProgramException("The types of the key fields do not match.", ike); } if (keys1 instanceof Keys.SelectorFunctionKeys @@ -315,10 +315,7 @@ else if (keys1 instanceof Keys.SelectorFunctionKeys) { return po; } - else if ((super.keys1 instanceof Keys.FieldPositionKeys - && super.keys2 instanceof Keys.FieldPositionKeys) || - ((super.keys1 instanceof Keys.ExpressionKeys - && super.keys2 instanceof Keys.ExpressionKeys))) + else if (super.keys1 instanceof Keys.ExpressionKeys && super.keys2 instanceof Keys.ExpressionKeys) { // Neither side needs the tuple wrapping/unwrapping @@ -765,7 +762,7 @@ public JoinOperatorSets(DataSet input1, DataSet input2, JoinHint hint) { * @see DataSet */ public JoinOperatorSetsPredicate where(int... fields) { - return new JoinOperatorSetsPredicate(new Keys.FieldPositionKeys(fields, input1.getType())); + return new JoinOperatorSetsPredicate(new Keys.ExpressionKeys(fields, input1.getType())); } /** @@ -782,9 +779,9 @@ public JoinOperatorSetsPredicate where(int... fields) { * @see Tuple * @see DataSet */ -// public JoinOperatorSetsPredicate where(String... fields) { -// return new JoinOperatorSetsPredicate(new Keys.ExpressionKeys(fields, input1.getType())); -// } + public JoinOperatorSetsPredicate where(String... fields) { + return new JoinOperatorSetsPredicate(new Keys.ExpressionKeys(fields, input1.getType())); + } /** * Continues a Join transformation and defines a {@link KeySelector} function for the first join {@link DataSet}.
@@ -843,7 +840,7 @@ private JoinOperatorSetsPredicate(Keys keys1) { * @return A DefaultJoin that represents the joined DataSet. */ public DefaultJoin equalTo(int... fields) { - return createJoinOperator(new Keys.FieldPositionKeys(fields, input2.getType())); + return createJoinOperator(new Keys.ExpressionKeys(fields, input2.getType())); } /** @@ -857,9 +854,9 @@ public DefaultJoin equalTo(int... fields) { * @param fields The fields of the second join DataSet that should be used as keys. * @return A DefaultJoin that represents the joined DataSet. */ -// public DefaultJoin equalTo(String... fields) { -// return createJoinOperator(new Keys.ExpressionKeys(fields, input2.getType())); -// } + public DefaultJoin equalTo(String... fields) { + return createJoinOperator(new Keys.ExpressionKeys(fields, input2.getType())); + } /** * Continues a Join transformation and defines a {@link KeySelector} function for the second join {@link DataSet}.
@@ -887,8 +884,10 @@ protected DefaultJoin createJoinOperator(Keys keys2) { throw new InvalidProgramException("The join keys may not be empty."); } - if (!keys1.areCompatibale(keys2)) { - throw new InvalidProgramException("The pair of join keys are not compatible with each other."); + try { + keys1.areCompatible(keys2); + } catch (IncompatibleKeysException e) { + throw new InvalidProgramException("The pair of join keys are not compatible with each other.",e); } return new DefaultJoin(input1, input2, keys1, keys2, joinHint); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java index b19d5c3335e6d..653a04ac72499 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java @@ -18,116 +18,53 @@ package org.apache.flink.api.java.operators; +import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Ints; -public abstract class Keys { +public abstract class Keys { + private static final Logger LOG = LoggerFactory.getLogger(Keys.class); public abstract int getNumberOfKeyFields(); public boolean isEmpty() { return getNumberOfKeyFields() == 0; } - - public abstract boolean areCompatibale(Keys other); - + + /** + * Check if two sets of keys are compatible to each other (matching types, key counts) + */ + public abstract boolean areCompatible(Keys other) throws IncompatibleKeysException; + public abstract int[] computeLogicalKeyPositions(); - + + // -------------------------------------------------------------------------------------------- - // Specializations for field indexed / expression-based / extractor-based grouping + // Specializations for expression-based / extractor-based grouping // -------------------------------------------------------------------------------------------- - - public static class FieldPositionKeys extends Keys { - - private final int[] fieldPositions; - private final TypeInformation[] types; - - public FieldPositionKeys(int[] groupingFields, TypeInformation type) { - this(groupingFields, type, false); - } - - public FieldPositionKeys(int[] groupingFields, TypeInformation type, boolean allowEmpty) { - if (!type.isTupleType()) { - throw new InvalidProgramException("Specifying keys via field positions is only valid" + - "for tuple data types. Type: " + type); - } - - if (!allowEmpty && (groupingFields == null || groupingFields.length == 0)) { - throw new IllegalArgumentException("The grouping fields must not be empty."); - } - - TupleTypeInfoBase tupleType = (TupleTypeInfoBase)type; - - this.fieldPositions = makeFields(groupingFields, (TupleTypeInfoBase) type); - - types = new TypeInformation[this.fieldPositions.length]; - for(int i = 0; i < this.fieldPositions.length; i++) { - types[i] = tupleType.getTypeAt(this.fieldPositions[i]); - } - - } - - @Override - public int getNumberOfKeyFields() { - return this.fieldPositions.length; - } - - @Override - public boolean areCompatibale(Keys other) { - - if (other instanceof FieldPositionKeys) { - FieldPositionKeys oKey = (FieldPositionKeys) other; - - if(oKey.types.length != this.types.length) { - return false; - } - for(int i=0; i sfk = (SelectorFunctionKeys) other; - - return sfk.keyType.equals(this.types[0]); - } - else { - return false; - } - } - - @Override - public int[] computeLogicalKeyPositions() { - return this.fieldPositions; - } - - @Override - public String toString() { - String fieldsString = Arrays.toString(fieldPositions); - String typesString = Arrays.toString(types); - return "Tuple position key (Fields: " + fieldsString + " Types: " + typesString + ")"; - } - } - - // -------------------------------------------------------------------------------------------- - + + public static class SelectorFunctionKeys extends Keys { private final KeySelector keyExtractor; private final TypeInformation keyType; + private final int[] logicalKeyFields; public SelectorFunctionKeys(KeySelector keyExtractor, TypeInformation inputType, TypeInformation keyType) { if (keyExtractor == null) { @@ -136,6 +73,15 @@ public SelectorFunctionKeys(KeySelector keyExtractor, TypeInformation i this.keyExtractor = keyExtractor; this.keyType = keyType; + + // we have to handle a special case here: + // if the keyType is a tuple type, we need to select the full tuple with all its fields. + if(keyType.isTupleType()) { + ExpressionKeys ek = new ExpressionKeys(new String[] {ExpressionKeys.SELECT_ALL_CHAR}, keyType); + logicalKeyFields = ek.computeLogicalKeyPositions(); + } else { + logicalKeyFields = new int[] {0}; + } if (!this.keyType.isKeyType()) { throw new IllegalArgumentException("Invalid type of KeySelector keys"); @@ -152,35 +98,53 @@ public KeySelector getKeyExtractor() { @Override public int getNumberOfKeyFields() { - return 1; + return logicalKeyFields.length; } @Override - public boolean areCompatibale(Keys other) { - + public boolean areCompatible(Keys other) throws IncompatibleKeysException { + if (other instanceof SelectorFunctionKeys) { @SuppressWarnings("unchecked") SelectorFunctionKeys sfk = (SelectorFunctionKeys) other; return sfk.keyType.equals(this.keyType); } - else if (other instanceof FieldPositionKeys) { - FieldPositionKeys fpk = (FieldPositionKeys) other; - - if(fpk.types.length != 1) { - return false; + else if (other instanceof ExpressionKeys) { + ExpressionKeys expressionKeys = (ExpressionKeys) other; + + if(keyType.isTupleType()) { + // special case again: + TupleTypeInfo tupleKeyType = (TupleTypeInfo) keyType; + List keyTypeFields = new ArrayList(tupleKeyType.getTotalFields()); + tupleKeyType.getKey(ExpressionKeys.SELECT_ALL_CHAR, 0, keyTypeFields); + if(expressionKeys.keyFields.size() != keyTypeFields.size()) { + throw new IncompatibleKeysException(IncompatibleKeysException.SIZE_MISMATCH_MESSAGE); + } + for(int i=0; i < expressionKeys.keyFields.size(); i++) { + if(!expressionKeys.keyFields.get(i).getType().equals(keyTypeFields.get(i).getType())) { + throw new IncompatibleKeysException(expressionKeys.keyFields.get(i).getType(), keyTypeFields.get(i).getType() ); + } + } + return true; } - - return fpk.types[0].equals(this.keyType); - } - else { - return false; + if(expressionKeys.getNumberOfKeyFields() != 1) { + throw new IncompatibleKeysException("Key selector functions are only compatible to one key"); + } + + if(expressionKeys.keyFields.get(0).getType().equals(this.keyType)) { + return true; + } else { + throw new IncompatibleKeysException(expressionKeys.keyFields.get(0).getType(), this.keyType); + } + } else { + throw new IncompatibleKeysException("The key is not compatible with "+other); } } @Override public int[] computeLogicalKeyPositions() { - return new int[] {0}; + return logicalKeyFields; } @Override @@ -188,93 +152,178 @@ public String toString() { return "Key function (Type: " + keyType + ")"; } } - - // -------------------------------------------------------------------------------------------- - + + + /** + * Represents (nested) field access through string and integer-based keys for Composite Types (Tuple or Pojo) + */ public static class ExpressionKeys extends Keys { + + public static final String SELECT_ALL_CHAR = "*"; + + /** + * Flattened fields representing keys fields + */ + private List keyFields; + + /** + * two constructors for field-based (tuple-type) keys + */ + public ExpressionKeys(int[] groupingFields, TypeInformation type) { + this(groupingFields, type, false); + } - private int[] logicalPositions; - - private final TypeInformation[] types; - - @SuppressWarnings("unused") - private PojoTypeInfo type; + // int-defined field + public ExpressionKeys(int[] groupingFields, TypeInformation type, boolean allowEmpty) { + if (!type.isTupleType()) { + throw new InvalidProgramException("Specifying keys via field positions is only valid" + + "for tuple data types. Type: " + type); + } - public ExpressionKeys(String[] expressions, TypeInformation type) { - if (!(type instanceof PojoTypeInfo)) { - throw new UnsupportedOperationException("Key expressions can only be used on POJOs." + " " + - "A POJO must have a default constructor without arguments and not have readObject" + - " and/or writeObject methods. A current restriction is that it can only have nested POJOs or primitive (also boxed)" + - " fields."); + if (!allowEmpty && (groupingFields == null || groupingFields.length == 0)) { + throw new IllegalArgumentException("The grouping fields must not be empty."); } - PojoTypeInfo pojoType = (PojoTypeInfo) type; - this.type = pojoType; - logicalPositions = pojoType.getLogicalPositions(expressions); - types = pojoType.getTypes(expressions); - - for (int i = 0; i < logicalPositions.length; i++) { - if (logicalPositions[i] < 0) { - throw new IllegalArgumentException("Expression '" + expressions[i] + "' is not a valid key for POJO" + - " type " + type.toString() + "."); + // select all fields. Therefore, set all fields on this tuple level and let the logic handle the rest + // (makes type assignment easier). + if (groupingFields == null || groupingFields.length == 0) { + groupingFields = new int[type.getArity()]; + for (int i = 0; i < groupingFields.length; i++) { + groupingFields[i] = i; } + } else { + groupingFields = rangeCheckFields(groupingFields, type.getArity() -1); } + TupleTypeInfoBase tupleType = (TupleTypeInfoBase)type; + Preconditions.checkArgument(groupingFields.length > 0, "Grouping fields can not be empty at this point"); + + keyFields = new ArrayList(type.getTotalFields()); + // for each key, find the field: + for(int j = 0; j < groupingFields.length; j++) { + for(int i = 0; i < type.getArity(); i++) { + TypeInformation fieldType = tupleType.getTypeAt(i); + + if(groupingFields[j] == i) { // check if user set the key + int keyId = countNestedElementsBefore(tupleType, i) + i; + if(fieldType instanceof TupleTypeInfoBase) { + TupleTypeInfoBase tupleFieldType = (TupleTypeInfoBase) fieldType; + tupleFieldType.addAllFields(keyId, keyFields); + } else { + Preconditions.checkArgument(fieldType instanceof AtomicType, "Wrong field type"); + keyFields.add(new FlatFieldDescriptor(keyId, fieldType)); + } + + } + } + } + keyFields = removeNullElementsFromList(keyFields); } - + + private static int countNestedElementsBefore(TupleTypeInfoBase tupleType, int pos) { + if( pos == 0) { + return 0; + } + int ret = 0; + for (int i = 0; i < pos; i++) { + TypeInformation fieldType = tupleType.getTypeAt(i); + ret += fieldType.getTotalFields() -1; + } + return ret; + } + + public static List removeNullElementsFromList(List in) { + List elements = new ArrayList(); + for(R e: in) { + if(e != null) { + elements.add(e); + } + } + return elements; + } + + /** + * Create ExpressionKeys from String-expressions + */ + public ExpressionKeys(String[] expressionsIn, TypeInformation type) { + if(!(type instanceof CompositeType)) { + throw new IllegalArgumentException("Type "+type+" is not a composite type. Key expressions are not supported."); + } + CompositeType cType = (CompositeType) type; + + String[] expressions = removeDuplicates(expressionsIn); + if(expressionsIn.length != expressions.length) { + LOG.warn("The key expressions contained duplicates. They are now unique"); + } + // extract the keys on their flat position + keyFields = new ArrayList(expressions.length); + for (int i = 0; i < expressions.length; i++) { + List keys = new ArrayList(); // use separate list to do a size check + cType.getKey(expressions[i], 0, keys); + if(keys.size() == 0) { + throw new IllegalArgumentException("Unable to extract key from expression "+expressions[i]+" on key "+cType); + } + keyFields.addAll(keys); + } + } + @Override public int getNumberOfKeyFields() { - return logicalPositions.length; + if(keyFields == null) { + return 0; + } + return keyFields.size(); } @Override - public boolean areCompatibale(Keys other) { + public boolean areCompatible(Keys other) throws IncompatibleKeysException { if (other instanceof ExpressionKeys) { ExpressionKeys oKey = (ExpressionKeys) other; - if(oKey.types.length != this.types.length) { - return false; + if(oKey.getNumberOfKeyFields() != this.getNumberOfKeyFields() ) { + throw new IncompatibleKeysException(IncompatibleKeysException.SIZE_MISMATCH_MESSAGE); } - for(int i=0; i) { + return other.areCompatible(this); } else { - return false; + throw new IncompatibleKeysException("The key is not compatible with "+other); } } @Override public int[] computeLogicalKeyPositions() { - return logicalPositions; + List logicalKeys = new LinkedList(); + for(FlatFieldDescriptor kd : keyFields) { + logicalKeys.addAll( Ints.asList(kd.getPosition())); + } + return Ints.toArray(logicalKeys); } + + } + + private static String[] removeDuplicates(String[] in) { + List ret = new LinkedList(); + for(String el : in) { + if(!ret.contains(el)) { + ret.add(el); + } + } + return ret.toArray(new String[ret.size()]); } - - + // -------------------------------------------------------------------------------------------- + + // -------------------------------------------------------------------------------------------- // Utilities // -------------------------------------------------------------------------------------------- - private static int[] makeFields(int[] fields, TupleTypeInfoBase type) { - int inLength = type.getArity(); - - // null parameter means all fields are considered - if (fields == null || fields.length == 0) { - fields = new int[inLength]; - for (int i = 0; i < inLength; i++) { - fields[i] = i; - } - return fields; - } else { - return rangeCheckAndOrderFields(fields, inLength-1); - } - } - private static final int[] rangeCheckAndOrderFields(int[] fields, int maxAllowedField) { - // order - Arrays.sort(fields); + private static final int[] rangeCheckFields(int[] fields, int maxAllowedField) { // range check and duplicate eliminate int i = 1, k = 0; @@ -285,12 +334,12 @@ private static final int[] rangeCheckAndOrderFields(int[] fields, int maxAllowed } for (; i < fields.length; i++) { - if (fields[i] < 0 || i > maxAllowedField) { + if (fields[i] < 0 || fields[i] > maxAllowedField) { throw new IllegalArgumentException("Tuple position is out of range."); } - if (fields[i] != last) { k++; + last = fields[i]; fields[k] = fields[i]; } } @@ -299,7 +348,20 @@ private static final int[] rangeCheckAndOrderFields(int[] fields, int maxAllowed if (k == fields.length - 1) { return fields; } else { - return Arrays.copyOfRange(fields, 0, k); + return Arrays.copyOfRange(fields, 0, k+1); + } + } + + public static class IncompatibleKeysException extends Exception { + private static final long serialVersionUID = 1L; + public static final String SIZE_MISMATCH_MESSAGE = "The number of specified keys is different."; + + public IncompatibleKeysException(String message) { + super(message); + } + + public IncompatibleKeysException(TypeInformation typeInformation, TypeInformation typeInformation2) { + super(typeInformation+" and "+typeInformation2+" are not compatible"); } } } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java index f0931b51c7c9b..532e464ad04aa 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/PartitionOperator.java @@ -50,7 +50,7 @@ public PartitionOperator(DataSet input, PartitionMethod pMethod, Keys pKey throw new UnsupportedOperationException("Range Partitioning not yet supported"); } - if(pKeys instanceof Keys.FieldPositionKeys && !input.getType().isTupleType()) { + if(pKeys instanceof Keys.ExpressionKeys && !input.getType().isTupleType()) { throw new IllegalArgumentException("Hash Partitioning with key fields only possible on Tuple DataSets"); } @@ -83,7 +83,7 @@ public PartitionOperator(DataSet input, PartitionMethod pMethod) { } else if (pMethod == PartitionMethod.HASH) { - if (pKeys instanceof Keys.FieldPositionKeys) { + if (pKeys instanceof Keys.ExpressionKeys) { int[] logicalKeyPositions = pKeys.computeLogicalKeyPositions(); UnaryOperatorInformation operatorInfo = new UnaryOperatorInformation(getType(), getType()); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java index 463b31c83d8d5..8cb64ba6d2b58 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/ReduceOperator.java @@ -101,8 +101,7 @@ public ReduceOperator(Grouping input, ReduceFunction function) { MapOperatorBase po = translateSelectorFunctionReducer(selectorKeys, function, getInputType(), name, input, this.getParallelism()); return po; } - else if (grouper.getKeys() instanceof Keys.FieldPositionKeys || - grouper.getKeys() instanceof Keys.ExpressionKeys) { + else if (grouper.getKeys() instanceof Keys.ExpressionKeys) { // reduce with field positions int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions(); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/GenericTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/GenericTypeInfo.java index e862b5a8f609e..6272538462e26 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/GenericTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/GenericTypeInfo.java @@ -52,6 +52,11 @@ public boolean isTupleType() { public int getArity() { return 1; } + + @Override + public int getTotalFields() { + return 1; + } @Override public Class getTypeClass() { @@ -65,6 +70,9 @@ public boolean isKeyType() { @Override public TypeSerializer createSerializer() { + // NOTE: The TypeExtractor / pojo logic is assuming that we are using a Avro Serializer here + // in particular classes implementing GenericContainer are handled as GenericTypeInfos + // (this will probably not work with Kryo) return new AvroSerializer(this.typeClass); } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java index 226557f011597..55128e68d5325 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java @@ -64,6 +64,11 @@ public boolean isTupleType() { public int getArity() { return 1; } + + @Override + public int getTotalFields() { + return 1; + } @SuppressWarnings("unchecked") @Override diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoField.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoField.java index 1b8ef3523fe04..bf0e25af2f8bc 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoField.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoField.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; -class PojoField { +public class PojoField { public Field field; public TypeInformation type; @@ -30,4 +30,9 @@ public PojoField(Field field, TypeInformation type) { this.field = field; this.type = type; } -} + + @Override + public String toString() { + return "PojoField " + field.getDeclaringClass() + "." + field.getName() + " (" + type + ")"; + } +} \ No newline at end of file diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoTypeInfo.java index 51ed507def826..fba1f24c5a3f2 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/PojoTypeInfo.java @@ -18,33 +18,41 @@ package org.apache.flink.api.java.typeutils; -import com.google.common.base.Joiner; - import java.lang.reflect.Field; +import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; +import org.apache.commons.lang3.Validate; import org.apache.flink.api.common.typeinfo.AtomicType; -import org.apache.flink.api.common.typeinfo.CompositeType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.operators.Keys.ExpressionKeys; import org.apache.flink.api.java.typeutils.runtime.PojoComparator; import org.apache.flink.api.java.typeutils.runtime.PojoSerializer; +import com.google.common.base.Joiner; + /** - * + * TypeInformation for arbitrary (they have to be java-beans-style) java objects (what we call POJO). + * */ -public class PojoTypeInfo extends TypeInformation implements CompositeType { +public class PojoTypeInfo extends CompositeType{ private final Class typeClass; private PojoField[] fields; + + private int totalFields; public PojoTypeInfo(Class typeClass, List fields) { + super(typeClass); this.typeClass = typeClass; List tempFields = new ArrayList(fields); Collections.sort(tempFields, new Comparator() { @@ -54,6 +62,14 @@ public int compare(PojoField o1, PojoField o2) { } }); this.fields = tempFields.toArray(new PojoField[tempFields.size()]); + + // check if POJO is public + if(!Modifier.isPublic(typeClass.getModifiers())) { + throw new RuntimeException("POJO "+typeClass+" is not public"); + } + for(PojoField field : fields) { + totalFields += field.type.getTotalFields(); + } } @Override @@ -71,6 +87,11 @@ public boolean isTupleType() { public int getArity() { return fields.length; } + + @Override + public int getTotalFields() { + return totalFields; + } @Override public Class getTypeClass() { @@ -82,18 +103,6 @@ public boolean isKeyType() { return Comparable.class.isAssignableFrom(typeClass); } - @Override - public TypeSerializer createSerializer() { - TypeSerializer[] fieldSerializers = new TypeSerializer[fields.length]; - Field[] reflectiveFields = new Field[fields.length]; - - for (int i = 0; i < fields.length; i++) { - fieldSerializers[i] = fields[i].type.createSerializer(); - reflectiveFields[i] = fields[i].field; - } - - return new PojoSerializer(this.typeClass, fieldSerializers, reflectiveFields); - } @Override public String toString() { @@ -105,73 +114,124 @@ public String toString() { + ", fields = [" + Joiner.on(", ").join(fieldStrings) + "]" + ">"; } - - public int getLogicalPosition(String fieldExpression) { - for (int i = 0; i < fields.length; i++) { - if (fields[i].field.getName().equals(fieldExpression)) { - return i; + + @Override + public void getKey(String fieldExpression, int offset, List result) { + // handle 'select all' first + if(fieldExpression.equals(ExpressionKeys.SELECT_ALL_CHAR)) { + int keyPosition = 0; + for(PojoField field : fields) { + if(field.type instanceof AtomicType) { + result.add(new FlatFieldDescriptor(offset + keyPosition, field.type)); + } else if(field.type instanceof CompositeType) { + CompositeType cType = (CompositeType)field.type; + cType.getKey(String.valueOf(ExpressionKeys.SELECT_ALL_CHAR), offset + keyPosition, result); + keyPosition += cType.getTotalFields()-1; + } else { + throw new RuntimeException("Unexpected key type: "+field.type); + } + keyPosition++; } + return; + } + Validate.notEmpty(fieldExpression, "Field expression must not be empty."); + // if there is a dot try getting the field from that sub field + int firstDot = fieldExpression.indexOf('.'); + if (firstDot == -1) { + // this is the last field (or only field) in the field expression + int fieldId = 0; + for (int i = 0; i < fields.length; i++) { + if(fields[i].type instanceof CompositeType) { + fieldId += fields[i].type.getTotalFields()-1; + } + if (fields[i].field.getName().equals(fieldExpression)) { + result.add(new FlatFieldDescriptor(offset + fieldId, fields[i].type)); + return; + } + fieldId++; + } + } else { + // split and go deeper + String firstField = fieldExpression.substring(0, firstDot); + String rest = fieldExpression.substring(firstDot + 1); + int fieldId = 0; + for (int i = 0; i < fields.length; i++) { + if (fields[i].field.getName().equals(firstField)) { + if (!(fields[i].type instanceof CompositeType)) { + throw new RuntimeException("Field "+fields[i].type+" is not composite type"); + } + CompositeType cType = (CompositeType) fields[i].type; + cType.getKey(rest, offset + fieldId, result); // recurse + return; + } + fieldId++; + } + throw new RuntimeException("Unable to find field "+fieldExpression+" in type "+this+" (looking for '"+firstField+"')"); } - return -1; } - public int[] getLogicalPositions(String[] fieldExpression) { - int[] result = new int[fieldExpression.length]; - for (int i = 0; i < fieldExpression.length; i++) { - result[i] = getLogicalPosition(fieldExpression[i]); + @Override + public TypeInformation getTypeAt(int pos) { + if (pos < 0 || pos >= this.fields.length) { + throw new IndexOutOfBoundsException(); } - return result; + @SuppressWarnings("unchecked") + TypeInformation typed = (TypeInformation) fields[pos].type; + return typed; } - public TypeInformation getType(String fieldExpression) { - for (int i = 0; i < fields.length; i++) { - if (fields[i].field.getName().equals(fieldExpression)) { - return fields[i].type; - } + // used for testing. Maybe use mockito here + public PojoField getPojoFieldAt(int pos) { + if (pos < 0 || pos >= this.fields.length) { + throw new IndexOutOfBoundsException(); } - return null; + return this.fields[pos]; } - public TypeInformation[] getTypes(String[] fieldExpression) { - TypeInformation[] result = new TypeInformation[fieldExpression.length]; - for (int i = 0; i < fieldExpression.length; i++) { - result[i] = getType(fieldExpression[i]); - } - return result; + /** + * Comparator creation + */ + private TypeComparator[] fieldComparators; + private Field[] keyFields; + private int comparatorHelperIndex = 0; + @Override + protected void initializeNewComparator(int keyCount) { + fieldComparators = new TypeComparator[keyCount]; + keyFields = new Field[keyCount]; + comparatorHelperIndex = 0; } @Override - public TypeComparator createComparator(int[] logicalKeyFields, boolean[] orders) { - // sanity checks - if (logicalKeyFields == null || orders == null || logicalKeyFields.length != orders.length || - logicalKeyFields.length > fields.length) - { - throw new IllegalArgumentException(); - } + protected void addCompareField(int fieldId, TypeComparator comparator) { + fieldComparators[comparatorHelperIndex] = comparator; + keyFields[comparatorHelperIndex] = fields[fieldId].field; + comparatorHelperIndex++; + } -// if (logicalKeyFields.length == 1) { -// return createSinglefieldComparator(logicalKeyFields[0], orders[0], types[logicalKeyFields[0]]); -// } + @Override + protected TypeComparator getNewComparator() { + // first remove the null array fields + final Field[] finalKeyFields = Arrays.copyOf(keyFields, comparatorHelperIndex); + @SuppressWarnings("rawtypes") + final TypeComparator[] finalFieldComparators = Arrays.copyOf(fieldComparators, comparatorHelperIndex); + if(finalFieldComparators.length == 0 || finalKeyFields.length == 0 || finalFieldComparators.length != finalKeyFields.length) { + throw new IllegalArgumentException("Pojo comparator creation has a bug"); + } + return new PojoComparator(finalKeyFields, finalFieldComparators, createSerializer(), typeClass); + } - // create the comparators for the individual fields - TypeComparator[] fieldComparators = new TypeComparator[logicalKeyFields.length]; - Field[] keyFields = new Field[logicalKeyFields.length]; - for (int i = 0; i < logicalKeyFields.length; i++) { - int field = logicalKeyFields[i]; + @Override + public TypeSerializer createSerializer() { + TypeSerializer[] fieldSerializers = new TypeSerializer[fields.length ]; + Field[] reflectiveFields = new Field[fields.length]; - if (field < 0 || field >= fields.length) { - throw new IllegalArgumentException("The field position " + field + " is out of range [0," + fields.length + ")"); - } - if (fields[field].type.isKeyType() && fields[field].type instanceof AtomicType) { - fieldComparators[i] = ((AtomicType) fields[field].type).createComparator(orders[i]); - keyFields[i] = fields[field].field; - keyFields[i].setAccessible(true); - } else { - throw new IllegalArgumentException("The field at position " + field + " (" + fields[field].type + ") is no atomic key type."); - } + for (int i = 0; i < fields.length; i++) { + fieldSerializers[i] = fields[i].type.createSerializer(); + reflectiveFields[i] = fields[i].field; } - return new PojoComparator(keyFields, fieldComparators, createSerializer(), typeClass); + return new PojoSerializer(this.typeClass, fieldSerializers, reflectiveFields); } + } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/RecordTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/RecordTypeInfo.java index f069eed16001b..2464f255a5cfc 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/RecordTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/RecordTypeInfo.java @@ -43,6 +43,11 @@ public boolean isTupleType() { public int getArity() { return 1; } + + @Override + public int getTotalFields() { + return 1; + } @Override public Class getTypeClass() { diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java index 6edb08c25249b..82f9c50817619 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfo.java @@ -20,22 +20,21 @@ import java.util.Arrays; -import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.typeutils.runtime.TupleComparator; -import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; - //CHECKSTYLE.OFF: AvoidStarImport - Needed for TupleGenerator import org.apache.flink.api.java.tuple.*; //CHECKSTYLE.ON: AvoidStarImport +import org.apache.flink.api.java.typeutils.runtime.TupleComparator; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; -public final class TupleTypeInfo extends TupleTypeInfoBase { +public final class TupleTypeInfo extends TupleTypeInfoBase { + @SuppressWarnings("unchecked") public TupleTypeInfo(TypeInformation... types) { this((Class) CLASSES[types.length - 1], types); @@ -60,59 +59,72 @@ public TupleSerializer createSerializer() { return new TupleSerializer(tupleClass, fieldSerializers); } + /** + * Comparator creation + */ + private TypeSerializer[] fieldSerializers; + private TypeComparator[] fieldComparators; + private int[] logicalKeyFields; + private int comparatorHelperIndex = 0; + @Override - public TypeComparator createComparator(int[] logicalKeyFields, boolean[] orders) { - // sanity checks - if (logicalKeyFields == null || orders == null || logicalKeyFields.length != orders.length || - logicalKeyFields.length > types.length) - { - throw new IllegalArgumentException(); - } + protected void initializeNewComparator(int localKeyCount) { + fieldSerializers = new TypeSerializer[localKeyCount]; + fieldComparators = new TypeComparator[localKeyCount]; + logicalKeyFields = new int[localKeyCount]; + comparatorHelperIndex = 0; + } - int maxKey = -1; - for (int key : logicalKeyFields){ - maxKey = Math.max(key, maxKey); - } - - if (maxKey >= this.types.length) { - throw new IllegalArgumentException("The key position " + maxKey + " is out of range for Tuple" + types.length); - } - - // create the comparators for the individual fields - TypeComparator[] fieldComparators = new TypeComparator[logicalKeyFields.length]; - for (int i = 0; i < logicalKeyFields.length; i++) { - int keyPos = logicalKeyFields[i]; - if (types[keyPos].isKeyType() && types[keyPos] instanceof AtomicType) { - fieldComparators[i] = ((AtomicType) types[keyPos]).createComparator(orders[i]); - } else if(types[keyPos].isTupleType() && types[keyPos] instanceof TupleTypeInfo){ // Check for tuple - TupleTypeInfo tupleType = (TupleTypeInfo) types[keyPos]; - - // All fields are key - int[] allFieldsKey = new int[tupleType.types.length]; - for(int h = 0; h < tupleType.types.length; h++){ - allFieldsKey[h]=h; - } - - // Prepare order - boolean[] tupleOrders = new boolean[tupleType.types.length]; - Arrays.fill(tupleOrders, orders[i]); - fieldComparators[i] = tupleType.createComparator(allFieldsKey, tupleOrders); - } else { - throw new IllegalArgumentException("The field at position " + i + " (" + types[keyPos] + ") is no atomic key type nor tuple type."); - } - } - + @Override + protected void addCompareField(int fieldId, TypeComparator comparator) { + fieldComparators[comparatorHelperIndex] = comparator; + fieldSerializers[comparatorHelperIndex] = types[fieldId].createSerializer(); + logicalKeyFields[comparatorHelperIndex] = fieldId; + comparatorHelperIndex++; + } + + @Override + protected TypeComparator getNewComparator() { + @SuppressWarnings("rawtypes") + final TypeComparator[] finalFieldComparators = Arrays.copyOf(fieldComparators, comparatorHelperIndex); + final int[] finalLogicalKeyFields = Arrays.copyOf(logicalKeyFields, comparatorHelperIndex); + //final TypeSerializer[] finalFieldSerializers = Arrays.copyOf(fieldSerializers, comparatorHelperIndex); // create the serializers for the prefix up to highest key position + int maxKey = 0; + for(int key : finalLogicalKeyFields) { + maxKey = Math.max(maxKey, key); + } TypeSerializer[] fieldSerializers = new TypeSerializer[maxKey + 1]; for (int i = 0; i <= maxKey; i++) { fieldSerializers[i] = types[i].createSerializer(); } - - return new TupleComparator(logicalKeyFields, fieldComparators, fieldSerializers); + if(finalFieldComparators.length == 0 || finalLogicalKeyFields.length == 0 || fieldSerializers.length == 0 + || finalFieldComparators.length != finalLogicalKeyFields.length) { + throw new IllegalArgumentException("Tuple comparator creation has a bug"); + } + return new TupleComparator(finalLogicalKeyFields, finalFieldComparators, fieldSerializers); } // -------------------------------------------------------------------------------------------- - + + @Override + public boolean equals(Object obj) { + if (obj instanceof TupleTypeInfo) { + @SuppressWarnings("unchecked") + TupleTypeInfo other = (TupleTypeInfo) obj; + return ((this.tupleType == null && other.tupleType == null) || this.tupleType.equals(other.tupleType)) && + Arrays.deepEquals(this.types, other.types); + + } else { + return false; + } + } + + @Override + public int hashCode() { + return this.types.hashCode() ^ Arrays.deepHashCode(this.types); + } + @Override public String toString() { return "Java " + super.toString(); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfoBase.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfoBase.java index 3e1b646bd3e02..4babbd7889890 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfoBase.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TupleTypeInfoBase.java @@ -19,19 +19,31 @@ package org.apache.flink.api.java.typeutils; import java.util.Arrays; +import java.util.List; -import org.apache.flink.api.common.typeinfo.CompositeType; +import org.apache.commons.lang3.StringUtils; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.operators.Keys.ExpressionKeys; -public abstract class TupleTypeInfoBase extends TypeInformation implements CompositeType { +import com.google.common.base.Preconditions; + +public abstract class TupleTypeInfoBase extends CompositeType { protected final TypeInformation[] types; protected final Class tupleType; + + private int totalFields; public TupleTypeInfoBase(Class tupleType, TypeInformation... types) { + super(tupleType); this.tupleType = tupleType; this.types = types; + for(TypeInformation type : types) { + totalFields += type.getTotalFields(); + } } @Override @@ -48,6 +60,11 @@ public boolean isTupleType() { public int getArity() { return types.length; } + + @Override + public int getTotalFields() { + return totalFields; + } @Override public Class getTypeClass() { @@ -55,6 +72,95 @@ public Class getTypeClass() { } + /** + * Recursively add all fields in this tuple type. We need this in particular to get all + * the types. + * @param keyId + * @param keyFields + */ + public void addAllFields(int startKeyId, List keyFields) { + for(int i = 0; i < this.getArity(); i++) { + TypeInformation type = this.types[i]; + if(type instanceof AtomicType) { + keyFields.add(new FlatFieldDescriptor(startKeyId, type)); + } else if(type instanceof TupleTypeInfoBase) { + TupleTypeInfoBase ttb = (TupleTypeInfoBase) type; + ttb.addAllFields(startKeyId, keyFields); + } + startKeyId += type.getTotalFields(); + } + } + + + @Override + public void getKey(String fieldExpression, int offset, List result) { + // handle 'select all' + if(fieldExpression.equals(ExpressionKeys.SELECT_ALL_CHAR)) { + int keyPosition = 0; + for(TypeInformation type : types) { + if(type instanceof AtomicType) { + result.add(new FlatFieldDescriptor(offset + keyPosition, type)); + } else if(type instanceof CompositeType) { + CompositeType cType = (CompositeType)type; + cType.getKey(String.valueOf(ExpressionKeys.SELECT_ALL_CHAR), offset + keyPosition, result); + keyPosition += cType.getTotalFields()-1; + } else { + throw new RuntimeException("Unexpected key type: "+type); + } + keyPosition++; + } + return; + } + // check input + if(fieldExpression.length() < 2) { + throw new IllegalArgumentException("The field expression '"+fieldExpression+"' is incorrect. The length must be at least 2"); + } + if(fieldExpression.charAt(0) != 'f') { + throw new IllegalArgumentException("The field expression '"+fieldExpression+"' is incorrect for a Tuple type. It has to start with an 'f'"); + } + // get first component of nested expression + int dotPos = fieldExpression.indexOf('.'); + String nestedSplitFirst = fieldExpression; + if(dotPos != -1 ) { + Preconditions.checkArgument(dotPos != fieldExpression.length()-1, "The field expression can never end with a dot."); + nestedSplitFirst = fieldExpression.substring(0, dotPos); + } + String fieldNumStr = nestedSplitFirst.substring(1, nestedSplitFirst.length()); + if(!StringUtils.isNumeric(fieldNumStr)) { + throw new IllegalArgumentException("The field expression '"+fieldExpression+"' is incorrect. Field number '"+fieldNumStr+" is not numeric"); + } + int pos = -1; + try { + pos = Integer.valueOf(fieldNumStr); + } catch(NumberFormatException nfe) { + throw new IllegalArgumentException("The field expression '"+fieldExpression+"' is incorrect. Field number '"+fieldNumStr+" is not numeric", nfe); + } + if(pos < 0) { + throw new IllegalArgumentException("Negative position is not possible"); + } + // pass down the remainder (after the dot) of the fieldExpression to the type at that position. + if(dotPos != -1) { // we need to go deeper + String rem = fieldExpression.substring(dotPos+1); + if( !(types[pos] instanceof CompositeType) ) { + throw new RuntimeException("Element at position "+pos+" is not a composite type. There are no nested types to select"); + } + CompositeType cType = (CompositeType) types[pos]; + cType.getKey(rem, offset + pos, result); + return; + } + + if(pos >= types.length) { + throw new IllegalArgumentException("The specified tuple position does not exist"); + } + + // count nested fields before "pos". + for(int i = 0; i < pos; i++) { + offset += types[i].getTotalFields() - 1; // this adds only something to offset if its a composite type. + } + + result.add(new FlatFieldDescriptor(offset + pos, types[pos])); + } + public TypeInformation getTypeAt(int pos) { if (pos < 0 || pos >= this.types.length) { throw new IndexOutOfBoundsException(); @@ -115,4 +221,5 @@ public String toString() { bld.append('>'); return bld.toString(); } + } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java index 55f6b1f1c9957..bc92cdd459c95 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractor.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Set; +import org.apache.avro.generic.GenericContainer; import org.apache.commons.lang3.Validate; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.CrossFunction; @@ -47,13 +48,19 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.types.Value; import org.apache.flink.util.Collector; import org.apache.hadoop.io.Writable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; public class TypeExtractor { + private static final Logger LOG = LoggerFactory.getLogger(TypeExtractor.class); // We need this to detect recursive types and not get caught // in an endless recursion @@ -220,6 +227,29 @@ private TypeInformation privateCreateTypeInfo(Type returnTy // get info from hierarchy return (TypeInformation) createTypeInfoWithTypeHierarchy(typeHierarchy, returnType, in1Type, in2Type); } + + + /** + * @param curT : start type + * @return Type The immediate child of the top class + */ + private Type recursivelyGetTypeHierarchy(ArrayList typeHierarchy, Type curT, Class stopAtClass) { + while (!(curT instanceof ParameterizedType && ((Class) ((ParameterizedType) curT).getRawType()).equals( + stopAtClass)) + && !(curT instanceof Class && ((Class) curT).equals(stopAtClass))) { + typeHierarchy.add(curT); + + // parameterized type + if (curT instanceof ParameterizedType) { + curT = ((Class) ((ParameterizedType) curT).getRawType()).getGenericSuperclass(); + } + // class + else { + curT = ((Class) curT).getGenericSuperclass(); + } + } + return curT; + } @SuppressWarnings({ "unchecked", "rawtypes" }) private TypeInformation createTypeInfoWithTypeHierarchy(ArrayList typeHierarchy, Type t, @@ -227,7 +257,7 @@ private TypeInformation createTypeInfoWithTypeHierarchy(Arr // check if type is a subclass of tuple if ((t instanceof Class && Tuple.class.isAssignableFrom((Class) t)) - || (t instanceof ParameterizedType && Tuple.class.isAssignableFrom((Class) ((ParameterizedType) t).getRawType()))) { + || (t instanceof ParameterizedType && Tuple.class.isAssignableFrom((Class) ((ParameterizedType) t).getRawType()))) { Type curT = t; @@ -236,7 +266,7 @@ private TypeInformation createTypeInfoWithTypeHierarchy(Arr throw new InvalidTypesException( "Usage of class Tuple as a type is not allowed. Use a concrete subclass (e.g. Tuple1, Tuple2, etc.) instead."); } - + // go up the hierarchy until we reach immediate child of Tuple (with or without generics) // collect the types while moving up for a later top-down while (!(curT instanceof ParameterizedType && ((Class) ((ParameterizedType) curT).getRawType()).getSuperclass().equals( @@ -295,15 +325,23 @@ private TypeInformation createTypeInfoWithTypeHierarchy(Arr } } - // TODO: Check that type that extends Tuple does not have additional fields. - // Right now, these fields are not be serialized by the TupleSerializer. - // We might want to add an ExtendedTupleSerializer for that. - + Class tAsClass = null; if (t instanceof Class) { - return new TupleTypeInfo(((Class) t), tupleSubTypes); + tAsClass = (Class) t; } else if (t instanceof ParameterizedType) { - return new TupleTypeInfo(((Class) ((ParameterizedType) t).getRawType()), tupleSubTypes); + tAsClass = (Class) ((ParameterizedType) t).getRawType(); } + Preconditions.checkNotNull(tAsClass, "t has a unexpected type"); + // check if the class we assumed to be a Tuple so far is actually a pojo because it contains additional fields. + // check for additional fields. + int fieldCount = countFieldsInClass(tAsClass); + if(fieldCount != tupleSubTypes.length) { + // the class is not a real tuple because it contains additional fields. treat as a pojo + return (TypeInformation) analyzePojo(tAsClass, new ArrayList() ); // the typeHierarchy here should be sufficient, even though it stops at the Tuple.class. + } + + return new TupleTypeInfo(tAsClass, tupleSubTypes); + } // type depends on another type // e.g. class MyMapper extends MapFunction @@ -361,16 +399,29 @@ else if (t instanceof GenericArrayType) { } // objects with generics are treated as raw type else if (t instanceof ParameterizedType) { - return privateGetForClass((Class) ((ParameterizedType) t).getRawType()); + return privateGetForClass((Class) ((ParameterizedType) t).getRawType(), new ArrayList()); // pass new type hierarchies here because + // while creating the TH here, we assumed a tuple type. } // no tuple, no TypeVariable, no generic type else if (t instanceof Class) { - return privateGetForClass((Class) t); + return privateGetForClass((Class) t, new ArrayList()); } throw new InvalidTypesException("Type Information could not be created."); } + private int countFieldsInClass(Class clazz) { + int fieldCount = 0; + for(Field field : clazz.getFields()) { // get all fields + if( !Modifier.isStatic(field.getModifiers()) && + !Modifier.isTransient(field.getModifiers()) + ) { + fieldCount++; + } + } + return fieldCount; + } + private TypeInformation createTypeInfoFromInputs(TypeVariable returnTypeVar, ArrayList returnTypeHierarchy, TypeInformation in1TypeInfo, TypeInformation in2TypeInfo) { @@ -383,7 +434,7 @@ private TypeInformation createTypeInfoFromInputs(TypeVariable r else { returnTypeVar = (TypeVariable) matReturnTypeVar; } - + TypeInformation info = null; if (in1TypeInfo != null) { // find the deepest type variable that describes the type of input 1 @@ -806,11 +857,11 @@ private static Type materializeTypeVariable(ArrayList typeHierarchy, TypeV } public static TypeInformation getForClass(Class clazz) { - return new TypeExtractor().privateGetForClass(clazz); + return new TypeExtractor().privateGetForClass(clazz, new ArrayList()); } @SuppressWarnings("unchecked") - private TypeInformation privateGetForClass(Class clazz) { + private TypeInformation privateGetForClass(Class clazz, ArrayList typeHierarchy) { Validate.notNull(clazz); // check for abstract classes or interfaces @@ -819,10 +870,8 @@ private TypeInformation privateGetForClass(Class clazz) { } if (clazz.equals(Object.class)) { - // this will occur when trying to analyze POJOs that have generic, this - // exception will be caught and a GenericTypeInfo will be created for the type. - // at some point we might support this using Kryo - throw new InvalidTypesException("Object is not a valid type."); + // TODO (merging): better throw an exception here. the runtime does not support it yet + return new GenericTypeInfo(clazz); } // check for arrays @@ -879,62 +928,135 @@ private TypeInformation privateGetForClass(Class clazz) { // special case handling for Class, this should not be handled by the POJO logic return new GenericTypeInfo(clazz); } + if(GenericContainer.class.isAssignableFrom(clazz)) { + // this is a type generated by Avro. GenericTypeInfo is able to handle this case because its using Avro. + return new GenericTypeInfo(clazz); + } + TypeInformation pojoType = analyzePojo(clazz, typeHierarchy); + if (pojoType != null) { + return pojoType; + } -// Disable POJO types for now (see https://mail-archives.apache.org/mod_mbox/incubator-flink-dev/201407.mbox/%3C53D96049.1060509%40cse.uta.edu%3E) -// -// TypeInformation pojoType = analyzePojo(clazz); -// if (pojoType != null) { -// return pojoType; -// } // return a generic type return new GenericTypeInfo(clazz); } + + /** + * Checks if the given field is a valid pojo field: + * - it is public + * OR + * - there are getter and setter methods for the field. + * + * @param f field to check + * @param clazz class of field + * @param typeHierarchy type hierarchy for materializing generic types + * @return + */ + private boolean isValidPojoField(Field f, Class clazz, ArrayList typeHierarchy) { + if(Modifier.isPublic(f.getModifiers())) { + return true; + } else { + boolean hasGetter = false, hasSetter = false; + final String fieldNameLow = f.getName().toLowerCase(); + + Type fieldType = f.getGenericType(); + TypeVariable fieldTypeGeneric = null; + if(fieldType instanceof TypeVariable) { + fieldTypeGeneric = (TypeVariable) fieldType; + fieldType = materializeTypeVariable(typeHierarchy, (TypeVariable)fieldType); + } + for(Method m : clazz.getMethods()) { + // check for getter + + if( // The name should be "get". + m.getName().toLowerCase().contains("get"+fieldNameLow) && + // no arguments for the getter + m.getParameterTypes().length == 0 && + // return type is same as field type (or the generic variant of it) + m.getReturnType().equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericReturnType().equals(fieldTypeGeneric) ) + ) { + if(hasGetter) { + throw new IllegalStateException("Detected more than one getters"); + } + hasGetter = true; + } + // check for setters + if( m.getName().toLowerCase().contains("set"+fieldNameLow) && + m.getParameterTypes().length == 1 && // one parameter of the field's type + ( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&& + // return type is void. + m.getReturnType().equals(Void.TYPE) + ) { + if(hasSetter) { + throw new IllegalStateException("Detected more than one getters"); + } + hasSetter = true; + } + } + if( hasGetter && hasSetter) { + return true; + } else { + if(!hasGetter) { + LOG.warn("Class "+clazz+" does not contain a getter for field "+f.getName() ); + } + if(!hasSetter) { + LOG.warn("Class "+clazz+" does not contain a setter for field "+f.getName() ); + } + return false; + } + } + } - @SuppressWarnings("unused") - private TypeInformation analyzePojo(Class clazz) { - List fields = getAllDeclaredFields(clazz); + private TypeInformation analyzePojo(Class clazz, ArrayList typeHierarchy) { + // try to create Type hierarchy, if the incoming one is empty. + if(typeHierarchy.size() == 0) { + recursivelyGetTypeHierarchy(typeHierarchy, clazz, Object.class); + } + + List fields = removeNonObjectFields(getAllDeclaredFields(clazz)); List pojoFields = new ArrayList(); for (Field field : fields) { + Type fieldType = field.getGenericType(); + if(!isValidPojoField(field, clazz, typeHierarchy)) { + LOG.warn("Class "+clazz+" is not a valid POJO type"); + return null; + } try { - if (!Modifier.isTransient(field.getModifiers()) && !Modifier.isStatic(field.getModifiers())) { - pojoFields.add(new PojoField(field, privateCreateTypeInfo(field.getType()))); - } + typeHierarchy.add(fieldType); + pojoFields.add(new PojoField(field, createTypeInfoWithTypeHierarchy(typeHierarchy, fieldType, null, null) )); } catch (InvalidTypesException e) { - // If some of the fields cannot be analyzed, just return a generic type info - // right now this happens when a field is an interface (collections are the prominent case here) or - // when the POJO is generic, in which case the fields will have type Object. - // We might fix that in the future when we use Kryo. - return new GenericTypeInfo(clazz); + //pojoFields.add(new PojoField(field, new GenericTypeInfo( Object.class ))); // we need kryo to properly serialize this + throw new InvalidTypesException("Flink is currently unable to serialize this type: "+fieldType+"" + + "\nThe system is using the Avro serializer which is not able to handle all types.", e); } } - PojoTypeInfo pojoType = new PojoTypeInfo(clazz, pojoFields); + CompositeType pojoType = new PojoTypeInfo(clazz, pojoFields); + // + // Validate the correctness of the pojo. + // returning "null" will result create a generic type information. + // List methods = getAllDeclaredMethods(clazz); - boolean containsReadObjectOrWriteObject = false; for (Method method : methods) { if (method.getName().equals("readObject") || method.getName().equals("writeObject")) { - containsReadObjectOrWriteObject = true; - break; + LOG.warn("Class "+clazz+" contains custom serialization methods we do not call."); + return null; } } // Try retrieving the default constructor, if it does not have one // we cannot use this because the serializer uses it. - boolean hasDefaultCtor = true; try { clazz.getDeclaredConstructor(); } catch (NoSuchMethodException e) { - hasDefaultCtor = false; + LOG.warn("Class "+clazz+" does not have a default constructor. You can not use it as a POJO"); + return null; } - - - if (!containsReadObjectOrWriteObject && hasDefaultCtor) { - return pojoType; - } - - return null; + + // everything is checked, we return the pojo + return pojoType; } // recursively determine all declared fields @@ -950,6 +1072,19 @@ private static List getAllDeclaredFields(Class clazz) { return result; } + /** + * Remove transient and static fields from a list of fields. + */ + private static List removeNonObjectFields(List fields) { + List result = new ArrayList(); + for(Field field: fields) { + if (!Modifier.isTransient(field.getModifiers()) && !Modifier.isStatic(field.getModifiers())) { + result.add(field); + } + } + return result; + } + // recursively determine all declared methods private static List getAllDeclaredMethods(Class clazz) { List result = new ArrayList(); @@ -976,6 +1111,11 @@ private TypeInformation privateGetForObject(X value) { if (value instanceof Tuple) { Tuple t = (Tuple) value; int numFields = t.getArity(); + if(numFields != countFieldsInClass(value.getClass())) { + // not a tuple since it has more fields. + return analyzePojo((Class) value.getClass(), new ArrayList()); // we immediately call analyze Pojo here, because + // there is currently no other type that can handle such a class. + } TypeInformation[] infos = new TypeInformation[numFields]; for (int i = 0; i < numFields; i++) { @@ -988,10 +1128,9 @@ private TypeInformation privateGetForObject(X value) { infos[i] = privateGetForObject(field); } - return (TypeInformation) new TupleTypeInfo(value.getClass(), infos); } else { - return privateGetForClass((Class) value.getClass()); + return privateGetForClass((Class) value.getClass(), new ArrayList()); } } } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java index 6894e5ac019c2..953b69c623339 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/TypeInfoParser.java @@ -87,7 +87,7 @@ public static TypeInformation parse(String infoString) { } return (TypeInformation) parse(new StringBuilder(clearedString)); } catch (Exception e) { - throw new IllegalArgumentException("String could not be parsed: " + e.getMessage()); + throw new IllegalArgumentException("String could not be parsed: " + e.getMessage(), e); } } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ValueTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ValueTypeInfo.java index 5045b382f25c6..b3c25e47680d6 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ValueTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ValueTypeInfo.java @@ -52,6 +52,11 @@ public int getArity() { return 1; } + @Override + public int getTotalFields() { + return 1; + } + @Override public Class getTypeClass() { return this.type; diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/WritableTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/WritableTypeInfo.java index 19bcf0bc10123..8c9e948d8c606 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/WritableTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/WritableTypeInfo.java @@ -67,6 +67,11 @@ public boolean isTupleType() { public int getArity() { return 1; } + + @Override + public int getTotalFields() { + return 1; + } @Override public Class getTypeClass() { @@ -88,6 +93,20 @@ public String toString() { return "WritableType<" + typeClass.getName() + ">"; } + @Override + public int hashCode() { + return typeClass.hashCode() ^ 0xd3a2646c; + } + + @Override + public boolean equals(Object obj) { + if (obj.getClass() == WritableTypeInfo.class) { + return typeClass == ((WritableTypeInfo) obj).typeClass; + } else { + return false; + } + } + // -------------------------------------------------------------------------------------------- static final TypeInformation getWritableTypeInfo(Class typeClass) { diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/CopyableValueComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/CopyableValueComparator.java index 0b7890fa0a1a3..9b3b191a9642a 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/CopyableValueComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/CopyableValueComparator.java @@ -43,8 +43,6 @@ public class CopyableValueComparator & Comparable> private transient T tempReference; - private final Comparable[] extractedKey = new Comparable[1]; - private final TypeComparator[] comparators = new TypeComparator[] {this}; public CopyableValueComparator(boolean ascending, Class type) { @@ -126,13 +124,13 @@ public TypeComparator duplicate() { } @Override - public Object[] extractKeys(T record) { - extractedKey[0] = record; - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; } @Override - public TypeComparator[] getComparators() { + public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericTypeComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericTypeComparator.java index 66cbdf4da3d7c..2d3ce392b2741 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericTypeComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericTypeComparator.java @@ -52,9 +52,6 @@ public class GenericTypeComparator> extends TypeComparat private transient Kryo kryo; - @SuppressWarnings("rawtypes") - private final Comparable[] extractedKey = new Comparable[1]; - @SuppressWarnings("rawtypes") private final TypeComparator[] comparators = new TypeComparator[] {this}; @@ -171,14 +168,14 @@ private final void checkKryoInitialized() { } @Override - public Object[] extractKeys(T record) { - extractedKey[0] = record; - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; } - @Override @SuppressWarnings("rawtypes") - public TypeComparator[] getComparators() { + @Override + public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java index 9ae2d6b32218c..9d7eed44b161a 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java @@ -22,7 +22,9 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.lang.reflect.Field; +import java.util.List; +import org.apache.flink.api.common.typeutils.CompositeTypeComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputView; @@ -32,10 +34,11 @@ import org.apache.flink.util.InstantiationUtil; -public final class PojoComparator extends TypeComparator implements java.io.Serializable { - +public final class PojoComparator extends CompositeTypeComparator implements java.io.Serializable { + private static final long serialVersionUID = 1L; + // Reflection fields for the comp fields private transient Field[] keyFields; private final TypeComparator[] comparators; @@ -52,8 +55,6 @@ public final class PojoComparator extends TypeComparator implements java.i private final Class type; - private final Comparable[] extractedKeys; - @SuppressWarnings("unchecked") public PojoComparator(Field[] keyFields, TypeComparator[] comparators, TypeSerializer serializer, Class type) { this.keyFields = keyFields; @@ -70,6 +71,12 @@ public PojoComparator(Field[] keyFields, TypeComparator[] comparators, TypeSe for (int i = 0; i < this.comparators.length; i++) { TypeComparator k = this.comparators[i]; + if(k == null) { + throw new IllegalArgumentException("One of the passed comparators is null"); + } + if(keyFields[i] == null) { + throw new IllegalArgumentException("One of the passed reflection fields is null"); + } // as long as the leading keys support normalized keys, we can build up the composite key if (k.supportsNormalizedKey()) { @@ -102,8 +109,6 @@ else if (k.invertNormalizedKey() != inverted) { this.numLeadingNormalizableKeys = nKeys; this.normalizableKeyPrefixLen = nKeyLen; this.invertNormKey = inverted; - - extractedKeys = new Comparable[keyFields.length]; } @SuppressWarnings("unchecked") @@ -130,8 +135,6 @@ private PojoComparator(PojoComparator toClone) { } catch (ClassNotFoundException e) { throw new RuntimeException("Cannot copy serializer", e); } - - extractedKeys = new Comparable[keyFields.length]; } private void writeObject(ObjectOutputStream out) @@ -150,87 +153,87 @@ private void readObject(ObjectInputStream in) int numKeyFields = in.readInt(); keyFields = new Field[numKeyFields]; for (int i = 0; i < numKeyFields; i++) { - Class clazz = (Class)in.readObject(); + Class clazz = (Class) in.readObject(); String fieldName = in.readUTF(); - keyFields[i] = null; // try superclasses as well while (clazz != null) { try { - keyFields[i] = clazz.getDeclaredField(fieldName); - keyFields[i].setAccessible(true); + Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + keyFields[i] = field; break; } catch (NoSuchFieldException e) { clazz = clazz.getSuperclass(); } } - if (keyFields[i] == null) { + if (keyFields[i] == null ) { throw new RuntimeException("Class resolved at TaskManager is not compatible with class read during Plan setup." + " (" + fieldName + ")"); } } } - public Field[] getKeyFields() { return this.keyFields; } - public TypeComparator[] getComparators() { - return this.comparators; + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Override + public void getFlatComparator(List flatComparators) { + for(int i = 0; i < comparators.length; i++) { + if(comparators[i] instanceof CompositeTypeComparator) { + ((CompositeTypeComparator)comparators[i]).getFlatComparator(flatComparators); + } else { + flatComparators.add(comparators[i]); + } + } + } + + /** + * This method is handling the IllegalAccess exceptions of Field.get() + */ + private final Object accessField(Field field, Object object) { + try { + object = field.get(object); + } catch (NullPointerException npex) { + throw new NullKeyFieldException("Unable to access field "+field+" on object "+object); + } catch (IllegalAccessException iaex) { + throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo." + + " fiels: " + field + " obj: " + object); + } + return object; } @Override public int hash(T value) { int i = 0; - try { - int code = 0; - for (; i < this.keyFields.length; i++) { - code ^= this.comparators[i].hash(this.keyFields[i].get(value)); - code *= HASH_SALT[i & 0x1F]; // salt code with (i % HASH_SALT.length)-th salt component - } - return code; - } - catch (NullPointerException npex) { - throw new NullKeyFieldException(this.keyFields[i].toString()); - } - catch (IllegalAccessException iaex) { - throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo."); + int code = 0; + for (; i < this.keyFields.length; i++) { + code *= TupleComparatorBase.HASH_SALT[i & 0x1F]; + code += this.comparators[i].hash(accessField(keyFields[i], value)); + } + return code; + } @Override public void setReference(T toCompare) { int i = 0; - try { - for (; i < this.keyFields.length; i++) { - this.comparators[i].setReference(this.keyFields[i].get(toCompare)); - } - } - catch (NullPointerException npex) { - throw new NullKeyFieldException(this.keyFields[i].toString()); - } - catch (IllegalAccessException iaex) { - throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo."); + for (; i < this.keyFields.length; i++) { + this.comparators[i].setReference(accessField(keyFields[i], toCompare)); } } @Override public boolean equalToReference(T candidate) { int i = 0; - try { - for (; i < this.keyFields.length; i++) { - if (!this.comparators[i].equalToReference(this.keyFields[i].get(candidate))) { - return false; - } + for (; i < this.keyFields.length; i++) { + if (!this.comparators[i].equalToReference(accessField(keyFields[i], candidate))) { + return false; } - return true; - } - catch (NullPointerException npex) { - throw new NullKeyFieldException(this.keyFields[i].toString()); - } - catch (IllegalAccessException iaex) { - throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo."); } + return true; } @Override @@ -255,22 +258,17 @@ public int compareToReference(TypeComparator referencedComparator) { @Override public int compare(T first, T second) { int i = 0; - try { - for (; i < keyFields.length; i++) { - int cmp = comparators[i].compare(keyFields[i].get(first),keyFields[i].get(second)); - if (cmp != 0) { - return cmp; - } + for (; i < keyFields.length; i++) { + int cmp = comparators[i].compare(accessField(keyFields[i], first), accessField(keyFields[i], second)); + if (cmp != 0) { + return cmp; } - - return 0; - } catch (NullPointerException npex) { - throw new NullKeyFieldException(keyFields[i].toString() + " " + first.toString() + " " + second.toString()); - } catch (IllegalAccessException iaex) { - throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo."); } + + return 0; } + @Override public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { T first = this.serializer.createInstance(); @@ -302,21 +300,13 @@ public boolean isNormalizedKeyPrefixOnly(int keyBytes) { @Override public void putNormalizedKey(T value, MemorySegment target, int offset, int numBytes) { int i = 0; - try { - for (; i < this.numLeadingNormalizableKeys & numBytes > 0; i++) - { - int len = this.normalizedKeyLengths[i]; - len = numBytes >= len ? len : numBytes; - this.comparators[i].putNormalizedKey(this.keyFields[i].get(value), target, offset, len); - numBytes -= len; - offset += len; - } - } - catch (IllegalAccessException iaex) { - throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo."); - } - catch (NullPointerException npex) { - throw new NullKeyFieldException(this.keyFields[i].toString()); + for (; i < this.numLeadingNormalizableKeys & numBytes > 0; i++) + { + int len = this.normalizedKeyLengths[i]; + len = numBytes >= len ? len : numBytes; + this.comparators[i].putNormalizedKey(accessField(keyFields[i], value), target, offset, len); + numBytes -= len; + offset += len; } } @@ -347,36 +337,21 @@ public PojoComparator duplicate() { } @Override - public Object[] extractKeys(T record) { - int i = 0; - try { - for (; i < keyFields.length; i++) { - extractedKeys[i] = (Comparable) keyFields[i].get(record); + public int extractKeys(Object record, Object[] target, int index) { + int localIndex = index; + for (int i = 0; i < comparators.length; i++) { + if(comparators[i] instanceof PojoComparator || comparators[i] instanceof TupleComparator) { + localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex) -1; + } else { + // non-composite case (= atomic). We can assume this to have only one key. + // comparators[i].extractKeys(accessField(keyFields[i], record), target, i); + target[localIndex] = accessField(keyFields[i], record); } + localIndex++; } - catch (IllegalAccessException iaex) { - throw new RuntimeException("This should not happen since we call setAccesssible(true) in PojoTypeInfo."); - } - catch (NullPointerException npex) { - throw new NullKeyFieldException(this.keyFields[i].toString()); - } - return extractedKeys; + return localIndex - index; } // -------------------------------------------------------------------------------------------- - - /** - * A sequence of prime numbers to be used for salting the computed hash values. - * Based on some empirical evidence, we are using a 32-element subsequence of the - * OEIS sequence #A068652 (numbers such that every cyclic permutation is a prime). - * - * @see: http://en.wikipedia.org/wiki/List_of_prime_numbers - * @see: http://oeis.org/A068652 - */ - private static final int[] HASH_SALT = new int[] { - 73 , 79 , 97 , 113 , 131 , 197 , 199 , 311 , - 337 , 373 , 719 , 733 , 919 , 971 , 991 , 1193 , - 1931 , 3119 , 3779 , 7793 , 7937 , 9311 , 9377 , 11939 , - 19391, 19937, 37199, 39119, 71993, 91193, 93719, 93911 }; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java index 71f2cd8af8fef..99b9f6551271a 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java @@ -174,11 +174,22 @@ public int getLength() { @Override public void serialize(T value, DataOutputView target) throws IOException { + // handle null values + if (value == null) { + target.writeBoolean(true); + return; + } else { + target.writeBoolean(false); + } try { - for (int i = 0; i < numFields; i++) { Object o = fields[i].get(value); - fieldSerializers[i].serialize(o, target); + if(o == null) { + target.writeBoolean(true); // null field handling + } else { + target.writeBoolean(false); + fieldSerializers[i].serialize(o, target); + } } } catch (IllegalAccessException e) { throw new RuntimeException("Error during POJO copy, this should not happen since we check the fields" + @@ -188,6 +199,10 @@ public void serialize(T value, DataOutputView target) throws IOException { @Override public T deserialize(DataInputView source) throws IOException { + boolean isNull = source.readBoolean(); + if(isNull) { + return null; + } T target; try { target = clazz.newInstance(); @@ -198,8 +213,13 @@ public T deserialize(DataInputView source) throws IOException { try { for (int i = 0; i < numFields; i++) { - Object field = fieldSerializers[i].deserialize(source); - fields[i].set(target, field); + isNull = source.readBoolean(); + if(isNull) { + fields[i].set(target, null); + } else { + Object field = fieldSerializers[i].deserialize(source); + fields[i].set(target, field); + } } } catch (IllegalAccessException e) { throw new RuntimeException("Error during POJO copy, this should not happen since we check the fields" + @@ -210,10 +230,20 @@ public T deserialize(DataInputView source) throws IOException { @Override public T deserialize(T reuse, DataInputView source) throws IOException { + // handle null values + boolean isNull = source.readBoolean(); + if (isNull) { + return null; + } try { for (int i = 0; i < numFields; i++) { - Object field = fieldSerializers[i].deserialize(fields[i].get(reuse), source); - fields[i].set(reuse, field); + isNull = source.readBoolean(); + if(isNull) { + fields[i].set(reuse, null); + } else { + Object field = fieldSerializers[i].deserialize(fields[i].get(reuse), source); + fields[i].set(reuse, field); + } } } catch (IllegalAccessException e) { throw new RuntimeException("Error during POJO copy, this should not happen since we check the fields" + @@ -224,7 +254,10 @@ public T deserialize(T reuse, DataInputView source) throws IOException { @Override public void copy(DataInputView source, DataOutputView target) throws IOException { + // copy the Non-Null/Null tag + target.writeBoolean(source.readBoolean()); for (int i = 0; i < numFields; i++) { + target.writeBoolean(source.readBoolean()); fieldSerializers[i].copy(source, target); } } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java index eee6643ab5ea7..31e28f7a5d7ae 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java @@ -28,7 +28,6 @@ public final class RuntimePairComparatorFactory private static final long serialVersionUID = 1L; - @SuppressWarnings("unchecked") @Override public TypePairComparator createComparator12( TypeComparator comparator1, @@ -36,7 +35,6 @@ public TypePairComparator createComparator12( return new GenericPairComparator(comparator1, comparator2); } - @SuppressWarnings("unchecked") @Override public TypePairComparator createComparator21( TypeComparator comparator1, diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java index f9b40849870ef..61a1567c8a66f 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java @@ -30,26 +30,20 @@ public final class TupleComparator extends TupleComparatorBase { private static final long serialVersionUID = 1L; - - private final Object[] extractedKeys; - - @SuppressWarnings("unchecked") + public TupleComparator(int[] keyPositions, TypeComparator[] comparators, TypeSerializer[] serializers) { super(keyPositions, comparators, serializers); - extractedKeys = new Object[keyPositions.length]; } - @SuppressWarnings("unchecked") private TupleComparator(TupleComparator toClone) { super(toClone); - extractedKeys = new Object[keyPositions.length]; - } // -------------------------------------------------------------------------------------------- // Comparator Methods // -------------------------------------------------------------------------------------------- + @SuppressWarnings("unchecked") @Override public int hash(T value) { int i = 0; @@ -70,6 +64,7 @@ public int hash(T value) { } } + @SuppressWarnings("unchecked") @Override public void setReference(T toCompare) { int i = 0; @@ -86,6 +81,7 @@ public void setReference(T toCompare) { } } + @SuppressWarnings("unchecked") @Override public boolean equalToReference(T candidate) { int i = 0; @@ -105,6 +101,7 @@ public boolean equalToReference(T candidate) { } } + @SuppressWarnings("unchecked") @Override public int compare(T first, T second) { int i = 0; @@ -127,6 +124,7 @@ public int compare(T first, T second) { } } + @SuppressWarnings("unchecked") @Override public void putNormalizedKey(T value, MemorySegment target, int offset, int numBytes) { int i = 0; @@ -146,11 +144,19 @@ public void putNormalizedKey(T value, MemorySegment target, int offset, int numB } @Override - public Object[] extractKeys(T record) { - for (int i = 0; i < keyPositions.length; i++) { - extractedKeys[i] = record.getField(keyPositions[i]); + public int extractKeys(Object record, Object[] target, int index) { + int localIndex = index; + for(int i = 0; i < comparators.length; i++) { + // handle nested case + if(comparators[i] instanceof TupleComparator || comparators[i] instanceof PojoComparator) { + localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex) -1; + } else { + // flat + target[localIndex] = ((Tuple) record).getField(keyPositions[i]); + } + localIndex++; } - return extractedKeys; + return localIndex - index; } public TypeComparator duplicate() { diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorBase.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorBase.java index 4c263181ff686..abcf89c2a0b57 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorBase.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparatorBase.java @@ -15,9 +15,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.api.java.typeutils.runtime; +import java.io.IOException; +import java.util.List; + +import org.apache.flink.api.common.typeutils.CompositeTypeComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerFactory; @@ -26,15 +29,16 @@ import org.apache.flink.types.KeyFieldOutOfBoundsException; import org.apache.flink.types.NullKeyFieldException; -import java.io.IOException; +public abstract class TupleComparatorBase extends CompositeTypeComparator implements java.io.Serializable { -public abstract class TupleComparatorBase extends TypeComparator implements java.io.Serializable { + private static final long serialVersionUID = 1L; /** key positions describe which fields are keys in what order */ protected int[] keyPositions; /** comparators for the key fields, in the same order as the key fields */ + @SuppressWarnings("rawtypes") protected TypeComparator[] comparators; /** serializer factories to duplicate non thread-safe serializers */ @@ -51,6 +55,7 @@ public abstract class TupleComparatorBase extends TypeComparator implement /** serializers to deserialize the first n fields for comparison */ + @SuppressWarnings("rawtypes") protected transient TypeSerializer[] serializers; // cache for the deserialized field objects @@ -115,7 +120,6 @@ else if (k.invertNormalizedKey() != inverted) { this.invertNormKey = inverted; } - @SuppressWarnings("unchecked") protected TupleComparatorBase(TupleComparatorBase toClone) { privateDuplicate(toClone); } @@ -146,14 +150,23 @@ protected int[] getKeyPositions() { return this.keyPositions; } - public TypeComparator[] getComparators() { - return this.comparators; - } + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Override + public void getFlatComparator(List flatComparators) { + for(int i = 0; i < comparators.length; i++) { + if(comparators[i] instanceof CompositeTypeComparator) { + ((CompositeTypeComparator)comparators[i]).getFlatComparator(flatComparators); + } else { + flatComparators.add(comparators[i]); + } + } + } // -------------------------------------------------------------------------------------------- // Comparator Methods // -------------------------------------------------------------------------------------------- + @Override public int compareToReference(TypeComparator referencedComparator) { TupleComparatorBase other = (TupleComparatorBase) referencedComparator; @@ -161,6 +174,7 @@ public int compareToReference(TypeComparator referencedComparator) { int i = 0; try { for (; i < this.keyPositions.length; i++) { + @SuppressWarnings("unchecked") int cmp = this.comparators[i].compareToReference(other.comparators[i]); if (cmp != 0) { return cmp; @@ -176,6 +190,7 @@ public int compareToReference(TypeComparator referencedComparator) { } } + @SuppressWarnings("unchecked") @Override public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { if (deserializedFields1 == null) { @@ -201,7 +216,7 @@ public int compareSerialized(DataInputView firstSource, DataInputView secondSour } catch (NullPointerException npex) { throw new NullKeyFieldException(keyPositions[i]); } catch (IndexOutOfBoundsException iobex) { - throw new KeyFieldOutOfBoundsException(keyPositions[i]); + throw new KeyFieldOutOfBoundsException(keyPositions[i], iobex); } } @@ -245,7 +260,6 @@ public T readWithKeyDenormalization(T reuse, DataInputView source) throws IOExce // -------------------------------------------------------------------------------------------- - @SuppressWarnings("unchecked") protected final void instantiateDeserializationUtils() { if (this.serializers == null) { this.serializers = new TypeSerializer[this.serializerFactories.length]; diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/ValueComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/ValueComparator.java index a51086374b0bf..eca1e6c0b59e0 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/ValueComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/ValueComparator.java @@ -47,8 +47,7 @@ public class ValueComparator> extends TypeCompar private transient Kryo kryo; - private final Comparable[] extractedKey = new Comparable[1]; - + @SuppressWarnings("rawtypes") private final TypeComparator[] comparators = new TypeComparator[] {this}; public ValueComparator(boolean ascending, Class type) { @@ -145,13 +144,14 @@ private final void checkKryoInitialized() { } @Override - public Object[] extractKeys(T record) { - extractedKey[0] = record; - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; } + @SuppressWarnings("rawtypes") @Override - public TypeComparator[] getComparators() { + public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/WritableComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/WritableComparator.java index a8c4ef52b42e0..88985bb10859e 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/WritableComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/WritableComparator.java @@ -44,8 +44,7 @@ public class WritableComparator> extends Type private transient Kryo kryo; - private final Comparable[] extractedKey = new Comparable[1]; - + @SuppressWarnings("rawtypes") private final TypeComparator[] comparators = new TypeComparator[] {this}; public WritableComparator(boolean ascending, Class type) { @@ -129,12 +128,13 @@ public TypeComparator duplicate() { } @Override - public Object[] extractKeys(T record) { - extractedKey[0] = record; - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; } - @Override public TypeComparator[] getComparators() { + @SuppressWarnings("rawtypes") + @Override public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java index 898eae503a2ec..d686633fc6431 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java @@ -18,7 +18,6 @@ package org.apache.flink.api.java.operator; -import java.io.Serializable; import java.util.ArrayList; import java.util.List; @@ -29,10 +28,10 @@ import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType; @SuppressWarnings("serial") public class CoGroupOperatorTest { @@ -127,7 +126,6 @@ public void testCoGroupKeyFields6() { ds1.coGroup(ds2).where(5).equalTo(0); } - @Ignore @Test public void testCoGroupKeyExpressions1() { @@ -137,13 +135,12 @@ public void testCoGroupKeyExpressions1() { // should work try { -// ds1.coGroup(ds2).where("myInt").equalTo("myInt"); + ds1.coGroup(ds2).where("myInt").equalTo("myInt"); } catch(Exception e) { Assert.fail(); } } - @Ignore @Test(expected = InvalidProgramException.class) public void testCoGroupKeyExpressions2() { @@ -152,10 +149,9 @@ public void testCoGroupKeyExpressions2() { DataSet ds2 = env.fromCollection(customTypeData); // should not work, incompatible cogroup key types -// ds1.coGroup(ds2).where("myInt").equalTo("myString"); + ds1.coGroup(ds2).where("myInt").equalTo("myString"); } - @Ignore @Test(expected = InvalidProgramException.class) public void testCoGroupKeyExpressions3() { @@ -164,10 +160,9 @@ public void testCoGroupKeyExpressions3() { DataSet ds2 = env.fromCollection(customTypeData); // should not work, incompatible number of cogroup keys -// ds1.coGroup(ds2).where("myInt", "myString").equalTo("myString"); + ds1.coGroup(ds2).where("myInt", "myString").equalTo("myString"); } - @Ignore @Test(expected = IllegalArgumentException.class) public void testCoGroupKeyExpressions4() { @@ -176,9 +171,58 @@ public void testCoGroupKeyExpressions4() { DataSet ds2 = env.fromCollection(customTypeData); // should not work, cogroup key non-existent -// ds1.coGroup(ds2).where("myNonExistent").equalTo("myInt"); + ds1.coGroup(ds2).where("myNonExistent").equalTo("myInt"); + } + + @Test + public void testCoGroupKeyExpressions1Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should work + try { + ds1.coGroup(ds2).where("nested.myInt").equalTo("nested.myInt"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } } + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyExpressions2Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should not work, incompatible cogroup key types + ds1.coGroup(ds2).where("nested.myInt").equalTo("nested.myString"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyExpressions3Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should not work, incompatible number of cogroup keys + ds1.coGroup(ds2).where("nested.myInt", "nested.myString").equalTo("nested.myString"); + } + + @Test(expected = IllegalArgumentException.class) + public void testCoGroupKeyExpressions4Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should not work, cogroup key non-existent + ds1.coGroup(ds2).where("nested.myNonExistent").equalTo("nested.myInt"); + } + @Test public void testCoGroupKeySelectors1() { @@ -304,26 +348,4 @@ public Long getKey(CustomType value) { } ); } - - public static class CustomType implements Serializable { - - private static final long serialVersionUID = 1L; - - public int myInt; - public long myLong; - public String myString; - - public CustomType() {}; - - public CustomType(int i, long l, String s) { - myInt = i; - myLong = l; - myString = s; - } - - @Override - public String toString() { - return myInt+","+myLong+","+myString; - } - } } diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java index adc8917e8359e..c9586802776b4 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java @@ -32,7 +32,6 @@ import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; public class GroupingTest { @@ -112,7 +111,6 @@ public void testGroupByKeyFields5() { tupleDs.groupBy(-1); } - @Ignore @Test public void testGroupByKeyExpressions1() { @@ -124,24 +122,22 @@ public void testGroupByKeyExpressions1() { // should work try { -// ds.groupBy("myInt"); + ds.groupBy("myInt"); } catch(Exception e) { Assert.fail(); } } - @Ignore - @Test(expected = UnsupportedOperationException.class) + @Test(expected = IllegalArgumentException.class) public void testGroupByKeyExpressions2() { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet longDs = env.fromCollection(emptyLongData, BasicTypeInfo.LONG_TYPE_INFO); // should not work: groups on basic type -// longDs.groupBy("myInt"); + longDs.groupBy("myInt"); } - @Ignore @Test(expected = InvalidProgramException.class) public void testGroupByKeyExpressions3() { @@ -150,12 +146,11 @@ public void testGroupByKeyExpressions3() { this.customTypeData.add(new CustomType()); DataSet customDs = env.fromCollection(customTypeData); - // should not work: groups on custom type + // should not work: tuple selector on custom type customDs.groupBy(0); } - @Ignore @Test(expected = IllegalArgumentException.class) public void testGroupByKeyExpressions4() { @@ -163,7 +158,34 @@ public void testGroupByKeyExpressions4() { DataSet ds = env.fromCollection(customTypeData); // should not work, key out of tuple bounds -// ds.groupBy("myNonExistent"); + ds.groupBy("myNonExistent"); + } + + @Test + public void testGroupByKeyExpressions1Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + this.customTypeData.add(new CustomType()); + + DataSet ds = env.fromCollection(customTypeData); + + // should work + try { + ds.groupBy("nested.myInt"); + } catch(Exception e) { + Assert.fail(); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testGroupByKeyExpressions2Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds = env.fromCollection(customTypeData); + + // should not work, key out of tuple bounds + ds.groupBy("nested.myNonExistent"); } @@ -309,11 +331,15 @@ public void testChainedGroupSortKeyFields() { public static class CustomType implements Serializable { + public static class Nest { + public int myInt; + } private static final long serialVersionUID = 1L; public int myInt; public long myLong; public String myString; + public Nest nested; public CustomType() {}; diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java index 962755e2e7a42..de50fd8371083 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java @@ -22,17 +22,20 @@ import java.util.ArrayList; import java.util.List; -import org.junit.Assert; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.junit.Assert; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; @SuppressWarnings("serial") public class JoinOperatorTest { @@ -41,20 +44,51 @@ public class JoinOperatorTest { private static final List> emptyTupleData = new ArrayList>(); - private final TupleTypeInfo> tupleTypeInfo = new - TupleTypeInfo>( + private final TupleTypeInfo> tupleTypeInfo = + new TupleTypeInfo>( BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO ); + // TUPLE DATA with nested Tuple2 + private static final List, Long, String, Long, Integer>> emptyNestedTupleData = + new ArrayList, Long, String, Long, Integer>>(); + private final TupleTypeInfo, Long, String, Long, Integer>> nestedTupleTypeInfo = + new TupleTypeInfo, Long, String, Long, Integer>>( + new TupleTypeInfo> (BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO + ); + + // TUPLE DATA with nested CustomType + private static final List> emptyNestedCustomTupleData = + new ArrayList>(); + + private final TupleTypeInfo> nestedCustomTupleTypeInfo = + new TupleTypeInfo>( + TypeExtractor.getForClass(CustomType.class), + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO + ); + + private static List customTypeWithTupleData = new ArrayList(); private static List customTypeData = new ArrayList(); + private static List customNestedTypeData = new ArrayList(); + + @BeforeClass public static void insertCustomData() { customTypeData.add(new CustomType()); + customTypeWithTupleData.add(new CustomTypeWithTuple()); + customNestedTypeData.add(new NestedCustomType()); } @Test @@ -127,7 +161,6 @@ public void testJoinKeyFields6() { ds1.join(ds2).where(5).equalTo(0); } - @Ignore @Test public void testJoinKeyExpressions1() { @@ -137,13 +170,29 @@ public void testJoinKeyExpressions1() { // should work try { -// ds1.join(ds2).where("myInt").equalTo("myInt"); + ds1.join(ds2).where("myInt").equalTo("myInt"); + } catch(Exception e) { + Assert.fail(); + } + } + + @Test + public void testJoinKeyExpressionsNested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customNestedTypeData); + DataSet ds2 = env.fromCollection(customNestedTypeData); + + // should work + try { + ds1.join(ds2).where("myInt").equalTo("myInt"); } catch(Exception e) { Assert.fail(); } } + + - @Ignore @Test(expected = InvalidProgramException.class) public void testJoinKeyExpressions2() { @@ -152,10 +201,9 @@ public void testJoinKeyExpressions2() { DataSet ds2 = env.fromCollection(customTypeData); // should not work, incompatible join key types -// ds1.join(ds2).where("myInt").equalTo("myString"); + ds1.join(ds2).where("myInt").equalTo("myString"); } - @Ignore @Test(expected = InvalidProgramException.class) public void testJoinKeyExpressions3() { @@ -164,10 +212,9 @@ public void testJoinKeyExpressions3() { DataSet ds2 = env.fromCollection(customTypeData); // should not work, incompatible number of join keys -// ds1.join(ds2).where("myInt", "myString").equalTo("myString"); + ds1.join(ds2).where("myInt", "myString").equalTo("myString"); } - @Ignore @Test(expected = IllegalArgumentException.class) public void testJoinKeyExpressions4() { @@ -176,7 +223,230 @@ public void testJoinKeyExpressions4() { DataSet ds2 = env.fromCollection(customTypeData); // should not work, join key non-existent -// ds1.join(ds2).where("myNonExistent").equalTo("myInt"); + ds1.join(ds2).where("myNonExistent").equalTo("myInt"); + } + + /** + * Test if mixed types of key selectors are properly working. + */ + @Test + public void testJoinKeyMixedKeySelector() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + try { + ds1.join(ds2).where("myInt").equalTo(new KeySelector() { + @Override + public Integer getKey(CustomType value) throws Exception { + return value.myInt; + } + }); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyMixedKeySelectorTurned() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + try { + ds1.join(ds2).where(new KeySelector() { + @Override + public Integer getKey(CustomType value) throws Exception { + return value.myInt; + } + }).equalTo("myInt"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyMixedTupleIndex() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + try { + ds1.join(ds2).where("f0").equalTo(4); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyNestedTuples() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet, Long, String, Long, Integer>> ds1 = env.fromCollection(emptyNestedTupleData, nestedTupleTypeInfo); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + try { + ds1.join(ds2).where("f0.f0").equalTo(4); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyNestedTuplesWithCustom() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyNestedCustomTupleData, nestedCustomTupleTypeInfo); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + try { + TypeInformation t = ds1.join(ds2).where("f0.myInt").equalTo(4).getType(); + Assert.assertTrue("not a composite type", t instanceof CompositeType); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyWithCustomContainingTuple0() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeWithTupleData); + DataSet ds2 = env.fromCollection(customTypeWithTupleData); + try { + ds1.join(ds2).where("intByString.f0").equalTo("myInt"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyWithCustomContainingTuple1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeWithTupleData); + DataSet ds2 = env.fromCollection(customTypeWithTupleData); + try { + ds1.join(ds2).where("nested.myInt").equalTo("intByString.f0"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test + public void testJoinKeyWithCustomContainingTuple2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeWithTupleData); + DataSet ds2 = env.fromCollection(customTypeWithTupleData); + try { + ds1.join(ds2).where("nested.myInt", "myInt", "intByString.f1").equalTo("intByString.f0","myInt", "myString"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyNestedTuplesWrongType() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet, Long, String, Long, Integer>> ds1 = env.fromCollection(emptyNestedTupleData, nestedTupleTypeInfo); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + ds1.join(ds2).where("f0.f1").equalTo(4); // f0.f1 is a String + } + + @Test + public void testJoinKeyMixedTupleIndexTurned() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + try { + ds1.join(ds2).where(0).equalTo("f0"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyMixedTupleIndexWrongType() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + ds1.join(ds2).where("f0").equalTo(3); // 3 is of type long, so it should fail + } + + @Test + public void testJoinKeyMixedTupleIndex2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + try { + ds1.join(ds2).where("myInt").equalTo(4); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyMixedWrong() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + // wrongly mix String and Integer + ds1.join(ds2).where("myString").equalTo(new KeySelector() { + @Override + public Integer getKey(CustomType value) throws Exception { + return value.myInt; + } + }); + } + + @Test + public void testJoinKeyExpressions1Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should work + try { + ds1.join(ds2).where("nested.myInt").equalTo("nested.myInt"); + } catch(Exception e) { + e.printStackTrace(); + Assert.fail(); + } + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyExpressions2Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should not work, incompatible join key types + ds1.join(ds2).where("nested.myInt").equalTo("nested.myString"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyExpressions3Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should not work, incompatible number of join keys + ds1.join(ds2).where("nested.myInt", "nested.myString").equalTo("nested.myString"); + } + + @Test(expected = IllegalArgumentException.class) + public void testJoinKeyExpressions4Nested() { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromCollection(customTypeData); + + // should not work, join key non-existent + ds1.join(ds2).where("nested.myNonExistent").equalTo("nested.myInt"); } @@ -235,6 +505,7 @@ public Long getKey(CustomType value) { ) .equalTo(3); } catch(Exception e) { + e.printStackTrace(); Assert.fail(); } } @@ -403,7 +674,6 @@ public void testJoinProjection6() { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet ds1 = env.fromCollection(customTypeData); DataSet ds2 = env.fromCollection(customTypeData); - // should work try { ds1.join(ds2) @@ -430,6 +700,7 @@ public Long getKey(CustomType value) { .types(CustomType.class, CustomType.class); } catch(Exception e) { System.out.println("FAILED: " + e); + e.printStackTrace(); Assert.fail(); } } @@ -550,13 +821,58 @@ public void testJoinProjection14() { * #################################################################### */ + public static class Nested implements Serializable { + + private static final long serialVersionUID = 1L; + + public int myInt; + + public Nested() {}; + + public Nested(int i, long l, String s) { + myInt = i; + } + + @Override + public String toString() { + return ""+myInt; + } + } + // a simple nested type (only basic types) + public static class NestedCustomType implements Serializable { + + private static final long serialVersionUID = 1L; + + public int myInt; + public long myLong; + public String myString; + public Nested nest; + + public NestedCustomType() {}; + + public NestedCustomType(int i, long l, String s) { + myInt = i; + myLong = l; + myString = s; + } + + @Override + public String toString() { + return myInt+","+myLong+","+myString+","+nest; + } + } + public static class CustomType implements Serializable { private static final long serialVersionUID = 1L; public int myInt; public long myLong; + public NestedCustomType nested; public String myString; + public Object nothing; + // public List countries; need Kryo to support this + // public Writable interfaceTest; need kryo public CustomType() {}; @@ -564,6 +880,7 @@ public CustomType(int i, long l, String s) { myInt = i; myLong = l; myString = s; + nested = new NestedCustomType(i, l, s); } @Override @@ -571,4 +888,33 @@ public String toString() { return myInt+","+myLong+","+myString; } } + + + public static class CustomTypeWithTuple implements Serializable { + + private static final long serialVersionUID = 1L; + + public int myInt; + public long myLong; + public NestedCustomType nested; + public String myString; + public Tuple2 intByString; + + public CustomTypeWithTuple() {}; + + public CustomTypeWithTuple(int i, long l, String s) { + myInt = i; + myLong = l; + myString = s; + nested = new NestedCustomType(i, l, s); + intByString = new Tuple2(i, s); + } + + @Override + public String toString() { + return myInt+","+myLong+","+myString; + } + } + + } diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operators/KeysTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operators/KeysTest.java new file mode 100644 index 0000000000000..00ab5206eb630 --- /dev/null +++ b/flink-java/src/test/java/org/apache/flink/api/java/operators/KeysTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.operators; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Arrays; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.operators.Keys.ExpressionKeys; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple7; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +@RunWith(PowerMockRunner.class) +public class KeysTest { + + @Test + public void testTupleRangeCheck() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { + + // test private static final int[] rangeCheckFields(int[] fields, int maxAllowedField) + Method rangeCheckFieldsMethod = Whitebox.getMethod(Keys.class, "rangeCheckFields", int[].class, int.class); + int[] result = (int[]) rangeCheckFieldsMethod.invoke(null, new int[] {1,2,3,4}, 4); + Assert.assertArrayEquals(new int[] {1,2,3,4}, result); + + // test duplicate elimination + result = (int[]) rangeCheckFieldsMethod.invoke(null, new int[] {1,2,2,3,4}, 4); + Assert.assertArrayEquals(new int[] {1,2,3,4}, result); + + result = (int[]) rangeCheckFieldsMethod.invoke(null, new int[] {1,2,2,2,2,2,2,3,3,4}, 4); + Assert.assertArrayEquals(new int[] {1,2,3,4}, result); + + // corner case tests + result = (int[]) rangeCheckFieldsMethod.invoke(null, new int[] {0}, 0); + Assert.assertArrayEquals(new int[] {0}, result); + + Throwable ex = null; + try { + // throws illegal argument. + result = (int[]) rangeCheckFieldsMethod.invoke(null, new int[] {5}, 0); + } catch(Throwable iae) { + ex = iae; + } + Assert.assertNotNull(ex); + } + + @Test + public void testStandardTupleKeys() { + TupleTypeInfo> typeInfo = new TupleTypeInfo>( + BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO); + + ExpressionKeys> ek; + + for( int i = 1; i < 8; i++) { + int[] ints = new int[i]; + for( int j = 0; j < i; j++) { + ints[j] = j; + } + int[] inInts = Arrays.copyOf(ints, ints.length); // copy, just to make sure that the code is not cheating by changing the ints. + ek = new ExpressionKeys>(inInts, typeInfo); + Assert.assertArrayEquals(ints, ek.computeLogicalKeyPositions()); + Assert.assertEquals(ints.length, ek.computeLogicalKeyPositions().length); + + ArrayUtils.reverse(ints); + inInts = Arrays.copyOf(ints, ints.length); + ek = new ExpressionKeys>(inInts, typeInfo); + Assert.assertArrayEquals(ints, ek.computeLogicalKeyPositions()); + Assert.assertEquals(ints.length, ek.computeLogicalKeyPositions().length); + } + } + + @Test + public void testInvalid() throws Throwable { + TupleTypeInfo, String>> typeInfo = new TupleTypeInfo,String>>( + BasicTypeInfo.STRING_TYPE_INFO, + new TupleTypeInfo>(BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO), + BasicTypeInfo.STRING_TYPE_INFO); + ExpressionKeys, String>> fpk; + + String[][] tests = new String[][] { + new String[] {"f11"},new String[] {"f-35"}, new String[] {"f0.f33"}, new String[] {"f1.f33"} + }; + for(int i = 0; i < tests.length; i++) { + Throwable e = null; + try { + fpk = new ExpressionKeys, String>>(tests[i], typeInfo); + } catch(Throwable t) { + // System.err.println("Message: "+t.getMessage()); t.printStackTrace(); + e = t; + } + Assert.assertNotNull(e); + } + } + + @Test + public void testTupleKeyExpansion() { + TupleTypeInfo, String>> typeInfo = new TupleTypeInfo,String>>( + BasicTypeInfo.STRING_TYPE_INFO, + new TupleTypeInfo>(BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO), + BasicTypeInfo.STRING_TYPE_INFO); + ExpressionKeys, String>> fpk = + new ExpressionKeys, String>>(new int[] {0}, typeInfo); + Assert.assertArrayEquals(new int[] {0}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new int[] {1}, typeInfo); + Assert.assertArrayEquals(new int[] {1,2,3}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new int[] {2}, typeInfo); + Assert.assertArrayEquals(new int[] {4}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new int[] {0,1,2}, typeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,3,4}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(null, typeInfo, true); // empty case + Assert.assertArrayEquals(new int[] {0,1,2,3,4}, fpk.computeLogicalKeyPositions()); + + // duplicate elimination + fpk = new ExpressionKeys, String>>(new int[] {0,1,1,1,2}, typeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,3,4}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new String[] {"*"}, typeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,3,4}, fpk.computeLogicalKeyPositions()); + + // this was a bug: + fpk = new ExpressionKeys, String>>(new String[] {"f2"}, typeInfo); + Assert.assertArrayEquals(new int[] {4}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new String[] {"f0","f1.f0","f1.f1", "f1.f2", "f2"}, typeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,3,4}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new String[] {"f0","f1.f0","f1.f1", "f2"}, typeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,4}, fpk.computeLogicalKeyPositions()); + + fpk = new ExpressionKeys, String>>(new String[] {"f2", "f0"}, typeInfo); + Assert.assertArrayEquals(new int[] {4,0}, fpk.computeLogicalKeyPositions()); + + // duplicate elimination + fpk = new ExpressionKeys, String>>(new String[] {"f2","f2","f2", "f0"}, typeInfo); + Assert.assertArrayEquals(new int[] {4,0}, fpk.computeLogicalKeyPositions()); + + + TupleTypeInfo, String, String>, String>> complexTypeInfo = new TupleTypeInfo,String,String>,String>>( + BasicTypeInfo.STRING_TYPE_INFO, + new TupleTypeInfo, String, String>>(new TupleTypeInfo>(BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO),BasicTypeInfo.STRING_TYPE_INFO,BasicTypeInfo.STRING_TYPE_INFO), + BasicTypeInfo.STRING_TYPE_INFO); + + ExpressionKeys, String, String>, String>> complexFpk = + new ExpressionKeys, String, String>, String>>(new int[] {0}, complexTypeInfo); + Assert.assertArrayEquals(new int[] {0}, complexFpk.computeLogicalKeyPositions()); + + complexFpk = new ExpressionKeys, String, String>, String>>(new int[] {0,1,2}, complexTypeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,3,4,5,6}, complexFpk.computeLogicalKeyPositions()); + + complexFpk = new ExpressionKeys, String, String>, String>>(new String[] {"*"}, complexTypeInfo); + Assert.assertArrayEquals(new int[] {0,1,2,3,4,5,6}, complexFpk.computeLogicalKeyPositions()); + + complexFpk = new ExpressionKeys, String, String>, String>>(new String[] {"f1.f0.*"}, complexTypeInfo); + Assert.assertArrayEquals(new int[] {1,2,3}, complexFpk.computeLogicalKeyPositions()); + + complexFpk = new ExpressionKeys, String, String>, String>>(new String[] {"f2"}, complexTypeInfo); + Assert.assertArrayEquals(new int[] {6}, complexFpk.computeLogicalKeyPositions()); + } +} diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeInformationTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeInformationTest.java index 62f1fad28f30e..c1cb249c8470d 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeInformationTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/PojoTypeInformationTest.java @@ -22,82 +22,81 @@ import static org.junit.Assert.assertTrue; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; -import org.junit.Ignore; +import org.junit.Assert; import org.junit.Test; @SuppressWarnings("unused") public class PojoTypeInformationTest { - static class SimplePojo { - String str; - Boolean Bl; - boolean bl; - Byte Bt; - byte bt; - Short Shrt; - short shrt; - Integer Intgr; - int intgr; - Long Lng; - long lng; - Float Flt; - float flt; - Double Dbl; - double dbl; - Character Ch; - char ch; - int[] primIntArray; - Integer[] intWrapperArray; + public static class SimplePojo { + public String str; + public Boolean Bl; + public boolean bl; + public Byte Bt; + public byte bt; + public Short Shrt; + public short shrt; + public Integer Intgr; + public int intgr; + public Long Lng; + public long lng; + public Float Flt; + public float flt; + public Double Dbl; + public double dbl; + public Character Ch; + public char ch; + public int[] primIntArray; + public Integer[] intWrapperArray; } - @Ignore @Test public void testSimplePojoTypeExtraction() { TypeInformation type = TypeExtractor.getForClass(SimplePojo.class); - assertTrue("Extracted type is not a Pojo type but should be.", type instanceof PojoTypeInfo); + assertTrue("Extracted type is not a composite/pojo type but should be.", type instanceof CompositeType); } - static class NestedPojoInner { - private String field; + public static class NestedPojoInner { + public String field; } - static class NestedPojoOuter { - private Integer intField; - NestedPojoInner inner; + public static class NestedPojoOuter { + public Integer intField; + public NestedPojoInner inner; } - @Ignore @Test public void testNestedPojoTypeExtraction() { TypeInformation type = TypeExtractor.getForClass(NestedPojoOuter.class); - assertTrue("Extracted type is not a Pojo type but should be.", type instanceof PojoTypeInfo); + assertTrue("Extracted type is not a Pojo type but should be.", type instanceof CompositeType); } - static class Recursive1Pojo { - private Integer intField; - Recursive2Pojo rec; + public static class Recursive1Pojo { + public Integer intField; + public Recursive2Pojo rec; } - static class Recursive2Pojo { - private String strField; - Recursive1Pojo rec; + public static class Recursive2Pojo { + public String strField; + public Recursive1Pojo rec; } - @Ignore @Test public void testRecursivePojoTypeExtraction() { // This one tests whether a recursive pojo is detected using the set of visited // types in the type extractor. The recursive field will be handled using the generic serializer. TypeInformation type = TypeExtractor.getForClass(Recursive1Pojo.class); - assertTrue("Extracted type is not a Pojo type but should be.", type instanceof PojoTypeInfo); + assertTrue("Extracted type is not a Pojo type but should be.", type instanceof CompositeType); } - - @Ignore + @Test public void testRecursivePojoObjectTypeExtraction() { TypeInformation type = TypeExtractor.getForObject(new Recursive1Pojo()); - assertTrue("Extracted type is not a Pojo type but should be.", type instanceof PojoTypeInfo); + assertTrue("Extracted type is not a Pojo type but should be.", type instanceof CompositeType); } + } diff --git a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java index eb23338a2e01f..60d41d38225cd 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/type/extractor/TypeExtractorTest.java @@ -21,21 +21,26 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Date; +import java.util.List; import org.apache.flink.api.common.functions.InvalidTypesException; import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.common.functions.RichCrossFunction; +import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichGroupReduceFunction; -import org.apache.flink.api.common.functions.RichFlatJoinFunction; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; @@ -43,6 +48,8 @@ import org.apache.flink.api.java.tuple.Tuple9; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.api.java.typeutils.PojoField; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; @@ -57,10 +64,14 @@ import org.apache.flink.util.Collector; import org.apache.hadoop.io.Writable; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; +import com.google.common.collect.HashMultiset; + public class TypeExtractorTest { + @SuppressWarnings({ "rawtypes", "unchecked" }) @Test public void testBasicType() { @@ -140,7 +151,10 @@ public Tuple9 ffd = new ArrayList(); + ((TupleTypeInfo) ti).getKey("f3", 0, ffd); + Assert.assertTrue(ffd.size() == 1); + Assert.assertEquals(3, ffd.get(0).getPosition() ); TupleTypeInfo tti = (TupleTypeInfo) ti; Assert.assertEquals(Tuple9.class, tti.getTypeClass()); @@ -203,6 +217,20 @@ public void flatMap(Tuple3, Tuple1, Tuple2> Assert.assertTrue(ti.isTupleType()); Assert.assertEquals(3, ti.getArity()); Assert.assertTrue(ti instanceof TupleTypeInfo); + List ffd = new ArrayList(); + + ((TupleTypeInfo) ti).getKey("f0.f0", 0, ffd); + Assert.assertEquals(0, ffd.get(0).getPosition() ); + ffd.clear(); + + ((TupleTypeInfo) ti).getKey("f0.f0", 0, ffd); + Assert.assertTrue( ffd.get(0).getType() instanceof BasicTypeInfo ); + Assert.assertTrue( ffd.get(0).getType().getTypeClass().equals(String.class) ); + ffd.clear(); + + ((TupleTypeInfo) ti).getKey("f1.f0", 0, ffd); + Assert.assertEquals(1, ffd.get(0).getPosition() ); + ffd.clear(); TupleTypeInfo tti = (TupleTypeInfo) ti; Assert.assertEquals(Tuple3.class, tti.getTypeClass()); @@ -305,11 +333,11 @@ public CustomType cross(CustomType first, Integer second) throws Exception { Assert.assertFalse(ti.isBasicType()); Assert.assertFalse(ti.isTupleType()); - Assert.assertTrue(ti instanceof GenericTypeInfo); + Assert.assertTrue(ti instanceof PojoTypeInfo); Assert.assertEquals(ti.getTypeClass(), CustomType.class); // use getForClass() - Assert.assertTrue(TypeExtractor.getForClass(CustomType.class) instanceof GenericTypeInfo); + Assert.assertTrue(TypeExtractor.getForClass(CustomType.class) instanceof PojoTypeInfo); Assert.assertEquals(TypeExtractor.getForClass(CustomType.class).getTypeClass(), ti.getTypeClass()); // use getForObject() @@ -318,10 +346,454 @@ public CustomType cross(CustomType first, Integer second) throws Exception { Assert.assertFalse(ti2.isBasicType()); Assert.assertFalse(ti2.isTupleType()); - Assert.assertTrue(ti2 instanceof GenericTypeInfo); + Assert.assertTrue(ti2 instanceof PojoTypeInfo); Assert.assertEquals(ti2.getTypeClass(), CustomType.class); } + + // + // Pojo Type tests + // A Pojo is a bean-style class with getters, setters and empty ctor + // OR a class with all fields public (or for every private field, there has to be a public getter/setter) + // everything else is a generic type (that can't be used for field selection) + // + + + // test with correct pojo types + public static class WC { // is a pojo + public ComplexNestedClass complex; // is a pojo + private int count; // is a BasicType + public WC() { + } + public int getCount() { + return count; + } + public void setCount(int c) { + this.count = c; + } + } + public static class ComplexNestedClass { // pojo + public static int ignoreStaticField; + public transient int ignoreTransientField; + public Date date; // generic type + public Integer someNumber; // BasicType + public float someFloat; // BasicType + public Tuple3 word; //Tuple Type with three basic types + public Object nothing; // generic type + public MyWritable hadoopCitizen; // writableType + } + + // all public test + public static class AllPublic extends ComplexNestedClass { + public ArrayList somethingFancy; // generic type + public HashMultiset fancyIds; // generic type + public String[] fancyArray; // generic type + } + + public static class ParentSettingGenerics extends PojoWithGenerics { + public String field3; + } + public static class PojoWithGenerics { + public int key; + public T1 field1; + public T2 field2; + } + + public static class ComplexHierarchyTop extends ComplexHierarchy> {} + public static class ComplexHierarchy extends PojoWithGenerics {} + + // extends from Tuple and adds a field + public static class FromTuple extends Tuple3 { + private static final long serialVersionUID = 1L; + public int special; + } + + public static class IncorrectPojo { + private int isPrivate; + public int getIsPrivate() { + return isPrivate; + } + // setter is missing (intentional) + } + + // correct pojo + public static class BeanStylePojo { + public String abc; + private int field; + public int getField() { + return this.field; + } + public void setField(int f) { + this.field = f; + } + } + public static class WrongCtorPojo { + public int a; + public WrongCtorPojo(int a) { + this.a = a; + } + } + + // in this test, the location of the getters and setters is mixed across the type hierarchy. + public static class TypedPojoGetterSetterCheck extends GenericPojoGetterSetterCheck { + public void setPackageProtected(String in) { + this.packageProtected = in; + } + } + public static class GenericPojoGetterSetterCheck { + T packageProtected; + public T getPackageProtected() { + return packageProtected; + } + } + + @Test + public void testIncorrectPojos() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(IncorrectPojo.class); + Assert.assertTrue(typeForClass instanceof GenericTypeInfo); + + typeForClass = TypeExtractor.createTypeInfo(WrongCtorPojo.class); + Assert.assertTrue(typeForClass instanceof GenericTypeInfo); + } + + @Test + public void testCorrectPojos() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(BeanStylePojo.class); + Assert.assertTrue(typeForClass instanceof PojoTypeInfo); + + typeForClass = TypeExtractor.createTypeInfo(TypedPojoGetterSetterCheck.class); + Assert.assertTrue(typeForClass instanceof PojoTypeInfo); + } + + @Test + public void testPojoWC() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(WC.class); + checkWCPojoAsserts(typeForClass); + + WC t = new WC(); + t.complex = new ComplexNestedClass(); + TypeInformation typeForObject = TypeExtractor.getForObject(t); + checkWCPojoAsserts(typeForObject); + } + + private void checkWCPojoAsserts(TypeInformation typeInfo) { + Assert.assertFalse(typeInfo.isBasicType()); + Assert.assertFalse(typeInfo.isTupleType()); + Assert.assertEquals(9, typeInfo.getTotalFields()); + Assert.assertTrue(typeInfo instanceof PojoTypeInfo); + PojoTypeInfo pojoType = (PojoTypeInfo) typeInfo; + + List ffd = new ArrayList(); + String[] fields = {"count","complex.date", "complex.hadoopCitizen", "complex.nothing", + "complex.someFloat", "complex.someNumber", "complex.word.f0", + "complex.word.f1", "complex.word.f2"}; + int[] positions = {8,0,1,2, + 3,4,5, + 6,7}; + Assert.assertEquals(fields.length, positions.length); + for(int i = 0; i < fields.length; i++) { + pojoType.getKey(fields[i], 0, ffd); + Assert.assertEquals("Too many keys returned", 1, ffd.size()); + Assert.assertEquals("position of field "+fields[i]+" wrong", positions[i], ffd.get(0).getPosition()); + ffd.clear(); + } + + pojoType.getKey("complex.word.*", 0, ffd); + Assert.assertEquals(3, ffd.size()); + // check if it returns 5,6,7 + for(FlatFieldDescriptor ffdE : ffd) { + final int pos = ffdE.getPosition(); + Assert.assertTrue(pos <= 7 ); + Assert.assertTrue(5 <= pos ); + if(pos == 5) { + Assert.assertEquals(Long.class, ffdE.getType().getTypeClass()); + } + if(pos == 6) { + Assert.assertEquals(Long.class, ffdE.getType().getTypeClass()); + } + if(pos == 7) { + Assert.assertEquals(String.class, ffdE.getType().getTypeClass()); + } + } + ffd.clear(); + + + pojoType.getKey("complex.*", 0, ffd); + Assert.assertEquals(8, ffd.size()); + // check if it returns 0-7 + for(FlatFieldDescriptor ffdE : ffd) { + final int pos = ffdE.getPosition(); + Assert.assertTrue(ffdE.getPosition() <= 7 ); + Assert.assertTrue(0 <= ffdE.getPosition() ); + if(pos == 0) { + Assert.assertEquals(Date.class, ffdE.getType().getTypeClass()); + } + if(pos == 1) { + Assert.assertEquals(MyWritable.class, ffdE.getType().getTypeClass()); + } + if(pos == 2) { + Assert.assertEquals(Object.class, ffdE.getType().getTypeClass()); + } + if(pos == 3) { + Assert.assertEquals(Float.class, ffdE.getType().getTypeClass()); + } + if(pos == 4) { + Assert.assertEquals(Integer.class, ffdE.getType().getTypeClass()); + } + if(pos == 5) { + Assert.assertEquals(Long.class, ffdE.getType().getTypeClass()); + } + if(pos == 6) { + Assert.assertEquals(Long.class, ffdE.getType().getTypeClass()); + } + if(pos == 7) { + Assert.assertEquals(String.class, ffdE.getType().getTypeClass()); + } + } + ffd.clear(); + + pojoType.getKey("*", 0, ffd); + Assert.assertEquals(9, ffd.size()); + // check if it returns 0-8 + for(FlatFieldDescriptor ffdE : ffd) { + Assert.assertTrue(ffdE.getPosition() <= 8 ); + Assert.assertTrue(0 <= ffdE.getPosition() ); + if(ffdE.getPosition() == 8) { + Assert.assertEquals(Integer.class, ffdE.getType().getTypeClass()); + } + } + ffd.clear(); + + TypeInformation typeComplexNested = pojoType.getTypeAt(0); // ComplexNestedClass complex + Assert.assertTrue(typeComplexNested instanceof PojoTypeInfo); + + Assert.assertEquals(6, typeComplexNested.getArity()); + Assert.assertEquals(8, typeComplexNested.getTotalFields()); + PojoTypeInfo pojoTypeComplexNested = (PojoTypeInfo) typeComplexNested; + + boolean dateSeen = false, intSeen = false, floatSeen = false, + tupleSeen = false, objectSeen = false, writableSeen = false; + for(int i = 0; i < pojoTypeComplexNested.getArity(); i++) { + PojoField field = pojoTypeComplexNested.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("date")) { + if(dateSeen) { + Assert.fail("already seen"); + } + dateSeen = true; + Assert.assertEquals(new GenericTypeInfo(Date.class), field.type); + Assert.assertEquals(Date.class, field.type.getTypeClass()); + } else if(name.equals("someNumber")) { + if(intSeen) { + Assert.fail("already seen"); + } + intSeen = true; + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + Assert.assertEquals(Integer.class, field.type.getTypeClass()); + } else if(name.equals("someFloat")) { + if(floatSeen) { + Assert.fail("already seen"); + } + floatSeen = true; + Assert.assertEquals(BasicTypeInfo.FLOAT_TYPE_INFO, field.type); + Assert.assertEquals(Float.class, field.type.getTypeClass()); + } else if(name.equals("word")) { + if(tupleSeen) { + Assert.fail("already seen"); + } + tupleSeen = true; + Assert.assertTrue(field.type instanceof TupleTypeInfo); + Assert.assertEquals(Tuple3.class, field.type.getTypeClass()); + // do some more advanced checks on the tuple + TupleTypeInfo tupleTypeFromComplexNested = (TupleTypeInfo) field.type; + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, tupleTypeFromComplexNested.getTypeAt(0)); + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, tupleTypeFromComplexNested.getTypeAt(1)); + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, tupleTypeFromComplexNested.getTypeAt(2)); + } else if(name.equals("nothing")) { + if(objectSeen) { + Assert.fail("already seen"); + } + objectSeen = true; + Assert.assertEquals(new GenericTypeInfo(Object.class), field.type); + Assert.assertEquals(Object.class, field.type.getTypeClass()); + } else if(name.equals("hadoopCitizen")) { + if(writableSeen) { + Assert.fail("already seen"); + } + writableSeen = true; + Assert.assertEquals(new WritableTypeInfo(MyWritable.class), field.type); + Assert.assertEquals(MyWritable.class, field.type.getTypeClass()); + } else { + Assert.fail("field "+field+" is not expected"); + } + } + Assert.assertTrue("Field was not present", dateSeen); + Assert.assertTrue("Field was not present", intSeen); + Assert.assertTrue("Field was not present", floatSeen); + Assert.assertTrue("Field was not present", tupleSeen); + Assert.assertTrue("Field was not present", objectSeen); + Assert.assertTrue("Field was not present", writableSeen); + + TypeInformation typeAtOne = pojoType.getTypeAt(1); // int count + Assert.assertTrue(typeAtOne instanceof BasicTypeInfo); + + Assert.assertEquals(typeInfo.getTypeClass(), WC.class); + Assert.assertEquals(typeInfo.getArity(), 2); + } + + @Test + public void testPojoAllPublic() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(AllPublic.class); + checkAllPublicAsserts(typeForClass); + + TypeInformation typeForObject = TypeExtractor.getForObject(new AllPublic() ); + checkAllPublicAsserts(typeForObject); + } + + private void checkAllPublicAsserts(TypeInformation typeInformation) { + Assert.assertTrue(typeInformation instanceof PojoTypeInfo); + Assert.assertEquals(9, typeInformation.getArity()); + Assert.assertEquals(11, typeInformation.getTotalFields()); + // check if the three additional fields are identified correctly + boolean arrayListSeen = false, multisetSeen = false, strArraySeen = false; + PojoTypeInfo pojoTypeForClass = (PojoTypeInfo) typeInformation; + for(int i = 0; i < pojoTypeForClass.getArity(); i++) { + PojoField field = pojoTypeForClass.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("somethingFancy")) { + if(arrayListSeen) { + Assert.fail("already seen"); + } + arrayListSeen = true; + Assert.assertTrue(field.type instanceof GenericTypeInfo); + Assert.assertEquals(ArrayList.class, field.type.getTypeClass()); + } else if(name.equals("fancyIds")) { + if(multisetSeen) { + Assert.fail("already seen"); + } + multisetSeen = true; + Assert.assertTrue(field.type instanceof GenericTypeInfo); + Assert.assertEquals(HashMultiset.class, field.type.getTypeClass()); + } else if(name.equals("fancyArray")) { + if(strArraySeen) { + Assert.fail("already seen"); + } + strArraySeen = true; + Assert.assertEquals(BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO, field.type); + Assert.assertEquals(String[].class, field.type.getTypeClass()); + } else if(Arrays.asList("date", "someNumber", "someFloat", "word", "nothing", "hadoopCitizen").contains(name)) { + // ignore these, they are inherited from the ComplexNestedClass + } + else { + Assert.fail("field "+field+" is not expected"); + } + } + Assert.assertTrue("Field was not present", arrayListSeen); + Assert.assertTrue("Field was not present", multisetSeen); + Assert.assertTrue("Field was not present", strArraySeen); + } + + @Test + public void testPojoExtendingTuple() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(FromTuple.class); + checkFromTuplePojo(typeForClass); + + FromTuple ft = new FromTuple(); + ft.f0 = ""; ft.f1 = ""; ft.f2 = 0L; + TypeInformation typeForObject = TypeExtractor.getForObject(ft); + checkFromTuplePojo(typeForObject); + } + + private void checkFromTuplePojo(TypeInformation typeInformation) { + Assert.assertTrue(typeInformation instanceof PojoTypeInfo); + Assert.assertEquals(4, typeInformation.getTotalFields()); + PojoTypeInfo pojoTypeForClass = (PojoTypeInfo) typeInformation; + for(int i = 0; i < pojoTypeForClass.getArity(); i++) { + PojoField field = pojoTypeForClass.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("special")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else if(name.equals("f0") || name.equals("f1")) { + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, field.type); + } else if(name.equals("f2")) { + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, field.type); + } else { + Assert.fail("unexpected field"); + } + } + } + + @Test + public void testPojoWithGenerics() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(ParentSettingGenerics.class); + Assert.assertTrue(typeForClass instanceof PojoTypeInfo); + PojoTypeInfo pojoTypeForClass = (PojoTypeInfo) typeForClass; + for(int i = 0; i < pojoTypeForClass.getArity(); i++) { + PojoField field = pojoTypeForClass.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("field1")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else if (name.equals("field2")) { + Assert.assertEquals(BasicTypeInfo.LONG_TYPE_INFO, field.type); + } else if (name.equals("field3")) { + Assert.assertEquals(BasicTypeInfo.STRING_TYPE_INFO, field.type); + } else if (name.equals("key")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + /** + * Test if the TypeExtractor is accepting untyped generics, + * making them GenericTypes + */ + @Test + @Ignore // kryo needed. + public void testPojoWithGenericsSomeFieldsGeneric() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(PojoWithGenerics.class); + Assert.assertTrue(typeForClass instanceof PojoTypeInfo); + PojoTypeInfo pojoTypeForClass = (PojoTypeInfo) typeForClass; + for(int i = 0; i < pojoTypeForClass.getArity(); i++) { + PojoField field = pojoTypeForClass.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("field1")) { + Assert.assertEquals(new GenericTypeInfo(Object.class), field.type); + } else if (name.equals("field2")) { + Assert.assertEquals(new GenericTypeInfo(Object.class), field.type); + } else if (name.equals("key")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + + @Test + public void testPojoWithComplexHierarchy() { + TypeInformation typeForClass = TypeExtractor.createTypeInfo(ComplexHierarchyTop.class); + Assert.assertTrue(typeForClass instanceof PojoTypeInfo); + PojoTypeInfo pojoTypeForClass = (PojoTypeInfo) typeForClass; + for(int i = 0; i < pojoTypeForClass.getArity(); i++) { + PojoField field = pojoTypeForClass.getPojoFieldAt(i); + String name = field.field.getName(); + if(name.equals("field1")) { + Assert.assertTrue(field.type instanceof PojoTypeInfo); // From tuple is pojo (not tuple type!) + } else if (name.equals("field2")) { + Assert.assertTrue(field.type instanceof TupleTypeInfo); + Assert.assertTrue( ((TupleTypeInfo)field.type).getTypeAt(0).equals(BasicTypeInfo.STRING_TYPE_INFO) ); + } else if (name.equals("key")) { + Assert.assertEquals(BasicTypeInfo.INT_TYPE_INFO, field.type); + } else { + Assert.fail("Unexpected field "+field); + } + } + } + + // End of Pojo type tests + public static class CustomType { public String myField1; public int myField2; @@ -356,9 +828,27 @@ public Tuple2 map(Tuple2 value) throws Excep TupleTypeInfo tti = (TupleTypeInfo) ti; Assert.assertEquals(Tuple2.class, tti.getTypeClass()); + List ffd = new ArrayList(); + + tti.getKey("f0", 0, ffd); + Assert.assertEquals(1, ffd.size()); + Assert.assertEquals(0, ffd.get(0).getPosition() ); // Long + Assert.assertTrue( ffd.get(0).getType().getTypeClass().equals(Long.class) ); + ffd.clear(); + + tti.getKey("f1.myField1", 0, ffd); + Assert.assertEquals(1, ffd.get(0).getPosition() ); + Assert.assertTrue( ffd.get(0).getType().getTypeClass().equals(String.class) ); + ffd.clear(); + + + tti.getKey("f1.myField2", 0, ffd); + Assert.assertEquals(2, ffd.get(0).getPosition() ); + Assert.assertTrue( ffd.get(0).getType().getTypeClass().equals(Integer.class) ); + Assert.assertEquals(Long.class, tti.getTypeAt(0).getTypeClass()); - Assert.assertTrue(tti.getTypeAt(1) instanceof GenericTypeInfo); + Assert.assertTrue(tti.getTypeAt(1) instanceof PojoTypeInfo); Assert.assertEquals(CustomType.class, tti.getTypeAt(1).getTypeClass()); // use getForObject() @@ -371,7 +861,7 @@ public Tuple2 map(Tuple2 value) throws Excep Assert.assertEquals(Tuple2.class, tti2.getTypeClass()); Assert.assertEquals(Long.class, tti2.getTypeAt(0).getTypeClass()); - Assert.assertTrue(tti2.getTypeAt(1) instanceof GenericTypeInfo); + Assert.assertTrue(tti2.getTypeAt(1) instanceof PojoTypeInfo); Assert.assertEquals(CustomType.class, tti2.getTypeAt(1).getTypeClass()); } @@ -1203,6 +1693,7 @@ public static class MyObject { @SuppressWarnings({ "rawtypes", "unchecked" }) @Test + @Ignore public void testParamertizedCustomObject() { RichMapFunction function = new RichMapFunction, MyObject>() { private static final long serialVersionUID = 1L; @@ -1214,7 +1705,7 @@ public MyObject map(MyObject value) throws Exception { }; TypeInformation ti = TypeExtractor.getMapReturnTypes(function, (TypeInformation) TypeInfoParser.parse("org.apache.flink.api.java.type.extractor.TypeExtractorTest$MyObject")); - Assert.assertTrue(ti instanceof GenericTypeInfo); + Assert.assertTrue(ti instanceof PojoTypeInfo); } @Test diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java index 382b00911a657..06330d373a632 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java @@ -28,7 +28,6 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple4; - import org.apache.flink.api.java.typeutils.runtime.tuple.base.TuplePairComparatorTestBase; public class GenericPairComparatorTest extends TuplePairComparatorTestBase, Tuple4> { @@ -57,7 +56,7 @@ public class GenericPairComparatorTest extends TuplePairComparatorTestBase(7, 0.88f, 34L, 15.2) }; - @SuppressWarnings("unchecked") + @SuppressWarnings("rawtypes") @Override protected GenericPairComparator, Tuple4> createComparator(boolean ascending) { int[] fields1 = new int[]{0, 2}; diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializerTest.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializerTest.java index fcc84f062e27b..5d0917e909206 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializerTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializerTest.java @@ -18,27 +18,37 @@ package org.apache.flink.api.java.typeutils.runtime; -import com.google.common.base.Objects; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.CompositeType.FlatFieldDescriptor; +import org.apache.flink.api.java.operators.Keys.ExpressionKeys; +import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; -import org.junit.Ignore; +import org.junit.Assert; +import org.junit.Test; -import java.util.Random; +import com.google.common.base.Objects; /** * A test for the {@link org.apache.flink.api.java.typeutils.runtime.PojoSerializer}. */ -@Ignore public class PojoSerializerTest extends SerializerTestBase { private TypeInformation type = TypeExtractor.getForClass(TestUserClass.class); @Override protected TypeSerializer createSerializer() { TypeSerializer serializer = type.createSerializer(); - assert (serializer instanceof PojoSerializer); + assert(serializer instanceof PojoSerializer); return serializer; } @@ -67,12 +77,12 @@ protected TestUserClass[] getTestData() { // User code class for testing the serializer public static class TestUserClass { - private int dumm1; - protected String dumm2; + public int dumm1; + public String dumm2; public double dumm3; - private int[] dumm4; + public int[] dumm4; - private NestedTestUserClass nestedClass; + public NestedTestUserClass nestedClass; public TestUserClass() { } @@ -121,10 +131,10 @@ public boolean equals(Object other) { } public static class NestedTestUserClass { - private int dumm1; - protected String dumm2; + public int dumm1; + public String dumm2; public double dumm3; - private int[] dumm4; + public int[] dumm4; public NestedTestUserClass() { } @@ -167,4 +177,52 @@ public boolean equals(Object other) { return true; } } + + /** + * This tests if the hashes returned by the pojo and tuple comparators are the same + */ + @Test + public void testTuplePojoTestEquality() { + + // test with a simple, string-key first. + PojoTypeInfo pType = (PojoTypeInfo) type; + List result = new ArrayList(); + pType.getKey("nestedClass.dumm2", 0, result); + int[] fields = new int[1]; // see below + fields[0] = result.get(0).getPosition(); + TypeComparator pojoComp = pType.createComparator( fields, new boolean[]{true}, 0); + + TestUserClass pojoTestRecord = new TestUserClass(0, "abc", 3d, new int[] {1,2,3}, new NestedTestUserClass(1, "haha", 4d, new int[] {5,4,3})); + int pHash = pojoComp.hash(pojoTestRecord); + + Tuple1 tupleTest = new Tuple1("haha"); + TupleTypeInfo> tType = (TupleTypeInfo>)TypeExtractor.getForObject(tupleTest); + TypeComparator> tupleComp = tType.createComparator(new int[] {0}, new boolean[] {true}, 0); + + int tHash = tupleComp.hash(tupleTest); + + Assert.assertTrue("The hashing for tuples and pojos must be the same, so that they are mixable", pHash == tHash); + + Tuple3 multiTupleTest = new Tuple3(1, "haha", 4d); // its important here to use the same values. + TupleTypeInfo> multiTupleType = (TupleTypeInfo>)TypeExtractor.getForObject(multiTupleTest); + + ExpressionKeys fieldKey = new ExpressionKeys(new int[]{1,0,2}, multiTupleType); + ExpressionKeys expressKey = new ExpressionKeys(new String[] {"nestedClass.dumm2", "nestedClass.dumm1", "nestedClass.dumm3"}, pType); + try { + Assert.assertTrue("Expecting the keys to be compatible", fieldKey.areCompatible(expressKey)); + } catch (IncompatibleKeysException e) { + e.printStackTrace(); + Assert.fail("Keys must be compatible: "+e.getMessage()); + } + TypeComparator multiPojoComp = pType.createComparator( expressKey.computeLogicalKeyPositions(), new boolean[]{true, true, true}, 0); + int multiPojoHash = multiPojoComp.hash(pojoTestRecord); + + + // pojo order is: dumm2 (str), dumm1 (int), dumm3 (double). + TypeComparator> multiTupleComp = multiTupleType.createComparator(fieldKey.computeLogicalKeyPositions(), new boolean[] {true, true,true}, 0); + int multiTupleHash = multiTupleComp.hash(multiTupleTest); + + Assert.assertTrue("The hashing for tuples and pojos must be the same, so that they are mixable. Also for those with multiple key fields", multiPojoHash == multiTupleHash); + + } } diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java index 1cbc1d023794a..ce442d1b26b02 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/TupleSerializerTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Random; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; @@ -208,7 +209,7 @@ public void testTuple5CustomObjects() { private final void runTests(T... instances) { try { TupleTypeInfo tupleTypeInfo = (TupleTypeInfo) TypeExtractor.getForObject(instances[0]); - TupleSerializer serializer = tupleTypeInfo.createSerializer(); + TypeSerializer serializer = tupleTypeInfo.createSerializer(); Class tupleClass = tupleTypeInfo.getTypeClass(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/GroupReduceDriverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/GroupReduceDriverTest.java index ecc5259b39a14..d249e9ac9b13a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/GroupReduceDriverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/GroupReduceDriverTest.java @@ -50,7 +50,7 @@ public void testAllReduceDriverImmutableEmpty() { List> data = DriverTestData.createReduceImmutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = EmptyMutableObjectIterator.get(); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); context.setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -82,7 +82,7 @@ public void testAllReduceDriverImmutable() { List> data = DriverTestData.createReduceImmutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -118,7 +118,7 @@ public void testAllReduceDriverMutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -154,7 +154,7 @@ public void testAllReduceDriverIncorrectlyAccumulatingMutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -196,7 +196,7 @@ public void testAllReduceDriverAccumulatingImmutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceCombineDriverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceCombineDriverTest.java index e189790cfa902..44cbe169ba0ce 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceCombineDriverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceCombineDriverTest.java @@ -55,7 +55,7 @@ public void testImmutableEmpty() { MutableObjectIterator> input = EmptyMutableObjectIterator.get(); context.setDriverStrategy(DriverStrategy.SORTED_PARTIAL_REDUCE); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -90,7 +90,7 @@ public void testReduceDriverImmutable() { TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -121,7 +121,7 @@ public void testReduceDriverImmutable() { TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -160,7 +160,7 @@ public void testReduceDriverMutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -188,7 +188,7 @@ public void testReduceDriverMutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceDriverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceDriverTest.java index d7c14fba7974c..ae4e54cf49ac6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceDriverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/ReduceDriverTest.java @@ -50,7 +50,7 @@ public void testReduceDriverImmutableEmpty() { TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = EmptyMutableObjectIterator.get(); context.setDriverStrategy(DriverStrategy.SORTED_REDUCE); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -82,7 +82,7 @@ public void testReduceDriverImmutable() { List> data = DriverTestData.createReduceImmutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -110,7 +110,7 @@ public void testReduceDriverImmutable() { List> data = DriverTestData.createReduceImmutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -148,7 +148,7 @@ public void testReduceDriverMutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); @@ -175,7 +175,7 @@ public void testReduceDriverMutable() { List> data = DriverTestData.createReduceMutableData(); TupleTypeInfo> typeInfo = (TupleTypeInfo>) TypeExtractor.getForObject(data.get(0)); MutableObjectIterator> input = new RegularToMutableObjectIterator>(data.iterator(), typeInfo.createSerializer()); - TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}); + TypeComparator> comparator = typeInfo.createComparator(new int[]{0}, new boolean[] {true}, 0); GatheringCollector> result = new GatheringCollector>(typeInfo.createSerializer()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/sort/MassiveStringSortingITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/sort/MassiveStringSortingITCase.java index a9553028ea6d0..257eb8785130c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/sort/MassiveStringSortingITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/sort/MassiveStringSortingITCase.java @@ -180,7 +180,7 @@ public void testStringTuplesSorting() { TupleTypeInfo> typeInfo = (TupleTypeInfo>) (TupleTypeInfo) TypeInfoParser.parse("Tuple2"); TypeSerializer> serializer = typeInfo.createSerializer(); - TypeComparator> comparator = typeInfo.createComparator(new int[] { 0 }, new boolean[] { true } ); + TypeComparator> comparator = typeInfo.createComparator(new int[] { 0 }, new boolean[] { true }, 0); reader = new BufferedReader(new FileReader(input)); MutableObjectIterator> inputIterator = new StringTupleReaderMutableObjectIterator(reader); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntListComparator.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntListComparator.java index 49d31a2a17907..b5822a814bb63 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntListComparator.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntListComparator.java @@ -33,8 +33,6 @@ public class IntListComparator extends TypeComparator { private int reference; - private Comparable[] extractedKey = new Comparable[1]; - private final TypeComparator[] comparators = new TypeComparator[] {new IntComparator(true)}; @Override @@ -142,12 +140,12 @@ public TypeComparator duplicate() { } @Override - public Object[] extractKeys(IntList record) { - extractedKey[0] = record.getKey(); - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = (Comparable) record; + return 1; } - @Override public TypeComparator[] getComparators() { + @Override public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntPairComparator.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntPairComparator.java index 98c5b694d027a..ab20b7f88083e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntPairComparator.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/IntPairComparator.java @@ -33,8 +33,6 @@ public class IntPairComparator extends TypeComparator { private int reference; - private final Comparable[] extractedKey = new Comparable[1]; - private final TypeComparator[] comparators = new TypeComparator[] {new IntComparator(true)}; @Override @@ -117,11 +115,12 @@ public IntPairComparator duplicate() { } @Override - public Object[] extractKeys(IntPair pair) { - extractedKey[0] = pair.getKey(); - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = ((IntPair) record).getKey(); + return 1; } - @Override public TypeComparator[] getComparators() { + + @Override public TypeComparator[] getFlatComparators() { return comparators; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/StringPairComparator.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/StringPairComparator.java index a5355dd307bff..0b0446fd2f410 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/StringPairComparator.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/types/StringPairComparator.java @@ -34,8 +34,6 @@ public class StringPairComparator extends TypeComparator { private String reference; - private Comparable[] extractedKey = new Comparable[1]; - private final TypeComparator[] comparators = new TypeComparator[] {new StringComparator(true)}; @Override @@ -118,12 +116,12 @@ public TypeComparator duplicate() { } @Override - public Object[] extractKeys(StringPair record) { - extractedKey[0] = record.getKey(); - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = ((StringPair) record).getKey(); + return 1; } - @Override public TypeComparator[] getComparators() { + @Override public TypeComparator[] getFlatComparators() { return comparators; } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java index 765d1214e97d5..a717f5504e9c2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java @@ -382,8 +382,6 @@ public void testWrongKeyClass() { @SuppressWarnings({"serial", "rawtypes"}) private static class TestIntComparator extends TypeComparator { - private final Comparable[] extractedKey = new Comparable[1]; - private TypeComparator[] comparators = new TypeComparator[]{new IntComparator(true)}; @Override @@ -444,13 +442,13 @@ public Integer readWithKeyDenormalization(Integer reuse, DataInputView source) t public TypeComparator duplicate() { throw new UnsupportedOperationException(); } @Override - public Object[] extractKeys(Integer record) { - extractedKey[0] = record; - return extractedKey; + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; } @Override - public TypeComparator[] getComparators() { + public TypeComparator[] getFlatComparators() { return comparators; } } diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java index 310fc17ac02c4..534ef45d0166a 100644 --- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java +++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java @@ -198,7 +198,7 @@ protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase operatorInfo = new UnaryOperatorInformation(getInputType(), getResultType()); diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 8b452ae745053..28624bccbfc5a 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -27,7 +27,7 @@ import org.apache.flink.api.java.aggregation.Aggregations import org.apache.flink.api.java.functions.{FirstReducer, KeySelector} import org.apache.flink.api.java.io.{PrintingOutputFormat, TextOutputFormat} import org.apache.flink.api.java.operators.JoinOperator.JoinHint -import org.apache.flink.api.java.operators.Keys.FieldPositionKeys +import org.apache.flink.api.java.operators.Keys.ExpressionKeys import org.apache.flink.api.java.operators._ import org.apache.flink.api.java.{DataSet => JavaDataSet} import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator} @@ -544,7 +544,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { def distinct(fields: Int*): DataSet[T] = { wrap(new DistinctOperator[T]( javaSet, - new Keys.FieldPositionKeys[T](fields.toArray, javaSet.getType, true))) + new Keys.ExpressionKeys[T](fields.toArray, javaSet.getType, true))) } /** @@ -557,7 +557,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray) wrap(new DistinctOperator[T]( javaSet, - new Keys.FieldPositionKeys[T](fieldIndices, javaSet.getType, true))) + new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType, true))) } /** @@ -602,7 +602,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { def groupBy(fields: Int*): GroupedDataSet[T] = { new GroupedDataSet[T]( this, - new Keys.FieldPositionKeys[T](fields.toArray, javaSet.getType,false)) + new Keys.ExpressionKeys[T](fields.toArray, javaSet.getType,false)) } /** @@ -619,7 +619,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { new GroupedDataSet[T]( this, - new Keys.FieldPositionKeys[T](fieldIndices, javaSet.getType,false)) + new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType,false)) } // public UnsortedGrouping groupBy(String... fields) { @@ -835,7 +835,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { */ def iterateDelta[R: ClassTag](workset: DataSet[R], maxIterations: Int, keyFields: Array[Int])( stepFunction: (DataSet[T], DataSet[R]) => (DataSet[T], DataSet[R])) = { - val key = new FieldPositionKeys[T](keyFields, javaSet.getType, false) + val key = new ExpressionKeys[T](keyFields, javaSet.getType, false) val iterativeSet = new DeltaIteration[T, R]( javaSet.getExecutionEnvironment, @@ -864,8 +864,8 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { stepFunction: (DataSet[T], DataSet[R]) => (DataSet[T], DataSet[R])) = { val fieldIndices = fieldNames2Indices(javaSet.getType, keyFields) - val key = new FieldPositionKeys[T](fieldIndices, javaSet.getType, false) + val key = new ExpressionKeys[T](fieldIndices, javaSet.getType, false) val iterativeSet = new DeltaIteration[T, R]( javaSet.getExecutionEnvironment, javaSet.getType, @@ -920,7 +920,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { val op = new PartitionOperator[T]( javaSet, PartitionMethod.HASH, - new Keys.FieldPositionKeys[T](fields.toArray, javaSet.getType, false)) + new Keys.ExpressionKeys[T](fields.toArray, javaSet.getType, false)) wrap(op) } @@ -936,7 +936,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { val op = new PartitionOperator[T]( javaSet, PartitionMethod.HASH, - new Keys.FieldPositionKeys[T](fieldIndices, javaSet.getType, false)) + new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType, false)) wrap(op) } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala index fd1985d2c3f81..f67f10cf34cc4 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala @@ -19,7 +19,7 @@ package org.apache.flink.api.scala import org.apache.commons.lang3.Validate import org.apache.flink.api.common.functions.{RichCrossFunction, CrossFunction} -import org.apache.flink.api.common.typeutils.TypeSerializer +import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} import org.apache.flink.api.java.operators._ import org.apache.flink.api.java.{DataSet => JavaDataSet} import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo} diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala index 83fe5cfc097c2..1353b44420b87 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala @@ -139,10 +139,10 @@ class CaseClassComparator[T <: Product]( } } - def extractKeys(value: T) = { + def extractKeys(value: AnyRef, target: Array[AnyRef], index: Int) = { for (i <- 0 until keyPositions.length ) { - extractedKeys(i) = value.productElement(keyPositions(i)).asInstanceOf[AnyRef] + target(index + i) = value.asInstanceOf[T].productElement(keyPositions(i)).asInstanceOf[AnyRef] } - extractedKeys + keyPositions.length } } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala index 153f3a8c0c94a..c9a3bbff038ce 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassTypeInfo.scala @@ -32,7 +32,8 @@ abstract class CaseClassTypeInfo[T <: Product]( val fieldNames: Seq[String]) extends TupleTypeInfoBase[T](clazz, fieldTypes: _*) { - def createComparator(logicalKeyFields: Array[Int], orders: Array[Boolean]): TypeComparator[T] = { + override def createComparator(logicalKeyFields: Array[Int], + orders: Array[Boolean], offset: Int): TypeComparator[T] = { // sanity checks if (logicalKeyFields == null || orders == null || logicalKeyFields.length != orders.length || logicalKeyFields.length > types.length) { @@ -80,6 +81,21 @@ abstract class CaseClassTypeInfo[T <: Product]( fields map { x => fieldNames.indexOf(x) } } + override protected def initializeNewComparator(localKeyCount: Int): Unit = { + throw new UnsupportedOperationException("The Scala API is not using the composite " + + "type comparator creation") + } + + override protected def getNewComparator: TypeComparator[T] = { + throw new UnsupportedOperationException("The Scala API is not using the composite " + + "type comparator creation") + } + + override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = { + throw new UnsupportedOperationException("The Scala API is not using the composite " + + "type comparator creation") + } + override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map { case (n, t) => n + ": " + t} .mkString(", ") + ")" diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala index f8a5d034cb379..9d9a19f0a37db 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.java.functions.KeySelector import org.apache.flink.api.java.operators.Keys -import org.apache.flink.api.java.operators.Keys.FieldPositionKeys +import org.apache.flink.api.java.operators.Keys.ExpressionKeys import org.apache.flink.api.common.typeinfo.TypeInformation /** @@ -56,7 +56,7 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O]( * This only works on Tuple [[DataSet]]. */ def where(leftKeys: Int*) = { - val leftKey = new FieldPositionKeys[L](leftKeys.toArray, leftInput.getType) + val leftKey = new ExpressionKeys[L](leftKeys.toArray, leftInput.getType) new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey) } @@ -73,7 +73,7 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O]( leftInput.getType, firstLeftField +: otherLeftFields.toArray) - val leftKey = new FieldPositionKeys[L](fieldIndices, leftInput.getType) + val leftKey = new ExpressionKeys[L](fieldIndices, leftInput.getType) new HalfUnfinishedKeyPairOperation[L, R, O](this, leftKey) } @@ -103,8 +103,8 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( * This only works on a Tuple [[DataSet]]. */ def equalTo(rightKeys: Int*): O = { - val rightKey = new FieldPositionKeys[R](rightKeys.toArray, unfinished.rightInput.getType) - if (!leftKey.areCompatibale(rightKey)) { + val rightKey = new ExpressionKeys[R](rightKeys.toArray, unfinished.rightInput.getType) + if (!leftKey.areCompatible(rightKey)) { throw new InvalidProgramException("The types of the key fields do not match. Left: " + leftKey + " Right: " + rightKey) } @@ -122,8 +122,8 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( unfinished.rightInput.getType, firstRightField +: otherRightFields.toArray) - val rightKey = new FieldPositionKeys[R](fieldIndices, unfinished.rightInput.getType) - if (!leftKey.areCompatibale(rightKey)) { + val rightKey = new ExpressionKeys[R](fieldIndices, unfinished.rightInput.getType) + if (!leftKey.areCompatible(rightKey)) { throw new InvalidProgramException("The types of the key fields do not match. Left: " + leftKey + " Right: " + rightKey) } @@ -145,7 +145,7 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O]( unfinished.rightInput.getType, keyType) - if (!leftKey.areCompatibale(rightKey)) { + if (!leftKey.areCompatible(rightKey)) { throw new InvalidProgramException("The types of the key fields do not match. Left: " + leftKey + " Right: " + rightKey) } diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala index a3bb34d42e21b..d962b76369757 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala @@ -19,6 +19,7 @@ package org.apache.flink.api.scala.operators import java.io.Serializable import org.apache.flink.api.common.InvalidProgramException +import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.junit.Assert import org.junit.Ignore import org.junit.Test @@ -44,7 +45,7 @@ class CoGroupOperatorTest { } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyFields2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -54,7 +55,7 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where(0).equalTo(2) } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyFields3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -109,7 +110,7 @@ class CoGroupOperatorTest { } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyFieldNames2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -119,7 +120,7 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where("_1").equalTo("_3") } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyFieldNames3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -254,7 +255,7 @@ class CoGroupOperatorTest { } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyMixing3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -264,7 +265,7 @@ class CoGroupOperatorTest { ds1.coGroup(ds2).where(2).equalTo { _.l } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testCoGroupKeyMixing4(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala index 05e45d8518843..cae936da804bb 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.api.scala.operators +import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException import org.junit.Assert import org.apache.flink.api.common.InvalidProgramException import org.junit.Ignore @@ -45,7 +46,7 @@ class JoinOperatorTest { } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testJoinKeyIndices2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -55,7 +56,7 @@ class JoinOperatorTest { ds1.join(ds2).where(0).equalTo(2) } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testJoinKeyIndices3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -110,7 +111,7 @@ class JoinOperatorTest { } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testJoinKeyFields2(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -120,7 +121,7 @@ class JoinOperatorTest { ds1.join(ds2).where("_1").equalTo("_3") } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testJoinKeyFields3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -255,7 +256,7 @@ class JoinOperatorTest { } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testJoinKeyMixing3(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) @@ -265,7 +266,7 @@ class JoinOperatorTest { ds1.join(ds2).where(2).equalTo { _.l } } - @Test(expected = classOf[InvalidProgramException]) + @Test(expected = classOf[IncompatibleKeysException]) def testJoinKeyMixing4(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val ds1 = env.fromCollection(emptyTupleData) diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD2Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD2Test.scala index e810b4568500e..adc46acace60e 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD2Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD2Test.scala @@ -30,7 +30,7 @@ class TupleComparatorILD2Test extends TupleComparatorTestBase[(Int, Long, Double protected def createComparator(ascending: Boolean): TypeComparator[(Int, Long, Double)] = { val ti = createTypeInformation[(Int, Long, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, Long, Double)]] - .createComparator(Array(0, 1), Array(ascending, ascending)) + .createComparator(Array(0, 1), Array(ascending, ascending), 0) } protected def createSerializer: TypeSerializer[(Int, Long, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD3Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD3Test.scala index 96cc42c2f4356..377f9857e4a06 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD3Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILD3Test.scala @@ -28,7 +28,7 @@ class TupleComparatorILD3Test extends TupleComparatorTestBase[(Int, Long, Double protected def createComparator(ascending: Boolean): TypeComparator[(Int, Long, Double)] = { val ti = createTypeInformation[(Int, Long, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, Long, Double)]] - .createComparator(Array(0, 1, 2), Array(ascending, ascending, ascending)) + .createComparator(Array(0, 1, 2), Array(ascending, ascending, ascending), 0) } protected def createSerializer: TypeSerializer[(Int, Long, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDC3Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDC3Test.scala index 30d55c9ca5e90..1578951df39af 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDC3Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDC3Test.scala @@ -28,7 +28,7 @@ class TupleComparatorILDC3Test extends TupleComparatorTestBase[(Int, Long, Doubl protected def createComparator(ascending: Boolean): TypeComparator[(Int, Long, Double)] = { val ti = createTypeInformation[(Int, Long, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, Long, Double)]] - .createComparator(Array(2, 0, 1), Array(ascending, ascending, ascending)) + .createComparator(Array(2, 0, 1), Array(ascending, ascending, ascending), 0) } protected def createSerializer: TypeSerializer[(Int, Long, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDX1Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDX1Test.scala index df0983cb1c8b6..51e08ccb5f267 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDX1Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDX1Test.scala @@ -28,7 +28,7 @@ class TupleComparatorILDX1Test extends TupleComparatorTestBase[(Int, Long, Doubl protected def createComparator(ascending: Boolean): TypeComparator[(Int, Long, Double)] = { val ti = createTypeInformation[(Int, Long, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, Long, Double)]] - .createComparator(Array(1), Array(ascending)) + .createComparator(Array(1), Array(ascending), 0) } protected def createSerializer: TypeSerializer[(Int, Long, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDXC2Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDXC2Test.scala index c4e3883c74e4a..3dbaabfdd784d 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDXC2Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorILDXC2Test.scala @@ -28,7 +28,7 @@ class TupleComparatorILDXC2Test extends TupleComparatorTestBase[(Int, Long, Doub protected def createComparator(ascending: Boolean): TypeComparator[(Int, Long, Double)] = { val ti = createTypeInformation[(Int, Long, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, Long, Double)]] - .createComparator(Array(2, 1), Array(ascending, ascending)) + .createComparator(Array(2, 1), Array(ascending, ascending), 0) } protected def createSerializer: TypeSerializer[(Int, Long, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD1Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD1Test.scala index 01c3ab839410f..252ae79e604f7 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD1Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD1Test.scala @@ -27,7 +27,7 @@ class TupleComparatorISD1Test extends TupleComparatorTestBase[(Int, String, Doub protected def createComparator(ascending: Boolean): TypeComparator[(Int, String, Double)] = { val ti = createTypeInformation[(Int, String, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, String, Double)]] - .createComparator(Array(0), Array(ascending)) + .createComparator(Array(0), Array(ascending),0) } protected def createSerializer: TypeSerializer[(Int, String, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD2Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD2Test.scala index 20c1955d5d3c6..37e775edf3ff8 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD2Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD2Test.scala @@ -27,7 +27,7 @@ class TupleComparatorISD2Test extends TupleComparatorTestBase[(Int, String, Doub protected def createComparator(ascending: Boolean): TypeComparator[(Int, String, Double)] = { val ti = createTypeInformation[(Int, String, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, String, Double)]] - .createComparator(Array(0, 1), Array(ascending, ascending)) + .createComparator(Array(0, 1), Array(ascending, ascending), 0) } protected def createSerializer: TypeSerializer[(Int, String, Double)] = { diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD3Test.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD3Test.scala index 6d945a9bd8bcd..227b041fc14d1 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD3Test.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/TupleComparatorISD3Test.scala @@ -27,7 +27,7 @@ class TupleComparatorISD3Test extends TupleComparatorTestBase[(Int, String, Doub protected def createComparator(ascending: Boolean): TypeComparator[(Int, String, Double)] = { val ti = createTypeInformation[(Int, String, Double)] ti.asInstanceOf[TupleTypeInfoBase[(Int, String, Double)]] - .createComparator(Array(0, 1, 2), Array(ascending, ascending, ascending)) + .createComparator(Array(0, 1, 2), Array(ascending, ascending, ascending), 0) } protected def createSerializer: TypeSerializer[(Int, String, Double)] = { diff --git a/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountNestedPOJOITCase.java b/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountNestedPOJOITCase.java new file mode 100644 index 0000000000000..bdd5a7dc73a4f --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountNestedPOJOITCase.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.exampleJavaPrograms; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.test.testdata.WordCountData; +import org.apache.flink.test.util.JavaProgramTestBase; +import org.apache.flink.util.Collector; + +import java.io.Serializable; +import java.util.Date; + + +public class WordCountNestedPOJOITCase extends JavaProgramTestBase implements Serializable { + private static final long serialVersionUID = 1L; + protected String textPath; + protected String resultPath; + + + @Override + protected void preSubmit() throws Exception { + textPath = createTempFile("text.txt", WordCountData.TEXT); + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(WordCountData.COUNTS, resultPath); + } + + @Override + protected void testProgram() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet text = env.readTextFile(textPath); + + DataSet counts = text + .flatMap(new Tokenizer()) + .groupBy("complex.someTest") + .reduce(new ReduceFunction() { + private static final long serialVersionUID = 1L; + public WC reduce(WC value1, WC value2) { + return new WC(value1.complex.someTest, value1.count + value2.count); + } + }); + + counts.writeAsText(resultPath); + + env.execute("WordCount with custom data types example"); + } + + public static final class Tokenizer implements FlatMapFunction { + + @Override + public void flatMap(String value, Collector out) { + // normalize and split the line + String[] tokens = value.toLowerCase().split("\\W+"); + + // emit the pairs + for (String token : tokens) { + if (token.length() > 0) { + out.collect(new WC(token, 1)); + } + } + } + } + + public static class WC { // is a pojo + public ComplexNestedClass complex; // is a pojo + public int count; // is a BasicType + + public WC() { + } + public WC(String t, int c) { + this.count = c; + this.complex = new ComplexNestedClass(); + this.complex.word = new Tuple3(0L, 0L, "egal"); + this.complex.date = new Date(); + this.complex.someFloat = 0.0f; + this.complex.someNumber = 666; + this.complex.someTest = t; + } + @Override + public String toString() { + return this.complex.someTest+" "+count; + } + } + + public static class ComplexNestedClass { // pojo + public static int ignoreStaticField; + public transient int ignoreTransientField; + public Date date; // generic type + public Integer someNumber; // BasicType + public float someFloat; // BasicType + public Tuple3 word; //Tuple Type with three basic types + public String someTest; + } + +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountPOJOITCase.java b/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountPOJOITCase.java deleted file mode 100644 index 0a5732ca20d39..0000000000000 --- a/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountPOJOITCase.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// -//package org.apache.flink.test.exampleJavaPrograms; -// -////import org.apache.flink.examples.java.wordcount.WordCountPOJO; -//import org.apache.flink.test.testdata.WordCountData; -//import org.apache.flink.test.util.JavaProgramTestBase; -//import org.junit.Ignore; -// -//@Ignore -//public class WordCountPOJOITCase extends JavaProgramTestBase { -// -// protected String textPath; -// protected String resultPath; -// -// -// @Override -// protected void preSubmit() throws Exception { -// textPath = createTempFile("text.txt", WordCountData.TEXT); -// resultPath = getTempDirPath("result"); -// } -// -// @Override -// protected void postSubmit() throws Exception { -// compareResultsByLinesInMemory(WordCountData.COUNTS, resultPath); -// } -// -// @Override -// protected void testProgram() throws Exception { -// WordCountPOJO.main(new String[]{textPath, resultPath}); -// } -//} diff --git a/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountSimplePOJOITCase.java b/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountSimplePOJOITCase.java new file mode 100644 index 0000000000000..7d20597c45dfb --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/exampleJavaPrograms/WordCountSimplePOJOITCase.java @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.exampleJavaPrograms; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.test.testdata.WordCountData; +import org.apache.flink.test.util.JavaProgramTestBase; +import org.apache.flink.util.Collector; + +import java.io.Serializable; + + +public class WordCountSimplePOJOITCase extends JavaProgramTestBase implements Serializable { + private static final long serialVersionUID = 1L; + protected String textPath; + protected String resultPath; + + + @Override + protected void preSubmit() throws Exception { + textPath = createTempFile("text.txt", WordCountData.TEXT); + resultPath = getTempDirPath("result"); + } + + @Override + protected void postSubmit() throws Exception { + compareResultsByLinesInMemory(WordCountData.COUNTS, resultPath); + } + + @Override + protected void testProgram() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet text = env.readTextFile(textPath); + + DataSet counts = text + .flatMap(new Tokenizer()) + .groupBy("word") + .reduce(new ReduceFunction() { + private static final long serialVersionUID = 1L; + + public WC reduce(WC value1, WC value2) { + return new WC(value1.word, value1.count + value2.count); + } + }); + + counts.writeAsText(resultPath); + + env.execute("WordCount with custom data types example"); + } + + public static final class Tokenizer implements FlatMapFunction { + private static final long serialVersionUID = 1L; + + @Override + public void flatMap(String value, Collector out) { + // normalize and split the line + String[] tokens = value.toLowerCase().split("\\W+"); + + // emit the pairs + for (String token : tokens) { + if (token.length() > 0) { + out.collect(new WC(token, 1)); + } + } + } + } + + public static class WC { + public WC() {} + public WC(String w, int c) { + word = w; + count = c; + } + public String word; + public int count; + @Override + public String toString() { + return word + " " + count; + } + } + +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/EnumTriangleOptITCase.java b/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/EnumTriangleOptITCase.java index 9701086ca7d6c..4f00b1dc9a0b7 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/EnumTriangleOptITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/EnumTriangleOptITCase.java @@ -18,7 +18,7 @@ package org.apache.flink.test.exampleScalaPrograms; -import org.apache.flink.examples.scala.graph.EnumTrianglesOpt; +import org.apache.flink.examples.java.graph.EnumTrianglesOpt; import org.apache.flink.test.testdata.EnumTriangleData; import org.apache.flink.test.util.JavaProgramTestBase; diff --git a/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/PageRankITCase.java b/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/PageRankITCase.java index d14c2e192d06e..42c6a8e7d2f7b 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/PageRankITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/exampleScalaPrograms/PageRankITCase.java @@ -18,19 +18,19 @@ package org.apache.flink.test.exampleScalaPrograms; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Collection; +import java.util.LinkedList; + import org.apache.flink.configuration.Configuration; -import org.apache.flink.examples.scala.graph.PageRankBasic; +import org.apache.flink.examples.java.graph.PageRankBasic; import org.apache.flink.test.testdata.PageRankData; import org.apache.flink.test.util.JavaProgramTestBase; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.util.Collection; -import java.util.LinkedList; - @RunWith(Parameterized.class) public class PageRankITCase extends JavaProgramTestBase { diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/types/VertexWithAdjacencyListComparator.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/types/VertexWithAdjacencyListComparator.java index ce31ea23d3116..e607ba73309e4 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/types/VertexWithAdjacencyListComparator.java +++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/types/VertexWithAdjacencyListComparator.java @@ -32,8 +32,6 @@ public final class VertexWithAdjacencyListComparator extends TypeComparator getKey(Tuple3 t) { "5,3,HIJ\n" + "5,3,IJK\n"; } + case 10: { + /* + * CoGroup on two custom type inputs using expression keys + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getCustomTypeDataSet(env); + DataSet ds2 = CollectionDataSets.getCustomTypeDataSet(env); + DataSet coGroupDs = ds.coGroup(ds2).where("myInt").equalTo("myInt").with(new CustomTypeCoGroup()); + + coGroupDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "1,0,test\n" + + "2,6,test\n" + + "3,24,test\n" + + "4,60,test\n" + + "5,120,test\n" + + "6,210,test\n"; + } + case 11: { + /* + * CoGroup on two custom type inputs using expression keys + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet coGroupDs = ds.coGroup(ds2) + .where("nestedPojo.longNumber").equalTo(6).with(new CoGroupFunction, CustomType>() { + private static final long serialVersionUID = 1L; + + @Override + public void coGroup( + Iterable first, + Iterable> second, + Collector out) throws Exception { + for(POJO p : first) { + for(Tuple7 t: second) { + Assert.assertTrue(p.nestedPojo.longNumber == t.f6); + out.collect(new CustomType(-1, p.nestedPojo.longNumber, "Flink")); + } + } + } + }); + coGroupDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "-1,20000,Flink\n" + + "-1,10000,Flink\n" + + "-1,30000,Flink\n"; + } + case 12: { + /* + * CoGroup field-selector (expression keys) + key selector function + * The key selector is unnecessary complicated (Tuple1) ;) + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet coGroupDs = ds.coGroup(ds2) + .where(new KeySelector>() { + private static final long serialVersionUID = 1L; + + @Override + public Tuple1 getKey(POJO value) + throws Exception { + return new Tuple1(value.nestedPojo.longNumber); + } + }).equalTo(6).with(new CoGroupFunction, CustomType>() { + private static final long serialVersionUID = 1L; + + @Override + public void coGroup( + Iterable first, + Iterable> second, + Collector out) throws Exception { + for(POJO p : first) { + for(Tuple7 t: second) { + Assert.assertTrue(p.nestedPojo.longNumber == t.f6); + out.collect(new CustomType(-1, p.nestedPojo.longNumber, "Flink")); + } + } + } + }); + coGroupDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "-1,20000,Flink\n" + + "-1,10000,Flink\n" + + "-1,30000,Flink\n"; + } + case 13: { + /* + * CoGroup field-selector (expression keys) + key selector function + * The key selector is simple here + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet coGroupDs = ds.coGroup(ds2) + .where(new KeySelector() { + private static final long serialVersionUID = 1L; + + @Override + public Long getKey(POJO value) + throws Exception { + return value.nestedPojo.longNumber; + } + }).equalTo(6).with(new CoGroupFunction, CustomType>() { + private static final long serialVersionUID = 1L; + + @Override + public void coGroup( + Iterable first, + Iterable> second, + Collector out) throws Exception { + for(POJO p : first) { + for(Tuple7 t: second) { + Assert.assertTrue(p.nestedPojo.longNumber == t.f6); + out.collect(new CustomType(-1, p.nestedPojo.longNumber, "Flink")); + } + } + } + }); + coGroupDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "-1,20000,Flink\n" + + "-1,10000,Flink\n" + + "-1,30000,Flink\n"; + } + default: throw new IllegalArgumentException("Invalid program id"); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java index 00b68fc1225c2..fce56d1594b94 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java @@ -35,7 +35,12 @@ import org.apache.flink.compiler.PactCompiler; import org.apache.flink.configuration.Configuration; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CrazyNested; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CustomType; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.FromTuple; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.FromTupleWithCTor; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.POJO; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.PojoContainingTupleAndWritable; import org.apache.flink.test.util.JavaProgramTestBase; import org.apache.flink.util.Collector; import org.junit.runner.RunWith; @@ -48,7 +53,7 @@ @RunWith(Parameterized.class) public class GroupReduceITCase extends JavaProgramTestBase { - private static int NUM_PROGRAMS = 15; + private static int NUM_PROGRAMS = 19; private int curProgId = config.getInteger("ProgramId", -1); private String resultPath; @@ -406,7 +411,6 @@ public Integer getKey(CustomType in) { /* * check correctness of groupReduce with descending group sort */ - final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setDegreeOfParallelism(1); @@ -484,6 +488,117 @@ public Tuple2 getKey(Tuple5 t) "16,6,Comment#10\n"; } + case 16: { + /* + * Deep nesting test + * + null value in pojo + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getCrazyNestedDataSet(env); + DataSet> reduceDs = ds.groupBy("nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal") + .reduceGroup(new GroupReduceFunction>() { + private static final long serialVersionUID = 1L; + + @Override + public void reduce(Iterable values, + Collector> out) + throws Exception { + int c = 0; String n = null; + for(CrazyNested v : values) { + c++; // haha + n = v.nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal; + } + out.collect(new Tuple2(n,c)); + }}); + + reduceDs.writeAsCsv(resultPath); + env.execute(); + + // return expected result + return "aa,1\nbb,2\ncc,3\n"; + } + case 17: { + /* + * Test Pojo extending from tuple WITH custom fields + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getPojoExtendingFromTuple(env); + DataSet reduceDs = ds.groupBy("special", "f2") + .reduceGroup(new GroupReduceFunction() { + private static final long serialVersionUID = 1L; + @Override + public void reduce(Iterable values, + Collector out) + throws Exception { + int c = 0; + for(FromTuple v : values) { + c++; + } + out.collect(c); + }}); + + reduceDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "3\n2\n"; + } + case 18: { + /* + * Test Pojo containing a Writable and Tuples + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds = CollectionDataSets.getPojoContainingTupleAndWritable(env); + DataSet reduceDs = ds.groupBy("hadoopFan", "theTuple.*") // full tuple selection + .reduceGroup(new GroupReduceFunction() { + private static final long serialVersionUID = 1L; + @Override + public void reduce(Iterable values, + Collector out) + throws Exception { + int c = 0; + for(PojoContainingTupleAndWritable v : values) { + c++; + } + out.collect(c); + }}); + + reduceDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "1\n5\n"; + } + case 19: { + /* + * Test Tuple containing pojos and regular fields + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet> ds = CollectionDataSets.getTupleContainingPojos(env); + DataSet reduceDs = ds.groupBy("f0", "f1.*") // nested full tuple selection + .reduceGroup(new GroupReduceFunction, Integer>() { + private static final long serialVersionUID = 1L; + @Override + public void reduce(Iterable> values, + Collector out) + throws Exception { + int c = 0; + for(Tuple3 v : values) { + c++; + } + out.collect(c); + }}); + + reduceDs.writeAsText(resultPath); + env.execute(); + + // return expected result + return "3\n1\n"; + } default: { throw new IllegalArgumentException("Invalid program id"); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java index 433a076844107..e8d8be951d107 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java @@ -31,9 +31,11 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.api.java.tuple.Tuple6; +import org.apache.flink.api.java.tuple.Tuple7; import org.apache.flink.configuration.Configuration; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CustomType; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.POJO; import org.apache.flink.test.util.JavaProgramTestBase; import org.apache.flink.util.Collector; import org.junit.runner.RunWith; @@ -46,7 +48,7 @@ @RunWith(Parameterized.class) public class JoinITCase extends JavaProgramTestBase { - private static int NUM_PROGRAMS = 14; + private static int NUM_PROGRAMS = 21; private int curProgId = config.getInteger("ProgramId", -1); private String resultPath; @@ -493,8 +495,175 @@ public Tuple2 getKey(Tuple5 t) "I am fine.,HIJ\n" + "I am fine.,IJK\n"; } + /** + * Joins with POJOs + */ + case 15: { + /* + * Join nested pojo against tuple (selected using a string) + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet >> joinDs = + ds1.join(ds2).where("nestedPojo.longNumber").equalTo("f6"); + + joinDs.writeAsCsv(resultPath); + env.execute(); + + // return expected result + return "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One,10000)\n" + + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two,20000)\n" + + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n"; + } + + case 16: { + /* + * Join nested pojo against tuple (selected as an integer) + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet >> joinDs = + ds1.join(ds2).where("nestedPojo.longNumber").equalTo(6); // <--- difference! + + joinDs.writeAsCsv(resultPath); + env.execute(); + + // return expected result + return "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One,10000)\n" + + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two,20000)\n" + + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n"; + } + case 17: { + /* + * selecting multiple fields using expression language + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet >> joinDs = + ds1.join(ds2).where("nestedPojo.longNumber", "number", "str").equalTo("f6","f0","f1"); + + joinDs.writeAsCsv(resultPath); + env.setDegreeOfParallelism(1); + env.execute(); + + // return expected result + return "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One,10000)\n" + + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two,20000)\n" + + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n"; + + } + case 18: { + /* + * nested into tuple + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet >> joinDs = + ds1.join(ds2).where("nestedPojo.longNumber", "number","nestedTupleWithCustom.f0").equalTo("f6","f0","f2"); + + joinDs.writeAsCsv(resultPath); + env.setDegreeOfParallelism(1); + env.execute(); + + // return expected result + return "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One,10000)\n" + + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two,20000)\n" + + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n"; + + } + case 19: { + /* + * nested into tuple into pojo + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet >> joinDs = + ds1.join(ds2).where("nestedTupleWithCustom.f0","nestedTupleWithCustom.f1.myInt","nestedTupleWithCustom.f1.myLong").equalTo("f2","f3","f4"); + + joinDs.writeAsCsv(resultPath); + env.setDegreeOfParallelism(1); + env.execute(); + + // return expected result + return "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One,10000)\n" + + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two,20000)\n" + + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n"; + + } + case 20: { + /* + * Non-POJO test to verify that full-tuple keys are working. + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet, String>> ds1 = CollectionDataSets.getSmallNestedTupleDataSet(env); + DataSet, String>> ds2 = CollectionDataSets.getSmallNestedTupleDataSet(env); + DataSet, String>, Tuple2, String> >> joinDs = + ds1.join(ds2).where(0).equalTo("f0.f0", "f0.f1"); // key is now Tuple2 + + joinDs.writeAsCsv(resultPath); + env.setDegreeOfParallelism(1); + env.execute(); + + // return expected result + return "((1,1),one),((1,1),one)\n" + + "((2,2),two),((2,2),two)\n" + + "((3,3),three),((3,3),three)\n"; + + } + case 21: { + /* + * Non-POJO test to verify "nested" tuple-element selection. + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet, String>> ds1 = CollectionDataSets.getSmallNestedTupleDataSet(env); + DataSet, String>> ds2 = CollectionDataSets.getSmallNestedTupleDataSet(env); + DataSet, String>, Tuple2, String> >> joinDs = + ds1.join(ds2).where("f0.f0").equalTo("f0.f0"); // key is now Integer from Tuple2 + + joinDs.writeAsCsv(resultPath); + env.setDegreeOfParallelism(1); + env.execute(); + + // return expected result + return "((1,1),one),((1,1),one)\n" + + "((2,2),two),((2,2),two)\n" + + "((3,3),three),((3,3),three)\n"; + + } + case 22: { + /* + * full pojo with full tuple + */ + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = CollectionDataSets.getSmallPojoDataSet(env); + DataSet> ds2 = CollectionDataSets.getSmallTuplebasedPojoMatchingDataSet(env); + DataSet >> joinDs = + ds1.join(ds2).where("*").equalTo("*"); + + joinDs.writeAsCsv(resultPath); + env.setDegreeOfParallelism(1); + env.execute(); + + // return expected result + return "1 First (10,100,1000,One) 10000,(1,First,10,100,1000,One,10000)\n" + + "2 Second (20,200,2000,Two) 20000,(2,Second,20,200,2000,Two,20000)\n" + + "3 Third (30,300,3000,Three) 30000,(3,Third,30,300,3000,Three,30000)\n"; + } default: - throw new IllegalArgumentException("Invalid program id"); + throw new IllegalArgumentException("Invalid program id: "+progId); } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java index afefef902bd90..bf1d404cba77b 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/PartitionITCase.java @@ -45,7 +45,7 @@ @RunWith(Parameterized.class) public class PartitionITCase extends JavaProgramTestBase { - private static int NUM_PROGRAMS = 1; + private static int NUM_PROGRAMS = 3; private int curProgId = config.getInteger("ProgramId", -1); private String resultPath; @@ -111,7 +111,7 @@ public static String runProgram(int progId, String resultPath) throws Exception "5\n" + "6\n"; } - case 2: { + case 1: { /* * Test hash partition by key selector */ @@ -141,7 +141,7 @@ public Long getKey(Tuple3 value) throws Exception { "5\n" + "6\n"; } - case 1: { + case 2: { /* * Test forced rebalancing */ @@ -200,7 +200,7 @@ public Tuple2 map(Tuple2 value) throws Excep // return expected result return result.toString(); } - case 4: { + case 3: { /* * Test hash partition by key field and different DOP */ diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceITCase.java index 10ea882f4eee8..a1957f95fe39e 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/ReduceITCase.java @@ -42,7 +42,7 @@ @RunWith(Parameterized.class) public class ReduceITCase extends JavaProgramTestBase { - private static int NUM_PROGRAMS = 9; + private static int NUM_PROGRAMS = 10; private int curProgId = config.getInteger("ProgramId", -1); private String resultPath; @@ -305,6 +305,33 @@ public Tuple2 getKey(Tuple5 t) "5,29,0,P-),2\n" + "5,25,0,P-),3\n"; } + case 10: { + /* + * Case 2 with String-based field expression + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet> ds = CollectionDataSets.get5TupleDataSet(env); + DataSet> reduceDs = ds. + groupBy("f4","f0").reduce(new Tuple5Reduce()); + + reduceDs.writeAsCsv(resultPath); + env.execute(); + + // return expected result + return "1,1,0,Hallo,1\n" + + "2,3,2,Hallo Welt wie,1\n" + + "2,2,1,Hallo Welt,2\n" + + "3,9,0,P-),2\n" + + "3,6,5,BCD,3\n" + + "4,17,0,P-),1\n" + + "4,17,0,P-),2\n" + + "5,11,10,GHI,1\n" + + "5,29,0,P-),2\n" + + "5,25,0,P-),3\n"; + } + default: throw new IllegalArgumentException("Invalid program id"); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/util/CollectionDataSets.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/util/CollectionDataSets.java index 3ca8f31bf535f..b657545fd497e 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/util/CollectionDataSets.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/util/CollectionDataSets.java @@ -24,11 +24,14 @@ import java.util.List; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple5; +import org.apache.flink.api.java.tuple.Tuple7; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.hadoop.io.IntWritable; /** * ####################################################################################################### @@ -136,6 +139,22 @@ public static DataSet> getSmall5Tup return env.fromCollection(data, type); } + public static DataSet, String>> getSmallNestedTupleDataSet(ExecutionEnvironment env) { + + List, String>> data = new ArrayList, String>>(); + data.add(new Tuple2, String>(new Tuple2(1,1), "one")); + data.add(new Tuple2, String>(new Tuple2(2,2), "two")); + data.add(new Tuple2, String>(new Tuple2(3,3), "three")); + + TupleTypeInfo, String>> type = new + TupleTypeInfo, String>>( + new TupleTypeInfo>(BasicTypeInfo.INT_TYPE_INFO,BasicTypeInfo.INT_TYPE_INFO), + BasicTypeInfo.STRING_TYPE_INFO + ); + + return env.fromCollection(data, type); + } + public static DataSet getStringDataSet(ExecutionEnvironment env) { List data = new ArrayList(); @@ -241,7 +260,150 @@ public CustomType(int i, long l, String s) { public String toString() { return myInt+","+myLong+","+myString; } - + } + + public static DataSet> getSmallTuplebasedPojoMatchingDataSet(ExecutionEnvironment env) { + List> data = new ArrayList>(); + data.add(new Tuple7(1, "First",10, 100, 1000L, "One", 10000L)); + data.add(new Tuple7(2, "Second",20, 200, 2000L, "Two", 20000L)); + data.add(new Tuple7(3, "Third",30, 300, 3000L, "Three", 30000L)); + return env.fromCollection(data); + } + + public static DataSet getSmallPojoDataSet(ExecutionEnvironment env) { + List data = new ArrayList(); + data.add(new POJO(1, "First",10, 100, 1000L, "One", 10000L)); + data.add(new POJO(2, "Second",20, 200, 2000L, "Two", 20000L)); + data.add(new POJO(3, "Third",30, 300, 3000L, "Three", 30000L)); + return env.fromCollection(data); + } + + public static class POJO { + public int number; + public String str; + public Tuple2 nestedTupleWithCustom; + public NestedPojo nestedPojo; + public transient Long ignoreMe; + public POJO(int i0, String s0, + int i1, int i2, long l0, String s1, + long l1) { + this.number = i0; + this.str = s0; + this.nestedTupleWithCustom = new Tuple2(i1, new CustomType(i2, l0, s1)); + this.nestedPojo = new NestedPojo(); + this.nestedPojo.longNumber = l1; + } + public POJO() {} + @Override + public String toString() { + return number+" "+str+" "+nestedTupleWithCustom+" "+nestedPojo.longNumber; + } + } + + public static class NestedPojo { + public static Object ignoreMe; + public long longNumber; + public NestedPojo() {} + } + + public static DataSet getCrazyNestedDataSet(ExecutionEnvironment env) { + List data = new ArrayList(); + data.add(new CrazyNested("aa")); + data.add(new CrazyNested("bb")); + data.add(new CrazyNested("bb")); + data.add(new CrazyNested("cc")); + data.add(new CrazyNested("cc")); + data.add(new CrazyNested("cc")); + return env.fromCollection(data); + } + + public static class CrazyNested { + public CrazyNestedL1 nest_Lvl1; + public Long something; // test proper null-value handling + public CrazyNested() {} + public CrazyNested(String set, String second, long s) { // additional CTor to set all fields to non-null values + this(set); + something = s; + nest_Lvl1.a = second; + } + public CrazyNested(String set) { + nest_Lvl1 = new CrazyNestedL1(); + nest_Lvl1.nest_Lvl2 = new CrazyNestedL2(); + nest_Lvl1.nest_Lvl2.nest_Lvl3 = new CrazyNestedL3(); + nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4 = new CrazyNestedL4(); + nest_Lvl1.nest_Lvl2.nest_Lvl3.nest_Lvl4.f1nal = set; + } + } + public static class CrazyNestedL1 { + public String a; + public int b; + public CrazyNestedL2 nest_Lvl2; + } + public static class CrazyNestedL2 { + public CrazyNestedL3 nest_Lvl3; + } + public static class CrazyNestedL3 { + public CrazyNestedL4 nest_Lvl4; + } + public static class CrazyNestedL4 { + public String f1nal; + } + + // Copied from TypeExtractorTest + public static class FromTuple extends Tuple3 { + private static final long serialVersionUID = 1L; + public int special; + } + + public static class FromTupleWithCTor extends FromTuple { + public FromTupleWithCTor() {} + public FromTupleWithCTor(int special, long tupleField ) { + this.special = special; + this.setField(tupleField, 2); + } + } + public static DataSet getPojoExtendingFromTuple(ExecutionEnvironment env) { + List data = new ArrayList(); + data.add(new FromTupleWithCTor(1, 10L)); // 3x + data.add(new FromTupleWithCTor(1, 10L)); + data.add(new FromTupleWithCTor(1, 10L)); + data.add(new FromTupleWithCTor(2, 20L)); // 2x + data.add(new FromTupleWithCTor(2, 20L)); + return env.fromCollection(data); + } + + public static class PojoContainingTupleAndWritable { + public int someInt; + public String someString; + public IntWritable hadoopFan; + public Tuple2 theTuple; + public PojoContainingTupleAndWritable() {} + public PojoContainingTupleAndWritable(int i, long l1, long l2) { + hadoopFan = new IntWritable(i); + someInt = i; + theTuple = new Tuple2(l1, l2); + } + } + + public static DataSet getPojoContainingTupleAndWritable(ExecutionEnvironment env) { + List data = new ArrayList(); + data.add(new PojoContainingTupleAndWritable(1, 10L, 100L)); // 1x + data.add(new PojoContainingTupleAndWritable(2, 20L, 200L)); // 5x + data.add(new PojoContainingTupleAndWritable(2, 20L, 200L)); + data.add(new PojoContainingTupleAndWritable(2, 20L, 200L)); + data.add(new PojoContainingTupleAndWritable(2, 20L, 200L)); + data.add(new PojoContainingTupleAndWritable(2, 20L, 200L)); + return env.fromCollection(data); + } + + public static DataSet> getTupleContainingPojos(ExecutionEnvironment env) { + List> data = new ArrayList>(); + data.add(new Tuple3(1, new CrazyNested("one", "uno", 1L), new POJO(1, "First",10, 100, 1000L, "One", 10000L) )); // 3x + data.add(new Tuple3(1, new CrazyNested("one", "uno", 1L), new POJO(1, "First",10, 100, 1000L, "One", 10000L) )); + data.add(new Tuple3(1, new CrazyNested("one", "uno", 1L), new POJO(1, "First",10, 100, 1000L, "One", 10000L) )); + // POJO is not initialized according to the first two fields. + data.add(new Tuple3(2, new CrazyNested("two", "duo", 2L), new POJO(1, "First",10, 100, 1000L, "One", 10000L) )); // 1x + return env.fromCollection(data); } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/recordJobs/kmeans/udfs/ComputeDistance.java b/flink-tests/src/test/java/org/apache/flink/test/recordJobs/kmeans/udfs/ComputeDistance.java index 6dedcc146f194..16267f65512ec 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/recordJobs/kmeans/udfs/ComputeDistance.java +++ b/flink-tests/src/test/java/org/apache/flink/test/recordJobs/kmeans/udfs/ComputeDistance.java @@ -25,7 +25,6 @@ import org.apache.flink.types.DoubleValue; import org.apache.flink.types.IntValue; import org.apache.flink.types.Record; -import org.apache.flink.util.Collector; /** * Cross PACT computes the distance of all data points to all cluster diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/testjar/KMeansForTest.java b/flink-tests/src/test/java/org/apache/flink/test/util/testjar/KMeansForTest.java index 1925a94a86fe8..0821af5878f5e 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/util/testjar/KMeansForTest.java +++ b/flink-tests/src/test/java/org/apache/flink/test/util/testjar/KMeansForTest.java @@ -24,17 +24,20 @@ import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.Program; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichReduceFunction; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; - +import org.apache.flink.test.localDistributed.PackagedProgramEndToEndITCase; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.IterativeDataSet; +/** + * This class belongs to the @see {@link PackagedProgramEndToEndITCase} test + * + */ @SuppressWarnings("serial") public class KMeansForTest implements Program { @@ -80,12 +83,8 @@ public Plan getPlan(String... args) { .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids") // count and sum point coordinates for each centroid .map(new CountAppender()) - .groupBy(new KeySelector() { - @Override - public Integer getKey(DummyTuple3IntPointLong value) throws Exception { - return value.f0; - } - }).reduce(new CentroidAccumulator()) + // !test if key expressions are working! + .groupBy("field0").reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .map(new CentroidAverager()); @@ -228,16 +227,16 @@ public Tuple2 map(Point p) throws Exception { // Use this so that we can check whether POJOs and the POJO comparator also work public static final class DummyTuple3IntPointLong { - public Integer f0; - public Point f1; - public Long f2; + public Integer field0; + public Point field1; + public Long field2; public DummyTuple3IntPointLong() {} DummyTuple3IntPointLong(Integer f0, Point f1, Long f2) { - this.f0 = f0; - this.f1 = f1; - this.f2 = f2; + this.field0 = f0; + this.field1 = f1; + this.field2 = f2; } } @@ -255,7 +254,7 @@ public static final class CentroidAccumulator extends RichReduceFunction