Skip to content

Commit

Permalink
[FLINK-18885][python] Add partitioning interfaces for Python DataStre…
Browse files Browse the repository at this point in the history
…am API. (apache#13119)
  • Loading branch information
shuiqiangchen authored Aug 12, 2020
1 parent 808ec56 commit 82b3669
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 0 deletions.
115 changes: 115 additions & 0 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,100 @@ def flat_map(self, value):
filtered_stream.name("Filter")
return filtered_stream

def union(self, *streams) -> 'DataStream':
"""
Creates a new DataStream by merging DataStream outputs of the same type with each other. The
DataStreams merged using this operator will be transformed simultaneously.
:param streams: The DataStream to union outputwith.
:return: The DataStream.
"""
j_data_streams = []
for data_stream in streams:
j_data_streams.append(data_stream._j_data_stream)
gateway = get_gateway()
j_data_stream_class = gateway.jvm.org.apache.flink.streaming.api.datastream.DataStream
j_data_stream_arr = get_gateway().new_array(j_data_stream_class, len(j_data_streams))
for i in range(len(j_data_streams)):
j_data_stream_arr[i] = j_data_streams[i]
j_united_stream = self._j_data_stream.union(j_data_stream_arr)
return DataStream(j_data_stream=j_united_stream)

def shuffle(self) -> 'DataStream':
"""
Sets the partitioning of the DataStream so that the output elements are shuffled uniformly
randomly to the next operation.
:return: The DataStream with shuffle partitioning set.
"""
return DataStream(self._j_data_stream.shuffle())

def project(self, *field_indexes) -> 'DataStream':
"""
Initiates a Project transformation on a Tuple DataStream.
Note that only Tuple DataStreams can be projected.
:param field_indexes: The field indexes of the input tuples that are retained. The order of
fields in the output tuple corresponds to the order of field indexes.
:return: The projected DataStream.
"""
if not isinstance(self.get_type(), typeinfo.TupleTypeInfo):
raise Exception('Only Tuple DataStreams can be projected.')

gateway = get_gateway()
j_index_arr = gateway.new_array(gateway.jvm.int, len(field_indexes))
for i in range(len(field_indexes)):
j_index_arr[i] = field_indexes[i]
return DataStream(self._j_data_stream.project(j_index_arr))

def rescale(self) -> 'DataStream':
"""
Sets the partitioning of the DataStream so that the output elements are distributed evenly
to a subset of instances of the next operation in a round-robin fashion.
The subset of downstream operations to which the upstream operation sends elements depends
on the degree of parallelism of both the upstream and downstream operation. For example, if
the upstream operation has parallelism 2 and the downstream operation has parallelism 4,
then one upstream operation would distribute elements to two downstream operations. If, on
the other hand, the downstream operation has parallelism 4 then two upstream operations will
distribute to one downstream operation while the other two upstream operations will
distribute to the other downstream operations.
In cases where the different parallelisms are not multiples of each one or several
downstream operations will have a differing number of inputs from upstream operations.
:return: The DataStream with rescale partitioning set.
"""
return DataStream(self._j_data_stream.rescale())

def rebalance(self) -> 'DataStream':
"""
Sets the partitioning of the DataStream so that the output elements are distributed evenly
to instances of the next operation in a round-robin fashion.
:return: The DataStream with rebalance partition set.
"""
return DataStream(self._j_data_stream.rebalance())

def forward(self) -> 'DataStream':
"""
Sets the partitioning of the DataStream so that the output elements are forwarded to the
local sub-task of the next operation.
:return: The DataStream with forward partitioning set.
"""
return DataStream(self._j_data_stream.forward())

def broadcast(self) -> 'DataStream':
"""
Sets the partitioning of the DataStream so that the output elements are broadcasted to every
parallel instance of the next operation.
:return: The DataStream with broadcast partitioning set.
"""
return DataStream(self._j_data_stream.broadcast())

def _get_java_python_function_operator(self, func: Union[Function, FunctionWrapper],
type_info: TypeInformation, func_name: str,
func_type: int):
Expand Down Expand Up @@ -563,6 +657,27 @@ def key_by(self, key_selector: Union[Callable, KeySelector],
key_type_info: TypeInformation = None) -> 'KeyedStream':
return self._origin_stream.key_by(key_selector, key_type_info)

def union(self, *streams) -> 'DataStream':
return self._values().union(*streams)

def shuffle(self) -> 'DataStream':
raise Exception('Cannot override partitioning for KeyedStream.')

def project(self, *field_indexes) -> 'DataStream':
return self._values().project(*field_indexes)

def rescale(self) -> 'DataStream':
raise Exception('Cannot override partitioning for KeyedStream.')

def rebalance(self) -> 'DataStream':
raise Exception('Cannot override partitioning for KeyedStream.')

def forward(self) -> 'DataStream':
raise Exception('Cannot override partitioning for KeyedStream.')

def broadcast(self) -> 'DataStream':
raise Exception('Cannot override partitioning for KeyedStream.')

def print(self, sink_identifier=None):
return self._values().print()

Expand Down
80 changes: 80 additions & 0 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,86 @@ def test_print_with_align_output(self):
self.assertEqual(3, len(plan['nodes']))
self.assertEqual("Sink: Print to Std. Out", plan['nodes'][2]['type'])

def test_union_stream(self):
ds_1 = self.env.from_collection([1, 2, 3])
ds_2 = self.env.from_collection([4, 5, 6])
ds_3 = self.env.from_collection([7, 8, 9])

united_stream = ds_3.union(ds_1, ds_2)

united_stream.map(lambda x: x + 1).add_sink(self.test_sink)
exec_plan = eval(self.env.get_execution_plan())
source_ids = []
union_node_pre_ids = []
for node in exec_plan['nodes']:
if node['pact'] == 'Data Source':
source_ids.append(node['id'])
if node['pact'] == 'Operator':
for pre in node['predecessors']:
union_node_pre_ids.append(pre['id'])

source_ids.sort()
union_node_pre_ids.sort()
self.assertEqual(source_ids, union_node_pre_ids)

def test_project(self):
ds = self.env.from_collection([[1, 2, 3, 4], [5, 6, 7, 8]],
type_info=Types.TUPLE(
[Types.INT(), Types.INT(), Types.INT(), Types.INT()]))
ds.project(1, 3).map(lambda x: (x[0], x[1] + 1)).add_sink(self.test_sink)
exec_plan = eval(self.env.get_execution_plan())
self.assertEqual(exec_plan['nodes'][1]['type'], 'Projection')

def test_broadcast(self):
ds_1 = self.env.from_collection([1, 2, 3])
ds_1.broadcast().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink)
exec_plan = eval(self.env.get_execution_plan())
broadcast_node = exec_plan['nodes'][1]
pre_ship_strategy = broadcast_node['predecessors'][0]['ship_strategy']
self.assertEqual(pre_ship_strategy, 'BROADCAST')

def test_rebalance(self):
ds_1 = self.env.from_collection([1, 2, 3])
ds_1.rebalance().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink)
exec_plan = eval(self.env.get_execution_plan())
rebalance_node = exec_plan['nodes'][1]
pre_ship_strategy = rebalance_node['predecessors'][0]['ship_strategy']
self.assertEqual(pre_ship_strategy, 'REBALANCE')

def test_rescale(self):
ds_1 = self.env.from_collection([1, 2, 3])
ds_1.rescale().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink)
exec_plan = eval(self.env.get_execution_plan())
rescale_node = exec_plan['nodes'][1]
pre_ship_strategy = rescale_node['predecessors'][0]['ship_strategy']
self.assertEqual(pre_ship_strategy, 'RESCALE')

def test_shuffle(self):
ds_1 = self.env.from_collection([1, 2, 3])
ds_1.shuffle().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink)
exec_plan = eval(self.env.get_execution_plan())
shuffle_node = exec_plan['nodes'][1]
pre_ship_strategy = shuffle_node['predecessors'][0]['ship_strategy']
self.assertEqual(pre_ship_strategy, 'SHUFFLE')

def test_keyed_stream_partitioning(self):
ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)])
keyed_stream = ds.key_by(lambda x: x[1])
with self.assertRaises(Exception):
keyed_stream.shuffle()

with self.assertRaises(Exception):
keyed_stream.rebalance()

with self.assertRaises(Exception):
keyed_stream.rescale()

with self.assertRaises(Exception):
keyed_stream.broadcast()

with self.assertRaises(Exception):
keyed_stream.forward()

def tearDown(self) -> None:
self.test_sink.get_results()

Expand Down

0 comments on commit 82b3669

Please sign in to comment.