Skip to content

Commit

Permalink
[FLINK-20337] Let StatefulSinkWriterOperator load StreamingFileSink's…
Browse files Browse the repository at this point in the history
… state

To allow stateful migration from `StreamingFileSink` to `FileSink` we
let the `StatefulSinkWriterOperator` load the `StreamingFileSink`'s
state ("bucket-state") if it exists.
  • Loading branch information
guoweiM authored and aljoscha committed Nov 30, 2020
1 parent 15e1901 commit ff32471
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.util.CollectionUtil;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.List;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
* Runtime {@link org.apache.flink.streaming.api.operators.StreamOperator} for executing {@link
* SinkWriter Writers} that have state.
Expand All @@ -54,18 +59,28 @@ final class StatefulSinkWriterOperator<InputT, CommT, WriterStateT> extends Abst
/** The writer operator's state serializer. */
private final SimpleVersionedSerializer<WriterStateT> writerStateSimpleVersionedSerializer;

/** The previous sink operator's state name. */
@Nullable
private final String previousSinkStateName;

// ------------------------------- runtime fields ---------------------------------------

/** The previous sink operator's state. */
@Nullable
private ListState<WriterStateT> previousSinkState;

/** The operator's state. */
private ListState<WriterStateT> writerState;

StatefulSinkWriterOperator(
@Nullable final String previousSinkStateName,
final ProcessingTimeService processingTimeService,
final Sink<InputT, CommT, WriterStateT, ?> sink,
final SimpleVersionedSerializer<WriterStateT> writerStateSimpleVersionedSerializer) {
super(processingTimeService);
this.sink = sink;
this.writerStateSimpleVersionedSerializer = writerStateSimpleVersionedSerializer;
this.previousSinkStateName = previousSinkStateName;
}

@Override
Expand All @@ -74,17 +89,40 @@ public void initializeState(StateInitializationContext context) throws Exception

final ListState<byte[]> rawState = context.getOperatorStateStore().getListState(WRITER_RAW_STATES_DESC);
writerState = new SimpleVersionedListState<>(rawState, writerStateSimpleVersionedSerializer);

if (previousSinkStateName != null) {
final ListStateDescriptor<byte[]> preSinkStateDesc = new ListStateDescriptor<>(
previousSinkStateName,
BytePrimitiveArraySerializer.INSTANCE);

final ListState<byte[]> preRawState = context
.getOperatorStateStore()
.getListState(preSinkStateDesc);
this.previousSinkState = new SimpleVersionedListState<>(
preRawState,
writerStateSimpleVersionedSerializer);
}
}

@SuppressWarnings("unchecked")
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
writerState.update((List<WriterStateT>) sinkWriter.snapshotState());
if (previousSinkState != null) {
previousSinkState.clear();
}
}

@Override
SinkWriter<InputT, CommT, WriterStateT> createWriter() throws Exception {
final List<WriterStateT> committables = CollectionUtil.iterableToList(writerState.get());
return sink.createWriter(createInitContext(), committables);
final List<WriterStateT> writerStates = CollectionUtil.iterableToList(writerState.get());
final List<WriterStateT> states = new ArrayList<>(writerStates);
if (previousSinkStateName != null) {
checkNotNull(previousSinkState);
final List<WriterStateT> previousSinkStates = CollectionUtil.iterableToList(
previousSinkState.get());
states.addAll(previousSinkStates);
}
return sink.createWriter(createInitContext(), states);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;

import javax.annotation.Nullable;

/**
* A {@link org.apache.flink.streaming.api.operators.StreamOperatorFactory} for {@link
* StatefulSinkWriterOperator}.
Expand All @@ -35,13 +37,27 @@ public final class StatefulSinkWriterOperatorFactory<InputT, CommT, WriterStateT

private final Sink<InputT, CommT, WriterStateT, ?> sink;

@Nullable
private final String previousSinkStateName;

public StatefulSinkWriterOperatorFactory(Sink<InputT, CommT, WriterStateT, ?> sink) {
this(sink, null);
}

public StatefulSinkWriterOperatorFactory(
Sink<InputT, CommT, WriterStateT, ?> sink,
@Nullable String previousSinkStateName) {
this.sink = sink;
this.previousSinkStateName = previousSinkStateName;
}

@Override
AbstractSinkWriterOperator<InputT, CommT> createWriterOperator(ProcessingTimeService processingTimeService) {
return new StatefulSinkWriterOperator<>(processingTimeService, sink, sink.getWriterStateSerializer().get());
return new StatefulSinkWriterOperator<>(
previousSinkStateName,
processingTimeService,
sink,
sink.getWriterStateSerializer().get());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ public class SinkTransformationTranslator<InputT, CommT, WriterStateT, GlobalCom

protected static final Logger LOG = LoggerFactory.getLogger(SinkTransformationTranslator.class);

// Currently we only support load the state from streaming file sink;
private static final String PREVIOUS_SINK_STATE_NAME = "bucket-states";

@Override
public Collection<Integer> translateForBatch(
SinkTransformation<InputT, CommT, WriterStateT, GlobalCommT> transformation,
Expand All @@ -75,6 +78,7 @@ public Collection<Integer> translateForBatch(
internalTranslate(
transformation,
parallelism,
PREVIOUS_SINK_STATE_NAME,
new BatchCommitterOperatorFactory<>(transformation.getSink()),
1,
1,
Expand All @@ -101,6 +105,7 @@ public Collection<Integer> translateForStreaming(
internalTranslate(
transformation,
parallelism,
PREVIOUS_SINK_STATE_NAME,
new StreamingCommitterOperatorFactory<>(transformation.getSink()),
parallelism,
transformation.getMaxParallelism(),
Expand All @@ -117,9 +122,9 @@ public Collection<Integer> translateForStreaming(

/**
* Add the sink operators to the stream graph.
*
* @param sinkTransformation The sink transformation that committer and global committer belongs to.
* @param writerParallelism The parallelism of the writer operator.
* @param previousSinkStateName The state name of previous sink's state.
* @param committerFactory The committer operator factory.
* @param committerParallelism The parallelism of the committer operator.
* @param committerMaxParallelism The max parallelism of the committer operator.
Expand All @@ -128,6 +133,7 @@ public Collection<Integer> translateForStreaming(
private void internalTranslate(
SinkTransformation<InputT, CommT, WriterStateT, GlobalCommT> sinkTransformation,
int writerParallelism,
@SuppressWarnings("SameParameterValue") @Nullable String previousSinkStateName,
OneInputStreamOperatorFactory<CommT, CommT> committerFactory,
int committerParallelism,
int committerMaxParallelism,
Expand All @@ -139,6 +145,7 @@ private void internalTranslate(
final int writerId = addWriter(
sinkTransformation,
writerParallelism,
previousSinkStateName,
context);

final int committerId = addCommitter(
Expand All @@ -161,12 +168,14 @@ private void internalTranslate(
*
* @param sinkTransformation The transformation that the writer belongs to
* @param parallelism The parallelism of the writer
* @param previousSinkStateName The state name of previous sink's state.
*
* @return The stream node id of the writer
*/
private int addWriter(
SinkTransformation<InputT, CommT, WriterStateT, GlobalCommT> sinkTransformation,
int parallelism,
@Nullable String previousSinkStateName,
Context context) {
final boolean hasState = sinkTransformation
.getSink()
Expand All @@ -180,7 +189,9 @@ private int addWriter(
final TypeInformation<InputT> inputTypeInfo = input.getOutputType();

final StreamOperatorFactory<CommT> writer =
hasState ? new StatefulSinkWriterOperatorFactory<>(sinkTransformation.getSink()) : new StatelessSinkWriterOperatorFactory<>(
hasState ? new StatefulSinkWriterOperatorFactory<>(
sinkTransformation.getSink(),
previousSinkStateName) : new StatelessSinkWriterOperatorFactory<>(
sinkTransformation.getSink());

final String prefix = "Sink Writer:";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,31 @@

package org.apache.flink.streaming.runtime.operators.sink;

import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.connector.sink.SinkWriter;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.util.SimpleVersionedListState;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.TestHarnessUtil;

import org.junit.Test;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertThat;

/**
Expand Down Expand Up @@ -90,6 +103,70 @@ public void stateIsRestored() throws Exception {
new StreamRecord<>(Tuple3.of(2, initialTime + 2, initialTime).toString())));
}

@Test
public void loadPreviousSinkState() throws Exception {
//1. Build previous sink state
final List<String> previousSinkInputs = Arrays.asList("bit", "mention", "thick", "stick", "stir",
"easy", "sleep", "forth", "cost", "prompt");

final OneInputStreamOperatorTestHarness<String, String> previousSink =
new OneInputStreamOperatorTestHarness<>(
new DummySinkOperator(),
StringSerializer.INSTANCE);

OperatorSubtaskState previousSinkState = TestHarnessUtil.buildSubtaskState(
previousSink,
previousSinkInputs);

//2. Load previous sink state and verify the output
final OneInputStreamOperatorTestHarness<Integer, String> compatibleWriterOperator =
createCompatibleSinkOperator();

final List<StreamRecord<String>> expectedOutput1 =
previousSinkInputs.stream().map(StreamRecord::new).collect(Collectors.toList());
expectedOutput1.add(new StreamRecord<>(Tuple3.of(1, 1, Long.MIN_VALUE).toString()));

// load the state from previous sink
compatibleWriterOperator.initializeState(previousSinkState);

compatibleWriterOperator.open();

compatibleWriterOperator.processElement(1, 1);

// this will flush out the committables that were restored from previous sink
compatibleWriterOperator.endInput();

OperatorSubtaskState operatorStateWithoutPreviousState = compatibleWriterOperator.snapshot(
1L,
1L);

compatibleWriterOperator.close();

assertThat(
compatibleWriterOperator.getOutput(),
containsInAnyOrder(expectedOutput1.toArray()));

//3. Restore the sink without previous sink's state
final OneInputStreamOperatorTestHarness<Integer, String> restoredSinkOperator =
createCompatibleSinkOperator();
final List<StreamRecord<String>> expectedOutput2 =
Arrays.asList(
new StreamRecord<>(Tuple3.of(2, 2, Long.MIN_VALUE).toString()),
new StreamRecord<>(Tuple3.of(3, 3, Long.MIN_VALUE).toString()));

restoredSinkOperator.initializeState(operatorStateWithoutPreviousState);

restoredSinkOperator.open();

restoredSinkOperator.processElement(2, 2);
restoredSinkOperator.processElement(3, 3);

// this will flush out the committables that were restored
restoredSinkOperator.endInput();

assertThat(restoredSinkOperator.getOutput(), containsInAnyOrder(expectedOutput2.toArray()));
}

/**
* A {@link SinkWriter} buffers elements and snapshots them when asked.
*/
Expand All @@ -105,4 +182,36 @@ void restoredFrom(List<String> states) {
this.elements = states;
}
}

static class DummySinkOperator extends AbstractStreamOperator<String> implements OneInputStreamOperator<String, String> {

static final String DUMMY_SINK_STATE_NAME = "dummy_sink_state";

static final ListStateDescriptor<byte[]> SINK_STATE_DESC = new ListStateDescriptor<>(
DUMMY_SINK_STATE_NAME,
BytePrimitiveArraySerializer.INSTANCE);
ListState<String> sinkState;

public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
sinkState = new SimpleVersionedListState<>(context
.getOperatorStateStore()
.getListState(SINK_STATE_DESC), TestSink.StringCommittableSerializer.INSTANCE);
}

@Override
public void processElement(StreamRecord<String> element) throws Exception {
sinkState.add(element.getValue());
}
}

private OneInputStreamOperatorTestHarness<Integer, String> createCompatibleSinkOperator() throws Exception {
return new OneInputStreamOperatorTestHarness<>(
new StatefulSinkWriterOperatorFactory<>(TestSink
.newBuilder()
.setWriter(new SnapshottingBufferingSinkWriter())
.withWriterState()
.build(), DummySinkOperator.DUMMY_SINK_STATE_NAME),
IntSerializer.INSTANCE);
}
}

0 comments on commit ff32471

Please sign in to comment.