Skip to content

Commit

Permalink
[FLINK-22148][table] Planner rules should use RexCall#equsls to check…
Browse files Browse the repository at this point in the history
… whether two rexCalls are equivalent

This closes apache#15529
  • Loading branch information
cshuo committed Apr 9, 2021
1 parent 3f4dd82 commit b17e7b5
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class FlinkCalcMergeRule[C <: Calc](calcClass: Class[C]) extends RelOptRule(
val newMergedProgram = if (mergedProgram.getCondition != null) {
val condition = mergedProgram.expandLocalRef(mergedProgram.getCondition)
val simplifiedCondition = FlinkRexUtil.simplify(rexBuilder, condition)
if (simplifiedCondition.toString == condition.toString) {
if (simplifiedCondition.equals(condition)) {
mergedProgram
} else {
val programBuilder = RexProgramBuilder.forProgram(mergedProgram, rexBuilder, true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class JoinDependentConditionDerivationRule
builder.getRexBuilder,
builder.and(conjunctions ++ additionalConditions))

if (!newCondExp.toString.equals(join.getCondition.toString)) {
if (!newCondExp.equals(join.getCondition)) {
val newJoin = join.copy(
join.getTraitSet,
newCondExp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ class RewriteMultiJoinConditionRule extends RelOptRule(
val multiJoin: MultiJoin = call.rel(0)
val (equiJoinFilters, nonEquiJoinFilters) = partitionJoinFilters(multiJoin)
// there is no `equals` method in RexCall, so the key of this map should be String
val equiJoinFilterMap = mutable.HashMap[String, mutable.ListBuffer[RexNode]]()
val equiJoinFilterMap = mutable.HashMap[RexNode, mutable.ListBuffer[RexNode]]()
equiJoinFilters.foreach {
case c: RexCall =>
require(c.isA(SqlKind.EQUALS))
val left = c.operands.head
val right = c.operands(1)
equiJoinFilterMap.getOrElseUpdate(left.toString, mutable.ListBuffer[RexNode]()) += right
equiJoinFilterMap.getOrElseUpdate(right.toString, mutable.ListBuffer[RexNode]()) += left
equiJoinFilterMap.getOrElseUpdate(left, mutable.ListBuffer[RexNode]()) += right
equiJoinFilterMap.getOrElseUpdate(right, mutable.ListBuffer[RexNode]()) += left
}

val candidateJoinFilters = equiJoinFilterMap.values.filter(_.size > 1)
Expand All @@ -72,7 +72,7 @@ class RewriteMultiJoinConditionRule extends RelOptRule(

val newEquiJoinFilters = mutable.ListBuffer[RexNode](equiJoinFilters: _*)
def containEquiJoinFilter(joinFilter: RexNode): Boolean = {
newEquiJoinFilters.exists { f => f.toString.equals(joinFilter.toString) }
newEquiJoinFilters.exists { f => f.equals(joinFilter) }
}

val rexBuilder = multiJoin.getCluster.getRexBuilder
Expand All @@ -82,11 +82,9 @@ class RewriteMultiJoinConditionRule extends RelOptRule(
val op1 = candidate(startIndex)
candidate.subList(startIndex + 1, candidate.size).foreach {
op2 =>
// `a = b` and `b = a` are the same
val newFilter1 = rexBuilder.makeCall(EQUALS, op1, op2)
val newFilter2 = rexBuilder.makeCall(EQUALS, op2, op1)
if (!containEquiJoinFilter(newFilter1) && !containEquiJoinFilter(newFilter2)) {
newEquiJoinFilters += newFilter1
val newFilter = rexBuilder.makeCall(EQUALS, op1, op2)
if (!containEquiJoinFilter(newFilter)) {
newEquiJoinFilters += newFilter
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SimplifyFilterConditionRule(
val simplifiedCondition = FlinkRexUtil.simplify(rexBuilder, condition)
val newCondition = RexUtil.pullFactors(rexBuilder, simplifiedCondition)

if (!changed.head && !RexUtil.eq(condition, newCondition)) {
if (!changed.head && !condition.equals(newCondition)) {
changed(0) = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SimplifyJoinConditionRule
val simpleCondExp = FlinkRexUtil.simplify(join.getCluster.getRexBuilder, condition)
val newCondExp = RexUtil.pullFactors(join.getCluster.getRexBuilder, simpleCondExp)

if (newCondExp.toString.equals(condition.toString)) {
if (newCondExp.equals(condition)) {
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,39 @@ HashJoin(joinType=[LeftSemiJoin], where=[((a = $0) AND (b = $1))], select=[a, b,
+- OverAggregate(window#0=[MAX(d) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[d, e, f, w0$o0])
+- Exchange(distribution=[single])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
]]>
</Resource>
</TestCase>
<TestCase name="testNotSimplifyJoinConditionWithSameDigest">
<Resource name="sql">
<![CDATA[
SELECT a
FROM l
WHERE c NOT IN (
SELECT f FROM r WHERE f = c)
]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0])
+- LogicalFilter(condition=[NOT(IN($2, {
LogicalProject(f=[$2])
LogicalFilter(condition=[=($2, $cor0.c)])
LogicalTableScan(table=[[default_catalog, default_database, r, source: [TestTableSource(d, e, f)]]])
}))], variablesSet=[[$cor0]])
+- LogicalTableScan(table=[[default_catalog, default_database, l, source: [TestTableSource(a, b, c)]]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
Calc(select=[a])
+- HashJoin(joinType=[LeftAntiJoin], where=[AND(OR(IS NULL(c), IS NULL(f), =(c, f)), =(f, c))], select=[a, c], build=[right])
:- Exchange(distribution=[hash[c]])
: +- Calc(select=[a, c])
: +- LegacyTableSourceScan(table=[[default_catalog, default_database, l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+- Exchange(distribution=[hash[f]])
+- Calc(select=[f])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ LogicalProject(a1=[$0], a2=[$1], b1=[$2], b2=[$3], c1=[$4], c2=[$5], d1=[$6], d2
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
MultiJoin(joinFilter=[AND(=($4, $6), =($2, $4), =($0, $2), =($6, $2), =($4, $0), =($6, $0))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[{0, 1}, {0, 1}, {0, 1}, {0, 1}]])
MultiJoin(joinFilter=[AND(=($4, $6), =($2, $4), =($0, $2), =($4, $0), =($6, $2), =($0, $6))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[{0, 1}, {0, 1}, {0, 1}, {0, 1}]])
:- LogicalTableScan(table=[[default_catalog, default_database, A, source: [TestTableSource(a1, a2)]]])
:- LogicalTableScan(table=[[default_catalog, default_database, B, source: [TestTableSource(b1, b2)]]])
:- LogicalTableScan(table=[[default_catalog, default_database, C, source: [TestTableSource(c1, c2)]]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ LogicalProject(a1=[$0], b1=[$1], c1=[$2], a2=[$3], b2=[$4], c2=[$5], a3=[$6], b3
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a1, b1, c1, a2, b2, c2, a3, b3, c3, a4, b4, c4, a5, b5, c5])
+- Join(joinType=[InnerJoin], where=[((b3 = b5) AND (b2 = b3) AND (b3 = b1) AND (b4 = b3))], select=[a1, b1, c1, a5, b5, c5, a2, b2, c2, a4, b4, c4, a3, b3, c3], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
+- Join(joinType=[InnerJoin], where=[((b3 = b5) AND (b2 = b3) AND (b3 = b1) AND (b3 = b4))], select=[a1, b1, c1, a5, b5, c5, a2, b2, c2, a4, b4, c4, a3, b3, c3], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
:- Exchange(distribution=[hash[b5, b2, b1, b4]])
: +- Join(joinType=[InnerJoin], where=[((b1 = b4) AND (b1 = b2) AND (b5 = b1))], select=[a1, b1, c1, a5, b5, c5, a2, b2, c2, a4, b4, c4], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
: +- Join(joinType=[InnerJoin], where=[((b1 = b4) AND (b1 = b2) AND (b1 = b5))], select=[a1, b1, c1, a5, b5, c5, a2, b2, c2, a4, b4, c4], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
: :- Exchange(distribution=[hash[b1, b1, b1]])
: : +- LegacyTableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a1, b1, c1)]]], fields=[a1, b1, c1])
: +- Exchange(distribution=[hash[b4, b2, b5]])
Expand Down Expand Up @@ -453,7 +453,7 @@ LogicalProject(a4=[$0], b4=[$1], c4=[$2], a1=[$3], b1=[$4], c1=[$5], a2=[$6], b2
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a4, b4, c4, a1, b1, c1, a2, b2, c2, a3, b3, c3, a5, b5, c5])
+- Join(joinType=[InnerJoin], where=[((a4 = a5) AND (a3 = a4) AND (a4 = a2) AND (a1 = a4) AND ((b2 + b4) > 100))], select=[a2, b2, c2, a5, b5, c5, a1, b1, c1, a3, b3, c3, a4, b4, c4], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
+- Join(joinType=[InnerJoin], where=[((a4 = a5) AND (a3 = a4) AND (a4 = a2) AND (a4 = a1) AND ((b2 + b4) > 100))], select=[a2, b2, c2, a5, b5, c5, a1, b1, c1, a3, b3, c3, a4, b4, c4], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
:- Exchange(distribution=[hash[a5, a3, a2, a1]])
: +- Join(joinType=[InnerJoin], where=[((a2 = a3) AND (a1 = a2) AND (a2 = a5) AND ((b1 * b2) > 10))], select=[a2, b2, c2, a5, b5, c5, a1, b1, c1, a3, b3, c3], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
: :- Exchange(distribution=[hash[a2, a2, a2]])
Expand Down Expand Up @@ -654,7 +654,7 @@ Calc(select=[a1, b1, c1, a2, b2, c2, a3, b3, c3, a4, b4, c4, a5, b5, c5])
:- Exchange(distribution=[hash[c2, c2, c2]])
: +- LegacyTableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2])
+- Exchange(distribution=[hash[c4, c5, c3]])
+- Join(joinType=[InnerJoin], where=[((c4 = c5) AND (c4 = c3))], select=[a5, b5, c5, a3, b3, c3, a4, b4, c4], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
+- Join(joinType=[InnerJoin], where=[((c5 = c4) AND (c3 = c4))], select=[a5, b5, c5, a3, b3, c3, a4, b4, c4], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
:- Exchange(distribution=[hash[c5, c3]])
: +- Join(joinType=[InnerJoin], where=[(c5 = c3)], select=[a5, b5, c5, a3, b3, c3], leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
: :- Exchange(distribution=[hash[c5]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,24 @@

package org.apache.flink.table.planner.plan.batch.sql.join

import org.junit.Test

/**
* Test SEMI/ANTI Join, the join operators are chose based on cost.
*/
class SemiAntiJoinTest extends SemiAntiJoinTestBase {

@Test
def testNotSimplifyJoinConditionWithSameDigest(): Unit = {
// The new condition generated by the rule is digest-equaling
// (with normalization) to the old one
val sqlQuery =
"""
|SELECT a
|FROM l
|WHERE c NOT IN (
| SELECT f FROM r WHERE f = c)
|""".stripMargin
util.verifyRelPlan(sqlQuery)
}
}

0 comments on commit b17e7b5

Please sign in to comment.