Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-9639][BEAM-9608] Improvements for FnApiRunner #11270

Merged
merged 4 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 176 additions & 3 deletions sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,29 @@
import collections
import itertools
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
from typing import Dict
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Tuple

from typing_extensions import Protocol

from apache_beam import coders
from apache_beam.coders import BytesCoder
from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.coders.coders import GlobalWindowCoder
from apache_beam.coders.coders import WindowedValueCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import only_element
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import unique_name
Expand All @@ -45,8 +57,13 @@

if TYPE_CHECKING:
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner import worker_handlers
from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested(
GlobalWindows.windowed_value(b''))


class Buffer(Protocol):
Expand Down Expand Up @@ -204,7 +221,7 @@ class WindowGroupingBuffer(object):
def __init__(
self,
access_pattern,
coder # type: coders.WindowedValueCoder
coder # type: WindowedValueCoder
):
# type: (...) -> None
# Here's where we would use a different type of partitioning
Expand Down Expand Up @@ -251,11 +268,12 @@ def encoded_items(self):

class FnApiRunnerExecutionContext(object):
"""
:var pcoll_buffers: (collections.defaultdict of str: list): Mapping of
:var pcoll_buffers: (dict): Mapping of
PCollection IDs to list that functions as buffer for the
``beam.PCollection``.
"""
def __init__(self,
stages, # type: List[translations.Stage]
worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
safe_coders,
Expand All @@ -268,6 +286,9 @@ def __init__(self,
:param safe_coders:
:param data_channel_coders:
"""
self.stages = stages
self.side_input_descriptors_by_stage = (
self._build_data_side_inputs_map(stages))
self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer]
self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer]
self.worker_handler_manager = worker_handler_manager
Expand All @@ -280,6 +301,63 @@ def __init__(self,
iterable_state_write=self._iterable_state_write)
self._last_uid = -1

@staticmethod
def _build_data_side_inputs_map(stages):
# type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]

"""Builds an index mapping stages to side input descriptors.

A side input descriptor is a map of side input IDs to side input access
patterns for all of the outputs of a stage that will be consumed as a
side input.
"""
transform_consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]]
stage_consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[translations.Stage]]

def get_all_side_inputs():
# type: () -> Set[str]
all_side_inputs = set() # type: Set[str]
for stage in stages:
for transform in stage.transforms:
for input in transform.inputs.values():
transform_consumers[input].append(transform)
stage_consumers[input].append(stage)
for si in stage.side_inputs():
all_side_inputs.add(si)
return all_side_inputs

all_side_inputs = frozenset(get_all_side_inputs())
data_side_inputs_by_producing_stage = {}

producing_stages_by_pcoll = {}

for s in stages:
data_side_inputs_by_producing_stage[s.name] = {}
for transform in s.transforms:
for o in transform.outputs.values():
if o in s.side_inputs():
continue
producing_stages_by_pcoll[o] = s

for side_pc in all_side_inputs:
for consuming_transform in transform_consumers[side_pc]:
if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
continue
producing_stage = producing_stages_by_pcoll[side_pc]
payload = proto_utils.parse_Bytes(
consuming_transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for si_tag in payload.side_inputs:
if consuming_transform.inputs[si_tag] == side_pc:
side_input_id = (consuming_transform.unique_name, si_tag)
data_side_inputs_by_producing_stage[
producing_stage.name][side_input_id] = (
translations.create_buffer_id(side_pc),
payload.side_inputs[si_tag].access_pattern)

return data_side_inputs_by_producing_stage

@property
def state_servicer(self):
# TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
Expand All @@ -301,6 +379,43 @@ def _iterable_state_write(self, values, element_coder_impl):
out.get())
return token

def commit_side_inputs_to_state(
self,
data_side_input, # type: DataSideInput
):
# type: (...) -> None
for (consuming_transform_id, tag), (buffer_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if

((consuming_transform_id, tag), (buffer_id, func_spec))

would make both yapf and humans happy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it did not : ( hehe

func_spec) in data_side_input.items():
_, pcoll_id = split_buffer_id(buffer_id)
value_coder = self.pipeline_context.coders[self.safe_coders[
self.data_channel_coders[pcoll_id]]]
elements_by_window = WindowGroupingBuffer(func_spec, value_coder)
if buffer_id not in self.pcoll_buffers:
self.pcoll_buffers[buffer_id] = ListBuffer(
coder_impl=value_coder.get_impl())
for element_data in self.pcoll_buffers[buffer_id]:
elements_by_window.append(element_data)

if func_spec.urn == common_urns.side_inputs.ITERABLE.urn:
for _, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
transform_id=consuming_transform_id,
side_input_id=tag,
window=window))
self.state_servicer.append_raw(state_key, elements_data)
elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn:
for key, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
transform_id=consuming_transform_id,
side_input_id=tag,
window=window,
key=key))
self.state_servicer.append_raw(state_key, elements_data)
else:
raise ValueError("Unknown access pattern: '%s'" % func_spec.urn)


class BundleContextManager(object):

Expand Down Expand Up @@ -367,6 +482,64 @@ def _build_process_bundle_descriptor(self):
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor())

def extract_bundle_inputs_and_outputs(self):
# type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]

"""Returns maps of transform names to PCollection identifiers.

Also mutates IO stages to point to the data ApiServiceDescriptor.

Returns:
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
data_input = {} # type: Dict[str, PartitionableBuffer]
data_output = {} # type: DataOutput
# A mapping of {(transform_id, timer_family_id) : buffer_id}
expected_timer_output = {} # type: Dict[Tuple[str, str], str]
for transform in self.stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
coder_id = self.execution_context.data_channel_coders[only_element(
transform.outputs.values())]
coder = self.execution_context.pipeline_context.coders[
self.execution_context.safe_coders.get(coder_id, coder_id)]
if pcoll_id == translations.IMPULSE_BUFFER:
data_input[transform.unique_name] = ListBuffer(
coder_impl=coder.get_impl())
data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
else:
if pcoll_id not in self.execution_context.pcoll_buffers:
self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
coder_impl=coder.get_impl())
data_input[transform.unique_name] = (
self.execution_context.pcoll_buffers[pcoll_id])
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[only_element(
transform.inputs.values())]
else:
raise NotImplementedError
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
data_api_service_descriptor = self.data_api_service_descriptor()
if data_api_service_descriptor:
data_spec.api_service_descriptor.url = (
data_api_service_descriptor.url)
transform.spec.payload = data_spec.SerializeToString()
elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for timer_family_id in payload.timer_family_specs.keys():
expected_timer_output[(transform.unique_name, timer_family_id)] = (
create_buffer_id(timer_family_id, 'timers'))
return data_input, data_output, expected_timer_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docs to match.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def get_input_coder_impl(self, transform_id):
# type: (str) -> CoderImpl
coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString(
Expand Down
Loading