Skip to content

Commit

Permalink
[FLINK-18884][python] Add chaining strategy and slot sharing group in…
Browse files Browse the repository at this point in the history
…terfaces for Python DataStream API. (apache#13140)
  • Loading branch information
shuiqiangchen committed Aug 14, 2020
1 parent dc73dbf commit d3aa4f3
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 6 deletions.
49 changes: 49 additions & 0 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,46 @@ def set_buffer_timeout(self, timeout_millis: int):
self._j_data_stream.setBufferTimeout(timeout_millis)
return self

def start_new_chain(self) -> 'DataStream':
"""
Starts a new task chain beginning at this operator. This operator will be chained (thread
co-located for increased performance) to any previous tasks even if possible.
:return: The operator with chaining set.
"""
self._j_data_stream.startNewChain()
return self

def disable_chaining(self) -> 'DataStream':
"""
Turns off chaining for this operator so thread co-location will not be used as an
optimization.
Chaining can be turned off for the whole job by
StreamExecutionEnvironment.disableOperatorChaining() however it is not advised for
performance consideration.
:return: The operator with chaining disabled.
"""
self._j_data_stream.disableChaining()
return self

def slot_sharing_group(self, slot_sharing_group: str) -> 'DataStream':
"""
Sets the slot sharing group of this operation. Parallel instances of operations that are in
the same slot sharing group will be co-located in the same TaskManager slot, if possible.
Operations inherit the slot sharing group of input operations if all input operations are in
the same slot sharing group and no slot sharing group was explicitly specified.
Initially an operation is in the default slot sharing group. An operation can be put into
the default group explicitly by setting the slot sharing group to 'default'.
:param slot_sharing_group: The slot sharing group name.
:return: This operator.
"""
self._j_data_stream.slotSharingGroup(slot_sharing_group)
return self

def map(self, func: Union[Callable, MapFunction], type_info: TypeInformation = None) \
-> 'DataStream':
"""
Expand Down Expand Up @@ -714,3 +754,12 @@ def force_non_parallel(self):

def set_buffer_timeout(self, timeout_millis: int):
raise Exception("Set buffer timeout for KeyedStream is not supported.")

def start_new_chain(self) -> 'DataStream':
raise Exception("Start new chain for KeyedStream is not supported.")

def disable_chaining(self) -> 'DataStream':
raise Exception("Disable chaining for KeyedStream is not supported.")

def slot_sharing_group(self, slot_sharing_group: str) -> 'DataStream':
raise Exception("Setting slot sharing group for KeyedStream is not supported.")
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def _from_collection(self, elements: List[Any],
j_input_format,
out_put_type_info.get_java_type_info()
)
j_data_stream_source.forceNonParallel()
return DataStream(j_data_stream=j_data_stream_source)
finally:
os.unlink(temp_file.name)
Expand Down
94 changes: 88 additions & 6 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyflink.datastream.functions import KeySelector
from pyflink.datastream.functions import MapFunction, FlatMapFunction
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.java_gateway import get_gateway
from pyflink.testing.test_case_utils import PyFlinkTestCase


Expand All @@ -40,19 +41,18 @@ def test_data_stream_name(self):

def test_set_parallelism(self):
parallelism = 3
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
ds.set_parallelism(parallelism)
ds.add_sink(self.test_sink)
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x)
ds.set_parallelism(parallelism).add_sink(self.test_sink)
plan = eval(str(self.env.get_execution_plan()))
self.assertEqual(parallelism, plan['nodes'][0]['parallelism'])
self.assertEqual(parallelism, plan['nodes'][1]['parallelism'])

def test_set_max_parallelism(self):
max_parallelism = 4
self.env.set_parallelism(8)
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x)
ds.set_parallelism(max_parallelism).add_sink(self.test_sink)
plan = eval(str(self.env.get_execution_plan()))
self.assertEqual(max_parallelism, plan['nodes'][0]['parallelism'])
self.assertEqual(max_parallelism, plan['nodes'][1]['parallelism'])

def test_force_non_parallel(self):
self.env.set_parallelism(8)
Expand Down Expand Up @@ -322,6 +322,88 @@ def test_keyed_stream_partitioning(self):
with self.assertRaises(Exception):
keyed_stream.forward()

def test_slot_sharing_group(self):
source_operator_name = 'collection source'
map_operator_name = 'map_operator'
slot_sharing_group_1 = 'slot_sharing_group_1'
slot_sharing_group_2 = 'slot_sharing_group_2'
ds_1 = self.env.from_collection([1, 2, 3]).name(source_operator_name)
ds_1.slot_sharing_group(slot_sharing_group_1).map(lambda x: x + 1).set_parallelism(3)\
.name(map_operator_name).slot_sharing_group(slot_sharing_group_2)\
.add_sink(self.test_sink)

j_generated_stream_graph = self.env._j_stream_execution_environment \
.getStreamGraph("test start new_chain", True)

j_stream_nodes = list(j_generated_stream_graph.getStreamNodes().toArray())
for j_stream_node in j_stream_nodes:
if j_stream_node.getOperatorName() == source_operator_name:
self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_1)
elif j_stream_node.getOperatorName() == map_operator_name:
self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_2)

def test_chaining_strategy(self):
chained_operator_name_0 = "map_operator_0"
chained_operator_name_1 = "map_operator_1"
chained_operator_name_2 = "map_operator_2"

ds = self.env.from_collection([1, 2, 3])
ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0)\
.map(lambda x: x).set_parallelism(2).name(chained_operator_name_1)\
.map(lambda x: x).set_parallelism(2).name(chained_operator_name_2)\
.add_sink(self.test_sink)

def assert_chainable(j_stream_graph, expected_upstream_chainable,
expected_downstream_chainable):
j_stream_nodes = list(j_stream_graph.getStreamNodes().toArray())
for j_stream_node in j_stream_nodes:
if j_stream_node.getOperatorName() == chained_operator_name_1:
JStreamingJobGraphGenerator = get_gateway().jvm \
.org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator

j_in_stream_edge = j_stream_node.getInEdges().get(0)
upstream_chainable = JStreamingJobGraphGenerator.isChainable(j_in_stream_edge,
j_stream_graph)
self.assertEqual(expected_upstream_chainable, upstream_chainable)

j_out_stream_edge = j_stream_node.getOutEdges().get(0)
downstream_chainable = JStreamingJobGraphGenerator.isChainable(
j_out_stream_edge, j_stream_graph)
self.assertEqual(expected_downstream_chainable, downstream_chainable)

# The map_operator_1 has the same parallelism with map_operator_0 and map_operator_2, and
# ship_strategy for map_operator_0 and map_operator_1 is FORWARD, so the map_operator_1
# can be chained with map_operator_0 and map_operator_2.
j_generated_stream_graph = self.env._j_stream_execution_environment\
.getStreamGraph("test start new_chain", True)
assert_chainable(j_generated_stream_graph, True, True)

ds = self.env.from_collection([1, 2, 3])
# Start a new chain for map_operator_1
ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \
.map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).start_new_chain() \
.map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \
.add_sink(self.test_sink)

j_generated_stream_graph = self.env._j_stream_execution_environment \
.getStreamGraph("test start new_chain", True)
# We start a new chain for map operator, therefore, it cannot be chained with upstream
# operator, but can be chained with downstream operator.
assert_chainable(j_generated_stream_graph, False, True)

ds = self.env.from_collection([1, 2, 3])
# Disable chaining for map_operator_1
ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \
.map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).disable_chaining() \
.map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \
.add_sink(self.test_sink)

j_generated_stream_graph = self.env._j_stream_execution_environment \
.getStreamGraph("test start new_chain", True)
# We disable chaining for map_operator_1, therefore, it cannot be chained with
# upstream and downstream operators.
assert_chainable(j_generated_stream_graph, False, False)

def tearDown(self) -> None:
self.test_sink.clear()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pyflink.datastream.functions import SourceFunction
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.java_gateway import get_gateway
from pyflink.pyflink_gateway_server import on_windows
from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, StreamTableEnvironment
from pyflink.testing.test_case_utils import PyFlinkTestCase

Expand Down Expand Up @@ -396,6 +397,7 @@ def add_from_file(i):
expected.sort()
self.assertEqual(expected, result)

@unittest.skipIf(on_windows(), "Symbolic link is not supported on Windows, skipping.")
def test_set_stream_env(self):
import sys
python_exec = sys.executable
Expand Down

0 comments on commit d3aa4f3

Please sign in to comment.