getShardUrls() {
+ return shardUrls;
+ }
+
+ /**
+ * The url of the first shard.
+ *
+ */
+ public String getUrl() {
+ return getShardUrl(0);
+ }
+
+ /**
+ * The url of a specific shard.
+ *
+ */
+ public String getShardUrl(int shardIndex) {
+ validateShardIndex(shardIndex);
+ return shardUrls.get(shardIndex);
+ }
+
+ private void validateShardIndex(int i) {
+ if (i < 0) {
+ throw new IllegalArgumentException("Index must be positive.");
+ } else if (getNumberOfShards() <= i) {
+ throw new IllegalArgumentException("Index must be less then the total number of shards.");
+ }
+ }
+
+ /**
+ * Get the {@link DbAdapter} that will be used to operate on the database
+ * during checkpointing.
+ *
+ */
+ public DbAdapter getDbAdapter() {
+ return dbAdapter;
+ }
+
+ /**
+ * Set the {@link DbAdapter} that will be used to operate on the database
+ * during checkpointing.
+ *
+ */
+ public void setDbAdapter(DbAdapter adapter) {
+ this.dbAdapter = adapter;
+ }
+
+ /**
+ * The class name that should be used to load the JDBC driver using
+ * Class.forName(JDBCDriverClass).
+ */
+ public String getJDBCDriver() {
+ return JDBCDriver;
+ }
+
+ /**
+ * Set the class name that should be used to load the JDBC driver using
+ * Class.forName(JDBCDriverClass).
+ */
+ public void setJDBCDriver(String jDBCDriverClassName) {
+ JDBCDriver = jDBCDriverClassName;
+ }
+
+ /**
+ * The maximum number of key-value pairs stored in one task instance's cache
+ * before evicting to the underlying database.
+ *
+ */
+ public int getKvCacheSize() {
+ return kvStateCacheSize;
+ }
+
+ /**
+ * Set the maximum number of key-value pairs stored in one task instance's
+ * cache before evicting to the underlying database. When the cache is full
+ * the N least recently used keys will be evicted to the database, where N =
+ * maxKvEvictFraction*KvCacheSize.
+ *
+ */
+ public void setKvCacheSize(int size) {
+ kvStateCacheSize = size;
+ }
+
+ /**
+ * The maximum number of key-value pairs inserted in the database as one
+ * batch operation.
+ */
+ public int getMaxKvInsertBatchSize() {
+ return maxKvInsertBatchSize;
+ }
+
+ /**
+ * Set the maximum number of key-value pairs inserted in the database as one
+ * batch operation.
+ */
+ public void setMaxKvInsertBatchSize(int size) {
+ maxKvInsertBatchSize = size;
+ }
+
+ /**
+ * Sets the maximum fraction of key-value states evicted from the cache if
+ * the cache is full.
+ */
+ public void setMaxKvCacheEvictFraction(float fraction) {
+ if (fraction > 1 || fraction <= 0) {
+ throw new RuntimeException("Must be a number between 0 and 1");
+ } else {
+ maxKvEvictFraction = fraction;
+ }
+ }
+
+ /**
+ * The maximum fraction of key-value states evicted from the cache if the
+ * cache is full.
+ */
+ public float getMaxKvCacheEvictFraction() {
+ return maxKvEvictFraction;
+ }
+
+ /**
+ * The number of elements that will be evicted when the cache is full.
+ *
+ */
+ public int getNumElementsToEvict() {
+ return (int) Math.ceil(getKvCacheSize() * getMaxKvCacheEvictFraction());
+ }
+
+ /**
+ * Sets how often will automatic compaction be performed on the database to
+ * remove old overwritten state changes. The frequency is set in terms of
+ * number of successful checkpoints between two compactions and should take
+ * the state size and checkpoint frequency into account.
+ *
+ * By default automatic compaction is turned off.
+ */
+ public void setKvStateCompactionFrequency(int compactEvery) {
+ this.kvStateCompactionFreq = compactEvery;
+ }
+
+ /**
+ * Sets how often will automatic compaction be performed on the database to
+ * remove old overwritten state changes. The frequency is set in terms of
+ * number of successful checkpoints between two compactions and should take
+ * the state size and checkpoint frequency into account.
+ *
+ * By default automatic compaction is turned off.
+ */
+ public int getKvStateCompactionFrequency() {
+ return kvStateCompactionFreq;
+ }
+
+ /**
+ * The number of times each SQL command will be retried on failure.
+ */
+ public int getMaxNumberOfSqlRetries() {
+ return maxNumberOfSqlRetries;
+ }
+
+ /**
+ * Sets the number of times each SQL command will be retried on failure.
+ */
+ public void setMaxNumberOfSqlRetries(int maxNumberOfSqlRetries) {
+ this.maxNumberOfSqlRetries = maxNumberOfSqlRetries;
+ }
+
+ /**
+ * The number of milliseconds slept between two SQL retries. The actual
+ * sleep time will be chosen randomly between 1 and the given time.
+ *
+ */
+ public int getSleepBetweenSqlRetries() {
+ return sleepBetweenSqlRetries;
+ }
+
+ /**
+ * Sets the number of milliseconds slept between two SQL retries. The actual
+ * sleep time will be chosen randomly between 1 and the given time.
+ *
+ */
+ public void setSleepBetweenSqlRetries(int sleepBetweenSqlRetries) {
+ this.sleepBetweenSqlRetries = sleepBetweenSqlRetries;
+ }
+
+ /**
+ * Sets the partitioner used to assign keys to different database shards
+ */
+ public void setPartitioner(Partitioner partitioner) {
+ this.shardPartitioner = partitioner;
+ }
+
+ /**
+ * Creates a new {@link ShardedConnection} using the set parameters.
+ *
+ * @throws SQLException
+ */
+ public ShardedConnection createShardedConnection() throws SQLException {
+ if (JDBCDriver != null) {
+ try {
+ Class.forName(JDBCDriver);
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Could not load JDBC driver class", e);
+ }
+ }
+ if (shardPartitioner == null) {
+ return new ShardedConnection(shardUrls, userName, userPassword);
+ } else {
+ return new ShardedConnection(shardUrls, userName, userPassword, shardPartitioner);
+ }
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateBackend.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateBackend.java
new file mode 100644
index 0000000000000..72482aedeb0c0
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateBackend.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.SQLException;
+import java.util.Random;
+import java.util.concurrent.Callable;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.InstantiationUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.flink.contrib.streaming.state.SQLRetrier.retry;
+
+/**
+ * {@link StateBackend} for storing checkpoints in JDBC supporting databases.
+ * Key-Value state is stored out-of-core and is lazily fetched using the
+ * {@link LazyDbKvState} implementation. A different backend can also be
+ * provided in the constructor to store the non-partitioned states. A common use
+ * case would be to store the key-value states in the database and store larger
+ * non-partitioned states on a distributed file system.
+ *
+ * This backend implementation also allows the sharding of the checkpointed
+ * states among multiple database instances, which can be enabled by passing
+ * multiple database urls to the {@link DbBackendConfig} instance.
+ *
+ * By default there are multiple tables created in the given databases: 1 table
+ * for non-partitioned checkpoints and 1 table for each key-value state in the
+ * streaming program.
+ *
+ * To control table creation, insert/lookup operations and to provide
+ * compatibility for different SQL implementations, a custom
+ * {@link MySqlAdapter} can be supplied in the {@link DbBackendConfig}.
+ *
+ */
+public class DbStateBackend extends StateBackend {
+
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(DbStateBackend.class);
+
+ private Random rnd;
+
+ // ------------------------------------------------------
+
+ private Environment env;
+
+ // ------------------------------------------------------
+
+ private final DbBackendConfig dbConfig;
+ private final DbAdapter dbAdapter;
+
+ private ShardedConnection connections;
+
+ private final int numSqlRetries;
+ private final int sqlRetrySleep;
+
+ private PreparedStatement insertStatement;
+
+ // ------------------------------------------------------
+
+ // We allow to use a different backend for storing non-partitioned states
+ private StateBackend> nonPartitionedStateBackend = null;
+
+ // ------------------------------------------------------
+
+ /**
+ * Create a new {@link DbStateBackend} using the provided
+ * {@link DbBackendConfig} configuration.
+ *
+ */
+ public DbStateBackend(DbBackendConfig backendConfig) {
+ this.dbConfig = backendConfig;
+ dbAdapter = backendConfig.getDbAdapter();
+ numSqlRetries = backendConfig.getMaxNumberOfSqlRetries();
+ sqlRetrySleep = backendConfig.getSleepBetweenSqlRetries();
+ }
+
+ /**
+ * Create a new {@link DbStateBackend} using the provided
+ * {@link DbBackendConfig} configuration and a different backend for storing
+ * non-partitioned state snapshots.
+ *
+ */
+ public DbStateBackend(DbBackendConfig backendConfig, StateBackend> backend) {
+ this(backendConfig);
+ this.nonPartitionedStateBackend = backend;
+ }
+
+ /**
+ * Get the database connections maintained by the backend.
+ */
+ public ShardedConnection getConnections() {
+ return connections;
+ }
+
+ /**
+ * Check whether the backend has been initialized.
+ *
+ */
+ public boolean isInitialized() {
+ return connections != null;
+ }
+
+ public Environment getEnvironment() {
+ return env;
+ }
+
+ /**
+ * Get the backend configuration object.
+ */
+ public DbBackendConfig getConfiguration() {
+ return dbConfig;
+ }
+
+ @Override
+ public StateHandle checkpointStateSerializable(final S state, final long checkpointID,
+ final long timestamp) throws Exception {
+
+ // If we set a different backend for non-partitioned checkpoints we use
+ // that otherwise write to the database.
+ if (nonPartitionedStateBackend == null) {
+ return retry(new Callable>() {
+ public DbStateHandle call() throws Exception {
+ // We create a unique long id for each handle, but we also
+ // store the checkpoint id and timestamp for bookkeeping
+ long handleId = rnd.nextLong();
+ String jobIdShort = env.getJobID().toShortString();
+
+ dbAdapter.setCheckpointInsertParams(jobIdShort, insertStatement,
+ checkpointID, timestamp, handleId,
+ InstantiationUtil.serializeObject(state));
+
+ insertStatement.executeUpdate();
+
+ return new DbStateHandle(jobIdShort, checkpointID, timestamp, handleId,
+ dbConfig);
+ }
+ }, numSqlRetries, sqlRetrySleep);
+ } else {
+ return nonPartitionedStateBackend.checkpointStateSerializable(state, checkpointID, timestamp);
+ }
+ }
+
+ @Override
+ public CheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp)
+ throws Exception {
+ if (nonPartitionedStateBackend == null) {
+ // We don't implement this functionality for the DbStateBackend as
+ // we cannot directly write a stream to the database anyways.
+ throw new UnsupportedOperationException("Use ceckpointStateSerializable instead.");
+ } else {
+ return nonPartitionedStateBackend.createCheckpointStateOutputStream(checkpointID, timestamp);
+ }
+ }
+
+ @Override
+ public LazyDbKvState createKvState(String stateId, String stateName,
+ TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws IOException {
+ return new LazyDbKvState(
+ stateId + "_" + env.getJobID().toShortString(),
+ env.getIndexInSubtaskGroup() == 0,
+ getConnections(),
+ getConfiguration(),
+ keySerializer,
+ valueSerializer,
+ defaultValue);
+ }
+
+ @Override
+ public void initializeForJob(final Environment env) throws Exception {
+ this.rnd = new Random();
+ this.env = env;
+
+ connections = dbConfig.createShardedConnection();
+
+ // We want the most light-weight transaction isolation level as we don't
+ // have conflicting reads/writes. We just want to be able to roll back
+ // batch inserts for k-v snapshots. This requirement might be removed in
+ // the future.
+ connections.setTransactionIsolation(Connection.TRANSACTION_READ_UNCOMMITTED);
+
+ // If we have a different backend for non-partitioned states we
+ // initialize that, otherwise create tables for storing the checkpoints.
+ //
+ // Currently all non-partitioned states are written to the first
+ // database shard
+ if (nonPartitionedStateBackend == null) {
+ insertStatement = retry(new Callable() {
+ public PreparedStatement call() throws SQLException {
+ dbAdapter.createCheckpointsTable(env.getJobID().toShortString(), getConnections().getFirst());
+ return dbAdapter.prepareCheckpointInsert(env.getJobID().toShortString(),
+ getConnections().getFirst());
+ }
+ }, numSqlRetries, sqlRetrySleep);
+ } else {
+ nonPartitionedStateBackend.initializeForJob(env);
+ }
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Database state backend successfully initialized");
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ // We first close the statement/non-partitioned backend, then we close
+ // the database connection
+ try (ShardedConnection c = connections) {
+ if (nonPartitionedStateBackend == null) {
+ insertStatement.close();
+ } else {
+ nonPartitionedStateBackend.close();
+ }
+ }
+ }
+
+ @Override
+ public void disposeAllStateForCurrentJob() throws Exception {
+ if (nonPartitionedStateBackend == null) {
+ dbAdapter.disposeAllStateForJob(env.getJobID().toShortString(), connections.getFirst());
+ } else {
+ nonPartitionedStateBackend.disposeAllStateForCurrentJob();
+ }
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateHandle.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateHandle.java
new file mode 100644
index 0000000000000..2ecfcc4b7840d
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateHandle.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import static org.apache.flink.contrib.streaming.state.SQLRetrier.retry;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.concurrent.Callable;
+
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.InstantiationUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * State handle implementation for storing checkpoints as byte arrays in
+ * databases using the {@link MySqlAdapter} defined in the {@link DbBackendConfig}.
+ *
+ */
+public class DbStateHandle implements Serializable, StateHandle {
+
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG = LoggerFactory.getLogger(DbStateHandle.class);
+
+ private final String jobId;
+ private final DbBackendConfig dbConfig;
+
+ private final long checkpointId;
+ private final long checkpointTs;
+
+ private final long handleId;
+
+ public DbStateHandle(String jobId, long checkpointId, long checkpointTs, long handleId, DbBackendConfig dbConfig) {
+ this.checkpointId = checkpointId;
+ this.handleId = handleId;
+ this.jobId = jobId;
+ this.dbConfig = dbConfig;
+ this.checkpointTs = checkpointTs;
+ }
+
+ protected byte[] getBytes() throws IOException {
+ return retry(new Callable() {
+ public byte[] call() throws Exception {
+ try (ShardedConnection con = dbConfig.createShardedConnection()) {
+ return dbConfig.getDbAdapter().getCheckpoint(jobId, con.getFirst(), checkpointId, checkpointTs, handleId);
+ }
+ }
+ }, dbConfig.getMaxNumberOfSqlRetries(), dbConfig.getSleepBetweenSqlRetries());
+ }
+
+ @Override
+ public void discardState() {
+ try {
+ retry(new Callable() {
+ public Boolean call() throws Exception {
+ try (ShardedConnection con = dbConfig.createShardedConnection()) {
+ dbConfig.getDbAdapter().deleteCheckpoint(jobId, con.getFirst(), checkpointId, checkpointTs, handleId);
+ }
+ return true;
+ }
+ }, dbConfig.getMaxNumberOfSqlRetries(), dbConfig.getSleepBetweenSqlRetries());
+ } catch (IOException e) {
+ // We don't want to fail the job here, but log the error.
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Could not discard state.");
+ }
+ }
+ }
+
+ @Override
+ public S getState(ClassLoader userCodeClassLoader) throws IOException, ClassNotFoundException {
+ return InstantiationUtil.deserializeObject(getBytes(), userCodeClassLoader);
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/LazyDbKvState.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/LazyDbKvState.java
new file mode 100644
index 0000000000000..3d7abff1c9ffd
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/LazyDbKvState.java
@@ -0,0 +1,624 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import static org.apache.flink.contrib.streaming.state.SQLRetrier.retry;
+
+import java.io.IOException;
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.contrib.streaming.state.ShardedConnection.ShardedStatement;
+import org.apache.flink.runtime.state.KvState;
+import org.apache.flink.runtime.state.KvStateSnapshot;
+import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier;
+import org.apache.flink.util.InstantiationUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Optional;
+
+/**
+ *
+ * Lazily fetched {@link KvState} using a SQL backend. Key-value pairs are
+ * cached on heap and are lazily retrieved on access.
+ *
+ */
+public class LazyDbKvState implements KvState, CheckpointNotifier {
+
+ private static final Logger LOG = LoggerFactory.getLogger(LazyDbKvState.class);
+
+ // ------------------------------------------------------
+
+ // Unique id for this state (jobID_operatorID_stateName)
+ private final String kvStateId;
+ private final boolean compact;
+
+ private K currentKey;
+ private final V defaultValue;
+
+ private final TypeSerializer keySerializer;
+ private final TypeSerializer valueSerializer;
+
+ // ------------------------------------------------------
+
+ // Max number of retries for failed database operations
+ private final int numSqlRetries;
+ // Sleep time between two retries
+ private final int sqlRetrySleep;
+ // Max number of key-value pairs inserted in one batch to the database
+ private final int maxInsertBatchSize;
+ // We will do database compaction every so many checkpoints
+ private final int compactEvery;
+ // Executor for automatic compactions
+ private ExecutorService executor = null;
+
+ // Database properties
+ private final DbBackendConfig conf;
+ private final ShardedConnection connections;
+ private final DbAdapter dbAdapter;
+
+ // Convenience object for handling inserts to the database
+ private final BatchInserter batchInsert;
+
+ // Statements for key-lookups and inserts as prepared by the dbAdapter
+ private ShardedStatement selectStatements;
+ private ShardedStatement insertStatements;
+
+ // ------------------------------------------------------
+
+ // LRU cache for the key-value states backed by the database
+ private final StateCache cache;
+
+ private long nextTs;
+ private Map completedCheckpoints = new HashMap<>();
+
+ private volatile long lastCompactedTs;
+
+ // ------------------------------------------------------
+
+ /**
+ * Constructor to initialize the {@link LazyDbKvState} the first time the
+ * job starts.
+ */
+ public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, DbBackendConfig conf,
+ TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws IOException {
+ this(kvStateId, compact, cons, conf, keySerializer, valueSerializer, defaultValue, 1, 0);
+ }
+
+ /**
+ * Initialize the {@link LazyDbKvState} from a snapshot.
+ */
+ public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, final DbBackendConfig conf,
+ TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue, long nextTs,
+ long lastCompactedTs)
+ throws IOException {
+
+ this.kvStateId = kvStateId;
+ this.compact = compact;
+ if (compact) {
+ // Compactions will run in a seperate thread
+ executor = Executors.newSingleThreadExecutor();
+ }
+
+ this.keySerializer = keySerializer;
+ this.valueSerializer = valueSerializer;
+ this.defaultValue = defaultValue;
+
+ this.maxInsertBatchSize = conf.getMaxKvInsertBatchSize();
+ this.conf = conf;
+ this.connections = cons;
+ this.dbAdapter = conf.getDbAdapter();
+ this.compactEvery = conf.getKvStateCompactionFrequency();
+ this.numSqlRetries = conf.getMaxNumberOfSqlRetries();
+ this.sqlRetrySleep = conf.getSleepBetweenSqlRetries();
+
+ this.nextTs = nextTs;
+ this.lastCompactedTs = lastCompactedTs;
+
+ this.cache = new StateCache(conf.getKvCacheSize(), conf.getNumElementsToEvict());
+
+ initDB(this.connections);
+
+ batchInsert = new BatchInserter(connections.getNumShards());
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Lazy database kv-state ({}) successfully initialized", kvStateId);
+ }
+ }
+
+ @Override
+ public void setCurrentKey(K key) {
+ this.currentKey = key;
+ }
+
+ @Override
+ public void update(V value) throws IOException {
+ try {
+ cache.put(currentKey, Optional.fromNullable(value));
+ } catch (RuntimeException e) {
+ // We need to catch the RuntimeExceptions thrown in the StateCache
+ // methods here
+ throw new IOException(e);
+ }
+ }
+
+ @Override
+ public V value() throws IOException {
+ try {
+ // We get the value from the cache (which will automatically load it
+ // from the database if necessary). If null, we return a copy of the
+ // default value
+ V val = cache.get(currentKey).orNull();
+ return val != null ? val : copyDefault();
+ } catch (RuntimeException e) {
+ // We need to catch the RuntimeExceptions thrown in the StateCache
+ // methods here
+ throw new IOException(e);
+ }
+ }
+
+ @Override
+ public DbKvStateSnapshot snapshot(long checkpointId, long timestamp) throws IOException {
+
+ // Validate timing assumptions
+ if (timestamp <= nextTs) {
+ throw new RuntimeException("Checkpoint timestamp is smaller than previous ts + 1, "
+ + "this should not happen.");
+ }
+
+ // If there are any modified states we perform the inserts
+ if (!cache.modified.isEmpty()) {
+ // We insert the modified elements to the database with the current
+ // timestamp then clear the modified states
+ for (Entry> state : cache.modified.entrySet()) {
+ batchInsert.add(state, timestamp);
+ }
+ batchInsert.flush(timestamp);
+ cache.modified.clear();
+ } else if (compact) {
+ // Otherwise we call the keep alive method to avoid dropped
+ // connections (only call this on the compactor instance)
+ for (final Connection c : connections.connections()) {
+ SQLRetrier.retry(new Callable() {
+ @Override
+ public Void call() throws Exception {
+ dbAdapter.keepAlive(c);
+ return null;
+ }
+ }, numSqlRetries, sqlRetrySleep);
+ }
+ }
+
+ nextTs = timestamp + 1;
+ completedCheckpoints.put(checkpointId, timestamp);
+ return new DbKvStateSnapshot(kvStateId, timestamp, lastCompactedTs);
+ }
+
+ /**
+ * Returns the number of elements currently stored in the task's cache. Note
+ * that the number of elements in the database is not counted here.
+ */
+ @Override
+ public int size() {
+ return cache.size();
+ }
+
+ /**
+ * Return a copy the default value or null if the default was null.
+ *
+ */
+ private V copyDefault() {
+ return defaultValue != null ? valueSerializer.copy(defaultValue) : null;
+ }
+
+ /**
+ * Create a table for the kvstate checkpoints (based on the kvStateId) and
+ * prepare the statements used during checkpointing.
+ */
+ private void initDB(final ShardedConnection cons) throws IOException {
+
+ retry(new Callable() {
+ public Void call() throws Exception {
+
+ for (Connection con : cons.connections()) {
+ dbAdapter.createKVStateTable(kvStateId, con);
+ }
+
+ insertStatements = cons.prepareStatement(dbAdapter.prepareKVCheckpointInsert(kvStateId));
+ selectStatements = cons.prepareStatement(dbAdapter.prepareKeyLookup(kvStateId));
+
+ return null;
+ }
+
+ }, numSqlRetries, sqlRetrySleep);
+ }
+
+ @Override
+ public void notifyCheckpointComplete(long checkpointId) {
+ final Long ts = completedCheckpoints.remove(checkpointId);
+ if (ts == null) {
+ LOG.warn("Complete notification for missing checkpoint: " + checkpointId);
+ } else {
+ // If compaction is turned on we compact on the compactor subtask
+ // asynchronously in the background
+ if (compactEvery > 0 && compact && checkpointId % compactEvery == 0) {
+ executor.execute(new Compactor(ts));
+ }
+ }
+ }
+
+ @Override
+ public void dispose() {
+ // We are only closing the statements here, the connection is borrowed
+ // from the state backend and will be closed there.
+ try {
+ selectStatements.close();
+ } catch (SQLException e) {
+ // There is not much to do about this
+ }
+ try {
+ insertStatements.close();
+ } catch (SQLException e) {
+ // There is not much to do about this
+ }
+
+ if (executor != null) {
+ executor.shutdown();
+ }
+ }
+
+ /**
+ * Return the Map of cached states.
+ *
+ */
+ public Map> getStateCache() {
+ return cache;
+ }
+
+ /**
+ * Return the Map of modified states that hasn't been written to the
+ * database yet.
+ *
+ */
+ public Map> getModified() {
+ return cache.modified;
+ }
+
+ /**
+ * Used for testing purposes
+ */
+ public boolean isCompactor() {
+ return compact;
+ }
+
+ /**
+ * Used for testing purposes
+ */
+ public ExecutorService getExecutor() {
+ return executor;
+ }
+
+ /**
+ * Snapshot that stores a specific checkpoint timestamp and state id, and
+ * also rolls back the database to that point upon restore. The rollback is
+ * done by removing all state checkpoints that have timestamps between the
+ * checkpoint and recovery timestamp.
+ *
+ */
+ private static class DbKvStateSnapshot implements KvStateSnapshot {
+
+ private static final long serialVersionUID = 1L;
+
+ private final String kvStateId;
+ private final long checkpointTimestamp;
+ private final long lastCompactedTimestamp;
+
+ public DbKvStateSnapshot(String kvStateId, long checkpointTimestamp, long lastCompactedTs) {
+ this.checkpointTimestamp = checkpointTimestamp;
+ this.kvStateId = kvStateId;
+ this.lastCompactedTimestamp = lastCompactedTs;
+ }
+
+ @Override
+ public LazyDbKvState restoreState(final DbStateBackend stateBackend,
+ final TypeSerializer keySerializer, final TypeSerializer valueSerializer, final V defaultValue,
+ ClassLoader classLoader, final long recoveryTimestamp) throws IOException {
+
+ // Validate timing assumptions
+ if (recoveryTimestamp <= checkpointTimestamp) {
+ throw new RuntimeException(
+ "Recovery timestamp is smaller or equal to checkpoint timestamp. "
+ + "This might happen if the job was started with a new JobManager "
+ + "and the clocks got really out of sync.");
+ }
+
+ // First we clean up the states written by partially failed
+ // snapshots
+ retry(new Callable() {
+ public Void call() throws Exception {
+
+ // We need to perform cleanup on all shards to be safe here
+ for (Connection c : stateBackend.getConnections().connections()) {
+ stateBackend.getConfiguration().getDbAdapter().cleanupFailedCheckpoints(kvStateId,
+ c, checkpointTimestamp, recoveryTimestamp);
+ }
+
+ return null;
+ }
+ }, stateBackend.getConfiguration().getMaxNumberOfSqlRetries(),
+ stateBackend.getConfiguration().getSleepBetweenSqlRetries());
+
+ boolean cleanup = stateBackend.getEnvironment().getIndexInSubtaskGroup() == 0;
+
+ // Restore the KvState
+ LazyDbKvState restored = new LazyDbKvState(kvStateId, cleanup,
+ stateBackend.getConnections(), stateBackend.getConfiguration(), keySerializer, valueSerializer,
+ defaultValue, recoveryTimestamp, lastCompactedTimestamp);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("KV state({},{}) restored.", kvStateId, recoveryTimestamp);
+ }
+
+ return restored;
+ }
+
+ @Override
+ public void discardState() throws Exception {
+ // Don't discard, it will be compacted by the LazyDbKvState
+ }
+
+ }
+
+ /**
+ * LRU cache implementation for storing the key-value states. When the cache
+ * is full elements are not evicted one by one but are evicted in a batch
+ * defined by the evictionSize parameter.
+ *
+ * Keys not found in the cached will be retrieved from the underlying
+ * database
+ */
+ private final class StateCache extends LinkedHashMap> {
+ private static final long serialVersionUID = 1L;
+
+ private final int cacheSize;
+ private final int evictionSize;
+
+ // We keep track the state modified since the last checkpoint
+ private final Map> modified = new HashMap<>();
+
+ public StateCache(int cacheSize, int evictionSize) {
+ super(cacheSize, 0.75f, true);
+ this.cacheSize = cacheSize;
+ this.evictionSize = evictionSize;
+ }
+
+ @Override
+ public Optional put(K key, Optional value) {
+ // Put kv pair in the cache and evict elements if the cache is full
+ Optional old = super.put(key, value);
+ modified.put(key, value);
+ evictIfFull();
+ return old;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public Optional get(Object key) {
+ // First we check whether the value is cached
+ Optional value = super.get(key);
+ if (value == null) {
+ // If it doesn't try to load it from the database
+ value = Optional.fromNullable(getFromDatabaseOrNull((K) key));
+ put((K) key, value);
+ }
+ return value;
+ }
+
+ @Override
+ protected boolean removeEldestEntry(Entry> eldest) {
+ // We need to remove elements manually if the cache becomes full, so
+ // we always return false here.
+ return false;
+ }
+
+ /**
+ * Fetch the current value from the database if exists or return null.
+ *
+ * @param key
+ * @return The value corresponding to the key and the last checkpointid
+ * from the database if exists or null.
+ */
+ private V getFromDatabaseOrNull(final K key) {
+ try {
+ return retry(new Callable() {
+ public V call() throws Exception {
+ byte[] serializedKey = InstantiationUtil.serializeToByteArray(keySerializer, key);
+ // We lookup using the adapter and serialize/deserialize
+ // with the TypeSerializers
+ byte[] serializedVal = dbAdapter.lookupKey(kvStateId,
+ selectStatements.getForKey(key), serializedKey, nextTs);
+
+ return serializedVal != null
+ ? InstantiationUtil.deserializeFromByteArray(valueSerializer, serializedVal) : null;
+ }
+ }, numSqlRetries, sqlRetrySleep);
+ } catch (IOException e) {
+ // We need to re-throw this exception to conform to the map
+ // interface, we will catch this when we call the the put/get
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * If the cache is full we remove the evictionSize least recently
+ * accessed elements and write them to the database if they were
+ * modified since the last checkpoint.
+ */
+ private void evictIfFull() {
+ if (size() > cacheSize) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("State cache is full for {}, evicting {} elements.", kvStateId, evictionSize);
+ }
+ try {
+ int numEvicted = 0;
+
+ Iterator>> entryIterator = entrySet().iterator();
+ while (numEvicted++ < evictionSize && entryIterator.hasNext()) {
+
+ Entry> next = entryIterator.next();
+
+ // We only need to write to the database if modified
+ if (modified.remove(next.getKey()) != null) {
+ batchInsert.add(next, nextTs);
+ }
+
+ entryIterator.remove();
+ }
+
+ batchInsert.flush(nextTs);
+
+ } catch (IOException e) {
+ // We need to re-throw this exception to conform to the map
+ // interface, we will catch this when we call the the
+ // put/get
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ @Override
+ public void putAll(Map extends K, ? extends Optional> m) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ /**
+ * Object for handling inserts to the database by batching them together
+ * partitioned on the sharding key. The batches are written to the database
+ * when they are full or when the inserter is flushed.
+ *
+ */
+ private class BatchInserter {
+
+ // Map from shard index to the kv pairs to be inserted
+ // Map>> inserts = new HashMap<>();
+
+ List>[] inserts;
+
+ @SuppressWarnings("unchecked")
+ public BatchInserter(int numShards) {
+ inserts = new List[numShards];
+ for (int i = 0; i < numShards; i++) {
+ inserts[i] = new ArrayList<>();
+ }
+ }
+
+ public void add(Entry> next, long timestamp) throws IOException {
+
+ K key = next.getKey();
+ V value = next.getValue().orNull();
+
+ // Get the current partition if present or initialize empty list
+ int shardIndex = connections.getShardIndex(key);
+
+ List> insertPartition = inserts[shardIndex];
+
+ // Add the k-v pair to the partition
+ byte[] k = InstantiationUtil.serializeToByteArray(keySerializer, key);
+ byte[] v = value != null ? InstantiationUtil.serializeToByteArray(valueSerializer, value) : null;
+ insertPartition.add(Tuple2.of(k, v));
+
+ // If partition is full write to the database and clear
+ if (insertPartition.size() == maxInsertBatchSize) {
+ dbAdapter.insertBatch(kvStateId, conf,
+ connections.getForIndex(shardIndex),
+ insertStatements.getForIndex(shardIndex),
+ timestamp, insertPartition);
+
+ insertPartition.clear();
+ }
+ }
+
+ public void flush(long timestamp) throws IOException {
+ // We flush all non-empty partitions
+ for (int i = 0; i < inserts.length; i++) {
+ List> insertPartition = inserts[i];
+ if (!insertPartition.isEmpty()) {
+ dbAdapter.insertBatch(kvStateId, conf, connections.getForIndex(i),
+ insertStatements.getForIndex(i), timestamp, insertPartition);
+ insertPartition.clear();
+ }
+ }
+
+ }
+ }
+
+ private class Compactor implements Runnable {
+
+ private long upperBound;
+
+ public Compactor(long upperBound) {
+ this.upperBound = upperBound;
+ }
+
+ @Override
+ public void run() {
+ // We create new database connections to make sure we don't
+ // interfere with the checkpointing (connections are not thread
+ // safe)
+ try (ShardedConnection sc = conf.createShardedConnection()) {
+ for (final Connection c : sc.connections()) {
+ SQLRetrier.retry(new Callable() {
+ @Override
+ public Void call() throws Exception {
+ dbAdapter.compactKvStates(kvStateId, c, lastCompactedTs, upperBound);
+ return null;
+ }
+ }, numSqlRetries, sqlRetrySleep);
+ }
+ if (LOG.isInfoEnabled()) {
+ LOG.info("State succesfully compacted for {} between {} and {}.", kvStateId,
+ lastCompactedTs,
+ upperBound);
+ }
+ lastCompactedTs = upperBound;
+ } catch (SQLException | IOException e) {
+ LOG.warn("State compaction failed due: {}", e);
+ }
+ }
+
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/MySqlAdapter.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/MySqlAdapter.java
new file mode 100644
index 0000000000000..9eaa2833d13de
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/MySqlAdapter.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import java.io.IOException;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.sql.Types;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/**
+ *
+ * Adapter for bridging inconsistencies between the different SQL
+ * implementations. The default implementation has been tested to work well with
+ * MySQL
+ *
+ */
+public class MySqlAdapter implements DbAdapter {
+
+ private static final long serialVersionUID = 1L;
+
+ // -----------------------------------------------------------------------------
+ // Non-partitioned state checkpointing
+ // -----------------------------------------------------------------------------
+
+ @Override
+ public void createCheckpointsTable(String jobId, Connection con) throws SQLException {
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate(
+ "CREATE TABLE IF NOT EXISTS checkpoints_" + jobId
+ + " ("
+ + "checkpointId bigint, "
+ + "timestamp bigint, "
+ + "handleId bigint,"
+ + "checkpoint blob,"
+ + "PRIMARY KEY (handleId)"
+ + ")");
+ }
+
+ }
+
+ @Override
+ public PreparedStatement prepareCheckpointInsert(String jobId, Connection con) throws SQLException {
+ return con.prepareStatement(
+ "INSERT INTO checkpoints_" + jobId
+ + " (checkpointId, timestamp, handleId, checkpoint) VALUES (?,?,?,?)");
+ }
+
+ @Override
+ public void setCheckpointInsertParams(String jobId, PreparedStatement insertStatement, long checkpointId,
+ long timestamp, long handleId, byte[] checkpoint) throws SQLException {
+ insertStatement.setLong(1, checkpointId);
+ insertStatement.setLong(2, timestamp);
+ insertStatement.setLong(3, handleId);
+ insertStatement.setBytes(4, checkpoint);
+ }
+
+ @Override
+ public byte[] getCheckpoint(String jobId, Connection con, long checkpointId, long checkpointTs, long handleId)
+ throws SQLException {
+ try (Statement smt = con.createStatement()) {
+ ResultSet rs = smt.executeQuery(
+ "SELECT checkpoint FROM checkpoints_" + jobId
+ + " WHERE handleId = " + handleId);
+ if (rs.next()) {
+ return rs.getBytes(1);
+ } else {
+ throw new SQLException("Checkpoint cannot be found in the database.");
+ }
+ }
+ }
+
+ @Override
+ public void deleteCheckpoint(String jobId, Connection con, long checkpointId, long checkpointTs, long handleId)
+ throws SQLException {
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate(
+ "DELETE FROM checkpoints_" + jobId
+ + " WHERE handleId = " + handleId);
+ }
+ }
+
+ @Override
+ public void disposeAllStateForJob(String jobId, Connection con) throws SQLException {
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate(
+ "DROP TABLE checkpoints_" + jobId);
+ }
+ }
+
+ // -----------------------------------------------------------------------------
+ // Partitioned state checkpointing
+ // -----------------------------------------------------------------------------
+
+ @Override
+ public void createKVStateTable(String stateId, Connection con) throws SQLException {
+ validateStateId(stateId);
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate(
+ "CREATE TABLE IF NOT EXISTS " + stateId
+ + " ("
+ + "timestamp bigint, "
+ + "k varbinary(256), "
+ + "v blob, "
+ + "PRIMARY KEY (k, timestamp) "
+ + ")");
+ }
+ }
+
+ @Override
+ public String prepareKVCheckpointInsert(String stateId) throws SQLException {
+ validateStateId(stateId);
+ return "INSERT INTO " + stateId + " (timestamp, k, v) VALUES (?,?,?) "
+ + "ON DUPLICATE KEY UPDATE v=? ";
+ }
+
+ @Override
+ public String prepareKeyLookup(String stateId) throws SQLException {
+ validateStateId(stateId);
+ return "SELECT v"
+ + " FROM " + stateId
+ + " WHERE k = ?"
+ + " AND timestamp <= ?"
+ + " ORDER BY timestamp DESC LIMIT 1";
+ }
+
+ @Override
+ public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupTs)
+ throws SQLException {
+ lookupStatement.setBytes(1, key);
+ lookupStatement.setLong(2, lookupTs);
+
+ ResultSet res = lookupStatement.executeQuery();
+
+ if (res.next()) {
+ return res.getBytes(1);
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTs,
+ long recoveryTs) throws SQLException {
+ validateStateId(stateId);
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate("DELETE FROM " + stateId
+ + " WHERE timestamp > " + checkpointTs
+ + " AND timestamp < " + recoveryTs);
+ }
+ }
+
+ @Override
+ public void compactKvStates(String stateId, Connection con, long lowerId, long upperId)
+ throws SQLException {
+ validateStateId(stateId);
+
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate("DELETE state.* FROM " + stateId + " AS state"
+ + " JOIN"
+ + " ("
+ + " SELECT MAX(timestamp) AS maxts, k FROM " + stateId
+ + " WHERE timestamp BETWEEN " + lowerId + " AND " + upperId
+ + " GROUP BY k"
+ + " ) m"
+ + " ON state.k = m.k"
+ + " AND state.timestamp >= " + lowerId);
+ }
+ }
+
+ /**
+ * Tries to avoid SQL injection with weird state names.
+ *
+ */
+ protected static void validateStateId(String name) {
+ if (!name.matches("[a-zA-Z0-9_]+")) {
+ throw new RuntimeException("State name contains invalid characters.");
+ }
+ }
+
+ @Override
+ public void insertBatch(final String stateId, final DbBackendConfig conf,
+ final Connection con, final PreparedStatement insertStatement, final long checkpointTs,
+ final List> toInsert) throws IOException {
+
+ SQLRetrier.retry(new Callable() {
+ public Void call() throws Exception {
+ for (Tuple2 kv : toInsert) {
+ setKvInsertParams(stateId, insertStatement, checkpointTs, kv.f0, kv.f1);
+ insertStatement.addBatch();
+ }
+ insertStatement.executeBatch();
+ insertStatement.clearBatch();
+ return null;
+ }
+ }, new Callable() {
+ public Void call() throws Exception {
+ insertStatement.clearBatch();
+ return null;
+ }
+ }, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries());
+ }
+
+ private void setKvInsertParams(String stateId, PreparedStatement insertStatement, long checkpointTs,
+ byte[] key, byte[] value) throws SQLException {
+ insertStatement.setLong(1, checkpointTs);
+ insertStatement.setBytes(2, key);
+ if (value != null) {
+ insertStatement.setBytes(3, value);
+ insertStatement.setBytes(4, value);
+ } else {
+ insertStatement.setNull(3, Types.BLOB);
+ insertStatement.setNull(4, Types.BLOB);
+ }
+ }
+
+ @Override
+ public void keepAlive(Connection con) throws SQLException {
+ try(Statement smt = con.createStatement()) {
+ smt.executeQuery("SELECT 1");
+ }
+ }
+
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/SQLRetrier.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/SQLRetrier.java
new file mode 100644
index 0000000000000..4ae3fd23ef8c3
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/SQLRetrier.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.sql.SQLException;
+import java.util.Random;
+import java.util.concurrent.Callable;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Simple utility to retry failed SQL commands for a predefined number of times
+ * before declaring failure. The retrier waits (randomly) between 2 retries.
+ *
+ */
+public final class SQLRetrier implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private static final Logger LOG = LoggerFactory.getLogger(SQLRetrier.class);
+ private static final Random rnd = new Random();
+
+ private static final int SLEEP_TIME = 10;
+
+ private SQLRetrier() {
+
+ }
+
+ /**
+ * Tries to run the given {@link Callable} the predefined number of times
+ * before throwing an {@link IOException}. This method will only retries
+ * calls ending in {@link SQLException}. Other exceptions will cause a
+ * {@link RuntimeException}.
+ *
+ * @param callable
+ * The callable to be retried.
+ * @param numRetries
+ * Max number of retries before throwing an {@link IOException}.
+ * @throws IOException
+ * The wrapped {@link SQLException}.
+ */
+ public static X retry(Callable callable, int numRetries) throws IOException {
+ return retry(callable, numRetries, SLEEP_TIME);
+ }
+
+ /**
+ * Tries to run the given {@link Callable} the predefined number of times
+ * before throwing an {@link IOException}. This method will only retries
+ * calls ending in {@link SQLException}. Other exceptions will cause a
+ * {@link RuntimeException}.
+ *
+ * @param callable
+ * The callable to be retried.
+ * @param numRetries
+ * Max number of retries before throwing an {@link IOException}.
+ * @param sleep
+ * The retrier will wait a random number of msecs between 1 and
+ * sleep.
+ * @return The result of the {@link Callable#call()}.
+ * @throws IOException
+ * The wrapped {@link SQLException}.
+ */
+ public static X retry(Callable callable, int numRetries, int sleep) throws IOException {
+ int numtries = 0;
+ while (true) {
+ try {
+ return callable.call();
+ } catch (SQLException e) {
+ handleSQLException(e, ++numtries, numRetries, sleep);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ /**
+ * Tries to run the given {@link Callable} the predefined number of times
+ * before throwing an {@link IOException}. This method will only retries
+ * calls ending in {@link SQLException}. Other exceptions will cause a
+ * {@link RuntimeException}.
+ *
+ * Additionally the user can supply a second callable which will be executed
+ * every time the first callable throws a {@link SQLException}.
+ *
+ * @param callable
+ * The callable to be retried.
+ * @param onException
+ * The callable to be executed when an {@link SQLException} was
+ * encountered. Exceptions thrown during this call are ignored.
+ * @param numRetries
+ * Max number of retries before throwing an {@link IOException}.
+ * @param sleep
+ * The retrier will wait a random number of msecs between 1 and
+ * sleep.
+ * @return The result of the {@link Callable#call()}.
+ * @throws IOException
+ * The wrapped {@link SQLException}.
+ */
+ public static X retry(Callable callable, Callable onException, int numRetries, int sleep)
+ throws IOException {
+ int numtries = 0;
+ while (true) {
+ try {
+ return callable.call();
+ } catch (SQLException se) {
+ try {
+ onException.call();
+ } catch (Exception e) {
+ // Exceptions thrown in this call will be ignored
+ }
+ handleSQLException(se, ++numtries, numRetries, sleep);
+ } catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+ }
+
+ /**
+ * Tries to run the given {@link Callable} the predefined number of times
+ * before throwing an {@link IOException}. This method will only retries
+ * calls ending in {@link SQLException}. Other exceptions will cause a
+ * {@link RuntimeException}.
+ *
+ * Additionally the user can supply a second callable which will be executed
+ * every time the first callable throws a {@link SQLException}.
+ *
+ * @param callable
+ * The callable to be retried.
+ * @param onException
+ * The callable to be executed when an {@link SQLException} was
+ * encountered. Exceptions thrown during this call are ignored.
+ * @param numRetries
+ * Max number of retries before throwing an {@link IOException}.
+ * @return The result of the {@link Callable#call()}.
+ * @throws IOException
+ * The wrapped {@link SQLException}.
+ */
+ public static X retry(Callable callable, Callable onException, int numRetries)
+ throws IOException {
+ return retry(callable, onException, numRetries, SLEEP_TIME);
+ }
+
+ private static void handleSQLException(SQLException e, int numTries, int maxRetries, int sleep) throws IOException {
+ if (numTries < maxRetries) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Error while executing SQL statement: {}\nRetrying...",
+ e.getMessage());
+ }
+ try {
+ Thread.sleep(numTries * rnd.nextInt(sleep));
+ } catch (InterruptedException ie) {
+ throw new RuntimeException("Thread has been interrupted.");
+ }
+ } else {
+ throw new IOException(
+ "Could not execute SQL statement after " + maxRetries + " retries.", e);
+ }
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/ShardedConnection.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/ShardedConnection.java
new file mode 100644
index 0000000000000..44995de7d2b1a
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/ShardedConnection.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import java.io.Serializable;
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.PreparedStatement;
+import java.sql.SQLException;
+import java.util.List;
+
+/**
+ * Helper class to maintain a sharded database connection and get
+ * {@link Connection}s and {@link PreparedStatement}s for keys.
+ *
+ */
+public class ShardedConnection implements AutoCloseable, Serializable {
+
+ private static final long serialVersionUID = 1L;
+ private final Connection[] connections;
+ private final int numShards;
+
+ private final Partitioner partitioner;
+
+ public ShardedConnection(List shardUrls, String user, String password, Partitioner partitioner)
+ throws SQLException {
+ numShards = shardUrls.size();
+ connections = new Connection[numShards];
+ for (int i = 0; i < numShards; i++) {
+ connections[i] = DriverManager.getConnection(shardUrls.get(i), user, password);
+ }
+ this.partitioner = partitioner;
+ }
+
+ public ShardedConnection(List shardUrls, String user, String password) throws SQLException {
+ this(shardUrls, user, password, new ModHashPartitioner());
+ }
+
+ public ShardedStatement prepareStatement(String sql) throws SQLException {
+ return new ShardedStatement(sql);
+ }
+
+ public Connection[] connections() {
+ return connections;
+ }
+
+ public Connection getForKey(Object key) {
+ return connections[getShardIndex(key)];
+ }
+
+ public Connection getForIndex(int index) {
+ if (index < numShards) {
+ return connections[index];
+ } else {
+ throw new RuntimeException("Index out of range");
+ }
+ }
+
+ public Connection getFirst() {
+ return connections[0];
+ }
+
+ public int getNumShards() {
+ return numShards;
+ }
+
+ @Override
+ public void close() throws SQLException {
+ if (connections != null) {
+ for (Connection c : connections) {
+ c.close();
+ }
+ }
+ }
+
+ public int getShardIndex(Object key) {
+ return partitioner.getShardIndex(key, numShards);
+ }
+
+ public void setTransactionIsolation(int level) throws SQLException {
+ for (Connection con : connections) {
+ con.setTransactionIsolation(level);
+ }
+ }
+
+ public class ShardedStatement implements AutoCloseable, Serializable {
+
+ private static final long serialVersionUID = 1L;
+ private final PreparedStatement[] statements = new PreparedStatement[numShards];
+
+ public ShardedStatement(final String sql) throws SQLException {
+ for (int i = 0; i < numShards; i++) {
+ statements[i] = connections[i].prepareStatement(sql);
+ }
+ }
+
+ public PreparedStatement getForKey(Object key) {
+ return statements[getShardIndex(key)];
+ }
+
+ public PreparedStatement getForIndex(int index) {
+ if (index < numShards) {
+ return statements[index];
+ } else {
+ throw new RuntimeException("Index out of range");
+ }
+ }
+
+ public PreparedStatement getFirst() {
+ return statements[0];
+ }
+
+ @Override
+ public void close() throws SQLException {
+ if (statements != null) {
+ for (PreparedStatement t : statements) {
+ t.close();
+ }
+ }
+ }
+
+ }
+
+ public interface Partitioner extends Serializable {
+ int getShardIndex(Object key, int numShards);
+ }
+
+ public static class ModHashPartitioner implements Partitioner {
+
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public int getShardIndex(Object key, int numShards) {
+ return Math.abs(key.hashCode() % numShards);
+ }
+
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DBStateCheckpointingTest.java b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DBStateCheckpointingTest.java
new file mode 100644
index 0000000000000..337960ff71c30
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DBStateCheckpointingTest.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.derby.drda.NetworkServerControl;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.OperatorState;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.test.checkpointing.PartitionedStateCheckpointingITCase.IdentityKeySelector;
+import org.apache.flink.test.checkpointing.PartitionedStateCheckpointingITCase.NonSerializableLong;
+import org.apache.flink.test.checkpointing.StreamFaultToleranceTestBase;
+import org.junit.After;
+import org.junit.Before;
+
+@SuppressWarnings("serial")
+public class DBStateCheckpointingTest extends StreamFaultToleranceTestBase {
+
+ final long NUM_STRINGS = 1_000_000L;
+ final static int NUM_KEYS = 100;
+ private static NetworkServerControl server;
+ private static File tempDir;
+
+ @Before
+ public void startDerbyServer() throws UnknownHostException, Exception {
+ server = new NetworkServerControl(InetAddress.getByName("localhost"), 1526, "flink", "flink");
+ server.start(null);
+ tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
+ }
+
+ @After
+ public void stopDerbyServer() {
+ try {
+ server.shutdown();
+ FileUtils.deleteDirectory(new File(tempDir.getAbsolutePath() + "/flinkDB1"));
+ FileUtils.forceDelete(new File("derby.log"));
+ } catch (Exception ignore) {
+ }
+ }
+
+ @Override
+ public void testProgram(StreamExecutionEnvironment env) {
+ env.enableCheckpointing(500);
+
+ DbBackendConfig conf = new DbBackendConfig("flink", "flink",
+ "jdbc:derby://localhost:1526/" + tempDir.getAbsolutePath() + "/flinkDB1;create=true");
+ conf.setDbAdapter(new DerbyAdapter());
+ conf.setKvStateCompactionFrequency(2);
+
+ // We store the non-partitioned states (source offset) in-memory
+ DbStateBackend backend = new DbStateBackend(conf, new MemoryStateBackend());
+
+ env.setStateBackend(backend);
+
+ DataStream stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
+ DataStream stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
+
+ stream1.union(stream2).keyBy(new IdentityKeySelector()).map(new OnceFailingPartitionedSum(NUM_STRINGS))
+ .keyBy(0).addSink(new CounterSink());
+ }
+
+ @Override
+ public void postSubmit() {
+ // verify that we counted exactly right
+ for (Entry sum : OnceFailingPartitionedSum.allSums.entrySet()) {
+ assertEquals(new Long(sum.getKey() * NUM_STRINGS / NUM_KEYS), sum.getValue());
+ }
+ for (Long count : CounterSink.allCounts.values()) {
+ assertEquals(new Long(NUM_STRINGS / NUM_KEYS), count);
+ }
+
+ assertEquals(NUM_KEYS, CounterSink.allCounts.size());
+ assertEquals(NUM_KEYS, OnceFailingPartitionedSum.allSums.size());
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Custom Functions
+ // --------------------------------------------------------------------------------------------
+
+ private static class IntGeneratingSourceFunction extends RichParallelSourceFunction
+ implements Checkpointed {
+
+ private final long numElements;
+
+ private int index;
+ private int step;
+
+ private Random rnd = new Random();
+
+ private volatile boolean isRunning = true;
+
+ static final long[] counts = new long[PARALLELISM];
+
+ @Override
+ public void close() throws IOException {
+ counts[getRuntimeContext().getIndexOfThisSubtask()] = index;
+ }
+
+ IntGeneratingSourceFunction(long numElements) {
+ this.numElements = numElements;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws IOException {
+ step = getRuntimeContext().getNumberOfParallelSubtasks();
+ if (index == 0) {
+ index = getRuntimeContext().getIndexOfThisSubtask();
+ }
+ }
+
+ @Override
+ public void run(SourceContext ctx) throws Exception {
+ final Object lockingObject = ctx.getCheckpointLock();
+
+ while (isRunning && index < numElements) {
+
+ synchronized (lockingObject) {
+ index += step;
+ ctx.collect(index % NUM_KEYS);
+ }
+
+ if (rnd.nextDouble() < 0.008) {
+ Thread.sleep(1);
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ isRunning = false;
+ }
+
+ @Override
+ public Integer snapshotState(long checkpointId, long checkpointTimestamp) {
+ return index;
+ }
+
+ @Override
+ public void restoreState(Integer state) {
+ index = state;
+ }
+ }
+
+ private static class OnceFailingPartitionedSum extends RichMapFunction> {
+
+ private static Map allSums = new ConcurrentHashMap();
+
+ private static volatile boolean hasFailed = false;
+
+ private final long numElements;
+
+ private long failurePos;
+ private long count;
+
+ private OperatorState sum;
+
+ OnceFailingPartitionedSum(long numElements) {
+ this.numElements = numElements;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws IOException {
+ long failurePosMin = (long) (0.6 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
+ long failurePosMax = (long) (0.8 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
+
+ failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
+ count = 0;
+ sum = getRuntimeContext().getKeyValueState("my_state", Long.class, 0L);
+ }
+
+ @Override
+ public Tuple2 map(Integer value) throws Exception {
+ count++;
+ if (!hasFailed && count >= failurePos) {
+ hasFailed = true;
+ throw new Exception("Test Failure");
+ }
+
+ long currentSum = sum.value() + value;
+ sum.update(currentSum);
+ allSums.put(value, currentSum);
+ return new Tuple2(value, currentSum);
+ }
+ }
+
+ private static class CounterSink extends RichSinkFunction> {
+
+ private static Map allCounts = new ConcurrentHashMap();
+
+ private OperatorState aCounts;
+ private OperatorState bCounts;
+
+ @Override
+ public void open(Configuration parameters) throws IOException {
+ aCounts = getRuntimeContext().getKeyValueState("a", NonSerializableLong.class, NonSerializableLong.of(0L));
+ bCounts = getRuntimeContext().getKeyValueState("b", Long.class, 0L);
+ }
+
+ @Override
+ public void invoke(Tuple2 value) throws Exception {
+ long ac = aCounts.value().value;
+ long bc = bCounts.value();
+ assertEquals(ac, bc);
+
+ long currentCount = ac + 1;
+ aCounts.update(NonSerializableLong.of(currentCount));
+ bCounts.update(currentCount);
+
+ allCounts.put(value.f0, currentCount);
+ }
+ }
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DbStateBackendTest.java b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DbStateBackendTest.java
new file mode 100644
index 0000000000000..209086f8d94a0
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DbStateBackendTest.java
@@ -0,0 +1,478 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.derby.drda.NetworkServerControl;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.core.testutils.CommonTestUtils;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.KvState;
+import org.apache.flink.runtime.state.KvStateSnapshot;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.runtime.state.filesystem.FsStateBackend;
+import org.apache.flink.shaded.com.google.common.collect.Lists;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import com.google.common.base.Optional;
+
+public class DbStateBackendTest {
+
+ private static NetworkServerControl server;
+ private static File tempDir;
+ private static DbBackendConfig conf;
+ private static String url1;
+ private static String url2;
+
+ @BeforeClass
+ public static void startDerbyServer() throws UnknownHostException, Exception {
+ server = new NetworkServerControl(InetAddress.getByName("localhost"), 1527, "flink", "flink");
+ server.start(null);
+ tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
+ conf = new DbBackendConfig("flink", "flink",
+ "jdbc:derby://localhost:1527/" + tempDir.getAbsolutePath() + "/flinkDB1;create=true");
+ conf.setDbAdapter(new DerbyAdapter());
+ conf.setKvStateCompactionFrequency(1);
+
+ url1 = "jdbc:derby://localhost:1527/" + tempDir.getAbsolutePath() + "/flinkDB1;create=true";
+ url2 = "jdbc:derby://localhost:1527/" + tempDir.getAbsolutePath() + "/flinkDB2;create=true";
+ }
+
+ @AfterClass
+ public static void stopDerbyServer() throws Exception {
+ try {
+ server.shutdown();
+ FileUtils.deleteDirectory(new File(tempDir.getAbsolutePath() + "/flinkDB1"));
+ FileUtils.deleteDirectory(new File(tempDir.getAbsolutePath() + "/flinkDB2"));
+ FileUtils.forceDelete(new File("derby.log"));
+ } catch (Exception ignore) {
+ }
+ }
+
+ @Test
+ public void testSetupAndSerialization() throws Exception {
+ DbStateBackend dbBackend = new DbStateBackend(conf);
+
+ assertFalse(dbBackend.isInitialized());
+
+ // serialize / copy the backend
+ DbStateBackend backend = CommonTestUtils.createCopySerializable(dbBackend);
+ assertFalse(backend.isInitialized());
+
+ Environment env = new DummyEnvironment("test", 1, 0);
+ backend.initializeForJob(env);
+
+ assertNotNull(backend.getConnections());
+ assertTrue(
+ isTableCreated(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString()));
+
+ backend.disposeAllStateForCurrentJob();
+ assertFalse(
+ isTableCreated(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString()));
+ backend.close();
+
+ assertTrue(backend.getConnections().getFirst().isClosed());
+ }
+
+ @Test
+ public void testSerializableState() throws Exception {
+ Environment env = new DummyEnvironment("test", 1, 0);
+ DbStateBackend backend = new DbStateBackend(conf);
+
+ backend.initializeForJob(env);
+
+ String state1 = "dummy state";
+ String state2 = "row row row your boat";
+ Integer state3 = 42;
+
+ StateHandle handle1 = backend.checkpointStateSerializable(state1, 439568923746L,
+ System.currentTimeMillis());
+ StateHandle handle2 = backend.checkpointStateSerializable(state2, 439568923746L,
+ System.currentTimeMillis());
+ StateHandle handle3 = backend.checkpointStateSerializable(state3, 439568923746L,
+ System.currentTimeMillis());
+
+ assertEquals(state1, handle1.getState(getClass().getClassLoader()));
+ handle1.discardState();
+
+ assertEquals(state2, handle2.getState(getClass().getClassLoader()));
+ handle2.discardState();
+
+ assertFalse(isTableEmpty(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString()));
+
+ assertEquals(state3, handle3.getState(getClass().getClassLoader()));
+ handle3.discardState();
+
+ assertTrue(isTableEmpty(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString()));
+
+ backend.close();
+
+ }
+
+ @Test
+ public void testKeyValueState() throws Exception {
+
+ // We will create the DbStateBackend backed with a FsStateBackend for
+ // nonPartitioned states
+ File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
+ try {
+ FsStateBackend fileBackend = new FsStateBackend(localFileUri(tempDir));
+
+ DbStateBackend backend = new DbStateBackend(conf, fileBackend);
+
+ Environment env = new DummyEnvironment("test", 2, 0);
+
+ backend.initializeForJob(env);
+
+ LazyDbKvState kv = backend.createKvState("state1_1", "state1", IntSerializer.INSTANCE,
+ StringSerializer.INSTANCE, null);
+
+ String tableName = "state1_1_" + env.getJobID().toShortString();
+ assertTrue(isTableCreated(backend.getConnections().getFirst(), tableName));
+
+ assertEquals(0, kv.size());
+
+ // some modifications to the state
+ kv.setCurrentKey(1);
+ assertNull(kv.value());
+ kv.update("1");
+ assertEquals(1, kv.size());
+ kv.setCurrentKey(2);
+ assertNull(kv.value());
+ kv.update("2");
+ assertEquals(2, kv.size());
+ kv.setCurrentKey(1);
+ assertEquals("1", kv.value());
+ assertEquals(2, kv.size());
+
+ kv.snapshot(682375462378L, 100);
+
+ // make some more modifications
+ kv.setCurrentKey(1);
+ kv.update("u1");
+ kv.setCurrentKey(2);
+ kv.update("u2");
+ kv.setCurrentKey(3);
+ kv.update("u3");
+
+ // draw another snapshot
+ KvStateSnapshot snapshot2 = kv.snapshot(682375462379L,
+ 200);
+
+ // validate the original state
+ assertEquals(3, kv.size());
+ kv.setCurrentKey(1);
+ assertEquals("u1", kv.value());
+ kv.setCurrentKey(2);
+ assertEquals("u2", kv.value());
+ kv.setCurrentKey(3);
+ assertEquals("u3", kv.value());
+
+ // restore the first snapshot and validate it
+ KvState restored2 = snapshot2.restoreState(backend, IntSerializer.INSTANCE,
+ StringSerializer.INSTANCE, null, getClass().getClassLoader(), 6823754623710L);
+
+ assertEquals(0, restored2.size());
+ restored2.setCurrentKey(1);
+ assertEquals("u1", restored2.value());
+ restored2.setCurrentKey(2);
+ assertEquals("u2", restored2.value());
+ restored2.setCurrentKey(3);
+ assertEquals("u3", restored2.value());
+
+ backend.close();
+ } finally {
+ deleteDirectorySilently(tempDir);
+ }
+ }
+
+ @Test
+ public void testCompaction() throws Exception {
+ DbBackendConfig conf = new DbBackendConfig("flink", "flink", url1);
+ MockAdapter adapter = new MockAdapter();
+ conf.setKvStateCompactionFrequency(2);
+ conf.setDbAdapter(adapter);
+
+ DbStateBackend backend1 = new DbStateBackend(conf);
+ DbStateBackend backend2 = new DbStateBackend(conf);
+ DbStateBackend backend3 = new DbStateBackend(conf);
+
+ backend1.initializeForJob(new DummyEnvironment("test", 3, 0));
+ backend2.initializeForJob(new DummyEnvironment("test", 3, 1));
+ backend3.initializeForJob(new DummyEnvironment("test", 3, 2));
+
+ LazyDbKvState, ?> s1 = backend1.createKvState("a_1", "a", null, null, null);
+ LazyDbKvState, ?> s2 = backend2.createKvState("a_1", "a", null, null, null);
+ LazyDbKvState, ?> s3 = backend3.createKvState("a_1", "a", null, null, null);
+
+ assertTrue(s1.isCompactor());
+ assertFalse(s2.isCompactor());
+ assertFalse(s3.isCompactor());
+ assertNotNull(s1.getExecutor());
+ assertNull(s2.getExecutor());
+ assertNull(s3.getExecutor());
+
+ s1.snapshot(1, 100);
+ s1.notifyCheckpointComplete(1);
+ s1.snapshot(2, 200);
+ s1.snapshot(3, 300);
+ s1.notifyCheckpointComplete(2);
+ s1.notifyCheckpointComplete(3);
+ s1.snapshot(4, 400);
+ s1.snapshot(5, 500);
+ s1.notifyCheckpointComplete(4);
+ s1.notifyCheckpointComplete(5);
+
+ s1.dispose();
+ s2.dispose();
+ s3.dispose();
+
+ // Wait until the compaction completes
+ s1.getExecutor().awaitTermination(5, TimeUnit.SECONDS);
+ assertEquals(2, adapter.numCompcations.get());
+ assertEquals(5, adapter.keptAlive);
+
+ backend1.close();
+ backend2.close();
+ backend3.close();
+ }
+
+ @Test
+ public void testCaching() throws Exception {
+
+ List urls = Lists.newArrayList(url1, url2);
+ DbBackendConfig conf = new DbBackendConfig("flink", "flink",
+ urls);
+
+ conf.setDbAdapter(new DerbyAdapter());
+ conf.setKvCacheSize(3);
+ conf.setMaxKvInsertBatchSize(2);
+
+ // We evict 2 elements when the cache is full
+ conf.setMaxKvCacheEvictFraction(0.6f);
+
+ DbStateBackend backend = new DbStateBackend(conf);
+
+ Environment env = new DummyEnvironment("test", 2, 0);
+
+ String tableName = "state1_1_" + env.getJobID().toShortString();
+ assertFalse(isTableCreated(DriverManager.getConnection(url1, "flink", "flink"), tableName));
+ assertFalse(isTableCreated(DriverManager.getConnection(url2, "flink", "flink"), tableName));
+
+ backend.initializeForJob(env);
+
+ LazyDbKvState kv = backend.createKvState("state1_1", "state1", IntSerializer.INSTANCE,
+ StringSerializer.INSTANCE, "a");
+
+ assertTrue(isTableCreated(DriverManager.getConnection(url1, "flink", "flink"), tableName));
+ assertTrue(isTableCreated(DriverManager.getConnection(url2, "flink", "flink"), tableName));
+
+ Map> cache = kv.getStateCache();
+ Map> modified = kv.getModified();
+
+ assertEquals(0, kv.size());
+
+ // some modifications to the state
+ kv.setCurrentKey(1);
+ assertEquals("a", kv.value());
+
+ kv.update(null);
+ assertEquals(1, kv.size());
+ kv.setCurrentKey(2);
+ assertEquals("a", kv.value());
+ kv.update("2");
+ assertEquals(2, kv.size());
+ assertEquals("2", kv.value());
+
+ kv.setCurrentKey(1);
+ assertEquals("a", kv.value());
+
+ kv.setCurrentKey(3);
+ kv.update("3");
+ assertEquals("3", kv.value());
+
+ assertTrue(modified.containsKey(1));
+ assertTrue(modified.containsKey(2));
+ assertTrue(modified.containsKey(3));
+
+ // 1,2 should be evicted as the cache filled
+ kv.setCurrentKey(4);
+ kv.update("4");
+ assertEquals("4", kv.value());
+
+ assertFalse(modified.containsKey(1));
+ assertFalse(modified.containsKey(2));
+ assertTrue(modified.containsKey(3));
+ assertTrue(modified.containsKey(4));
+
+ assertEquals(Optional.of("3"), cache.get(3));
+ assertEquals(Optional.of("4"), cache.get(4));
+ assertFalse(cache.containsKey(1));
+ assertFalse(cache.containsKey(2));
+
+ // draw a snapshot
+ kv.snapshot(682375462378L, 100);
+
+ assertTrue(modified.isEmpty());
+
+ // make some more modifications
+ kv.setCurrentKey(2);
+ assertEquals("2", kv.value());
+ kv.update(null);
+ assertEquals("a", kv.value());
+
+ assertTrue(modified.containsKey(2));
+ assertEquals(1, modified.size());
+
+ assertEquals(Optional.of("3"), cache.get(3));
+ assertEquals(Optional.of("4"), cache.get(4));
+ assertEquals(Optional.absent(), cache.get(2));
+ assertFalse(cache.containsKey(1));
+
+ assertTrue(modified.containsKey(2));
+ assertFalse(modified.containsKey(3));
+ assertFalse(modified.containsKey(4));
+ assertTrue(cache.containsKey(3));
+ assertTrue(cache.containsKey(4));
+
+ // clear cache from initial keys
+
+ kv.setCurrentKey(5);
+ kv.value();
+ kv.setCurrentKey(6);
+ kv.value();
+ kv.setCurrentKey(7);
+ kv.value();
+
+ assertFalse(modified.containsKey(5));
+ assertTrue(modified.containsKey(6));
+ assertTrue(modified.containsKey(7));
+
+ assertFalse(cache.containsKey(1));
+ assertFalse(cache.containsKey(2));
+ assertFalse(cache.containsKey(3));
+ assertFalse(cache.containsKey(4));
+
+ kv.setCurrentKey(2);
+ assertEquals("a", kv.value());
+
+ long checkpointTs = System.currentTimeMillis();
+
+ // Draw a snapshot that we will restore later
+ KvStateSnapshot snapshot1 = kv.snapshot(682375462379L, checkpointTs);
+ assertTrue(modified.isEmpty());
+
+ // Do some updates then draw another snapshot (imitate a partial
+ // failure), these updates should not be visible if we restore snapshot1
+ kv.setCurrentKey(1);
+ kv.update("123");
+ kv.setCurrentKey(3);
+ kv.update("456");
+ kv.setCurrentKey(2);
+ kv.notifyCheckpointComplete(682375462379L);
+ kv.update("2");
+ kv.setCurrentKey(4);
+ kv.update("4");
+ kv.update("5");
+
+ kv.snapshot(6823754623710L, checkpointTs + 10);
+
+ // restore the second snapshot and validate it (we set a new default
+ // value here to make sure that the default wasn't written)
+ KvState restored = snapshot1.restoreState(backend, IntSerializer.INSTANCE,
+ StringSerializer.INSTANCE, "b", getClass().getClassLoader(), 6823754623711L);
+
+ restored.setCurrentKey(1);
+ assertEquals("b", restored.value());
+ restored.setCurrentKey(2);
+ assertEquals("b", restored.value());
+ restored.setCurrentKey(3);
+ assertEquals("3", restored.value());
+ restored.setCurrentKey(4);
+ assertEquals("4", restored.value());
+
+ backend.close();
+ }
+
+ private static boolean isTableCreated(Connection con, String tableName) throws SQLException {
+ return con.getMetaData().getTables(null, null, tableName.toUpperCase(), null).next();
+ }
+
+ private static boolean isTableEmpty(Connection con, String tableName) throws SQLException {
+ try (Statement smt = con.createStatement()) {
+ ResultSet res = smt.executeQuery("select * from " + tableName);
+ return !res.next();
+ }
+ }
+
+ private static String localFileUri(File path) {
+ return path.toURI().toString();
+ }
+
+ private static void deleteDirectorySilently(File dir) {
+ try {
+ FileUtils.deleteDirectory(dir);
+ } catch (IOException ignored) {
+ }
+ }
+
+ private static class MockAdapter extends DerbyAdapter {
+
+ private static final long serialVersionUID = 1L;
+ public AtomicInteger numCompcations = new AtomicInteger(0);
+ public int keptAlive = 0;
+
+ @Override
+ public void compactKvStates(String kvStateId, Connection con, long lowerTs, long upperTs) throws SQLException {
+ numCompcations.incrementAndGet();
+ }
+
+ @Override
+ public void keepAlive(Connection con) throws SQLException {
+ keptAlive++;
+ }
+ }
+
+}
diff --git a/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DerbyAdapter.java b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DerbyAdapter.java
new file mode 100644
index 0000000000000..1f13f4bdc397f
--- /dev/null
+++ b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DerbyAdapter.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import java.io.IOException;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.sql.Types;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/**
+ * Adapter for the Derby JDBC driver which has slightly restricted CREATE TABLE
+ * and SELECT semantics compared to the default assumptions.
+ *
+ */
+public class DerbyAdapter extends MySqlAdapter {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * We need to override this method as Derby does not support the
+ * "IF NOT EXISTS" clause at table creation
+ */
+ @Override
+ public void createCheckpointsTable(String jobId, Connection con) throws SQLException {
+
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate(
+ "CREATE TABLE checkpoints_" + jobId
+ + " ("
+ + "checkpointId bigint, "
+ + "timestamp bigint, "
+ + "handleId bigint,"
+ + "checkpoint blob,"
+ + "PRIMARY KEY (handleId)"
+ + ")");
+ } catch (SQLException se) {
+ if (se.getSQLState().equals("X0Y32")) {
+ // table already created, ignore
+ } else {
+ throw se;
+ }
+ }
+ }
+
+ /**
+ * We need to override this method as Derby does not support the
+ * "IF NOT EXISTS" clause at table creation
+ */
+ @Override
+ public void createKVStateTable(String stateId, Connection con) throws SQLException {
+
+ validateStateId(stateId);
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate(
+ "CREATE TABLE " + stateId
+ + " ("
+ + "timestamp bigint, "
+ + "k varchar(256) for bit data, "
+ + "v blob, "
+ + "PRIMARY KEY (k, timestamp)"
+ + ")");
+ } catch (SQLException se) {
+ if (se.getSQLState().equals("X0Y32")) {
+ // table already created, ignore
+ } else {
+ throw se;
+ }
+ }
+ }
+
+ /**
+ * We need to override this method as Derby does not support "LIMIT n" for
+ * select statements.
+ */
+ @Override
+ public String prepareKeyLookup(String stateId) throws SQLException {
+ validateStateId(stateId);
+ return "SELECT v " + "FROM " + stateId
+ + " WHERE k = ? "
+ + " AND timestamp <= ?"
+ + " ORDER BY timestamp DESC";
+ }
+
+ @Override
+ public void compactKvStates(String stateId, Connection con, long lowerBound, long upperBound)
+ throws SQLException {
+ validateStateId(stateId);
+
+ try (Statement smt = con.createStatement()) {
+ smt.executeUpdate("DELETE FROM " + stateId + " t1"
+ + " WHERE EXISTS"
+ + " ("
+ + " SELECT * FROM " + stateId + " t2"
+ + " WHERE t2.k = t1.k"
+ + " AND t2.timestamp > t1.timestamp"
+ + " AND t2.timestamp <=" + upperBound
+ + " AND t2.timestamp >= " + lowerBound
+ + " )");
+ }
+ }
+
+ @Override
+ public String prepareKVCheckpointInsert(String stateId) throws SQLException {
+ validateStateId(stateId);
+ return "INSERT INTO " + stateId + " (timestamp, k, v) VALUES (?,?,?)";
+ }
+
+ @Override
+ public void insertBatch(final String stateId, final DbBackendConfig conf,
+ final Connection con, final PreparedStatement insertStatement, final long checkpointTs,
+ final List> toInsert) throws IOException {
+
+ SQLRetrier.retry(new Callable() {
+ public Void call() throws Exception {
+ con.setAutoCommit(false);
+ for (Tuple2 kv : toInsert) {
+ setKVInsertParams(stateId, insertStatement, checkpointTs, kv.f0, kv.f1);
+ insertStatement.addBatch();
+ }
+ insertStatement.executeBatch();
+ con.commit();
+ con.setAutoCommit(true);
+ insertStatement.clearBatch();
+ return null;
+ }
+ }, new Callable() {
+ public Void call() throws Exception {
+ con.rollback();
+ insertStatement.clearBatch();
+ return null;
+ }
+ }, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries());
+ }
+
+ private void setKVInsertParams(String stateId, PreparedStatement insertStatement, long checkpointId,
+ byte[] key, byte[] value) throws SQLException {
+ insertStatement.setLong(1, checkpointId);
+ insertStatement.setBytes(2, key);
+ if (value != null) {
+ insertStatement.setBytes(3, value);
+ } else {
+ insertStatement.setNull(3, Types.BLOB);
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index fdb59d9a89a56..09dd2d97768cd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -489,14 +489,16 @@ public void restoreLatestCheckpointedState(
return;
}
}
+
+ long recoveryTimestamp = System.currentTimeMillis();
if (allOrNothingState) {
Map stateCounts = new HashMap();
-
+
for (StateForTask state : latest.getStates()) {
ExecutionJobVertex vertex = tasks.get(state.getOperatorId());
Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt();
- exec.setInitialState(state.getState());
+ exec.setInitialState(state.getState(), recoveryTimestamp);
Integer count = stateCounts.get(vertex);
if (count != null) {
@@ -519,7 +521,7 @@ public void restoreLatestCheckpointedState(
for (StateForTask state : latest.getStates()) {
ExecutionJobVertex vertex = tasks.get(state.getOperatorId());
Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt();
- exec.setInitialState(state.getState());
+ exec.setInitialState(state.getState(), recoveryTimestamp);
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java
index 34b7946fc04c7..82d8e7cde4c5f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java
@@ -39,5 +39,5 @@ public interface CheckpointIDCounter {
* @return The previous checkpoint ID
*/
long getAndIncrement() throws Exception;
-
+
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
index 558fcd039d71c..e6a1583088d26 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
@@ -83,7 +83,9 @@ public final class TaskDeploymentDescriptor implements Serializable {
private final List requiredClasspaths;
private final SerializedValue> operatorState;
-
+
+ private long recoveryTimestamp;
+
/**
* Constructs a task deployment descriptor.
*/
@@ -94,7 +96,7 @@ public TaskDeploymentDescriptor(
List producedPartitions,
List inputGates,
List requiredJarFiles, List requiredClasspaths,
- int targetSlotNumber, SerializedValue> operatorState) {
+ int targetSlotNumber, SerializedValue> operatorState, long recoveryTimestamp) {
checkArgument(indexInSubtaskGroup >= 0);
checkArgument(numberOfSubtasks > indexInSubtaskGroup);
@@ -115,6 +117,7 @@ public TaskDeploymentDescriptor(
this.requiredClasspaths = checkNotNull(requiredClasspaths);
this.targetSlotNumber = targetSlotNumber;
this.operatorState = operatorState;
+ this.recoveryTimestamp = recoveryTimestamp;
}
public TaskDeploymentDescriptor(
@@ -128,7 +131,7 @@ public TaskDeploymentDescriptor(
this(jobID, vertexID, executionId, taskName, indexInSubtaskGroup, numberOfSubtasks,
jobConfiguration, taskConfiguration, invokableClassName, producedPartitions,
- inputGates, requiredJarFiles, requiredClasspaths, targetSlotNumber, null);
+ inputGates, requiredJarFiles, requiredClasspaths, targetSlotNumber, null, -1);
}
/**
@@ -245,4 +248,8 @@ private String collectionToString(Collection> collection) {
public SerializedValue> getOperatorState() {
return operatorState;
}
+
+ public long getRecoveryTimestamp() {
+ return recoveryTimestamp;
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index faabfb3bdaa2a..ce17525fe7e23 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -135,6 +135,8 @@ public class Execution implements Serializable {
private volatile InstanceConnectionInfo assignedResourceLocation; // for the archived execution
private SerializedValue> operatorState;
+
+ private long recoveryTimestamp;
/** The execution context which is used to execute futures. */
@SuppressWarnings("NonSerializableFieldInSerializableClass")
@@ -231,11 +233,12 @@ public void prepareForArchiving() {
partialInputChannelDeploymentDescriptors = null;
}
- public void setInitialState(SerializedValue> initialState) {
+ public void setInitialState(SerializedValue> initialState, long recoveryTimestamp) {
if (state != ExecutionState.CREATED) {
throw new IllegalArgumentException("Can only assign operator state when execution attempt is in CREATED");
}
this.operatorState = initialState;
+ this.recoveryTimestamp = recoveryTimestamp;
}
// --------------------------------------------------------------------------------------------
@@ -359,7 +362,7 @@ public void deployToSlot(final SimpleSlot slot) throws JobException {
attemptNumber, slot.getInstance().getInstanceConnectionInfo().getHostname()));
}
- final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(attemptId, slot, operatorState);
+ final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(attemptId, slot, operatorState, recoveryTimestamp);
// register this execution at the execution graph, to receive call backs
vertex.getExecutionGraph().registerExecution(this);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 1e5d02cb95431..d10aac13a5453 100755
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -811,6 +811,7 @@ else if (current != JobStatus.RESTARTING) {
*
* The recovery of checkpoints might block. Make sure that calls to this method don't
* block the job manager actor and run asynchronously.
+ *
*/
public void restoreLatestCheckpointedState() throws Exception {
synchronized (progressLock) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index 6a635287d912e..fba5652bbc427 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -616,7 +616,8 @@ void notifyStateTransition(ExecutionAttemptID executionId, ExecutionState newSta
TaskDeploymentDescriptor createDeploymentDescriptor(
ExecutionAttemptID executionId,
SimpleSlot targetSlot,
- SerializedValue> operatorState) {
+ SerializedValue> operatorState,
+ long recoveryTimestamp) {
// Produced intermediate results
List producedPartitions = new ArrayList(resultPartitions.size());
@@ -651,7 +652,7 @@ TaskDeploymentDescriptor createDeploymentDescriptor(
subTaskIndex, getTotalNumberOfParallelSubtasks(), getExecutionGraph().getJobConfiguration(),
jobVertex.getJobVertex().getConfiguration(), jobVertex.getJobVertex().getInvokableClassName(),
producedPartitions, consumedPartitions, jarFiles, classpaths, targetSlot.getRoot().getSlotNumber(),
- operatorState);
+ operatorState, recoveryTimestamp);
}
// --------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
index 894e6d91fff4b..fac4ec471df5a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
@@ -31,8 +31,9 @@ public interface StatefulTask> {
* a snapshot of the state from a previous execution.
*
* @param stateHandle The handle to the state.
+ * @param recoveryTimestamp Global recovery timestamp.
*/
- void setInitialState(T stateHandle) throws Exception;
+ void setInitialState(T stateHandle, long recoveryTimestamp) throws Exception;
/**
* This method is either called directly and asynchronously by the checkpoint
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
index 3d6c56c806ab7..e2e521cb3269f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java
@@ -57,7 +57,8 @@ KvState restoreState(
TypeSerializer keySerializer,
TypeSerializer valueSerializer,
V defaultValue,
- ClassLoader classLoader) throws Exception;
+ ClassLoader classLoader,
+ long recoveryTimestamp) throws Exception;
/**
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java
index f8b1cfdbed642..293de956a055c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java
@@ -18,12 +18,12 @@
package org.apache.flink.runtime.state;
-import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.execution.Environment;
import java.io.IOException;
import java.io.OutputStream;
@@ -31,32 +31,32 @@
/**
* A state backend defines how state is stored and snapshotted during checkpoints.
- *
+ *
* @param The type of backend itself. This generic parameter is used to refer to the
* type of backend when creating state backed by this backend.
*/
public abstract class StateBackend> implements java.io.Serializable {
-
+
private static final long serialVersionUID = 4620413814639220247L;
-
+
// ------------------------------------------------------------------------
// initialization and cleanup
// ------------------------------------------------------------------------
-
+
/**
* This method is called by the task upon deployment to initialize the state backend for
* data for a specific job.
- *
- * @param job The ID of the job for which the state backend instance checkpoints data.
+ *
+ * @param The {@link Environment} of the task that instantiated the state backend
* @throws Exception Overwritten versions of this method may throw exceptions, in which
* case the job that uses the state backend is considered failed during
* deployment.
*/
- public abstract void initializeForJob(JobID job) throws Exception;
+ public abstract void initializeForJob(Environment env) throws Exception;
/**
* Disposes all state associated with the current job.
- *
+ *
* @throws Exception Exceptions may occur during disposal of the state and should be forwarded.
*/
public abstract void disposeAllStateForCurrentJob() throws Exception;
@@ -64,33 +64,35 @@ public abstract class StateBackend> implem
/**
* Closes the state backend, releasing all internal resources, but does not delete any persistent
* checkpoint data.
- *
+ *
* @throws Exception Exceptions can be forwarded and will be logged by the system
*/
public abstract void close() throws Exception;
-
+
// ------------------------------------------------------------------------
// key/value state
// ------------------------------------------------------------------------
/**
* Creates a key/value state backed by this state backend.
- *
+ *
+ * @param stateId Unique id that identifies the kv state in the streaming program.
+ * @param stateName Name of the created state
* @param keySerializer The serializer for the key.
* @param valueSerializer The serializer for the value.
* @param defaultValue The value that is returned when no other value has been associated with a key, yet.
* @param The type of the key.
* @param The type of the value.
- *
+ *
* @return A new key/value state backed by this backend.
- *
+ *
* @throws Exception Exceptions may occur during initialization of the state and should be forwarded.
*/
- public abstract KvState createKvState(
+ public abstract KvState createKvState(String stateId, String stateName,
TypeSerializer keySerializer, TypeSerializer valueSerializer,
V defaultValue) throws Exception;
-
-
+
+
// ------------------------------------------------------------------------
// storing state for a checkpoint
// ------------------------------------------------------------------------
@@ -98,16 +100,16 @@ public abstract KvState createKvState(
/**
* Creates an output stream that writes into the state of the given checkpoint. When the stream
* is closes, it returns a state handle that can retrieve the state back.
- *
+ *
* @param checkpointID The ID of the checkpoint.
* @param timestamp The timestamp of the checkpoint.
* @return An output stream that writes state for the given checkpoint.
- *
+ *
* @throws Exception Exceptions may occur while creating the stream and should be forwarded.
*/
public abstract CheckpointStateOutputStream createCheckpointStateOutputStream(
long checkpointID, long timestamp) throws Exception;
-
+
/**
* Creates a {@link DataOutputView} stream that writes into the state of the given checkpoint.
* When the stream is closes, it returns a state handle that can retrieve the state back.
@@ -125,20 +127,20 @@ public CheckpointStateOutputView createCheckpointStateOutputView(
/**
* Writes the given state into the checkpoint, and returns a handle that can retrieve the state back.
- *
+ *
* @param state The state to be checkpointed.
* @param checkpointID The ID of the checkpoint.
* @param timestamp The timestamp of the checkpoint.
* @param The type of the state.
- *
+ *
* @return A state handle that can retrieve the checkpoined state.
- *
+ *
* @throws Exception Exceptions may occur during serialization / storing the state and should be forwarded.
*/
public abstract StateHandle checkpointStateSerializable(
S state, long checkpointID, long timestamp) throws Exception;
-
-
+
+
// ------------------------------------------------------------------------
// Checkpoint state output stream
// ------------------------------------------------------------------------
@@ -151,7 +153,7 @@ public static abstract class CheckpointStateOutputStream extends OutputStream {
/**
* Closes the stream and gets a state handle that can create an input stream
* producing the data written to this stream.
- *
+ *
* @return A state handle that can create an input stream producing the data written to this stream.
* @throws IOException Thrown, if the stream cannot be closed.
*/
@@ -162,9 +164,9 @@ public static abstract class CheckpointStateOutputStream extends OutputStream {
* A dedicated DataOutputView stream that produces a {@code StateHandle} when closed.
*/
public static final class CheckpointStateOutputView extends DataOutputViewStreamWrapper {
-
+
private final CheckpointStateOutputStream out;
-
+
public CheckpointStateOutputView(CheckpointStateOutputStream out) {
super(out);
this.out = out;
@@ -193,7 +195,7 @@ public void close() throws IOException {
private static final class DataInputViewHandle implements StateHandle {
private static final long serialVersionUID = 2891559813513532079L;
-
+
private final StreamStateHandle stream;
private DataInputViewHandle(StreamStateHandle stream) {
@@ -202,7 +204,7 @@ private DataInputViewHandle(StreamStateHandle stream) {
@Override
public DataInputView getState(ClassLoader userCodeClassLoader) throws Exception {
- return new DataInputViewStreamWrapper(stream.getState(userCodeClassLoader));
+ return new DataInputViewStreamWrapper(stream.getState(userCodeClassLoader));
}
@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
index 88b0d1899014b..96e0eb526aebb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
@@ -38,17 +38,19 @@ public class StateUtils {
* The state carrier operator.
* @param state
* The state handle.
+ * @param recoveryTimestamp
+ * Global recovery timestamp
* @param
* Type bound for the
*/
public static > void setOperatorState(StatefulTask> op,
- StateHandle> state) throws Exception {
+ StateHandle> state, long recoveryTimestamp) throws Exception {
@SuppressWarnings("unchecked")
StatefulTask typedOp = (StatefulTask) op;
@SuppressWarnings("unchecked")
T typedHandle = (T) state;
- typedOp.setInitialState(typedHandle);
+ typedOp.setInitialState(typedHandle, recoveryTimestamp);
}
// ------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java
index 781ee3dafd89d..c5c2fd7661987 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java
@@ -62,7 +62,8 @@ public FsHeapKvState restoreState(
final TypeSerializer keySerializer,
final TypeSerializer valueSerializer,
V defaultValue,
- ClassLoader classLoader) throws Exception {
+ ClassLoader classLoader,
+ long recoveryTimestamp) throws Exception {
// validity checks
if (!keySerializer.getClass().getName().equals(keySerializerClassName) ||
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index d7b392cda463e..25c63e5d2252a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -18,14 +18,13 @@
package org.apache.flink.runtime.state.filesystem;
-import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.fs.FSDataOutputStream;
import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.state.StateBackend;
-
+import org.apache.flink.runtime.state.StateHandle;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -38,27 +37,27 @@
/**
* The file state backend is a state backend that stores the state of streaming jobs in a file system.
- *
+ *
* The state backend has one core directory into which it puts all checkpoint data. Inside that
* directory, it creates a directory per job, inside which each checkpoint gets a directory, with
* files for each state, for example:
- *
+ *
* {@code hdfs://namenode:port/flink-checkpoints//chk-17/6ba7b810-9dad-11d1-80b4-00c04fd430c8 }
*/
public class FsStateBackend extends StateBackend {
private static final long serialVersionUID = -8191916350224044011L;
-
+
private static final Logger LOG = LoggerFactory.getLogger(FsStateBackend.class);
-
-
+
+
/** The path to the directory for the checkpoint data, including the file system
* description via scheme and optional authority */
private final Path basePath;
-
+
/** The directory (job specific) into this initialized instance of the backend stores its data */
private transient Path checkpointDirectory;
-
+
/** Cached handle to the file system for file operations */
private transient FileSystem filesystem;
@@ -104,14 +103,14 @@ public FsStateBackend(Path checkpointDataUri) throws IOException {
/**
* Creates a new state backend that stores its checkpoint data in the file system and location
* defined by the given URI.
- *
+ *
* A file system for the file system scheme in the URI (e.g., 'file://', 'hdfs://', or 'S3://')
* must be accessible via {@link FileSystem#get(URI)}.
- *
+ *
*
For a state backend targeting HDFS, this means that the URI must either specify the authority
* (host and port), or that the Hadoop configuration that describes that information must be in the
* classpath.
- *
+ *
* @param checkpointDataUri The URI describing the filesystem (scheme and optionally authority),
* and the path to teh checkpoint data directory.
* @throws IOException Thrown, if no file system can be found for the scheme in the URI.
@@ -119,7 +118,7 @@ public FsStateBackend(Path checkpointDataUri) throws IOException {
public FsStateBackend(URI checkpointDataUri) throws IOException {
final String scheme = checkpointDataUri.getScheme();
final String path = checkpointDataUri.getPath();
-
+
// some validity checks
if (scheme == null) {
throw new IllegalArgumentException("The scheme (hdfs://, file://, etc) is null. " +
@@ -132,12 +131,12 @@ public FsStateBackend(URI checkpointDataUri) throws IOException {
if (path.length() == 0 || path.equals("/")) {
throw new IllegalArgumentException("Cannot use the root directory for checkpoints.");
}
-
+
// we do a bit of work to make sure that the URI for the filesystem refers to exactly the same
// (distributed) filesystem on all hosts and includes full host/port information, even if the
// original URI did not include that. We count on the filesystem loading from the configuration
// to fill in the missing data.
-
+
// try to grab the file system for this path/URI
this.filesystem = FileSystem.get(checkpointDataUri);
if (this.filesystem == null) {
@@ -151,7 +150,7 @@ public FsStateBackend(URI checkpointDataUri) throws IOException {
}
catch (URISyntaxException e) {
throw new IOException(
- String.format("Cannot create file system URI for checkpointDataUri %s and filesystem URI %s",
+ String.format("Cannot create file system URI for checkpointDataUri %s and filesystem URI %s",
checkpointDataUri, fsURI), e);
}
}
@@ -159,7 +158,7 @@ public FsStateBackend(URI checkpointDataUri) throws IOException {
/**
* Gets the base directory where all state-containing files are stored.
* The job specific directory is created inside this directory.
- *
+ *
* @return The base directory.
*/
public Path getBasePath() {
@@ -169,7 +168,7 @@ public Path getBasePath() {
/**
* Gets the directory where this state backend stores its checkpoint data. Will be null if
* the state backend has not been initialized.
- *
+ *
* @return The directory where this state backend stores its checkpoint data.
*/
public Path getCheckpointDirectory() {
@@ -179,16 +178,16 @@ public Path getCheckpointDirectory() {
/**
* Checks whether this state backend is initialized. Note that initialization does not carry
* across serialization. After each serialization, the state backend needs to be initialized.
- *
+ *
* @return True, if the file state backend has been initialized, false otherwise.
*/
public boolean isInitialized() {
- return filesystem != null && checkpointDirectory != null;
+ return filesystem != null && checkpointDirectory != null;
}
/**
* Gets the file system handle for the file system that stores the state for this backend.
- *
+ *
* @return This backend's file system handle.
*/
public FileSystem getFileSystem() {
@@ -203,13 +202,13 @@ public FileSystem getFileSystem() {
// ------------------------------------------------------------------------
// initialization and cleanup
// ------------------------------------------------------------------------
-
+
@Override
- public void initializeForJob(JobID jobId) throws Exception {
- Path dir = new Path(basePath, jobId.toString());
-
+ public void initializeForJob(Environment env) throws Exception {
+ Path dir = new Path(basePath, env.getJobID().toString());
+
LOG.info("Initializing file state backend to URI " + dir);
-
+
filesystem = basePath.getFileSystem();
filesystem.mkdirs(dir);
@@ -220,7 +219,7 @@ public void initializeForJob(JobID jobId) throws Exception {
public void disposeAllStateForCurrentJob() throws Exception {
FileSystem fs = this.filesystem;
Path dir = this.checkpointDirectory;
-
+
if (fs != null && dir != null) {
this.filesystem = null;
this.checkpointDirectory = null;
@@ -237,9 +236,9 @@ public void close() throws Exception {}
// ------------------------------------------------------------------------
// state backend operations
// ------------------------------------------------------------------------
-
+
@Override
- public FsHeapKvState createKvState(
+ public FsHeapKvState createKvState(String stateId, String stateName,
TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws Exception {
return new FsHeapKvState(keySerializer, valueSerializer, defaultValue, this);
}
@@ -254,7 +253,7 @@ public StateHandle checkpointStateSerializable(
final Path checkpointDir = createCheckpointDirPath(checkpointID);
filesystem.mkdirs(checkpointDir);
-
+
Exception latestException = null;
for (int attempt = 0; attempt < 10; attempt++) {
@@ -273,19 +272,19 @@ public StateHandle checkpointStateSerializable(
}
return new FileSerializableStateHandle(targetPath);
}
-
+
throw new Exception("Could not open output stream for state backend", latestException);
}
-
+
@Override
public FsCheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) throws Exception {
checkFileSystemInitialized();
-
+
final Path checkpointDir = createCheckpointDirPath(checkpointID);
filesystem.mkdirs(checkpointDir);
-
+
Exception latestException = null;
-
+
for (int attempt = 0; attempt < 10; attempt++) {
Path targetPath = new Path(checkpointDir, UUID.randomUUID().toString());
try {
@@ -298,7 +297,7 @@ public FsCheckpointStateOutputStream createCheckpointStateOutputStream(long chec
}
throw new Exception("Could not open output stream for state backend", latestException);
}
-
+
// ------------------------------------------------------------------------
// utilities
// ------------------------------------------------------------------------
@@ -308,18 +307,18 @@ private void checkFileSystemInitialized() throws IllegalStateException {
throw new IllegalStateException("filesystem has not been re-initialized after deserialization");
}
}
-
+
private Path createCheckpointDirPath(long checkpointID) {
return new Path(checkpointDirectory, "chk-" + checkpointID);
}
-
+
@Override
public String toString() {
return checkpointDirectory == null ?
- "File State Backend @ " + basePath :
+ "File State Backend @ " + basePath :
"File State Backend (initialized) @ " + checkpointDirectory;
}
-
+
// ------------------------------------------------------------------------
// Output stream for state checkpointing
// ------------------------------------------------------------------------
@@ -331,11 +330,11 @@ public String toString() {
public static final class FsCheckpointStateOutputStream extends CheckpointStateOutputStream {
private final FSDataOutputStream outStream;
-
+
private final Path filePath;
-
+
private final FileSystem fs;
-
+
private boolean closed;
FsCheckpointStateOutputStream(FSDataOutputStream outStream, Path filePath, FileSystem fs) {
@@ -373,7 +372,7 @@ public void close() {
try {
outStream.close();
fs.delete(filePath, false);
-
+
// attempt to delete the parent (will fail and be ignored if the parent has more files)
try {
fs.delete(filePath.getParent(), false);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java
index 1b03defb00852..bda0290e7494b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java
@@ -70,7 +70,8 @@ public MemHeapKvState restoreState(
final TypeSerializer keySerializer,
final TypeSerializer valueSerializer,
V defaultValue,
- ClassLoader classLoader) throws Exception {
+ ClassLoader classLoader,
+ long recoveryTimestamp) throws Exception {
// validity checks
if (!keySerializer.getClass().getName().equals(keySerializerClassName) ||
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 8d297d4902d78..2963237738311 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -18,10 +18,10 @@
package org.apache.flink.runtime.state.memory;
-import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import java.io.ByteArrayOutputStream;
@@ -31,15 +31,15 @@
/**
* A {@link StateBackend} that stores all its data and checkpoints in memory and has no
* capabilities to spill to disk. Checkpoints are serialized and the serialized data is
- * transferred
+ * transferred
*/
public class MemoryStateBackend extends StateBackend {
private static final long serialVersionUID = 4109305377809414635L;
-
+
/** The default maximal size that the snapshotted memory state may have (5 MiBytes) */
private static final int DEFAULT_MAX_STATE_SIZE = 5 * 1024 * 1024;
-
+
/** The maximal size that the snapshotted memory state may have */
private final int maxStateSize;
@@ -54,7 +54,7 @@ public MemoryStateBackend() {
/**
* Creates a new memory state backend that accepts states whose serialized forms are
* up to the given number of bytes.
- *
+ *
* @param maxStateSize The maximal size of the serialized state
*/
public MemoryStateBackend(int maxStateSize) {
@@ -66,7 +66,7 @@ public MemoryStateBackend(int maxStateSize) {
// ------------------------------------------------------------------------
@Override
- public void initializeForJob(JobID job) {
+ public void initializeForJob(Environment env) {
// nothing to do here
}
@@ -81,22 +81,22 @@ public void close() throws Exception {}
// ------------------------------------------------------------------------
// State backend operations
// ------------------------------------------------------------------------
-
+
@Override
- public MemHeapKvState createKvState(
+ public MemHeapKvState createKvState(String stateId, String stateName,
TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) {
return new MemHeapKvState(keySerializer, valueSerializer, defaultValue);
}
-
+
/**
* Serialized the given state into bytes using Java serialization and creates a state handle that
* can re-create that state.
- *
+ *
* @param state The state to checkpoint.
* @param checkpointID The ID of the checkpoint.
* @param timestamp The timestamp of the checkpoint.
* @param The type of the state.
- *
+ *
* @return A state handle that contains the given state serialized as bytes.
* @throws Exception Thrown, if the serialization fails.
*/
@@ -119,7 +119,7 @@ public CheckpointStateOutputStream createCheckpointStateOutputStream(
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
-
+
@Override
public String toString() {
return "MemoryStateBackend (data in heap memory / checkpoints to JobManager)";
@@ -133,18 +133,18 @@ static void checkSize(int size, int maxSize) throws IOException {
+ " . Consider using a different state backend, like the File System State backend.");
}
}
-
+
// ------------------------------------------------------------------------
/**
* A CheckpointStateOutputStream that writes into a byte array.
*/
public static final class MemoryCheckpointOutputStream extends CheckpointStateOutputStream {
-
+
private final ByteArrayOutputStream os = new ByteArrayOutputStream();
-
+
private final int maxSize;
-
+
private boolean closed;
public MemoryCheckpointOutputStream(int maxSize) {
@@ -177,7 +177,7 @@ public StreamStateHandle closeAndGetHandle() throws IOException {
/**
* Closes the stream and returns the byte array containing the stream's data.
* @return The byte array containing the stream's data.
- * @throws IOException Thrown if the size of the data exceeds the maximal
+ * @throws IOException Thrown if the size of the data exceeds the maximal
*/
public byte[] closeAndGetBytes() throws IOException {
if (!closed) {
@@ -191,11 +191,11 @@ public byte[] closeAndGetBytes() throws IOException {
}
}
}
-
+
// ------------------------------------------------------------------------
// Static default instance
// ------------------------------------------------------------------------
-
+
/** The default instance of this state backend, using the default maximal state size */
private static final MemoryStateBackend DEFAULT_INSTANCE = new MemoryStateBackend();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index c8d50c793f744..ae1c0cda2db1a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -219,6 +219,8 @@ public class Task implements Runnable {
* initialization, to be memory friendly */
private volatile SerializedValue> operatorState;
+ private volatile long recoveryTs;
+
/**
* IMPORTANT: This constructor may not start any work that would need to
* be undone in the case of a failing task deployment.
@@ -252,6 +254,7 @@ public Task(TaskDeploymentDescriptor tdd,
this.requiredClasspaths = checkNotNull(tdd.getRequiredClasspaths());
this.nameOfInvokableClass = checkNotNull(tdd.getInvokableClassName());
this.operatorState = tdd.getOperatorState();
+ this.recoveryTs = tdd.getRecoveryTimestamp();
this.memoryManager = checkNotNull(memManager);
this.ioManager = checkNotNull(ioManager);
@@ -535,13 +538,14 @@ else if (current == ExecutionState.CANCELING) {
// get our private reference onto the stack (be safe against concurrent changes)
SerializedValue> operatorState = this.operatorState;
+ long recoveryTs = this.recoveryTs;
if (operatorState != null) {
if (invokable instanceof StatefulTask) {
try {
StateHandle> state = operatorState.deserializeValue(userCodeClassLoader);
StatefulTask> op = (StatefulTask>) invokable;
- StateUtils.setOperatorState(op, state);
+ StateUtils.setOperatorState(op, state, recoveryTs);
}
catch (Exception e) {
throw new RuntimeException("Failed to deserialize state handle and setup initial operator state.", e);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 32c15bf613af1..7b2c2d4b3e8e6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -45,25 +45,25 @@
* Tests concerning the restoring of state from a checkpoint to the task executions.
*/
public class CheckpointStateRestoreTest {
-
+
private static final ClassLoader cl = Thread.currentThread().getContextClassLoader();
-
+
@Test
public void testSetState() {
try {
final SerializedValue> serializedState = new SerializedValue>(
new LocalStateHandle(new SerializableObject()));
-
+
final JobID jid = new JobID();
final JobVertexID statefulId = new JobVertexID();
final JobVertexID statelessId = new JobVertexID();
-
+
Execution statefulExec1 = mockExecution();
Execution statefulExec2 = mockExecution();
Execution statefulExec3 = mockExecution();
Execution statelessExec1 = mockExecution();
Execution statelessExec2 = mockExecution();
-
+
ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0);
ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1);
ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2);
@@ -74,44 +74,44 @@ public void testSetState() {
new ExecutionVertex[] { stateful1, stateful2, stateful3 });
ExecutionJobVertex stateless = mockExecutionJobVertex(statelessId,
new ExecutionVertex[] { stateless1, stateless2 });
-
+
Map map = new HashMap();
map.put(statefulId, stateful);
map.put(statelessId, stateless);
-
-
+
+
CheckpointCoordinator coord = new CheckpointCoordinator(jid, 200000L,
new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
new ExecutionVertex[0], cl,
new StandaloneCheckpointIDCounter(),
new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
-
+
// create ourselves a checkpoint with state
final long timestamp = 34623786L;
coord.triggerCheckpoint(timestamp);
-
+
PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
final long checkpointId = pending.getCheckpointId();
-
+
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, serializedState));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
-
+
assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
assertEquals(0, coord.getNumberOfPendingCheckpoints());
-
+
// let the coordinator inject the state
coord.restoreLatestCheckpointedState(map, true, false);
-
+
// verify that each stateful vertex got the state
- verify(statefulExec1, times(1)).setInitialState(serializedState);
- verify(statefulExec2, times(1)).setInitialState(serializedState);
- verify(statefulExec3, times(1)).setInitialState(serializedState);
- verify(statelessExec1, times(0)).setInitialState(Mockito.>>any());
- verify(statelessExec2, times(0)).setInitialState(Mockito.>>any());
+ verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong());
+ verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong());
+ verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong());
+ verify(statelessExec1, times(0)).setInitialState(Mockito.>>any(), Mockito.anyLong());
+ verify(statelessExec2, times(0)).setInitialState(Mockito.>>any(), Mockito.anyLong());
}
catch (Exception e) {
e.printStackTrace();
@@ -189,7 +189,7 @@ public void testStateOnlyPartiallyAvailable() {
fail(e.getMessage());
}
}
-
+
@Test
public void testNoCheckpointAvailable() {
try {
@@ -213,20 +213,20 @@ public void testNoCheckpointAvailable() {
fail(e.getMessage());
}
}
-
+
// ------------------------------------------------------------------------
private Execution mockExecution() {
return mockExecution(ExecutionState.RUNNING);
}
-
+
private Execution mockExecution(ExecutionState state) {
Execution mock = mock(Execution.class);
when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
when(mock.getState()).thenReturn(state);
return mock;
}
-
+
private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask) {
ExecutionVertex mock = mock(ExecutionVertex.class);
when(mock.getJobvertexId()).thenReturn(vertexId);
@@ -234,7 +234,7 @@ private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID ver
when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
return mock;
}
-
+
private ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
when(vertex.getParallelism()).thenReturn(vertices.length);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
new file mode 100644
index 0000000000000..71bec4a41b8c5
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.operators.testutils;
+
+import java.util.Map;
+import java.util.concurrent.Future;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
+import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
+
+public class DummyEnvironment implements Environment {
+
+ private final String taskName;
+ private final int numSubTasks;
+ private final int subTaskIndex;
+ private final JobID jobId = new JobID();
+ private final JobVertexID jobVertexId = new JobVertexID();
+
+ public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex) {
+ this.taskName = taskName;
+ this.numSubTasks = numSubTasks;
+ this.subTaskIndex = subTaskIndex;
+ }
+
+ @Override
+ public JobID getJobID() {
+ return jobId;
+ }
+
+ @Override
+ public JobVertexID getJobVertexId() {
+ return jobVertexId;
+ }
+
+ @Override
+ public ExecutionAttemptID getExecutionId() {
+ return null;
+ }
+
+ @Override
+ public Configuration getTaskConfiguration() {
+ return null;
+ }
+
+ @Override
+ public TaskManagerRuntimeInfo getTaskManagerInfo() {
+ return null;
+ }
+
+ @Override
+ public Configuration getJobConfiguration() {
+ return null;
+ }
+
+ @Override
+ public int getNumberOfSubtasks() {
+ return numSubTasks;
+ }
+
+ @Override
+ public int getIndexInSubtaskGroup() {
+ return subTaskIndex;
+ }
+
+ @Override
+ public InputSplitProvider getInputSplitProvider() {
+ return null;
+ }
+
+ @Override
+ public IOManager getIOManager() {
+ return null;
+ }
+
+ @Override
+ public MemoryManager getMemoryManager() {
+ return null;
+ }
+
+ @Override
+ public String getTaskName() {
+ return taskName;
+ }
+
+ @Override
+ public String getTaskNameWithSubtasks() {
+ return taskName;
+ }
+
+ @Override
+ public ClassLoader getUserClassLoader() {
+ return null;
+ }
+
+ @Override
+ public Map> getDistributedCacheEntries() {
+ return null;
+ }
+
+ @Override
+ public BroadcastVariableManager getBroadcastVariableManager() {
+ return null;
+ }
+
+ @Override
+ public AccumulatorRegistry getAccumulatorRegistry() {
+ return null;
+ }
+
+ @Override
+ public void acknowledgeCheckpoint(long checkpointId) {
+ }
+
+ @Override
+ public void acknowledgeCheckpoint(long checkpointId, StateHandle> state) {
+ }
+
+ @Override
+ public ResultPartitionWriter getWriter(int index) {
+ return null;
+ }
+
+ @Override
+ public ResultPartitionWriter[] getAllWriters() {
+ return null;
+ }
+
+ @Override
+ public InputGate getInputGate(int index) {
+ return null;
+ }
+
+ @Override
+ public InputGate[] getAllInputGates() {
+ return null;
+ }
+
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index a6cfae39b60e9..37ccde2bb5d19 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -18,8 +18,22 @@
package org.apache.flink.runtime.state;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.util.Random;
+import java.util.UUID;
+
import org.apache.commons.io.FileUtils;
-import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.FloatSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
@@ -29,41 +43,34 @@
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.testutils.CommonTestUtils;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.state.filesystem.FileStreamStateHandle;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.StringValue;
+import org.apache.flink.util.OperatingSystem;
import org.junit.Test;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-import java.net.URI;
-import java.util.Random;
-import java.util.UUID;
-
-import static org.junit.Assert.*;
-
public class FileStateBackendTest {
-
+
@Test
public void testSetupAndSerialization() {
File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
try {
final String backendDir = localFileUri(tempDir);
FsStateBackend originalBackend = new FsStateBackend(backendDir);
-
+
assertFalse(originalBackend.isInitialized());
assertEquals(new URI(backendDir), originalBackend.getBasePath().toUri());
assertNull(originalBackend.getCheckpointDirectory());
-
+
// serialize / copy the backend
FsStateBackend backend = CommonTestUtils.createCopySerializable(originalBackend);
assertFalse(backend.isInitialized());
assertEquals(new URI(backendDir), backend.getBasePath().toUri());
assertNull(backend.getCheckpointDirectory());
-
+
// no file operations should be possible right now
try {
backend.checkpointStateSerializable("exception train rolling in", 2L, System.currentTimeMillis());
@@ -71,17 +78,17 @@ public void testSetupAndSerialization() {
} catch (IllegalStateException e) {
// supreme!
}
-
- backend.initializeForJob(new JobID());
+
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
assertNotNull(backend.getCheckpointDirectory());
-
+
File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
assertTrue(checkpointDir.exists());
assertTrue(isDirectoryEmpty(checkpointDir));
-
+
backend.disposeAllStateForCurrentJob();
assertNull(backend.getCheckpointDirectory());
-
+
assertTrue(isDirectoryEmpty(tempDir));
}
catch (Exception e) {
@@ -92,20 +99,20 @@ public void testSetupAndSerialization() {
deleteDirectorySilently(tempDir);
}
}
-
+
@Test
public void testSerializableState() {
File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir)));
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
String state1 = "dummy state";
String state2 = "row row row your boat";
Integer state3 = 42;
-
+
StateHandle handle1 = backend.checkpointStateSerializable(state1, 439568923746L, System.currentTimeMillis());
StateHandle handle2 = backend.checkpointStateSerializable(state2, 439568923746L, System.currentTimeMillis());
StateHandle handle3 = backend.checkpointStateSerializable(state3, 439568923746L, System.currentTimeMillis());
@@ -113,15 +120,15 @@ public void testSerializableState() {
assertFalse(isDirectoryEmpty(checkpointDir));
assertEquals(state1, handle1.getState(getClass().getClassLoader()));
handle1.discardState();
-
+
assertFalse(isDirectoryEmpty(checkpointDir));
assertEquals(state2, handle2.getState(getClass().getClassLoader()));
handle2.discardState();
-
+
assertFalse(isDirectoryEmpty(checkpointDir));
assertEquals(state3, handle3.getState(getClass().getClassLoader()));
handle3.discardState();
-
+
assertTrue(isDirectoryEmpty(checkpointDir));
}
catch (Exception e) {
@@ -138,7 +145,7 @@ public void testStateOutputStream() {
File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir)));
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
@@ -146,7 +153,7 @@ public void testStateOutputStream() {
byte[] state2 = new byte[1];
byte[] state3 = new byte[0];
byte[] state4 = new byte[177];
-
+
Random rnd = new Random();
rnd.nextBytes(state1);
rnd.nextBytes(state2);
@@ -155,21 +162,21 @@ public void testStateOutputStream() {
long checkpointId = 97231523452L;
- FsStateBackend.FsCheckpointStateOutputStream stream1 =
+ FsStateBackend.FsCheckpointStateOutputStream stream1 =
backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
FsStateBackend.FsCheckpointStateOutputStream stream2 =
backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
FsStateBackend.FsCheckpointStateOutputStream stream3 =
backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
-
+
stream1.write(state1);
stream2.write(state2);
stream3.write(state3);
-
+
FileStreamStateHandle handle1 = stream1.closeAndGetHandle();
FileStreamStateHandle handle2 = stream2.closeAndGetHandle();
FileStreamStateHandle handle3 = stream3.closeAndGetHandle();
-
+
// use with try-with-resources
StreamStateHandle handle4;
try (StateBackend.CheckpointStateOutputStream stream4 =
@@ -177,7 +184,7 @@ public void testStateOutputStream() {
stream4.write(state4);
handle4 = stream4.closeAndGetHandle();
}
-
+
// close before accessing handle
StateBackend.CheckpointStateOutputStream stream5 =
backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
@@ -189,22 +196,22 @@ public void testStateOutputStream() {
} catch (IOException e) {
// uh-huh
}
-
+
validateBytesInStream(handle1.getState(getClass().getClassLoader()), state1);
handle1.discardState();
assertFalse(isDirectoryEmpty(checkpointDir));
ensureLocalFileDeleted(handle1.getFilePath());
-
+
validateBytesInStream(handle2.getState(getClass().getClassLoader()), state2);
handle2.discardState();
assertFalse(isDirectoryEmpty(checkpointDir));
ensureLocalFileDeleted(handle2.getFilePath());
-
+
validateBytesInStream(handle3.getState(getClass().getClassLoader()), state3);
handle3.discardState();
assertFalse(isDirectoryEmpty(checkpointDir));
ensureLocalFileDeleted(handle3.getFilePath());
-
+
validateBytesInStream(handle4.getState(getClass().getClassLoader()), state4);
handle4.discardState();
assertTrue(isDirectoryEmpty(checkpointDir));
@@ -223,12 +230,12 @@ public void testKeyValueState() {
File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir)));
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
KvState kv =
- backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
+ backend.createKvState("0", "a", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
assertEquals(0, kv.size());
@@ -272,7 +279,7 @@ public void testKeyValueState() {
// restore the first snapshot and validate it
KvState restored1 = snapshot1.restoreState(backend,
- IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader());
+ IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1);
assertEquals(2, restored1.size());
restored1.setCurrentKey(1);
@@ -282,7 +289,7 @@ public void testKeyValueState() {
// restore the first snapshot and validate it
KvState restored2 = snapshot2.restoreState(backend,
- IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader());
+ IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1);
assertEquals(3, restored2.size());
restored2.setCurrentKey(1);
@@ -312,12 +319,12 @@ public void testRestoreWithWrongSerializers() {
File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir)));
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
-
+
KvState kv =
- backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
+ backend.createKvState("a_0", "a", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
kv.setCurrentKey(1);
kv.update("1");
@@ -338,7 +345,7 @@ public void testRestoreWithWrongSerializers() {
try {
snapshot.restoreState(backend, fakeIntSerializer,
- StringSerializer.INSTANCE, null, getClass().getClassLoader());
+ StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1);
fail("should recognize wrong serializers");
} catch (IllegalArgumentException e) {
// expected
@@ -348,7 +355,7 @@ public void testRestoreWithWrongSerializers() {
try {
snapshot.restoreState(backend, IntSerializer.INSTANCE,
- fakeStringSerializer, null, getClass().getClassLoader());
+ fakeStringSerializer, null, getClass().getClassLoader(), 1);
fail("should recognize wrong serializers");
} catch (IllegalArgumentException e) {
// expected
@@ -358,14 +365,14 @@ public void testRestoreWithWrongSerializers() {
try {
snapshot.restoreState(backend, fakeIntSerializer,
- fakeStringSerializer, null, getClass().getClassLoader());
+ fakeStringSerializer, null, getClass().getClassLoader(), 1);
fail("should recognize wrong serializers");
} catch (IllegalArgumentException e) {
// expected
} catch (Exception e) {
fail("wrong exception");
}
-
+
snapshot.discardState();
assertTrue(isDirectoryEmpty(checkpointDir));
@@ -384,10 +391,10 @@ public void testCopyDefaultValue() {
File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir)));
- backend.initializeForJob(new JobID());
-
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
+
KvState kv =
- backend.createKvState(IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1));
+ backend.createKvState("a_0", "a", IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1));
kv.setCurrentKey(1);
IntValue default1 = kv.value();
@@ -408,11 +415,11 @@ public void testCopyDefaultValue() {
deleteDirectorySilently(tempDir);
}
}
-
+
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
-
+
private static void ensureLocalFileDeleted(Path path) {
URI uri = path.toUri();
if ("file".equals(uri.getScheme())) {
@@ -423,23 +430,23 @@ private static void ensureLocalFileDeleted(Path path) {
throw new IllegalArgumentException("not a local path");
}
}
-
+
private static void deleteDirectorySilently(File dir) {
try {
FileUtils.deleteDirectory(dir);
}
catch (IOException ignored) {}
}
-
+
private static boolean isDirectoryEmpty(File directory) {
String[] nested = directory.list();
return nested == null || nested.length == 0;
}
-
+
private static String localFileUri(File path) {
return path.toURI().toString();
}
-
+
private static void validateBytesInStream(InputStream is, byte[] data) throws IOException {
byte[] holder = new byte[data.length];
assertEquals("not enough data", holder.length, is.read(holder));
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index f6d1bb51881c8..4b5aebd0c74cf 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -40,7 +40,7 @@
* Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}.
*/
public class MemoryStateBackendTest {
-
+
@Test
public void testSerializableState() {
try {
@@ -49,10 +49,10 @@ public void testSerializableState() {
HashMap state = new HashMap<>();
state.put("hey there", 2);
state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77);
-
+
StateHandle> handle = backend.checkpointStateSerializable(state, 12, 459);
assertNotNull(handle);
-
+
HashMap restored = handle.getState(getClass().getClassLoader());
assertEquals(state, restored);
}
@@ -99,7 +99,7 @@ public void testStateStream() {
oos.writeObject(state);
oos.flush();
StreamStateHandle handle = os.closeAndGetHandle();
-
+
assertNotNull(handle);
ObjectInputStream ois = new ObjectInputStream(handle.getState(getClass().getClassLoader()));
@@ -124,7 +124,7 @@ public void testOversizedStateStream() {
StateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2);
ObjectOutputStream oos = new ObjectOutputStream(os);
-
+
try {
oos.writeObject(state);
oos.flush();
@@ -140,17 +140,17 @@ public void testOversizedStateStream() {
fail(e.getMessage());
}
}
-
+
@Test
public void testKeyValueState() {
try {
MemoryStateBackend backend = new MemoryStateBackend();
-
- KvState kv =
- backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
-
+
+ KvState kv =
+ backend.createKvState("s_0", "s", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
+
assertEquals(0, kv.size());
-
+
// some modifications to the state
kv.setCurrentKey(1);
assertNull(kv.value());
@@ -163,7 +163,7 @@ public void testKeyValueState() {
kv.setCurrentKey(1);
assertEquals("1", kv.value());
assertEquals(2, kv.size());
-
+
// draw a snapshot
KvStateSnapshot snapshot1 =
kv.snapshot(682375462378L, System.currentTimeMillis());
@@ -179,7 +179,7 @@ public void testKeyValueState() {
// draw another snapshot
KvStateSnapshot snapshot2 =
kv.snapshot(682375462379L, System.currentTimeMillis());
-
+
// validate the original state
assertEquals(3, kv.size());
kv.setCurrentKey(1);
@@ -188,10 +188,10 @@ public void testKeyValueState() {
assertEquals("u2", kv.value());
kv.setCurrentKey(3);
assertEquals("u3", kv.value());
-
+
// restore the first snapshot and validate it
- KvState restored1 = snapshot1.restoreState(backend,
- IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader());
+ KvState restored1 = snapshot1.restoreState(backend,
+ IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1);
assertEquals(2, restored1.size());
restored1.setCurrentKey(1);
@@ -201,7 +201,7 @@ public void testKeyValueState() {
// restore the first snapshot and validate it
KvState restored2 = snapshot2.restoreState(backend,
- IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader());
+ IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1);
assertEquals(3, restored2.size());
restored2.setCurrentKey(1);
@@ -216,34 +216,34 @@ public void testKeyValueState() {
fail(e.getMessage());
}
}
-
+
@Test
public void testRestoreWithWrongSerializers() {
try {
MemoryStateBackend backend = new MemoryStateBackend();
KvState kv =
- backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
-
+ backend.createKvState("s_0", "s", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null);
+
kv.setCurrentKey(1);
kv.update("1");
kv.setCurrentKey(2);
kv.update("2");
-
+
KvStateSnapshot snapshot =
kv.snapshot(682375462378L, System.currentTimeMillis());
@SuppressWarnings("unchecked")
- TypeSerializer fakeIntSerializer =
+ TypeSerializer fakeIntSerializer =
(TypeSerializer) (TypeSerializer>) FloatSerializer.INSTANCE;
@SuppressWarnings("unchecked")
- TypeSerializer fakeStringSerializer =
+ TypeSerializer fakeStringSerializer =
(TypeSerializer) (TypeSerializer>) new ValueSerializer(StringValue.class);
try {
snapshot.restoreState(backend, fakeIntSerializer,
- StringSerializer.INSTANCE, null, getClass().getClassLoader());
+ StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1);
fail("should recognize wrong serializers");
} catch (IllegalArgumentException e) {
// expected
@@ -253,7 +253,7 @@ public void testRestoreWithWrongSerializers() {
try {
snapshot.restoreState(backend, IntSerializer.INSTANCE,
- fakeStringSerializer, null, getClass().getClassLoader());
+ fakeStringSerializer, null, getClass().getClassLoader(), 1);
fail("should recognize wrong serializers");
} catch (IllegalArgumentException e) {
// expected
@@ -263,7 +263,7 @@ public void testRestoreWithWrongSerializers() {
try {
snapshot.restoreState(backend, fakeIntSerializer,
- fakeStringSerializer, null, getClass().getClassLoader());
+ fakeStringSerializer, null, getClass().getClassLoader(), 1);
fail("should recognize wrong serializers");
} catch (IllegalArgumentException e) {
// expected
@@ -276,20 +276,20 @@ public void testRestoreWithWrongSerializers() {
fail(e.getMessage());
}
}
-
+
@Test
public void testCopyDefaultValue() {
try {
MemoryStateBackend backend = new MemoryStateBackend();
KvState kv =
- backend.createKvState(IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1));
+ backend.createKvState("a_0", "a", IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1));
kv.setCurrentKey(1);
IntValue default1 = kv.value();
kv.setCurrentKey(2);
IntValue default2 = kv.value();
-
+
assertNotNull(default1);
assertNotNull(default2);
assertEquals(default1, default2);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index fff6e7019ffdf..85f8be58651b4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -198,7 +198,7 @@ public void invoke() throws Exception {
}
@Override
- public void setInitialState(StateHandle stateHandle) throws Exception {
+ public void setInitialState(StateHandle stateHandle, long ts) throws Exception {
}
diff --git a/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java b/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
index 4e4acd2b6fbe8..4fb68203f5433 100644
--- a/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
+++ b/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
@@ -20,7 +20,6 @@
import org.apache.commons.io.FileUtils;
-import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.core.fs.FileStatus;
import org.apache.flink.core.fs.FileSystem;
@@ -29,7 +28,7 @@
import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.state.filesystem.FileStreamStateHandle;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
-
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.hadoop.conf.Configuration;
@@ -63,7 +62,7 @@ public class FileStateBackendTest {
private static MiniDFSCluster HDFS_CLUSTER;
private static FileSystem FS;
-
+
// ------------------------------------------------------------------------
// startup / shutdown
// ------------------------------------------------------------------------
@@ -127,7 +126,7 @@ public void testSetupAndSerialization() {
// supreme!
}
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
assertNotNull(backend.getCheckpointDirectory());
Path checkpointDir = backend.getCheckpointDirectory();
@@ -150,7 +149,7 @@ public void testSerializableState() {
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(randomHdfsFileUri()));
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
Path checkpointDir = backend.getCheckpointDirectory();
@@ -186,7 +185,7 @@ public void testSerializableState() {
public void testStateOutputStream() {
try {
FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(randomHdfsFileUri()));
- backend.initializeForJob(new JobID());
+ backend.initializeForJob(new DummyEnvironment("test", 0, 0));
Path checkpointDir = backend.getCheckpointDirectory();
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index ce4d174763216..3f1cfae68485e 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -23,6 +23,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.runtime.state.KvState;
import org.apache.flink.runtime.state.KvStateSnapshot;
@@ -93,6 +94,8 @@ public abstract class AbstractStreamOperator
private transient TypeSerializer> keySerializer;
private transient HashMap> keyValueStateSnapshots;
+
+ private long recoveryTimestamp;
// ------------------------------------------------------------------------
// Life Cycle
@@ -172,15 +175,23 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp)
}
@Override
- public void restoreState(StreamTaskState state) throws Exception {
+ public void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception {
// restore the key/value state. the actual restore happens lazily, when the function requests
// the state again, because the restore method needs information provided by the user function
keyValueStateSnapshots = state.getKvStates();
+ this.recoveryTimestamp = recoveryTimestamp;
}
@Override
public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
- // by default, nothing needs a notification of checkpoint completion
+ // We check whether the KvStates require notifications
+ if (keyValueStates != null) {
+ for (KvState, ?, ?> kvstate : keyValueStates) {
+ if (kvstate instanceof CheckpointNotifier) {
+ ((CheckpointNotifier) kvstate).notifyCheckpointComplete(checkpointId);
+ }
+ }
+ }
}
// ------------------------------------------------------------------------
@@ -269,7 +280,7 @@ protected OperatorState createKeyValueState(
* @throws IllegalStateException Thrown, if the key/value state was already initialized.
* @throws Exception Thrown, if the state backend cannot create the key/value state.
*/
- @SuppressWarnings({"rawtypes", "unchecked"})
+ @SuppressWarnings("unchecked")
protected > OperatorState createKeyValueState(
String name, TypeSerializer valueSerializer, V defaultValue) throws Exception
{
@@ -304,25 +315,25 @@ else if (this.keySerializer != null) {
throw new RuntimeException();
}
- @SuppressWarnings("unchecked")
Backend stateBackend = (Backend) container.getStateBackend();
KvState kvstate = null;
// check whether we restore the key/value state from a snapshot, or create a new blank one
if (keyValueStateSnapshots != null) {
- @SuppressWarnings("unchecked")
KvStateSnapshot snapshot = (KvStateSnapshot) keyValueStateSnapshots.remove(name);
if (snapshot != null) {
kvstate = snapshot.restoreState(
- stateBackend, keySerializer, valueSerializer, defaultValue, getUserCodeClassloader());
+ stateBackend, keySerializer, valueSerializer, defaultValue, getUserCodeClassloader(), recoveryTimestamp);
}
}
if (kvstate == null) {
+ // create unique state id from operator id + state name
+ String stateId = name + "_" + getOperatorConfig().getVertexID();
// create a new blank key/value state
- kvstate = stateBackend.createKvState(keySerializer, valueSerializer, defaultValue);
+ kvstate = stateBackend.createKvState(stateId ,name , keySerializer, valueSerializer, defaultValue);
}
if (keyValueStatesByName == null) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 32be2ba49e9bb..c20544565ad13 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -147,8 +147,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp)
}
@Override
- public void restoreState(StreamTaskState state) throws Exception {
- super.restoreState(state);
+ public void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception {
+ super.restoreState(state, recoveryTimestamp);
StateHandle stateHandle = state.getFunctionState();
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
index fac26f153b9d9..1ef3298316f5f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
@@ -112,11 +112,13 @@ public interface StreamOperator extends Serializable {
*
* @param state The state of operator that was snapshotted as part of checkpoint
* from which the execution is restored.
+ *
+ * @param recoveryTimestamp Global recovery timestamp
*
* @throws Exception Exceptions during state restore should be forwarded, so that the system can
* properly react to failed state restore and fail the execution attempt.
*/
- void restoreState(StreamTaskState state) throws Exception;
+ void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception;
/**
* Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java
index 90d3d82bf6b32..677a7dd2db4d5 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java
@@ -264,8 +264,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp)
}
@Override
- public void restoreState(StreamTaskState taskState) throws Exception {
- super.restoreState(taskState);
+ public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception {
+ super.restoreState(taskState, recoveryTimestamp);
@SuppressWarnings("unchecked")
StateHandle inputState = (StateHandle) taskState.getOperatorState();
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java
index 5e4dea7e52a83..782363139e8b3 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java
@@ -536,8 +536,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp)
}
@Override
- public void restoreState(StreamTaskState taskState) throws Exception {
- super.restoreState(taskState);
+ public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception {
+ super.restoreState(taskState, recoveryTimestamp);
final ClassLoader userClassloader = getUserCodeClassloader();
@SuppressWarnings("unchecked")
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
index f19e76057444a..68c3a5f26bf13 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
@@ -609,8 +609,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp)
}
@Override
- public void restoreState(StreamTaskState taskState) throws Exception {
- super.restoreState(taskState);
+ public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception {
+ super.restoreState(taskState, recoveryTimestamp);
final ClassLoader userClassloader = getUserCodeClassloader();
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 80c63dab90407..c310439789284 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -32,6 +32,7 @@
import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
import org.apache.flink.runtime.util.event.EventListener;
+import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -138,6 +139,8 @@ public abstract class StreamTask>
/** Flag to mark the task "in operation", in which case check
* needs to be initialized to true, so that early cancel() before invoke() behaves correctly */
private volatile boolean isRunning;
+
+ private long recoveryTimestamp;
// ------------------------------------------------------------------------
@@ -169,7 +172,7 @@ public final void registerInputOutput() throws Exception {
accumulatorMap = accumulatorRegistry.getUserMap();
stateBackend = createStateBackend();
- stateBackend.initializeForJob(getEnvironment().getJobID());
+ stateBackend.initializeForJob(getEnvironment());
headOperator = configuration.getStreamOperator(userClassLoader);
operatorChain = new OperatorChain<>(this, headOperator, accumulatorRegistry.getReadWriteReporter());
@@ -382,8 +385,9 @@ public RecordWriterOutput>[] getStreamOutputs() {
// ------------------------------------------------------------------------
@Override
- public void setInitialState(StreamTaskStateList initialState) {
+ public void setInitialState(StreamTaskStateList initialState, long recoveryTimestamp) {
lazyRestoreState = initialState;
+ this.recoveryTimestamp = recoveryTimestamp;
}
public void restoreStateLazy() throws Exception {
@@ -403,7 +407,7 @@ public void restoreStateLazy() throws Exception {
if (state != null && operator != null) {
LOG.debug("Task {} in chain ({}) has checkpointed state", i, getName());
- operator.restoreState(state);
+ operator.restoreState(state, recoveryTimestamp);
}
else if (operator != null) {
LOG.debug("Task {} in chain ({}) does not have checkpointed state", i, getName());
@@ -464,6 +468,11 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception {
if (isRunning) {
LOG.debug("Notification of complete checkpoint for task {}", getName());
+ // We first notify the state backend if necessary
+ if (stateBackend instanceof CheckpointNotifier) {
+ ((CheckpointNotifier) stateBackend).notifyCheckpointComplete(checkpointId);
+ }
+
for (StreamOperator> operator : operatorChain.getAllOperators()) {
if (operator != null) {
operator.notifyOfCompletedCheckpoint(checkpointId);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
index 62eb268dac38b..63cbd6a281275 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
@@ -595,7 +595,7 @@ public void checkpointRestoreWithPendingWindowTumbling() {
windowSize, windowSize);
op.setup(mockTask, new StreamConfig(new Configuration()), out2);
- op.restoreState(state);
+ op.restoreState(state, 1);
op.open();
// inject some more elements
@@ -694,7 +694,7 @@ public void checkpointRestoreWithPendingWindowSliding() {
windowSize, windowSlide);
op.setup(mockTask, new StreamConfig(new Configuration()), out2);
- op.restoreState(state);
+ op.restoreState(state, 1);
op.open();
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
index 4d507fb0919dc..55cd9fe1144ba 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
@@ -653,7 +653,7 @@ public void checkpointRestoreWithPendingWindowTumbling() {
windowSize, windowSize);
op.setup(mockTask, new StreamConfig(new Configuration()), out2);
- op.restoreState(state);
+ op.restoreState(state, 1);
op.open();
// inject the remaining elements
@@ -759,7 +759,7 @@ public void checkpointRestoreWithPendingWindowSliding() {
windowSize, windowSlide);
op.setup(mockTask, new StreamConfig(new Configuration()), out2);
- op.restoreState(state);
+ op.restoreState(state, 1);
op.open();
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
index 67c018912f5af..42b62303d8b42 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
@@ -50,10 +50,11 @@
public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTestBase {
final long NUM_STRINGS = 10_000_000L;
+ final static int NUM_KEYS = 40;
@Override
public void testProgram(StreamExecutionEnvironment env) {
- assertTrue("Broken test setup", (NUM_STRINGS/2) % 40 == 0);
+ assertTrue("Broken test setup", (NUM_STRINGS/2) % NUM_KEYS == 0);
DataStream stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
DataStream stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
@@ -69,14 +70,14 @@ public void testProgram(StreamExecutionEnvironment env) {
public void postSubmit() {
// verify that we counted exactly right
for (Entry sum : OnceFailingPartitionedSum.allSums.entrySet()) {
- assertEquals(new Long(sum.getKey() * NUM_STRINGS / 40), sum.getValue());
+ assertEquals(new Long(sum.getKey() * NUM_STRINGS / NUM_KEYS), sum.getValue());
}
for (Long count : CounterSink.allCounts.values()) {
- assertEquals(new Long(NUM_STRINGS / 40), count);
+ assertEquals(new Long(NUM_STRINGS / NUM_KEYS), count);
}
- assertEquals(40, CounterSink.allCounts.size());
- assertEquals(40, OnceFailingPartitionedSum.allSums.size());
+ assertEquals(NUM_KEYS, CounterSink.allCounts.size());
+ assertEquals(NUM_KEYS, OnceFailingPartitionedSum.allSums.size());
}
// --------------------------------------------------------------------------------------------
@@ -120,7 +121,7 @@ public void run(SourceContext ctx) throws Exception {
synchronized (lockingObject) {
index += step;
- ctx.collect(index % 40);
+ ctx.collect(index % NUM_KEYS);
}
}
}
@@ -160,9 +161,9 @@ private static class OnceFailingPartitionedSum extends RichMapFunction value) throws Exception {
}
}
- private static class NonSerializableLong {
+ public static class NonSerializableLong {
public Long value;
private NonSerializableLong(long value) {
@@ -225,7 +226,7 @@ public static NonSerializableLong of(long value) {
}
}
- private static class IdentityKeySelector implements KeySelector