Skip to content

Commit

Permalink
Merge pull request apache#11653 from robertwb/split-points
Browse files Browse the repository at this point in the history
[BEAM-9935] Respect allowed split points in Python.
  • Loading branch information
lukecwik authored May 11, 2020
2 parents c541641 + 773854c commit cad0333
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 30 deletions.
3 changes: 3 additions & 0 deletions model/fn-execution/src/main/proto/beam_fn_api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ message ProcessBundleSplitRequest {

// A set of allowed element indices where the SDK may split. When this is
// empty, there are no constraints on where to split.
// Specifically, the first_residual_element of a split result must be an
// allowed split point, and the last_primary_element must immediately
// preceded an allowed split point.
repeated int64 allowed_split_points = 3;

// (Required for GrpcRead operations) Number of total elements expected
Expand Down
101 changes: 71 additions & 30 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from __future__ import print_function

import base64
import bisect
import collections
import json
import logging
Expand Down Expand Up @@ -216,15 +217,12 @@ def process_encoded(self, encoded_windowed_values):
input_stream, True)
self.output(decoded_value)

def try_split(self, fraction_of_remainder, total_buffer_size):
def try_split(
self, fraction_of_remainder, total_buffer_size, allowed_split_points):
# type: (...) -> Optional[Tuple[int, Optional[operations.SdfSplitResultsPrimary], Optional[operations.SdfSplitResultsResidual], int]]
with self.splitting_lock:
if not self.started:
return None
if total_buffer_size < self.index + 1:
total_buffer_size = self.index + 1
elif self.stop and total_buffer_size > self.stop:
total_buffer_size = self.stop
if self.index == -1:
# We are "finished" with the (non-existent) previous element.
current_element_progress = 1.0
Expand All @@ -237,30 +235,72 @@ def try_split(self, fraction_of_remainder, total_buffer_size):
current_element_progress = (
current_element_progress_object.fraction_completed)
# Now figure out where to split.
# The units here (except for keep_of_element_remainder) are all in
# terms of number of (possibly fractional) elements.
remainder = total_buffer_size - self.index - current_element_progress
keep = remainder * fraction_of_remainder
if current_element_progress < 1:
keep_of_element_remainder = keep / (1 - current_element_progress)
# If it's less than what's left of the current element,
# try splitting at the current element.
if keep_of_element_remainder < 1:
split = self.receivers[0].try_split(
keep_of_element_remainder
) # type: Optional[Tuple[operations.SdfSplitResultsPrimary, operations.SdfSplitResultsResidual]]
if split:
element_primary, element_residual = split
self.stop = self.index + 1
return self.index - 1, element_primary, element_residual, self.stop
# Otherwise, split at the closest element boundary.
# pylint: disable=round-builtin
stop_index = (
self.index + max(1, int(round(current_element_progress + keep))))
if stop_index < self.stop:
self.stop = stop_index
return self.stop - 1, None, None, self.stop
return None
split = self._compute_split(
self.index,
current_element_progress,
self.stop,
fraction_of_remainder,
total_buffer_size,
allowed_split_points,
self.receivers[0].try_split)
if split:
self.stop = split[-1]
return split

@staticmethod
def _compute_split(
index,
current_element_progress,
stop,
fraction_of_remainder,
total_buffer_size,
allowed_split_points=(),
try_split=lambda fraction: None):
def is_valid_split_point(index):
return not allowed_split_points or index in allowed_split_points

if total_buffer_size < index + 1:
total_buffer_size = index + 1
elif total_buffer_size > stop:
total_buffer_size = stop
# The units here (except for keep_of_element_remainder) are all in
# terms of number of (possibly fractional) elements.
remainder = total_buffer_size - index - current_element_progress
keep = remainder * fraction_of_remainder
if current_element_progress < 1:
keep_of_element_remainder = keep / (1 - current_element_progress)
# If it's less than what's left of the current element,
# try splitting at the current element.
if (keep_of_element_remainder < 1 and is_valid_split_point(index) and
is_valid_split_point(index + 1)):
split = try_split(
keep_of_element_remainder
) # type: Optional[Tuple[operations.SdfSplitResultsPrimary, operations.SdfSplitResultsResidual]]
if split:
element_primary, element_residual = split
return index - 1, element_primary, element_residual, index + 1
# Otherwise, split at the closest element boundary.
# pylint: disable=round-builtin
stop_index = index + max(1, int(round(current_element_progress + keep)))
if allowed_split_points and stop_index not in allowed_split_points:
# Choose the closest allowed split point.
allowed_split_points = sorted(allowed_split_points)
closest = bisect.bisect(allowed_split_points, stop_index)
if closest == 0:
stop_index = allowed_split_points[0]
elif closest == len(allowed_split_points):
stop_index = allowed_split_points[-1]
else:
prev = allowed_split_points[closest - 1]
next = allowed_split_points[closest]
if index < prev and stop_index - prev < next - stop_index:
stop_index = prev
else:
stop_index = next
if index < stop_index < stop:
return stop_index - 1, None, None, stop_index
else:
return None

def finish(self):
# type: () -> None
Expand Down Expand Up @@ -955,7 +995,8 @@ def try_split(self, bundle_split_request):
if desired_split:
split = op.try_split(
desired_split.fraction_of_remainder,
desired_split.estimated_input_elements)
desired_split.estimated_input_elements,
desired_split.allowed_split_points)
if split:
(
primary_end,
Expand Down
142 changes: 142 additions & 0 deletions sdks/python/apache_beam/runners/worker/bundle_processor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for bundle processing."""
# pytype: skip-file

from __future__ import absolute_import

import unittest

from apache_beam.runners.worker.bundle_processor import DataInputOperation


def simple_split(first_residual_index):
return first_residual_index - 1, None, None, first_residual_index


def element_split(frac, index):
return (
index - 1,
'Primary(%0.1f)' % frac,
'Residual(%0.1f)' % (1 - frac),
index + 1)


class SplitTest(unittest.TestCase):
def split(
self,
index,
current_element_progress,
fraction_of_remainder,
buffer_size,
allowed=(),
sdf=False):
return DataInputOperation._compute_split(
index,
current_element_progress,
float('inf'),
fraction_of_remainder,
buffer_size,
allowed_split_points=allowed,
try_split=lambda frac: element_split(frac, 0)[1:3] if sdf else None)

def sdf_split(self, *args, **kwargs):
return self.split(*args, sdf=True, **kwargs)

def test_simple_split(self):
# Split as close to the beginning as possible.
self.assertEqual(self.split(0, 0, 0, 16), simple_split(1))
# The closest split is at 4, even when just above or below it.
self.assertEqual(self.split(0, 0, 0.24, 16), simple_split(4))
self.assertEqual(self.split(0, 0, 0.25, 16), simple_split(4))
self.assertEqual(self.split(0, 0, 0.26, 16), simple_split(4))
# Split the *remainder* in half.
self.assertEqual(self.split(0, 0, 0.5, 16), simple_split(8))
self.assertEqual(self.split(2, 0, 0.5, 16), simple_split(9))
self.assertEqual(self.split(6, 0, 0.5, 16), simple_split(11))

def test_split_with_element_progress(self):
# Progress into the active element influences where the split of the
# remainder falls.
self.assertEqual(self.split(0, 0.5, 0.25, 4), simple_split(1))
self.assertEqual(self.split(0, 0.9, 0.25, 4), simple_split(2))
self.assertEqual(self.split(1, 0.0, 0.25, 4), simple_split(2))
self.assertEqual(self.split(1, 0.1, 0.25, 4), simple_split(2))

def test_split_with_element_allowed_splits(self):
# The desired split point is at 4.
self.assertEqual(
self.split(0, 0, 0.25, 16, allowed=(2, 3, 4, 5)), simple_split(4))
# If we can't split at 4, choose the closest possible split point.
self.assertEqual(
self.split(0, 0, 0.25, 16, allowed=(2, 3, 5)), simple_split(5))
self.assertEqual(
self.split(0, 0, 0.25, 16, allowed=(2, 3, 6)), simple_split(3))

# Also test the case where all possible split points lie above or below
# the desired split point.
self.assertEqual(
self.split(0, 0, 0.25, 16, allowed=(5, 6, 7)), simple_split(5))
self.assertEqual(
self.split(0, 0, 0.25, 16, allowed=(1, 2, 3)), simple_split(3))

# We have progressed beyond all possible split points, so can't split.
self.assertEqual(self.split(5, 0, 0.25, 16, allowed=(1, 2, 3)), None)

def test_sdf_split(self):
# Split between future elements at element boundaries.
self.assertEqual(self.sdf_split(0, 0, 0.51, 4), simple_split(2))
self.assertEqual(self.sdf_split(0, 0, 0.49, 4), simple_split(2))
self.assertEqual(self.sdf_split(0, 0, 0.26, 4), simple_split(1))
self.assertEqual(self.sdf_split(0, 0, 0.25, 4), simple_split(1))

# If the split falls inside the first, splittable element, split there.
self.assertEqual(
self.sdf_split(0, 0, 0.20, 4), (-1, 'Primary(0.8)', 'Residual(0.2)', 1))
# The choice of split depends on the progress into the first element.
self.assertEqual(
self.sdf_split(0, 0, .125, 4), (-1, 'Primary(0.5)', 'Residual(0.5)', 1))
# Here we are far enough into the first element that splitting at 0.2 of the
# remainder falls outside the first element.
self.assertEqual(self.sdf_split(0, .5, 0.2, 4), simple_split(1))

# Verify the above logic when we are partially throug the stream.
self.assertEqual(self.sdf_split(2, 0, 0.6, 4), simple_split(3))
self.assertEqual(self.sdf_split(2, 0.9, 0.6, 4), simple_split(4))
self.assertEqual(
self.sdf_split(2, 0.5, 0.2, 4), (1, 'Primary(0.6)', 'Residual(0.4)', 3))

def test_sdf_split_with_allowed_splits(self):
# This is where we would like to split, when all split points are available.
self.assertEqual(
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 2, 3, 4, 5)),
(1, 'Primary(0.6)', 'Residual(0.4)', 3))
# We can't split element at index 2, because 3 is not a split point.
self.assertEqual(
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 2, 4, 5)), simple_split(4))
# We can't even split element at index 4 as above, because 4 is also not a
# split point.
self.assertEqual(
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 2, 5)), simple_split(5))
# We can't split element at index 2, because 2 is not a split point.
self.assertEqual(
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 3, 4, 5)), simple_split(3))


if __name__ == '__main__':
unittest.main()

0 comments on commit cad0333

Please sign in to comment.