Skip to content

Commit

Permalink
[FLINK-9060][state] Fix concurrent modification exception when iterat…
Browse files Browse the repository at this point in the history
…ing keys.

This closes apache#5751.
  • Loading branch information
sihuazhou authored and kl0u committed Mar 29, 2018
1 parent d3c489d commit 62bbada
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,16 @@ public <N, S extends State, T> void applyToAllKeys(
final KeyedStateFunction<K, S> function) throws Exception {

try (Stream<K> keyStream = getKeys(stateDescriptor.getName(), namespace)) {

final S state = getPartitionedState(
namespace,
namespaceSerializer,
stateDescriptor);

keyStream.forEach((K key) -> {
setCurrentKey(key);
try {
function.process(
key,
getPartitionedState(
namespace,
namespaceSerializer,
stateDescriptor)
);
function.process(key, state);
} catch (Throwable e) {
// we wrap the checked exception in an unchecked
// one and catch it (and re-throw it) later.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ public abstract class KeyedStateFunction<K, S extends State> {
/**
* The actual method to be applied on each of the states.
*
* @param key a safe copy of the key (see {@link KeyedStateBackend#getCurrentKeySafe()})
* whose state is being processed.
* @param key the key whose state is being processed.
* @param state the state associated with the aforementioned key.
*/
public abstract void process(K key, S state) throws Exception;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.CompatibilityResult;
Expand All @@ -49,6 +50,7 @@
import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
import org.apache.flink.runtime.state.KeyedStateFunction;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.LocalRecoveryConfig;
import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
Expand Down Expand Up @@ -85,6 +87,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.RunnableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
Expand Down Expand Up @@ -470,6 +473,31 @@ public void notifyCheckpointComplete(long checkpointId) {
//Nothing to do
}

@Override
public <N, S extends State, T> void applyToAllKeys(
final N namespace,
final TypeSerializer<N> namespaceSerializer,
final StateDescriptor<S, T> stateDescriptor,
final KeyedStateFunction<K, S> function) throws Exception {

try (Stream<K> keyStream = getKeys(stateDescriptor.getName(), namespace)) {

// we copy the keys into list to avoid the concurrency problem
// when state.clear() is invoked in function.process().
final List<K> keys = keyStream.collect(Collectors.toList());

final S state = getPartitionedState(
namespace,
namespaceSerializer,
stateDescriptor);

for (K key : keys) {
setCurrentKey(key);
function.process(key, state);
}
}
}

@Override
public String toString() {
return "HeapKeyedStateBackend";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
Expand Down Expand Up @@ -3449,6 +3450,83 @@ public void testAsyncSnapshot() throws Exception {
}
}

/**
* Since {@link AbstractKeyedStateBackend#getKeys(String, Object)} does't support concurrent modification
* and {@link AbstractKeyedStateBackend#applyToAllKeys(Object, TypeSerializer, StateDescriptor,
* KeyedStateFunction)} rely on it to get keys from backend. So we need this unit test to verify the concurrent
* modification with {@link AbstractKeyedStateBackend#applyToAllKeys(Object, TypeSerializer, StateDescriptor, KeyedStateFunction)}.
*/
@Test
public void testConcurrentModificationWithApplyToAllKeys() throws Exception {
AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);

try {
ListStateDescriptor<String> listStateDescriptor =
new ListStateDescriptor<>("foo", StringSerializer.INSTANCE);

ListState<String> listState =
backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor);

for (int i = 0; i < 100; ++i) {
backend.setCurrentKey(i);
listState.add("Hello" + i);
}

// valid state value via applyToAllKeys().
backend.applyToAllKeys(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor,
new KeyedStateFunction<Integer, ListState<String>>() {
@Override
public void process(Integer key, ListState<String> state) throws Exception {
assertEquals("Hello" + key, state.get().iterator().next());
}
});

// clear state via applyToAllKeys().
backend.applyToAllKeys(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor,
new KeyedStateFunction<Integer, ListState<String>>() {
@Override
public void process(Integer key, ListState<String> state) throws Exception {
state.clear();
}
});

// valid that state has been cleared.
backend.applyToAllKeys(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor,
new KeyedStateFunction<Integer, ListState<String>>() {
@Override
public void process(Integer key, ListState<String> state) throws Exception {
assertFalse(state.get().iterator().hasNext());
}
});

// clear() with add() in applyToAllKeys()
backend.applyToAllKeys(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor,
new KeyedStateFunction<Integer, ListState<String>>() {
@Override
public void process(Integer key, ListState<String> state) throws Exception {
state.add("Hello" + key);
state.clear();
state.add("Hello_" + key);
}
});

// valid state value via applyToAllKeys().
backend.applyToAllKeys(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor,
new KeyedStateFunction<Integer, ListState<String>>() {
@Override
public void process(Integer key, ListState<String> state) throws Exception {
final Iterator<String> it = state.get().iterator();
assertEquals("Hello_" + key, it.next());
assertFalse(it.hasNext()); // finally verify we have no more elements
}
});
}
finally {
IOUtils.closeQuietly(backend);
backend.dispose();
}
}

@Test
public void testAsyncSnapshotCancellation() throws Exception {
OneShotLatch blocker = new OneShotLatch();
Expand Down

0 comments on commit 62bbada

Please sign in to comment.