Skip to content

Commit

Permalink
[FLINK-25719][python] Support General Python UDF in Thread Mode
Browse files Browse the repository at this point in the history
This closes apache#18418.
  • Loading branch information
HuangXingBo committed Jan 27, 2022
1 parent ca519f6 commit 51eb386
Show file tree
Hide file tree
Showing 72 changed files with 1,795 additions and 572 deletions.
6 changes: 6 additions & 0 deletions docs/layouts/shortcodes/generated/python_configuration.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
<td>String</td>
<td>Specify the path of the python interpreter used to execute the python UDF worker. The python UDF worker depends on Python 3.6+, Apache Beam (version == 2.27.0), Pip (version &gt;= 7.1.0) and SetupTools (version &gt;= 37.0.0). Please ensure that the specified environment meets the above requirements. The option is equivalent to the command line option "-pyexec".</td>
</tr>
<tr>
<td><h5>python.execution-mode</h5></td>
<td style="word-wrap: break-word;">"process"</td>
<td>String</td>
<td>Specify the python runtime execution mode. The optional values are `process`, `multi-thread` and `sub-interpreter`. The `process` mode means that the Python user-defined functions will be executed in separate Python process. The `multi-thread` mode means that the Python user-defined functions will be executed in the same thread as Java Operator, but it will be affected by GIL performance. The `sub-interpreter` mode means that the Python user-defined functions will be executed in python different sub-interpreters rather than different threads of one interpreter, which can largely overcome the effects of the GIL, but it maybe fail in some CPython extensions libraries, such as numpy, tensorflow. Note that if the python operator dose not support `multi-thread` and `sub-interpreter` mode, we will still use `process` mode.</td>
</tr>
<tr>
<td><h5>python.files</h5></td>
<td style="word-wrap: break-word;">(none)</td>
Expand Down
1 change: 1 addition & 0 deletions flink-python/dev/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ numpy>=1.14.3,<1.20
fastavro>=0.21.4,<0.24
grpcio>=1.29.0,<2
grpcio-tools>=1.3.5,<=1.14.2
pemja==0.1.2; python_version >= '3.7'
8 changes: 8 additions & 0 deletions flink-python/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ under the License.
<artifactId>beam-runners-core-java</artifactId>
</dependency>

<!-- PemJa dependencies -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>pemja</artifactId>
<version>0.1.2</version>
</dependency>

<!-- Protobuf dependencies -->

<dependency>
Expand Down Expand Up @@ -375,6 +382,7 @@ under the License.
<include>org.apache.arrow:*</include>
<include>io.netty:*</include>
<include>com.google.flatbuffers:*</include>
<include>com.alibaba:pemja</include>
</includes>
</artifactSet>
<filters>
Expand Down
17 changes: 9 additions & 8 deletions flink-python/pyflink/fn_execution/beam/beam_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ def _create_user_defined_function_operation(factory, transform_proto, consumers,
side_inputs=None,
output_coders=[output_coders[tag] for tag in output_tags])

if hasattr(spec.serialized_fn, "key_type"):
serialized_fn = spec.serialized_fn
if hasattr(serialized_fn, "key_type"):
# keyed operation, need to create the KeyedStateBackend.
row_schema = spec.serialized_fn.key_type.row_schema
row_schema = serialized_fn.key_type.row_schema
key_row_coder = FlattenRowCoder([from_proto(f.type) for f in row_schema.fields])
if spec.serialized_fn.HasField('group_window'):
if spec.serialized_fn.group_window.is_time_window:
if serialized_fn.HasField('group_window'):
if serialized_fn.group_window.is_time_window:
window_coder = TimeWindowCoder()
else:
window_coder = CountWindowCoder()
Expand All @@ -166,9 +167,9 @@ def _create_user_defined_function_operation(factory, transform_proto, consumers,
factory.state_handler,
key_row_coder,
window_coder,
spec.serialized_fn.state_cache_size,
spec.serialized_fn.map_state_read_cache_size,
spec.serialized_fn.map_state_write_cache_size)
serialized_fn.state_cache_size,
serialized_fn.map_state_read_cache_size,
serialized_fn.map_state_write_cache_size)

return beam_operation_cls(
transform_proto.unique_name,
Expand All @@ -179,7 +180,7 @@ def _create_user_defined_function_operation(factory, transform_proto, consumers,
internal_operation_cls,
keyed_state_backend)
elif internal_operation_cls == datastream_operations.StatefulOperation:
key_row_coder = from_type_info_proto(spec.serialized_fn.key_type_info)
key_row_coder = from_type_info_proto(serialized_fn.key_type_info)
keyed_state_backend = RemoteKeyedStateBackend(
factory.state_handler,
key_row_coder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ cdef class StatelessFunctionOperation(FunctionOperation):
name, spec, counter_factory, sampler, consumers, operation_cls)

cdef object generate_operation(self):
return self.operation_cls(self.spec)
return self.operation_cls(self.spec.serialized_fn)


cdef class StatefulFunctionOperation(FunctionOperation):
Expand All @@ -211,7 +211,7 @@ cdef class StatefulFunctionOperation(FunctionOperation):
name, spec, counter_factory, sampler, consumers, operation_cls)

cdef object generate_operation(self):
return self.operation_cls(self.spec, self._keyed_state_backend)
return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend)

cpdef void add_timer_info(self, timer_family_id, timer_info):
# ignore timer_family_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cl
name, spec, counter_factory, sampler, consumers, operation_cls)

def generate_operation(self):
return self.operation_cls(self.spec)
return self.operation_cls(self.spec.serialized_fn)


class StatefulFunctionOperation(FunctionOperation):
Expand All @@ -161,7 +161,7 @@ def __init__(self, name, spec, counter_factory, sampler, consumers, operation_cl
name, spec, counter_factory, sampler, consumers, operation_cls)

def generate_operation(self):
return self.operation_cls(self.spec, self._keyed_state_backend)
return self.operation_cls(self.spec.serialized_fn, self._keyed_state_backend)

def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo):
# ignore timer_family_id
Expand Down
5 changes: 4 additions & 1 deletion flink-python/pyflink/fn_execution/coder_impl_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import pickle
from typing import List, Union

import cloudpickle
import pyarrow as pa

from pyflink.common import Row, RowKind
from pyflink.common.time import Instant
Expand Down Expand Up @@ -437,6 +436,8 @@ cdef class ArrowCoderImpl(FieldCoderImpl):
self._batch_reader = self._load_from_stream(self._resettable_io)

cpdef encode_to_stream(self, cols, OutputStream out_stream):
import pyarrow as pa

self._resettable_io.set_output_stream(out_stream)
batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
batch_writer.write_batch(
Expand All @@ -451,6 +452,8 @@ cdef class ArrowCoderImpl(FieldCoderImpl):
return arrow_to_pandas(self._timezone, self._field_types, [next(self._batch_reader)])

def _load_from_stream(self, stream):
import pyarrow as pa

while stream.readable():
reader = pa.ipc.open_stream(stream)
yield reader.read_next_batch()
Expand Down
5 changes: 4 additions & 1 deletion flink-python/pyflink/fn_execution/coder_impl_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import List

import cloudpickle
import pyarrow as pa

from pyflink.common import Row, RowKind
from pyflink.common.time import Instant
Expand Down Expand Up @@ -281,6 +280,8 @@ def __init__(self, schema, row_type, timezone):
self._batch_reader = ArrowCoderImpl._load_from_stream(self._resettable_io)

def encode_to_stream(self, cols, out_stream: OutputStream):
import pyarrow as pa

self._resettable_io.set_output_stream(out_stream)
batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
batch_writer.write_batch(
Expand All @@ -296,6 +297,8 @@ def decode_one_batch_from_stream(self, in_stream: InputStream, size: int) -> Lis

@staticmethod
def _load_from_stream(stream):
import pyarrow as pa

while stream.readable():
reader = pa.ipc.open_stream(stream)
yield reader.read_next_batch()
Expand Down
90 changes: 48 additions & 42 deletions flink-python/pyflink/fn_execution/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
import os
from abc import ABC, abstractmethod

import pyarrow as pa
import pytz

from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, DateTypeInfo, \
TimeTypeInfo, TimestampTypeInfo, PrimitiveArrayTypeInfo, BasicArrayTypeInfo, TupleTypeInfo, \
MapTypeInfo, ListTypeInfo, RowTypeInfo, PickledBytesTypeInfo, ObjectArrayTypeInfo, \
ExternalTypeInfo
from pyflink.fn_execution import flink_fn_execution_pb2
from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType
Expand Down Expand Up @@ -59,6 +57,8 @@ def get_impl(self):

@classmethod
def from_coder_info_descriptor_proto(cls, coder_info_descriptor_proto):
from pyflink.fn_execution import flink_fn_execution_pb2

field_coder = cls._to_field_coder(coder_info_descriptor_proto)
mode = coder_info_descriptor_proto.mode
separated_with_end_message = coder_info_descriptor_proto.separated_with_end_message
Expand Down Expand Up @@ -98,11 +98,15 @@ def _to_field_coder(cls, coder_info_descriptor_proto):

@classmethod
def _to_arrow_schema(cls, row_type):
import pyarrow as pa

return pa.schema([pa.field(n, to_arrow_type(t), t._nullable)
for n, t in zip(row_type.field_names(), row_type.field_types())])

@classmethod
def _to_data_type(cls, field_type):
from pyflink.fn_execution import flink_fn_execution_pb2

if field_type.type_name == flink_fn_execution_pb2.Schema.TINYINT:
return TinyIntType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.SMALLINT:
Expand Down Expand Up @@ -593,31 +597,32 @@ def get_impl(self):
return coder_impl.DataViewFilterCoderImpl(self._udf_data_view_specs)


type_name = flink_fn_execution_pb2.Schema
_type_name_mappings = {
type_name.TINYINT: TinyIntCoder(),
type_name.SMALLINT: SmallIntCoder(),
type_name.INT: IntCoder(),
type_name.BIGINT: BigIntCoder(),
type_name.BOOLEAN: BooleanCoder(),
type_name.FLOAT: FloatCoder(),
type_name.DOUBLE: DoubleCoder(),
type_name.BINARY: BinaryCoder(),
type_name.VARBINARY: BinaryCoder(),
type_name.CHAR: CharCoder(),
type_name.VARCHAR: CharCoder(),
type_name.DATE: DateCoder(),
type_name.TIME: TimeCoder(),
}


def from_proto(field_type):
"""
Creates the corresponding :class:`Coder` given the protocol representation of the field type.
:param field_type: the protocol representation of the field type
:return: :class:`Coder`
"""
from pyflink.fn_execution import flink_fn_execution_pb2

type_name = flink_fn_execution_pb2.Schema
_type_name_mappings = {
type_name.TINYINT: TinyIntCoder(),
type_name.SMALLINT: SmallIntCoder(),
type_name.INT: IntCoder(),
type_name.BIGINT: BigIntCoder(),
type_name.BOOLEAN: BooleanCoder(),
type_name.FLOAT: FloatCoder(),
type_name.DOUBLE: DoubleCoder(),
type_name.BINARY: BinaryCoder(),
type_name.VARBINARY: BinaryCoder(),
type_name.CHAR: CharCoder(),
type_name.VARCHAR: CharCoder(),
type_name.DATE: DateCoder(),
type_name.TIME: TimeCoder(),
}

field_type_name = field_type.type_name
coder = _type_name_mappings.get(field_type_name)
if coder is not None:
Expand All @@ -642,29 +647,30 @@ def from_proto(field_type):
raise ValueError("field_type %s is not supported." % field_type)


# for data stream type information.
type_info_name = flink_fn_execution_pb2.TypeInfo
_type_info_name_mappings = {
type_info_name.STRING: CharCoder(),
type_info_name.BYTE: TinyIntCoder(),
type_info_name.BOOLEAN: BooleanCoder(),
type_info_name.SHORT: SmallIntCoder(),
type_info_name.INT: IntCoder(),
type_info_name.LONG: BigIntCoder(),
type_info_name.FLOAT: FloatCoder(),
type_info_name.DOUBLE: DoubleCoder(),
type_info_name.CHAR: CharCoder(),
type_info_name.BIG_INT: BigIntCoder(),
type_info_name.BIG_DEC: BigDecimalCoder(),
type_info_name.SQL_DATE: DateCoder(),
type_info_name.SQL_TIME: TimeCoder(),
type_info_name.SQL_TIMESTAMP: TimestampCoder(3),
type_info_name.PICKLED_BYTES: CloudPickleCoder(),
type_info_name.INSTANT: InstantCoder()
}


def from_type_info_proto(type_info):
# for data stream type information.
from pyflink.fn_execution import flink_fn_execution_pb2

type_info_name = flink_fn_execution_pb2.TypeInfo
_type_info_name_mappings = {
type_info_name.STRING: CharCoder(),
type_info_name.BYTE: TinyIntCoder(),
type_info_name.BOOLEAN: BooleanCoder(),
type_info_name.SHORT: SmallIntCoder(),
type_info_name.INT: IntCoder(),
type_info_name.LONG: BigIntCoder(),
type_info_name.FLOAT: FloatCoder(),
type_info_name.DOUBLE: DoubleCoder(),
type_info_name.CHAR: CharCoder(),
type_info_name.BIG_INT: BigIntCoder(),
type_info_name.BIG_DEC: BigDecimalCoder(),
type_info_name.SQL_DATE: DateCoder(),
type_info_name.SQL_TIME: TimeCoder(),
type_info_name.SQL_TIMESTAMP: TimestampCoder(3),
type_info_name.PICKLED_BYTES: CloudPickleCoder(),
type_info_name.INSTANT: InstantCoder()
}

field_type_name = type_info.type_name
try:
return _type_info_name_mappings[field_type_name]
Expand Down
Loading

0 comments on commit 51eb386

Please sign in to comment.