Skip to content

Commit

Permalink
[FLINK-22131][python] Fix the bug of general udf and pandas udf chain…
Browse files Browse the repository at this point in the history
…ed together in map operation

This closes apache#15502.
  • Loading branch information
HuangXingBo committed Apr 13, 2021
1 parent fad4874 commit ba5cc58
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 11 deletions.
11 changes: 10 additions & 1 deletion flink-python/pyflink/table/tests/test_row_based_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def func(x):
def func2(x):
return x * 2

def func3(x):
assert isinstance(x, Row)
return x

pandas_udf = udf(func,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
Expand All @@ -86,7 +90,12 @@ def func2(x):
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')

t.map(pandas_udf).map(pandas_udf_2).execute_insert("Results").wait()
general_udf = udf(func3,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]))

t.map(pandas_udf).map(pandas_udf_2).map(general_udf).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(
actual,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.PythonUtil;

Expand Down Expand Up @@ -179,14 +180,29 @@ public void onMatch(RelOptRuleCall call) {
Collections.singletonList("f0"),
call.builder().getRexBuilder()));

// merge bottomCalc
RexBuilder rexBuilder = call.builder().getRexBuilder();
RexProgram mergedProgram =
RexProgramBuilder.mergePrograms(
topMiddleMergedCalc.getProgram(), bottomCalc.getProgram(), rexBuilder);
Calc newCalc =
topMiddleMergedCalc.copy(
topMiddleMergedCalc.getTraitSet(), bottomCalc.getInput(), mergedProgram);
call.transformTo(newCalc);
RexProgram bottomProgram = bottomCalc.getProgram();
List<RexCall> bottomProjects =
bottomProgram.getProjectList().stream()
.map(bottomProgram::expandLocalRef)
.map(x -> (RexCall) x)
.collect(Collectors.toList());
RexCall bottomPythonCall = bottomProjects.get(0);
// Only Python Functions with same Python function kind can be merged together.
if (PythonUtil.isPythonCall(topPythonCall, PythonFunctionKind.GENERAL)
^ PythonUtil.isPythonCall(bottomPythonCall, PythonFunctionKind.GENERAL)) {
call.transformTo(topMiddleMergedCalc);
} else {
// merge bottomCalc
RexBuilder rexBuilder = call.builder().getRexBuilder();
RexProgram mergedProgram =
RexProgramBuilder.mergePrograms(
topMiddleMergedCalc.getProgram(), bottomCalc.getProgram(), rexBuilder);
Calc newCalc =
topMiddleMergedCalc.copy(
topMiddleMergedCalc.getTraitSet(),
bottomCalc.getInput(),
mergedProgram);
call.transformTo(newCalc);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,17 @@ public PythonFunctionKind getPythonFunctionKind() {
return PythonFunctionKind.PANDAS;
}
}

/** Test for Pandas Python Scalar Function. */
public static class RowPandasScalarFunction extends RowPythonScalarFunction {

public RowPandasScalarFunction(String name) {
super(name);
}

@Override
public PythonFunctionKind getPythonFunctionKind() {
return PythonFunctionKind.PANDAS;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,20 @@ FlinkLogicalCalc(select=[f0.f0 AS _c0, f0.f1 AS _c1])
]]>
</Resource>
</TestCase>
<TestCase name="testMapOperationMixedWithPandasUDFAndGeneralUDF">
<Resource name="ast">
<![CDATA[
LogicalProject(_c0=[org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPandasScalarFunction$55ec12d1188da02d641be38b6cf77e21(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$19523bda2dba321ac79aaa8d4b9febb0($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$19523bda2dba321ac79aaa8d4b9febb0($0, $1, $2).f1).f0], _c1=[org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPandasScalarFunction$55ec12d1188da02d641be38b6cf77e21(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$19523bda2dba321ac79aaa8d4b9febb0($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$19523bda2dba321ac79aaa8d4b9febb0($0, $1, $2).f1).f1])
+- LogicalTableScan(table=[[default_catalog, default_database, source, source: [TestTableSource(a, b, c)]]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
FlinkLogicalCalc(select=[f0.f0 AS _c0, f0.f1 AS _c1])
+- FlinkLogicalCalc(select=[pandas_func(f0) AS f0])
+- FlinkLogicalCalc(select=[general_func(a, b, c) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, source, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
</Root>
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.flink.table.api._
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.optimize.program._
import org.apache.flink.table.planner.plan.rules.{FlinkBatchRuleSets, FlinkStreamRuleSets}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.RowPythonScalarFunction
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{RowPandasScalarFunction, RowPythonScalarFunction}
import org.apache.flink.table.planner.utils.TableTestBase

import org.apache.calcite.plan.hep.HepMatchOrder
Expand Down Expand Up @@ -64,4 +64,15 @@ class PythonMapMergeRuleTest extends TableTestBase {
.map(func(withColumns('*)))
util.verifyRelPlan(result)
}

@Test
def testMapOperationMixedWithPandasUDFAndGeneralUDF(): Unit = {
val sourceTable = util.addTableSource[(Int, Int, Int)]("source", 'a, 'b, 'c)
val general_func = new RowPythonScalarFunction("general_func")
val pandas_func = new RowPandasScalarFunction("pandas_func")

val result = sourceTable.map(general_func(withColumns('*)))
.map(pandas_func(withColumns('*)))
util.verifyRelPlan(result)
}
}

0 comments on commit ba5cc58

Please sign in to comment.