Skip to content

Commit

Permalink
[FLINK-25720][python] Support Python UDTF in Thread Mode
Browse files Browse the repository at this point in the history
This closes apache#20175.
  • Loading branch information
HuangXingBo committed Jul 8, 2022
1 parent 6c0dba2 commit 915efb2
Show file tree
Hide file tree
Showing 13 changed files with 353 additions and 93 deletions.
2 changes: 1 addition & 1 deletion flink-python/dev/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ numpy>=1.14.3,<1.20; python_version < '3.7'
fastavro>=1.1.0,<1.4.8
grpcio>=1.29.0,<1.47
grpcio-tools>=1.3.5,<=1.14.2
pemja==0.1.5; python_version >= '3.7' and platform_system != 'Windows'
pemja==0.2.0; python_version >= '3.7' and platform_system != 'Windows'
httplib2>=0.19.0,<=0.20.4
protobuf<3.18
2 changes: 1 addition & 1 deletion flink-python/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ under the License.
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>pemja</artifactId>
<version>0.1.5</version>
<version>0.2.0</version>
</dependency>

<!-- Protobuf dependencies -->
Expand Down
17 changes: 15 additions & 2 deletions flink-python/pyflink/fn_execution/utils/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def normalize_one_row(value):
return [value]

if it is None:
return []
def func():
for i in []:
yield i
return func()

if isinstance(it, (list, range, Generator)):
def func():
Expand All @@ -57,7 +60,9 @@ def func():

return func()
else:
return [normalize_one_row(it)]
def func():
yield normalize_one_row(it)
return func()


def normalize_pandas_result(it):
Expand Down Expand Up @@ -303,6 +308,14 @@ def create_scalar_operation_from_proto(proto, one_arg_optimization=False,
return scalar_operation


def create_table_operation_from_proto(proto):
from pyflink.fn_execution.table.operations import TableFunctionOperation

serialized_fn = parse_function_proto(proto)
table_operation = TableFunctionOperation(serialized_fn)
return table_operation


def create_serialized_scalar_operation_from_proto(proto, one_arg_optimization=False,
one_result_optimization=False):
"""
Expand Down
4 changes: 2 additions & 2 deletions flink-python/pyflink/table/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def timestamp_func(timestamp_param):

@udf(result_type=DataTypes.ARRAY(DataTypes.BIGINT()))
def array_func(array_param):
assert array_param == [[1, 2, 3]], \
assert array_param == [[1, 2, 3]] or array_param == ((1, 2, 3),), \
'array_param is wrong value %s !' % array_param
return array_param[0]

Expand Down Expand Up @@ -500,7 +500,7 @@ def timestamp_func(timestamp_param):
return timestamp_param

def array_func(array_param):
assert array_param == [[1, 2, 3]], \
assert array_param == [[1, 2, 3]] or array_param == ((1, 2, 3),), \
'array_param is wrong value %s !' % array_param
return array_param[0]

Expand Down
13 changes: 10 additions & 3 deletions flink-python/pyflink/table/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import sys
import unittest

import pytest

from pyflink.table import DataTypes
from pyflink.table.udf import TableFunction, udtf, ScalarFunction, udf
from pyflink.table.expressions import col
Expand Down Expand Up @@ -129,15 +132,19 @@ class PyFlinkBatchUserDefinedFunctionTests(UserDefinedTableFunctionTests,
pass


@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7")
class PyFlinkEmbeddedThreadTests(UserDefinedTableFunctionTests, PyFlinkStreamTableTestCase):
def setUp(self):
super(PyFlinkEmbeddedThreadTests, self).setUp()
self.t_env.get_config().set("python.execution-mode", "thread")


class MultiEmit(TableFunction, unittest.TestCase):

def open(self, function_context):
mg = function_context.get_metric_group()
self.counter = mg.add_group("key", "value").counter("my_counter")
self.counter_sum = 0

def eval(self, x, y):
self.counter.inc(y)
self.counter_sum += y
for i in range(y):
yield x, i
Expand Down
2 changes: 1 addition & 1 deletion flink-python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def extracted_output_files(base_dir, file_path, output_directory):
'cloudpickle==2.1.0', 'avro-python3>=1.8.1,!=1.9.2,<1.10.0',
'pytz>=2018.3', 'fastavro>=1.1.0,<1.4.8', 'requests>=2.26.0',
'protobuf<3.18',
'pemja==0.1.5;'
'pemja==0.2.0;'
'python_full_version >= "3.7" and platform_system != "Windows"',
'httplib2>=0.19.0,<=0.20.4', apache_flink_libraries_dependency]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.operators.python;

import org.apache.flink.annotation.Internal;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.python.AbstractEmbeddedPythonFunctionOperator;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.operators.python.utils.StreamRecordRowDataWrappingCollector;
import org.apache.flink.table.runtime.typeutils.PythonTypeUtils;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;

import java.util.Arrays;
import java.util.stream.Collectors;

/**
* Base class for all stream operators to execute Python Stateless Functions in embedded Python
* environment.
*/
@Internal
public abstract class AbstractEmbeddedStatelessFunctionOperator
extends AbstractEmbeddedPythonFunctionOperator<RowData>
implements OneInputStreamOperator<RowData, RowData>, BoundedOneInput {

private static final long serialVersionUID = 1L;

/** The offsets of user-defined function inputs. */
protected final int[] udfInputOffsets;

/** The input logical type. */
protected final RowType inputType;

/** The user-defined function input logical type. */
protected final RowType udfInputType;

/** The user-defined function output logical type. */
protected final RowType udfOutputType;

/** The GenericRowData reused holding the execution result of python udf. */
protected transient GenericRowData reuseResultRowData;

/** The collector used to collect records. */
protected transient StreamRecordRowDataWrappingCollector rowDataWrapper;

protected transient PythonTypeUtils.DataConverter[] userDefinedFunctionInputConverters;
protected transient Object[] userDefinedFunctionInputArgs;
protected transient PythonTypeUtils.DataConverter[] userDefinedFunctionOutputConverters;

public AbstractEmbeddedStatelessFunctionOperator(
Configuration config,
RowType inputType,
RowType udfInputType,
RowType udfOutputType,
int[] udfInputOffsets) {
super(config);
this.inputType = Preconditions.checkNotNull(inputType);
this.udfInputType = Preconditions.checkNotNull(udfInputType);
this.udfOutputType = Preconditions.checkNotNull(udfOutputType);
this.udfInputOffsets = Preconditions.checkNotNull(udfInputOffsets);
}

@Override
public void open() throws Exception {
super.open();
rowDataWrapper = new StreamRecordRowDataWrappingCollector(output);
reuseResultRowData = new GenericRowData(udfOutputType.getFieldCount());
RowType userDefinedFunctionInputType =
new RowType(
Arrays.stream(udfInputOffsets)
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
userDefinedFunctionInputConverters =
userDefinedFunctionInputType.getFields().stream()
.map(RowType.RowField::getType)
.map(PythonTypeUtils::toDataConverter)
.toArray(PythonTypeUtils.DataConverter[]::new);
userDefinedFunctionInputArgs = new Object[udfInputOffsets.length];
userDefinedFunctionOutputConverters =
udfOutputType.getFields().stream()
.map(RowType.RowField::getType)
.map(PythonTypeUtils::toDataConverter)
.toArray(PythonTypeUtils.DataConverter[]::new);
}

@Override
protected void invokeFinishBundle() throws Exception {
// TODO: Support batches invoking.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.python.AbstractEmbeddedPythonFunctionOperator;
import org.apache.flink.streaming.api.utils.ProtoUtils;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.utils.JoinedRowData;
Expand All @@ -35,56 +31,30 @@
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.generated.Projection;
import org.apache.flink.table.runtime.operators.python.utils.StreamRecordRowDataWrappingCollector;
import org.apache.flink.table.runtime.typeutils.PythonTypeUtils;
import org.apache.flink.table.runtime.operators.python.AbstractEmbeddedStatelessFunctionOperator;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;

import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;

import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
import static org.apache.flink.python.PythonOptions.PYTHON_PROFILE_ENABLED;

/** The Python {@link ScalarFunction} operator in embedded Python environment. */
@Internal
public class EmbeddedPythonScalarFunctionOperator
extends AbstractEmbeddedPythonFunctionOperator<RowData>
implements OneInputStreamOperator<RowData, RowData>, BoundedOneInput {
extends AbstractEmbeddedStatelessFunctionOperator {

private static final long serialVersionUID = 1L;

/** The Python {@link ScalarFunction}s to be executed. */
private final PythonFunctionInfo[] scalarFunctions;

/** The offsets of user-defined function inputs. */
private final int[] udfInputOffsets;

/** The input logical type. */
protected final RowType inputType;

/** The user-defined function input logical type. */
protected final RowType udfInputType;

/** The user-defined function output logical type. */
protected final RowType udfOutputType;

private GeneratedProjection forwardedFieldGeneratedProjection;

/** The GenericRowData reused holding the execution result of python udf. */
private GenericRowData reuseResultRowData;

/** The collector used to collect records. */
private transient StreamRecordRowDataWrappingCollector rowDataWrapper;

/** The Projection which projects the forwarded fields from the input row. */
private transient Projection<RowData, BinaryRowData> forwardedFieldProjection;

private transient PythonTypeUtils.DataConverter[] userDefinedFunctionInputConverters;
private transient Object[] userDefinedFunctionInputArgs;
private transient PythonTypeUtils.DataConverter[] userDefinedFunctionOutputConverters;

/** Whether there is only one input argument. */
private transient boolean isOneArg;

Expand All @@ -98,11 +68,7 @@ public EmbeddedPythonScalarFunctionOperator(
RowType udfInputType,
RowType udfOutputType,
int[] udfInputOffsets) {
super(config);
this.inputType = Preconditions.checkNotNull(inputType);
this.udfInputType = Preconditions.checkNotNull(udfInputType);
this.udfOutputType = Preconditions.checkNotNull(udfOutputType);
this.udfInputOffsets = Preconditions.checkNotNull(udfInputOffsets);
super(config, inputType, udfInputType, udfOutputType, udfInputOffsets);
this.scalarFunctions = Preconditions.checkNotNull(scalarFunctions);
}

Expand All @@ -125,24 +91,6 @@ public void open() throws Exception {
isOneArg = udfInputOffsets.length == 1;
isOneFieldResult = udfOutputType.getFieldCount() == 1;
super.open();
rowDataWrapper = new StreamRecordRowDataWrappingCollector(output);
reuseResultRowData = new GenericRowData(udfOutputType.getFieldCount());
RowType userDefinedFunctionInputType =
new RowType(
Arrays.stream(udfInputOffsets)
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
userDefinedFunctionInputConverters =
userDefinedFunctionInputType.getFields().stream()
.map(RowType.RowField::getType)
.map(PythonTypeUtils::toDataConverter)
.toArray(PythonTypeUtils.DataConverter[]::new);
userDefinedFunctionInputArgs = new Object[udfInputOffsets.length];
userDefinedFunctionOutputConverters =
udfOutputType.getFields().stream()
.map(RowType.RowField::getType)
.map(PythonTypeUtils::toDataConverter)
.toArray(PythonTypeUtils.DataConverter[]::new);

if (forwardedFieldGeneratedProjection != null) {
forwardedFieldProjection =
Expand Down
Loading

0 comments on commit 915efb2

Please sign in to comment.