Skip to content

Commit

Permalink
[FLINK-703] [java api] Use complete element as join key
Browse files Browse the repository at this point in the history
  • Loading branch information
chiwanpark authored and fhueske committed Apr 21, 2015
1 parent e1618e2 commit 30a74c7
Show file tree
Hide file tree
Showing 11 changed files with 406 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
Expand Down Expand Up @@ -273,12 +274,15 @@ protected List<OUT> executeOnCollections(List<IN1> input1, List<IN2> input2, Run
return result;
}

@SuppressWarnings("unchecked")
private <T> TypeComparator<T> getTypeComparator(ExecutionConfig executionConfig, TypeInformation<T> inputType, int[] inputKeys, boolean[] inputSortDirections) {
if (!(inputType instanceof CompositeType)) {
throw new InvalidProgramException("Input types of coGroup must be composite types.");
if (inputType instanceof CompositeType) {
return ((CompositeType<T>) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig);
} else if (inputType instanceof AtomicType) {
return ((AtomicType<T>) inputType).createComparator(inputSortDirections[0], executionConfig);
}

return ((CompositeType<T>) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig);
throw new InvalidProgramException("Input type of coGroup must be one of composite types or atomic types.");
}

private static class CoGroupSortListIterator<IN1, IN2> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
Expand Down Expand Up @@ -77,6 +78,16 @@ public Ordering getGroupOrder() {
return this.groupOrder;
}

private TypeComparator<IN> getTypeComparator(TypeInformation<IN> typeInfo, int[] sortColumns, boolean[] sortOrderings, ExecutionConfig executionConfig) {
if (typeInfo instanceof CompositeType) {
return ((CompositeType<IN>) typeInfo).createComparator(sortColumns, sortOrderings, 0, executionConfig);
} else if (typeInfo instanceof AtomicType) {
return ((AtomicType<IN>) typeInfo).createComparator(sortOrderings[0], executionConfig);
}

throw new InvalidProgramException("Input type of GroupCombine must be one of composite types or atomic types.");
}

// --------------------------------------------------------------------------------------------

@Override
Expand All @@ -87,11 +98,6 @@ protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx,
TypeInformation<IN> inputType = operatorInfo.getInputType();

int[] keyColumns = getKeyColumns(0);

if (!(inputType instanceof CompositeType) && (keyColumns.length > 0 || groupOrder != null)) {
throw new InvalidProgramException("Grouping or group-sorting is only possible on composite type.");
}

int[] sortColumns = keyColumns;
boolean[] sortOrderings = new boolean[sortColumns.length];

Expand All @@ -100,19 +106,17 @@ protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx,
sortOrderings = ArrayUtils.addAll(sortOrderings, groupOrder.getFieldSortDirections());
}

if (inputType instanceof CompositeType) {
if(sortColumns.length == 0) { // => all reduce. No comparator
Preconditions.checkArgument(sortOrderings.length == 0);
} else {
final TypeComparator<IN> sortComparator = ((CompositeType<IN>) inputType).createComparator(sortColumns, sortOrderings, 0, executionConfig);

Collections.sort(inputData, new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return sortComparator.compare(o1, o2);
}
});
}
if(sortColumns.length == 0) { // => all reduce. No comparator
Preconditions.checkArgument(sortOrderings.length == 0);
} else {
final TypeComparator<IN> sortComparator = getTypeComparator(inputType, sortColumns, sortOrderings, executionConfig);

Collections.sort(inputData, new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return sortComparator.compare(o1, o2);
}
});
}

FunctionUtils.setFunctionRuntimeContext(function, ctx);
Expand All @@ -133,7 +137,7 @@ public int compare(IN o1, IN o2) {
} else {
final TypeSerializer<IN> inputSerializer = inputType.createSerializer(executionConfig);
boolean[] keyOrderings = new boolean[keyColumns.length];
final TypeComparator<IN> comparator = ((CompositeType<IN>) inputType).createComparator(keyColumns, keyOrderings, 0, executionConfig);
final TypeComparator<IN> comparator = getTypeComparator(inputType, keyColumns, keyOrderings, executionConfig);

ListKeyGroupedIterator<IN> keyedIterator = new ListKeyGroupedIterator<IN>(inputData, inputSerializer, comparator);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
Expand Down Expand Up @@ -148,6 +149,16 @@ public void setCustomPartitioner(Partitioner<?> customPartitioner) {
public Partitioner<?> getCustomPartitioner() {
return customPartitioner;
}

private TypeComparator<IN> getTypeComparator(TypeInformation<IN> typeInfo, int[] sortColumns, boolean[] sortOrderings, ExecutionConfig executionConfig) {
if (typeInfo instanceof CompositeType) {
return ((CompositeType<IN>) typeInfo).createComparator(sortColumns, sortOrderings, 0, executionConfig);
} else if (typeInfo instanceof AtomicType) {
return ((AtomicType<IN>) typeInfo).createComparator(sortOrderings[0], executionConfig);
}

throw new InvalidProgramException("Input type of GroupReduce must be one of composite types or atomic types.");
}

// --------------------------------------------------------------------------------------------

Expand All @@ -159,11 +170,6 @@ protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx,
TypeInformation<IN> inputType = operatorInfo.getInputType();

int[] keyColumns = getKeyColumns(0);

if (!(inputType instanceof CompositeType) && (keyColumns.length > 0 || groupOrder != null)) {
throw new InvalidProgramException("Grouping or group-sorting is only possible on composite type.");
}

int[] sortColumns = keyColumns;
boolean[] sortOrderings = new boolean[sortColumns.length];

Expand All @@ -172,19 +178,16 @@ protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx,
sortOrderings = ArrayUtils.addAll(sortOrderings, groupOrder.getFieldSortDirections());
}

if (inputType instanceof CompositeType) {
if(sortColumns.length == 0) { // => all reduce. No comparator
Preconditions.checkArgument(sortOrderings.length == 0);
} else {
final TypeComparator<IN> sortComparator = ((CompositeType<IN>) inputType).createComparator(sortColumns, sortOrderings, 0, executionConfig);

Collections.sort(inputData, new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return sortComparator.compare(o1, o2);
}
});
}
if(sortColumns.length == 0) { // => all reduce. No comparator
Preconditions.checkArgument(sortOrderings.length == 0);
} else {
final TypeComparator<IN> sortComparator = getTypeComparator(inputType, sortColumns, sortOrderings, executionConfig);
Collections.sort(inputData, new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return sortComparator.compare(o1, o2);
}
});
}

FunctionUtils.setFunctionRuntimeContext(function, ctx);
Expand All @@ -205,7 +208,7 @@ public int compare(IN o1, IN o2) {
} else {
final TypeSerializer<IN> inputSerializer = inputType.createSerializer(executionConfig);
boolean[] keyOrderings = new boolean[keyColumns.length];
final TypeComparator<IN> comparator = ((CompositeType<IN>) inputType).createComparator(keyColumns, keyOrderings, 0, executionConfig);
final TypeComparator<IN> comparator = getTypeComparator(inputType, keyColumns, keyOrderings, executionConfig);

ListKeyGroupedIterator<IN> keyedIterator = new ListKeyGroupedIterator<IN>(inputData, inputSerializer, comparator);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,24 +274,33 @@ public static <R> List<R> removeNullElementsFromList(List<R> in) {
* Create ExpressionKeys from String-expressions
*/
public ExpressionKeys(String[] expressionsIn, TypeInformation<T> type) {
if(!(type instanceof CompositeType<?>)) {
throw new IllegalArgumentException("Key expressions are only supported on POJO types and Tuples. "
+ "A type is considered a POJO if all its fields are public, or have both getters and setters defined");
}
CompositeType<T> cType = (CompositeType<T>) type;

String[] expressions = removeDuplicates(expressionsIn);
if(expressionsIn.length != expressions.length) {
LOG.warn("The key expressions contained duplicates. They are now unique");
}
// extract the keys on their flat position
keyFields = new ArrayList<FlatFieldDescriptor>(expressions.length);
for (int i = 0; i < expressions.length; i++) {
List<FlatFieldDescriptor> keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check
if(keys.size() == 0) {
throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType);
Preconditions.checkNotNull(expressionsIn, "Field expression cannot be null.");

if (type instanceof AtomicType) {
if (!type.isKeyType()) {
throw new InvalidProgramException("This type (" + type + ") cannot be used as key.");
} else if (expressionsIn.length != 1 || !(Keys.ExpressionKeys.SELECT_ALL_CHAR.equals(expressionsIn[0]) || Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA.equals(expressionsIn[0]))) {
throw new IllegalArgumentException("Field expression for atomic type must be equal to '*' or '_'.");
}

keyFields = new ArrayList<FlatFieldDescriptor>(1);
keyFields.add(new FlatFieldDescriptor(0, type));
} else {
CompositeType<T> cType = (CompositeType<T>) type;

String[] expressions = removeDuplicates(expressionsIn);
if(expressionsIn.length != expressions.length) {
LOG.warn("The key expressions contained duplicates. They are now unique");
}
// extract the keys on their flat position
keyFields = new ArrayList<FlatFieldDescriptor>(expressions.length);
for (int i = 0; i < expressions.length; i++) {
List<FlatFieldDescriptor> keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check
if(keys.size() == 0) {
throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType);
}
keyFields.addAll(keys);
}
keyFields.addAll(keys);
}
}

Expand Down Expand Up @@ -410,7 +419,7 @@ private static final int[] rangeCheckFields(int[] fields, int maxAllowedField) {
return Arrays.copyOfRange(fields, 0, k+1);
}
}

public static class IncompatibleKeysException extends Exception {
private static final long serialVersionUID = 1L;
public static final String SIZE_MISMATCH_MESSAGE = "The number of specified keys is different.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,26 @@

package org.apache.flink.api.java.operator;

import java.util.ArrayList;
import java.util.List;

import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.operators.SemanticProperties;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType;
import org.apache.flink.api.java.operators.CoGroupOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.assertTrue;

Expand Down Expand Up @@ -181,6 +181,78 @@ public void testCoGroupKeyExpressions4() {
// should not work, cogroup key non-existent
ds1.coGroup(ds2).where("myNonExistent").equalTo("myInt");
}

@Test
public void testCoGroupKeyAtomicExpression1() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<CustomType> ds1 = env.fromCollection(customTypeData);
DataSet<Integer> ds2 = env.fromElements(0, 0, 1);

ds1.coGroup(ds2).where("myInt").equalTo("*");
}

@Test
public void testCoGroupKeyAtomicExpression2() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Integer> ds1 = env.fromElements(0, 0, 1);
DataSet<CustomType> ds2 = env.fromCollection(customTypeData);

ds1.coGroup(ds2).where("*").equalTo("myInt");
}

@Test(expected = InvalidProgramException.class)
public void testCoGroupKeyAtomicInvalidExpression1() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Integer> ds1 = env.fromElements(0, 0, 1);
DataSet<CustomType> ds2 = env.fromCollection(customTypeData);

ds1.coGroup(ds2).where("*", "invalidKey");
}

@Test(expected = InvalidProgramException.class)
public void testCoGroupKeyAtomicInvalidExpression2() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Integer> ds1 = env.fromElements(0, 0, 1);
DataSet<CustomType> ds2 = env.fromCollection(customTypeData);

ds1.coGroup(ds2).where("invalidKey");
}

@Test(expected = InvalidProgramException.class)
public void testCoGroupKeyAtomicInvalidExpression3() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<CustomType> ds1 = env.fromCollection(customTypeData);
DataSet<Integer> ds2 = env.fromElements(0, 0, 1);

ds1.coGroup(ds2).where("myInt").equalTo("invalidKey");
}

@Test(expected = InvalidProgramException.class)
public void testCoGroupKeyAtomicInvalidExpression4() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<CustomType> ds1 = env.fromCollection(customTypeData);
DataSet<Integer> ds2 = env.fromElements(0, 0, 1);

ds1.coGroup(ds2).where("myInt").equalTo("*", "invalidKey");
}

@Test(expected = InvalidProgramException.class)
public void testCoGroupKeyAtomicInvalidExpression5() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<ArrayList<Integer>> ds1 = env.fromElements(new ArrayList<Integer>());
DataSet<Integer> ds2 = env.fromElements(0, 0, 0);

ds1.coGroup(ds2).where("*");
}

@Test(expected = InvalidProgramException.class)
public void testCoGroupKeyAtomicInvalidExpression6() {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Integer> ds1 = env.fromElements(0, 0, 0);
DataSet<ArrayList<Integer>> ds2 = env.fromElements(new ArrayList<Integer>());

ds1.coGroup(ds2).where("*").equalTo("*");
}

@Test
public void testCoGroupKeyExpressions1Nested() {
Expand Down
Loading

0 comments on commit 30a74c7

Please sign in to comment.