Skip to content

Commit

Permalink
[FLINK-1188] [streaming] Updated aggregations to work also on arrays …
Browse files Browse the repository at this point in the history
…by default
  • Loading branch information
gyfora authored and mbalassi committed Oct 27, 2014
1 parent a221796 commit 7709a3a
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public BatchedDataStream<OUT> groupBy(int keyPosition) {
public SingleOutputStreamOperator<OUT, ?> sum(int positionToSum) {
dataStream.checkFieldRange(positionToSum);
return aggregate((AggregationFunction<OUT>) SumAggregationFunction.getSumFunction(
positionToSum, dataStream.getClassAtPos(positionToSum)));
positionToSum, dataStream.getClassAtPos(positionToSum), dataStream.getOutputType()));
}

/**
Expand All @@ -159,7 +159,7 @@ public BatchedDataStream<OUT> groupBy(int keyPosition) {
*/
public SingleOutputStreamOperator<OUT, ?> min(int positionToMin) {
dataStream.checkFieldRange(positionToMin);
return aggregate(new MinAggregationFunction<OUT>(positionToMin));
return aggregate(new MinAggregationFunction<OUT>(positionToMin, dataStream.getOutputType()));
}

/**
Expand Down Expand Up @@ -191,7 +191,8 @@ public BatchedDataStream<OUT> groupBy(int keyPosition) {
*/
public SingleOutputStreamOperator<OUT, ?> minBy(int positionToMinBy, boolean first) {
dataStream.checkFieldRange(positionToMinBy);
return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first));
return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first,
dataStream.getOutputType()));
}

/**
Expand All @@ -213,7 +214,7 @@ public BatchedDataStream<OUT> groupBy(int keyPosition) {
*/
public SingleOutputStreamOperator<OUT, ?> max(int positionToMax) {
dataStream.checkFieldRange(positionToMax);
return aggregate(new MaxAggregationFunction<OUT>(positionToMax));
return aggregate(new MaxAggregationFunction<OUT>(positionToMax, dataStream.getOutputType()));
}

/**
Expand Down Expand Up @@ -244,7 +245,8 @@ public BatchedDataStream<OUT> groupBy(int keyPosition) {
*/
public SingleOutputStreamOperator<OUT, ?> maxBy(int positionToMaxBy, boolean first) {
dataStream.checkFieldRange(positionToMaxBy);
return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first));
return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first,
dataStream.getOutputType()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple;
Expand Down Expand Up @@ -201,6 +203,31 @@ protected Class<?> getClassAtPos(int pos) {
TypeInformation<OUT> outTypeInfo = outTypeWrapper.getTypeInfo();
if (outTypeInfo.isTupleType()) {
type = ((TupleTypeInfo) outTypeInfo).getTypeAt(pos).getTypeClass();

} else if (outTypeInfo instanceof BasicArrayTypeInfo) {

type = ((BasicArrayTypeInfo) outTypeInfo).getComponentTypeClass();

} else if (outTypeInfo instanceof PrimitiveArrayTypeInfo) {
Class<?> clazz = outTypeInfo.getTypeClass();
if (clazz == boolean[].class) {
type = Boolean.class;
} else if (clazz == short[].class) {
type = Short.class;
} else if (clazz == int[].class) {
type = Integer.class;
} else if (clazz == long[].class) {
type = Long.class;
} else if (clazz == float[].class) {
type = Float.class;
} else if (clazz == double[].class) {
type = Double.class;
} else if (clazz == char[].class) {
type = Character.class;
} else {
throw new IndexOutOfBoundsException("Type could not be determined for array");
}

} else if (pos == 0) {
type = outTypeInfo.getTypeClass();
} else {
Expand Down Expand Up @@ -594,7 +621,7 @@ public WindowDataStream<OUT> window(long windowSize) {
public SingleOutputStreamOperator<OUT, ?> sum(int positionToSum) {
checkFieldRange(positionToSum);
return aggregate((AggregationFunction<OUT>) SumAggregationFunction.getSumFunction(
positionToSum, getClassAtPos(positionToSum)));
positionToSum, getClassAtPos(positionToSum), getOutputType()));
}

/**
Expand All @@ -616,7 +643,7 @@ public WindowDataStream<OUT> window(long windowSize) {
*/
public SingleOutputStreamOperator<OUT, ?> min(int positionToMin) {
checkFieldRange(positionToMin);
return aggregate(new MinAggregationFunction<OUT>(positionToMin));
return aggregate(new MinAggregationFunction<OUT>(positionToMin, getOutputType()));
}

/**
Expand Down Expand Up @@ -648,7 +675,7 @@ public WindowDataStream<OUT> window(long windowSize) {
*/
public SingleOutputStreamOperator<OUT, ?> minBy(int positionToMinBy, boolean first) {
checkFieldRange(positionToMinBy);
return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first));
return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first, getOutputType()));
}

/**
Expand All @@ -670,7 +697,7 @@ public WindowDataStream<OUT> window(long windowSize) {
*/
public SingleOutputStreamOperator<OUT, ?> max(int positionToMax) {
checkFieldRange(positionToMax);
return aggregate(new MaxAggregationFunction<OUT>(positionToMax));
return aggregate(new MaxAggregationFunction<OUT>(positionToMax, getOutputType()));
}

/**
Expand Down Expand Up @@ -702,7 +729,7 @@ public WindowDataStream<OUT> window(long windowSize) {
*/
public SingleOutputStreamOperator<OUT, ?> maxBy(int positionToMaxBy, boolean first) {
checkFieldRange(positionToMaxBy);
return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first));
return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first, getOutputType()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,23 @@
package org.apache.flink.streaming.api.function.aggregation;

import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple;

public abstract class AggregationFunction<T> implements ReduceFunction<T> {
private static final long serialVersionUID = 1L;

public int position;
protected Tuple returnTuple;
protected boolean isTuple;
protected boolean isArray;

public AggregationFunction(int pos) {
public AggregationFunction(int pos, TypeInformation<?> type) {
this.position = pos;
this.isTuple = type.isTupleType();
this.isArray = type instanceof BasicArrayTypeInfo || type instanceof PrimitiveArrayTypeInfo;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,56 @@

package org.apache.flink.streaming.api.function.aggregation;

import java.lang.reflect.Array;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple;

public abstract class ComparableAggregationFunction<T> extends AggregationFunction<T> {

private static final long serialVersionUID = 1L;

public ComparableAggregationFunction(int positionToAggregate) {
super(positionToAggregate);
public ComparableAggregationFunction(int positionToAggregate, TypeInformation<?> type) {
super(positionToAggregate, type);
}

@SuppressWarnings("unchecked")
@Override
public T reduce(T value1, T value2) throws Exception {
if (value1 instanceof Tuple) {
if (isTuple) {
Tuple t1 = (Tuple) value1;
Tuple t2 = (Tuple) value2;

compare(t1, t2);

return (T) returnTuple;
} else if (isArray) {
return compareArray(value1, value2);
} else if (value1 instanceof Comparable) {
if (isExtremal((Comparable<Object>) value1, value2)) {
return value1;
}else{
} else {
return value2;
}
} else {
throw new RuntimeException("The values " + value1 + " and "+ value2 + " cannot be compared.");
throw new RuntimeException("The values " + value1 + " and " + value2
+ " cannot be compared.");
}
}

@SuppressWarnings("unchecked")
public T compareArray(T array1, T array2) {
Object v1 = Array.get(array1, position);
Object v2 = Array.get(array2, position);
if (isExtremal((Comparable<Object>) v1, v2)) {
Array.set(array2, position, v1);
} else {
Array.set(array2, position, v2);
}

return array2;
}

public <R> void compare(Tuple tuple1, Tuple tuple2) throws InstantiationException,
IllegalAccessException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.flink.streaming.api.function.aggregation;

import org.apache.flink.api.common.typeinfo.TypeInformation;

public class MaxAggregationFunction<T> extends ComparableAggregationFunction<T> {

private static final long serialVersionUID = 1L;

public MaxAggregationFunction(int pos) {
super(pos);
public MaxAggregationFunction(int pos, TypeInformation<?> type) {
super(pos, type);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.flink.streaming.api.function.aggregation;

import org.apache.flink.api.common.typeinfo.TypeInformation;

public class MaxByAggregationFunction<T> extends MinByAggregationFunction<T> {

private static final long serialVersionUID = 1L;

public MaxByAggregationFunction(int pos, boolean first) {
super(pos, first);
public MaxByAggregationFunction(int pos, boolean first, TypeInformation<?> type) {
super(pos, first, type);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.flink.streaming.api.function.aggregation;

import org.apache.flink.api.common.typeinfo.TypeInformation;

public class MinAggregationFunction<T> extends ComparableAggregationFunction<T> {

private static final long serialVersionUID = 1L;

public MinAggregationFunction(int pos) {
super(pos);
public MinAggregationFunction(int pos, TypeInformation<?> type) {
super(pos, type);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

package org.apache.flink.streaming.api.function.aggregation;

import java.lang.reflect.Array;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple;

public class MinByAggregationFunction<T> extends ComparableAggregationFunction<T> {

private static final long serialVersionUID = 1L;
protected boolean first;

public MinByAggregationFunction(int pos, boolean first) {
super(pos);
public MinByAggregationFunction(int pos, boolean first, TypeInformation<?> type) {
super(pos, type);
this.first = first;
}

Expand All @@ -43,6 +46,18 @@ public <R> void compare(Tuple tuple1, Tuple tuple2) throws InstantiationExceptio
}
}

@Override
@SuppressWarnings("unchecked")
public T compareArray(T array1, T array2) {
Object v1 = Array.get(array1, position);
Object v2 = Array.get(array2, position);
if (isExtremal((Comparable<Object>) v1, v2)) {
return array1;
} else {
return array2;
}
}

@Override
public <R> boolean isExtremal(Comparable<R> o1, R o2) {
if (first) {
Expand Down
Loading

0 comments on commit 7709a3a

Please sign in to comment.