Skip to content

Commit

Permalink
[FLINK-8903] [table] Fix VAR_SAMP, VAR_POP, STDEV_SAMP, STDEV_POP fun…
Browse files Browse the repository at this point in the history
…ctions on GROUP BY windows.

This closes apache#5706.
  • Loading branch information
fhueske committed Mar 22, 2018
1 parent 893fabf commit 8c042e3
Show file tree
Hide file tree
Showing 10 changed files with 888 additions and 5 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.calcite.sql.SqlKind
import org.apache.calcite.util.ImmutableBitSet
import org.apache.flink.table.plan.nodes.FlinkConventions

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

class FlinkLogicalAggregate(
cluster: RelOptCluster,
Expand Down Expand Up @@ -74,8 +74,11 @@ private class FlinkLogicalAggregateConverter

// we do not support these functions natively
// they have to be converted using the AggregateReduceFunctionsRule
val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall {
case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false
val supported = agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall {
// we support AVG
case SqlKind.AVG => true
// but none of the other AVG agg functions
case k if SqlKind.AVG_AGG_FUNCTIONS.contains(k) => false
case _ => true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelShuttle}
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.util.ImmutableBitSet
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.logical.LogicalWindow
import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate
import org.apache.flink.table.plan.nodes.FlinkConventions

import scala.collection.JavaConverters._

class FlinkLogicalWindowAggregate(
window: LogicalWindow,
namedProperties: Seq[NamedWindowProperty],
Expand Down Expand Up @@ -103,6 +106,20 @@ class FlinkLogicalWindowAggregateConverter
FlinkConventions.LOGICAL,
"FlinkLogicalWindowAggregateConverter") {

override def matches(call: RelOptRuleCall): Boolean = {
val agg = call.rel(0).asInstanceOf[LogicalWindowAggregate]

// we do not support these functions natively
// they have to be converted using the WindowAggregateReduceFunctionsRule
agg.getAggCallList.asScala.map(_.getAggregation.getKind).forall {
// we support AVG
case SqlKind.AVG => true
// but none of the other AVG agg functions
case k if SqlKind.AVG_AGG_FUNCTIONS.contains(k) => false
case _ => true
}
}

override def convert(rel: RelNode): RelNode = {
val agg = rel.asInstanceOf[LogicalWindowAggregate]
val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ object FlinkRuleSets {

// reduce aggregate functions like AVG, STDDEV_POP etc.
AggregateReduceFunctionsRule.INSTANCE,
WindowAggregateReduceFunctionsRule.INSTANCE,

// remove unnecessary sort rule
SortRemoveRule.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.plan.rules.common

import java.util

import org.apache.calcite.plan.RelOptRule
import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories}
import org.apache.calcite.rel.logical.LogicalAggregate
import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule
import org.apache.calcite.rex.RexNode
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate

/**
* Rule to convert complex aggregation functions into simpler ones.
* Have a look at [[AggregateReduceFunctionsRule]] for details.
*/
class WindowAggregateReduceFunctionsRule extends AggregateReduceFunctionsRule(
RelOptRule.operand(classOf[LogicalWindowAggregate], RelOptRule.any()),
RelFactories.LOGICAL_BUILDER) {

override def newAggregateRel(
relBuilder: RelBuilder,
oldAgg: Aggregate,
newCalls: util.List[AggregateCall]): Unit = {

// create a LogicalAggregate with simpler aggregation functions
super.newAggregateRel(relBuilder, oldAgg, newCalls)
// pop LogicalAggregate from RelBuilder
val newAgg = relBuilder.build().asInstanceOf[LogicalAggregate]

// create a new LogicalWindowAggregate (based on the new LogicalAggregate) and push it on the
// RelBuilder
val oldWindowAgg = oldAgg.asInstanceOf[LogicalWindowAggregate]
relBuilder.push(LogicalWindowAggregate.create(
oldWindowAgg.getWindow,
oldWindowAgg.getNamedProperties,
newAgg))
}

override def newCalcRel(
relBuilder: RelBuilder,
oldAgg: Aggregate,
exprs: util.List[RexNode]): Unit = {

// add all named properties of the window to the selection
val oldWindowAgg = oldAgg.asInstanceOf[LogicalWindowAggregate]
oldWindowAgg.getNamedProperties.foreach(np => exprs.add(relBuilder.field(np.name)))

// create a LogicalCalc that computes the complex aggregates and forwards the window properties
relBuilder.project(exprs, oldAgg.getRowType.getFieldNames)
}

}

object WindowAggregateReduceFunctionsRule {
val INSTANCE = new WindowAggregateReduceFunctionsRule
}
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ object AggregateUtil {
}
}

case _: SqlAvgAggFunction =>
case a: SqlAvgAggFunction if a.kind == SqlKind.AVG =>
aggregates(index) = sqlTypeName match {
case TINYINT =>
new ByteAvgAggFunction
Expand Down Expand Up @@ -1413,7 +1413,7 @@ object AggregateUtil {
accTypes(index) = udagg.accType

case unSupported: SqlAggFunction =>
throw new TableException(s"unsupported Function: '${unSupported.getName}'")
throw new TableException(s"Unsupported Function: '${unSupported.getName}'")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,53 @@ class GroupWindowTest extends TableTestBase {

util.verifySql(sql, expected)
}

@Test
def testDecomposableAggFunctions() = {
val util = batchTestUtil()
util.addTable[(Int, String, Long, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime)

val sql =
"SELECT " +
" VAR_POP(c), VAR_SAMP(c), STDDEV_POP(c), STDDEV_SAMP(c), " +
" TUMBLE_START(rowtime, INTERVAL '15' MINUTE), " +
" TUMBLE_END(rowtime, INTERVAL '15' MINUTE)" +
"FROM MyTable " +
"GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)"

val expected =
unaryNode(
"DataSetCalc",
unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "rowtime", "c",
"*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5")
),
term("window", TumblingGroupWindow('w$, 'rowtime, 900000.millis)),
term("select",
"SUM($f2) AS $f0",
"SUM(c) AS $f1",
"COUNT(c) AS $f2",
"SUM($f3) AS $f3",
"SUM($f4) AS $f4",
"SUM($f5) AS $f5",
"start('w$) AS w$start",
"end('w$) AS w$end",
"rowtime('w$) AS w$rowtime")
),
term("select",
"CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS EXPR$0",
"CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS EXPR$1",
"CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS EXPR$2",
"CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " +
"AS EXPR$3",
"CAST(w$start) AS EXPR$4",
"CAST(w$end) AS EXPR$5")
)

util.verifySql(sql, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -449,4 +449,49 @@ class GroupWindowTest extends TableTestBase {

util.verifyTable(windowedTable, expected)
}

@Test
def testDecomposableAggFunctions(): Unit = {
val util = batchTestUtil()
val table = util.addTable[(Long, Int, String, Long)]('rowtime, 'a, 'b, 'c)

val windowedTable = table
.window(Tumble over 15.minutes on 'rowtime as 'w)
.groupBy('w)
.select('c.varPop, 'c.varSamp, 'c.stddevPop, 'c.stddevSamp, 'w.start, 'w.end)

val expected =
unaryNode(
"DataSetCalc",
unaryNode(
"DataSetWindowAggregate",
unaryNode(
"DataSetCalc",
batchTableNode(0),
term("select", "c", "rowtime",
"*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5")
),
term("window", TumblingGroupWindow('w, 'rowtime, 900000.millis)),
term("select",
"SUM($f2) AS $f0",
"SUM(c) AS $f1",
"COUNT(c) AS $f2",
"SUM($f3) AS $f3",
"SUM($f4) AS $f4",
"SUM($f5) AS $f5",
"start('w) AS TMP_4",
"end('w) AS TMP_5")
),
term("select",
"CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS TMP_0",
"CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS TMP_1",
"CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS TMP_2",
"CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " +
"AS TMP_3",
"TMP_4",
"TMP_5")
)

util.verifyTable(windowedTable, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,50 @@ class GroupWindowTest extends TableTestBase {

streamUtil.verifySql(sql, expected)
}

@Test
def testDecomposableAggFunctions() = {

val sql =
"SELECT " +
" VAR_POP(c), VAR_SAMP(c), STDDEV_POP(c), STDDEV_SAMP(c), " +
" TUMBLE_START(rowtime, INTERVAL '15' MINUTE), " +
" TUMBLE_END(rowtime, INTERVAL '15' MINUTE)" +
"FROM MyTable " +
"GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)"
val expected =
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamGroupWindowAggregate",
unaryNode(
"DataStreamCalc",
streamTableNode(0),
term("select", "rowtime", "c",
"*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5")
),
term("window", TumblingGroupWindow('w$, 'rowtime, 900000.millis)),
term("select",
"SUM($f2) AS $f0",
"SUM(c) AS $f1",
"COUNT(c) AS $f2",
"SUM($f3) AS $f3",
"SUM($f4) AS $f4",
"SUM($f5) AS $f5",
"start('w$) AS w$start",
"end('w$) AS w$end",
"rowtime('w$) AS w$rowtime",
"proctime('w$) AS w$proctime")
),
term("select",
"CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS EXPR$0",
"CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS EXPR$1",
"CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS EXPR$2",
"CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " +
"AS EXPR$3",
"w$start AS EXPR$4",
"w$end AS EXPR$5")
)
streamUtil.verifySql(sql, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -782,4 +782,49 @@ class GroupWindowTest extends TableTestBase {

util.verifyTable(windowedTable, expected)
}

@Test
def testDecomposableAggFunctions(): Unit = {
val util = streamTestUtil()
val table = util.addTable[(Long, Int, String, Long)]('rowtime.rowtime, 'a, 'b, 'c)

val windowedTable = table
.window(Tumble over 15.minutes on 'rowtime as 'w)
.groupBy('w)
.select('c.varPop, 'c.varSamp, 'c.stddevPop, 'c.stddevSamp, 'w.start, 'w.end)

val expected =
unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamGroupWindowAggregate",
unaryNode(
"DataStreamCalc",
streamTableNode(0),
term("select", "c", "rowtime",
"*(c, c) AS $f2", "*(c, c) AS $f3", "*(c, c) AS $f4", "*(c, c) AS $f5")
),
term("window", TumblingGroupWindow('w, 'rowtime, 900000.millis)),
term("select",
"SUM($f2) AS $f0",
"SUM(c) AS $f1",
"COUNT(c) AS $f2",
"SUM($f3) AS $f3",
"SUM($f4) AS $f4",
"SUM($f5) AS $f5",
"start('w) AS TMP_4",
"end('w) AS TMP_5")
),
term("select",
"CAST(/(-($f0, /(*($f1, $f1), $f2)), $f2)) AS TMP_0",
"CAST(/(-($f3, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1)))) AS TMP_1",
"CAST(POWER(/(-($f4, /(*($f1, $f1), $f2)), $f2), 0.5)) AS TMP_2",
"CAST(POWER(/(-($f5, /(*($f1, $f1), $f2)), CASE(=($f2, 1), null, -($f2, 1))), 0.5)) " +
"AS TMP_3",
"TMP_4",
"TMP_5")
)

util.verifyTable(windowedTable, expected)
}
}

0 comments on commit 8c042e3

Please sign in to comment.