Skip to content

Commit

Permalink
[BEAM-10249] Populate state cache with initial values before appending
Browse files Browse the repository at this point in the history
When new values are added to a bag state, the cache handler in the Python SDK
does not check whether the underlying state already has values present. Although
the append operation will be sent to the backend, the cache itself will be
corrupted because it returns only the newly appended values for get operations,
not the values present before the append operation.

This is a concern when restoring from a checkpoint but also when items are
evicted from the cache.
  • Loading branch information
mxm committed Jun 15, 2020
1 parent 0588260 commit c8eb5ac
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 8 deletions.
16 changes: 8 additions & 8 deletions sdks/python/apache_beam/runners/portability/flink_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,9 @@ def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
# depends on the cache token which is lazily initialized by the
# Runner's StateRequestHandlers.
'stateful.beam.metric:statecache:size: 20',
'stateful.beam.metric:statecache:get: 10',
'stateful.beam.metric:statecache:get: 20',
'stateful.beam.metric:statecache:miss: 0',
'stateful.beam.metric:statecache:hit: 10',
'stateful.beam.metric:statecache:hit: 20',
'stateful.beam.metric:statecache:put: 0',
'stateful.beam.metric:statecache:extend: 10',
'stateful.beam.metric:statecache:evict: 0',
Expand All @@ -313,9 +313,9 @@ def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
# initialized by the Runner's StateRequestHandlers).
# If cross-bundle caching is not requested, caching is done
# at the bundle level.
'stateful.beam.metric:statecache:get_total: 110',
'stateful.beam.metric:statecache:get_total: 220',
'stateful.beam.metric:statecache:miss_total: 20',
'stateful.beam.metric:statecache:hit_total: 90',
'stateful.beam.metric:statecache:hit_total: 200',
'stateful.beam.metric:statecache:put_total: 20',
'stateful.beam.metric:statecache:extend_total: 110',
'stateful.beam.metric:statecache:evict_total: 0',
Expand All @@ -330,17 +330,17 @@ def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
# It's lazily initialized after first access in StateRequestHandlers
'stateful).beam.metric:statecache:size: 10',
# We have 11 here because there are 110 / 10 elements per key
'stateful).beam.metric:statecache:get: 11',
'stateful).beam.metric:statecache:get: 12',
'stateful).beam.metric:statecache:miss: 1',
'stateful).beam.metric:statecache:hit: 10',
'stateful).beam.metric:statecache:hit: 11',
# State is flushed back once per key
'stateful).beam.metric:statecache:put: 1',
'stateful).beam.metric:statecache:extend: 1',
'stateful).beam.metric:statecache:evict: 0',
# Counters
'stateful).beam.metric:statecache:get_total: 110',
'stateful).beam.metric:statecache:get_total: 120',
'stateful).beam.metric:statecache:miss_total: 10',
'stateful).beam.metric:statecache:hit_total: 100',
'stateful).beam.metric:statecache:hit_total: 110',
'stateful).beam.metric:statecache:put_total: 10',
'stateful).beam.metric:statecache:extend_total: 10',
'stateful).beam.metric:statecache:evict_total: 0',
Expand Down
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/runners/worker/sdk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,10 @@ def extend(self,
if cache_token:
# Update the cache
cache_key = self._convert_to_cache_key(state_key)
if self._state_cache.get(cache_key, cache_token) is None:
# We have never cached this key before, first initialize cache
self.blocking_get(state_key, coder, is_cached=True)
# Now update the values in the cache
self._state_cache.extend(cache_key, cache_token, elements)
# Write to state handler
out = coder_impl.create_OutputStream()
Expand Down
56 changes: 56 additions & 0 deletions sdks/python/apache_beam/runners/worker/sdk_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,62 @@ def get_as_list(key):
self.assertEqual(get_as_list(side2), [502]) # uncached
self.assertEqual(get_as_list(side2), [502]) # cached on bundle

def test_extend_fetches_initial_state(self):
coder = VarIntCoder()
coder_impl = coder.get_impl()

class UnderlyingStateHandler(object):
"""Simply returns an incremented counter as the state "value."
"""
def set_value(self, value):
self._encoded_values = coder.encode(value)

def get_raw(self, *args):
return self._encoded_values, None

def append_raw(self, _key, bytes):
self._encoded_values += bytes

def clear(self, *args):
self._encoded_values = bytes()

@contextlib.contextmanager
def process_instruction_id(self, bundle_id):
yield

underlying_state_handler = UnderlyingStateHandler()
state_cache = statecache.StateCache(100)
handler = sdk_worker.CachingStateHandler(
state_cache, underlying_state_handler)

state = beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
user_state_id='state1'))

cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
token=b'state_token1',
user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.UserState())

def get():
return list(handler.blocking_get(state, coder_impl, True))

def append(value):
handler.extend(state, coder_impl, [value], True)

def clear():
handler.clear(state, True)

# Initialize state
underlying_state_handler.set_value(42)
with handler.process_instruction_id('bundle', [cache_token]):
# Append without reading beforehand
append(43)
self.assertEqual(get(), [42, 43])
clear()
self.assertEqual(get(), [])
append(44)
self.assertEqual(get(), [44])


class ShortIdCacheTest(unittest.TestCase):
def testShortIdAssignment(self):
Expand Down

0 comments on commit c8eb5ac

Please sign in to comment.