Skip to content

Commit

Permalink
[scala] Add and* methods to AggregateDateSet
Browse files Browse the repository at this point in the history
And create AggregateDataSet in the first place, to add the methods to.
  • Loading branch information
aljoscha committed Sep 22, 2014
1 parent b8a8780 commit c778d28
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala

import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.scala.operators.ScalaAggregateOperator

import scala.reflect.ClassTag

class AggregateDataSet[T: ClassTag](set: ScalaAggregateOperator[T])
extends DataSet[T](set) {

/**
* Adds the given aggregation on the given field to the previous aggregation operation.
*
* This only works on Tuple DataSets.
*/
def and(agg: Aggregations, field: Int): AggregateDataSet[T] = {
set.and(agg, field)
this
}

/**
* Adds the given aggregation on the given field to the previous aggregation operation.
*
* This only works on CaseClass DataSets.
*/
def and(agg: Aggregations, field: String): AggregateDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
set.and(agg, fieldIndex)
this
}

/**
* Syntactic sugar for [[and]] with `SUM`
*/
def andSum(field: Int) = {
and(Aggregations.SUM, field)
}

/**
* Syntactic sugar for [[and]] with `MAX`
*/
def andMax(field: Int) = {
and(Aggregations.MAX, field)
}

/**
* Syntactic sugar for [[and]] with `MIN`
*/
def andMin(field: Int) = {
and(Aggregations.MIN, field)
}

/**
* Syntactic sugar for [[and]] with `SUM`
*/
def andSum(field: String) = {
and(Aggregations.SUM, field)
}

/**
* Syntactic sugar for [[and]] with `MAX`
*/
def andMax(field: String) = {
and(Aggregations.MAX, field)
}

/**
* Syntactic sugar for [[and]] with `MIN`
*/
def andMin(field: String) = {
and(Aggregations.MIN, field)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.core.fs.FileSystem.WriteMode
import org.apache.flink.core.fs.{FileSystem, Path}
import org.apache.flink.types.TypeInformation
Expand Down Expand Up @@ -367,12 +366,8 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
*
* This only works on Tuple DataSets.
*/
def aggregate(agg: Aggregations, field: Int): DataSet[T] = set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, field)
wrap(aggregation)

case _ => wrap(new ScalaAggregateOperator[T](set, agg, field))
def aggregate(agg: Aggregations, field: Int): AggregateDataSet[T] = {
new AggregateDataSet(new ScalaAggregateOperator[T](set, agg, field))
}

/**
Expand All @@ -382,16 +377,10 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
*
* This only works on CaseClass DataSets.
*/
def aggregate(agg: Aggregations, field: String): DataSet[T] = {
def aggregate(agg: Aggregations, field: String): AggregateDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)

set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, fieldIndex)
wrap(aggregation)

case _ => wrap(new ScalaAggregateOperator[T](set, agg, fieldIndex))
}
new AggregateDataSet(new ScalaAggregateOperator[T](set, agg, fieldIndex))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.flink.api.scala

import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.scala.operators.ScalaAggregateOperator
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -67,7 +66,7 @@ trait GroupedDataSet[T] {
*
* This only works on Tuple DataSets.
*/
def aggregate(agg: Aggregations, field: Int): DataSet[T]
def aggregate(agg: Aggregations, field: Int): AggregateDataSet[T]

/**
* Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
Expand All @@ -76,37 +75,37 @@ trait GroupedDataSet[T] {
*
* This only works on CaseClass DataSets.
*/
def aggregate(agg: Aggregations, field: String): DataSet[T]
def aggregate(agg: Aggregations, field: String): AggregateDataSet[T]

/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: Int): DataSet[T]
def sum(field: Int): AggregateDataSet[T]

/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: Int): DataSet[T]
def max(field: Int): AggregateDataSet[T]

/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: Int): DataSet[T]
def min(field: Int): AggregateDataSet[T]

/**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: String): DataSet[T]
def sum(field: String): AggregateDataSet[T]

/**
* Syntactic sugar for [[aggregate]] with `MAX`
*/
def max(field: String): DataSet[T]
def max(field: String): AggregateDataSet[T]

/**
* Syntactic sugar for [[aggregate]] with `MIN`
*/
def min(field: String): DataSet[T]
def min(field: String): AggregateDataSet[T]

/**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
Expand Down Expand Up @@ -194,47 +193,37 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
/** Convenience methods for creating the [[UnsortedGrouping]] */
private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set, keys)

def aggregate(agg: Aggregations, field: String): DataSet[T] = {
def aggregate(agg: Aggregations, field: String): AggregateDataSet[T] = {
val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)

set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, fieldIndex)
wrap(aggregation)

case _ => wrap(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, fieldIndex))
}
new AggregateDataSet(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, fieldIndex))
}

def aggregate(agg: Aggregations, field: Int): DataSet[T] = set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, field)
wrap(aggregation)

case _ => wrap(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, field))
def aggregate(agg: Aggregations, field: Int): AggregateDataSet[T] = {
new AggregateDataSet(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, field))
}

def sum(field: Int): DataSet[T] = {
def sum(field: Int) = {
aggregate(Aggregations.SUM, field)
}

def max(field: Int): DataSet[T] = {
def max(field: Int) = {
aggregate(Aggregations.MAX, field)
}

def min(field: Int): DataSet[T] = {
def min(field: Int) = {
aggregate(Aggregations.MIN, field)
}

def sum(field: String): DataSet[T] = {
def sum(field: String) = {
aggregate(Aggregations.SUM, field)
}

def max(field: String): DataSet[T] = {
def max(field: String) = {
aggregate(Aggregations.MAX, field)
}

def min(field: String): DataSet[T] = {
def min(field: String) = {
aggregate(Aggregations.MIN, field)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ class ScalaAPICompletenessTest {
checkMethods("SortedGrouping", "GroupedDataSet",
classOf[SortedGrouping[_]], classOf[GroupedDataSet[_]])

checkMethods("AggregateOperator", "AggregateDataSet",
classOf[AggregateOperator[_]], classOf[AggregateDataSet[_]])

checkMethods("SingleInputOperator", "DataSet",
classOf[SingleInputOperator[_, _, _]], classOf[DataSet[_]])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ class AggregateOperatorTest {
val tupleDs = env.fromCollection(emptyTupleData)

// should work: multiple aggregates
tupleDs.aggregate(Aggregations.SUM, 0).aggregate(Aggregations.MIN, 4)
tupleDs.aggregate(Aggregations.SUM, 0).and(Aggregations.MIN, 4)

// should work: nested aggregates
tupleDs.aggregate(Aggregations.MIN, 2).aggregate(Aggregations.SUM, 1)
tupleDs.aggregate(Aggregations.MIN, 2).and(Aggregations.SUM, 1)

// should not work: average on string
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AggregateTranslationTest {

val initialData = env.fromElements((3.141592, "foobar", 77L))

initialData.groupBy(0).aggregate(Aggregations.MIN, 1).aggregate(Aggregations.SUM, 2).print()
initialData.groupBy(0).aggregate(Aggregations.MIN, 1).and(Aggregations.SUM, 2).print()

val p: Plan = env.createProgramPlan()
val sink = p.getDataSinks.iterator.next
Expand Down

0 comments on commit c778d28

Please sign in to comment.