Skip to content

Commit

Permalink
[FLINK-20528][python] Support table aggregation for group aggregation…
Browse files Browse the repository at this point in the history
… in streaming mode

This closes apache#14389.
  • Loading branch information
HuangXingBo committed Dec 17, 2020
1 parent 98ed08b commit 9c486d1
Show file tree
Hide file tree
Showing 16 changed files with 881 additions and 143 deletions.
207 changes: 171 additions & 36 deletions flink-python/pyflink/fn_execution/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
from abc import ABC, abstractmethod
from typing import List, Dict
from typing import List, Dict, Iterable

from apache_beam.coders import PickleCoder, Coder

Expand All @@ -25,8 +25,9 @@
from pyflink.fn_execution.coders import from_proto
from pyflink.fn_execution.operation_utils import is_built_in_function, load_aggregate_function
from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
from pyflink.table import AggregateFunction, FunctionContext
from pyflink.table import AggregateFunction, FunctionContext, TableAggregateFunction
from pyflink.table.data_view import ListView, MapView
from pyflink.table.udf import ImperativeAggregateFunction


def join_row(left: Row, right: Row):
Expand Down Expand Up @@ -219,9 +220,9 @@ def get_state_map_view(self, state_name, key_coder, value_coder):
self._keyed_state_backend.get_map_state(state_name, key_coder, value_coder))


class AggsHandleFunction(ABC):
class AggsHandleFunctionBase(ABC):
"""
The base class for handling aggregate functions.
The base class for handling aggregate or table aggregate functions.
"""

@abstractmethod
Expand Down Expand Up @@ -300,6 +301,20 @@ def cleanup(self):
"""
pass

@abstractmethod
def close(self):
"""
Tear-down method for this function. It can be used for clean up work.
By default, this method does nothing.
"""
pass


class AggsHandleFunction(AggsHandleFunctionBase):
"""
The base class for handling aggregate functions.
"""

@abstractmethod
def get_value(self) -> Row:
"""
Expand All @@ -309,36 +324,35 @@ def get_value(self) -> Row:
"""
pass


class TableAggsHandleFunction(AggsHandleFunctionBase):
"""
The base class for handling table aggregate functions.
"""

@abstractmethod
def close(self):
def emit_value(self, current_key: Row, is_retract: bool) -> Iterable[Row]:
"""
Tear-down method for this function. It can be used for clean up work.
By default, this method does nothing.
Emit the result of the table aggregation.
"""
pass


class SimpleAggsHandleFunction(AggsHandleFunction):
class SimpleAggsHandleFunctionBase(AggsHandleFunctionBase):
"""
A simple AggsHandleFunction implementation which provides the basic functionality.
A simple AggsHandleFunctionBase implementation which provides the basic functionality.
"""

def __init__(self,
udfs: List[AggregateFunction],
udfs: List[ImperativeAggregateFunction],
input_extractors: List,
index_of_count_star: int,
count_star_inserted: bool,
udf_data_view_specs: List[List[DataViewSpec]],
filter_args: List[int],
distinct_indexes: List[int],
distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
self._udfs = udfs
self._input_extractors = input_extractors
self._accumulators = None # type: Row
self._get_value_indexes = [i for i in range(len(udfs))]
if index_of_count_star >= 0 and count_star_inserted:
# The record count is used internally, should be ignored by the get_value method.
self._get_value_indexes.remove(index_of_count_star)
self._udf_data_view_specs = udf_data_view_specs
self._udf_data_views = []
self._filter_args = filter_args
Expand Down Expand Up @@ -451,13 +465,64 @@ def cleanup(self):
for data_view in self._udf_data_views[i].values():
data_view.clear()

def close(self):
for udf in self._udfs:
udf.close()


class SimpleAggsHandleFunction(SimpleAggsHandleFunctionBase, AggsHandleFunction):
"""
A simple AggsHandleFunction implementation which provides the basic functionality.
"""

def __init__(self,
udfs: List[AggregateFunction],
input_extractors: List,
index_of_count_star: int,
count_star_inserted: bool,
udf_data_view_specs: List[List[DataViewSpec]],
filter_args: List[int],
distinct_indexes: List[int],
distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
super(SimpleAggsHandleFunction, self).__init__(
udfs, input_extractors, udf_data_view_specs, filter_args, distinct_indexes,
distinct_view_descriptors)
self._get_value_indexes = [i for i in range(len(udfs))]
if index_of_count_star >= 0 and count_star_inserted:
# The record count is used internally, should be ignored by the get_value method.
self._get_value_indexes.remove(index_of_count_star)

def get_value(self):
return Row(*[self._udfs[i].get_value(self._accumulators[i])
for i in self._get_value_indexes])

def close(self):
for udf in self._udfs:
udf.close()

class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase, TableAggsHandleFunction):
"""
A simple TableAggsHandleFunction implementation which provides the basic functionality.
"""

def __init__(self,
udfs: List[TableAggregateFunction],
input_extractors: List,
udf_data_view_specs: List[List[DataViewSpec]],
filter_args: List[int],
distinct_indexes: List[int],
distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
super(SimpleTableAggsHandleFunction, self).__init__(
udfs, input_extractors, udf_data_view_specs, filter_args, distinct_indexes,
distinct_view_descriptors)

def emit_value(self, current_key: Row, is_retract: bool):
udf = self._udfs[0] # type: TableAggregateFunction
results = udf.emit_value(self._accumulators[0])
for x in results:
result = join_row(current_key, x)
if is_retract:
result.set_row_kind(RowKind.DELETE)
else:
result.set_row_kind(RowKind.INSERT)
yield result


class RecordCounter(ABC):
Expand Down Expand Up @@ -494,10 +559,10 @@ def record_count_is_zero(self, acc):
return acc is None or acc[self._index_of_count_star][0] == 0


class GroupAggFunction(object):
class GroupAggFunctionBase(object):

def __init__(self,
aggs_handle: AggsHandleFunction,
aggs_handle: AggsHandleFunctionBase,
key_selector: RowKeySelector,
state_backend: RemoteKeyedStateBackend,
state_value_coder: Coder,
Expand All @@ -518,6 +583,41 @@ def open(self, function_context: FunctionContext):
def close(self):
self.aggs_handle.close()

def on_timer(self, key):
if self.state_cleaning_enabled:
self.state_backend.set_current_key(key)
accumulator_state = self.state_backend.get_value_state(
"accumulators", self.state_value_coder)
accumulator_state.clear()
self.aggs_handle.cleanup()

@staticmethod
def is_retract_msg(data: Row):
return data.get_row_kind() == RowKind.UPDATE_BEFORE or data.get_row_kind() == RowKind.DELETE

@staticmethod
def is_accumulate_msg(data: Row):
return data.get_row_kind() == RowKind.UPDATE_AFTER or data.get_row_kind() == RowKind.INSERT

@abstractmethod
def process_element(self, input_data: Row):
pass


class GroupAggFunction(GroupAggFunctionBase):

def __init__(self,
aggs_handle: AggsHandleFunction,
key_selector: RowKeySelector,
state_backend: RemoteKeyedStateBackend,
state_value_coder: Coder,
generate_update_before: bool,
state_cleaning_enabled: bool,
index_of_count_star: int):
super(GroupAggFunction, self).__init__(
aggs_handle, key_selector, state_backend, state_value_coder, generate_update_before,
state_cleaning_enabled, index_of_count_star)

def process_element(self, input_data: Row):
key = self.key_selector.get_key(input_data)
self.state_backend.set_current_key(key)
Expand Down Expand Up @@ -597,20 +697,55 @@ def process_element(self, input_data: Row):
# cleanup dataview under current key
self.aggs_handle.cleanup()

def on_timer(self, key):
if self.state_cleaning_enabled:
self.state_backend.set_current_key(key)
accumulator_state = self.state_backend.get_value_state(
"accumulators", self.state_value_coder)
accumulator_state.clear()
self.aggs_handle.cleanup()

@staticmethod
def is_retract_msg(data: Row):
return data.get_row_kind() == RowKind.UPDATE_BEFORE \
or data.get_row_kind() == RowKind.DELETE
class GroupTableAggFunction(GroupAggFunctionBase):
def __init__(self,
aggs_handle: TableAggsHandleFunction,
key_selector: RowKeySelector,
state_backend: RemoteKeyedStateBackend,
state_value_coder: Coder,
generate_update_before: bool,
state_cleaning_enabled: bool,
index_of_count_star: int):
super(GroupTableAggFunction, self).__init__(
aggs_handle, key_selector, state_backend, state_value_coder, generate_update_before,
state_cleaning_enabled, index_of_count_star)

@staticmethod
def is_accumulate_msg(data: Row):
return data.get_row_kind() == RowKind.UPDATE_AFTER \
or data.get_row_kind() == RowKind.INSERT
def process_element(self, input_data: Row):
key = self.key_selector.get_key(input_data)
self.state_backend.set_current_key(key)
self.state_backend.clear_cached_iterators()
accumulator_state = self.state_backend.get_value_state(
"accumulators", self.state_value_coder)
accumulators = accumulator_state.value()
if accumulators is None:
first_row = True
accumulators = self.aggs_handle.create_accumulators()
else:
first_row = False

# set accumulators to handler first
self.aggs_handle.set_accumulators(accumulators)

if not first_row and self.generate_update_before:
yield from self.aggs_handle.emit_value(key, True)

# update aggregate result and set to the newRow
if self.is_accumulate_msg(input_data):
# accumulate input
self.aggs_handle.accumulate(input_data)
else:
# retract input
self.aggs_handle.retract(input_data)

# get accumulator
accumulators = self.aggs_handle.get_accumulators()

if not self.record_counter.record_count_is_zero(accumulators):
yield from self.aggs_handle.emit_value(key, False)
accumulator_state.update(accumulators)
else:
# and clear all state
accumulator_state.clear()
# cleanup dataview under current key
self.aggs_handle.cleanup()
10 changes: 10 additions & 0 deletions flink-python/pyflink/fn_execution/beam/beam_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def create_data_stream_keyed_process_function(factory, transform_id, transform_p
operations.KeyedProcessFunctionOperation)


@bundle_processor.BeamTransformFactory.register_urn(
operations.STREAM_GROUP_TABLE_AGGREGATE_URN,
flink_fn_execution_pb2.UserDefinedAggregateFunctions)
def create_table_aggregate_function(factory, transform_id, transform_proto, parameter, consumers):
return _create_user_defined_function_operation(
factory, transform_proto, consumers, parameter,
beam_operations.StatefulFunctionOperation,
operations.StreamGroupTableAggregateOperation)


def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto,
beam_operation_cls, internal_operation_cls):
output_tags = list(transform_proto.outputs.keys())
Expand Down
4 changes: 2 additions & 2 deletions flink-python/pyflink/fn_execution/operation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pyflink.serializers import PickleSerializer
from pyflink.table import functions
from pyflink.table.udf import DelegationTableFunction, DelegatingScalarFunction, \
AggregateFunction, PandasAggregateFunctionWrapper
ImperativeAggregateFunction, PandasAggregateFunctionWrapper

_func_num = 0
_constant_num = 0
Expand Down Expand Up @@ -147,7 +147,7 @@ def extract_user_defined_aggregate_function(
user_defined_function_proto,
distinct_info_dict: Dict[Tuple[List[str]], Tuple[List[int], List[int]]]):
user_defined_agg = load_aggregate_function(user_defined_function_proto.payload)
assert isinstance(user_defined_agg, AggregateFunction)
assert isinstance(user_defined_agg, ImperativeAggregateFunction)
args_str = []
local_variable_dict = {}
for arg in user_defined_function_proto.inputs:
Expand Down
Loading

0 comments on commit 9c486d1

Please sign in to comment.