diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java index dbebeb4dde9fe..7be5650ef4821 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java @@ -32,6 +32,7 @@ import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; @@ -273,12 +274,15 @@ protected List executeOnCollections(List input1, List input2, Run return result; } + @SuppressWarnings("unchecked") private TypeComparator getTypeComparator(ExecutionConfig executionConfig, TypeInformation inputType, int[] inputKeys, boolean[] inputSortDirections) { - if (!(inputType instanceof CompositeType)) { - throw new InvalidProgramException("Input types of coGroup must be composite types."); + if (inputType instanceof CompositeType) { + return ((CompositeType) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig); + } else if (inputType instanceof AtomicType) { + return ((AtomicType) inputType).createComparator(inputSortDirections[0], executionConfig); } - return ((CompositeType) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig); + throw new InvalidProgramException("Input type of coGroup must be one of composite types or atomic types."); } private static class CoGroupSortListIterator { diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java index 27fbc1c3eca42..c7ba92b70b0b4 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java @@ -31,6 +31,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation; import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; @@ -77,6 +78,16 @@ public Ordering getGroupOrder() { return this.groupOrder; } + private TypeComparator getTypeComparator(TypeInformation typeInfo, int[] sortColumns, boolean[] sortOrderings, ExecutionConfig executionConfig) { + if (typeInfo instanceof CompositeType) { + return ((CompositeType) typeInfo).createComparator(sortColumns, sortOrderings, 0, executionConfig); + } else if (typeInfo instanceof AtomicType) { + return ((AtomicType) typeInfo).createComparator(sortOrderings[0], executionConfig); + } + + throw new InvalidProgramException("Input type of GroupCombine must be one of composite types or atomic types."); + } + // -------------------------------------------------------------------------------------------- @Override @@ -87,11 +98,6 @@ protected List executeOnCollections(List inputData, RuntimeContext ctx, TypeInformation inputType = operatorInfo.getInputType(); int[] keyColumns = getKeyColumns(0); - - if (!(inputType instanceof CompositeType) && (keyColumns.length > 0 || groupOrder != null)) { - throw new InvalidProgramException("Grouping or group-sorting is only possible on composite type."); - } - int[] sortColumns = keyColumns; boolean[] sortOrderings = new boolean[sortColumns.length]; @@ -100,19 +106,17 @@ protected List executeOnCollections(List inputData, RuntimeContext ctx, sortOrderings = ArrayUtils.addAll(sortOrderings, groupOrder.getFieldSortDirections()); } - if (inputType instanceof CompositeType) { - if(sortColumns.length == 0) { // => all reduce. No comparator - Preconditions.checkArgument(sortOrderings.length == 0); - } else { - final TypeComparator sortComparator = ((CompositeType) inputType).createComparator(sortColumns, sortOrderings, 0, executionConfig); - - Collections.sort(inputData, new Comparator() { - @Override - public int compare(IN o1, IN o2) { - return sortComparator.compare(o1, o2); - } - }); - } + if(sortColumns.length == 0) { // => all reduce. No comparator + Preconditions.checkArgument(sortOrderings.length == 0); + } else { + final TypeComparator sortComparator = getTypeComparator(inputType, sortColumns, sortOrderings, executionConfig); + + Collections.sort(inputData, new Comparator() { + @Override + public int compare(IN o1, IN o2) { + return sortComparator.compare(o1, o2); + } + }); } FunctionUtils.setFunctionRuntimeContext(function, ctx); @@ -133,7 +137,7 @@ public int compare(IN o1, IN o2) { } else { final TypeSerializer inputSerializer = inputType.createSerializer(executionConfig); boolean[] keyOrderings = new boolean[keyColumns.length]; - final TypeComparator comparator = ((CompositeType) inputType).createComparator(keyColumns, keyOrderings, 0, executionConfig); + final TypeComparator comparator = getTypeComparator(inputType, keyColumns, keyOrderings, executionConfig); ListKeyGroupedIterator keyedIterator = new ListKeyGroupedIterator(inputData, inputSerializer, comparator); diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java index 57f07f38be91a..3056fe7b00dfa 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java @@ -34,6 +34,7 @@ import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; @@ -148,6 +149,16 @@ public void setCustomPartitioner(Partitioner customPartitioner) { public Partitioner getCustomPartitioner() { return customPartitioner; } + + private TypeComparator getTypeComparator(TypeInformation typeInfo, int[] sortColumns, boolean[] sortOrderings, ExecutionConfig executionConfig) { + if (typeInfo instanceof CompositeType) { + return ((CompositeType) typeInfo).createComparator(sortColumns, sortOrderings, 0, executionConfig); + } else if (typeInfo instanceof AtomicType) { + return ((AtomicType) typeInfo).createComparator(sortOrderings[0], executionConfig); + } + + throw new InvalidProgramException("Input type of GroupReduce must be one of composite types or atomic types."); + } // -------------------------------------------------------------------------------------------- @@ -159,11 +170,6 @@ protected List executeOnCollections(List inputData, RuntimeContext ctx, TypeInformation inputType = operatorInfo.getInputType(); int[] keyColumns = getKeyColumns(0); - - if (!(inputType instanceof CompositeType) && (keyColumns.length > 0 || groupOrder != null)) { - throw new InvalidProgramException("Grouping or group-sorting is only possible on composite type."); - } - int[] sortColumns = keyColumns; boolean[] sortOrderings = new boolean[sortColumns.length]; @@ -172,19 +178,16 @@ protected List executeOnCollections(List inputData, RuntimeContext ctx, sortOrderings = ArrayUtils.addAll(sortOrderings, groupOrder.getFieldSortDirections()); } - if (inputType instanceof CompositeType) { - if(sortColumns.length == 0) { // => all reduce. No comparator - Preconditions.checkArgument(sortOrderings.length == 0); - } else { - final TypeComparator sortComparator = ((CompositeType) inputType).createComparator(sortColumns, sortOrderings, 0, executionConfig); - - Collections.sort(inputData, new Comparator() { - @Override - public int compare(IN o1, IN o2) { - return sortComparator.compare(o1, o2); - } - }); - } + if(sortColumns.length == 0) { // => all reduce. No comparator + Preconditions.checkArgument(sortOrderings.length == 0); + } else { + final TypeComparator sortComparator = getTypeComparator(inputType, sortColumns, sortOrderings, executionConfig); + Collections.sort(inputData, new Comparator() { + @Override + public int compare(IN o1, IN o2) { + return sortComparator.compare(o1, o2); + } + }); } FunctionUtils.setFunctionRuntimeContext(function, ctx); @@ -205,7 +208,7 @@ public int compare(IN o1, IN o2) { } else { final TypeSerializer inputSerializer = inputType.createSerializer(executionConfig); boolean[] keyOrderings = new boolean[keyColumns.length]; - final TypeComparator comparator = ((CompositeType) inputType).createComparator(keyColumns, keyOrderings, 0, executionConfig); + final TypeComparator comparator = getTypeComparator(inputType, keyColumns, keyOrderings, executionConfig); ListKeyGroupedIterator keyedIterator = new ListKeyGroupedIterator(inputData, inputSerializer, comparator); diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java index a2cde0794c2b9..ee233e8e48ffd 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java @@ -274,24 +274,33 @@ public static List removeNullElementsFromList(List in) { * Create ExpressionKeys from String-expressions */ public ExpressionKeys(String[] expressionsIn, TypeInformation type) { - if(!(type instanceof CompositeType)) { - throw new IllegalArgumentException("Key expressions are only supported on POJO types and Tuples. " - + "A type is considered a POJO if all its fields are public, or have both getters and setters defined"); - } - CompositeType cType = (CompositeType) type; - - String[] expressions = removeDuplicates(expressionsIn); - if(expressionsIn.length != expressions.length) { - LOG.warn("The key expressions contained duplicates. They are now unique"); - } - // extract the keys on their flat position - keyFields = new ArrayList(expressions.length); - for (int i = 0; i < expressions.length; i++) { - List keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check - if(keys.size() == 0) { - throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType); + Preconditions.checkNotNull(expressionsIn, "Field expression cannot be null."); + + if (type instanceof AtomicType) { + if (!type.isKeyType()) { + throw new InvalidProgramException("This type (" + type + ") cannot be used as key."); + } else if (expressionsIn.length != 1 || !(Keys.ExpressionKeys.SELECT_ALL_CHAR.equals(expressionsIn[0]) || Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA.equals(expressionsIn[0]))) { + throw new IllegalArgumentException("Field expression for atomic type must be equal to '*' or '_'."); + } + + keyFields = new ArrayList(1); + keyFields.add(new FlatFieldDescriptor(0, type)); + } else { + CompositeType cType = (CompositeType) type; + + String[] expressions = removeDuplicates(expressionsIn); + if(expressionsIn.length != expressions.length) { + LOG.warn("The key expressions contained duplicates. They are now unique"); + } + // extract the keys on their flat position + keyFields = new ArrayList(expressions.length); + for (int i = 0; i < expressions.length; i++) { + List keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check + if(keys.size() == 0) { + throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType); + } + keyFields.addAll(keys); } - keyFields.addAll(keys); } } @@ -410,7 +419,7 @@ private static final int[] rangeCheckFields(int[] fields, int maxAllowedField) { return Arrays.copyOfRange(fields, 0, k+1); } } - + public static class IncompatibleKeysException extends Exception { private static final long serialVersionUID = 1L; public static final String SIZE_MISMATCH_MESSAGE = "The number of specified keys is different."; diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java index 60754e661d2e7..f32f6a953220f 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java @@ -18,26 +18,26 @@ package org.apache.flink.api.java.operator; -import java.util.ArrayList; -import java.util.List; - +import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.operators.SemanticProperties; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.functions.FunctionAnnotation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType; import org.apache.flink.api.java.operators.CoGroupOperator; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.util.Collector; -import org.junit.Assert; -import org.apache.flink.api.common.InvalidProgramException; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.util.Collector; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType; + +import java.util.ArrayList; +import java.util.List; import static org.junit.Assert.assertTrue; @@ -181,6 +181,78 @@ public void testCoGroupKeyExpressions4() { // should not work, cogroup key non-existent ds1.coGroup(ds2).where("myNonExistent").equalTo("myInt"); } + + @Test + public void testCoGroupKeyAtomicExpression1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromElements(0, 0, 1); + + ds1.coGroup(ds2).where("myInt").equalTo("*"); + } + + @Test + public void testCoGroupKeyAtomicExpression2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 1); + DataSet ds2 = env.fromCollection(customTypeData); + + ds1.coGroup(ds2).where("*").equalTo("myInt"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 1); + DataSet ds2 = env.fromCollection(customTypeData); + + ds1.coGroup(ds2).where("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 1); + DataSet ds2 = env.fromCollection(customTypeData); + + ds1.coGroup(ds2).where("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromElements(0, 0, 1); + + ds1.coGroup(ds2).where("myInt").equalTo("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression4() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromCollection(customTypeData); + DataSet ds2 = env.fromElements(0, 0, 1); + + ds1.coGroup(ds2).where("myInt").equalTo("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression5() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromElements(new ArrayList()); + DataSet ds2 = env.fromElements(0, 0, 0); + + ds1.coGroup(ds2).where("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression6() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 0); + DataSet> ds2 = env.fromElements(new ArrayList()); + + ds1.coGroup(ds2).where("*").equalTo("*"); + } @Test public void testCoGroupKeyExpressions1Nested() { diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java index 314695fdb83d4..b3922b3974439 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java @@ -143,7 +143,7 @@ public void testGroupByKeyExpressions1() { } } - @Test(expected = IllegalArgumentException.class) + @Test(expected = InvalidProgramException.class) public void testGroupByKeyExpressions2() { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); @@ -551,6 +551,38 @@ public Long[] getKey(Tuple4 value) throws Exc }, Order.ASCENDING); } + @Test + public void testGroupingAtomicType() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet dataSet = env.fromElements(0, 1, 1, 2, 0, 0); + + dataSet.groupBy("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testGroupAtomicTypeWithInvalid1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet dataSet = env.fromElements(0, 1, 2, 3); + + dataSet.groupBy("*", "invalidField"); + } + + @Test(expected = InvalidProgramException.class) + public void testGroupAtomicTypeWithInvalid2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet dataSet = env.fromElements(0, 1, 2, 3); + + dataSet.groupBy("invalidField"); + } + + @Test(expected = InvalidProgramException.class) + public void testGroupAtomicTypeWithInvalid3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> dataSet = env.fromElements(new ArrayList()); + + dataSet.groupBy("*"); + } + public static class CustomType implements Serializable { diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java index f1aadca25e7aa..be964ccc67d84 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java @@ -585,6 +585,78 @@ public Long getKey(CustomType value) { } ); } + + @Test + public void testJoinKeyAtomic1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 0); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + + ds1.join(ds2).where("*").equalTo(0); + } + + @Test + public void testJoinKeyAtomic2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where(0).equalTo("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 0); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + + ds1.join(ds2).where("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where(0).equalTo("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 0); + DataSet> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + + ds1.join(ds2).where("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic4() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where(0).equalTo("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic5() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = env.fromElements(new ArrayList()); + DataSet ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where("*").equalTo("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic6() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 0, 0); + DataSet> ds2 = env.fromElements(new ArrayList()); + + ds1.join(ds2).where("*").equalTo("*"); + } @Test public void testJoinProjection1() { diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java index 5fdf3dd0cd415..a685ff43c2f1c 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java @@ -41,7 +41,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializerFactory; import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator; import org.apache.flink.api.java.tuple.Tuple; -import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory; @@ -305,10 +304,6 @@ else if (typeInfo instanceof AtomicType) { } private static TypePairComparatorFactory createPairComparator(TypeInformation typeInfo1, TypeInformation typeInfo2) { - if (!(typeInfo1.isTupleType() || typeInfo1 instanceof PojoTypeInfo) && (typeInfo2.isTupleType() || typeInfo2 instanceof PojoTypeInfo)) { - throw new RuntimeException("The runtime currently supports only keyed binary operations (such as joins) on tuples and POJO types."); - } - // @SuppressWarnings("unchecked") // TupleTypeInfo info1 = (TupleTypeInfo) typeInfo1; // @SuppressWarnings("unchecked") diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java index 99f568e5241c7..84c05d6cc0920 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java @@ -18,10 +18,6 @@ package org.apache.flink.test.javaApiOperators; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.java.DataSet; @@ -47,6 +43,10 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + @RunWith(Parameterized.class) public class CoGroupITCase extends MultipleProgramsTestBase { @@ -488,6 +488,36 @@ public void testCoGroupFieldSelectorAndKeySelector() throws Exception { "-1,30000,Flink\n"; } + @Test + public void testCoGroupWithAtomicType1() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); + DataSet ds2 = env.fromElements(0, 1, 2); + + DataSet> coGroupDs = ds1.coGroup(ds2).where(0).equalTo("*").with(new CoGroupAtomic1()); + + coGroupDs.writeAsText(resultPath); + env.execute(); + + expected = "(1,1,Hi)\n" + + "(2,2,Hello)"; + } + + @Test + public void testCoGroupWithAtomicType2() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds1 = env.fromElements(0, 1, 2); + DataSet> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); + + DataSet> coGroupDs = ds1.coGroup(ds2).where("*").equalTo(0).with(new CoGroupAtomic2()); + + coGroupDs.writeAsText(resultPath); + env.execute(); + + expected = "(1,1,Hi)\n" + + "(2,2,Hello)"; + } + public static class KeySelector1 implements KeySelector { private static final long serialVersionUID = 1L; @@ -719,4 +749,48 @@ public void coGroup(Iterable> first } } } + + public static class CoGroupAtomic1 implements CoGroupFunction, Integer, Tuple3> { + + private static final long serialVersionUID = 1L; + + @Override + public void coGroup(Iterable> first, Iterable second, Collector> out) throws Exception { + List ints = new ArrayList(); + + for (Integer i : second) { + ints.add(i); + } + + for (Tuple3 t : first) { + for (Integer i : ints) { + if (t.f0.equals(i)) { + out.collect(t); + } + } + } + } + } + + public static class CoGroupAtomic2 implements CoGroupFunction, Tuple3> { + + private static final long serialVersionUID = 1L; + + @Override + public void coGroup(Iterable first, Iterable> second, Collector> out) throws Exception { + List ints = new ArrayList(); + + for (Integer i : first) { + ints.add(i); + } + + for (Tuple3 t : second) { + for (Integer i : ints) { + if (t.f0.equals(i)) { + out.collect(t); + } + } + } + } + } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java index 9eb9a378d5f45..cf6b529d6675e 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java @@ -18,21 +18,20 @@ package org.apache.flink.test.javaApiOperators; -import java.util.Collection; -import java.util.Iterator; - import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.operators.Order; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple5; -import org.apache.flink.optimizer.Optimizer; import org.apache.flink.configuration.Configuration; +import org.apache.flink.optimizer.Optimizer; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CrazyNested; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CustomType; @@ -49,10 +48,13 @@ import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; import scala.math.BigInt; +import java.util.Collection; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + @SuppressWarnings("serial") @RunWith(Parameterized.class) public class GroupReduceITCase extends MultipleProgramsTestBase { @@ -1063,6 +1065,26 @@ public void reduce(Iterable> values } + @Test + public void testGroupReduceWithAtomicValue() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet ds = env.fromElements(1, 1, 2, 3, 4); + DataSet reduceDs = ds.groupBy("*").reduceGroup(new GroupReduceFunction() { + @Override + public void reduce(Iterable values, Collector out) throws Exception { + out.collect(values.iterator().next()); + } + }); + + reduceDs.writeAsText(resultPath); + env.execute(); + + expected = "1\n" + + "2\n" + + "3\n" + + "4"; + } + public static class GroupReducer8 implements GroupReduceFunction { @Override public void reduce( diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java index 0080fb1845152..fe436a3fc2480 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java @@ -663,6 +663,37 @@ public void testNonPojoToVerifyNestedTupleElementSelectionWithFirstKeyFieldGreat "((3,2,Hello world),(3,2,Hello world)),((3,2,Hello world),(3,2,Hello world))\n"; } + @Test + public void testJoinWithAtomicType1() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); + DataSet ds2 = env.fromElements(1, 2); + + DataSet, Integer>> joinDs = ds1.join(ds2).where(0).equalTo("*"); + + joinDs.writeAsCsv(resultPath); + env.execute(); + + expected = "(1,1,Hi),1\n" + + "(2,2,Hello),2"; + } + + public void testJoinWithAtomicType2() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet ds1 = env.fromElements(1, 2); + DataSet> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); + + DataSet>> joinDs = ds1.join(ds2).where("*").equalTo(0); + + joinDs.writeAsCsv(resultPath); + env.execute(); + + expected = "1,(1,1,Hi)\n" + + "2,(2,2,Hello)"; + } + public static class T3T5FlatJoin implements FlatJoinFunction, Tuple5, Tuple2> { @Override