Skip to content

Commit

Permalink
[FLINK-18866][python] Support filter() operation for Python DataStrea…
Browse files Browse the repository at this point in the history
…m API. (apache#13098)
  • Loading branch information
shuiqiangchen authored Aug 11, 2020
1 parent 5debd15 commit f1e34e6
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 5 deletions.
54 changes: 49 additions & 5 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from pyflink.common.typeinfo import RowTypeInfo, PickledBytesTypeInfo, Types
from pyflink.common.typeinfo import TypeInformation
from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \
MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, \
KeySelectorFunctionWrapper, KeySelector, ReduceFunction, ReduceFunctionWrapper
MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, FilterFunction, \
FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, ReduceFunction, \
ReduceFunctionWrapper
from pyflink.java_gateway import get_gateway


Expand Down Expand Up @@ -262,10 +263,41 @@ def key_by(self, key_selector: Union[Callable, KeySelector],
output_type_info]))
._j_data_stream
.keyBy(PickledKeySelector(is_key_pickled_byte_array),
key_type_info.get_java_type_info()))
key_type_info.get_java_type_info()), self)
generated_key_stream._original_data_type_info = output_type_info
return generated_key_stream

def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
"""
Applies a Filter transformation on a DataStream. The transformation calls a FilterFunction
for each element of the DataStream and retains only those element for which the function
returns true. Elements for which the function returns false are filtered. The user can also
extend RichFilterFunction to gain access to other features provided by the RichFunction
interface.
:param func: The FilterFunction that is called for each element of the DataStream.
:return: The filtered DataStream.
"""
class FilterFlatMap(FlatMapFunction):
def __init__(self, filter_func):
self._func = filter_func

def flat_map(self, value):
if self._func.filter(value):
yield value

if isinstance(func, Callable):
func = FilterFunctionWrapper(func)
elif not isinstance(func, FilterFunction):
raise TypeError("func must be a Callable or instance of FilterFunction.")

j_input_type = self._j_data_stream.getTransformation().getOutputType()
type_info = typeinfo._from_java_type(j_input_type)
j_data_stream = self.flat_map(FilterFlatMap(func), type_info=type_info)._j_data_stream
filtered_stream = DataStream(j_data_stream)
filtered_stream.name("Filter")
return filtered_stream

def _get_java_python_function_operator(self, func: Union[Function, FunctionWrapper],
type_info: TypeInformation, func_name: str,
func_type: int):
Expand Down Expand Up @@ -432,14 +464,16 @@ class KeyedStream(DataStream):
Reduce-style operations, such as reduce and sum work on elements that have the same key.
"""

def __init__(self, j_keyed_stream):
def __init__(self, j_keyed_stream, origin_stream: DataStream):
"""
Constructor of KeyedStream.
:param j_keyed_stream: A java KeyedStream object.
:param origin_stream: The DataStream before key by.
"""
super(KeyedStream, self).__init__(j_data_stream=j_keyed_stream)
self._original_data_type_info = None
self._origin_stream = origin_stream

def map(self, func: Union[Callable, MapFunction], type_info: TypeInformation = None) \
-> 'DataStream':
Expand Down Expand Up @@ -483,7 +517,17 @@ def reduce(self, func: Union[Callable, ReduceFunction]) -> 'DataStream':
j_python_data_stream_scalar_function_operator
))

def _values(self):
def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
return self._values().filter(func)

def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink':
return self._values().add_sink(sink_func)

def key_by(self, key_selector: Union[Callable, KeySelector],
key_type_info: TypeInformation = None) -> 'KeyedStream':
return self._origin_stream.key_by(key_selector, key_type_info)

def _values(self) -> 'DataStream':
"""
Since python KeyedStream is in the format of Row(key_value, original_data), it is used for
getting the original_data.
Expand Down
36 changes: 36 additions & 0 deletions flink-python/pyflink/datastream/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,29 @@ def get_key(self, value):
pass


class FilterFunction(Function):
"""
A filter function is a predicate applied individually to each record. The predicate decides
whether to keep the element, or to discard it.
The basic syntax for using a FilterFunction is as follows:
:
>>> ds = ...
>>> result = ds.filter(MyFilterFunction())
Note that the system assumes that the function does not modify the elements on which the
predicate is applied. Violating this assumption can lead to incorrect results.
"""

@abc.abstractmethod
def filter(self, value):
"""
The filter function that evaluates the predicate.
:param value: The value to be filtered.
:return: True for values that should be retained, false for values to be filtered out.
"""
pass


class FunctionWrapper(object):
"""
A basic wrapper class for user defined function.
Expand Down Expand Up @@ -188,6 +211,19 @@ def flat_map(self, value):
return self._func(value)


class FilterFunctionWrapper(FunctionWrapper):
"""
A wrapper class for FilterFunction. It's used for wrapping up user defined function in a
FilterFunction when user does not implement a FilterFunction but directly pass a function
object or a lambda function to filter() function.
"""
def __init__(self, func):
super(FilterFunctionWrapper, self).__init__(func)

def filter(self, value):
return self._func(value)


class ReduceFunctionWrapper(FunctionWrapper):
"""
A wrapper class for ReduceFunction. It's used for wrapping up user defined function in a
Expand Down
43 changes: 43 additions & 0 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyflink.common.typeinfo import Types
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.datastream.functions import FilterFunction
from pyflink.datastream.functions import KeySelector
from pyflink.datastream.functions import MapFunction, FlatMapFunction
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
Expand Down Expand Up @@ -148,6 +149,29 @@ def flat_map(value):
expected.sort()
self.assertEqual(expected, results)

def test_filter_without_data_types(self):
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
ds.filter(MyFilterFunction()).add_sink(self.test_sink)
self.env.execute("test filter")
results = self.test_sink.get_results(True)
expected = ["(2, 'Hello', 'Hi')"]
results.sort()
expected.sort()
self.assertEqual(expected, results)

def test_filter_with_data_types(self):
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')],
type_info=Types.ROW(
[Types.INT(), Types.STRING(), Types.STRING()])
)
ds.filter(lambda x: x[0] % 2 == 0).add_sink(self.test_sink)
self.env.execute("test filter")
results = self.test_sink.get_results(False)
expected = ['2,Hello,Hi']
results.sort()
expected.sort()
self.assertEqual(expected, results)

def test_add_sink(self):
ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
type_info=Types.ROW([Types.STRING(), Types.INT()]))
Expand Down Expand Up @@ -188,6 +212,19 @@ def map(self, value):
expected.sort()
self.assertEqual(expected, results)

def test_multi_key_by(self):
ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2)],
type_info=Types.ROW([Types.STRING(), Types.INT()]))
ds.key_by(MyKeySelector(), key_type_info=Types.INT()).key_by(lambda x: x[0])\
.add_sink(self.test_sink)

self.env.execute("test multi key by")
results = self.test_sink.get_results(False)
expected = ['d,1', 'c,1', 'a,0', 'b,0', 'e,2']
results.sort()
expected.sort()
self.assertEqual(expected, results)

def tearDown(self) -> None:
self.test_sink.get_results()

Expand All @@ -209,3 +246,9 @@ def flat_map(self, value):
class MyKeySelector(KeySelector):
def get_key(self, value):
return value[1]


class MyFilterFunction(FilterFunction):

def filter(self, value):
return value[0] % 2 == 0

0 comments on commit f1e34e6

Please sign in to comment.