Skip to content

Commit

Permalink
[FLINK-1367] [scala] [streaming] Field aggregations added to streamin…
Browse files Browse the repository at this point in the history
…g scala api
  • Loading branch information
gyfora authored and mbalassi committed Jan 8, 2015
1 parent 10a8186 commit 06503c8
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object TopSpeedWindowing {
.window(Time.of(evictionSec, SECONDS))
.every(Delta.of[CarEvent](triggerMeters,
(oldSp,newSp) => newSp.distance-oldSp.distance, CarEvent(0,0,0,0)))
.reduce((x, y) => if (x.speed > y.speed) x else y)
.maxBy("speed")

cars print

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class DataStream[T](javaStream: JavaStream[T]) {
/**
* Sets the partitioning of the DataStream so that the output values all go to
* the first instance of the next processing operator. Use this setting with care
* since it might cause a serious performance bottlenect in the application.
* since it might cause a serious performance bottleneck in the application.
*/
def global: DataStream[T] = javaStream.global()

Expand Down Expand Up @@ -203,39 +203,78 @@ class DataStream[T](javaStream: JavaStream[T]) {
*
*/
def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position)


/**
* Applies an aggregation that that gives the current maximum of the data stream at
* the given field.
*
*/
def max(field: String): DataStream[T] = aggregate(AggregationType.MAX, field)

/**
* Applies an aggregation that that gives the current minimum of the data stream at
* the given position.
*
*/
def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position)

/**
* Applies an aggregation that that gives the current minimum of the data stream at
* the given field.
*
*/
def min(field: String): DataStream[T] = aggregate(AggregationType.MIN, field)

/**
* Applies an aggregation that sums the data stream at the given position.
*
*/
def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position)

/**
* Applies an aggregation that sums the data stream at the given field.
*
*/
def sum(field: String): DataStream[T] = aggregate(AggregationType.SUM, field)

/**
* Applies an aggregation that that gives the current minimum element of the data stream by
* the given position. When equality, the user can set to get the first or last element with
* the minimal value.
* the given position. When equality, the first element is returned with the minimal value.
*
*/
def minBy(position: Int): DataStream[T] = aggregate(AggregationType
.MINBY, position)

/**
* Applies an aggregation that that gives the current minimum element of the data stream by
* the given field. When equality, the first element is returned with the minimal value.
*
*/
def minBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType
.MINBY, position, first)
def minBy(field: String): DataStream[T] = aggregate(AggregationType
.MINBY, field )

/**
/**
* Applies an aggregation that that gives the current maximum element of the data stream by
* the given position. When equality, the first element is returned with the maximal value.
*
*/
def maxBy(position: Int): DataStream[T] =
aggregate(AggregationType.MAXBY, position)

/**
* Applies an aggregation that that gives the current maximum element of the data stream by
* the given position. When equality, the user can set to get the first or last element with
* the maximal value.
* the given field. When equality, the first element is returned with the maximal value.
*
*/
def maxBy(position: Int, first: Boolean = true): DataStream[T] =
aggregate(AggregationType.MAXBY, position, first)
def maxBy(field: String): DataStream[T] =
aggregate(AggregationType.MAXBY, field)

private def aggregate(aggregationType: AggregationType, field: String): DataStream[T] = {
val position = fieldNames2Indices(javaStream.getType(), Array(field))(0)
aggregate(aggregationType, position)
}

private def aggregate(aggregationType: AggregationType, position: Int, first: Boolean = true):
private def aggregate(aggregationType: AggregationType, position: Int):
DataStream[T] = {

val jStream = javaStream.asInstanceOf[JavaStream[Product]]
Expand All @@ -246,7 +285,7 @@ class DataStream[T](javaStream: JavaStream[T]) {
val reducer = aggregationType match {
case AggregationType.SUM => new agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).
getTypeClass()));
case _ => new agg.ProductComparableAggregator(aggregationType, first)
case _ => new agg.ProductComparableAggregator(aggregationType, true)
}

val invokable = jStream match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,37 +157,78 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
*
*/
def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position)

/**
* Applies an aggregation that that gives the maximum of the elements in the window at
* the given field.
*
*/
def max(field: String): DataStream[T] = aggregate(AggregationType.MAX, field)

/**
* Applies an aggregation that that gives the minimum of the elements in the window at
* the given position.
*
*/
def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position)

/**
* Applies an aggregation that that gives the minimum of the elements in the window at
* the given field.
*
*/
def min(field: String): DataStream[T] = aggregate(AggregationType.MIN, field)

/**
* Applies an aggregation that sums the elements in the window at the given position.
*
*/
def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position)

/**
* Applies an aggregation that sums the elements in the window at the given field.
*
*/
def sum(field: String): DataStream[T] = aggregate(AggregationType.SUM, field)

/**
* Applies an aggregation that that gives the maximum element of the window by
* the given position. When equality, returns the first.
*
*/
def maxBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MAXBY,
position, first)
def maxBy(position: Int): DataStream[T] = aggregate(AggregationType.MAXBY,
position)

/**
* Applies an aggregation that that gives the maximum element of the window by
* the given field. When equality, returns the first.
*
*/
def maxBy(field: String): DataStream[T] = aggregate(AggregationType.MAXBY,
field)

/**
* Applies an aggregation that that gives the minimum element of the window by
* the given position. When equality, returns the first.
*
*/
def minBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MINBY,
position, first)
def minBy(position: Int): DataStream[T] = aggregate(AggregationType.MINBY,
position)

/**
* Applies an aggregation that that gives the minimum element of the window by
* the given field. When equality, returns the first.
*
*/
def minBy(field: String): DataStream[T] = aggregate(AggregationType.MINBY,
field)

private def aggregate(aggregationType: AggregationType, field: String): DataStream[T] = {
val position = fieldNames2Indices(javaStream.getType(), Array(field))(0)
aggregate(aggregationType, position)
}

def aggregate(aggregationType: AggregationType, position: Int, first: Boolean = true):
def aggregate(aggregationType: AggregationType, position: Int):
DataStream[T] = {

val jStream = javaStream.asInstanceOf[JavaWStream[Product]]
Expand All @@ -198,7 +239,7 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
val reducer = aggregationType match {
case AggregationType.SUM => new agg.Sum(SumFunction.getForClass(
outType.getTypeAt(position).getTypeClass()));
case _ => new agg.ProductComparableAggregator(aggregationType, first)
case _ => new agg.ProductComparableAggregator(aggregationType, true)
}

new DataStream[Product](jStream.reduce(reducer)).asInstanceOf[DataStream[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,24 @@ package object scala {

implicit def javaToScalaConnectedStream[IN1, IN2](javaStream: JavaConStream[IN1, IN2]):
ConnectedDataStream[IN1, IN2] = new ConnectedDataStream[IN1, IN2](javaStream)

private[flink] def fieldNames2Indices(
typeInfo: TypeInformation[_],
fields: Array[String]): Array[Int] = {
typeInfo match {
case ti: CaseClassTypeInfo[_] =>
val result = ti.getFieldIndices(fields)

if (result.contains(-1)) {
throw new IllegalArgumentException("Fields '" + fields.mkString(", ") +
"' are not valid for '" + ti.toString + "'.")
}

result

case _ =>
throw new UnsupportedOperationException("Specifying fields by name is only" +
"supported on Case Classes (for now).")
}
}
}

0 comments on commit 06503c8

Please sign in to comment.