Skip to content

Commit

Permalink
Fix Bug in ScalaAggregate Operator and add ITCase
Browse files Browse the repository at this point in the history
  • Loading branch information
aljoscha committed Sep 29, 2014
1 parent 3e5ab89 commit f0ed58c
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 2 deletions.
8 changes: 8 additions & 0 deletions flink-scala/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ under the License.
<version>2.2.0</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-test-utils</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,10 @@ public void reduce(Iterable<T> records, Collector<T> out) {
}

Object[] fields = new Object[serializer.getArity()];
int length = serializer.getArity();
// First copy all tuple fields, then overwrite the aggregated ones
for (int i = 0; i < fieldPositions.length; i++) {
fields[0] = current.productElement(i);
for (int i = 0; i < length; i++) {
fields[i] = current.productElement(i);
}
for (int i = 0; i < fieldPositions.length; i++) {
Object aggVal = aggFunctions[i].getAggregate();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.operators

import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.configuration.Configuration
import org.apache.flink.test.util.JavaProgramTestBase
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.flink.api.scala._


object AggregateProgs {
var NUM_PROGRAMS: Int = 3

val tupleInput = Array(
(1,1L,"Hi"),
(2,2L,"Hello"),
(3,2L,"Hello world"),
(4,3L,"Hello world, how are you?"),
(5,3L,"I am fine."),
(6,3L,"Luke Skywalker"),
(7,4L,"Comment#1"),
(8,4L,"Comment#2"),
(9,4L,"Comment#3"),
(10,4L,"Comment#4"),
(11,5L,"Comment#5"),
(12,5L,"Comment#6"),
(13,5L,"Comment#7"),
(14,5L,"Comment#8"),
(15,5L,"Comment#9"),
(16,6L,"Comment#10"),
(17,6L,"Comment#11"),
(18,6L,"Comment#12"),
(19,6L,"Comment#13"),
(20,6L,"Comment#14"),
(21,6L,"Comment#15")
)


def runProgram(progId: Int, resultPath: String): String = {
progId match {
case 1 =>
// Full aggregate
val env = ExecutionEnvironment.getExecutionEnvironment
env.setDegreeOfParallelism(10)
val ds = env.fromCollection(tupleInput)

val aggregateDs = ds
.aggregate(Aggregations.SUM,0)
.and(Aggregations.MAX, 1)
// Ensure aggregate operator correctly copies other fields
.filter(_._3 != null)
.map{ t => (t._1, t._2) }

aggregateDs.writeAsCsv(resultPath)

env.execute()

// return expected result
"231,6\n"

case 2 =>
// Grouped aggregate
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput)

val aggregateDs = ds
.groupBy(1)
.aggregate(Aggregations.SUM, 0)
// Ensure aggregate operator correctly copies other fields
.filter(_._3 != null)
.map { t => (t._2, t._1) }

aggregateDs.writeAsCsv(resultPath)

env.execute()

// return expected result
"1,1\n" + "2,5\n" + "3,15\n" + "4,34\n" + "5,65\n" + "6,111\n"

case 3 =>
// Nested aggregate
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromCollection(tupleInput)

val aggregateDs = ds
.groupBy(1)
.aggregate(Aggregations.MIN, 0)
.aggregate(Aggregations.MIN, 0)
// Ensure aggregate operator correctly copies other fields
.filter(_._3 != null)
.map { t => Tuple1(t._1) }

aggregateDs.writeAsCsv(resultPath)

env.execute()

// return expected result
"1\n"


case _ =>
throw new IllegalArgumentException("Invalid program id")
}
}
}


@RunWith(classOf[Parameterized])
class AggregateITCase(config: Configuration) extends JavaProgramTestBase(config) {

private var curProgId: Int = config.getInteger("ProgramId", -1)
private var resultPath: String = null
private var expectedResult: String = null

protected override def preSubmit(): Unit = {
resultPath = getTempDirPath("result")
}

protected def testProgram(): Unit = {
expectedResult = AggregateProgs.runProgram(curProgId, resultPath)
}

protected override def postSubmit(): Unit = {
compareResultsByLinesInMemory(expectedResult, resultPath)
}
}

object AggregateITCase {
@Parameters
def getConfigurations: java.util.Collection[Array[AnyRef]] = {
val configs = mutable.MutableList[Array[AnyRef]]()
for (i <- 1 to AggregateProgs.NUM_PROGRAMS) {
val config = new Configuration()
config.setInteger("ProgramId", i)
configs += Array(config)
}

configs.asJavaCollection
}
}

0 comments on commit f0ed58c

Please sign in to comment.