diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperator.java index 3751a3e8322fe..65b395c57a6a7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperator.java @@ -31,8 +31,13 @@ import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.util.CollectionUtil; +import javax.annotation.Nullable; + +import java.util.ArrayList; import java.util.List; +import static org.apache.flink.util.Preconditions.checkNotNull; + /** * Runtime {@link org.apache.flink.streaming.api.operators.StreamOperator} for executing {@link * SinkWriter Writers} that have state. @@ -54,18 +59,28 @@ final class StatefulSinkWriterOperator extends Abst /** The writer operator's state serializer. */ private final SimpleVersionedSerializer writerStateSimpleVersionedSerializer; + /** The previous sink operator's state name. */ + @Nullable + private final String previousSinkStateName; + // ------------------------------- runtime fields --------------------------------------- + /** The previous sink operator's state. */ + @Nullable + private ListState previousSinkState; + /** The operator's state. */ private ListState writerState; StatefulSinkWriterOperator( + @Nullable final String previousSinkStateName, final ProcessingTimeService processingTimeService, final Sink sink, final SimpleVersionedSerializer writerStateSimpleVersionedSerializer) { super(processingTimeService); this.sink = sink; this.writerStateSimpleVersionedSerializer = writerStateSimpleVersionedSerializer; + this.previousSinkStateName = previousSinkStateName; } @Override @@ -74,17 +89,40 @@ public void initializeState(StateInitializationContext context) throws Exception final ListState rawState = context.getOperatorStateStore().getListState(WRITER_RAW_STATES_DESC); writerState = new SimpleVersionedListState<>(rawState, writerStateSimpleVersionedSerializer); + + if (previousSinkStateName != null) { + final ListStateDescriptor preSinkStateDesc = new ListStateDescriptor<>( + previousSinkStateName, + BytePrimitiveArraySerializer.INSTANCE); + + final ListState preRawState = context + .getOperatorStateStore() + .getListState(preSinkStateDesc); + this.previousSinkState = new SimpleVersionedListState<>( + preRawState, + writerStateSimpleVersionedSerializer); + } } @SuppressWarnings("unchecked") @Override public void snapshotState(StateSnapshotContext context) throws Exception { writerState.update((List) sinkWriter.snapshotState()); + if (previousSinkState != null) { + previousSinkState.clear(); + } } @Override SinkWriter createWriter() throws Exception { - final List committables = CollectionUtil.iterableToList(writerState.get()); - return sink.createWriter(createInitContext(), committables); + final List writerStates = CollectionUtil.iterableToList(writerState.get()); + final List states = new ArrayList<>(writerStates); + if (previousSinkStateName != null) { + checkNotNull(previousSinkState); + final List previousSinkStates = CollectionUtil.iterableToList( + previousSinkState.get()); + states.addAll(previousSinkStates); + } + return sink.createWriter(createInitContext(), states); } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorFactory.java index 2741e373888d1..95da4878ac48a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorFactory.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorFactory.java @@ -23,6 +23,8 @@ import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import javax.annotation.Nullable; + /** * A {@link org.apache.flink.streaming.api.operators.StreamOperatorFactory} for {@link * StatefulSinkWriterOperator}. @@ -35,13 +37,27 @@ public final class StatefulSinkWriterOperatorFactory sink; + @Nullable + private final String previousSinkStateName; + public StatefulSinkWriterOperatorFactory(Sink sink) { + this(sink, null); + } + + public StatefulSinkWriterOperatorFactory( + Sink sink, + @Nullable String previousSinkStateName) { this.sink = sink; + this.previousSinkStateName = previousSinkStateName; } @Override AbstractSinkWriterOperator createWriterOperator(ProcessingTimeService processingTimeService) { - return new StatefulSinkWriterOperator<>(processingTimeService, sink, sink.getWriterStateSerializer().get()); + return new StatefulSinkWriterOperator<>( + previousSinkStateName, + processingTimeService, + sink, + sink.getWriterStateSerializer().get()); } @Override diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java index ddd5e062db2d8..9e643dea5e446 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/translators/SinkTransformationTranslator.java @@ -63,6 +63,9 @@ public class SinkTransformationTranslator translateForBatch( SinkTransformation transformation, @@ -75,6 +78,7 @@ public Collection translateForBatch( internalTranslate( transformation, parallelism, + PREVIOUS_SINK_STATE_NAME, new BatchCommitterOperatorFactory<>(transformation.getSink()), 1, 1, @@ -101,6 +105,7 @@ public Collection translateForStreaming( internalTranslate( transformation, parallelism, + PREVIOUS_SINK_STATE_NAME, new StreamingCommitterOperatorFactory<>(transformation.getSink()), parallelism, transformation.getMaxParallelism(), @@ -117,9 +122,9 @@ public Collection translateForStreaming( /** * Add the sink operators to the stream graph. - * * @param sinkTransformation The sink transformation that committer and global committer belongs to. * @param writerParallelism The parallelism of the writer operator. + * @param previousSinkStateName The state name of previous sink's state. * @param committerFactory The committer operator factory. * @param committerParallelism The parallelism of the committer operator. * @param committerMaxParallelism The max parallelism of the committer operator. @@ -128,6 +133,7 @@ public Collection translateForStreaming( private void internalTranslate( SinkTransformation sinkTransformation, int writerParallelism, + @SuppressWarnings("SameParameterValue") @Nullable String previousSinkStateName, OneInputStreamOperatorFactory committerFactory, int committerParallelism, int committerMaxParallelism, @@ -139,6 +145,7 @@ private void internalTranslate( final int writerId = addWriter( sinkTransformation, writerParallelism, + previousSinkStateName, context); final int committerId = addCommitter( @@ -161,12 +168,14 @@ private void internalTranslate( * * @param sinkTransformation The transformation that the writer belongs to * @param parallelism The parallelism of the writer + * @param previousSinkStateName The state name of previous sink's state. * * @return The stream node id of the writer */ private int addWriter( SinkTransformation sinkTransformation, int parallelism, + @Nullable String previousSinkStateName, Context context) { final boolean hasState = sinkTransformation .getSink() @@ -180,7 +189,9 @@ private int addWriter( final TypeInformation inputTypeInfo = input.getOutputType(); final StreamOperatorFactory writer = - hasState ? new StatefulSinkWriterOperatorFactory<>(sinkTransformation.getSink()) : new StatelessSinkWriterOperatorFactory<>( + hasState ? new StatefulSinkWriterOperatorFactory<>( + sinkTransformation.getSink(), + previousSinkStateName) : new StatelessSinkWriterOperatorFactory<>( sinkTransformation.getSink()); final String prefix = "Sink Writer:"; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorTest.java index bc7c79bac9733..35a93a482efec 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/StatefulSinkWriterOperatorTest.java @@ -18,18 +18,31 @@ package org.apache.flink.streaming.runtime.operators.sink; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; import org.apache.flink.api.connector.sink.SinkWriter; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.util.SimpleVersionedListState; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; import org.junit.Test; +import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertThat; /** @@ -90,6 +103,70 @@ public void stateIsRestored() throws Exception { new StreamRecord<>(Tuple3.of(2, initialTime + 2, initialTime).toString()))); } + @Test + public void loadPreviousSinkState() throws Exception { + //1. Build previous sink state + final List previousSinkInputs = Arrays.asList("bit", "mention", "thick", "stick", "stir", + "easy", "sleep", "forth", "cost", "prompt"); + + final OneInputStreamOperatorTestHarness previousSink = + new OneInputStreamOperatorTestHarness<>( + new DummySinkOperator(), + StringSerializer.INSTANCE); + + OperatorSubtaskState previousSinkState = TestHarnessUtil.buildSubtaskState( + previousSink, + previousSinkInputs); + + //2. Load previous sink state and verify the output + final OneInputStreamOperatorTestHarness compatibleWriterOperator = + createCompatibleSinkOperator(); + + final List> expectedOutput1 = + previousSinkInputs.stream().map(StreamRecord::new).collect(Collectors.toList()); + expectedOutput1.add(new StreamRecord<>(Tuple3.of(1, 1, Long.MIN_VALUE).toString())); + + // load the state from previous sink + compatibleWriterOperator.initializeState(previousSinkState); + + compatibleWriterOperator.open(); + + compatibleWriterOperator.processElement(1, 1); + + // this will flush out the committables that were restored from previous sink + compatibleWriterOperator.endInput(); + + OperatorSubtaskState operatorStateWithoutPreviousState = compatibleWriterOperator.snapshot( + 1L, + 1L); + + compatibleWriterOperator.close(); + + assertThat( + compatibleWriterOperator.getOutput(), + containsInAnyOrder(expectedOutput1.toArray())); + + //3. Restore the sink without previous sink's state + final OneInputStreamOperatorTestHarness restoredSinkOperator = + createCompatibleSinkOperator(); + final List> expectedOutput2 = + Arrays.asList( + new StreamRecord<>(Tuple3.of(2, 2, Long.MIN_VALUE).toString()), + new StreamRecord<>(Tuple3.of(3, 3, Long.MIN_VALUE).toString())); + + restoredSinkOperator.initializeState(operatorStateWithoutPreviousState); + + restoredSinkOperator.open(); + + restoredSinkOperator.processElement(2, 2); + restoredSinkOperator.processElement(3, 3); + + // this will flush out the committables that were restored + restoredSinkOperator.endInput(); + + assertThat(restoredSinkOperator.getOutput(), containsInAnyOrder(expectedOutput2.toArray())); + } + /** * A {@link SinkWriter} buffers elements and snapshots them when asked. */ @@ -105,4 +182,36 @@ void restoredFrom(List states) { this.elements = states; } } + + static class DummySinkOperator extends AbstractStreamOperator implements OneInputStreamOperator { + + static final String DUMMY_SINK_STATE_NAME = "dummy_sink_state"; + + static final ListStateDescriptor SINK_STATE_DESC = new ListStateDescriptor<>( + DUMMY_SINK_STATE_NAME, + BytePrimitiveArraySerializer.INSTANCE); + ListState sinkState; + + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + sinkState = new SimpleVersionedListState<>(context + .getOperatorStateStore() + .getListState(SINK_STATE_DESC), TestSink.StringCommittableSerializer.INSTANCE); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + sinkState.add(element.getValue()); + } + } + + private OneInputStreamOperatorTestHarness createCompatibleSinkOperator() throws Exception { + return new OneInputStreamOperatorTestHarness<>( + new StatefulSinkWriterOperatorFactory<>(TestSink + .newBuilder() + .setWriter(new SnapshottingBufferingSinkWriter()) + .withWriterState() + .build(), DummySinkOperator.DUMMY_SINK_STATE_NAME), + IntSerializer.INSTANCE); + } }