Skip to content

Commit

Permalink
[FLINK-19134][python] Introduce BasicArrayTypeInfo and PrimitiveArray…
Browse files Browse the repository at this point in the history
…TypeInfo for Python DataStream API

This closes apache#13327.
  • Loading branch information
shuiqiangchen authored and dianfu committed Sep 25, 2020
1 parent f547f4c commit 8525c5c
Show file tree
Hide file tree
Showing 14 changed files with 355 additions and 47 deletions.
110 changes: 110 additions & 0 deletions flink-python/pyflink/common/typeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,66 @@ def CHAR_PRIMITIVE_ARRAY_TYPE_INFO():
.PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO)


class BasicArrayTypeInfo(TypeInformation, ABC):
"""
A TypeInformation for arrays of boxed primitive types (Integer, Long, Double, ...).
Supports the creation of dedicated efficient serializers for these types.
"""
@staticmethod
def BOOLEAN_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO)

@staticmethod
def BYTE_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO)

@staticmethod
def SHORT_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO)

@staticmethod
def INT_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO)

@staticmethod
def LONG_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO)

@staticmethod
def FLOAT_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO)

@staticmethod
def DOUBLE_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO)

@staticmethod
def CHAR_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO)

@staticmethod
def STRING_ARRAY_TYPE_INFO():
return WrapperTypeInfo(
get_gateway().jvm.org.apache.flink.api.common.typeinfo
.BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO)


class PickledBytesTypeInfo(TypeInformation, ABC):
"""
A PickledBytesTypeInfo indicates the data is a primitive byte array generated by pickle
Expand Down Expand Up @@ -384,6 +444,36 @@ def PRIMITIVE_ARRAY(element_type: TypeInformation):
else:
raise TypeError("Invalid element type for a primitive array.")

@staticmethod
def BASIC_ARRAY(element_type: TypeInformation):
"""
Returns type information for arrays of boxed primitive type (such as Integer[]).
:param element_type element type of the array (e.g. Types.BOOLEAN(), Types.INT(),
Types.DOUBLE())
"""
if element_type == Types.BOOLEAN():
return BasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO()
elif element_type == Types.BYTE():
return BasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO()
elif element_type == Types.SHORT():
return BasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO()
elif element_type == Types.INT():
return BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO()
elif element_type == Types.LONG():
return BasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO()
elif element_type == Types.FLOAT():
return BasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO()
elif element_type == Types.DOUBLE():
return BasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO()
elif element_type == Types.CHAR():
return BasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO()
elif element_type == Types.STRING():
return BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO()
else:
raise TypeError("Invalid element type for a boxed primitive array: %s" %
str(element_type))


def _from_java_type(j_type_info: JavaObject) -> TypeInformation:
gateway = get_gateway()
Expand Down Expand Up @@ -440,6 +530,26 @@ def _from_java_type(j_type_info: JavaObject) -> TypeInformation:
elif _is_instance_of(j_type_info, JPrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO):
return Types.PRIMITIVE_ARRAY(Types.CHAR())

JBasicArrayTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo \
.BasicArrayTypeInfo

if _is_instance_of(j_type_info, JBasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.BOOLEAN())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.BYTE())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.SHORT())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.INT_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.INT())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.LONG())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.FLOAT())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.DOUBLE())
elif _is_instance_of(j_type_info, JBasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO):
return Types.BASIC_ARRAY(Types.CHAR())

JPickledBytesTypeInfo = gateway.jvm \
.org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo\
.PICKLED_BYTE_ARRAY_TYPE_INFO
Expand Down
37 changes: 37 additions & 0 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,43 @@ def assert_chainable(j_stream_graph, expected_upstream_chainable,
# upstream and downstream operators.
assert_chainable(j_generated_stream_graph, False, False)

def test_primitive_array_type_info(self):
ds = self.env.from_collection([(1, [1.1, 1.2, 1.30]), (2, [2.1, 2.2, 2.3]),
(3, [3.1, 3.2, 3.3])],
type_info=Types.ROW([Types.INT(),
Types.PRIMITIVE_ARRAY(Types.FLOAT())]))

ds.map(lambda x: x, output_type=Types.ROW([Types.INT(),
Types.PRIMITIVE_ARRAY(Types.FLOAT())]))\
.add_sink(self.test_sink)
self.env.execute("test primitive array type info")
results = self.test_sink.get_results()
expected = ['1,[1.1, 1.2, 1.3]', '2,[2.1, 2.2, 2.3]', '3,[3.1, 3.2, 3.3]']
results.sort()
expected.sort()
self.assertEqual(expected, results)

def test_basic_array_type_info(self):
ds = self.env.from_collection([(1, [1.1, None, 1.30], [None, 'hi', 'flink']),
(2, [None, 2.2, 2.3], ['hello', None, 'flink']),
(3, [3.1, 3.2, None], ['hello', 'hi', None])],
type_info=Types.ROW([Types.INT(),
Types.BASIC_ARRAY(Types.FLOAT()),
Types.BASIC_ARRAY(Types.STRING())]))

ds.map(lambda x: x, output_type=Types.ROW([Types.INT(),
Types.BASIC_ARRAY(Types.FLOAT()),
Types.BASIC_ARRAY(Types.STRING())]))\
.add_sink(self.test_sink)
self.env.execute("test basic array type info")
results = self.test_sink.get_results()
expected = ['1,[1.1, null, 1.3],[null, hi, flink]',
'2,[null, 2.2, 2.3],[hello, null, flink]',
'3,[3.1, 3.2, null],[hello, hi, null]']
results.sort()
expected.sort()
self.assertEqual(expected, results)

def tearDown(self) -> None:
self.test_sink.clear()

Expand Down
22 changes: 20 additions & 2 deletions flink-python/pyflink/fn_execution/beam/beam_coder_impl_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __repr__(self):
return 'TableFunctionRowCoderImpl[%s]' % repr(self._flatten_row_coder)


class ArrayCoderImpl(StreamCoderImpl):
class BasicArrayCoderImpl(StreamCoderImpl):

def __init__(self, elem_coder):
self._elem_coder = elem_coder
Expand All @@ -208,7 +208,25 @@ def decode_from_stream(self, in_stream, nested):
return elements

def __repr__(self):
return 'ArrayCoderImpl[%s]' % repr(self._elem_coder)
return 'BasicArrayCoderImpl[%s]' % repr(self._elem_coder)


class PrimitiveArrayCoderImpl(StreamCoderImpl):
def __init__(self, elem_coder):
self._elem_coder = elem_coder

def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int32(len(value))
for elem in value:
self._elem_coder.encode_to_stream(elem, out_stream, nested)

def decode_from_stream(self, in_stream, nested):
size = in_stream.read_bigendian_int32()
elements = [self._elem_coder.decode_from_stream(in_stream, nested) for _ in range(size)]
return elements

def __repr__(self):
return 'PrimitiveArrayCoderImpl[%s]' % repr(self._elem_coder)


class PickledBytesCoderImpl(StreamCoderImpl):
Expand Down
2 changes: 1 addition & 1 deletion flink-python/pyflink/fn_execution/beam/beam_coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _to_data_type(field_type):
field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.TIMESTAMP:
return TimestampType(field_type.timestamp_info.precision, field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.ARRAY:
elif field_type.type_name == flink_fn_execution_pb2.Schema.BASIC_ARRAY:
return ArrayType(_to_data_type(field_type.collection_element_type),
field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.ROW:
Expand Down
8 changes: 6 additions & 2 deletions flink-python/pyflink/fn_execution/coder_impl_fast.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ cdef enum TypeName:
BOOLEAN = 11
BINARY = 12
CHAR = 13
ARRAY = 14
BASIC_ARRAY = 14
MAP = 15
LOCAL_ZONED_TIMESTAMP = 16
PICKLED_BYTES = 17
BIG_DEC = 18
TUPLE = 19
PRIMITIVE_ARRAY = 20

cdef class FieldCoder:
cpdef CoderType coder_type(self)
Expand Down Expand Up @@ -183,7 +184,10 @@ cdef class TimestampCoderImpl(FieldCoder):
cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl):
cdef readonly object timezone

cdef class ArrayCoderImpl(FieldCoder):
cdef class BasicArrayCoderImpl(FieldCoder):
cdef readonly FieldCoder elem_coder

cdef class PrimitiveArrayCoderImpl(FieldCoder):
cdef readonly FieldCoder elem_coder

cdef class MapCoderImpl(FieldCoder):
Expand Down
48 changes: 38 additions & 10 deletions flink-python/pyflink/fn_execution/coder_impl_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ cdef class DataStreamStatelessMapCoderImpl(FlattenRowCoderImpl):
output_stream.write(self._tmp_output_data, self._tmp_output_pos)
self._tmp_output_pos = 0

cpdef void _encode_field(self, CoderType coder_type, TypeName field_type, FieldCoder field_coder,
cdef void _encode_field(self, CoderType coder_type, TypeName field_type, FieldCoder field_coder,
item):
if coder_type == SIMPLE:
self._encode_field_simple(field_type, item)
Expand All @@ -89,7 +89,7 @@ cdef class DataStreamStatelessMapCoderImpl(FlattenRowCoderImpl):
self._encode_field_complex(field_type, field_coder, item)
self._encode_data_stream_field_complex(field_type, field_coder, item)

cpdef object _decode_field(self, CoderType coder_type, TypeName field_type,
cdef object _decode_field(self, CoderType coder_type, TypeName field_type,
FieldCoder field_coder):
if coder_type == SIMPLE:
decoded_obj = self._decode_field_simple(field_type)
Expand Down Expand Up @@ -339,15 +339,23 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
milliseconds % 1000 * 1000 + nanoseconds // 1000)
return (<LocalZonedTimestampCoderImpl> field_coder).timezone.localize(
datetime.datetime.utcfromtimestamp(seconds).replace(microsecond=microseconds))
elif field_type == ARRAY:
# Array
elif field_type == BASIC_ARRAY:
# Basic Array
length = self._decode_int()
value_coder = (<ArrayCoderImpl> field_coder).elem_coder
value_coder = (<BasicArrayCoderImpl> field_coder).elem_coder
value_type = value_coder.type_name()
value_coder_type = value_coder.coder_type()
return [
self._decode_field(value_coder_type, value_type, value_coder) if self._decode_byte()
else None for _ in range(length)]
elif field_type == PRIMITIVE_ARRAY:
# Primitive Array
length = self._decode_int()
value_coder = (<PrimitiveArrayCoderImpl> field_coder).elem_coder
value_type = value_coder.type_name()
value_coder_type = value_coder.coder_type()
return [self._decode_field(value_coder_type, value_type, value_coder)
for _ in range(length)]
elif field_type == MAP:
# Map
key_coder = (<MapCoderImpl> field_coder).key_coder
Expand Down Expand Up @@ -510,10 +518,10 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
else:
self._encode_bigint(timestamp_milliseconds)
self._encode_int(nanoseconds)
elif field_type == ARRAY:
# Array
elif field_type == BASIC_ARRAY:
# Basic Array
length = len(item)
value_coder = (<ArrayCoderImpl> field_coder).elem_coder
value_coder = (<BasicArrayCoderImpl> field_coder).elem_coder
value_type = value_coder.type_name()
value_coder_type = value_coder.coder_type()
self._encode_int(length)
Expand All @@ -524,6 +532,16 @@ cdef class FlattenRowCoderImpl(BaseCoderImpl):
else:
self._encode_byte(True)
self._encode_field(value_coder_type, value_type, value_coder, value)
elif field_type == PRIMITIVE_ARRAY:
# Primitive Array
length = len(item)
value_coder = (<PrimitiveArrayCoderImpl> field_coder).elem_coder
value_type = value_coder.type_name()
value_coder_type = value_coder.coder_type()
self._encode_int(length)
for i in range(length):
value = item[i]
self._encode_field(value_coder_type, value_type, value_coder, value)
elif field_type == MAP:
# Map
length = len(item)
Expand Down Expand Up @@ -771,15 +789,25 @@ cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl):
cpdef TypeName type_name(self):
return LOCAL_ZONED_TIMESTAMP

cdef class ArrayCoderImpl(FieldCoder):
cdef class BasicArrayCoderImpl(FieldCoder):
def __cinit__(self, elem_coder):
self.elem_coder = elem_coder

cpdef CoderType coder_type(self):
return COMPLEX

cpdef TypeName type_name(self):
return BASIC_ARRAY

cdef class PrimitiveArrayCoderImpl(FieldCoder):
def __cinit__(self, elem_coder):
self.elem_coder = elem_coder

cpdef CoderType coder_type(self):
return COMPLEX

cpdef TypeName type_name(self):
return ARRAY
return PRIMITIVE_ARRAY

cdef class MapCoderImpl(FieldCoder):
def __cinit__(self, key_coder, value_coder):
Expand Down
Loading

0 comments on commit 8525c5c

Please sign in to comment.