diff --git a/docs/_includes/generated/rocks_db_configuration.html b/docs/_includes/generated/rocks_db_configuration.html
index 8983f8b41dde9..81f6b53f11793 100644
--- a/docs/_includes/generated/rocks_db_configuration.html
+++ b/docs/_includes/generated/rocks_db_configuration.html
@@ -7,6 +7,11 @@
+
+ state.backend.rocksdb.checkpoint.restore.thread.num |
+ 1 |
+ The number of threads used to download files from DFS in RocksDBStateBackend. |
+
state.backend.rocksdb.localdir |
(none) |
diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
index fff8f02285f48..96a28903792a8 100644
--- a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
+++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
@@ -82,6 +82,7 @@ public void testListSerialization() throws Exception {
new KeyGroupRange(0, 0),
new ExecutionConfig(),
false,
+ 1,
TestLocalRecoveryConfig.disabled(),
RocksDBStateBackend.PriorityQueueStateType.HEAP,
TtlTimeProvider.DEFAULT,
@@ -126,6 +127,7 @@ public void testMapSerialization() throws Exception {
new KeyGroupRange(0, 0),
new ExecutionConfig(),
false,
+ 1,
TestLocalRecoveryConfig.disabled(),
RocksDBStateBackend.PriorityQueueStateType.HEAP,
TtlTimeProvider.DEFAULT,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index a37f8aa8df827..700c5468c8c40 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -37,7 +37,6 @@
import org.apache.flink.contrib.streaming.state.snapshot.RocksFullSnapshotStrategy;
import org.apache.flink.contrib.streaming.state.snapshot.RocksIncrementalSnapshotStrategy;
import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FSDataOutputStream;
import org.apache.flink.core.fs.FileStatus;
import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
@@ -127,6 +126,7 @@
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
+import static org.apache.flink.contrib.streaming.state.RocksDbStateDataTransfer.transferAllStateDataToDirectory;
import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.clearMetaDataFollowsFlag;
@@ -217,6 +217,9 @@ IS createState(
/** True if incremental checkpointing is enabled. */
private final boolean enableIncrementalCheckpointing;
+ /** Thread number used to download from DFS when restore. */
+ private final int restoringThreadNum;
+
/** The configuration of local recovery. */
private final LocalRecoveryConfig localRecoveryConfig;
@@ -251,6 +254,7 @@ public RocksDBKeyedStateBackend(
KeyGroupRange keyGroupRange,
ExecutionConfig executionConfig,
boolean enableIncrementalCheckpointing,
+ int restoringThreadNum,
LocalRecoveryConfig localRecoveryConfig,
RocksDBStateBackend.PriorityQueueStateType priorityQueueStateType,
TtlTimeProvider ttlTimeProvider,
@@ -264,6 +268,7 @@ public RocksDBKeyedStateBackend(
this.operatorIdentifier = Preconditions.checkNotNull(operatorIdentifier);
this.enableIncrementalCheckpointing = enableIncrementalCheckpointing;
+ this.restoringThreadNum = restoringThreadNum;
this.rocksDBResourceGuard = new ResourceGuard();
// ensure that we use the right merge operator, because other code relies on this
@@ -494,7 +499,7 @@ public void restore(Collection restoreState) throws Exception
LOG.info("Initializing RocksDB keyed state backend.");
if (LOG.isDebugEnabled()) {
- LOG.debug("Restoring snapshot from state handles: {}.", restoreState);
+ LOG.debug("Restoring snapshot from state handles: {}, will use {} thread(s) to download files from DFS.", restoreState, restoringThreadNum);
}
// clear all meta data
@@ -876,7 +881,7 @@ void restoreWithoutRescaling(KeyedStateHandle rawStateHandle) throws Exception {
IncrementalKeyedStateHandle restoreStateHandle = (IncrementalKeyedStateHandle) rawStateHandle;
// read state data.
- transferAllStateDataToDirectory(restoreStateHandle, temporaryRestoreInstancePath);
+ transferAllStateDataToDirectory(restoreStateHandle, temporaryRestoreInstancePath, stateBackend.restoringThreadNum, stateBackend.cancelStreamRegistry);
stateMetaInfoSnapshots = readMetaData(restoreStateHandle.getMetaStateHandle());
columnFamilyDescriptors = createAndRegisterColumnFamilyDescriptors(stateMetaInfoSnapshots);
@@ -1029,7 +1034,7 @@ private RestoredDBInstance restoreDBInstanceFromStateHandle(
IncrementalKeyedStateHandle restoreStateHandle,
Path temporaryRestoreInstancePath) throws Exception {
- transferAllStateDataToDirectory(restoreStateHandle, temporaryRestoreInstancePath);
+ transferAllStateDataToDirectory(restoreStateHandle, temporaryRestoreInstancePath, stateBackend.restoringThreadNum, stateBackend.cancelStreamRegistry);
// read meta data
List stateMetaInfoSnapshots =
@@ -1274,74 +1279,6 @@ private List readMetaData(
}
}
}
-
- private void transferAllStateDataToDirectory(
- IncrementalKeyedStateHandle restoreStateHandle,
- Path dest) throws IOException {
-
- final Map sstFiles =
- restoreStateHandle.getSharedState();
- final Map miscFiles =
- restoreStateHandle.getPrivateState();
-
- transferAllDataFromStateHandles(sstFiles, dest);
- transferAllDataFromStateHandles(miscFiles, dest);
- }
-
- /**
- * Copies all the files from the given stream state handles to the given path, renaming the files w.r.t. their
- * {@link StateHandleID}.
- */
- private void transferAllDataFromStateHandles(
- Map stateHandleMap,
- Path restoreInstancePath) throws IOException {
-
- for (Map.Entry entry : stateHandleMap.entrySet()) {
- StateHandleID stateHandleID = entry.getKey();
- StreamStateHandle remoteFileHandle = entry.getValue();
- copyStateDataHandleData(new Path(restoreInstancePath, stateHandleID.toString()), remoteFileHandle);
- }
-
- }
-
- /**
- * Copies the file from a single state handle to the given path.
- */
- private void copyStateDataHandleData(
- Path restoreFilePath,
- StreamStateHandle remoteFileHandle) throws IOException {
-
- FileSystem restoreFileSystem = restoreFilePath.getFileSystem();
-
- FSDataInputStream inputStream = null;
- FSDataOutputStream outputStream = null;
-
- try {
- inputStream = remoteFileHandle.openInputStream();
- stateBackend.cancelStreamRegistry.registerCloseable(inputStream);
-
- outputStream = restoreFileSystem.create(restoreFilePath, FileSystem.WriteMode.OVERWRITE);
- stateBackend.cancelStreamRegistry.registerCloseable(outputStream);
-
- byte[] buffer = new byte[8 * 1024];
- while (true) {
- int numBytes = inputStream.read(buffer);
- if (numBytes == -1) {
- break;
- }
-
- outputStream.write(buffer, 0, numBytes);
- }
- } finally {
- if (stateBackend.cancelStreamRegistry.unregisterCloseable(inputStream)) {
- inputStream.close();
- }
-
- if (stateBackend.cancelStreamRegistry.unregisterCloseable(outputStream)) {
- outputStream.close();
- }
- }
- }
}
// ------------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
index c85a7b2077caa..9b15bf13fd1c4 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOptions.java
@@ -45,4 +45,12 @@ public class RocksDBOptions {
.withDescription(String.format("This determines the factory for timer service state implementation. Options " +
"are either %s (heap-based, default) or %s for an implementation based on RocksDB .",
HEAP.name(), ROCKSDB.name()));
+
+ /**
+ * The number of threads used to download files from DFS in RocksDBStateBackend.
+ */
+ public static final ConfigOption CHECKPOINT_RESTORE_THREAD_NUM = ConfigOptions
+ .key("state.backend.rocksdb.checkpoint.restore.thread.num")
+ .defaultValue(1)
+ .withDescription("The number of threads used to download files from DFS in RocksDBStateBackend.");
}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 794e22160a25e..080e7cfda7258 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -40,6 +40,7 @@
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
import org.apache.flink.util.AbstractID;
+import org.apache.flink.util.Preconditions;
import org.apache.flink.util.TernaryBoolean;
import org.rocksdb.ColumnFamilyOptions;
@@ -97,6 +98,8 @@ public enum PriorityQueueStateType {
/** Flag whether the native library has been loaded. */
private static boolean rocksDbInitialized = false;
+ private static final int UNDEFINED_NUMBER_OF_RESTORING_THREADS = -1;
+
// ------------------------------------------------------------------------
// -- configuration values, set in the application / configuration
@@ -120,6 +123,9 @@ public enum PriorityQueueStateType {
/** This determines if incremental checkpointing is enabled. */
private final TernaryBoolean enableIncrementalCheckpointing;
+ /** Thread number used to download from DFS when restore, default value: 1. */
+ private int numberOfRestoringThreads;
+
/** This determines the type of priority queue state. */
private final PriorityQueueStateType priorityQueueStateType;
@@ -238,6 +244,7 @@ public RocksDBStateBackend(StateBackend checkpointStreamBackend) {
public RocksDBStateBackend(StateBackend checkpointStreamBackend, TernaryBoolean enableIncrementalCheckpointing) {
this.checkpointStreamBackend = checkNotNull(checkpointStreamBackend);
this.enableIncrementalCheckpointing = enableIncrementalCheckpointing;
+ this.numberOfRestoringThreads = UNDEFINED_NUMBER_OF_RESTORING_THREADS;
// for now, we use still the heap-based implementation as default
this.priorityQueueStateType = PriorityQueueStateType.HEAP;
this.defaultMetricOptions = new RocksDBNativeMetricOptions();
@@ -276,6 +283,12 @@ private RocksDBStateBackend(RocksDBStateBackend original, Configuration config)
this.enableIncrementalCheckpointing = original.enableIncrementalCheckpointing.resolveUndefined(
config.getBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS));
+ if (original.numberOfRestoringThreads == UNDEFINED_NUMBER_OF_RESTORING_THREADS) {
+ this.numberOfRestoringThreads = config.getInteger(RocksDBOptions.CHECKPOINT_RESTORE_THREAD_NUM);
+ } else {
+ this.numberOfRestoringThreads = original.numberOfRestoringThreads;
+ }
+
final String priorityQueueTypeString = config.getString(TIMER_SERVICE_FACTORY);
this.priorityQueueStateType = priorityQueueTypeString.length() > 0 ?
@@ -452,6 +465,7 @@ public AbstractKeyedStateBackend createKeyedStateBackend(
keyGroupRange,
env.getExecutionConfig(),
isIncrementalCheckpointsEnabled(),
+ getNumberOfRestoringThreads(),
localRecoveryConfig,
priorityQueueStateType,
ttlTimeProvider,
@@ -686,6 +700,20 @@ public RocksDBNativeMetricOptions getMemoryWatcherOptions() {
return options;
}
+ /**
+ * Gets the thread number will used for downloading files from DFS when restore.
+ */
+ public int getNumberOfRestoringThreads() {
+ return numberOfRestoringThreads == UNDEFINED_NUMBER_OF_RESTORING_THREADS ?
+ RocksDBOptions.CHECKPOINT_RESTORE_THREAD_NUM.defaultValue() : numberOfRestoringThreads;
+ }
+
+ public void setNumberOfRestoringThreads(int numberOfRestoringThreads) {
+ Preconditions.checkArgument(numberOfRestoringThreads > 0,
+ "The number of threads used to download files from DFS in RocksDBStateBackend should > 0.");
+ this.numberOfRestoringThreads = numberOfRestoringThreads;
+ }
+
// ------------------------------------------------------------------------
// utilities
// ------------------------------------------------------------------------
@@ -696,6 +724,7 @@ public String toString() {
"checkpointStreamBackend=" + checkpointStreamBackend +
", localRocksDbDirectories=" + Arrays.toString(localRocksDbDirectories) +
", enableIncrementalCheckpointing=" + enableIncrementalCheckpointing +
+ ", numberOfRestoringThreads=" + numberOfRestoringThreads +
'}';
}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDbStateDataTransfer.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDbStateDataTransfer.java
new file mode 100644
index 0000000000000..03e114da282bf
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDbStateDataTransfer.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.concurrent.FutureUtils;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
+import org.apache.flink.runtime.state.StateHandleID;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.function.ThrowingRunnable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import static org.apache.flink.runtime.concurrent.Executors.newDirectExecutorService;
+
+/**
+ * Data transfer utils for {@link RocksDBKeyedStateBackend}.
+ */
+class RocksDbStateDataTransfer {
+
+ static void transferAllStateDataToDirectory(
+ IncrementalKeyedStateHandle restoreStateHandle,
+ Path dest,
+ int restoringThreadNum,
+ CloseableRegistry closeableRegistry) throws Exception {
+
+ final Map sstFiles =
+ restoreStateHandle.getSharedState();
+ final Map miscFiles =
+ restoreStateHandle.getPrivateState();
+
+ downloadDataForAllStateHandles(sstFiles, dest, restoringThreadNum, closeableRegistry);
+ downloadDataForAllStateHandles(miscFiles, dest, restoringThreadNum, closeableRegistry);
+ }
+
+ /**
+ * Copies all the files from the given stream state handles to the given path, renaming the files w.r.t. their
+ * {@link StateHandleID}.
+ */
+ private static void downloadDataForAllStateHandles(
+ Map stateHandleMap,
+ Path restoreInstancePath,
+ int restoringThreadNum,
+ CloseableRegistry closeableRegistry) throws Exception {
+
+ final ExecutorService executorService = createExecutorService(restoringThreadNum);
+
+ try {
+ List runnables = createDownloadRunnables(stateHandleMap, restoreInstancePath, closeableRegistry);
+ List> futures = new ArrayList<>(runnables.size());
+ for (Runnable runnable : runnables) {
+ futures.add(CompletableFuture.runAsync(runnable, executorService));
+ }
+ FutureUtils.waitForAll(futures).get();
+ } catch (ExecutionException e) {
+ Throwable throwable = ExceptionUtils.stripExecutionException(e);
+ throwable = ExceptionUtils.stripException(throwable, RuntimeException.class);
+ if (throwable instanceof IOException) {
+ throw (IOException) throwable;
+ } else {
+ throw new FlinkRuntimeException("Failed to download data for state handles.", e);
+ }
+ } finally {
+ executorService.shutdownNow();
+ }
+ }
+
+ private static ExecutorService createExecutorService(int threadNum) {
+ if (threadNum > 1) {
+ return Executors.newFixedThreadPool(threadNum);
+ } else {
+ return newDirectExecutorService();
+ }
+ }
+
+ private static List createDownloadRunnables(
+ Map stateHandleMap,
+ Path restoreInstancePath,
+ CloseableRegistry closeableRegistry) {
+ List runnables = new ArrayList<>(stateHandleMap.size());
+ for (Map.Entry entry : stateHandleMap.entrySet()) {
+ StateHandleID stateHandleID = entry.getKey();
+ StreamStateHandle remoteFileHandle = entry.getValue();
+
+ Path path = new Path(restoreInstancePath, stateHandleID.toString());
+
+ runnables.add(ThrowingRunnable.unchecked(
+ () -> downloadDataForStateHandle(path, remoteFileHandle, closeableRegistry)));
+ }
+ return runnables;
+ }
+
+ /**
+ * Copies the file from a single state handle to the given path.
+ */
+ private static void downloadDataForStateHandle(
+ Path restoreFilePath,
+ StreamStateHandle remoteFileHandle,
+ CloseableRegistry closeableRegistry) throws IOException {
+
+ FSDataInputStream inputStream = null;
+ FSDataOutputStream outputStream = null;
+
+ try {
+ FileSystem restoreFileSystem = restoreFilePath.getFileSystem();
+ inputStream = remoteFileHandle.openInputStream();
+ closeableRegistry.registerCloseable(inputStream);
+
+ outputStream = restoreFileSystem.create(restoreFilePath, FileSystem.WriteMode.OVERWRITE);
+ closeableRegistry.registerCloseable(outputStream);
+
+ byte[] buffer = new byte[8 * 1024];
+ while (true) {
+ int numBytes = inputStream.read(buffer);
+ if (numBytes == -1) {
+ break;
+ }
+
+ outputStream.write(buffer, 0, numBytes);
+ }
+ } finally {
+ if (closeableRegistry.unregisterCloseable(inputStream)) {
+ inputStream.close();
+ }
+
+ if (closeableRegistry.unregisterCloseable(outputStream)) {
+ outputStream.close();
+ }
+ }
+ }
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index d7d6bdea879d8..0796c4f00fe9b 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -249,6 +249,7 @@ public void testCorrectMergeOperatorSet() throws IOException {
new KeyGroupRange(0, 0),
new ExecutionConfig(),
enableIncrementalCheckpointing,
+ 1,
TestLocalRecoveryConfig.disabled(),
RocksDBStateBackend.PriorityQueueStateType.HEAP,
TtlTimeProvider.DEFAULT,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDataTransferTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDataTransferTest.java
new file mode 100644
index 0000000000000..5b01e438006c6
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDataTransferTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.StateHandleID;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * Tests for {@link RocksDbStateDataTransfer}.
+ */
+public class RocksDBStateDataTransferTest extends TestLogger {
+ @Rule
+ public final TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+ /**
+ * Test that the exception arose in the thread pool will rethrow to the main thread.
+ */
+ @Test
+ public void testThreadPoolExceptionRethrow() {
+ SpecifiedException expectedException = new SpecifiedException("throw exception while multi thread restore.");
+ StreamStateHandle stateHandle = new StreamStateHandle() {
+ @Override
+ public FSDataInputStream openInputStream() throws IOException {
+ throw expectedException;
+ }
+
+ @Override
+ public void discardState() {
+
+ }
+
+ @Override
+ public long getStateSize() {
+ return 0;
+ }
+ };
+
+ Map stateHandles = new HashMap<>(1);
+ stateHandles.put(new StateHandleID("state1"), stateHandle);
+
+ IncrementalKeyedStateHandle incrementalKeyedStateHandle =
+ new IncrementalKeyedStateHandle(
+ UUID.randomUUID(),
+ KeyGroupRange.EMPTY_KEY_GROUP_RANGE,
+ 1,
+ stateHandles,
+ stateHandles,
+ stateHandle);
+
+ try {
+ RocksDbStateDataTransfer.transferAllStateDataToDirectory(incrementalKeyedStateHandle, new Path(temporaryFolder.newFolder().toURI()), 5, new CloseableRegistry());
+ fail();
+ } catch (Exception e) {
+ assertEquals(expectedException, e);
+ }
+ }
+
+ /**
+ * Tests that download files with multi-thread correctly.
+ */
+ @Test
+ public void testMultiThreadRestoreCorrectly() throws Exception {
+ Random random = new Random();
+ int contentNum = 6;
+ byte[][] contents = new byte[contentNum][];
+ for (int i = 0; i < contentNum; ++i) {
+ contents[i] = new byte[random.nextInt(100000) + 1];
+ random.nextBytes(contents[i]);
+ }
+
+ List handles = new ArrayList<>(contentNum);
+ for (int i = 0; i < contentNum; ++i) {
+ handles.add(new ByteStreamStateHandle(String.format("state%d", i), contents[i]));
+ }
+
+ Map sharedStates = new HashMap<>(contentNum);
+ Map privateStates = new HashMap<>(contentNum);
+ for (int i = 0; i < contentNum; ++i) {
+ sharedStates.put(new StateHandleID(String.format("sharedState%d", i)), handles.get(i));
+ privateStates.put(new StateHandleID(String.format("privateState%d", i)), handles.get(i));
+ }
+
+ IncrementalKeyedStateHandle incrementalKeyedStateHandle =
+ new IncrementalKeyedStateHandle(
+ UUID.randomUUID(),
+ KeyGroupRange.of(0, 1),
+ 1,
+ sharedStates,
+ privateStates,
+ handles.get(0));
+
+ Path dstPath = new Path(temporaryFolder.newFolder().toURI());
+ RocksDbStateDataTransfer.transferAllStateDataToDirectory(incrementalKeyedStateHandle, dstPath, contentNum - 1, new CloseableRegistry());
+
+ for (int i = 0; i < contentNum; ++i) {
+ assertStateContentEqual(contents[i], new Path(dstPath, String.format("sharedState%d", i)));
+ }
+ }
+
+ private void assertStateContentEqual(byte[] expected, Path path) throws IOException {
+ byte[] actual = Files.readAllBytes(Paths.get(path.toUri()));
+ assertArrayEquals(expected, actual);
+ }
+
+ private static class SpecifiedException extends IOException {
+ SpecifiedException(String message) {
+ super(message);
+ }
+ }
+}