Skip to content

Commit

Permalink
[FLINK-2294] [streaming] Fix partitioned state next-input setting for…
Browse files Browse the repository at this point in the history
… copying chained collectors
  • Loading branch information
gyfora committed Jun 30, 2015
1 parent df42160 commit fef9f11
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ public CopyingOperatorCollector(OneInputStreamOperator operator, TypeSerializer<
@Override
public void collect(T record) {
try {
operator.getRuntimeContext().setNextInput(record);
operator.processElement(serializer.copy(record));
} catch (Exception e) {
if (LOG.isErrorEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ public void apiTest() throws Exception {
public void invoke(String value) throws Exception {}
});

keyedStream.map(new StatefulMapper2()).setParallelism(1).addSink(new SinkFunction<String>() {
private static final long serialVersionUID = 1L;
public void invoke(String value) throws Exception {}
});

try {
keyedStream.shuffle();
fail();
Expand Down Expand Up @@ -224,6 +229,36 @@ public void restoreState(Integer state) {
}
}

public static class StatefulMapper2 extends RichMapFunction<Integer, String> {
private static final long serialVersionUID = 1L;
OperatorState<Integer> groupCounter;

@Override
public String map(Integer value) throws Exception {
groupCounter.updateState(groupCounter.getState() + 1);

return value.toString();
}

@Override
public void open(Configuration conf) throws IOException {
groupCounter = getRuntimeContext().getOperatorState("groupCounter", 0, true);
}

@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
public void close() throws Exception {
Map<String, StreamOperatorState> states = ((StreamingRuntimeContext) getRuntimeContext()).getOperatorStates();
PartitionedStreamOperatorState<Integer, Integer, Integer> groupCounter = (PartitionedStreamOperatorState<Integer, Integer, Integer>) states.get("groupCounter");
for (Entry<Serializable, Integer> count : groupCounter.getPartitionedState().entrySet()) {
Integer key = (Integer) count.getKey();
Integer expected = key < 3 ? 2 : 1;
assertEquals(expected, count.getValue());
}
}

}

public static class ModKey implements KeySelector<Integer, Serializable> {

private static final long serialVersionUID = 4193026742083046736L;
Expand Down

0 comments on commit fef9f11

Please sign in to comment.