diff --git a/flink-contrib/flink-streaming-contrib/pom.xml b/flink-contrib/flink-streaming-contrib/pom.xml index 68e65f6a9c253..22b11b2a45500 100644 --- a/flink-contrib/flink-streaming-contrib/pom.xml +++ b/flink-contrib/flink-streaming-contrib/pom.xml @@ -53,6 +53,30 @@ under the License. ${project.version} test + + org.apache.flink + flink-tests + ${project.version} + test-jar + test + + + com.google.guava + guava + ${guava.version} + + + org.apache.derby + derbyclient + 10.12.1.1 + test + + + org.apache.derby + derbynet + 10.12.1.1 + test + diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbAdapter.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbAdapter.java new file mode 100644 index 0000000000000..26c27ddd0b1ca --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbAdapter.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.IOException; +import java.io.Serializable; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; + +import org.apache.flink.api.java.tuple.Tuple2; + +/** + * Adapter interface for executing different checkpointing related operations on + * the underlying database. + * + */ +public interface DbAdapter extends Serializable { + + /** + * Initialize tables for storing non-partitioned checkpoints for the given + * job id and database connection. + * + */ + void createCheckpointsTable(String jobId, Connection con) throws SQLException; + + /** + * Checkpoints will be inserted in the database using prepared statements. + * This methods should prepare and return the statement that will be used + * later to insert using the given connection. + * + */ + PreparedStatement prepareCheckpointInsert(String jobId, Connection con) throws SQLException; + + /** + * Set the {@link PreparedStatement} parameters for the statement returned + * by {@link #prepareCheckpointInsert(String, Connection)}. + * + * @param jobId + * Id of the current job. + * @param insertStatement + * Statement returned by + * {@link #prepareCheckpointInsert(String, Connection)}. + * @param checkpointId + * Global checkpoint id. + * @param timestamp + * Global checkpoint timestamp. + * @param handleId + * Unique id assigned to this state checkpoint (should be primary + * key). + * @param checkpoint + * The serialized checkpoint. + * @throws SQLException + */ + void setCheckpointInsertParams(String jobId, PreparedStatement insertStatement, long checkpointId, + long timestamp, long handleId, byte[] checkpoint) throws SQLException; + + /** + * Retrieve the serialized checkpoint data from the database. + * + * @param jobId + * Id of the current job. + * @param con + * Database connection + * @param checkpointId + * Global checkpoint id. + * @param checkpointTs + * Global checkpoint timestamp. + * @param handleId + * Unique id assigned to this state checkpoint (should be primary + * key). + * @return The byte[] corresponding to the checkpoint or null if missing. + * @throws SQLException + */ + byte[] getCheckpoint(String jobId, Connection con, long checkpointId, long checkpointTs, long handleId) + throws SQLException; + + /** + * Remove the given checkpoint from the database. + * + * @param jobId + * Id of the current job. + * @param con + * Database connection + * @param checkpointId + * Global checkpoint id. + * @param checkpointTs + * Global checkpoint timestamp. + * @param handleId + * Unique id assigned to this state checkpoint (should be primary + * key). + * @return The byte[] corresponding to the checkpoint or null if missing. + * @throws SQLException + */ + void deleteCheckpoint(String jobId, Connection con, long checkpointId, long checkpointTs, long handleId) + throws SQLException; + + /** + * Remove all states for the given JobId, by for instance dropping the + * entire table. + * + * @throws SQLException + */ + void disposeAllStateForJob(String jobId, Connection con) throws SQLException; + + /** + * Initialize the necessary tables for the given stateId. The state id + * consist of the JobId+OperatorId+StateName. + * + */ + void createKVStateTable(String stateId, Connection con) throws SQLException; + + /** + * Prepare the the statement that will be used to insert key-value pairs in + * the database. + * + */ + String prepareKVCheckpointInsert(String stateId) throws SQLException; + + /** + * Prepare the statement that will be used to lookup keys from the database. + * Keys and values are assumed to be byte arrays. + * + */ + String prepareKeyLookup(String stateId) throws SQLException; + + /** + * Retrieve the latest value from the database for a given key and + * timestamp. + * + * @param stateId + * Unique identifier of the kvstate (usually the table name). + * @param lookupStatement + * The statement returned by + * {@link #prepareKeyLookup(String, Connection)}. + * @param key + * The key to lookup. + * @return The latest valid value for the key. + * @throws SQLException + */ + byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupId) + throws SQLException; + + /** + * Clean up states between the checkpoint and recovery timestamp. + * + */ + void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTimestamp, + long recoveryTimestamp) throws SQLException; + + /** + * Insert a list of Key-Value pairs into the database. The suggested + * approach is to use idempotent inserts(updates) as 1 batch operation. + * + */ + void insertBatch(String stateId, DbBackendConfig conf, Connection con, PreparedStatement insertStatement, + long checkpointTimestamp, List> toInsert) throws IOException; + + /** + * Compact the states between two checkpoint timestamp by only keeping the + * most recent. + */ + void compactKvStates(String kvStateId, Connection con, long lowerTs, long upperTs) throws SQLException; + + /** + * Execute a simple operation to refresh the current database connection in + * case no data is written for a longer time period. Usually something like + * "select 1" + */ + void keepAlive(Connection con) throws SQLException; + +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbBackendConfig.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbBackendConfig.java new file mode 100644 index 0000000000000..883b65ab8ff12 --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbBackendConfig.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.Serializable; +import java.sql.SQLException; +import java.util.List; + +import org.apache.flink.contrib.streaming.state.ShardedConnection.Partitioner; + +import com.google.common.collect.Lists; + +/** + * + * Configuration object for {@link DbStateBackend}, containing information to + * shard and connect to the databases that will store the state checkpoints. + * + */ +public class DbBackendConfig implements Serializable { + + private static final long serialVersionUID = 1L; + + // Database connection properties + private final String userName; + private final String userPassword; + private final List shardUrls; + + // JDBC Driver + DbAdapter information + private DbAdapter dbAdapter = new MySqlAdapter(); + private String JDBCDriver = null; + + private int maxNumberOfSqlRetries = 5; + private int sleepBetweenSqlRetries = 100; + + // KvState properties + private int kvStateCacheSize = 10000; + private int maxKvInsertBatchSize = 1000; + private float maxKvEvictFraction = 0.1f; + private int kvStateCompactionFreq = -1; + + private Partitioner shardPartitioner; + + /** + * Creates a new sharded database state backend configuration with the given + * parameters and default {@link MySqlAdapter}. + * + * @param dbUserName + * The username used to connect to the database at the given url. + * @param dbUserPassword + * The password used to connect to the database at the given url + * and username. + * @param dbShardUrls + * The list of JDBC urls of the databases that will be used as + * shards for the state backend. Sharding of the state will + * happen based on the subtask index of the given task. + */ + public DbBackendConfig(String dbUserName, String dbUserPassword, List dbShardUrls) { + this.userName = dbUserName; + this.userPassword = dbUserPassword; + this.shardUrls = dbShardUrls; + } + + /** + * Creates a new database state backend configuration with the given + * parameters and default {@link MySqlAdapter}. + * + * @param dbUserName + * The username used to connect to the database at the given url. + * @param dbUserPassword + * The password used to connect to the database at the given url + * and username. + * @param dbUrl + * The JDBC url of the database for example + * "jdbc:mysql://localhost:3306/flinkdb". + */ + public DbBackendConfig(String dbUserName, String dbUserPassword, String dbUrl) { + this(dbUserName, dbUserPassword, Lists.newArrayList(dbUrl)); + } + + /** + * The username used to connect to the database at the given urls. + */ + public String getUserName() { + return userName; + } + + /** + * The password used to connect to the database at the given url and + * username. + */ + public String getUserPassword() { + return userPassword; + } + + /** + * Number of database shards defined. + */ + public int getNumberOfShards() { + return shardUrls.size(); + } + + /** + * Database shard urls as provided in the constructor. + * + */ + public List getShardUrls() { + return shardUrls; + } + + /** + * The url of the first shard. + * + */ + public String getUrl() { + return getShardUrl(0); + } + + /** + * The url of a specific shard. + * + */ + public String getShardUrl(int shardIndex) { + validateShardIndex(shardIndex); + return shardUrls.get(shardIndex); + } + + private void validateShardIndex(int i) { + if (i < 0) { + throw new IllegalArgumentException("Index must be positive."); + } else if (getNumberOfShards() <= i) { + throw new IllegalArgumentException("Index must be less then the total number of shards."); + } + } + + /** + * Get the {@link DbAdapter} that will be used to operate on the database + * during checkpointing. + * + */ + public DbAdapter getDbAdapter() { + return dbAdapter; + } + + /** + * Set the {@link DbAdapter} that will be used to operate on the database + * during checkpointing. + * + */ + public void setDbAdapter(DbAdapter adapter) { + this.dbAdapter = adapter; + } + + /** + * The class name that should be used to load the JDBC driver using + * Class.forName(JDBCDriverClass). + */ + public String getJDBCDriver() { + return JDBCDriver; + } + + /** + * Set the class name that should be used to load the JDBC driver using + * Class.forName(JDBCDriverClass). + */ + public void setJDBCDriver(String jDBCDriverClassName) { + JDBCDriver = jDBCDriverClassName; + } + + /** + * The maximum number of key-value pairs stored in one task instance's cache + * before evicting to the underlying database. + * + */ + public int getKvCacheSize() { + return kvStateCacheSize; + } + + /** + * Set the maximum number of key-value pairs stored in one task instance's + * cache before evicting to the underlying database. When the cache is full + * the N least recently used keys will be evicted to the database, where N = + * maxKvEvictFraction*KvCacheSize. + * + */ + public void setKvCacheSize(int size) { + kvStateCacheSize = size; + } + + /** + * The maximum number of key-value pairs inserted in the database as one + * batch operation. + */ + public int getMaxKvInsertBatchSize() { + return maxKvInsertBatchSize; + } + + /** + * Set the maximum number of key-value pairs inserted in the database as one + * batch operation. + */ + public void setMaxKvInsertBatchSize(int size) { + maxKvInsertBatchSize = size; + } + + /** + * Sets the maximum fraction of key-value states evicted from the cache if + * the cache is full. + */ + public void setMaxKvCacheEvictFraction(float fraction) { + if (fraction > 1 || fraction <= 0) { + throw new RuntimeException("Must be a number between 0 and 1"); + } else { + maxKvEvictFraction = fraction; + } + } + + /** + * The maximum fraction of key-value states evicted from the cache if the + * cache is full. + */ + public float getMaxKvCacheEvictFraction() { + return maxKvEvictFraction; + } + + /** + * The number of elements that will be evicted when the cache is full. + * + */ + public int getNumElementsToEvict() { + return (int) Math.ceil(getKvCacheSize() * getMaxKvCacheEvictFraction()); + } + + /** + * Sets how often will automatic compaction be performed on the database to + * remove old overwritten state changes. The frequency is set in terms of + * number of successful checkpoints between two compactions and should take + * the state size and checkpoint frequency into account. + *

+ * By default automatic compaction is turned off. + */ + public void setKvStateCompactionFrequency(int compactEvery) { + this.kvStateCompactionFreq = compactEvery; + } + + /** + * Sets how often will automatic compaction be performed on the database to + * remove old overwritten state changes. The frequency is set in terms of + * number of successful checkpoints between two compactions and should take + * the state size and checkpoint frequency into account. + *

+ * By default automatic compaction is turned off. + */ + public int getKvStateCompactionFrequency() { + return kvStateCompactionFreq; + } + + /** + * The number of times each SQL command will be retried on failure. + */ + public int getMaxNumberOfSqlRetries() { + return maxNumberOfSqlRetries; + } + + /** + * Sets the number of times each SQL command will be retried on failure. + */ + public void setMaxNumberOfSqlRetries(int maxNumberOfSqlRetries) { + this.maxNumberOfSqlRetries = maxNumberOfSqlRetries; + } + + /** + * The number of milliseconds slept between two SQL retries. The actual + * sleep time will be chosen randomly between 1 and the given time. + * + */ + public int getSleepBetweenSqlRetries() { + return sleepBetweenSqlRetries; + } + + /** + * Sets the number of milliseconds slept between two SQL retries. The actual + * sleep time will be chosen randomly between 1 and the given time. + * + */ + public void setSleepBetweenSqlRetries(int sleepBetweenSqlRetries) { + this.sleepBetweenSqlRetries = sleepBetweenSqlRetries; + } + + /** + * Sets the partitioner used to assign keys to different database shards + */ + public void setPartitioner(Partitioner partitioner) { + this.shardPartitioner = partitioner; + } + + /** + * Creates a new {@link ShardedConnection} using the set parameters. + * + * @throws SQLException + */ + public ShardedConnection createShardedConnection() throws SQLException { + if (JDBCDriver != null) { + try { + Class.forName(JDBCDriver); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Could not load JDBC driver class", e); + } + } + if (shardPartitioner == null) { + return new ShardedConnection(shardUrls, userName, userPassword); + } else { + return new ShardedConnection(shardUrls, userName, userPassword, shardPartitioner); + } + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateBackend.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateBackend.java new file mode 100644 index 0000000000000..72482aedeb0c0 --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateBackend.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.IOException; +import java.io.Serializable; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.Random; +import java.util.concurrent.Callable; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.util.InstantiationUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.flink.contrib.streaming.state.SQLRetrier.retry; + +/** + * {@link StateBackend} for storing checkpoints in JDBC supporting databases. + * Key-Value state is stored out-of-core and is lazily fetched using the + * {@link LazyDbKvState} implementation. A different backend can also be + * provided in the constructor to store the non-partitioned states. A common use + * case would be to store the key-value states in the database and store larger + * non-partitioned states on a distributed file system. + *

+ * This backend implementation also allows the sharding of the checkpointed + * states among multiple database instances, which can be enabled by passing + * multiple database urls to the {@link DbBackendConfig} instance. + *

+ * By default there are multiple tables created in the given databases: 1 table + * for non-partitioned checkpoints and 1 table for each key-value state in the + * streaming program. + *

+ * To control table creation, insert/lookup operations and to provide + * compatibility for different SQL implementations, a custom + * {@link MySqlAdapter} can be supplied in the {@link DbBackendConfig}. + * + */ +public class DbStateBackend extends StateBackend { + + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(DbStateBackend.class); + + private Random rnd; + + // ------------------------------------------------------ + + private Environment env; + + // ------------------------------------------------------ + + private final DbBackendConfig dbConfig; + private final DbAdapter dbAdapter; + + private ShardedConnection connections; + + private final int numSqlRetries; + private final int sqlRetrySleep; + + private PreparedStatement insertStatement; + + // ------------------------------------------------------ + + // We allow to use a different backend for storing non-partitioned states + private StateBackend nonPartitionedStateBackend = null; + + // ------------------------------------------------------ + + /** + * Create a new {@link DbStateBackend} using the provided + * {@link DbBackendConfig} configuration. + * + */ + public DbStateBackend(DbBackendConfig backendConfig) { + this.dbConfig = backendConfig; + dbAdapter = backendConfig.getDbAdapter(); + numSqlRetries = backendConfig.getMaxNumberOfSqlRetries(); + sqlRetrySleep = backendConfig.getSleepBetweenSqlRetries(); + } + + /** + * Create a new {@link DbStateBackend} using the provided + * {@link DbBackendConfig} configuration and a different backend for storing + * non-partitioned state snapshots. + * + */ + public DbStateBackend(DbBackendConfig backendConfig, StateBackend backend) { + this(backendConfig); + this.nonPartitionedStateBackend = backend; + } + + /** + * Get the database connections maintained by the backend. + */ + public ShardedConnection getConnections() { + return connections; + } + + /** + * Check whether the backend has been initialized. + * + */ + public boolean isInitialized() { + return connections != null; + } + + public Environment getEnvironment() { + return env; + } + + /** + * Get the backend configuration object. + */ + public DbBackendConfig getConfiguration() { + return dbConfig; + } + + @Override + public StateHandle checkpointStateSerializable(final S state, final long checkpointID, + final long timestamp) throws Exception { + + // If we set a different backend for non-partitioned checkpoints we use + // that otherwise write to the database. + if (nonPartitionedStateBackend == null) { + return retry(new Callable>() { + public DbStateHandle call() throws Exception { + // We create a unique long id for each handle, but we also + // store the checkpoint id and timestamp for bookkeeping + long handleId = rnd.nextLong(); + String jobIdShort = env.getJobID().toShortString(); + + dbAdapter.setCheckpointInsertParams(jobIdShort, insertStatement, + checkpointID, timestamp, handleId, + InstantiationUtil.serializeObject(state)); + + insertStatement.executeUpdate(); + + return new DbStateHandle(jobIdShort, checkpointID, timestamp, handleId, + dbConfig); + } + }, numSqlRetries, sqlRetrySleep); + } else { + return nonPartitionedStateBackend.checkpointStateSerializable(state, checkpointID, timestamp); + } + } + + @Override + public CheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) + throws Exception { + if (nonPartitionedStateBackend == null) { + // We don't implement this functionality for the DbStateBackend as + // we cannot directly write a stream to the database anyways. + throw new UnsupportedOperationException("Use ceckpointStateSerializable instead."); + } else { + return nonPartitionedStateBackend.createCheckpointStateOutputStream(checkpointID, timestamp); + } + } + + @Override + public LazyDbKvState createKvState(String stateId, String stateName, + TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws IOException { + return new LazyDbKvState( + stateId + "_" + env.getJobID().toShortString(), + env.getIndexInSubtaskGroup() == 0, + getConnections(), + getConfiguration(), + keySerializer, + valueSerializer, + defaultValue); + } + + @Override + public void initializeForJob(final Environment env) throws Exception { + this.rnd = new Random(); + this.env = env; + + connections = dbConfig.createShardedConnection(); + + // We want the most light-weight transaction isolation level as we don't + // have conflicting reads/writes. We just want to be able to roll back + // batch inserts for k-v snapshots. This requirement might be removed in + // the future. + connections.setTransactionIsolation(Connection.TRANSACTION_READ_UNCOMMITTED); + + // If we have a different backend for non-partitioned states we + // initialize that, otherwise create tables for storing the checkpoints. + // + // Currently all non-partitioned states are written to the first + // database shard + if (nonPartitionedStateBackend == null) { + insertStatement = retry(new Callable() { + public PreparedStatement call() throws SQLException { + dbAdapter.createCheckpointsTable(env.getJobID().toShortString(), getConnections().getFirst()); + return dbAdapter.prepareCheckpointInsert(env.getJobID().toShortString(), + getConnections().getFirst()); + } + }, numSqlRetries, sqlRetrySleep); + } else { + nonPartitionedStateBackend.initializeForJob(env); + } + + if (LOG.isDebugEnabled()) { + LOG.debug("Database state backend successfully initialized"); + } + } + + @Override + public void close() throws Exception { + // We first close the statement/non-partitioned backend, then we close + // the database connection + try (ShardedConnection c = connections) { + if (nonPartitionedStateBackend == null) { + insertStatement.close(); + } else { + nonPartitionedStateBackend.close(); + } + } + } + + @Override + public void disposeAllStateForCurrentJob() throws Exception { + if (nonPartitionedStateBackend == null) { + dbAdapter.disposeAllStateForJob(env.getJobID().toShortString(), connections.getFirst()); + } else { + nonPartitionedStateBackend.disposeAllStateForCurrentJob(); + } + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateHandle.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateHandle.java new file mode 100644 index 0000000000000..2ecfcc4b7840d --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/DbStateHandle.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import static org.apache.flink.contrib.streaming.state.SQLRetrier.retry; + +import java.io.IOException; +import java.io.Serializable; +import java.util.concurrent.Callable; + +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.util.InstantiationUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * State handle implementation for storing checkpoints as byte arrays in + * databases using the {@link MySqlAdapter} defined in the {@link DbBackendConfig}. + * + */ +public class DbStateHandle implements Serializable, StateHandle { + + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(DbStateHandle.class); + + private final String jobId; + private final DbBackendConfig dbConfig; + + private final long checkpointId; + private final long checkpointTs; + + private final long handleId; + + public DbStateHandle(String jobId, long checkpointId, long checkpointTs, long handleId, DbBackendConfig dbConfig) { + this.checkpointId = checkpointId; + this.handleId = handleId; + this.jobId = jobId; + this.dbConfig = dbConfig; + this.checkpointTs = checkpointTs; + } + + protected byte[] getBytes() throws IOException { + return retry(new Callable() { + public byte[] call() throws Exception { + try (ShardedConnection con = dbConfig.createShardedConnection()) { + return dbConfig.getDbAdapter().getCheckpoint(jobId, con.getFirst(), checkpointId, checkpointTs, handleId); + } + } + }, dbConfig.getMaxNumberOfSqlRetries(), dbConfig.getSleepBetweenSqlRetries()); + } + + @Override + public void discardState() { + try { + retry(new Callable() { + public Boolean call() throws Exception { + try (ShardedConnection con = dbConfig.createShardedConnection()) { + dbConfig.getDbAdapter().deleteCheckpoint(jobId, con.getFirst(), checkpointId, checkpointTs, handleId); + } + return true; + } + }, dbConfig.getMaxNumberOfSqlRetries(), dbConfig.getSleepBetweenSqlRetries()); + } catch (IOException e) { + // We don't want to fail the job here, but log the error. + if (LOG.isDebugEnabled()) { + LOG.debug("Could not discard state."); + } + } + } + + @Override + public S getState(ClassLoader userCodeClassLoader) throws IOException, ClassNotFoundException { + return InstantiationUtil.deserializeObject(getBytes(), userCodeClassLoader); + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/LazyDbKvState.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/LazyDbKvState.java new file mode 100644 index 0000000000000..3d7abff1c9ffd --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/LazyDbKvState.java @@ -0,0 +1,624 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import static org.apache.flink.contrib.streaming.state.SQLRetrier.retry; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.contrib.streaming.state.ShardedConnection.ShardedStatement; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; +import org.apache.flink.util.InstantiationUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Optional; + +/** + * + * Lazily fetched {@link KvState} using a SQL backend. Key-value pairs are + * cached on heap and are lazily retrieved on access. + * + */ +public class LazyDbKvState implements KvState, CheckpointNotifier { + + private static final Logger LOG = LoggerFactory.getLogger(LazyDbKvState.class); + + // ------------------------------------------------------ + + // Unique id for this state (jobID_operatorID_stateName) + private final String kvStateId; + private final boolean compact; + + private K currentKey; + private final V defaultValue; + + private final TypeSerializer keySerializer; + private final TypeSerializer valueSerializer; + + // ------------------------------------------------------ + + // Max number of retries for failed database operations + private final int numSqlRetries; + // Sleep time between two retries + private final int sqlRetrySleep; + // Max number of key-value pairs inserted in one batch to the database + private final int maxInsertBatchSize; + // We will do database compaction every so many checkpoints + private final int compactEvery; + // Executor for automatic compactions + private ExecutorService executor = null; + + // Database properties + private final DbBackendConfig conf; + private final ShardedConnection connections; + private final DbAdapter dbAdapter; + + // Convenience object for handling inserts to the database + private final BatchInserter batchInsert; + + // Statements for key-lookups and inserts as prepared by the dbAdapter + private ShardedStatement selectStatements; + private ShardedStatement insertStatements; + + // ------------------------------------------------------ + + // LRU cache for the key-value states backed by the database + private final StateCache cache; + + private long nextTs; + private Map completedCheckpoints = new HashMap<>(); + + private volatile long lastCompactedTs; + + // ------------------------------------------------------ + + /** + * Constructor to initialize the {@link LazyDbKvState} the first time the + * job starts. + */ + public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, DbBackendConfig conf, + TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws IOException { + this(kvStateId, compact, cons, conf, keySerializer, valueSerializer, defaultValue, 1, 0); + } + + /** + * Initialize the {@link LazyDbKvState} from a snapshot. + */ + public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, final DbBackendConfig conf, + TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue, long nextTs, + long lastCompactedTs) + throws IOException { + + this.kvStateId = kvStateId; + this.compact = compact; + if (compact) { + // Compactions will run in a seperate thread + executor = Executors.newSingleThreadExecutor(); + } + + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + this.defaultValue = defaultValue; + + this.maxInsertBatchSize = conf.getMaxKvInsertBatchSize(); + this.conf = conf; + this.connections = cons; + this.dbAdapter = conf.getDbAdapter(); + this.compactEvery = conf.getKvStateCompactionFrequency(); + this.numSqlRetries = conf.getMaxNumberOfSqlRetries(); + this.sqlRetrySleep = conf.getSleepBetweenSqlRetries(); + + this.nextTs = nextTs; + this.lastCompactedTs = lastCompactedTs; + + this.cache = new StateCache(conf.getKvCacheSize(), conf.getNumElementsToEvict()); + + initDB(this.connections); + + batchInsert = new BatchInserter(connections.getNumShards()); + + if (LOG.isDebugEnabled()) { + LOG.debug("Lazy database kv-state ({}) successfully initialized", kvStateId); + } + } + + @Override + public void setCurrentKey(K key) { + this.currentKey = key; + } + + @Override + public void update(V value) throws IOException { + try { + cache.put(currentKey, Optional.fromNullable(value)); + } catch (RuntimeException e) { + // We need to catch the RuntimeExceptions thrown in the StateCache + // methods here + throw new IOException(e); + } + } + + @Override + public V value() throws IOException { + try { + // We get the value from the cache (which will automatically load it + // from the database if necessary). If null, we return a copy of the + // default value + V val = cache.get(currentKey).orNull(); + return val != null ? val : copyDefault(); + } catch (RuntimeException e) { + // We need to catch the RuntimeExceptions thrown in the StateCache + // methods here + throw new IOException(e); + } + } + + @Override + public DbKvStateSnapshot snapshot(long checkpointId, long timestamp) throws IOException { + + // Validate timing assumptions + if (timestamp <= nextTs) { + throw new RuntimeException("Checkpoint timestamp is smaller than previous ts + 1, " + + "this should not happen."); + } + + // If there are any modified states we perform the inserts + if (!cache.modified.isEmpty()) { + // We insert the modified elements to the database with the current + // timestamp then clear the modified states + for (Entry> state : cache.modified.entrySet()) { + batchInsert.add(state, timestamp); + } + batchInsert.flush(timestamp); + cache.modified.clear(); + } else if (compact) { + // Otherwise we call the keep alive method to avoid dropped + // connections (only call this on the compactor instance) + for (final Connection c : connections.connections()) { + SQLRetrier.retry(new Callable() { + @Override + public Void call() throws Exception { + dbAdapter.keepAlive(c); + return null; + } + }, numSqlRetries, sqlRetrySleep); + } + } + + nextTs = timestamp + 1; + completedCheckpoints.put(checkpointId, timestamp); + return new DbKvStateSnapshot(kvStateId, timestamp, lastCompactedTs); + } + + /** + * Returns the number of elements currently stored in the task's cache. Note + * that the number of elements in the database is not counted here. + */ + @Override + public int size() { + return cache.size(); + } + + /** + * Return a copy the default value or null if the default was null. + * + */ + private V copyDefault() { + return defaultValue != null ? valueSerializer.copy(defaultValue) : null; + } + + /** + * Create a table for the kvstate checkpoints (based on the kvStateId) and + * prepare the statements used during checkpointing. + */ + private void initDB(final ShardedConnection cons) throws IOException { + + retry(new Callable() { + public Void call() throws Exception { + + for (Connection con : cons.connections()) { + dbAdapter.createKVStateTable(kvStateId, con); + } + + insertStatements = cons.prepareStatement(dbAdapter.prepareKVCheckpointInsert(kvStateId)); + selectStatements = cons.prepareStatement(dbAdapter.prepareKeyLookup(kvStateId)); + + return null; + } + + }, numSqlRetries, sqlRetrySleep); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) { + final Long ts = completedCheckpoints.remove(checkpointId); + if (ts == null) { + LOG.warn("Complete notification for missing checkpoint: " + checkpointId); + } else { + // If compaction is turned on we compact on the compactor subtask + // asynchronously in the background + if (compactEvery > 0 && compact && checkpointId % compactEvery == 0) { + executor.execute(new Compactor(ts)); + } + } + } + + @Override + public void dispose() { + // We are only closing the statements here, the connection is borrowed + // from the state backend and will be closed there. + try { + selectStatements.close(); + } catch (SQLException e) { + // There is not much to do about this + } + try { + insertStatements.close(); + } catch (SQLException e) { + // There is not much to do about this + } + + if (executor != null) { + executor.shutdown(); + } + } + + /** + * Return the Map of cached states. + * + */ + public Map> getStateCache() { + return cache; + } + + /** + * Return the Map of modified states that hasn't been written to the + * database yet. + * + */ + public Map> getModified() { + return cache.modified; + } + + /** + * Used for testing purposes + */ + public boolean isCompactor() { + return compact; + } + + /** + * Used for testing purposes + */ + public ExecutorService getExecutor() { + return executor; + } + + /** + * Snapshot that stores a specific checkpoint timestamp and state id, and + * also rolls back the database to that point upon restore. The rollback is + * done by removing all state checkpoints that have timestamps between the + * checkpoint and recovery timestamp. + * + */ + private static class DbKvStateSnapshot implements KvStateSnapshot { + + private static final long serialVersionUID = 1L; + + private final String kvStateId; + private final long checkpointTimestamp; + private final long lastCompactedTimestamp; + + public DbKvStateSnapshot(String kvStateId, long checkpointTimestamp, long lastCompactedTs) { + this.checkpointTimestamp = checkpointTimestamp; + this.kvStateId = kvStateId; + this.lastCompactedTimestamp = lastCompactedTs; + } + + @Override + public LazyDbKvState restoreState(final DbStateBackend stateBackend, + final TypeSerializer keySerializer, final TypeSerializer valueSerializer, final V defaultValue, + ClassLoader classLoader, final long recoveryTimestamp) throws IOException { + + // Validate timing assumptions + if (recoveryTimestamp <= checkpointTimestamp) { + throw new RuntimeException( + "Recovery timestamp is smaller or equal to checkpoint timestamp. " + + "This might happen if the job was started with a new JobManager " + + "and the clocks got really out of sync."); + } + + // First we clean up the states written by partially failed + // snapshots + retry(new Callable() { + public Void call() throws Exception { + + // We need to perform cleanup on all shards to be safe here + for (Connection c : stateBackend.getConnections().connections()) { + stateBackend.getConfiguration().getDbAdapter().cleanupFailedCheckpoints(kvStateId, + c, checkpointTimestamp, recoveryTimestamp); + } + + return null; + } + }, stateBackend.getConfiguration().getMaxNumberOfSqlRetries(), + stateBackend.getConfiguration().getSleepBetweenSqlRetries()); + + boolean cleanup = stateBackend.getEnvironment().getIndexInSubtaskGroup() == 0; + + // Restore the KvState + LazyDbKvState restored = new LazyDbKvState(kvStateId, cleanup, + stateBackend.getConnections(), stateBackend.getConfiguration(), keySerializer, valueSerializer, + defaultValue, recoveryTimestamp, lastCompactedTimestamp); + + if (LOG.isDebugEnabled()) { + LOG.debug("KV state({},{}) restored.", kvStateId, recoveryTimestamp); + } + + return restored; + } + + @Override + public void discardState() throws Exception { + // Don't discard, it will be compacted by the LazyDbKvState + } + + } + + /** + * LRU cache implementation for storing the key-value states. When the cache + * is full elements are not evicted one by one but are evicted in a batch + * defined by the evictionSize parameter. + *

+ * Keys not found in the cached will be retrieved from the underlying + * database + */ + private final class StateCache extends LinkedHashMap> { + private static final long serialVersionUID = 1L; + + private final int cacheSize; + private final int evictionSize; + + // We keep track the state modified since the last checkpoint + private final Map> modified = new HashMap<>(); + + public StateCache(int cacheSize, int evictionSize) { + super(cacheSize, 0.75f, true); + this.cacheSize = cacheSize; + this.evictionSize = evictionSize; + } + + @Override + public Optional put(K key, Optional value) { + // Put kv pair in the cache and evict elements if the cache is full + Optional old = super.put(key, value); + modified.put(key, value); + evictIfFull(); + return old; + } + + @SuppressWarnings("unchecked") + @Override + public Optional get(Object key) { + // First we check whether the value is cached + Optional value = super.get(key); + if (value == null) { + // If it doesn't try to load it from the database + value = Optional.fromNullable(getFromDatabaseOrNull((K) key)); + put((K) key, value); + } + return value; + } + + @Override + protected boolean removeEldestEntry(Entry> eldest) { + // We need to remove elements manually if the cache becomes full, so + // we always return false here. + return false; + } + + /** + * Fetch the current value from the database if exists or return null. + * + * @param key + * @return The value corresponding to the key and the last checkpointid + * from the database if exists or null. + */ + private V getFromDatabaseOrNull(final K key) { + try { + return retry(new Callable() { + public V call() throws Exception { + byte[] serializedKey = InstantiationUtil.serializeToByteArray(keySerializer, key); + // We lookup using the adapter and serialize/deserialize + // with the TypeSerializers + byte[] serializedVal = dbAdapter.lookupKey(kvStateId, + selectStatements.getForKey(key), serializedKey, nextTs); + + return serializedVal != null + ? InstantiationUtil.deserializeFromByteArray(valueSerializer, serializedVal) : null; + } + }, numSqlRetries, sqlRetrySleep); + } catch (IOException e) { + // We need to re-throw this exception to conform to the map + // interface, we will catch this when we call the the put/get + throw new RuntimeException(e); + } + } + + /** + * If the cache is full we remove the evictionSize least recently + * accessed elements and write them to the database if they were + * modified since the last checkpoint. + */ + private void evictIfFull() { + if (size() > cacheSize) { + if (LOG.isDebugEnabled()) { + LOG.debug("State cache is full for {}, evicting {} elements.", kvStateId, evictionSize); + } + try { + int numEvicted = 0; + + Iterator>> entryIterator = entrySet().iterator(); + while (numEvicted++ < evictionSize && entryIterator.hasNext()) { + + Entry> next = entryIterator.next(); + + // We only need to write to the database if modified + if (modified.remove(next.getKey()) != null) { + batchInsert.add(next, nextTs); + } + + entryIterator.remove(); + } + + batchInsert.flush(nextTs); + + } catch (IOException e) { + // We need to re-throw this exception to conform to the map + // interface, we will catch this when we call the the + // put/get + throw new RuntimeException(e); + } + } + } + + @Override + public void putAll(Map> m) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + } + + /** + * Object for handling inserts to the database by batching them together + * partitioned on the sharding key. The batches are written to the database + * when they are full or when the inserter is flushed. + * + */ + private class BatchInserter { + + // Map from shard index to the kv pairs to be inserted + // Map>> inserts = new HashMap<>(); + + List>[] inserts; + + @SuppressWarnings("unchecked") + public BatchInserter(int numShards) { + inserts = new List[numShards]; + for (int i = 0; i < numShards; i++) { + inserts[i] = new ArrayList<>(); + } + } + + public void add(Entry> next, long timestamp) throws IOException { + + K key = next.getKey(); + V value = next.getValue().orNull(); + + // Get the current partition if present or initialize empty list + int shardIndex = connections.getShardIndex(key); + + List> insertPartition = inserts[shardIndex]; + + // Add the k-v pair to the partition + byte[] k = InstantiationUtil.serializeToByteArray(keySerializer, key); + byte[] v = value != null ? InstantiationUtil.serializeToByteArray(valueSerializer, value) : null; + insertPartition.add(Tuple2.of(k, v)); + + // If partition is full write to the database and clear + if (insertPartition.size() == maxInsertBatchSize) { + dbAdapter.insertBatch(kvStateId, conf, + connections.getForIndex(shardIndex), + insertStatements.getForIndex(shardIndex), + timestamp, insertPartition); + + insertPartition.clear(); + } + } + + public void flush(long timestamp) throws IOException { + // We flush all non-empty partitions + for (int i = 0; i < inserts.length; i++) { + List> insertPartition = inserts[i]; + if (!insertPartition.isEmpty()) { + dbAdapter.insertBatch(kvStateId, conf, connections.getForIndex(i), + insertStatements.getForIndex(i), timestamp, insertPartition); + insertPartition.clear(); + } + } + + } + } + + private class Compactor implements Runnable { + + private long upperBound; + + public Compactor(long upperBound) { + this.upperBound = upperBound; + } + + @Override + public void run() { + // We create new database connections to make sure we don't + // interfere with the checkpointing (connections are not thread + // safe) + try (ShardedConnection sc = conf.createShardedConnection()) { + for (final Connection c : sc.connections()) { + SQLRetrier.retry(new Callable() { + @Override + public Void call() throws Exception { + dbAdapter.compactKvStates(kvStateId, c, lastCompactedTs, upperBound); + return null; + } + }, numSqlRetries, sqlRetrySleep); + } + if (LOG.isInfoEnabled()) { + LOG.info("State succesfully compacted for {} between {} and {}.", kvStateId, + lastCompactedTs, + upperBound); + } + lastCompactedTs = upperBound; + } catch (SQLException | IOException e) { + LOG.warn("State compaction failed due: {}", e); + } + } + + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/MySqlAdapter.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/MySqlAdapter.java new file mode 100644 index 0000000000000..9eaa2833d13de --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/MySqlAdapter.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Types; +import java.util.List; +import java.util.concurrent.Callable; + +import org.apache.flink.api.java.tuple.Tuple2; + +/** + * + * Adapter for bridging inconsistencies between the different SQL + * implementations. The default implementation has been tested to work well with + * MySQL + * + */ +public class MySqlAdapter implements DbAdapter { + + private static final long serialVersionUID = 1L; + + // ----------------------------------------------------------------------------- + // Non-partitioned state checkpointing + // ----------------------------------------------------------------------------- + + @Override + public void createCheckpointsTable(String jobId, Connection con) throws SQLException { + try (Statement smt = con.createStatement()) { + smt.executeUpdate( + "CREATE TABLE IF NOT EXISTS checkpoints_" + jobId + + " (" + + "checkpointId bigint, " + + "timestamp bigint, " + + "handleId bigint," + + "checkpoint blob," + + "PRIMARY KEY (handleId)" + + ")"); + } + + } + + @Override + public PreparedStatement prepareCheckpointInsert(String jobId, Connection con) throws SQLException { + return con.prepareStatement( + "INSERT INTO checkpoints_" + jobId + + " (checkpointId, timestamp, handleId, checkpoint) VALUES (?,?,?,?)"); + } + + @Override + public void setCheckpointInsertParams(String jobId, PreparedStatement insertStatement, long checkpointId, + long timestamp, long handleId, byte[] checkpoint) throws SQLException { + insertStatement.setLong(1, checkpointId); + insertStatement.setLong(2, timestamp); + insertStatement.setLong(3, handleId); + insertStatement.setBytes(4, checkpoint); + } + + @Override + public byte[] getCheckpoint(String jobId, Connection con, long checkpointId, long checkpointTs, long handleId) + throws SQLException { + try (Statement smt = con.createStatement()) { + ResultSet rs = smt.executeQuery( + "SELECT checkpoint FROM checkpoints_" + jobId + + " WHERE handleId = " + handleId); + if (rs.next()) { + return rs.getBytes(1); + } else { + throw new SQLException("Checkpoint cannot be found in the database."); + } + } + } + + @Override + public void deleteCheckpoint(String jobId, Connection con, long checkpointId, long checkpointTs, long handleId) + throws SQLException { + try (Statement smt = con.createStatement()) { + smt.executeUpdate( + "DELETE FROM checkpoints_" + jobId + + " WHERE handleId = " + handleId); + } + } + + @Override + public void disposeAllStateForJob(String jobId, Connection con) throws SQLException { + try (Statement smt = con.createStatement()) { + smt.executeUpdate( + "DROP TABLE checkpoints_" + jobId); + } + } + + // ----------------------------------------------------------------------------- + // Partitioned state checkpointing + // ----------------------------------------------------------------------------- + + @Override + public void createKVStateTable(String stateId, Connection con) throws SQLException { + validateStateId(stateId); + try (Statement smt = con.createStatement()) { + smt.executeUpdate( + "CREATE TABLE IF NOT EXISTS " + stateId + + " (" + + "timestamp bigint, " + + "k varbinary(256), " + + "v blob, " + + "PRIMARY KEY (k, timestamp) " + + ")"); + } + } + + @Override + public String prepareKVCheckpointInsert(String stateId) throws SQLException { + validateStateId(stateId); + return "INSERT INTO " + stateId + " (timestamp, k, v) VALUES (?,?,?) " + + "ON DUPLICATE KEY UPDATE v=? "; + } + + @Override + public String prepareKeyLookup(String stateId) throws SQLException { + validateStateId(stateId); + return "SELECT v" + + " FROM " + stateId + + " WHERE k = ?" + + " AND timestamp <= ?" + + " ORDER BY timestamp DESC LIMIT 1"; + } + + @Override + public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupTs) + throws SQLException { + lookupStatement.setBytes(1, key); + lookupStatement.setLong(2, lookupTs); + + ResultSet res = lookupStatement.executeQuery(); + + if (res.next()) { + return res.getBytes(1); + } else { + return null; + } + } + + @Override + public void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTs, + long recoveryTs) throws SQLException { + validateStateId(stateId); + try (Statement smt = con.createStatement()) { + smt.executeUpdate("DELETE FROM " + stateId + + " WHERE timestamp > " + checkpointTs + + " AND timestamp < " + recoveryTs); + } + } + + @Override + public void compactKvStates(String stateId, Connection con, long lowerId, long upperId) + throws SQLException { + validateStateId(stateId); + + try (Statement smt = con.createStatement()) { + smt.executeUpdate("DELETE state.* FROM " + stateId + " AS state" + + " JOIN" + + " (" + + " SELECT MAX(timestamp) AS maxts, k FROM " + stateId + + " WHERE timestamp BETWEEN " + lowerId + " AND " + upperId + + " GROUP BY k" + + " ) m" + + " ON state.k = m.k" + + " AND state.timestamp >= " + lowerId); + } + } + + /** + * Tries to avoid SQL injection with weird state names. + * + */ + protected static void validateStateId(String name) { + if (!name.matches("[a-zA-Z0-9_]+")) { + throw new RuntimeException("State name contains invalid characters."); + } + } + + @Override + public void insertBatch(final String stateId, final DbBackendConfig conf, + final Connection con, final PreparedStatement insertStatement, final long checkpointTs, + final List> toInsert) throws IOException { + + SQLRetrier.retry(new Callable() { + public Void call() throws Exception { + for (Tuple2 kv : toInsert) { + setKvInsertParams(stateId, insertStatement, checkpointTs, kv.f0, kv.f1); + insertStatement.addBatch(); + } + insertStatement.executeBatch(); + insertStatement.clearBatch(); + return null; + } + }, new Callable() { + public Void call() throws Exception { + insertStatement.clearBatch(); + return null; + } + }, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries()); + } + + private void setKvInsertParams(String stateId, PreparedStatement insertStatement, long checkpointTs, + byte[] key, byte[] value) throws SQLException { + insertStatement.setLong(1, checkpointTs); + insertStatement.setBytes(2, key); + if (value != null) { + insertStatement.setBytes(3, value); + insertStatement.setBytes(4, value); + } else { + insertStatement.setNull(3, Types.BLOB); + insertStatement.setNull(4, Types.BLOB); + } + } + + @Override + public void keepAlive(Connection con) throws SQLException { + try(Statement smt = con.createStatement()) { + smt.executeQuery("SELECT 1"); + } + } + +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/SQLRetrier.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/SQLRetrier.java new file mode 100644 index 0000000000000..4ae3fd23ef8c3 --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/SQLRetrier.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.IOException; +import java.io.Serializable; +import java.sql.SQLException; +import java.util.Random; +import java.util.concurrent.Callable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Simple utility to retry failed SQL commands for a predefined number of times + * before declaring failure. The retrier waits (randomly) between 2 retries. + * + */ +public final class SQLRetrier implements Serializable { + private static final long serialVersionUID = 1L; + + private static final Logger LOG = LoggerFactory.getLogger(SQLRetrier.class); + private static final Random rnd = new Random(); + + private static final int SLEEP_TIME = 10; + + private SQLRetrier() { + + } + + /** + * Tries to run the given {@link Callable} the predefined number of times + * before throwing an {@link IOException}. This method will only retries + * calls ending in {@link SQLException}. Other exceptions will cause a + * {@link RuntimeException}. + * + * @param callable + * The callable to be retried. + * @param numRetries + * Max number of retries before throwing an {@link IOException}. + * @throws IOException + * The wrapped {@link SQLException}. + */ + public static X retry(Callable callable, int numRetries) throws IOException { + return retry(callable, numRetries, SLEEP_TIME); + } + + /** + * Tries to run the given {@link Callable} the predefined number of times + * before throwing an {@link IOException}. This method will only retries + * calls ending in {@link SQLException}. Other exceptions will cause a + * {@link RuntimeException}. + * + * @param callable + * The callable to be retried. + * @param numRetries + * Max number of retries before throwing an {@link IOException}. + * @param sleep + * The retrier will wait a random number of msecs between 1 and + * sleep. + * @return The result of the {@link Callable#call()}. + * @throws IOException + * The wrapped {@link SQLException}. + */ + public static X retry(Callable callable, int numRetries, int sleep) throws IOException { + int numtries = 0; + while (true) { + try { + return callable.call(); + } catch (SQLException e) { + handleSQLException(e, ++numtries, numRetries, sleep); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * Tries to run the given {@link Callable} the predefined number of times + * before throwing an {@link IOException}. This method will only retries + * calls ending in {@link SQLException}. Other exceptions will cause a + * {@link RuntimeException}. + * + * Additionally the user can supply a second callable which will be executed + * every time the first callable throws a {@link SQLException}. + * + * @param callable + * The callable to be retried. + * @param onException + * The callable to be executed when an {@link SQLException} was + * encountered. Exceptions thrown during this call are ignored. + * @param numRetries + * Max number of retries before throwing an {@link IOException}. + * @param sleep + * The retrier will wait a random number of msecs between 1 and + * sleep. + * @return The result of the {@link Callable#call()}. + * @throws IOException + * The wrapped {@link SQLException}. + */ + public static X retry(Callable callable, Callable onException, int numRetries, int sleep) + throws IOException { + int numtries = 0; + while (true) { + try { + return callable.call(); + } catch (SQLException se) { + try { + onException.call(); + } catch (Exception e) { + // Exceptions thrown in this call will be ignored + } + handleSQLException(se, ++numtries, numRetries, sleep); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + } + + /** + * Tries to run the given {@link Callable} the predefined number of times + * before throwing an {@link IOException}. This method will only retries + * calls ending in {@link SQLException}. Other exceptions will cause a + * {@link RuntimeException}. + * + * Additionally the user can supply a second callable which will be executed + * every time the first callable throws a {@link SQLException}. + * + * @param callable + * The callable to be retried. + * @param onException + * The callable to be executed when an {@link SQLException} was + * encountered. Exceptions thrown during this call are ignored. + * @param numRetries + * Max number of retries before throwing an {@link IOException}. + * @return The result of the {@link Callable#call()}. + * @throws IOException + * The wrapped {@link SQLException}. + */ + public static X retry(Callable callable, Callable onException, int numRetries) + throws IOException { + return retry(callable, onException, numRetries, SLEEP_TIME); + } + + private static void handleSQLException(SQLException e, int numTries, int maxRetries, int sleep) throws IOException { + if (numTries < maxRetries) { + if (LOG.isDebugEnabled()) { + LOG.debug("Error while executing SQL statement: {}\nRetrying...", + e.getMessage()); + } + try { + Thread.sleep(numTries * rnd.nextInt(sleep)); + } catch (InterruptedException ie) { + throw new RuntimeException("Thread has been interrupted."); + } + } else { + throw new IOException( + "Could not execute SQL statement after " + maxRetries + " retries.", e); + } + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/ShardedConnection.java b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/ShardedConnection.java new file mode 100644 index 0000000000000..44995de7d2b1a --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/main/java/org/apache/flink/contrib/streaming/state/ShardedConnection.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.Serializable; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; + +/** + * Helper class to maintain a sharded database connection and get + * {@link Connection}s and {@link PreparedStatement}s for keys. + * + */ +public class ShardedConnection implements AutoCloseable, Serializable { + + private static final long serialVersionUID = 1L; + private final Connection[] connections; + private final int numShards; + + private final Partitioner partitioner; + + public ShardedConnection(List shardUrls, String user, String password, Partitioner partitioner) + throws SQLException { + numShards = shardUrls.size(); + connections = new Connection[numShards]; + for (int i = 0; i < numShards; i++) { + connections[i] = DriverManager.getConnection(shardUrls.get(i), user, password); + } + this.partitioner = partitioner; + } + + public ShardedConnection(List shardUrls, String user, String password) throws SQLException { + this(shardUrls, user, password, new ModHashPartitioner()); + } + + public ShardedStatement prepareStatement(String sql) throws SQLException { + return new ShardedStatement(sql); + } + + public Connection[] connections() { + return connections; + } + + public Connection getForKey(Object key) { + return connections[getShardIndex(key)]; + } + + public Connection getForIndex(int index) { + if (index < numShards) { + return connections[index]; + } else { + throw new RuntimeException("Index out of range"); + } + } + + public Connection getFirst() { + return connections[0]; + } + + public int getNumShards() { + return numShards; + } + + @Override + public void close() throws SQLException { + if (connections != null) { + for (Connection c : connections) { + c.close(); + } + } + } + + public int getShardIndex(Object key) { + return partitioner.getShardIndex(key, numShards); + } + + public void setTransactionIsolation(int level) throws SQLException { + for (Connection con : connections) { + con.setTransactionIsolation(level); + } + } + + public class ShardedStatement implements AutoCloseable, Serializable { + + private static final long serialVersionUID = 1L; + private final PreparedStatement[] statements = new PreparedStatement[numShards]; + + public ShardedStatement(final String sql) throws SQLException { + for (int i = 0; i < numShards; i++) { + statements[i] = connections[i].prepareStatement(sql); + } + } + + public PreparedStatement getForKey(Object key) { + return statements[getShardIndex(key)]; + } + + public PreparedStatement getForIndex(int index) { + if (index < numShards) { + return statements[index]; + } else { + throw new RuntimeException("Index out of range"); + } + } + + public PreparedStatement getFirst() { + return statements[0]; + } + + @Override + public void close() throws SQLException { + if (statements != null) { + for (PreparedStatement t : statements) { + t.close(); + } + } + } + + } + + public interface Partitioner extends Serializable { + int getShardIndex(Object key, int numShards); + } + + public static class ModHashPartitioner implements Partitioner { + + private static final long serialVersionUID = 1L; + + @Override + public int getShardIndex(Object key, int numShards) { + return Math.abs(key.hashCode() % numShards); + } + + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DBStateCheckpointingTest.java b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DBStateCheckpointingTest.java new file mode 100644 index 0000000000000..337960ff71c30 --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DBStateCheckpointingTest.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import static org.junit.Assert.assertEquals; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.commons.io.FileUtils; +import org.apache.derby.drda.NetworkServerControl; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.OperatorState; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.test.checkpointing.PartitionedStateCheckpointingITCase.IdentityKeySelector; +import org.apache.flink.test.checkpointing.PartitionedStateCheckpointingITCase.NonSerializableLong; +import org.apache.flink.test.checkpointing.StreamFaultToleranceTestBase; +import org.junit.After; +import org.junit.Before; + +@SuppressWarnings("serial") +public class DBStateCheckpointingTest extends StreamFaultToleranceTestBase { + + final long NUM_STRINGS = 1_000_000L; + final static int NUM_KEYS = 100; + private static NetworkServerControl server; + private static File tempDir; + + @Before + public void startDerbyServer() throws UnknownHostException, Exception { + server = new NetworkServerControl(InetAddress.getByName("localhost"), 1526, "flink", "flink"); + server.start(null); + tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); + } + + @After + public void stopDerbyServer() { + try { + server.shutdown(); + FileUtils.deleteDirectory(new File(tempDir.getAbsolutePath() + "/flinkDB1")); + FileUtils.forceDelete(new File("derby.log")); + } catch (Exception ignore) { + } + } + + @Override + public void testProgram(StreamExecutionEnvironment env) { + env.enableCheckpointing(500); + + DbBackendConfig conf = new DbBackendConfig("flink", "flink", + "jdbc:derby://localhost:1526/" + tempDir.getAbsolutePath() + "/flinkDB1;create=true"); + conf.setDbAdapter(new DerbyAdapter()); + conf.setKvStateCompactionFrequency(2); + + // We store the non-partitioned states (source offset) in-memory + DbStateBackend backend = new DbStateBackend(conf, new MemoryStateBackend()); + + env.setStateBackend(backend); + + DataStream stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2)); + DataStream stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2)); + + stream1.union(stream2).keyBy(new IdentityKeySelector()).map(new OnceFailingPartitionedSum(NUM_STRINGS)) + .keyBy(0).addSink(new CounterSink()); + } + + @Override + public void postSubmit() { + // verify that we counted exactly right + for (Entry sum : OnceFailingPartitionedSum.allSums.entrySet()) { + assertEquals(new Long(sum.getKey() * NUM_STRINGS / NUM_KEYS), sum.getValue()); + } + for (Long count : CounterSink.allCounts.values()) { + assertEquals(new Long(NUM_STRINGS / NUM_KEYS), count); + } + + assertEquals(NUM_KEYS, CounterSink.allCounts.size()); + assertEquals(NUM_KEYS, OnceFailingPartitionedSum.allSums.size()); + } + + // -------------------------------------------------------------------------------------------- + // Custom Functions + // -------------------------------------------------------------------------------------------- + + private static class IntGeneratingSourceFunction extends RichParallelSourceFunction + implements Checkpointed { + + private final long numElements; + + private int index; + private int step; + + private Random rnd = new Random(); + + private volatile boolean isRunning = true; + + static final long[] counts = new long[PARALLELISM]; + + @Override + public void close() throws IOException { + counts[getRuntimeContext().getIndexOfThisSubtask()] = index; + } + + IntGeneratingSourceFunction(long numElements) { + this.numElements = numElements; + } + + @Override + public void open(Configuration parameters) throws IOException { + step = getRuntimeContext().getNumberOfParallelSubtasks(); + if (index == 0) { + index = getRuntimeContext().getIndexOfThisSubtask(); + } + } + + @Override + public void run(SourceContext ctx) throws Exception { + final Object lockingObject = ctx.getCheckpointLock(); + + while (isRunning && index < numElements) { + + synchronized (lockingObject) { + index += step; + ctx.collect(index % NUM_KEYS); + } + + if (rnd.nextDouble() < 0.008) { + Thread.sleep(1); + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public Integer snapshotState(long checkpointId, long checkpointTimestamp) { + return index; + } + + @Override + public void restoreState(Integer state) { + index = state; + } + } + + private static class OnceFailingPartitionedSum extends RichMapFunction> { + + private static Map allSums = new ConcurrentHashMap(); + + private static volatile boolean hasFailed = false; + + private final long numElements; + + private long failurePos; + private long count; + + private OperatorState sum; + + OnceFailingPartitionedSum(long numElements) { + this.numElements = numElements; + } + + @Override + public void open(Configuration parameters) throws IOException { + long failurePosMin = (long) (0.6 * numElements / getRuntimeContext().getNumberOfParallelSubtasks()); + long failurePosMax = (long) (0.8 * numElements / getRuntimeContext().getNumberOfParallelSubtasks()); + + failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin; + count = 0; + sum = getRuntimeContext().getKeyValueState("my_state", Long.class, 0L); + } + + @Override + public Tuple2 map(Integer value) throws Exception { + count++; + if (!hasFailed && count >= failurePos) { + hasFailed = true; + throw new Exception("Test Failure"); + } + + long currentSum = sum.value() + value; + sum.update(currentSum); + allSums.put(value, currentSum); + return new Tuple2(value, currentSum); + } + } + + private static class CounterSink extends RichSinkFunction> { + + private static Map allCounts = new ConcurrentHashMap(); + + private OperatorState aCounts; + private OperatorState bCounts; + + @Override + public void open(Configuration parameters) throws IOException { + aCounts = getRuntimeContext().getKeyValueState("a", NonSerializableLong.class, NonSerializableLong.of(0L)); + bCounts = getRuntimeContext().getKeyValueState("b", Long.class, 0L); + } + + @Override + public void invoke(Tuple2 value) throws Exception { + long ac = aCounts.value().value; + long bc = bCounts.value(); + assertEquals(ac, bc); + + long currentCount = ac + 1; + aCounts.update(NonSerializableLong.of(currentCount)); + bCounts.update(currentCount); + + allCounts.put(value.f0, currentCount); + } + } +} diff --git a/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DbStateBackendTest.java b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DbStateBackendTest.java new file mode 100644 index 0000000000000..209086f8d94a0 --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DbStateBackendTest.java @@ -0,0 +1,478 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.io.FileUtils; +import org.apache.derby.drda.NetworkServerControl; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.KvState; +import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.shaded.com.google.common.collect.Lists; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.google.common.base.Optional; + +public class DbStateBackendTest { + + private static NetworkServerControl server; + private static File tempDir; + private static DbBackendConfig conf; + private static String url1; + private static String url2; + + @BeforeClass + public static void startDerbyServer() throws UnknownHostException, Exception { + server = new NetworkServerControl(InetAddress.getByName("localhost"), 1527, "flink", "flink"); + server.start(null); + tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); + conf = new DbBackendConfig("flink", "flink", + "jdbc:derby://localhost:1527/" + tempDir.getAbsolutePath() + "/flinkDB1;create=true"); + conf.setDbAdapter(new DerbyAdapter()); + conf.setKvStateCompactionFrequency(1); + + url1 = "jdbc:derby://localhost:1527/" + tempDir.getAbsolutePath() + "/flinkDB1;create=true"; + url2 = "jdbc:derby://localhost:1527/" + tempDir.getAbsolutePath() + "/flinkDB2;create=true"; + } + + @AfterClass + public static void stopDerbyServer() throws Exception { + try { + server.shutdown(); + FileUtils.deleteDirectory(new File(tempDir.getAbsolutePath() + "/flinkDB1")); + FileUtils.deleteDirectory(new File(tempDir.getAbsolutePath() + "/flinkDB2")); + FileUtils.forceDelete(new File("derby.log")); + } catch (Exception ignore) { + } + } + + @Test + public void testSetupAndSerialization() throws Exception { + DbStateBackend dbBackend = new DbStateBackend(conf); + + assertFalse(dbBackend.isInitialized()); + + // serialize / copy the backend + DbStateBackend backend = CommonTestUtils.createCopySerializable(dbBackend); + assertFalse(backend.isInitialized()); + + Environment env = new DummyEnvironment("test", 1, 0); + backend.initializeForJob(env); + + assertNotNull(backend.getConnections()); + assertTrue( + isTableCreated(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString())); + + backend.disposeAllStateForCurrentJob(); + assertFalse( + isTableCreated(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString())); + backend.close(); + + assertTrue(backend.getConnections().getFirst().isClosed()); + } + + @Test + public void testSerializableState() throws Exception { + Environment env = new DummyEnvironment("test", 1, 0); + DbStateBackend backend = new DbStateBackend(conf); + + backend.initializeForJob(env); + + String state1 = "dummy state"; + String state2 = "row row row your boat"; + Integer state3 = 42; + + StateHandle handle1 = backend.checkpointStateSerializable(state1, 439568923746L, + System.currentTimeMillis()); + StateHandle handle2 = backend.checkpointStateSerializable(state2, 439568923746L, + System.currentTimeMillis()); + StateHandle handle3 = backend.checkpointStateSerializable(state3, 439568923746L, + System.currentTimeMillis()); + + assertEquals(state1, handle1.getState(getClass().getClassLoader())); + handle1.discardState(); + + assertEquals(state2, handle2.getState(getClass().getClassLoader())); + handle2.discardState(); + + assertFalse(isTableEmpty(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString())); + + assertEquals(state3, handle3.getState(getClass().getClassLoader())); + handle3.discardState(); + + assertTrue(isTableEmpty(backend.getConnections().getFirst(), "checkpoints_" + env.getJobID().toShortString())); + + backend.close(); + + } + + @Test + public void testKeyValueState() throws Exception { + + // We will create the DbStateBackend backed with a FsStateBackend for + // nonPartitioned states + File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); + try { + FsStateBackend fileBackend = new FsStateBackend(localFileUri(tempDir)); + + DbStateBackend backend = new DbStateBackend(conf, fileBackend); + + Environment env = new DummyEnvironment("test", 2, 0); + + backend.initializeForJob(env); + + LazyDbKvState kv = backend.createKvState("state1_1", "state1", IntSerializer.INSTANCE, + StringSerializer.INSTANCE, null); + + String tableName = "state1_1_" + env.getJobID().toShortString(); + assertTrue(isTableCreated(backend.getConnections().getFirst(), tableName)); + + assertEquals(0, kv.size()); + + // some modifications to the state + kv.setCurrentKey(1); + assertNull(kv.value()); + kv.update("1"); + assertEquals(1, kv.size()); + kv.setCurrentKey(2); + assertNull(kv.value()); + kv.update("2"); + assertEquals(2, kv.size()); + kv.setCurrentKey(1); + assertEquals("1", kv.value()); + assertEquals(2, kv.size()); + + kv.snapshot(682375462378L, 100); + + // make some more modifications + kv.setCurrentKey(1); + kv.update("u1"); + kv.setCurrentKey(2); + kv.update("u2"); + kv.setCurrentKey(3); + kv.update("u3"); + + // draw another snapshot + KvStateSnapshot snapshot2 = kv.snapshot(682375462379L, + 200); + + // validate the original state + assertEquals(3, kv.size()); + kv.setCurrentKey(1); + assertEquals("u1", kv.value()); + kv.setCurrentKey(2); + assertEquals("u2", kv.value()); + kv.setCurrentKey(3); + assertEquals("u3", kv.value()); + + // restore the first snapshot and validate it + KvState restored2 = snapshot2.restoreState(backend, IntSerializer.INSTANCE, + StringSerializer.INSTANCE, null, getClass().getClassLoader(), 6823754623710L); + + assertEquals(0, restored2.size()); + restored2.setCurrentKey(1); + assertEquals("u1", restored2.value()); + restored2.setCurrentKey(2); + assertEquals("u2", restored2.value()); + restored2.setCurrentKey(3); + assertEquals("u3", restored2.value()); + + backend.close(); + } finally { + deleteDirectorySilently(tempDir); + } + } + + @Test + public void testCompaction() throws Exception { + DbBackendConfig conf = new DbBackendConfig("flink", "flink", url1); + MockAdapter adapter = new MockAdapter(); + conf.setKvStateCompactionFrequency(2); + conf.setDbAdapter(adapter); + + DbStateBackend backend1 = new DbStateBackend(conf); + DbStateBackend backend2 = new DbStateBackend(conf); + DbStateBackend backend3 = new DbStateBackend(conf); + + backend1.initializeForJob(new DummyEnvironment("test", 3, 0)); + backend2.initializeForJob(new DummyEnvironment("test", 3, 1)); + backend3.initializeForJob(new DummyEnvironment("test", 3, 2)); + + LazyDbKvState s1 = backend1.createKvState("a_1", "a", null, null, null); + LazyDbKvState s2 = backend2.createKvState("a_1", "a", null, null, null); + LazyDbKvState s3 = backend3.createKvState("a_1", "a", null, null, null); + + assertTrue(s1.isCompactor()); + assertFalse(s2.isCompactor()); + assertFalse(s3.isCompactor()); + assertNotNull(s1.getExecutor()); + assertNull(s2.getExecutor()); + assertNull(s3.getExecutor()); + + s1.snapshot(1, 100); + s1.notifyCheckpointComplete(1); + s1.snapshot(2, 200); + s1.snapshot(3, 300); + s1.notifyCheckpointComplete(2); + s1.notifyCheckpointComplete(3); + s1.snapshot(4, 400); + s1.snapshot(5, 500); + s1.notifyCheckpointComplete(4); + s1.notifyCheckpointComplete(5); + + s1.dispose(); + s2.dispose(); + s3.dispose(); + + // Wait until the compaction completes + s1.getExecutor().awaitTermination(5, TimeUnit.SECONDS); + assertEquals(2, adapter.numCompcations.get()); + assertEquals(5, adapter.keptAlive); + + backend1.close(); + backend2.close(); + backend3.close(); + } + + @Test + public void testCaching() throws Exception { + + List urls = Lists.newArrayList(url1, url2); + DbBackendConfig conf = new DbBackendConfig("flink", "flink", + urls); + + conf.setDbAdapter(new DerbyAdapter()); + conf.setKvCacheSize(3); + conf.setMaxKvInsertBatchSize(2); + + // We evict 2 elements when the cache is full + conf.setMaxKvCacheEvictFraction(0.6f); + + DbStateBackend backend = new DbStateBackend(conf); + + Environment env = new DummyEnvironment("test", 2, 0); + + String tableName = "state1_1_" + env.getJobID().toShortString(); + assertFalse(isTableCreated(DriverManager.getConnection(url1, "flink", "flink"), tableName)); + assertFalse(isTableCreated(DriverManager.getConnection(url2, "flink", "flink"), tableName)); + + backend.initializeForJob(env); + + LazyDbKvState kv = backend.createKvState("state1_1", "state1", IntSerializer.INSTANCE, + StringSerializer.INSTANCE, "a"); + + assertTrue(isTableCreated(DriverManager.getConnection(url1, "flink", "flink"), tableName)); + assertTrue(isTableCreated(DriverManager.getConnection(url2, "flink", "flink"), tableName)); + + Map> cache = kv.getStateCache(); + Map> modified = kv.getModified(); + + assertEquals(0, kv.size()); + + // some modifications to the state + kv.setCurrentKey(1); + assertEquals("a", kv.value()); + + kv.update(null); + assertEquals(1, kv.size()); + kv.setCurrentKey(2); + assertEquals("a", kv.value()); + kv.update("2"); + assertEquals(2, kv.size()); + assertEquals("2", kv.value()); + + kv.setCurrentKey(1); + assertEquals("a", kv.value()); + + kv.setCurrentKey(3); + kv.update("3"); + assertEquals("3", kv.value()); + + assertTrue(modified.containsKey(1)); + assertTrue(modified.containsKey(2)); + assertTrue(modified.containsKey(3)); + + // 1,2 should be evicted as the cache filled + kv.setCurrentKey(4); + kv.update("4"); + assertEquals("4", kv.value()); + + assertFalse(modified.containsKey(1)); + assertFalse(modified.containsKey(2)); + assertTrue(modified.containsKey(3)); + assertTrue(modified.containsKey(4)); + + assertEquals(Optional.of("3"), cache.get(3)); + assertEquals(Optional.of("4"), cache.get(4)); + assertFalse(cache.containsKey(1)); + assertFalse(cache.containsKey(2)); + + // draw a snapshot + kv.snapshot(682375462378L, 100); + + assertTrue(modified.isEmpty()); + + // make some more modifications + kv.setCurrentKey(2); + assertEquals("2", kv.value()); + kv.update(null); + assertEquals("a", kv.value()); + + assertTrue(modified.containsKey(2)); + assertEquals(1, modified.size()); + + assertEquals(Optional.of("3"), cache.get(3)); + assertEquals(Optional.of("4"), cache.get(4)); + assertEquals(Optional.absent(), cache.get(2)); + assertFalse(cache.containsKey(1)); + + assertTrue(modified.containsKey(2)); + assertFalse(modified.containsKey(3)); + assertFalse(modified.containsKey(4)); + assertTrue(cache.containsKey(3)); + assertTrue(cache.containsKey(4)); + + // clear cache from initial keys + + kv.setCurrentKey(5); + kv.value(); + kv.setCurrentKey(6); + kv.value(); + kv.setCurrentKey(7); + kv.value(); + + assertFalse(modified.containsKey(5)); + assertTrue(modified.containsKey(6)); + assertTrue(modified.containsKey(7)); + + assertFalse(cache.containsKey(1)); + assertFalse(cache.containsKey(2)); + assertFalse(cache.containsKey(3)); + assertFalse(cache.containsKey(4)); + + kv.setCurrentKey(2); + assertEquals("a", kv.value()); + + long checkpointTs = System.currentTimeMillis(); + + // Draw a snapshot that we will restore later + KvStateSnapshot snapshot1 = kv.snapshot(682375462379L, checkpointTs); + assertTrue(modified.isEmpty()); + + // Do some updates then draw another snapshot (imitate a partial + // failure), these updates should not be visible if we restore snapshot1 + kv.setCurrentKey(1); + kv.update("123"); + kv.setCurrentKey(3); + kv.update("456"); + kv.setCurrentKey(2); + kv.notifyCheckpointComplete(682375462379L); + kv.update("2"); + kv.setCurrentKey(4); + kv.update("4"); + kv.update("5"); + + kv.snapshot(6823754623710L, checkpointTs + 10); + + // restore the second snapshot and validate it (we set a new default + // value here to make sure that the default wasn't written) + KvState restored = snapshot1.restoreState(backend, IntSerializer.INSTANCE, + StringSerializer.INSTANCE, "b", getClass().getClassLoader(), 6823754623711L); + + restored.setCurrentKey(1); + assertEquals("b", restored.value()); + restored.setCurrentKey(2); + assertEquals("b", restored.value()); + restored.setCurrentKey(3); + assertEquals("3", restored.value()); + restored.setCurrentKey(4); + assertEquals("4", restored.value()); + + backend.close(); + } + + private static boolean isTableCreated(Connection con, String tableName) throws SQLException { + return con.getMetaData().getTables(null, null, tableName.toUpperCase(), null).next(); + } + + private static boolean isTableEmpty(Connection con, String tableName) throws SQLException { + try (Statement smt = con.createStatement()) { + ResultSet res = smt.executeQuery("select * from " + tableName); + return !res.next(); + } + } + + private static String localFileUri(File path) { + return path.toURI().toString(); + } + + private static void deleteDirectorySilently(File dir) { + try { + FileUtils.deleteDirectory(dir); + } catch (IOException ignored) { + } + } + + private static class MockAdapter extends DerbyAdapter { + + private static final long serialVersionUID = 1L; + public AtomicInteger numCompcations = new AtomicInteger(0); + public int keptAlive = 0; + + @Override + public void compactKvStates(String kvStateId, Connection con, long lowerTs, long upperTs) throws SQLException { + numCompcations.incrementAndGet(); + } + + @Override + public void keepAlive(Connection con) throws SQLException { + keptAlive++; + } + } + +} diff --git a/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DerbyAdapter.java b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DerbyAdapter.java new file mode 100644 index 0000000000000..1f13f4bdc397f --- /dev/null +++ b/flink-contrib/flink-streaming-contrib/src/test/java/org/apache/flink/contrib/streaming/state/DerbyAdapter.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Types; +import java.util.List; +import java.util.concurrent.Callable; + +import org.apache.flink.api.java.tuple.Tuple2; + +/** + * Adapter for the Derby JDBC driver which has slightly restricted CREATE TABLE + * and SELECT semantics compared to the default assumptions. + * + */ +public class DerbyAdapter extends MySqlAdapter { + + private static final long serialVersionUID = 1L; + + /** + * We need to override this method as Derby does not support the + * "IF NOT EXISTS" clause at table creation + */ + @Override + public void createCheckpointsTable(String jobId, Connection con) throws SQLException { + + try (Statement smt = con.createStatement()) { + smt.executeUpdate( + "CREATE TABLE checkpoints_" + jobId + + " (" + + "checkpointId bigint, " + + "timestamp bigint, " + + "handleId bigint," + + "checkpoint blob," + + "PRIMARY KEY (handleId)" + + ")"); + } catch (SQLException se) { + if (se.getSQLState().equals("X0Y32")) { + // table already created, ignore + } else { + throw se; + } + } + } + + /** + * We need to override this method as Derby does not support the + * "IF NOT EXISTS" clause at table creation + */ + @Override + public void createKVStateTable(String stateId, Connection con) throws SQLException { + + validateStateId(stateId); + try (Statement smt = con.createStatement()) { + smt.executeUpdate( + "CREATE TABLE " + stateId + + " (" + + "timestamp bigint, " + + "k varchar(256) for bit data, " + + "v blob, " + + "PRIMARY KEY (k, timestamp)" + + ")"); + } catch (SQLException se) { + if (se.getSQLState().equals("X0Y32")) { + // table already created, ignore + } else { + throw se; + } + } + } + + /** + * We need to override this method as Derby does not support "LIMIT n" for + * select statements. + */ + @Override + public String prepareKeyLookup(String stateId) throws SQLException { + validateStateId(stateId); + return "SELECT v " + "FROM " + stateId + + " WHERE k = ? " + + " AND timestamp <= ?" + + " ORDER BY timestamp DESC"; + } + + @Override + public void compactKvStates(String stateId, Connection con, long lowerBound, long upperBound) + throws SQLException { + validateStateId(stateId); + + try (Statement smt = con.createStatement()) { + smt.executeUpdate("DELETE FROM " + stateId + " t1" + + " WHERE EXISTS" + + " (" + + " SELECT * FROM " + stateId + " t2" + + " WHERE t2.k = t1.k" + + " AND t2.timestamp > t1.timestamp" + + " AND t2.timestamp <=" + upperBound + + " AND t2.timestamp >= " + lowerBound + + " )"); + } + } + + @Override + public String prepareKVCheckpointInsert(String stateId) throws SQLException { + validateStateId(stateId); + return "INSERT INTO " + stateId + " (timestamp, k, v) VALUES (?,?,?)"; + } + + @Override + public void insertBatch(final String stateId, final DbBackendConfig conf, + final Connection con, final PreparedStatement insertStatement, final long checkpointTs, + final List> toInsert) throws IOException { + + SQLRetrier.retry(new Callable() { + public Void call() throws Exception { + con.setAutoCommit(false); + for (Tuple2 kv : toInsert) { + setKVInsertParams(stateId, insertStatement, checkpointTs, kv.f0, kv.f1); + insertStatement.addBatch(); + } + insertStatement.executeBatch(); + con.commit(); + con.setAutoCommit(true); + insertStatement.clearBatch(); + return null; + } + }, new Callable() { + public Void call() throws Exception { + con.rollback(); + insertStatement.clearBatch(); + return null; + } + }, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries()); + } + + private void setKVInsertParams(String stateId, PreparedStatement insertStatement, long checkpointId, + byte[] key, byte[] value) throws SQLException { + insertStatement.setLong(1, checkpointId); + insertStatement.setBytes(2, key); + if (value != null) { + insertStatement.setBytes(3, value); + } else { + insertStatement.setNull(3, Types.BLOB); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index fdb59d9a89a56..09dd2d97768cd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -489,14 +489,16 @@ public void restoreLatestCheckpointedState( return; } } + + long recoveryTimestamp = System.currentTimeMillis(); if (allOrNothingState) { Map stateCounts = new HashMap(); - + for (StateForTask state : latest.getStates()) { ExecutionJobVertex vertex = tasks.get(state.getOperatorId()); Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt(); - exec.setInitialState(state.getState()); + exec.setInitialState(state.getState(), recoveryTimestamp); Integer count = stateCounts.get(vertex); if (count != null) { @@ -519,7 +521,7 @@ public void restoreLatestCheckpointedState( for (StateForTask state : latest.getStates()) { ExecutionJobVertex vertex = tasks.get(state.getOperatorId()); Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt(); - exec.setInitialState(state.getState()); + exec.setInitialState(state.getState(), recoveryTimestamp); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java index 34b7946fc04c7..82d8e7cde4c5f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointIDCounter.java @@ -39,5 +39,5 @@ public interface CheckpointIDCounter { * @return The previous checkpoint ID */ long getAndIncrement() throws Exception; - + } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 558fcd039d71c..e6a1583088d26 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -83,7 +83,9 @@ public final class TaskDeploymentDescriptor implements Serializable { private final List requiredClasspaths; private final SerializedValue> operatorState; - + + private long recoveryTimestamp; + /** * Constructs a task deployment descriptor. */ @@ -94,7 +96,7 @@ public TaskDeploymentDescriptor( List producedPartitions, List inputGates, List requiredJarFiles, List requiredClasspaths, - int targetSlotNumber, SerializedValue> operatorState) { + int targetSlotNumber, SerializedValue> operatorState, long recoveryTimestamp) { checkArgument(indexInSubtaskGroup >= 0); checkArgument(numberOfSubtasks > indexInSubtaskGroup); @@ -115,6 +117,7 @@ public TaskDeploymentDescriptor( this.requiredClasspaths = checkNotNull(requiredClasspaths); this.targetSlotNumber = targetSlotNumber; this.operatorState = operatorState; + this.recoveryTimestamp = recoveryTimestamp; } public TaskDeploymentDescriptor( @@ -128,7 +131,7 @@ public TaskDeploymentDescriptor( this(jobID, vertexID, executionId, taskName, indexInSubtaskGroup, numberOfSubtasks, jobConfiguration, taskConfiguration, invokableClassName, producedPartitions, - inputGates, requiredJarFiles, requiredClasspaths, targetSlotNumber, null); + inputGates, requiredJarFiles, requiredClasspaths, targetSlotNumber, null, -1); } /** @@ -245,4 +248,8 @@ private String collectionToString(Collection collection) { public SerializedValue> getOperatorState() { return operatorState; } + + public long getRecoveryTimestamp() { + return recoveryTimestamp; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index faabfb3bdaa2a..ce17525fe7e23 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -135,6 +135,8 @@ public class Execution implements Serializable { private volatile InstanceConnectionInfo assignedResourceLocation; // for the archived execution private SerializedValue> operatorState; + + private long recoveryTimestamp; /** The execution context which is used to execute futures. */ @SuppressWarnings("NonSerializableFieldInSerializableClass") @@ -231,11 +233,12 @@ public void prepareForArchiving() { partialInputChannelDeploymentDescriptors = null; } - public void setInitialState(SerializedValue> initialState) { + public void setInitialState(SerializedValue> initialState, long recoveryTimestamp) { if (state != ExecutionState.CREATED) { throw new IllegalArgumentException("Can only assign operator state when execution attempt is in CREATED"); } this.operatorState = initialState; + this.recoveryTimestamp = recoveryTimestamp; } // -------------------------------------------------------------------------------------------- @@ -359,7 +362,7 @@ public void deployToSlot(final SimpleSlot slot) throws JobException { attemptNumber, slot.getInstance().getInstanceConnectionInfo().getHostname())); } - final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(attemptId, slot, operatorState); + final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(attemptId, slot, operatorState, recoveryTimestamp); // register this execution at the execution graph, to receive call backs vertex.getExecutionGraph().registerExecution(this); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java index 1e5d02cb95431..d10aac13a5453 100755 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java @@ -811,6 +811,7 @@ else if (current != JobStatus.RESTARTING) { * *

The recovery of checkpoints might block. Make sure that calls to this method don't * block the job manager actor and run asynchronously. + * */ public void restoreLatestCheckpointedState() throws Exception { synchronized (progressLock) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java index 6a635287d912e..fba5652bbc427 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java @@ -616,7 +616,8 @@ void notifyStateTransition(ExecutionAttemptID executionId, ExecutionState newSta TaskDeploymentDescriptor createDeploymentDescriptor( ExecutionAttemptID executionId, SimpleSlot targetSlot, - SerializedValue> operatorState) { + SerializedValue> operatorState, + long recoveryTimestamp) { // Produced intermediate results List producedPartitions = new ArrayList(resultPartitions.size()); @@ -651,7 +652,7 @@ TaskDeploymentDescriptor createDeploymentDescriptor( subTaskIndex, getTotalNumberOfParallelSubtasks(), getExecutionGraph().getJobConfiguration(), jobVertex.getJobVertex().getConfiguration(), jobVertex.getJobVertex().getInvokableClassName(), producedPartitions, consumedPartitions, jarFiles, classpaths, targetSlot.getRoot().getSlotNumber(), - operatorState); + operatorState, recoveryTimestamp); } // -------------------------------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java index 894e6d91fff4b..fac4ec471df5a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java @@ -31,8 +31,9 @@ public interface StatefulTask> { * a snapshot of the state from a previous execution. * * @param stateHandle The handle to the state. + * @param recoveryTimestamp Global recovery timestamp. */ - void setInitialState(T stateHandle) throws Exception; + void setInitialState(T stateHandle, long recoveryTimestamp) throws Exception; /** * This method is either called directly and asynchronously by the checkpoint diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java index 3d6c56c806ab7..e2e521cb3269f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KvStateSnapshot.java @@ -57,7 +57,8 @@ KvState restoreState( TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue, - ClassLoader classLoader) throws Exception; + ClassLoader classLoader, + long recoveryTimestamp) throws Exception; /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java index f8b1cfdbed642..293de956a055c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateBackend.java @@ -18,12 +18,12 @@ package org.apache.flink.runtime.state; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.execution.Environment; import java.io.IOException; import java.io.OutputStream; @@ -31,32 +31,32 @@ /** * A state backend defines how state is stored and snapshotted during checkpoints. - * + * * @param The type of backend itself. This generic parameter is used to refer to the * type of backend when creating state backed by this backend. */ public abstract class StateBackend> implements java.io.Serializable { - + private static final long serialVersionUID = 4620413814639220247L; - + // ------------------------------------------------------------------------ // initialization and cleanup // ------------------------------------------------------------------------ - + /** * This method is called by the task upon deployment to initialize the state backend for * data for a specific job. - * - * @param job The ID of the job for which the state backend instance checkpoints data. + * + * @param The {@link Environment} of the task that instantiated the state backend * @throws Exception Overwritten versions of this method may throw exceptions, in which * case the job that uses the state backend is considered failed during * deployment. */ - public abstract void initializeForJob(JobID job) throws Exception; + public abstract void initializeForJob(Environment env) throws Exception; /** * Disposes all state associated with the current job. - * + * * @throws Exception Exceptions may occur during disposal of the state and should be forwarded. */ public abstract void disposeAllStateForCurrentJob() throws Exception; @@ -64,33 +64,35 @@ public abstract class StateBackend> implem /** * Closes the state backend, releasing all internal resources, but does not delete any persistent * checkpoint data. - * + * * @throws Exception Exceptions can be forwarded and will be logged by the system */ public abstract void close() throws Exception; - + // ------------------------------------------------------------------------ // key/value state // ------------------------------------------------------------------------ /** * Creates a key/value state backed by this state backend. - * + * + * @param stateId Unique id that identifies the kv state in the streaming program. + * @param stateName Name of the created state * @param keySerializer The serializer for the key. * @param valueSerializer The serializer for the value. * @param defaultValue The value that is returned when no other value has been associated with a key, yet. * @param The type of the key. * @param The type of the value. - * + * * @return A new key/value state backed by this backend. - * + * * @throws Exception Exceptions may occur during initialization of the state and should be forwarded. */ - public abstract KvState createKvState( + public abstract KvState createKvState(String stateId, String stateName, TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws Exception; - - + + // ------------------------------------------------------------------------ // storing state for a checkpoint // ------------------------------------------------------------------------ @@ -98,16 +100,16 @@ public abstract KvState createKvState( /** * Creates an output stream that writes into the state of the given checkpoint. When the stream * is closes, it returns a state handle that can retrieve the state back. - * + * * @param checkpointID The ID of the checkpoint. * @param timestamp The timestamp of the checkpoint. * @return An output stream that writes state for the given checkpoint. - * + * * @throws Exception Exceptions may occur while creating the stream and should be forwarded. */ public abstract CheckpointStateOutputStream createCheckpointStateOutputStream( long checkpointID, long timestamp) throws Exception; - + /** * Creates a {@link DataOutputView} stream that writes into the state of the given checkpoint. * When the stream is closes, it returns a state handle that can retrieve the state back. @@ -125,20 +127,20 @@ public CheckpointStateOutputView createCheckpointStateOutputView( /** * Writes the given state into the checkpoint, and returns a handle that can retrieve the state back. - * + * * @param state The state to be checkpointed. * @param checkpointID The ID of the checkpoint. * @param timestamp The timestamp of the checkpoint. * @param The type of the state. - * + * * @return A state handle that can retrieve the checkpoined state. - * + * * @throws Exception Exceptions may occur during serialization / storing the state and should be forwarded. */ public abstract StateHandle checkpointStateSerializable( S state, long checkpointID, long timestamp) throws Exception; - - + + // ------------------------------------------------------------------------ // Checkpoint state output stream // ------------------------------------------------------------------------ @@ -151,7 +153,7 @@ public static abstract class CheckpointStateOutputStream extends OutputStream { /** * Closes the stream and gets a state handle that can create an input stream * producing the data written to this stream. - * + * * @return A state handle that can create an input stream producing the data written to this stream. * @throws IOException Thrown, if the stream cannot be closed. */ @@ -162,9 +164,9 @@ public static abstract class CheckpointStateOutputStream extends OutputStream { * A dedicated DataOutputView stream that produces a {@code StateHandle} when closed. */ public static final class CheckpointStateOutputView extends DataOutputViewStreamWrapper { - + private final CheckpointStateOutputStream out; - + public CheckpointStateOutputView(CheckpointStateOutputStream out) { super(out); this.out = out; @@ -193,7 +195,7 @@ public void close() throws IOException { private static final class DataInputViewHandle implements StateHandle { private static final long serialVersionUID = 2891559813513532079L; - + private final StreamStateHandle stream; private DataInputViewHandle(StreamStateHandle stream) { @@ -202,7 +204,7 @@ private DataInputViewHandle(StreamStateHandle stream) { @Override public DataInputView getState(ClassLoader userCodeClassLoader) throws Exception { - return new DataInputViewStreamWrapper(stream.getState(userCodeClassLoader)); + return new DataInputViewStreamWrapper(stream.getState(userCodeClassLoader)); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java index 88b0d1899014b..96e0eb526aebb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java @@ -38,17 +38,19 @@ public class StateUtils { * The state carrier operator. * @param state * The state handle. + * @param recoveryTimestamp + * Global recovery timestamp * @param * Type bound for the */ public static > void setOperatorState(StatefulTask op, - StateHandle state) throws Exception { + StateHandle state, long recoveryTimestamp) throws Exception { @SuppressWarnings("unchecked") StatefulTask typedOp = (StatefulTask) op; @SuppressWarnings("unchecked") T typedHandle = (T) state; - typedOp.setInitialState(typedHandle); + typedOp.setInitialState(typedHandle, recoveryTimestamp); } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java index 781ee3dafd89d..c5c2fd7661987 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsHeapKvStateSnapshot.java @@ -62,7 +62,8 @@ public FsHeapKvState restoreState( final TypeSerializer keySerializer, final TypeSerializer valueSerializer, V defaultValue, - ClassLoader classLoader) throws Exception { + ClassLoader classLoader, + long recoveryTimestamp) throws Exception { // validity checks if (!keySerializer.getClass().getName().equals(keySerializerClassName) || diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java index d7b392cda463e..25c63e5d2252a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java @@ -18,14 +18,13 @@ package org.apache.flink.runtime.state.filesystem; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.state.StateBackend; - +import org.apache.flink.runtime.state.StateHandle; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,27 +37,27 @@ /** * The file state backend is a state backend that stores the state of streaming jobs in a file system. - * + * *

The state backend has one core directory into which it puts all checkpoint data. Inside that * directory, it creates a directory per job, inside which each checkpoint gets a directory, with * files for each state, for example: - * + * * {@code hdfs://namenode:port/flink-checkpoints//chk-17/6ba7b810-9dad-11d1-80b4-00c04fd430c8 } */ public class FsStateBackend extends StateBackend { private static final long serialVersionUID = -8191916350224044011L; - + private static final Logger LOG = LoggerFactory.getLogger(FsStateBackend.class); - - + + /** The path to the directory for the checkpoint data, including the file system * description via scheme and optional authority */ private final Path basePath; - + /** The directory (job specific) into this initialized instance of the backend stores its data */ private transient Path checkpointDirectory; - + /** Cached handle to the file system for file operations */ private transient FileSystem filesystem; @@ -104,14 +103,14 @@ public FsStateBackend(Path checkpointDataUri) throws IOException { /** * Creates a new state backend that stores its checkpoint data in the file system and location * defined by the given URI. - * + * *

A file system for the file system scheme in the URI (e.g., 'file://', 'hdfs://', or 'S3://') * must be accessible via {@link FileSystem#get(URI)}. - * + * *

For a state backend targeting HDFS, this means that the URI must either specify the authority * (host and port), or that the Hadoop configuration that describes that information must be in the * classpath. - * + * * @param checkpointDataUri The URI describing the filesystem (scheme and optionally authority), * and the path to teh checkpoint data directory. * @throws IOException Thrown, if no file system can be found for the scheme in the URI. @@ -119,7 +118,7 @@ public FsStateBackend(Path checkpointDataUri) throws IOException { public FsStateBackend(URI checkpointDataUri) throws IOException { final String scheme = checkpointDataUri.getScheme(); final String path = checkpointDataUri.getPath(); - + // some validity checks if (scheme == null) { throw new IllegalArgumentException("The scheme (hdfs://, file://, etc) is null. " + @@ -132,12 +131,12 @@ public FsStateBackend(URI checkpointDataUri) throws IOException { if (path.length() == 0 || path.equals("/")) { throw new IllegalArgumentException("Cannot use the root directory for checkpoints."); } - + // we do a bit of work to make sure that the URI for the filesystem refers to exactly the same // (distributed) filesystem on all hosts and includes full host/port information, even if the // original URI did not include that. We count on the filesystem loading from the configuration // to fill in the missing data. - + // try to grab the file system for this path/URI this.filesystem = FileSystem.get(checkpointDataUri); if (this.filesystem == null) { @@ -151,7 +150,7 @@ public FsStateBackend(URI checkpointDataUri) throws IOException { } catch (URISyntaxException e) { throw new IOException( - String.format("Cannot create file system URI for checkpointDataUri %s and filesystem URI %s", + String.format("Cannot create file system URI for checkpointDataUri %s and filesystem URI %s", checkpointDataUri, fsURI), e); } } @@ -159,7 +158,7 @@ public FsStateBackend(URI checkpointDataUri) throws IOException { /** * Gets the base directory where all state-containing files are stored. * The job specific directory is created inside this directory. - * + * * @return The base directory. */ public Path getBasePath() { @@ -169,7 +168,7 @@ public Path getBasePath() { /** * Gets the directory where this state backend stores its checkpoint data. Will be null if * the state backend has not been initialized. - * + * * @return The directory where this state backend stores its checkpoint data. */ public Path getCheckpointDirectory() { @@ -179,16 +178,16 @@ public Path getCheckpointDirectory() { /** * Checks whether this state backend is initialized. Note that initialization does not carry * across serialization. After each serialization, the state backend needs to be initialized. - * + * * @return True, if the file state backend has been initialized, false otherwise. */ public boolean isInitialized() { - return filesystem != null && checkpointDirectory != null; + return filesystem != null && checkpointDirectory != null; } /** * Gets the file system handle for the file system that stores the state for this backend. - * + * * @return This backend's file system handle. */ public FileSystem getFileSystem() { @@ -203,13 +202,13 @@ public FileSystem getFileSystem() { // ------------------------------------------------------------------------ // initialization and cleanup // ------------------------------------------------------------------------ - + @Override - public void initializeForJob(JobID jobId) throws Exception { - Path dir = new Path(basePath, jobId.toString()); - + public void initializeForJob(Environment env) throws Exception { + Path dir = new Path(basePath, env.getJobID().toString()); + LOG.info("Initializing file state backend to URI " + dir); - + filesystem = basePath.getFileSystem(); filesystem.mkdirs(dir); @@ -220,7 +219,7 @@ public void initializeForJob(JobID jobId) throws Exception { public void disposeAllStateForCurrentJob() throws Exception { FileSystem fs = this.filesystem; Path dir = this.checkpointDirectory; - + if (fs != null && dir != null) { this.filesystem = null; this.checkpointDirectory = null; @@ -237,9 +236,9 @@ public void close() throws Exception {} // ------------------------------------------------------------------------ // state backend operations // ------------------------------------------------------------------------ - + @Override - public FsHeapKvState createKvState( + public FsHeapKvState createKvState(String stateId, String stateName, TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) throws Exception { return new FsHeapKvState(keySerializer, valueSerializer, defaultValue, this); } @@ -254,7 +253,7 @@ public StateHandle checkpointStateSerializable( final Path checkpointDir = createCheckpointDirPath(checkpointID); filesystem.mkdirs(checkpointDir); - + Exception latestException = null; for (int attempt = 0; attempt < 10; attempt++) { @@ -273,19 +272,19 @@ public StateHandle checkpointStateSerializable( } return new FileSerializableStateHandle(targetPath); } - + throw new Exception("Could not open output stream for state backend", latestException); } - + @Override public FsCheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) throws Exception { checkFileSystemInitialized(); - + final Path checkpointDir = createCheckpointDirPath(checkpointID); filesystem.mkdirs(checkpointDir); - + Exception latestException = null; - + for (int attempt = 0; attempt < 10; attempt++) { Path targetPath = new Path(checkpointDir, UUID.randomUUID().toString()); try { @@ -298,7 +297,7 @@ public FsCheckpointStateOutputStream createCheckpointStateOutputStream(long chec } throw new Exception("Could not open output stream for state backend", latestException); } - + // ------------------------------------------------------------------------ // utilities // ------------------------------------------------------------------------ @@ -308,18 +307,18 @@ private void checkFileSystemInitialized() throws IllegalStateException { throw new IllegalStateException("filesystem has not been re-initialized after deserialization"); } } - + private Path createCheckpointDirPath(long checkpointID) { return new Path(checkpointDirectory, "chk-" + checkpointID); } - + @Override public String toString() { return checkpointDirectory == null ? - "File State Backend @ " + basePath : + "File State Backend @ " + basePath : "File State Backend (initialized) @ " + checkpointDirectory; } - + // ------------------------------------------------------------------------ // Output stream for state checkpointing // ------------------------------------------------------------------------ @@ -331,11 +330,11 @@ public String toString() { public static final class FsCheckpointStateOutputStream extends CheckpointStateOutputStream { private final FSDataOutputStream outStream; - + private final Path filePath; - + private final FileSystem fs; - + private boolean closed; FsCheckpointStateOutputStream(FSDataOutputStream outStream, Path filePath, FileSystem fs) { @@ -373,7 +372,7 @@ public void close() { try { outStream.close(); fs.delete(filePath, false); - + // attempt to delete the parent (will fail and be ignored if the parent has more files) try { fs.delete(filePath.getParent(), false); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java index 1b03defb00852..bda0290e7494b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryHeapKvStateSnapshot.java @@ -70,7 +70,8 @@ public MemHeapKvState restoreState( final TypeSerializer keySerializer, final TypeSerializer valueSerializer, V defaultValue, - ClassLoader classLoader) throws Exception { + ClassLoader classLoader, + long recoveryTimestamp) throws Exception { // validity checks if (!keySerializer.getClass().getName().equals(keySerializerClassName) || diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java index 8d297d4902d78..2963237738311 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java @@ -18,10 +18,10 @@ package org.apache.flink.runtime.state.memory; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import java.io.ByteArrayOutputStream; @@ -31,15 +31,15 @@ /** * A {@link StateBackend} that stores all its data and checkpoints in memory and has no * capabilities to spill to disk. Checkpoints are serialized and the serialized data is - * transferred + * transferred */ public class MemoryStateBackend extends StateBackend { private static final long serialVersionUID = 4109305377809414635L; - + /** The default maximal size that the snapshotted memory state may have (5 MiBytes) */ private static final int DEFAULT_MAX_STATE_SIZE = 5 * 1024 * 1024; - + /** The maximal size that the snapshotted memory state may have */ private final int maxStateSize; @@ -54,7 +54,7 @@ public MemoryStateBackend() { /** * Creates a new memory state backend that accepts states whose serialized forms are * up to the given number of bytes. - * + * * @param maxStateSize The maximal size of the serialized state */ public MemoryStateBackend(int maxStateSize) { @@ -66,7 +66,7 @@ public MemoryStateBackend(int maxStateSize) { // ------------------------------------------------------------------------ @Override - public void initializeForJob(JobID job) { + public void initializeForJob(Environment env) { // nothing to do here } @@ -81,22 +81,22 @@ public void close() throws Exception {} // ------------------------------------------------------------------------ // State backend operations // ------------------------------------------------------------------------ - + @Override - public MemHeapKvState createKvState( + public MemHeapKvState createKvState(String stateId, String stateName, TypeSerializer keySerializer, TypeSerializer valueSerializer, V defaultValue) { return new MemHeapKvState(keySerializer, valueSerializer, defaultValue); } - + /** * Serialized the given state into bytes using Java serialization and creates a state handle that * can re-create that state. - * + * * @param state The state to checkpoint. * @param checkpointID The ID of the checkpoint. * @param timestamp The timestamp of the checkpoint. * @param The type of the state. - * + * * @return A state handle that contains the given state serialized as bytes. * @throws Exception Thrown, if the serialization fails. */ @@ -119,7 +119,7 @@ public CheckpointStateOutputStream createCheckpointStateOutputStream( // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ - + @Override public String toString() { return "MemoryStateBackend (data in heap memory / checkpoints to JobManager)"; @@ -133,18 +133,18 @@ static void checkSize(int size, int maxSize) throws IOException { + " . Consider using a different state backend, like the File System State backend."); } } - + // ------------------------------------------------------------------------ /** * A CheckpointStateOutputStream that writes into a byte array. */ public static final class MemoryCheckpointOutputStream extends CheckpointStateOutputStream { - + private final ByteArrayOutputStream os = new ByteArrayOutputStream(); - + private final int maxSize; - + private boolean closed; public MemoryCheckpointOutputStream(int maxSize) { @@ -177,7 +177,7 @@ public StreamStateHandle closeAndGetHandle() throws IOException { /** * Closes the stream and returns the byte array containing the stream's data. * @return The byte array containing the stream's data. - * @throws IOException Thrown if the size of the data exceeds the maximal + * @throws IOException Thrown if the size of the data exceeds the maximal */ public byte[] closeAndGetBytes() throws IOException { if (!closed) { @@ -191,11 +191,11 @@ public byte[] closeAndGetBytes() throws IOException { } } } - + // ------------------------------------------------------------------------ // Static default instance // ------------------------------------------------------------------------ - + /** The default instance of this state backend, using the default maximal state size */ private static final MemoryStateBackend DEFAULT_INSTANCE = new MemoryStateBackend(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index c8d50c793f744..ae1c0cda2db1a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -219,6 +219,8 @@ public class Task implements Runnable { * initialization, to be memory friendly */ private volatile SerializedValue> operatorState; + private volatile long recoveryTs; + /** *

IMPORTANT: This constructor may not start any work that would need to * be undone in the case of a failing task deployment.

@@ -252,6 +254,7 @@ public Task(TaskDeploymentDescriptor tdd, this.requiredClasspaths = checkNotNull(tdd.getRequiredClasspaths()); this.nameOfInvokableClass = checkNotNull(tdd.getInvokableClassName()); this.operatorState = tdd.getOperatorState(); + this.recoveryTs = tdd.getRecoveryTimestamp(); this.memoryManager = checkNotNull(memManager); this.ioManager = checkNotNull(ioManager); @@ -535,13 +538,14 @@ else if (current == ExecutionState.CANCELING) { // get our private reference onto the stack (be safe against concurrent changes) SerializedValue> operatorState = this.operatorState; + long recoveryTs = this.recoveryTs; if (operatorState != null) { if (invokable instanceof StatefulTask) { try { StateHandle state = operatorState.deserializeValue(userCodeClassLoader); StatefulTask op = (StatefulTask) invokable; - StateUtils.setOperatorState(op, state); + StateUtils.setOperatorState(op, state, recoveryTs); } catch (Exception e) { throw new RuntimeException("Failed to deserialize state handle and setup initial operator state.", e); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 32c15bf613af1..7b2c2d4b3e8e6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -45,25 +45,25 @@ * Tests concerning the restoring of state from a checkpoint to the task executions. */ public class CheckpointStateRestoreTest { - + private static final ClassLoader cl = Thread.currentThread().getContextClassLoader(); - + @Test public void testSetState() { try { final SerializedValue> serializedState = new SerializedValue>( new LocalStateHandle(new SerializableObject())); - + final JobID jid = new JobID(); final JobVertexID statefulId = new JobVertexID(); final JobVertexID statelessId = new JobVertexID(); - + Execution statefulExec1 = mockExecution(); Execution statefulExec2 = mockExecution(); Execution statefulExec3 = mockExecution(); Execution statelessExec1 = mockExecution(); Execution statelessExec2 = mockExecution(); - + ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0); ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1); ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2); @@ -74,44 +74,44 @@ public void testSetState() { new ExecutionVertex[] { stateful1, stateful2, stateful3 }); ExecutionJobVertex stateless = mockExecutionJobVertex(statelessId, new ExecutionVertex[] { stateless1, stateless2 }); - + Map map = new HashMap(); map.put(statefulId, stateful); map.put(statelessId, stateless); - - + + CheckpointCoordinator coord = new CheckpointCoordinator(jid, 200000L, new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 }, new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 }, new ExecutionVertex[0], cl, new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE); - + // create ourselves a checkpoint with state final long timestamp = 34623786L; coord.triggerCheckpoint(timestamp); - + PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next(); final long checkpointId = pending.getCheckpointId(); - + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, serializedState)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId)); - + assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); - + // let the coordinator inject the state coord.restoreLatestCheckpointedState(map, true, false); - + // verify that each stateful vertex got the state - verify(statefulExec1, times(1)).setInitialState(serializedState); - verify(statefulExec2, times(1)).setInitialState(serializedState); - verify(statefulExec3, times(1)).setInitialState(serializedState); - verify(statelessExec1, times(0)).setInitialState(Mockito.>>any()); - verify(statelessExec2, times(0)).setInitialState(Mockito.>>any()); + verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong()); + verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong()); + verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong()); + verify(statelessExec1, times(0)).setInitialState(Mockito.>>any(), Mockito.anyLong()); + verify(statelessExec2, times(0)).setInitialState(Mockito.>>any(), Mockito.anyLong()); } catch (Exception e) { e.printStackTrace(); @@ -189,7 +189,7 @@ public void testStateOnlyPartiallyAvailable() { fail(e.getMessage()); } } - + @Test public void testNoCheckpointAvailable() { try { @@ -213,20 +213,20 @@ public void testNoCheckpointAvailable() { fail(e.getMessage()); } } - + // ------------------------------------------------------------------------ private Execution mockExecution() { return mockExecution(ExecutionState.RUNNING); } - + private Execution mockExecution(ExecutionState state) { Execution mock = mock(Execution.class); when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID()); when(mock.getState()).thenReturn(state); return mock; } - + private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask) { ExecutionVertex mock = mock(ExecutionVertex.class); when(mock.getJobvertexId()).thenReturn(vertexId); @@ -234,7 +234,7 @@ private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID ver when(mock.getCurrentExecutionAttempt()).thenReturn(execution); return mock; } - + private ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) { ExecutionJobVertex vertex = mock(ExecutionJobVertex.class); when(vertex.getParallelism()).thenReturn(vertices.length); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java new file mode 100644 index 0000000000000..71bec4a41b8c5 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.operators.testutils; + +import java.util.Map; +import java.util.concurrent.Future; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.accumulators.AccumulatorRegistry; +import org.apache.flink.runtime.broadcast.BroadcastVariableManager; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; + +public class DummyEnvironment implements Environment { + + private final String taskName; + private final int numSubTasks; + private final int subTaskIndex; + private final JobID jobId = new JobID(); + private final JobVertexID jobVertexId = new JobVertexID(); + + public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex) { + this.taskName = taskName; + this.numSubTasks = numSubTasks; + this.subTaskIndex = subTaskIndex; + } + + @Override + public JobID getJobID() { + return jobId; + } + + @Override + public JobVertexID getJobVertexId() { + return jobVertexId; + } + + @Override + public ExecutionAttemptID getExecutionId() { + return null; + } + + @Override + public Configuration getTaskConfiguration() { + return null; + } + + @Override + public TaskManagerRuntimeInfo getTaskManagerInfo() { + return null; + } + + @Override + public Configuration getJobConfiguration() { + return null; + } + + @Override + public int getNumberOfSubtasks() { + return numSubTasks; + } + + @Override + public int getIndexInSubtaskGroup() { + return subTaskIndex; + } + + @Override + public InputSplitProvider getInputSplitProvider() { + return null; + } + + @Override + public IOManager getIOManager() { + return null; + } + + @Override + public MemoryManager getMemoryManager() { + return null; + } + + @Override + public String getTaskName() { + return taskName; + } + + @Override + public String getTaskNameWithSubtasks() { + return taskName; + } + + @Override + public ClassLoader getUserClassLoader() { + return null; + } + + @Override + public Map> getDistributedCacheEntries() { + return null; + } + + @Override + public BroadcastVariableManager getBroadcastVariableManager() { + return null; + } + + @Override + public AccumulatorRegistry getAccumulatorRegistry() { + return null; + } + + @Override + public void acknowledgeCheckpoint(long checkpointId) { + } + + @Override + public void acknowledgeCheckpoint(long checkpointId, StateHandle state) { + } + + @Override + public ResultPartitionWriter getWriter(int index) { + return null; + } + + @Override + public ResultPartitionWriter[] getAllWriters() { + return null; + } + + @Override + public InputGate getInputGate(int index) { + return null; + } + + @Override + public InputGate[] getAllInputGates() { + return null; + } + +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java index a6cfae39b60e9..37ccde2bb5d19 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java @@ -18,8 +18,22 @@ package org.apache.flink.runtime.state; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.util.Random; +import java.util.UUID; + import org.apache.commons.io.FileUtils; -import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.FloatSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; @@ -29,41 +43,34 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.fs.Path; import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.filesystem.FileStreamStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; import org.apache.flink.types.IntValue; import org.apache.flink.types.StringValue; +import org.apache.flink.util.OperatingSystem; import org.junit.Test; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.util.Random; -import java.util.UUID; - -import static org.junit.Assert.*; - public class FileStateBackendTest { - + @Test public void testSetupAndSerialization() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { final String backendDir = localFileUri(tempDir); FsStateBackend originalBackend = new FsStateBackend(backendDir); - + assertFalse(originalBackend.isInitialized()); assertEquals(new URI(backendDir), originalBackend.getBasePath().toUri()); assertNull(originalBackend.getCheckpointDirectory()); - + // serialize / copy the backend FsStateBackend backend = CommonTestUtils.createCopySerializable(originalBackend); assertFalse(backend.isInitialized()); assertEquals(new URI(backendDir), backend.getBasePath().toUri()); assertNull(backend.getCheckpointDirectory()); - + // no file operations should be possible right now try { backend.checkpointStateSerializable("exception train rolling in", 2L, System.currentTimeMillis()); @@ -71,17 +78,17 @@ public void testSetupAndSerialization() { } catch (IllegalStateException e) { // supreme! } - - backend.initializeForJob(new JobID()); + + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); assertNotNull(backend.getCheckpointDirectory()); - + File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); assertTrue(checkpointDir.exists()); assertTrue(isDirectoryEmpty(checkpointDir)); - + backend.disposeAllStateForCurrentJob(); assertNull(backend.getCheckpointDirectory()); - + assertTrue(isDirectoryEmpty(tempDir)); } catch (Exception e) { @@ -92,20 +99,20 @@ public void testSetupAndSerialization() { deleteDirectorySilently(tempDir); } } - + @Test public void testSerializableState() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); String state1 = "dummy state"; String state2 = "row row row your boat"; Integer state3 = 42; - + StateHandle handle1 = backend.checkpointStateSerializable(state1, 439568923746L, System.currentTimeMillis()); StateHandle handle2 = backend.checkpointStateSerializable(state2, 439568923746L, System.currentTimeMillis()); StateHandle handle3 = backend.checkpointStateSerializable(state3, 439568923746L, System.currentTimeMillis()); @@ -113,15 +120,15 @@ public void testSerializableState() { assertFalse(isDirectoryEmpty(checkpointDir)); assertEquals(state1, handle1.getState(getClass().getClassLoader())); handle1.discardState(); - + assertFalse(isDirectoryEmpty(checkpointDir)); assertEquals(state2, handle2.getState(getClass().getClassLoader())); handle2.discardState(); - + assertFalse(isDirectoryEmpty(checkpointDir)); assertEquals(state3, handle3.getState(getClass().getClassLoader())); handle3.discardState(); - + assertTrue(isDirectoryEmpty(checkpointDir)); } catch (Exception e) { @@ -138,7 +145,7 @@ public void testStateOutputStream() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); @@ -146,7 +153,7 @@ public void testStateOutputStream() { byte[] state2 = new byte[1]; byte[] state3 = new byte[0]; byte[] state4 = new byte[177]; - + Random rnd = new Random(); rnd.nextBytes(state1); rnd.nextBytes(state2); @@ -155,21 +162,21 @@ public void testStateOutputStream() { long checkpointId = 97231523452L; - FsStateBackend.FsCheckpointStateOutputStream stream1 = + FsStateBackend.FsCheckpointStateOutputStream stream1 = backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); FsStateBackend.FsCheckpointStateOutputStream stream2 = backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); FsStateBackend.FsCheckpointStateOutputStream stream3 = backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); - + stream1.write(state1); stream2.write(state2); stream3.write(state3); - + FileStreamStateHandle handle1 = stream1.closeAndGetHandle(); FileStreamStateHandle handle2 = stream2.closeAndGetHandle(); FileStreamStateHandle handle3 = stream3.closeAndGetHandle(); - + // use with try-with-resources StreamStateHandle handle4; try (StateBackend.CheckpointStateOutputStream stream4 = @@ -177,7 +184,7 @@ public void testStateOutputStream() { stream4.write(state4); handle4 = stream4.closeAndGetHandle(); } - + // close before accessing handle StateBackend.CheckpointStateOutputStream stream5 = backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); @@ -189,22 +196,22 @@ public void testStateOutputStream() { } catch (IOException e) { // uh-huh } - + validateBytesInStream(handle1.getState(getClass().getClassLoader()), state1); handle1.discardState(); assertFalse(isDirectoryEmpty(checkpointDir)); ensureLocalFileDeleted(handle1.getFilePath()); - + validateBytesInStream(handle2.getState(getClass().getClassLoader()), state2); handle2.discardState(); assertFalse(isDirectoryEmpty(checkpointDir)); ensureLocalFileDeleted(handle2.getFilePath()); - + validateBytesInStream(handle3.getState(getClass().getClassLoader()), state3); handle3.discardState(); assertFalse(isDirectoryEmpty(checkpointDir)); ensureLocalFileDeleted(handle3.getFilePath()); - + validateBytesInStream(handle4.getState(getClass().getClassLoader()), state4); handle4.discardState(); assertTrue(isDirectoryEmpty(checkpointDir)); @@ -223,12 +230,12 @@ public void testKeyValueState() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); KvState kv = - backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); + backend.createKvState("0", "a", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); assertEquals(0, kv.size()); @@ -272,7 +279,7 @@ public void testKeyValueState() { // restore the first snapshot and validate it KvState restored1 = snapshot1.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader()); + IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); assertEquals(2, restored1.size()); restored1.setCurrentKey(1); @@ -282,7 +289,7 @@ public void testKeyValueState() { // restore the first snapshot and validate it KvState restored2 = snapshot2.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader()); + IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); assertEquals(3, restored2.size()); restored2.setCurrentKey(1); @@ -312,12 +319,12 @@ public void testRestoreWithWrongSerializers() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath()); - + KvState kv = - backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); + backend.createKvState("a_0", "a", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); kv.setCurrentKey(1); kv.update("1"); @@ -338,7 +345,7 @@ public void testRestoreWithWrongSerializers() { try { snapshot.restoreState(backend, fakeIntSerializer, - StringSerializer.INSTANCE, null, getClass().getClassLoader()); + StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); fail("should recognize wrong serializers"); } catch (IllegalArgumentException e) { // expected @@ -348,7 +355,7 @@ public void testRestoreWithWrongSerializers() { try { snapshot.restoreState(backend, IntSerializer.INSTANCE, - fakeStringSerializer, null, getClass().getClassLoader()); + fakeStringSerializer, null, getClass().getClassLoader(), 1); fail("should recognize wrong serializers"); } catch (IllegalArgumentException e) { // expected @@ -358,14 +365,14 @@ public void testRestoreWithWrongSerializers() { try { snapshot.restoreState(backend, fakeIntSerializer, - fakeStringSerializer, null, getClass().getClassLoader()); + fakeStringSerializer, null, getClass().getClassLoader(), 1); fail("should recognize wrong serializers"); } catch (IllegalArgumentException e) { // expected } catch (Exception e) { fail("wrong exception"); } - + snapshot.discardState(); assertTrue(isDirectoryEmpty(checkpointDir)); @@ -384,10 +391,10 @@ public void testCopyDefaultValue() { File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(localFileUri(tempDir))); - backend.initializeForJob(new JobID()); - + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); + KvState kv = - backend.createKvState(IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1)); + backend.createKvState("a_0", "a", IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1)); kv.setCurrentKey(1); IntValue default1 = kv.value(); @@ -408,11 +415,11 @@ public void testCopyDefaultValue() { deleteDirectorySilently(tempDir); } } - + // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ - + private static void ensureLocalFileDeleted(Path path) { URI uri = path.toUri(); if ("file".equals(uri.getScheme())) { @@ -423,23 +430,23 @@ private static void ensureLocalFileDeleted(Path path) { throw new IllegalArgumentException("not a local path"); } } - + private static void deleteDirectorySilently(File dir) { try { FileUtils.deleteDirectory(dir); } catch (IOException ignored) {} } - + private static boolean isDirectoryEmpty(File directory) { String[] nested = directory.list(); return nested == null || nested.length == 0; } - + private static String localFileUri(File path) { return path.toURI().toString(); } - + private static void validateBytesInStream(InputStream is, byte[] data) throws IOException { byte[] holder = new byte[data.length]; assertEquals("not enough data", holder.length, is.read(holder)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java index f6d1bb51881c8..4b5aebd0c74cf 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java @@ -40,7 +40,7 @@ * Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}. */ public class MemoryStateBackendTest { - + @Test public void testSerializableState() { try { @@ -49,10 +49,10 @@ public void testSerializableState() { HashMap state = new HashMap<>(); state.put("hey there", 2); state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77); - + StateHandle> handle = backend.checkpointStateSerializable(state, 12, 459); assertNotNull(handle); - + HashMap restored = handle.getState(getClass().getClassLoader()); assertEquals(state, restored); } @@ -99,7 +99,7 @@ public void testStateStream() { oos.writeObject(state); oos.flush(); StreamStateHandle handle = os.closeAndGetHandle(); - + assertNotNull(handle); ObjectInputStream ois = new ObjectInputStream(handle.getState(getClass().getClassLoader())); @@ -124,7 +124,7 @@ public void testOversizedStateStream() { StateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2); ObjectOutputStream oos = new ObjectOutputStream(os); - + try { oos.writeObject(state); oos.flush(); @@ -140,17 +140,17 @@ public void testOversizedStateStream() { fail(e.getMessage()); } } - + @Test public void testKeyValueState() { try { MemoryStateBackend backend = new MemoryStateBackend(); - - KvState kv = - backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); - + + KvState kv = + backend.createKvState("s_0", "s", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); + assertEquals(0, kv.size()); - + // some modifications to the state kv.setCurrentKey(1); assertNull(kv.value()); @@ -163,7 +163,7 @@ public void testKeyValueState() { kv.setCurrentKey(1); assertEquals("1", kv.value()); assertEquals(2, kv.size()); - + // draw a snapshot KvStateSnapshot snapshot1 = kv.snapshot(682375462378L, System.currentTimeMillis()); @@ -179,7 +179,7 @@ public void testKeyValueState() { // draw another snapshot KvStateSnapshot snapshot2 = kv.snapshot(682375462379L, System.currentTimeMillis()); - + // validate the original state assertEquals(3, kv.size()); kv.setCurrentKey(1); @@ -188,10 +188,10 @@ public void testKeyValueState() { assertEquals("u2", kv.value()); kv.setCurrentKey(3); assertEquals("u3", kv.value()); - + // restore the first snapshot and validate it - KvState restored1 = snapshot1.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader()); + KvState restored1 = snapshot1.restoreState(backend, + IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); assertEquals(2, restored1.size()); restored1.setCurrentKey(1); @@ -201,7 +201,7 @@ public void testKeyValueState() { // restore the first snapshot and validate it KvState restored2 = snapshot2.restoreState(backend, - IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader()); + IntSerializer.INSTANCE, StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); assertEquals(3, restored2.size()); restored2.setCurrentKey(1); @@ -216,34 +216,34 @@ public void testKeyValueState() { fail(e.getMessage()); } } - + @Test public void testRestoreWithWrongSerializers() { try { MemoryStateBackend backend = new MemoryStateBackend(); KvState kv = - backend.createKvState(IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); - + backend.createKvState("s_0", "s", IntSerializer.INSTANCE, StringSerializer.INSTANCE, null); + kv.setCurrentKey(1); kv.update("1"); kv.setCurrentKey(2); kv.update("2"); - + KvStateSnapshot snapshot = kv.snapshot(682375462378L, System.currentTimeMillis()); @SuppressWarnings("unchecked") - TypeSerializer fakeIntSerializer = + TypeSerializer fakeIntSerializer = (TypeSerializer) (TypeSerializer) FloatSerializer.INSTANCE; @SuppressWarnings("unchecked") - TypeSerializer fakeStringSerializer = + TypeSerializer fakeStringSerializer = (TypeSerializer) (TypeSerializer) new ValueSerializer(StringValue.class); try { snapshot.restoreState(backend, fakeIntSerializer, - StringSerializer.INSTANCE, null, getClass().getClassLoader()); + StringSerializer.INSTANCE, null, getClass().getClassLoader(), 1); fail("should recognize wrong serializers"); } catch (IllegalArgumentException e) { // expected @@ -253,7 +253,7 @@ public void testRestoreWithWrongSerializers() { try { snapshot.restoreState(backend, IntSerializer.INSTANCE, - fakeStringSerializer, null, getClass().getClassLoader()); + fakeStringSerializer, null, getClass().getClassLoader(), 1); fail("should recognize wrong serializers"); } catch (IllegalArgumentException e) { // expected @@ -263,7 +263,7 @@ public void testRestoreWithWrongSerializers() { try { snapshot.restoreState(backend, fakeIntSerializer, - fakeStringSerializer, null, getClass().getClassLoader()); + fakeStringSerializer, null, getClass().getClassLoader(), 1); fail("should recognize wrong serializers"); } catch (IllegalArgumentException e) { // expected @@ -276,20 +276,20 @@ public void testRestoreWithWrongSerializers() { fail(e.getMessage()); } } - + @Test public void testCopyDefaultValue() { try { MemoryStateBackend backend = new MemoryStateBackend(); KvState kv = - backend.createKvState(IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1)); + backend.createKvState("a_0", "a", IntSerializer.INSTANCE, IntValueSerializer.INSTANCE, new IntValue(-1)); kv.setCurrentKey(1); IntValue default1 = kv.value(); kv.setCurrentKey(2); IntValue default2 = kv.value(); - + assertNotNull(default1); assertNotNull(default2); assertEquals(default1, default2); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index fff6e7019ffdf..85f8be58651b4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -198,7 +198,7 @@ public void invoke() throws Exception { } @Override - public void setInitialState(StateHandle stateHandle) throws Exception { + public void setInitialState(StateHandle stateHandle, long ts) throws Exception { } diff --git a/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java b/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java index 4e4acd2b6fbe8..4fb68203f5433 100644 --- a/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java +++ b/flink-staging/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java @@ -20,7 +20,6 @@ import org.apache.commons.io.FileUtils; -import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.fs.FileStatus; import org.apache.flink.core.fs.FileSystem; @@ -29,7 +28,7 @@ import org.apache.flink.runtime.state.StateHandle; import org.apache.flink.runtime.state.filesystem.FileStreamStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; - +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.hadoop.conf.Configuration; @@ -63,7 +62,7 @@ public class FileStateBackendTest { private static MiniDFSCluster HDFS_CLUSTER; private static FileSystem FS; - + // ------------------------------------------------------------------------ // startup / shutdown // ------------------------------------------------------------------------ @@ -127,7 +126,7 @@ public void testSetupAndSerialization() { // supreme! } - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); assertNotNull(backend.getCheckpointDirectory()); Path checkpointDir = backend.getCheckpointDirectory(); @@ -150,7 +149,7 @@ public void testSerializableState() { try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(randomHdfsFileUri())); - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); Path checkpointDir = backend.getCheckpointDirectory(); @@ -186,7 +185,7 @@ public void testSerializableState() { public void testStateOutputStream() { try { FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(randomHdfsFileUri())); - backend.initializeForJob(new JobID()); + backend.initializeForJob(new DummyEnvironment("test", 0, 0)); Path checkpointDir = backend.getCheckpointDirectory(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index ce4d174763216..3f1cfae68485e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvStateSnapshot; @@ -93,6 +94,8 @@ public abstract class AbstractStreamOperator private transient TypeSerializer keySerializer; private transient HashMap> keyValueStateSnapshots; + + private long recoveryTimestamp; // ------------------------------------------------------------------------ // Life Cycle @@ -172,15 +175,23 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) } @Override - public void restoreState(StreamTaskState state) throws Exception { + public void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception { // restore the key/value state. the actual restore happens lazily, when the function requests // the state again, because the restore method needs information provided by the user function keyValueStateSnapshots = state.getKvStates(); + this.recoveryTimestamp = recoveryTimestamp; } @Override public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { - // by default, nothing needs a notification of checkpoint completion + // We check whether the KvStates require notifications + if (keyValueStates != null) { + for (KvState kvstate : keyValueStates) { + if (kvstate instanceof CheckpointNotifier) { + ((CheckpointNotifier) kvstate).notifyCheckpointComplete(checkpointId); + } + } + } } // ------------------------------------------------------------------------ @@ -269,7 +280,7 @@ protected OperatorState createKeyValueState( * @throws IllegalStateException Thrown, if the key/value state was already initialized. * @throws Exception Thrown, if the state backend cannot create the key/value state. */ - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings("unchecked") protected > OperatorState createKeyValueState( String name, TypeSerializer valueSerializer, V defaultValue) throws Exception { @@ -304,25 +315,25 @@ else if (this.keySerializer != null) { throw new RuntimeException(); } - @SuppressWarnings("unchecked") Backend stateBackend = (Backend) container.getStateBackend(); KvState kvstate = null; // check whether we restore the key/value state from a snapshot, or create a new blank one if (keyValueStateSnapshots != null) { - @SuppressWarnings("unchecked") KvStateSnapshot snapshot = (KvStateSnapshot) keyValueStateSnapshots.remove(name); if (snapshot != null) { kvstate = snapshot.restoreState( - stateBackend, keySerializer, valueSerializer, defaultValue, getUserCodeClassloader()); + stateBackend, keySerializer, valueSerializer, defaultValue, getUserCodeClassloader(), recoveryTimestamp); } } if (kvstate == null) { + // create unique state id from operator id + state name + String stateId = name + "_" + getOperatorConfig().getVertexID(); // create a new blank key/value state - kvstate = stateBackend.createKvState(keySerializer, valueSerializer, defaultValue); + kvstate = stateBackend.createKvState(stateId ,name , keySerializer, valueSerializer, defaultValue); } if (keyValueStatesByName == null) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index 32be2ba49e9bb..c20544565ad13 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -147,8 +147,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) } @Override - public void restoreState(StreamTaskState state) throws Exception { - super.restoreState(state); + public void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception { + super.restoreState(state, recoveryTimestamp); StateHandle stateHandle = state.getFunctionState(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index fac26f153b9d9..1ef3298316f5f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -112,11 +112,13 @@ public interface StreamOperator extends Serializable { * * @param state The state of operator that was snapshotted as part of checkpoint * from which the execution is restored. + * + * @param recoveryTimestamp Global recovery timestamp * * @throws Exception Exceptions during state restore should be forwarded, so that the system can * properly react to failed state restore and fail the execution attempt. */ - void restoreState(StreamTaskState state) throws Exception; + void restoreState(StreamTaskState state, long recoveryTimestamp) throws Exception; /** * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager. diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java index 90d3d82bf6b32..677a7dd2db4d5 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java @@ -264,8 +264,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) } @Override - public void restoreState(StreamTaskState taskState) throws Exception { - super.restoreState(taskState); + public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception { + super.restoreState(taskState, recoveryTimestamp); @SuppressWarnings("unchecked") StateHandle inputState = (StateHandle) taskState.getOperatorState(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java index 5e4dea7e52a83..782363139e8b3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/NonKeyedWindowOperator.java @@ -536,8 +536,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) } @Override - public void restoreState(StreamTaskState taskState) throws Exception { - super.restoreState(taskState); + public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception { + super.restoreState(taskState, recoveryTimestamp); final ClassLoader userClassloader = getUserCodeClassloader(); @SuppressWarnings("unchecked") diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index f19e76057444a..68c3a5f26bf13 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -609,8 +609,8 @@ public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) } @Override - public void restoreState(StreamTaskState taskState) throws Exception { - super.restoreState(taskState); + public void restoreState(StreamTaskState taskState, long recoveryTimestamp) throws Exception { + super.restoreState(taskState, recoveryTimestamp); final ClassLoader userClassloader = getUserCodeClassloader(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 80c63dab90407..c310439789284 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.runtime.util.event.EventListener; +import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -138,6 +139,8 @@ public abstract class StreamTask> /** Flag to mark the task "in operation", in which case check * needs to be initialized to true, so that early cancel() before invoke() behaves correctly */ private volatile boolean isRunning; + + private long recoveryTimestamp; // ------------------------------------------------------------------------ @@ -169,7 +172,7 @@ public final void registerInputOutput() throws Exception { accumulatorMap = accumulatorRegistry.getUserMap(); stateBackend = createStateBackend(); - stateBackend.initializeForJob(getEnvironment().getJobID()); + stateBackend.initializeForJob(getEnvironment()); headOperator = configuration.getStreamOperator(userClassLoader); operatorChain = new OperatorChain<>(this, headOperator, accumulatorRegistry.getReadWriteReporter()); @@ -382,8 +385,9 @@ public RecordWriterOutput[] getStreamOutputs() { // ------------------------------------------------------------------------ @Override - public void setInitialState(StreamTaskStateList initialState) { + public void setInitialState(StreamTaskStateList initialState, long recoveryTimestamp) { lazyRestoreState = initialState; + this.recoveryTimestamp = recoveryTimestamp; } public void restoreStateLazy() throws Exception { @@ -403,7 +407,7 @@ public void restoreStateLazy() throws Exception { if (state != null && operator != null) { LOG.debug("Task {} in chain ({}) has checkpointed state", i, getName()); - operator.restoreState(state); + operator.restoreState(state, recoveryTimestamp); } else if (operator != null) { LOG.debug("Task {} in chain ({}) does not have checkpointed state", i, getName()); @@ -464,6 +468,11 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception { if (isRunning) { LOG.debug("Notification of complete checkpoint for task {}", getName()); + // We first notify the state backend if necessary + if (stateBackend instanceof CheckpointNotifier) { + ((CheckpointNotifier) stateBackend).notifyCheckpointComplete(checkpointId); + } + for (StreamOperator operator : operatorChain.getAllOperators()) { if (operator != null) { operator.notifyOfCompletedCheckpoint(checkpointId); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java index 62eb268dac38b..63cbd6a281275 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java @@ -595,7 +595,7 @@ public void checkpointRestoreWithPendingWindowTumbling() { windowSize, windowSize); op.setup(mockTask, new StreamConfig(new Configuration()), out2); - op.restoreState(state); + op.restoreState(state, 1); op.open(); // inject some more elements @@ -694,7 +694,7 @@ public void checkpointRestoreWithPendingWindowSliding() { windowSize, windowSlide); op.setup(mockTask, new StreamConfig(new Configuration()), out2); - op.restoreState(state); + op.restoreState(state, 1); op.open(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java index 4d507fb0919dc..55cd9fe1144ba 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java @@ -653,7 +653,7 @@ public void checkpointRestoreWithPendingWindowTumbling() { windowSize, windowSize); op.setup(mockTask, new StreamConfig(new Configuration()), out2); - op.restoreState(state); + op.restoreState(state, 1); op.open(); // inject the remaining elements @@ -759,7 +759,7 @@ public void checkpointRestoreWithPendingWindowSliding() { windowSize, windowSlide); op.setup(mockTask, new StreamConfig(new Configuration()), out2); - op.restoreState(state); + op.restoreState(state, 1); op.open(); diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java index 67c018912f5af..42b62303d8b42 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java @@ -50,10 +50,11 @@ public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTestBase { final long NUM_STRINGS = 10_000_000L; + final static int NUM_KEYS = 40; @Override public void testProgram(StreamExecutionEnvironment env) { - assertTrue("Broken test setup", (NUM_STRINGS/2) % 40 == 0); + assertTrue("Broken test setup", (NUM_STRINGS/2) % NUM_KEYS == 0); DataStream stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2)); DataStream stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2)); @@ -69,14 +70,14 @@ public void testProgram(StreamExecutionEnvironment env) { public void postSubmit() { // verify that we counted exactly right for (Entry sum : OnceFailingPartitionedSum.allSums.entrySet()) { - assertEquals(new Long(sum.getKey() * NUM_STRINGS / 40), sum.getValue()); + assertEquals(new Long(sum.getKey() * NUM_STRINGS / NUM_KEYS), sum.getValue()); } for (Long count : CounterSink.allCounts.values()) { - assertEquals(new Long(NUM_STRINGS / 40), count); + assertEquals(new Long(NUM_STRINGS / NUM_KEYS), count); } - assertEquals(40, CounterSink.allCounts.size()); - assertEquals(40, OnceFailingPartitionedSum.allSums.size()); + assertEquals(NUM_KEYS, CounterSink.allCounts.size()); + assertEquals(NUM_KEYS, OnceFailingPartitionedSum.allSums.size()); } // -------------------------------------------------------------------------------------------- @@ -120,7 +121,7 @@ public void run(SourceContext ctx) throws Exception { synchronized (lockingObject) { index += step; - ctx.collect(index % 40); + ctx.collect(index % NUM_KEYS); } } } @@ -160,9 +161,9 @@ private static class OnceFailingPartitionedSum extends RichMapFunction value) throws Exception { } } - private static class NonSerializableLong { + public static class NonSerializableLong { public Long value; private NonSerializableLong(long value) { @@ -225,7 +226,7 @@ public static NonSerializableLong of(long value) { } } - private static class IdentityKeySelector implements KeySelector { + public static class IdentityKeySelector implements KeySelector { @Override public T getKey(T value) throws Exception {