From f0ed58c6a580db0966104b81491d08d25d1ff57e Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Mon, 29 Sep 2014 17:26:31 +0200 Subject: [PATCH] Fix Bug in ScalaAggregate Operator and add ITCase --- flink-scala/pom.xml | 8 + .../operators/ScalaAggregateOperator.java | 5 +- .../api/scala/operators/AggregateITCase.scala | 164 ++++++++++++++++++ 3 files changed, 175 insertions(+), 2 deletions(-) create mode 100644 flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala diff --git a/flink-scala/pom.xml b/flink-scala/pom.xml index c0a156d644aaa..e0b9aeb11a298 100644 --- a/flink-scala/pom.xml +++ b/flink-scala/pom.xml @@ -97,6 +97,14 @@ under the License. 2.2.0 test + + + org.apache.flink + flink-test-utils + ${project.version} + test + + diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java index 8a58ce52c99fc..310fc17ac02c4 100644 --- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java +++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaAggregateOperator.java @@ -297,9 +297,10 @@ public void reduce(Iterable records, Collector out) { } Object[] fields = new Object[serializer.getArity()]; + int length = serializer.getArity(); // First copy all tuple fields, then overwrite the aggregated ones - for (int i = 0; i < fieldPositions.length; i++) { - fields[0] = current.productElement(i); + for (int i = 0; i < length; i++) { + fields[i] = current.productElement(i); } for (int i = 0; i < fieldPositions.length; i++) { Object aggVal = aggFunctions[i].getAggregate(); diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala new file mode 100644 index 0000000000000..631e68a7fec1a --- /dev/null +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateITCase.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.java.aggregation.Aggregations +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.configuration.Configuration +import org.apache.flink.test.util.JavaProgramTestBase +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.flink.api.scala._ + + +object AggregateProgs { + var NUM_PROGRAMS: Int = 3 + + val tupleInput = Array( + (1,1L,"Hi"), + (2,2L,"Hello"), + (3,2L,"Hello world"), + (4,3L,"Hello world, how are you?"), + (5,3L,"I am fine."), + (6,3L,"Luke Skywalker"), + (7,4L,"Comment#1"), + (8,4L,"Comment#2"), + (9,4L,"Comment#3"), + (10,4L,"Comment#4"), + (11,5L,"Comment#5"), + (12,5L,"Comment#6"), + (13,5L,"Comment#7"), + (14,5L,"Comment#8"), + (15,5L,"Comment#9"), + (16,6L,"Comment#10"), + (17,6L,"Comment#11"), + (18,6L,"Comment#12"), + (19,6L,"Comment#13"), + (20,6L,"Comment#14"), + (21,6L,"Comment#15") + ) + + + def runProgram(progId: Int, resultPath: String): String = { + progId match { + case 1 => + // Full aggregate + val env = ExecutionEnvironment.getExecutionEnvironment + env.setDegreeOfParallelism(10) + val ds = env.fromCollection(tupleInput) + + val aggregateDs = ds + .aggregate(Aggregations.SUM,0) + .and(Aggregations.MAX, 1) + // Ensure aggregate operator correctly copies other fields + .filter(_._3 != null) + .map{ t => (t._1, t._2) } + + aggregateDs.writeAsCsv(resultPath) + + env.execute() + + // return expected result + "231,6\n" + + case 2 => + // Grouped aggregate + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromCollection(tupleInput) + + val aggregateDs = ds + .groupBy(1) + .aggregate(Aggregations.SUM, 0) + // Ensure aggregate operator correctly copies other fields + .filter(_._3 != null) + .map { t => (t._2, t._1) } + + aggregateDs.writeAsCsv(resultPath) + + env.execute() + + // return expected result + "1,1\n" + "2,5\n" + "3,15\n" + "4,34\n" + "5,65\n" + "6,111\n" + + case 3 => + // Nested aggregate + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = env.fromCollection(tupleInput) + + val aggregateDs = ds + .groupBy(1) + .aggregate(Aggregations.MIN, 0) + .aggregate(Aggregations.MIN, 0) + // Ensure aggregate operator correctly copies other fields + .filter(_._3 != null) + .map { t => Tuple1(t._1) } + + aggregateDs.writeAsCsv(resultPath) + + env.execute() + + // return expected result + "1\n" + + + case _ => + throw new IllegalArgumentException("Invalid program id") + } + } +} + + +@RunWith(classOf[Parameterized]) +class AggregateITCase(config: Configuration) extends JavaProgramTestBase(config) { + + private var curProgId: Int = config.getInteger("ProgramId", -1) + private var resultPath: String = null + private var expectedResult: String = null + + protected override def preSubmit(): Unit = { + resultPath = getTempDirPath("result") + } + + protected def testProgram(): Unit = { + expectedResult = AggregateProgs.runProgram(curProgId, resultPath) + } + + protected override def postSubmit(): Unit = { + compareResultsByLinesInMemory(expectedResult, resultPath) + } +} + +object AggregateITCase { + @Parameters + def getConfigurations: java.util.Collection[Array[AnyRef]] = { + val configs = mutable.MutableList[Array[AnyRef]]() + for (i <- 1 to AggregateProgs.NUM_PROGRAMS) { + val config = new Configuration() + config.setInteger("ProgramId", i) + configs += Array(config) + } + + configs.asJavaCollection + } +} +