Skip to content

Commit

Permalink
[FLINK-3087] [Table API] support multi count in aggregation.
Browse files Browse the repository at this point in the history
  • Loading branch information
chengxiang li authored and aljoscha committed Dec 2, 2015
1 parent 9215b72 commit 20fe2af
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,17 @@ abstract class ExpressionCodeGenerator[R](
}
}

val cleanedExpr = expr match {
case expressions.Naming(namedExpr, _) => namedExpr
case _ => expr
def cleanedExpr(e: Expression): Expression = {
e match {
case expressions.Naming(namedExpr, _) => cleanedExpr(namedExpr)
case _ => e
}
}

val resultTpe = typeTermForTypeInfo(cleanedExpr.typeInfo)

val code: String = cleanedExpr match {
val cleanedExpression = cleanedExpr(expr)
val resultTpe = typeTermForTypeInfo(cleanedExpression.typeInfo)

val code: String = cleanedExpression match {

case expressions.Literal(null, typeInfo) =>
if (nullCheck) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,23 @@ object ExpandAggregations {
val aggregationIntermediates = mutable.HashMap[Aggregation, Seq[Expression]]()

var intermediateCount = 0
var resultCount = 0
selection foreach { f =>
f.transformPre {
case agg: Aggregation =>
val intermediateReferences = agg.getIntermediateFields.zip(agg.getAggregations) map {
case (expr, basicAgg) =>
resultCount += 1
val resultName = s"result.$resultCount"
aggregations.get((expr, basicAgg)) match {
case Some(intermediateName) =>
ResolvedFieldReference(intermediateName, expr.typeInfo)
Naming(ResolvedFieldReference(intermediateName, expr.typeInfo), resultName)
case None =>
intermediateCount = intermediateCount + 1
val intermediateName = s"intermediate.$intermediateCount"
intermediateFields += Naming(expr, intermediateName)
aggregations((expr, basicAgg)) = intermediateName
ResolvedFieldReference(intermediateName, expr.typeInfo)
Naming(ResolvedFieldReference(intermediateName, expr.typeInfo), resultName)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@ public void testAggregationWithArithmetic() throws Exception {
compareResultAsText(results, expected);
}

@Test
public void testAggregationWithTwoCount() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSource<Tuple2<Float, String>> input =
env.fromElements(
new Tuple2<>(1f, "Hello"),
new Tuple2<>(2f, "Ciao"));

Table table =
tableEnv.fromDataSet(input);

Table result =
table.select("f0.count, f1.count");


DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "2,2";
compareResultAsText(results, expected);
}

@Test(expected = ExpressionException.class)
public void testNonWorkingDataTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testAggregationWithTwoCount(): Unit = {

val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements((1f, "Hello"), (2f, "Ciao")).toTable
.select('_1.count, '_2.count).toDataSet[Row]
val expected = "2,2"
val results = ds.collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test(expected = classOf[ExpressionException])
def testNonWorkingAggregationDataTypes(): Unit = {

Expand Down

0 comments on commit 20fe2af

Please sign in to comment.