Skip to content

Commit

Permalink
Merge pull request #11270 from [BEAM-9639][BEAM-9608] Improvements fo…
Browse files Browse the repository at this point in the history
…r FnApiRunner

[BEAM-9639][BEAM-9608] Improvements for FnApiRunner
  • Loading branch information
pabloem authored Apr 21, 2020
2 parents 4a7f04c + cf821e5 commit 1fe543e
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 255 deletions.
179 changes: 176 additions & 3 deletions sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,29 @@
import collections
import itertools
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
from typing import Dict
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Tuple

from typing_extensions import Protocol

from apache_beam import coders
from apache_beam.coders import BytesCoder
from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.coders.coders import GlobalWindowCoder
from apache_beam.coders.coders import WindowedValueCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import only_element
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import unique_name
Expand All @@ -45,8 +57,13 @@

if TYPE_CHECKING:
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner import worker_handlers
from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested(
GlobalWindows.windowed_value(b''))


class Buffer(Protocol):
Expand Down Expand Up @@ -204,7 +221,7 @@ class WindowGroupingBuffer(object):
def __init__(
self,
access_pattern,
coder # type: coders.WindowedValueCoder
coder # type: WindowedValueCoder
):
# type: (...) -> None
# Here's where we would use a different type of partitioning
Expand Down Expand Up @@ -251,11 +268,12 @@ def encoded_items(self):

class FnApiRunnerExecutionContext(object):
"""
:var pcoll_buffers: (collections.defaultdict of str: list): Mapping of
:var pcoll_buffers: (dict): Mapping of
PCollection IDs to list that functions as buffer for the
``beam.PCollection``.
"""
def __init__(self,
stages, # type: List[translations.Stage]
worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
safe_coders,
Expand All @@ -268,6 +286,9 @@ def __init__(self,
:param safe_coders:
:param data_channel_coders:
"""
self.stages = stages
self.side_input_descriptors_by_stage = (
self._build_data_side_inputs_map(stages))
self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer]
self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer]
self.worker_handler_manager = worker_handler_manager
Expand All @@ -280,6 +301,63 @@ def __init__(self,
iterable_state_write=self._iterable_state_write)
self._last_uid = -1

@staticmethod
def _build_data_side_inputs_map(stages):
# type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]

"""Builds an index mapping stages to side input descriptors.
A side input descriptor is a map of side input IDs to side input access
patterns for all of the outputs of a stage that will be consumed as a
side input.
"""
transform_consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]]
stage_consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[translations.Stage]]

def get_all_side_inputs():
# type: () -> Set[str]
all_side_inputs = set() # type: Set[str]
for stage in stages:
for transform in stage.transforms:
for input in transform.inputs.values():
transform_consumers[input].append(transform)
stage_consumers[input].append(stage)
for si in stage.side_inputs():
all_side_inputs.add(si)
return all_side_inputs

all_side_inputs = frozenset(get_all_side_inputs())
data_side_inputs_by_producing_stage = {}

producing_stages_by_pcoll = {}

for s in stages:
data_side_inputs_by_producing_stage[s.name] = {}
for transform in s.transforms:
for o in transform.outputs.values():
if o in s.side_inputs():
continue
producing_stages_by_pcoll[o] = s

for side_pc in all_side_inputs:
for consuming_transform in transform_consumers[side_pc]:
if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
continue
producing_stage = producing_stages_by_pcoll[side_pc]
payload = proto_utils.parse_Bytes(
consuming_transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for si_tag in payload.side_inputs:
if consuming_transform.inputs[si_tag] == side_pc:
side_input_id = (consuming_transform.unique_name, si_tag)
data_side_inputs_by_producing_stage[
producing_stage.name][side_input_id] = (
translations.create_buffer_id(side_pc),
payload.side_inputs[si_tag].access_pattern)

return data_side_inputs_by_producing_stage

@property
def state_servicer(self):
# TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
Expand All @@ -301,6 +379,43 @@ def _iterable_state_write(self, values, element_coder_impl):
out.get())
return token

def commit_side_inputs_to_state(
self,
data_side_input, # type: DataSideInput
):
# type: (...) -> None
for (consuming_transform_id, tag), (buffer_id,
func_spec) in data_side_input.items():
_, pcoll_id = split_buffer_id(buffer_id)
value_coder = self.pipeline_context.coders[self.safe_coders[
self.data_channel_coders[pcoll_id]]]
elements_by_window = WindowGroupingBuffer(func_spec, value_coder)
if buffer_id not in self.pcoll_buffers:
self.pcoll_buffers[buffer_id] = ListBuffer(
coder_impl=value_coder.get_impl())
for element_data in self.pcoll_buffers[buffer_id]:
elements_by_window.append(element_data)

if func_spec.urn == common_urns.side_inputs.ITERABLE.urn:
for _, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
transform_id=consuming_transform_id,
side_input_id=tag,
window=window))
self.state_servicer.append_raw(state_key, elements_data)
elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn:
for key, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
transform_id=consuming_transform_id,
side_input_id=tag,
window=window,
key=key))
self.state_servicer.append_raw(state_key, elements_data)
else:
raise ValueError("Unknown access pattern: '%s'" % func_spec.urn)


class BundleContextManager(object):

Expand Down Expand Up @@ -367,6 +482,64 @@ def _build_process_bundle_descriptor(self):
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor())

def extract_bundle_inputs_and_outputs(self):
# type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]

"""Returns maps of transform names to PCollection identifiers.
Also mutates IO stages to point to the data ApiServiceDescriptor.
Returns:
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
data_input = {} # type: Dict[str, PartitionableBuffer]
data_output = {} # type: DataOutput
# A mapping of {(transform_id, timer_family_id) : buffer_id}
expected_timer_output = {} # type: Dict[Tuple[str, str], str]
for transform in self.stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
coder_id = self.execution_context.data_channel_coders[only_element(
transform.outputs.values())]
coder = self.execution_context.pipeline_context.coders[
self.execution_context.safe_coders.get(coder_id, coder_id)]
if pcoll_id == translations.IMPULSE_BUFFER:
data_input[transform.unique_name] = ListBuffer(
coder_impl=coder.get_impl())
data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
else:
if pcoll_id not in self.execution_context.pcoll_buffers:
self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
coder_impl=coder.get_impl())
data_input[transform.unique_name] = (
self.execution_context.pcoll_buffers[pcoll_id])
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[only_element(
transform.inputs.values())]
else:
raise NotImplementedError
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
data_api_service_descriptor = self.data_api_service_descriptor()
if data_api_service_descriptor:
data_spec.api_service_descriptor.url = (
data_api_service_descriptor.url)
transform.spec.payload = data_spec.SerializeToString()
elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for timer_family_id in payload.timer_family_specs.keys():
expected_timer_output[(transform.unique_name, timer_family_id)] = (
create_buffer_id(timer_family_id, 'timers'))
return data_input, data_output, expected_timer_output

def get_input_coder_impl(self, transform_id):
# type: (str) -> CoderImpl
coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString(
Expand Down
Loading

0 comments on commit 1fe543e

Please sign in to comment.