Skip to content

Commit

Permalink
Making side inputs work
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloem committed Apr 18, 2020
1 parent 049f45d commit af74e21
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,32 +304,42 @@ def __init__(self,
@staticmethod
def _build_data_side_inputs_map(stages):
# type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]

"""Builds an index mapping stages to side input descriptors.
A side input descriptor is a map of side input IDs to side input access
patterns for all of the outputs of a stage that will be consumed as a
side input.
"""
data_side_inputs_by_stage = {}
all_side_inputs, stage_consumers, transform_consumers = (
data_side_inputs_by_producing_stage = {}
all_side_inputs, _, transform_consumers = (
translations.get_all_side_inputs_and_consumers(stages))

producing_stages_by_pcoll = {}

for s in stages:
data_side_inputs_by_stage[s.name] = {}
data_side_inputs_by_producing_stage[s.name] = {}
for transform in s.transforms:
for output in transform.outputs.values():
if output in all_side_inputs:
for consuming_transform in transform_consumers[output]:
payload = proto_utils.parse_Bytes(consuming_transform.spec.payload,
beam_runner_api_pb2.ParDoPayload)
for si_tag in payload.side_inputs:
side_input_id = (consuming_transform.unique_name, si_tag)
data_side_inputs_by_stage[s.name][side_input_id] = (
consuming_transform.inputs[si_tag],
payload.side_inputs[si_tag].access_pattern)

return data_side_inputs_by_stage

for o in transform.outputs.values():
producing_stages_by_pcoll[o] = s

for side_pc in all_side_inputs:
for consuming_transform in transform_consumers[side_pc]:
if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
continue
producing_stage = producing_stages_by_pcoll[side_pc]
payload = proto_utils.parse_Bytes(
consuming_transform.spec.payload,
beam_runner_api_pb2.ParDoPayload)
for si_tag in payload.side_inputs:
if consuming_transform.inputs[si_tag] == side_pc:
side_input_id = (consuming_transform.unique_name, si_tag)
data_side_inputs_by_producing_stage[
producing_stage.name][side_input_id] = (
translations.create_buffer_id(side_pc),
payload.side_inputs[si_tag].access_pattern)

return data_side_inputs_by_producing_stage

@property
def state_servicer(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from typing import Tuple
from typing import TypeVar

import apache_beam as beam # pylint: disable=ungrouped-imports
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.metrics import metric
from apache_beam.metrics import monitoring_infos
Expand All @@ -59,10 +58,8 @@
from apache_beam.runners.portability.fn_api_runner import execution
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
from apache_beam.runners.portability.fn_api_runner.execution import WindowGroupingBuffer
from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import only_element
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager
from apache_beam.transforms import environments
from apache_beam.utils import profiler
Expand All @@ -85,6 +82,7 @@

# This module is experimental. No backwards-compatibility guarantees.


class FnApiRunner(runner.PipelineRunner):

def __init__(
Expand Down Expand Up @@ -338,45 +336,6 @@ def run_stages(self,
worker_handler_manager.close_all()
return RunnerResult(runner.PipelineState.DONE, monitoring_infos_by_stage)

def _store_side_inputs_in_state(self,
runner_execution_context, # type: execution.FnApiRunnerExecutionContext
data_side_input, # type: DataSideInput
):
# type: (...) -> None
for (transform_id, tag), (buffer_id, si) in data_side_input.items():
_, pcoll_id = split_buffer_id(buffer_id)
value_coder = runner_execution_context.pipeline_context.coders[
runner_execution_context.safe_coders[
runner_execution_context.data_channel_coders[pcoll_id]]]
elements_by_window = WindowGroupingBuffer(si, value_coder)
if buffer_id not in runner_execution_context.pcoll_buffers:
runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer(
coder_impl=value_coder.get_impl())
for element_data in runner_execution_context.pcoll_buffers[buffer_id]:
elements_by_window.append(element_data)

if si.urn == common_urns.side_inputs.ITERABLE.urn:
for _, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
transform_id=transform_id, side_input_id=tag, window=window))
(
runner_execution_context.worker_handler_manager.state_servicer.
append_raw(state_key, elements_data))
elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
for key, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
transform_id=transform_id,
side_input_id=tag,
window=window,
key=key))
(
runner_execution_context.worker_handler_manager.state_servicer.
append_raw(state_key, elements_data))
else:
raise ValueError("Unknown access pattern: '%s'" % si.urn)

def _run_bundle_multiple_times_for_testing(
self,
runner_execution_context, # type: execution.FnApiRunnerExecutionContext
Expand Down Expand Up @@ -559,8 +518,9 @@ def merge_results(last_result):

# Store the required downstream side inputs into state so it is accessible
# for the worker when it runs bundles that consume this stage's output.
data_side_input = runner_execution_context.side_input_descriptors_by_stage[
bundle_context_manager.stage]
data_side_input = (
runner_execution_context.side_input_descriptors_by_stage.get(
bundle_context_manager.stage.name, {}))
runner_execution_context.commit_side_inputs_to_state(data_side_input)

return final_result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import functools
import logging
from builtins import object
from typing import Any
from typing import Container
from typing import DefaultDict
from typing import Dict
Expand Down Expand Up @@ -1044,7 +1043,7 @@ def expand_gbk(stages, pipeline_context):
urn=bundle_processor.DATA_OUTPUT_URN,
payload=grouping_buffer))
],
downstream_side_inputs={},
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
yield gbk_write

Expand Down Expand Up @@ -1097,7 +1096,7 @@ def fix_flatten_coders(stages, pipeline_context):
urn=bundle_processor.IDENTITY_DOFN_URN),
environment_id=transform.environment_id)
],
downstream_side_inputs={},
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
pcollections[transcoded_pcollection].CopyFrom(pcollections[pcoll_in])
pcollections[transcoded_pcollection].unique_name = (
Expand Down Expand Up @@ -1135,7 +1134,7 @@ def sink_flattens(stages, pipeline_context):
urn=bundle_processor.DATA_OUTPUT_URN,
payload=buffer_id))
],
downstream_side_inputs={},
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
flatten_writes.append(flatten_write)
yield flatten_write
Expand Down

0 comments on commit af74e21

Please sign in to comment.