Skip to content

Commit

Permalink
[FLINK-19675][python] Fix PythonCalcExpandProjectRule to handle cases…
Browse files Browse the repository at this point in the history
… when the calc node contains WHERE clause, composite fields access and Python UDF at the same time (apache#13746)
  • Loading branch information
dianfu committed Oct 23, 2020
1 parent 1f8627b commit ae4080c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ object PythonCalcExpandProjectRule extends PythonCalcSplitRuleBase(

override def split(program: RexProgram, splitter: ScalarFunctionSplitter)
: (Option[RexNode], Option[RexNode], Seq[RexNode]) = {
(None, None, program.getProjectList.map(program.expandLocalRef(_).accept(splitter)))
(Option(program.getCondition).map(program.expandLocalRef),
None,
program.getProjectList.map(program.expandLocalRef(_).accept(splitter)))
}

private def containsFieldAccessInputs(node: RexNode): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,27 @@ LogicalProject(EXPR$0=[pyFunc1($0, $1)], EXPR$1=[+($2, 1)])
FlinkLogicalCalc(select=[f0 AS EXPR$0, +(c, 1) AS EXPR$1])
+- FlinkLogicalCalc(select=[c, pyFunc1(a, b) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
<TestCase name="testPythonFunctionMixedWithJavaFunctionInWhereClause">
<Resource name="sql">
<![CDATA[SELECT pyFunc1(a, b), c + 1 FROM MyTable WHERE pyFunc2(a, c) > 0]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EXPR$0=[pyFunc1($0, $1)], EXPR$1=[+($2, 1)])
+- LogicalFilter(condition=[>(pyFunc2($0, $2), 0)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[f0 AS EXPR$0, +(c, 1) AS EXPR$1])
+- FlinkLogicalCalc(select=[c, pyFunc1(a, b) AS f0])
+- FlinkLogicalCalc(select=[c, a, b], where=[>(f0, 0)])
+- FlinkLogicalCalc(select=[a, b, c, pyFunc2(a, c) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
Expand All @@ -405,24 +426,22 @@ FlinkLogicalCalc(select=[a, f0 AS EXPR$1, b])
]]>
</Resource>
</TestCase>
<TestCase name="testPythonFunctionMixedWithJavaFunctionInWhereClause">
<TestCase name="testPythonFunctionWithCompositeInputsAndWhereClause">
<Resource name="sql">
<![CDATA[SELECT pyFunc1(a, b), c + 1 FROM MyTable WHERE pyFunc2(a, c) > 0]]>
<![CDATA[SELECT a, pyFunc1(b, d._1) FROM MyTable WHERE a + 1 > 0]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EXPR$0=[pyFunc1($0, $1)], EXPR$1=[+($2, 1)])
+- LogicalFilter(condition=[>(pyFunc2($0, $2), 0)])
LogicalProject(a=[$0], EXPR$1=[pyFunc1($1, $3._1)])
+- LogicalFilter(condition=[>(+($0, 1), 0)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
FlinkLogicalCalc(select=[f0 AS EXPR$0, +(c, 1) AS EXPR$1])
+- FlinkLogicalCalc(select=[c, pyFunc1(a, b) AS f0])
+- FlinkLogicalCalc(select=[c, a, b], where=[>(f0, 0)])
+- FlinkLogicalCalc(select=[a, b, c, pyFunc2(a, c) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
FlinkLogicalCalc(select=[a, pyFunc1(b, f0) AS EXPR$1])
+- FlinkLogicalCalc(select=[a, b, d._1 AS f0], where=[>(+(a, 1), 0)])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ class PythonCalcSplitRuleTest extends TableTestBase {
util.verifyPlan(sqlQuery)
}

@Test
def testPythonFunctionWithCompositeInputsAndWhereClause(): Unit = {
val sqlQuery = "SELECT a, pyFunc1(b, d._1) FROM MyTable WHERE a + 1 > 0"
util.verifyPlan(sqlQuery)
}

@Test
def testChainingPythonFunctionWithCompositeInputs(): Unit = {
val sqlQuery = "SELECT a, pyFunc1(b, pyFunc1(c, d._1)) FROM MyTable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ object PythonCalcExpandProjectRule extends PythonCalcSplitRuleBase(

override def split(program: RexProgram, splitter: ScalarFunctionSplitter)
: (Option[RexNode], Option[RexNode], Seq[RexNode]) = {
(None, None, program.getProjectList.map(program.expandLocalRef(_).accept(splitter)))
(Option(program.getCondition).map(program.expandLocalRef),
None,
program.getProjectList.map(program.expandLocalRef(_).accept(splitter)))
}

private def containsFieldAccessInputs(node: RexNode): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,30 @@ class PythonCalcSplitRuleTest extends TableTestBase {
util.verifyTable(resultTable, expected)
}

@Test
def testPythonFunctionWithCompositeInputsAndWhereClause(): Unit = {
val util = streamTestUtil()
val table = util.addTable[(Int, Int, (Int, Int))]("MyTable", 'a, 'b, 'c)
util.tableEnv.registerFunction("pyFunc1", new PythonScalarFunction("pyFunc1"))

val resultTable = table.select('a, 'b, 'c.flatten())
.select($"a", call("pyFunc1", $"b", $"c$$_1"))
.where($"a".plus(lit(1)).isGreater(lit(0)))

val expected = unaryNode(
"DataStreamPythonCalc",
unaryNode(
"DataStreamCalc",
streamTableNode(table),
term("select", "a", "b", "c._1 AS f0"),
term("where", ">(+(a, 1), 0)")
),
term("select", "a", "pyFunc1(b, f0) AS _c1")
)

util.verifyTable(resultTable, expected)
}

@Test
def testPandasFunctionWithCompositeInputs(): Unit = {
val util = streamTestUtil()
Expand Down

0 comments on commit ae4080c

Please sign in to comment.