diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index 649c6d03a6a2c..2634268c947ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -117,6 +117,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.RunnableFuture; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; @@ -2916,6 +2917,52 @@ public void testMapState() throws Exception { backend.dispose(); } + /** + * Verify iterator of {@link MapState} supporting arbitrary access, see [FLINK-10267] to know more details. + */ + @Test + public void testMapStateIteratorArbitraryAccess() throws Exception { + MapStateDescriptor kvId = new MapStateDescriptor<>("id", Integer.class, Long.class); + + AbstractKeyedStateBackend backend = createKeyedBackend(IntSerializer.INSTANCE); + + try { + MapState state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); + backend.setCurrentKey(1); + int stateSize = 4096; + for (int i = 0; i < stateSize; i++) { + state.put(i, i * 2L); + } + Iterator> iterator = state.iterator(); + int iteratorCount = 0; + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + assertEquals(iteratorCount, (int) entry.getKey()); + switch (ThreadLocalRandom.current().nextInt() % 3) { + case 0: // remove twice + iterator.remove(); + try { + iterator.remove(); + fail(); + } catch (IllegalStateException e) { + // ignore expected exception + } + break; + case 1: // hasNext -> remove + iterator.hasNext(); + iterator.remove(); + break; + case 2: // nothing to do + break; + } + iteratorCount++; + } + assertEquals(stateSize, iteratorCount); + } finally { + backend.dispose(); + } + } + /** * Verify that {@link ValueStateDescriptor} allows {@code null} as default. */ diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java index 5c9f7f9f30c2c..cb656b53b1b15 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java @@ -498,6 +498,7 @@ private abstract class RocksDBMapIterator implements Iterator { * have the same prefix, hence we can stop iterating once coming across an * entry with a different prefix. */ + @Nonnull private final byte[] keyPrefixBytes; /** @@ -508,6 +509,9 @@ private abstract class RocksDBMapIterator implements Iterator { /** A in-memory cache for the entries in the rocksdb. */ private ArrayList cacheEntries = new ArrayList<>(); + + /** The entry pointing to the current position which is last returned by calling {@link #nextEntry()}. */ + private RocksDBMapEntry currentEntry; private int cacheIndex = 0; private final TypeSerializer keySerializer; @@ -537,12 +541,11 @@ public boolean hasNext() { @Override public void remove() { - if (cacheIndex == 0 || cacheIndex > cacheEntries.size()) { + if (currentEntry == null || currentEntry.deleted) { throw new IllegalStateException("The remove operation must be called after a valid next operation."); } - RocksDBMapEntry lastEntry = cacheEntries.get(cacheIndex - 1); - lastEntry.remove(); + currentEntry.remove(); } final RocksDBMapEntry nextEntry() { @@ -556,10 +559,10 @@ final RocksDBMapEntry nextEntry() { return null; } - RocksDBMapEntry entry = cacheEntries.get(cacheIndex); + this.currentEntry = cacheEntries.get(cacheIndex); cacheIndex++; - return entry; + return currentEntry; } private void loadCache() { @@ -577,12 +580,11 @@ private void loadCache() { try (RocksIteratorWrapper iterator = RocksDBKeyedStateBackend.getRocksIterator(db, columnFamily)) { /* - * The iteration starts from the prefix bytes at the first loading. The cache then is - * reloaded when the next entry to return is the last one in the cache. At that time, - * we will start the iterating from the last returned entry. - */ - RocksDBMapEntry lastEntry = cacheEntries.size() == 0 ? null : cacheEntries.get(cacheEntries.size() - 1); - byte[] startBytes = (lastEntry == null ? keyPrefixBytes : lastEntry.rawKeyBytes); + * The iteration starts from the prefix bytes at the first loading. After #nextEntry() is called, + * the currentEntry points to the last returned entry, and at that time, we will start + * the iterating from currentEntry if reloading cache is needed. + */ + byte[] startBytes = (currentEntry == null ? keyPrefixBytes : currentEntry.rawKeyBytes); cacheEntries.clear(); cacheIndex = 0; @@ -590,10 +592,10 @@ private void loadCache() { iterator.seek(startBytes); /* - * If the last returned entry is not deleted, it will be the first entry in the - * iterating. Skip it to avoid redundant access in such cases. + * If the entry pointing to the current position is not removed, it will be the first entry in the + * new iterating. Skip it to avoid redundant access in such cases. */ - if (lastEntry != null && !lastEntry.deleted) { + if (currentEntry != null && !currentEntry.deleted) { iterator.next(); }