Skip to content

Commit

Permalink
[FLINK-21655][table-planner-blink] Fix incorrect simplification for c…
Browse files Browse the repository at this point in the history
…oalesce call on a groupingsets' result

This closes apache#15117
  • Loading branch information
lincoln-lil authored and godfreyhe committed Mar 24, 2021
1 parent d3d320b commit b975fda
Show file tree
Hide file tree
Showing 32 changed files with 578 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,9 @@ object FlinkLogicalRelFactories {
class ExpandFactoryImpl extends ExpandFactory {
def createExpand(
input: RelNode,
rowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int): RelNode = {
FlinkLogicalExpand.create(input, rowType, projects, expandIdIndex)
FlinkLogicalExpand.create(input, projects, expandIdIndex)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.flink.table.runtime.operators.rank.{RankRange, RankType}
import com.google.common.collect.ImmutableList
import org.apache.calcite.plan._
import org.apache.calcite.rel.RelCollation
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rel.`type`.RelDataTypeField
import org.apache.calcite.rel.logical.LogicalAggregate
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlKind
Expand Down Expand Up @@ -84,11 +84,10 @@ class FlinkRelBuilder(
}

def expand(
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int): RelBuilder = {
val input = build()
val expand = expandFactory.createExpand(input, outputRowType, projects, expandIdIndex)
val expand = expandFactory.createExpand(input, projects, expandIdIndex)
push(expand)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.flink.table.planner.plan.nodes.calcite.{LogicalExpand, Logical
import org.apache.flink.table.runtime.operators.rank.{RankRange, RankType}

import org.apache.calcite.plan.Contexts
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rel.`type`.RelDataTypeField
import org.apache.calcite.rel.core.RelFactories
import org.apache.calcite.rel.{RelCollation, RelNode}
import org.apache.calcite.rex.RexNode
Expand Down Expand Up @@ -62,7 +62,6 @@ object FlinkRelFactories {
trait ExpandFactory {
def createExpand(
input: RelNode,
rowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int): RelNode
}
Expand All @@ -73,9 +72,10 @@ object FlinkRelFactories {
class ExpandFactoryImpl extends ExpandFactory {
def createExpand(
input: RelNode,
rowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int): RelNode = LogicalExpand.create(input, rowType, projects, expandIdIndex)
expandIdIndex: Int): RelNode = {
LogicalExpand.create(input, projects, expandIdIndex)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@

package org.apache.flink.table.planner.plan.nodes.calcite

import org.apache.flink.table.api.DataTypes.NULL
import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.plan.utils.{ExpandUtil, RelExplainUtil}

import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.rex.{RexInputRef, RexLiteral, RexNode}
import org.apache.calcite.util.Litmus
import org.apache.flink.table.planner.plan.utils.RelExplainUtil

import java.util

import scala.collection.JavaConversions._
import scala.collection.mutable

/**
* Relational expression that apply a number of projects to every input row,
Expand All @@ -39,7 +43,6 @@ import scala.collection.JavaConversions._
* @param cluster cluster that this relational expression belongs to
* @param traits the traits of this rel
* @param input input relational expression
* @param outputRowType output row type
* @param projects all projects, each project contains list of expressions for
* the output columns
* @param expandIdIndex expand_id('$e') field index
Expand All @@ -48,7 +51,6 @@ abstract class Expand(
cluster: RelOptCluster,
traits: RelTraitSet,
input: RelNode,
outputRowType: RelDataType,
val projects: util.List[util.List[RexNode]],
val expandIdIndex: Int)
extends SingleRel(cluster, traits, input) {
Expand All @@ -59,10 +61,20 @@ abstract class Expand(
if (projects.size() <= 1) {
return litmus.fail("Expand should output more than one rows, otherwise use Project.")
}
if (projects.exists(_.size != outputRowType.getFieldCount)) {
return litmus.fail("project filed count is not equal to output field count.")
val fieldLen = projects.get(0).size()
if (projects.exists(_.size != fieldLen)) {
return litmus.fail("all projects' field count should be equal.")
}

// do type check and derived row type info will be cached by framework
try {
deriveRowType()
} catch {
case exp: TableException =>
return litmus.fail(exp.getMessage)
}
if (expandIdIndex < 0 || expandIdIndex >= outputRowType.getFieldCount) {

if (expandIdIndex < 0 || expandIdIndex >= fieldLen) {
return litmus.fail(
"expand_id field index should be greater than 0 and less than output field count.")
}
Expand All @@ -79,7 +91,57 @@ abstract class Expand(
litmus.succeed()
}

override def deriveRowType(): RelDataType = outputRowType
override def deriveRowType(): RelDataType = {
val inputNames = input.getRowType.getFieldNames
val fieldNameSet = mutable.Set[String](inputNames: _*)
val rowTypes = mutable.ListBuffer[RelDataType]()
val outputNames = mutable.ListBuffer[String]()
val fieldLen = projects.get(0).size()
val inputNameRefCnt = mutable.Map[String, Int]()

for (fieldIndex <- 0 until fieldLen) {
val fieldTypes = mutable.ListBuffer[RelDataType]()
val fieldNames = mutable.ListBuffer[String]()
for (projectIndex <- 0 until projects.size()) {
val rexNode = projects.get(projectIndex).get(fieldIndex)
fieldTypes += rexNode.getType
rexNode match {
case ref: RexInputRef =>
fieldNames += inputNames.get(ref.getIndex)
case _: RexLiteral => // ignore
case exp@_ =>
throw new TableException(
"Expand node only support RexInputRef and RexLiteral, but got " + exp)
}
}
if (!fieldNames.isEmpty) {
val inputName = fieldNames(0)
val refCnt = inputNameRefCnt.getOrElse(inputName, 0) + 1
inputNameRefCnt.put(inputName, refCnt)
outputNames += ExpandUtil.buildDuplicateFieldName(
fieldNameSet,
inputName,
inputNameRefCnt.get(inputName).get)
} else if (fieldIndex == expandIdIndex) {
outputNames += ExpandUtil.buildUniqueFieldName(fieldNameSet, "$e")
} else {
outputNames += ExpandUtil.buildUniqueFieldName(fieldNameSet, "$f" + fieldIndex)
}

val leastRestrictive = input.getCluster.getTypeFactory.leastRestrictive(fieldTypes)
// if leastRestrictive type is null or type name is NULL means we can not support given
// projects with different column types (NULL type name is reserved for untyped literals only)
if (leastRestrictive == null || leastRestrictive.getSqlTypeName == NULL) {
throw new TableException(
"Expand node only support projects that have common types, but got a column with " +
"different types which can not derive a least restrictive common type: column index[" +
fieldIndex + "], column types[" + fieldTypes.mkString(",") + "]")
} else {
rowTypes += leastRestrictive
}
}
cluster.getTypeFactory.createStructType(rowTypes, outputNames)
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.flink.table.planner.plan.nodes.calcite

import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.RexNode

import java.util
Expand All @@ -34,25 +33,23 @@ final class LogicalExpand(
cluster: RelOptCluster,
traits: RelTraitSet,
input: RelNode,
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int)
extends Expand(cluster, traits, input, outputRowType, projects, expandIdIndex) {
extends Expand(cluster, traits, input, projects, expandIdIndex) {

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new LogicalExpand(cluster, traitSet, inputs.get(0), outputRowType, projects, expandIdIndex)
new LogicalExpand(cluster, traitSet, inputs.get(0), projects, expandIdIndex)
}

}

object LogicalExpand {
def create(
input: RelNode,
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int): LogicalExpand = {
val traits = input.getCluster.traitSetOf(Convention.NONE)
new LogicalExpand(input.getCluster, traits, input, outputRowType, projects, expandIdIndex)
new LogicalExpand(input.getCluster, traits, input, projects, expandIdIndex)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, LogicalExpand}

import org.apache.calcite.plan.{Convention, RelOptCluster, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rex.RexNode

Expand All @@ -37,14 +36,18 @@ class FlinkLogicalExpand(
cluster: RelOptCluster,
traits: RelTraitSet,
input: RelNode,
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int)
extends Expand(cluster, traits, input, outputRowType, projects, expandIdIndex)
extends Expand(cluster, traits, input, projects, expandIdIndex)
with FlinkLogicalRel {

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new FlinkLogicalExpand(cluster, traitSet, inputs.get(0), outputRowType, projects, expandIdIndex)
new FlinkLogicalExpand(
cluster,
traitSet,
inputs.get(0),
projects,
expandIdIndex)
}

}
Expand All @@ -61,7 +64,6 @@ private class FlinkLogicalExpandConverter
val newInput = RelOptRule.convert(expand.getInput, FlinkConventions.LOGICAL)
FlinkLogicalExpand.create(
newInput,
expand.getRowType,
expand.projects,
expand.expandIdIndex)
}
Expand All @@ -72,11 +74,10 @@ object FlinkLogicalExpand {

def create(
input: RelNode,
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int): FlinkLogicalExpand = {
val cluster = input.getCluster
val traitSet = cluster.traitSetOf(FlinkConventions.LOGICAL).simplify()
new FlinkLogicalExpand(cluster, traitSet, input, outputRowType, projects, expandIdIndex)
new FlinkLogicalExpand(cluster, traitSet, input, projects, expandIdIndex)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecExpand
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rex.RexNode

Expand All @@ -37,18 +36,16 @@ class BatchPhysicalExpand(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int)
extends Expand(cluster, traitSet, input, outputRowType, projects, expandIdIndex)
extends Expand(cluster, traitSet, input, projects, expandIdIndex)
with BatchPhysicalRel {

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchPhysicalExpand(
cluster,
traitSet,
inputs.get(0),
outputRowType,
projects,
expandIdIndex
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.RexNode

import java.util
Expand All @@ -36,17 +35,16 @@ class StreamPhysicalExpand(
cluster: RelOptCluster,
traitSet: RelTraitSet,
inputRel: RelNode,
outputRowType: RelDataType,
projects: util.List[util.List[RexNode]],
expandIdIndex: Int)
extends Expand(cluster, traitSet, inputRel, outputRowType, projects, expandIdIndex)
extends Expand(cluster, traitSet, inputRel, projects, expandIdIndex)
with StreamPhysicalRel {

override def requireWatermark: Boolean = false

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new StreamPhysicalExpand(
cluster, traitSet, inputs.get(0), outputRowType, projects, expandIdIndex)
cluster, traitSet, inputs.get(0), projects, expandIdIndex)
}

override def translateToExecNode(): ExecNode[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class DecomposeGroupingSetsRule extends RelOptRule(

val (newGroupSet, duplicateFieldMap) = if (needExpand) {
val (duplicateFieldMap, expandIdIdxInExpand) = ExpandUtil.buildExpandNode(
cluster, relBuilder, agg.getAggCallList, agg.getGroupSet, agg.getGroupSets)
relBuilder, agg.getAggCallList, agg.getGroupSet, agg.getGroupSets)

// new groupSet contains original groupSet and expand_id('$e') field
val newGroupSet = agg.getGroupSet.union(ImmutableBitSet.of(expandIdIdxInExpand))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class SplitAggregateRule extends RelOptRule(
val needExpand = newGroupSetsNum > 1
val duplicateFieldMap = if (needExpand) {
val (duplicateFieldMap, _) = ExpandUtil.buildExpandNode(
cluster, relBuilder, partialAggCalls, fullGroupSet, groupSets)
relBuilder, partialAggCalls, fullGroupSet, groupSets)
duplicateFieldMap
} else {
Map.empty[Integer, Integer]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class BatchPhysicalExpandRule
rel.getCluster,
newTrait,
newInput,
rel.getRowType,
expand.projects,
expand.expandIdIndex)
}
Expand Down
Loading

0 comments on commit b975fda

Please sign in to comment.