Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloem committed Apr 15, 2020
1 parent 82828b3 commit c5353c4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,13 @@
if TYPE_CHECKING:
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.runners.portability.fn_api_runner import worker_handlers
from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested(
GlobalWindows.windowed_value(b''))

DataOutput = Dict[str, bytes]

DataSideInput = Dict[translations.SideInputId,
Tuple[bytes, beam_runner_api_pb2.FunctionSpec]]


class Buffer(Protocol):
def __iter__(self):
Expand Down Expand Up @@ -425,7 +421,7 @@ def _build_process_bundle_descriptor(self):
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor())

def commit_output_views_to_state(self):
def commit_side_inputs_to_state(self):
"""Commit bundle outputs to state to be consumed as side inputs later.
Only the outputs that should be side inputs are committed to state.
Expand All @@ -437,18 +433,20 @@ def commit_output_views_to_state(self):
translations.create_buffer_id(pcoll), access_pattern)
self.execution_context.commit_side_inputs_to_state(data_side_input)

def extract_bundle_inputs(self):
def extract_bundle_inputs_and_outputs(self):
# type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput]

"""Returns maps of transform names to PCollection identifiers.
Also mutates IO stages to point to the data ApiServiceDescriptor.
Returns:
A tuple of (data_input, data_output) dictionaries.
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
data_input = {} # type: Dict[str, PartitionableBuffer]
data_output = {} # type: DataOutput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _run_stage(self,
the stage to execute, and its context.
"""
data_input, data_output, expected_timer_output = (
bundle_context_manager.extract_bundle_inputs())
bundle_context_manager.extract_bundle_inputs_and_outputs())
input_timers = {}

worker_handler_manager = runner_execution_context.worker_handler_manager
Expand Down Expand Up @@ -564,7 +564,7 @@ def merge_results(last_result):

# Store the required downstream side inputs into state so it is accessible
# for the worker when it runs bundles that consume this stage's output.
bundle_context_manager.commit_output_views_to_state()
bundle_context_manager.commit_side_inputs_to_state()
return final_result

def _run_bundle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,30 @@ def test_multimap_side_input(self):
lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
equal_to([('a', [1, 3]), ('b', [2])]))

def test_multimap_multiside_input(self):
# A test where two transforms in the same stage consume the same PCollection
# twice as side input.
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b'])
side = (
p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)])
# TODO(BEAM-4782): Obviate the need for this map.
| beam.Map(lambda kv: (kv[0], kv[1])))
assert_that(
main | 'first map' >> beam.Map(
lambda k,
d,
l: (k, sorted(d[k]), sorted([e[1] for e in l])),
beam.pvalue.AsMultiMap(side),
beam.pvalue.AsList(side))
| 'second map' >> beam.Map(
lambda k,
d,
l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])),
beam.pvalue.AsMultiMap(side),
beam.pvalue.AsList(side)),
equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])]))

def test_multimap_side_input_type_coercion(self):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,22 @@

# SideInputId is identified by a consumer ParDo + tag.
SideInputId = Tuple[str, str]
SideInputAccessPattern = beam_runner_api_pb2.FunctionSpec

DataOutput = Dict[str, bytes]

# DataSideInput maps SideInputIds to a tuple of the encoded bytes of the side
# input content, and a payload specification regarding the type of side input
# (MultiMap / Iterable).
DataSideInput = Dict[SideInputId, Tuple[bytes, SideInputAccessPattern]]


class Stage(object):
"""A set of Transforms that can be sent to the worker for processing."""
def __init__(self,
name, # type: str
transforms, # type: List[beam_runner_api_pb2.PTransform]
downstream_side_inputs=None, # type: Optional[Dict[str, SideInputId]]
downstream_side_inputs=None, # type: Optional[Dict[str, Dict[SideInputId, SideInputAccessPattern]]]
must_follow=frozenset(), # type: FrozenSet[Stage]
parent=None, # type: Optional[Stage]
environment=None, # type: Optional[str]
Expand Down Expand Up @@ -158,11 +166,12 @@ def no_overlap(a, b):
{i
for i, _, _ in consumer.side_inputs()}))

def _fuse_downstream_side_inputs(self, other):
def _get_fused_downstream_side_inputs(self, other):
# type: (Dict[str, Dict[SideInputId, SideInputAccessPattern]]) -> Dict[str, Dict[SideInputId, SideInputAccessPattern]]
res = dict(self.downstream_side_inputs)
for si, other_si_ids in other.downstream_side_inputs.items():
if si in res:
res[si] = union(res[si], other_si_ids)
res[si].update(other_si_ids)
else:
res[si] = other_si_ids
return res
Expand All @@ -172,7 +181,7 @@ def fuse(self, other):
return Stage(
"(%s)+(%s)" % (self.name, other.name),
self.transforms + other.transforms,
self._fuse_downstream_side_inputs(other),
self._get_fused_downstream_side_inputs(other),
union(self.must_follow, other.must_follow),
environment=self._merge_environments(
self.environment, other.environment),
Expand Down Expand Up @@ -645,7 +654,7 @@ def get_all_side_inputs():
all_side_inputs = get_all_side_inputs()

downstream_side_inputs_by_stage = {
} # type: Dict[Stage, DefaultDict[str, Dict[SideInputId, Any]]]
} # type: Dict[Stage, Dict[str, Dict[SideInputId, SideInputAccessPattern]]]

def compute_downstream_side_inputs(stage):
# type: (Stage) -> Dict[str, Dict[SideInputId, Any]]
Expand Down

0 comments on commit c5353c4

Please sign in to comment.