From bc32991a358889464a5b596684aa40ff3e31acbe Mon Sep 17 00:00:00 2001 From: Piotr Nowojski Date: Wed, 18 Oct 2017 15:01:37 +0200 Subject: [PATCH] [FLINK-7623][tests] Add tests verifying isRestored flag --- .../checkpoint/StateAssignmentOperation.java | 2 +- .../runtime/tasks/RestoreStreamTaskTest.java | 359 ++++++++++++++++++ 2 files changed, 360 insertions(+), 1 deletion(-) create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 6a58acb9adef7..d80311ca74736 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -219,7 +219,7 @@ private void assignTaskStateToExecutionJobVertices( } } - private static OperatorSubtaskState operatorSubtaskStateFrom( + public static OperatorSubtaskState operatorSubtaskStateFrom( OperatorInstanceID instanceID, Map> subManagedOperatorState, Map> subRawOperatorState, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java new file mode 100644 index 0000000000000..48240976aa0f6 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.tasks; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.StateAssignmentOperation; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.jobgraph.OperatorInstanceID; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.util.TestLogger; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static org.junit.Assert.assertEquals; + +/** + * Tests ensuring correct behaviour of {@link org.apache.flink.runtime.state.ManagedInitializationContext#isRestored} + * method. + */ +public class RestoreStreamTaskTest extends TestLogger { + + private static final Set RESTORED_OPERATORS = ConcurrentHashMap.newKeySet(); + + @Before + public void setup() { + RESTORED_OPERATORS.clear(); + } + + @Test + public void testRestore() throws Exception { + OperatorID headOperatorID = new OperatorID(42L, 42L); + OperatorID tailOperatorID = new OperatorID(44L, 44L); + AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator(), + tailOperatorID, + new CounterOperator(), + Optional.empty()); + + assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + + TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + + AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator(), + tailOperatorID, + new CounterOperator(), + Optional.of(stateHandles)); + + assertEquals( + new HashSet() {{ + add(headOperatorID); + add(tailOperatorID); + }}, + RESTORED_OPERATORS); + } + + @Test + public void testRestoreHeadWithNewId() throws Exception { + OperatorID tailOperatorID = new OperatorID(44L, 44L); + AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + new OperatorID(42L, 42L), + new CounterOperator(), + tailOperatorID, + new CounterOperator(), + Optional.empty()); + + assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + + TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + + AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + new OperatorID(4242L, 4242L), + new CounterOperator(), + tailOperatorID, + new CounterOperator(), + Optional.of(stateHandles)); + + assertEquals( + new HashSet() {{ + add(tailOperatorID); + }}, + RESTORED_OPERATORS); + } + + @Test + public void testRestoreTailWithNewId() throws Exception { + OperatorID headOperatorID = new OperatorID(42L, 42L); + + AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator(), + new OperatorID(44L, 44L), + new CounterOperator(), + Optional.empty()); + + assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + + TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + + AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator(), + new OperatorID(4444L, 4444L), + new CounterOperator(), + Optional.of(stateHandles)); + + assertEquals( + new HashSet() {{ + add(headOperatorID); + }}, + RESTORED_OPERATORS); + } + + @Test + public void testRestoreAfterScaleUp() throws Exception { + OperatorID headOperatorID = new OperatorID(42L, 42L); + OperatorID tailOperatorID = new OperatorID(44L, 44L); + + AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator(), + tailOperatorID, + new CounterOperator(), + Optional.empty()); + + assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + + // test empty state in case of scale up + OperatorSubtaskState emptyHeadOperatorState = StateAssignmentOperation.operatorSubtaskStateFrom( + new OperatorInstanceID(0, headOperatorID), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap()); + + TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + stateHandles.putSubtaskStateByOperatorID(headOperatorID, emptyHeadOperatorState); + + AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator(), + tailOperatorID, + new CounterOperator(), + Optional.of(stateHandles)); + + assertEquals( + new HashSet() {{ + add(headOperatorID); + add(tailOperatorID); + }}, + RESTORED_OPERATORS); + } + + @Test + public void testRestoreWithoutState() throws Exception { + OperatorID headOperatorID = new OperatorID(42L, 42L); + OperatorID tailOperatorID = new OperatorID(44L, 44L); + + AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain( + headOperatorID, + new StatelessOperator(), + tailOperatorID, + new CounterOperator(), + Optional.empty()); + + assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size()); + + TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles(); + + AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain( + headOperatorID, + new StatelessOperator(), + tailOperatorID, + new CounterOperator(), + Optional.of(stateHandles)); + + assertEquals( + new HashSet() {{ + add(headOperatorID); + add(tailOperatorID); + }}, + RESTORED_OPERATORS); + } + + private AcknowledgeStreamMockEnvironment createRunAndCheckpointOperatorChain( + OperatorID headId, + OneInputStreamOperator headOperator, + OperatorID tailId, + OneInputStreamOperator tailOperator, + Optional stateHandles) throws Exception { + + final OneInputStreamTask streamTask = new OneInputStreamTask<>(); + final OneInputStreamTaskTestHarness testHarness = + new OneInputStreamTaskTestHarness( + streamTask, 1, 1, + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setupOperatorChain(headId, headOperator) + .chain(tailId, tailOperator, StringSerializer.INSTANCE) + .finish(); + + AcknowledgeStreamMockEnvironment environment = new AcknowledgeStreamMockEnvironment( + testHarness.jobConfig, + testHarness.taskConfig, + testHarness.executionConfig, + testHarness.memorySize, + new MockInputSplitProvider(), + testHarness.bufferSize); + + if (stateHandles.isPresent()) { + streamTask.setInitialState(stateHandles.get()); + } + testHarness.invoke(environment); + testHarness.waitForTaskRunning(); + + processRecords(testHarness); + triggerCheckpoint(testHarness, environment, streamTask); + + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + + return environment; + } + + private void triggerCheckpoint( + OneInputStreamTaskTestHarness testHarness, + AcknowledgeStreamMockEnvironment environment, + OneInputStreamTask streamTask) throws Exception { + long checkpointId = 1L; + CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 1L); + + while (!streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forFullCheckpoint())) {} + + environment.getCheckpointLatch().await(); + assertEquals(checkpointId, environment.getCheckpointId()); + } + + private void processRecords(OneInputStreamTaskTestHarness testHarness) throws Exception { + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.processElement(new StreamRecord<>("10"), 0, 0); + testHarness.processElement(new StreamRecord<>("20"), 0, 0); + testHarness.processElement(new StreamRecord<>("30"), 0, 0); + + testHarness.waitForInputProcessing(); + + expectedOutput.add(new StreamRecord<>("10")); + expectedOutput.add(new StreamRecord<>("20")); + expectedOutput.add(new StreamRecord<>("30")); + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + } + + private abstract static class RestoreWatchOperator + extends AbstractStreamOperator + implements OneInputStreamOperator { + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + if (context.isRestored()) { + RESTORED_OPERATORS.add(getOperatorID()); + } + } + } + + /** + * Operator that counts processed messages and keeps result on state. + */ + private static class CounterOperator extends RestoreWatchOperator { + private static final long serialVersionUID = 2048954179291813243L; + + private ListState counterState; + private long counter = 0; + + @Override + public void processElement(StreamRecord element) throws Exception { + counter++; + output.collect(element); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + counterState = context + .getOperatorStateStore() + .getListState(new ListStateDescriptor<>("counter-state", LongSerializer.INSTANCE)); + + if (context.isRestored()) { + for (Long value : counterState.get()) { + counter += value; + } + counterState.clear(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + counterState.add(counter); + } + } + + /** + * Operator that does nothing except counting state restorations. + */ + private static class StatelessOperator extends RestoreWatchOperator { + + private static final long serialVersionUID = 2048954179291813244L; + + @Override + public void processElement(StreamRecord element) throws Exception { + output.collect(element); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + } + } +}