Skip to content

Commit

Permalink
[FLINK-8802] [QS] Fix concurrent access to non-duplicated serializers.
Browse files Browse the repository at this point in the history
This closes apache#5691.
  • Loading branch information
kl0u committed Mar 29, 2018
1 parent c16e2c9 commit db8e1f0
Show file tree
Hide file tree
Showing 44 changed files with 1,074 additions and 276 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ public void operationComplete(ChannelFuture future) throws Exception {
* @param request the request to be sent.
* @return Future holding the serialized result
*/
public CompletableFuture<RESP> sendRequest(REQ request) {
CompletableFuture<RESP> sendRequest(REQ request) {
synchronized (connectLock) {
if (failureCause != null) {
return FutureUtils.getFailedFuture(failureCause);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,15 @@
import org.apache.flink.queryablestate.network.AbstractServerHandler;
import org.apache.flink.queryablestate.network.messages.MessageSerializer;
import org.apache.flink.queryablestate.network.stats.KvStateRequestStats;
import org.apache.flink.runtime.query.KvStateEntry;
import org.apache.flink.runtime.query.KvStateInfo;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;

import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;

/**
Expand All @@ -50,8 +49,6 @@
@ChannelHandler.Sharable
public class KvStateServerHandler extends AbstractServerHandler<KvStateInternalRequest, KvStateResponse> {

private static final Logger LOG = LoggerFactory.getLogger(KvStateServerHandler.class);

/** KvState registry holding references to the KvState instances. */
private final KvStateRegistry registry;

Expand All @@ -78,13 +75,13 @@ public CompletableFuture<KvStateResponse> handleRequest(final long requestId, fi
final CompletableFuture<KvStateResponse> responseFuture = new CompletableFuture<>();

try {
final InternalKvState<?> kvState = registry.getKvState(request.getKvStateId());
final KvStateEntry<?, ?, ?> kvState = registry.getKvState(request.getKvStateId());
if (kvState == null) {
responseFuture.completeExceptionally(new UnknownKvStateIdException(getServerName(), request.getKvStateId()));
} else {
byte[] serializedKeyAndNamespace = request.getSerializedKeyAndNamespace();

byte[] serializedResult = kvState.getSerializedValue(serializedKeyAndNamespace);
byte[] serializedResult = getSerializedValue(kvState, serializedKeyAndNamespace);
if (serializedResult != null) {
responseFuture.complete(new KvStateResponse(serializedResult));
} else {
Expand All @@ -100,6 +97,21 @@ public CompletableFuture<KvStateResponse> handleRequest(final long requestId, fi
}
}

private static <K, N, V> byte[] getSerializedValue(
final KvStateEntry<K, N, V> entry,
final byte[] serializedKeyAndNamespace) throws Exception {

final InternalKvState<K, N, V> state = entry.getState();
final KvStateInfo<K, N, V> infoForCurrentThread = entry.getInfoForCurrentThread();

return state.getSerializedValue(
serializedKeyAndNamespace,
infoForCurrentThread.getKeySerializer(),
infoForCurrentThread.getNamespaceSerializer(),
infoForCurrentThread.getStateValueSerializer()
);
}

@Override
public CompletableFuture<Void> shutdown() {
return CompletableFuture.completedFuture(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,8 @@ public void testClientServerIntegration() throws Throwable {

state.update(201 + i);

// we know it must be a KvStat but this is not exposed to the user via State
InternalKvState<?> kvState = (InternalKvState<?>) state;
// we know it must be a KvState but this is not exposed to the user via State
InternalKvState<Integer, ?, Integer> kvState = (InternalKvState<Integer, ?, Integer>) state;

// Register KvState (one state instance for all server)
ids[i] = registry[i].registerKvState(new JobID(), new JobVertexID(), new KeyGroupRange(0, 0), "any", kvState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.rocksdb.DBOptions;

import java.io.File;
import java.util.Map;

import static org.mockito.Mockito.mock;

Expand Down Expand Up @@ -82,7 +83,7 @@ static final class RocksDBKeyedStateBackend2<K> extends RocksDBKeyedStateBackend
}

@Override
public <N, T> InternalListState<N, T> createListState(
public <N, T> InternalListState<K, N, T> createListState(
final TypeSerializer<N> namespaceSerializer,
final ListStateDescriptor<T> stateDesc) throws Exception {

Expand Down Expand Up @@ -120,7 +121,7 @@ public void testListSerialization() throws Exception {
longHeapKeyedStateBackend.restore(null);
longHeapKeyedStateBackend.setCurrentKey(key);

final InternalListState<VoidNamespace, Long> listState = longHeapKeyedStateBackend
final InternalListState<Long, VoidNamespace, Long> listState = longHeapKeyedStateBackend
.createListState(VoidNamespaceSerializer.INSTANCE,
new ListStateDescriptor<>("test", LongSerializer.INSTANCE));

Expand Down Expand Up @@ -159,11 +160,12 @@ public void testMapSerialization() throws Exception {
longHeapKeyedStateBackend.restore(null);
longHeapKeyedStateBackend.setCurrentKey(key);

final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>)
longHeapKeyedStateBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE));
final InternalMapState<Long, VoidNamespace, Long, String, Map<Long, String>> mapState =
(InternalMapState<Long, VoidNamespace, Long, String, Map<Long, String>>)
longHeapKeyedStateBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE));

KvStateRequestSerializerTest.testMapSerialization(key, mapState);
longHeapKeyedStateBackend.dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void testListSerialization() throws Exception {
);
longHeapKeyedStateBackend.setCurrentKey(key);

final InternalListState<VoidNamespace, Long> listState = longHeapKeyedStateBackend.createListState(
final InternalListState<Long, VoidNamespace, Long> listState = longHeapKeyedStateBackend.createListState(
VoidNamespaceSerializer.INSTANCE,
new ListStateDescriptor<>("test", LongSerializer.INSTANCE));

Expand All @@ -220,7 +220,7 @@ public void testListSerialization() throws Exception {
*/
public static void testListSerialization(
final long key,
final InternalListState<VoidNamespace, Long> listState) throws Exception {
final InternalListState<Long, VoidNamespace, Long> listState) throws Exception {

TypeSerializer<Long> valueSerializer = LongSerializer.INSTANCE;
listState.setCurrentNamespace(VoidNamespace.INSTANCE);
Expand All @@ -240,7 +240,11 @@ public static void testListSerialization(
key, LongSerializer.INSTANCE,
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE);

final byte[] serializedValues = listState.getSerializedValue(serializedKey);
final byte[] serializedValues = listState.getSerializedValue(
serializedKey,
listState.getKeySerializer(),
listState.getNamespaceSerializer(),
listState.getValueSerializer());

List<Long> actualValues = KvStateSerializer.deserializeList(serializedValues, valueSerializer);
assertEquals(expectedValues, actualValues);
Expand Down Expand Up @@ -303,10 +307,12 @@ public void testMapSerialization() throws Exception {
);
longHeapKeyedStateBackend.setCurrentKey(key);

final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>) longHeapKeyedStateBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE));
final InternalMapState<Long, VoidNamespace, Long, String, HashMap<Long, String>> mapState =
(InternalMapState<Long, VoidNamespace, Long, String, HashMap<Long, String>>)
longHeapKeyedStateBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE));

testMapSerialization(key, mapState);
}
Expand All @@ -322,9 +328,9 @@ public void testMapSerialization() throws Exception {
*
* @throws Exception
*/
public static void testMapSerialization(
public static <M extends Map<Long, String>> void testMapSerialization(
final long key,
final InternalMapState<VoidNamespace, Long, String> mapState) throws Exception {
final InternalMapState<Long, VoidNamespace, Long, String, M> mapState) throws Exception {

TypeSerializer<Long> userKeySerializer = LongSerializer.INSTANCE;
TypeSerializer<String> userValueSerializer = StringSerializer.INSTANCE;
Expand All @@ -348,7 +354,11 @@ public static void testMapSerialization(
key, LongSerializer.INSTANCE,
VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE);

final byte[] serializedValues = mapState.getSerializedValue(serializedKey);
final byte[] serializedValues = mapState.getSerializedValue(
serializedKey,
mapState.getKeySerializer(),
mapState.getNamespaceSerializer(),
mapState.getValueSerializer());

Map<Long, String> actualValues = KvStateSerializer.deserializeMap(serializedValues, userKeySerializer, userValueSerializer);
assertEquals(expectedValues.size(), actualValues.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.queryablestate.KvStateID;
Expand Down Expand Up @@ -70,9 +72,6 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Tests for {@link KvStateServerHandler}.
Expand Down Expand Up @@ -286,7 +285,7 @@ public void testQueryUnknownKey() throws Exception {
}

/**
* Tests the failure response on a failure on the {@link InternalKvState#getSerializedValue(byte[])} call.
* Tests the failure response on a failure on the {@link InternalKvState#getSerializedValue(byte[], TypeSerializer, TypeSerializer, TypeSerializer)} call.
*/
@Test
public void testFailureOnGetSerializedValue() throws Exception {
Expand All @@ -300,9 +299,42 @@ public void testFailureOnGetSerializedValue() throws Exception {
EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);

// Failing KvState
InternalKvState<?> kvState = mock(InternalKvState.class);
when(kvState.getSerializedValue(any(byte[].class)))
.thenThrow(new RuntimeException("Expected test Exception"));
InternalKvState<Integer, VoidNamespace, Long> kvState =
new InternalKvState<Integer, VoidNamespace, Long>() {
@Override
public TypeSerializer<Integer> getKeySerializer() {
return IntSerializer.INSTANCE;
}

@Override
public TypeSerializer<VoidNamespace> getNamespaceSerializer() {
return VoidNamespaceSerializer.INSTANCE;
}

@Override
public TypeSerializer<Long> getValueSerializer() {
return LongSerializer.INSTANCE;
}

@Override
public void setCurrentNamespace(VoidNamespace namespace) {
// do nothing
}

@Override
public byte[] getSerializedValue(
final byte[] serializedKeyAndNamespace,
final TypeSerializer<Integer> safeKeySerializer,
final TypeSerializer<VoidNamespace> safeNamespaceSerializer,
final TypeSerializer<Long> safeValueSerializer) throws Exception {
throw new RuntimeException("Expected test Exception");
}

@Override
public void clear() {

}
};

KvStateID kvStateId = registry.registerKvState(
new JobID(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.query;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.util.Preconditions;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* An entry holding the {@link InternalKvState} along with its {@link KvStateInfo}.
*
* @param <K> The type of key the state is associated to
* @param <N> The type of the namespace the state is associated to
* @param <V> The type of values kept internally in state
*/
@Internal
public class KvStateEntry<K, N, V> {

private final InternalKvState<K, N, V> state;
private final KvStateInfo<K, N, V> stateInfo;

private final boolean areSerializersStateless;

private final ConcurrentMap<Thread, KvStateInfo<K, N, V>> serializerCache;

public KvStateEntry(final InternalKvState<K, N, V> state) {
this.state = Preconditions.checkNotNull(state);
this.stateInfo = new KvStateInfo<>(
state.getKeySerializer(),
state.getNamespaceSerializer(),
state.getValueSerializer()
);
this.serializerCache = new ConcurrentHashMap<>();
this.areSerializersStateless = stateInfo.duplicate() == stateInfo;
}

public InternalKvState<K, N, V> getState() {
return state;
}

public KvStateInfo<K, N, V> getInfoForCurrentThread() {
return areSerializersStateless
? stateInfo
: serializerCache.computeIfAbsent(Thread.currentThread(), t -> stateInfo.duplicate());
}

public void clear() {
serializerCache.clear();
}

@VisibleForTesting
public int getCacheSize() {
return serializerCache.size();
}
}
Loading

0 comments on commit db8e1f0

Please sign in to comment.