Skip to content

Commit

Permalink
[FLINK-5266] [table] Inject projection of unused fields before aggreg…
Browse files Browse the repository at this point in the history
…ations.

This closes apache#2961.
  • Loading branch information
KurtYoung authored and fhueske committed Dec 15, 2016
1 parent 5dab934 commit 15e7f0a
Show file tree
Hide file tree
Showing 6 changed files with 551 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,19 @@ object ProjectionTranslator {

/**
* Extracts and deduplicates all aggregation and window property expressions (zero, one, or more)
* from all expressions and replaces the original expressions by field accesses expressions.
* from the given expressions.
*
* @param exprs a list of expressions to convert
* @param exprs a list of expressions to extract
* @param tableEnv the TableEnvironment
* @return a Tuple3, the first field contains the converted expressions, the second field the
* extracted and deduplicated aggregations, and the third field the extracted and
* deduplicated window properties.
* @return a Tuple2, the first field contains the extracted and deduplicated aggregations,
* and the second field contains the extracted and deduplicated window properties.
*/
def extractAggregationsAndProperties(
exprs: Seq[Expression],
tableEnv: TableEnvironment)
: (Seq[NamedExpression], Seq[NamedExpression], Seq[NamedExpression]) = {

val (aggNames, propNames) =
exprs.foldLeft( (Map[Expression, String](), Map[Expression, String]()) ) {
(x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
}

val replaced = exprs
.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
.map(UnresolvedAlias)
val aggs = aggNames.map( a => Alias(a._1, a._2)).toSeq
val props = propNames.map( p => Alias(p._1, p._2)).toSeq

(replaced, aggs, props)
tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = {
exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) {
(x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
}
}

/** Identifies and deduplicates aggregation functions and window properties. */
Expand Down Expand Up @@ -106,7 +94,24 @@ object ProjectionTranslator {
}
}

/** Replaces aggregations and projections by named field references. */
/**
* Replaces expressions with deduplicated aggregations and properties.
*
* @param exprs a list of expressions to replace
* @param tableEnv the TableEnvironment
* @param aggNames the deduplicated aggregations
* @param propNames the deduplicated properties
* @return a list of replaced expressions
*/
def replaceAggregationsAndProperties(
exprs: Seq[Expression],
tableEnv: TableEnvironment,
aggNames: Map[Expression, String],
propNames: Map[Expression, String]): Seq[NamedExpression] = {
exprs.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
.map(UnresolvedAlias)
}

private def replaceAggregationsAndProperties(
exp: Expression,
tableEnv: TableEnvironment,
Expand Down Expand Up @@ -197,4 +202,62 @@ object ProjectionTranslator {
}
projectList
}

/**
* Extract all field references from the given expressions.
*
* @param exprs a list of expressions to extract
* @return a list of field references extracted from the given expressions
*/
def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.foldLeft(Set[NamedExpression]()) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}.toSeq
}

private def identifyFieldReferences(
expr: Expression,
fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match {

case f: UnresolvedFieldReference =>
fieldReferences + UnresolvedAlias(f)

case b: BinaryExpression =>
val l = identifyFieldReferences(b.left, fieldReferences)
identifyFieldReferences(b.right, l)

// Functions calls
case c @ Call(name, args) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}
case sfc @ ScalarFunctionCall(clazz, args) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}

// array constructor
case c @ ArrayConstructor(args) =>
args.foldLeft(fieldReferences) {
(fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
}

// ignore fields from window property
case w : WindowProperty =>
fieldReferences

// keep this case after all unwanted unary expressions
case u: UnaryExpression =>
identifyFieldReferences(u.child, fieldReferences)

// General expression
case e: Expression =>
e.productIterator.foldLeft(fieldReferences) {
(fieldReferences, expr) => expr match {
case e: Expression => identifyFieldReferences(e, fieldReferences)
case _ => fieldReferences
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ package org.apache.flink.api.table
import org.apache.calcite.rel.RelNode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.plan.logical.Minus
import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall}
import org.apache.flink.api.table.expressions._
import org.apache.flink.api.table.plan.ProjectionTranslator._
import org.apache.flink.api.table.plan.logical._
import org.apache.flink.api.table.plan.logical.{Minus, _}
import org.apache.flink.api.table.sinks.TableSink

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -77,21 +76,27 @@ class Table(
* }}}
*/
def select(fields: Expression*): Table = {

val expandedFields = expandProjectList(fields, logicalPlan, tableEnv)
val (projection, aggs, props) = extractAggregationsAndProperties(expandedFields, tableEnv)

if (props.nonEmpty) {
val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, tableEnv)
if (propNames.nonEmpty) {
throw ValidationException("Window properties can only be used on windowed tables.")
}

if (aggs.nonEmpty) {
if (aggNames.nonEmpty) {
val projectsOnAgg = replaceAggregationsAndProperties(
expandedFields, tableEnv, aggNames, propNames)
val projectFields = extractFieldReferences(expandedFields)

new Table(tableEnv,
Project(projection,
Aggregate(Nil, aggs, logicalPlan).validate(tableEnv)).validate(tableEnv))
Project(projectsOnAgg,
Aggregate(Nil, aggNames.map(a => Alias(a._1, a._2)).toSeq,
Project(projectFields, logicalPlan).validate(tableEnv)
).validate(tableEnv)
).validate(tableEnv)
)
} else {
new Table(tableEnv,
Project(projection, logicalPlan).validate(tableEnv))
Project(expandedFields.map(UnresolvedAlias), logicalPlan).validate(tableEnv))
}
}

Expand Down Expand Up @@ -806,24 +811,21 @@ class GroupedTable(
* }}}
*/
def select(fields: Expression*): Table = {

val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv)

if (props.nonEmpty) {
val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv)
if (propNames.nonEmpty) {
throw ValidationException("Window properties can only be used on windowed tables.")
}

val logical =
Project(
projection,
Aggregate(
groupKey,
aggs,
table.logicalPlan
).validate(table.tableEnv)
).validate(table.tableEnv)
val projectsOnAgg = replaceAggregationsAndProperties(
fields, table.tableEnv, aggNames, propNames)
val projectFields = extractFieldReferences(fields ++ groupKey)

new Table(table.tableEnv, logical)
new Table(table.tableEnv,
Project(projectsOnAgg,
Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq,
Project(projectFields, table.logicalPlan).validate(table.tableEnv)
).validate(table.tableEnv)
).validate(table.tableEnv))
}

/**
Expand Down Expand Up @@ -877,24 +879,29 @@ class GroupWindowedTable(
* }}}
*/
def select(fields: Expression*): Table = {
val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv)
val projectsOnAgg = replaceAggregationsAndProperties(
fields, table.tableEnv, aggNames, propNames)

val projectFields = (table.tableEnv, window) match {
// event time can be arbitrary field in batch environment
case (_: BatchTableEnvironment, w: EventTimeWindow) =>
extractFieldReferences(fields ++ groupKey ++ Seq(w.timeField))
case (_, _) =>
extractFieldReferences(fields ++ groupKey)
}

val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv)

val groupWindow = window.toLogicalWindow

val logical =
new Table(table.tableEnv,
Project(
projection,
projectsOnAgg,
WindowAggregate(
groupKey,
groupWindow,
props,
aggs,
table.logicalPlan
window.toLogicalWindow,
propNames.map(a => Alias(a._1, a._2)).toSeq,
aggNames.map(a => Alias(a._1, a._2)).toSeq,
Project(projectFields, table.logicalPlan).validate(table.tableEnv)
).validate(table.tableEnv)
).validate(table.tableEnv)

new Table(table.tableEnv, logical)
).validate(table.tableEnv))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ trait GroupWindow {
* @param timeField defines the time mode for streaming tables. For batch table it defines the
* time attribute on which is grouped.
*/
abstract class EventTimeWindow(timeField: Expression) extends GroupWindow {
abstract class EventTimeWindow(val timeField: Expression) extends GroupWindow {

protected var name: Option[Expression] = None

Expand Down
Loading

0 comments on commit 15e7f0a

Please sign in to comment.