Skip to content

Commit

Permalink
[streaming] KafkaSource checkpointing rework for new interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
gyfora committed Jun 25, 2015
1 parent 0ecab82 commit 5ddd232
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@

package org.apache.flink.streaming.connectors.kafka.api.persistent;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import kafka.common.TopicAndPartition;
import kafka.consumer.Consumer;
import kafka.consumer.ConsumerConfig;
Expand All @@ -27,33 +37,27 @@
import kafka.message.MessageAndMetadata;
import kafka.utils.ZKGroupTopicDirs;
import kafka.utils.ZkUtils;

import org.I0Itec.zkclient.ZkClient;
import org.I0Itec.zkclient.exception.ZkMarshallingError;
import org.I0Itec.zkclient.serialize.ZkSerializer;
import org.apache.commons.collections.map.LinkedMap;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.streaming.api.checkpoint.CheckpointCommitter;
import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.util.serialization.DeserializationSchema;
import org.apache.zookeeper.data.Stat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import scala.Option;
import scala.collection.JavaConversions;
import scala.collection.Seq;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import com.google.common.base.Preconditions;

/**
* Source for reading from Kafka using Flink Streaming Fault Tolerance.
Expand All @@ -63,8 +67,7 @@
*/
public class PersistentKafkaSource<OUT> extends RichParallelSourceFunction<OUT> implements
ResultTypeQueryable<OUT>,
CheckpointCommitter,
CheckpointedAsynchronously<long[]> {
CheckpointCommitter {

private static final long serialVersionUID = 287845877188312621L;

Expand All @@ -73,17 +76,14 @@ public class PersistentKafkaSource<OUT> extends RichParallelSourceFunction<OUT>

private final String topicName;
private final DeserializationSchema<OUT> deserializationSchema;

private final LinkedMap pendingCheckpoints = new LinkedMap();

private transient ConsumerConfig consumerConfig;
private transient ConsumerIterator<byte[], byte[]> iteratorToRead;
private transient ConsumerConnector consumer;

private transient ZkClient zkClient;
private transient long[] lastOffsets;
private transient OperatorState<long[]> lastOffsets;
private transient long[] commitedOffsets; // maintain committed offsets, to avoid committing the same over and over again.
private transient long[] restoreState;

private volatile boolean running;

Expand Down Expand Up @@ -145,25 +145,23 @@ public void open(Configuration parameters) throws Exception {
// most likely the number of offsets we're going to store here will be lower than the number of partitions.
int numPartitions = getNumberOfPartitions();
LOG.debug("The topic {} has {} partitions", topicName, numPartitions);
this.lastOffsets = new long[numPartitions];
this.lastOffsets = getRuntimeContext().getOperatorState("offset", new long[numPartitions]);
this.commitedOffsets = new long[numPartitions];
// check if there are offsets to restore
if (restoreState != null) {
if (restoreState.length != numPartitions) {
throw new IllegalStateException("There are "+restoreState.length+" offsets to restore for topic "+topicName+" but " +
if (Arrays.equals(lastOffsets.getState(), new long[numPartitions])) {
if (lastOffsets.getState().length != numPartitions) {
throw new IllegalStateException("There are "+lastOffsets.getState().length+" offsets to restore for topic "+topicName+" but " +
"there are only "+numPartitions+" in the topic");
}

LOG.info("Setting restored offsets {} in ZooKeeper", Arrays.toString(restoreState));
setOffsetsInZooKeeper(restoreState);
this.lastOffsets = restoreState;
LOG.info("Setting restored offsets {} in ZooKeeper", Arrays.toString(lastOffsets.getState()));
setOffsetsInZooKeeper(lastOffsets.getState());
} else {
// initialize empty offsets
Arrays.fill(this.lastOffsets, -1);
Arrays.fill(this.lastOffsets.getState(), -1);
}
Arrays.fill(this.commitedOffsets, 0); // just to make it clear

pendingCheckpoints.clear();
running = true;
}

Expand All @@ -177,7 +175,7 @@ public void run(SourceContext<OUT> ctx) throws Exception {

while (running && iteratorToRead.hasNext()) {
MessageAndMetadata<byte[], byte[]> message = iteratorToRead.next();
if(lastOffsets[message.partition()] >= message.offset()) {
if(lastOffsets.getState()[message.partition()] >= message.offset()) {
LOG.info("Skipping message with offset {} from partition {}", message.offset(), message.partition());
continue;
}
Expand All @@ -190,7 +188,7 @@ public void run(SourceContext<OUT> ctx) throws Exception {

// make the state update and the element emission atomic
synchronized (checkpointLock) {
lastOffsets[message.partition()] = message.offset();
lastOffsets.getState()[message.partition()] = message.offset();
ctx.collect(next);
}

Expand All @@ -212,64 +210,19 @@ public void close() {
zkClient.close();
}


// ---------------------- State / Checkpoint handling -----------------
// this source is keeping the partition offsets in Zookeeper

@Override
public long[] snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
if (lastOffsets == null) {
LOG.warn("State snapshot requested on not yet opened source. Returning null");
return null;
}

if (LOG.isInfoEnabled()) {
LOG.info("Snapshotting state. Offsets: {}, checkpoint id {}, timestamp {}",
Arrays.toString(lastOffsets), checkpointId, checkpointTimestamp);
}

long[] currentOffsets = Arrays.copyOf(lastOffsets, lastOffsets.length);

// the map may be asynchronously updates when committing to Kafka, so we synchronize
synchronized (pendingCheckpoints) {
pendingCheckpoints.put(checkpointId, currentOffsets);
}

return currentOffsets;
}

@Override
public void restoreState(long[] state) {
LOG.info("The state will be restored to {} in the open() method", Arrays.toString(state));
this.restoreState = Arrays.copyOf(state, state.length);
}

/**
* Notification on completed checkpoints
* @param checkpointId The ID of the checkpoint that has been completed.
* @throws Exception
*/
@Override
public void commitCheckpoint(long checkpointId) {
public void commitCheckpoint(long checkpointId, StateHandle<Serializable> state) throws Exception {
LOG.info("Commit checkpoint {}", checkpointId);

long[] checkpointOffsets;

// the map may be asynchronously updates when snapshotting state, so we synchronize
synchronized (pendingCheckpoints) {
final int posInMap = pendingCheckpoints.indexOf(checkpointId);
if (posInMap == -1) {
LOG.warn("Unable to find pending checkpoint for id {}", checkpointId);
return;
}

checkpointOffsets = (long[]) pendingCheckpoints.remove(posInMap);
// remove older checkpoints in map:
if (!pendingCheckpoints.isEmpty()) {
for(int i = 0; i < posInMap; i++) {
pendingCheckpoints.remove(0);
}
}
}

checkpointOffsets = (long[]) state.getState();

if (LOG.isInfoEnabled()) {
LOG.info("Committing offsets {} to ZooKeeper", Arrays.toString(checkpointOffsets));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
import kafka.javaapi.consumer.ConsumerConnector;
import kafka.message.MessageAndMetadata;
import kafka.network.SocketServer;
import kafka.server.KafkaConfig;
import kafka.server.KafkaServer;

import org.I0Itec.zkclient.ZkClient;
import org.apache.commons.collections.map.LinkedMap;
import org.apache.curator.test.TestingServer;
Expand Down Expand Up @@ -80,8 +83,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import kafka.server.KafkaConfig;
import kafka.server.KafkaServer;
import scala.collection.Seq;

/**
Expand Down Expand Up @@ -183,62 +184,6 @@ public static void shutDownServices() {
zkClient.close();
}

// -------------------------- test checkpointing ------------------------
@Test
public void testCheckpointing() throws Exception {
createTestTopic("testCheckpointing", 1, 1);

Properties props = new Properties();
props.setProperty("zookeeper.connect", zookeeperConnectionString);
props.setProperty("group.id", "testCheckpointing");
props.setProperty("auto.commit.enable", "false");
ConsumerConfig cc = new ConsumerConfig(props);
PersistentKafkaSource<String> source = new PersistentKafkaSource<String>("testCheckpointing", new FakeDeserializationSchema(), cc);


Field pendingCheckpointsField = PersistentKafkaSource.class.getDeclaredField("pendingCheckpoints");
pendingCheckpointsField.setAccessible(true);
LinkedMap pendingCheckpoints = (LinkedMap) pendingCheckpointsField.get(source);


Assert.assertEquals(0, pendingCheckpoints.size());
// first restore
source.restoreState(new long[]{1337});
// then open
source.open(new Configuration());
long[] state1 = source.snapshotState(1, 15);
Assert.assertArrayEquals(new long[]{1337}, state1);
long[] state2 = source.snapshotState(2, 30);
Assert.assertArrayEquals(new long[]{1337}, state2);
Assert.assertEquals(2, pendingCheckpoints.size());

source.commitCheckpoint(1);
Assert.assertEquals(1, pendingCheckpoints.size());

source.commitCheckpoint(2);
Assert.assertEquals(0, pendingCheckpoints.size());

source.commitCheckpoint(666); // invalid checkpoint
Assert.assertEquals(0, pendingCheckpoints.size());

// create 500 snapshots
for(int i = 0; i < 500; i++) {
source.snapshotState(i, 15 * i);
}
Assert.assertEquals(500, pendingCheckpoints.size());

// commit only the second last
source.commitCheckpoint(498);
Assert.assertEquals(1, pendingCheckpoints.size());

// access invalid checkpoint
source.commitCheckpoint(490);

// and the last
source.commitCheckpoint(499);
Assert.assertEquals(0, pendingCheckpoints.size());
}

private static class FakeDeserializationSchema implements DeserializationSchema<String> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public interface CheckpointCommitter {
*
* @param checkpointId The ID of the checkpoint that has been completed.
* @param checkPointedState Handle to the state that was checkpointed with this checkpoint id.
* @throws Exception
*/
void commitCheckpoint(long checkpointId, StateHandle<Serializable> checkPointedState);
void commitCheckpoint(long checkpointId, StateHandle<Serializable> checkPointedState) throws Exception;
}

0 comments on commit 5ddd232

Please sign in to comment.