Skip to content

Commit

Permalink
[FLINK-20702][python] Support map operation chained together in Pytho…
Browse files Browse the repository at this point in the history
…n Table API

This closes apache#14473.
  • Loading branch information
HuangXingBo authored and dianfu committed Dec 25, 2020
1 parent b474d28 commit febce35
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 6 deletions.
2 changes: 2 additions & 0 deletions flink-python/pyflink/fn_execution/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def wrap_inputs_as_row(*args):
import pandas as pd
if type(args[0]) == pd.Series:
return pd.concat(args, axis=1)
elif len(args) == 1 and isinstance(args[0], (pd.DataFrame, Row, Tuple)):
return args[0]
else:
return Row(*args)

Expand Down
14 changes: 12 additions & 2 deletions flink-python/pyflink/table/tests/test_row_based_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,24 @@ def func(x):
res = pd.concat([x.a, x.c + x.d], axis=1)
return res

def func2(x):
return x * 2

pandas_udf = udf(func,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')
t.map(pandas_udf).execute_insert("Results").wait()

pandas_udf_2 = udf(func2,
result_type=DataTypes.ROW(
[DataTypes.FIELD("c", DataTypes.BIGINT()),
DataTypes.FIELD("d", DataTypes.BIGINT())]),
func_type='pandas')

t.map(pandas_udf).map(pandas_udf_2).execute_insert("Results").wait()
actual = source_sink_utils.results()
self.assert_equals(actual, ["2,4", "1,5", "1,14", "1,9", "2,7"])
self.assert_equals(actual, ["4,8", "2,10", "2,28", "2,18", "4,14"])

def test_flat_map(self):
t = self.t_env.from_elements(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.PythonUtil;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

/**
* Rule will merge Python {@link FlinkLogicalCalc} used in Map operation, Flatten {@link FlinkLogicalCalc}
* and Python {@link FlinkLogicalCalc} used in Map operation together.
*/
public class PythonMapMergeRule extends RelOptRule {

public static final PythonMapMergeRule INSTANCE = new PythonMapMergeRule();

private PythonMapMergeRule() {
super(operand(FlinkLogicalCalc.class,
operand(FlinkLogicalCalc.class,
operand(FlinkLogicalCalc.class, none()))),
"PythonMapMergeRule");
}

@Override
public boolean matches(RelOptRuleCall call) {
FlinkLogicalCalc topCalc = call.rel(0);
FlinkLogicalCalc middleCalc = call.rel(1);
FlinkLogicalCalc bottomCalc = call.rel(2);

RexProgram topProgram = topCalc.getProgram();
List<RexNode> topProjects = topProgram.getProjectList()
.stream()
.map(topProgram::expandLocalRef)
.collect(Collectors.toList());

if (topProjects.size() != 1 || PythonUtil.isNonPythonCall(topProjects.get(0)) ||
!PythonUtil.takesRowAsInput((RexCall) topProjects.get(0))) {
return false;
}

RexProgram bottomProgram = bottomCalc.getProgram();
List<RexNode> bottomProjects = bottomProgram.getProjectList()
.stream()
.map(bottomProgram::expandLocalRef)
.collect(Collectors.toList());
if (bottomProjects.size() != 1 || PythonUtil.isNonPythonCall(bottomProjects.get(0))) {
return false;
}

RexProgram middleProgram = middleCalc.getProgram();
if (middleProgram.getCondition() != null) {
return false;
}

List<RexNode> middleProjects = middleProgram.getProjectList()
.stream()
.map(middleProgram::expandLocalRef)
.collect(Collectors.toList());
int inputRowFieldCount = middleProgram.getInputRowType()
.getFieldList()
.get(0)
.getValue()
.getFieldList().size();

return isFlattenCalc(middleProjects, inputRowFieldCount) &&
isTopCalcTakesWholeMiddleCalcAsInputs((RexCall) topProjects.get(0), middleProjects.size());
}

private boolean isTopCalcTakesWholeMiddleCalcAsInputs(RexCall pythonCall, int inputColumnCount) {
List<RexNode> pythonCallInputs = pythonCall.getOperands();
if (pythonCallInputs.size() != inputColumnCount) {
return false;
}
for (int i = 0; i < pythonCallInputs.size(); i++) {
RexNode input = pythonCallInputs.get(i);
if (input instanceof RexInputRef) {
if (((RexInputRef) input).getIndex() != i) {
return false;
}
} else {
return false;
}
}
return true;
}

private boolean isFlattenCalc(List<RexNode> middleProjects, int inputRowFieldCount) {
if (inputRowFieldCount != middleProjects.size()) {
return false;
}
for (int i = 0; i < inputRowFieldCount; i++) {
RexNode middleProject = middleProjects.get(i);
if (middleProject instanceof RexFieldAccess) {
RexFieldAccess rexField = ((RexFieldAccess) middleProject);
if (rexField.getField().getIndex() != i) {
return false;
}
RexNode expr = rexField.getReferenceExpr();
if (expr instanceof RexInputRef) {
if (((RexInputRef) expr).getIndex() != 0) {
return false;
}
} else {
return false;
}
} else {
return false;
}
}
return true;
}

@Override
public void onMatch(RelOptRuleCall call) {
FlinkLogicalCalc topCalc = call.rel(0);
FlinkLogicalCalc middleCalc = call.rel(1);
FlinkLogicalCalc bottomCalc = call.rel(2);

RexProgram topProgram = topCalc.getProgram();
List<RexCall> topProjects = topProgram.getProjectList()
.stream()
.map(topProgram::expandLocalRef)
.map(x -> (RexCall) x)
.collect(Collectors.toList());
RexCall topPythonCall = topProjects.get(0);

// merge topCalc and middleCalc
RexCall newPythonCall = topPythonCall.clone(topPythonCall.getType(),
Collections.singletonList(RexInputRef.of(0, bottomCalc.getRowType())));
List<RexCall> topMiddleMergedProjects = Collections.singletonList(newPythonCall);
FlinkLogicalCalc topMiddleMergedCalc = new FlinkLogicalCalc(
middleCalc.getCluster(),
middleCalc.getTraitSet(),
bottomCalc,
RexProgram.create(
bottomCalc.getRowType(),
topMiddleMergedProjects,
null,
Collections.singletonList("f0"),
call.builder().getRexBuilder()));

// merge bottomCalc
RexBuilder rexBuilder = call.builder().getRexBuilder();
RexProgram mergedProgram = RexProgramBuilder.mergePrograms(
topMiddleMergedCalc.getProgram(), bottomCalc.getProgram(), rexBuilder);
Calc newCalc = topMiddleMergedCalc.copy(
topMiddleMergedCalc.getTraitSet(), bottomCalc.getInput(), mergedProgram);
call.transformTo(newCalc);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,9 @@ object FlinkBatchRuleSets {
PythonCalcSplitRule.SPLIT_PANDAS_IN_PROJECT,
PythonCalcSplitRule.EXPAND_PROJECT,
PythonCalcSplitRule.PUSH_CONDITION,
PythonCalcSplitRule.REWRITE_PROJECT
)
PythonCalcSplitRule.REWRITE_PROJECT,
PythonMapMergeRule.INSTANCE
)

/**
* RuleSet to do physical optimize for batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,9 @@ object FlinkStreamRuleSets {
PythonCalcSplitRule.SPLIT_PANDAS_IN_PROJECT,
PythonCalcSplitRule.EXPAND_PROJECT,
PythonCalcSplitRule.PUSH_CONDITION,
PythonCalcSplitRule.REWRITE_PROJECT
)
PythonCalcSplitRule.REWRITE_PROJECT,
PythonMapMergeRule.INSTANCE
)

/**
* RuleSet to do physical optimize for stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ object PythonUtil {
}
}

def takesRowAsInput(call: RexCall): Boolean = {
(call.getOperator match {
case sfc: ScalarSqlFunction => sfc.scalarFunction
case tfc: TableSqlFunction => tfc.udtf
case bsf: BridgingSqlFunction => bsf.getDefinition
}).asInstanceOf[PythonFunction].takesRowAsInput()
}

private[this] def isPythonFunction(
function: FunctionDefinition,
pythonFunctionKind: PythonFunctionKind): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,20 @@ public Row eval(int a) {
return Row.of(a + 1, Row.of(a * a));
}

public Row eval(Object... args) {
return Row.of(1, Row.of(2));
}

@Override
public TypeInformation<?> getResultType(Class<?>[] signature) {
return Types.ROW(BasicTypeInfo.INT_TYPE_INFO, Types.ROW(BasicTypeInfo.INT_TYPE_INFO));
}

@Override
public boolean takesRowAsInput() {
return true;
}

@Override
public String toString() {
return name;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<?xml version="1.0" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to you under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http:https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<Root>
<TestCase name="testMapOperationsChained">
<Resource name="ast">
<![CDATA[
LogicalProject(_c0=[org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f1).f0], _c1=[org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7(org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f0, org$apache$flink$table$planner$runtime$utils$JavaUserDefinedScalarFunctions$RowPythonScalarFunction$92c809bc96452fbf4a7f26bbd91364c7($0, $1, $2).f1).f1).f1])
+- LogicalTableScan(table=[[default_catalog, default_database, source, source: [TestTableSource(a, b, c)]]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
FlinkLogicalCalc(select=[f0.f0 AS _c0, f0.f1 AS _c1])
+- FlinkLogicalCalc(select=[pyFunc2(pyFunc2(pyFunc2(a, b, c))) AS f0])
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, source, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
</Root>
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.logical

import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.optimize.program._
import org.apache.flink.table.planner.plan.rules.{FlinkBatchRuleSets, FlinkStreamRuleSets}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.RowPythonScalarFunction
import org.apache.flink.table.planner.utils.TableTestBase

import org.apache.calcite.plan.hep.HepMatchOrder

import org.junit.{Before, Test}

/**
* Test for [[PythonMapMergeRule]].
*/
class PythonMapMergeRuleTest extends TableTestBase {
private val util = batchTestUtil()

@Before
def setup(): Unit = {
val programs = new FlinkChainedProgram[BatchOptimizeContext]()
programs.addLast(
"logical",
FlinkVolcanoProgramBuilder.newBuilder
.add(FlinkBatchRuleSets.LOGICAL_OPT_RULES)
.setRequiredOutputTraits(Array(FlinkConventions.LOGICAL))
.build())
programs.addLast(
"logical_rewrite",
FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
.add(FlinkStreamRuleSets.LOGICAL_REWRITE)
.build())
util.replaceBatchProgram(programs)
}

@Test
def testMapOperationsChained(): Unit = {
val sourceTable = util.addTableSource[(Int, Int, Int)]("source", 'a, 'b, 'c)
val func = new RowPythonScalarFunction("pyFunc2")
val result = sourceTable.map(func(withColumns('*)))
.map(func(withColumns('*)))
.map(func(withColumns('*)))
util.verifyRelPlan(result)
}
}

0 comments on commit febce35

Please sign in to comment.