Skip to content

Commit

Permalink
Merge pull request apache#10114 Cleanup: move direct runner test to c…
Browse files Browse the repository at this point in the history
…orrect location.
  • Loading branch information
robertwb authored Nov 15, 2019
2 parents e35e6b8 + 8f86340 commit 8c3af8a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 83 deletions.
83 changes: 0 additions & 83 deletions sdks/python/apache_beam/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import unittest
from builtins import object
from builtins import range
from collections import defaultdict

import mock

Expand All @@ -40,9 +39,6 @@
from apache_beam.pvalue import AsSingleton
from apache_beam.pvalue import TaggedOutput
from apache_beam.runners.dataflow.native_io.iobase import NativeSource
from apache_beam.runners.direct.evaluation_context import _ExecutionContext
from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
Expand Down Expand Up @@ -812,84 +808,5 @@ def expand(self, p):
p.transforms_stack[0])


class DirectRunnerRetryTests(unittest.TestCase):

def test_retry_fork_graph(self):
# TODO(BEAM-3642): The FnApiRunner currently does not currently support
# retries.
p = beam.Pipeline(runner='BundleBasedDirectRunner')

# TODO(mariagh): Remove the use of globals from the test.
global count_b, count_c # pylint: disable=global-variable-undefined
count_b, count_c = 0, 0

def f_b(x):
global count_b # pylint: disable=global-variable-undefined
count_b += 1
raise Exception('exception in f_b')

def f_c(x):
global count_c # pylint: disable=global-variable-undefined
count_c += 1
raise Exception('exception in f_c')

names = p | 'CreateNodeA' >> beam.Create(['Ann', 'Joe'])

fork_b = names | 'SendToB' >> beam.Map(f_b) # pylint: disable=unused-variable
fork_c = names | 'SendToC' >> beam.Map(f_c) # pylint: disable=unused-variable

with self.assertRaises(Exception):
p.run().wait_until_finish()
assert count_b == count_c == 4

def test_no_partial_writeouts(self):

class TestTransformEvaluator(_TransformEvaluator):

def __init__(self):
self._execution_context = _ExecutionContext(None, {})

def start_bundle(self):
self.step_context = self._execution_context.get_step_context()

def process_element(self, element):
k, v = element
state = self.step_context.get_keyed_state(k)
state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)

# Create instance and add key/value, key/value2
evaluator = TestTransformEvaluator()
evaluator.start_bundle()
self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key'))
self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key'))

evaluator.process_element(['key', 'value'])
self.assertEqual(
evaluator.step_context.existing_keyed_state['key'].state,
defaultdict(lambda: defaultdict(list)))
self.assertEqual(
evaluator.step_context.partial_keyed_state['key'].state,
{None: {'elements':['value']}})

evaluator.process_element(['key', 'value2'])
self.assertEqual(
evaluator.step_context.existing_keyed_state['key'].state,
defaultdict(lambda: defaultdict(list)))
self.assertEqual(
evaluator.step_context.partial_keyed_state['key'].state,
{None: {'elements':['value', 'value2']}})

# Simulate an exception (redo key/value)
evaluator._execution_context.reset()
evaluator.start_bundle()
evaluator.process_element(['key', 'value'])
self.assertEqual(
evaluator.step_context.existing_keyed_state['key'].state,
defaultdict(lambda: defaultdict(list)))
self.assertEqual(
evaluator.step_context.partial_keyed_state['key'].state,
{None: {'elements':['value']}})


if __name__ == '__main__':
unittest.main()
83 changes: 83 additions & 0 deletions sdks/python/apache_beam/runners/direct/direct_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import threading
import unittest
from collections import defaultdict

import hamcrest as hc

Expand All @@ -33,6 +34,9 @@
from apache_beam.runners import DirectRunner
from apache_beam.runners import TestDirectRunner
from apache_beam.runners import create_runner
from apache_beam.runners.direct.evaluation_context import _ExecutionContext
from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
from apache_beam.testing import test_pipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
Expand Down Expand Up @@ -129,5 +133,84 @@ def test_type_hints(self):
| beam.combiners.Count.Globally())


class DirectRunnerRetryTests(unittest.TestCase):

def test_retry_fork_graph(self):
# TODO(BEAM-3642): The FnApiRunner currently does not currently support
# retries.
p = beam.Pipeline(runner='BundleBasedDirectRunner')

# TODO(mariagh): Remove the use of globals from the test.
global count_b, count_c # pylint: disable=global-variable-undefined
count_b, count_c = 0, 0

def f_b(x):
global count_b # pylint: disable=global-variable-undefined
count_b += 1
raise Exception('exception in f_b')

def f_c(x):
global count_c # pylint: disable=global-variable-undefined
count_c += 1
raise Exception('exception in f_c')

names = p | 'CreateNodeA' >> beam.Create(['Ann', 'Joe'])

fork_b = names | 'SendToB' >> beam.Map(f_b) # pylint: disable=unused-variable
fork_c = names | 'SendToC' >> beam.Map(f_c) # pylint: disable=unused-variable

with self.assertRaises(Exception):
p.run().wait_until_finish()
assert count_b == count_c == 4

def test_no_partial_writeouts(self):

class TestTransformEvaluator(_TransformEvaluator):

def __init__(self):
self._execution_context = _ExecutionContext(None, {})

def start_bundle(self):
self.step_context = self._execution_context.get_step_context()

def process_element(self, element):
k, v = element
state = self.step_context.get_keyed_state(k)
state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)

# Create instance and add key/value, key/value2
evaluator = TestTransformEvaluator()
evaluator.start_bundle()
self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key'))
self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key'))

evaluator.process_element(['key', 'value'])
self.assertEqual(
evaluator.step_context.existing_keyed_state['key'].state,
defaultdict(lambda: defaultdict(list)))
self.assertEqual(
evaluator.step_context.partial_keyed_state['key'].state,
{None: {'elements':['value']}})

evaluator.process_element(['key', 'value2'])
self.assertEqual(
evaluator.step_context.existing_keyed_state['key'].state,
defaultdict(lambda: defaultdict(list)))
self.assertEqual(
evaluator.step_context.partial_keyed_state['key'].state,
{None: {'elements':['value', 'value2']}})

# Simulate an exception (redo key/value)
evaluator._execution_context.reset()
evaluator.start_bundle()
evaluator.process_element(['key', 'value'])
self.assertEqual(
evaluator.step_context.existing_keyed_state['key'].state,
defaultdict(lambda: defaultdict(list)))
self.assertEqual(
evaluator.step_context.partial_keyed_state['key'].state,
{None: {'elements':['value']}})


if __name__ == '__main__':
unittest.main()

0 comments on commit 8c3af8a

Please sign in to comment.