diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index f4abf182d12fa..5fc5d17494e32 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -63,6 +63,7 @@ import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; +import org.apache.flink.util.FlinkException; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; import org.apache.flink.util.WrappingRuntimeException; @@ -79,8 +80,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeoutException; import static org.junit.Assert.assertEquals; @@ -89,12 +92,10 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; -import static org.mockito.Matchers.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -105,8 +106,6 @@ */ public class TaskTest extends TestLogger { - private static final long TIMEOUT = 1000L; - private static OneShotLatch awaitLatch; private static OneShotLatch triggerLatch; @@ -121,7 +120,7 @@ public void setup() { @Test public void testRegularExecution() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setTaskManagerActions(taskManagerActions) .build(); @@ -133,20 +132,16 @@ public void testRegularExecution() throws Exception { // go into the run method. we should switch to DEPLOYING, RUNNING, then // FINISHED, and all should be good - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - task.run(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // verify final state assertEquals(ExecutionState.FINISHED, task.getExecutionState()); assertFalse(task.isCanceledOrFailed()); assertNull(task.getFailureCause()); assertNull(task.getInvokable()); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.FINISHED, task, null); } @Test @@ -166,7 +161,7 @@ public void testCancelRightAway() throws Exception { @Test public void testFailExternallyRightAway() throws Exception { - Task task = new TaskBuilder().build(); + final Task task = new TaskBuilder().build(); task.failExternally(new Exception("fail externally")); assertEquals(ExecutionState.FAILED, task.getExecutionState()); @@ -179,7 +174,9 @@ public void testFailExternallyRightAway() throws Exception { @Test public void testLibraryCacheRegistrationFailed() throws Exception { + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() + .setTaskManagerActions(taskManagerActions) .setLibraryCacheManager(mock(LibraryCacheManager.class)) // inactive manager .build(); @@ -199,6 +196,9 @@ public void testLibraryCacheRegistrationFailed() throws Exception { assertTrue(task.getFailureCause().getMessage().contains("classloader")); assertNull(task.getInvokable()); + + taskManagerActions.validateListenerMessage( + ExecutionState.FAILED, task, new Exception("No user code classloader available.")); } @Test @@ -258,7 +258,9 @@ public void testExecutionFailsInNetworkRegistration() throws Exception { when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() + .setTaskManagerActions(taskManagerActions) .setConsumableNotifier(consumableNotifier) .setPartitionProducerStateChecker(partitionProducerStateChecker) .setNetworkEnvironment(network) @@ -271,11 +273,16 @@ public void testExecutionFailsInNetworkRegistration() throws Exception { assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertTrue(task.getFailureCause().getMessage().contains("buffers")); + + taskManagerActions.validateListenerMessage( + ExecutionState.FAILED, task, new RuntimeException("buffers")); } @Test public void testInvokableInstantiationFailed() throws Exception { + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() + .setTaskManagerActions(taskManagerActions) .setInvokable(InvokableNonInstantiable.class) .build(); @@ -286,74 +293,62 @@ public void testInvokableInstantiationFailed() throws Exception { assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertTrue(task.getFailureCause().getMessage().contains("instantiate")); + + taskManagerActions.validateListenerMessage( + ExecutionState.FAILED, task, new FlinkException("Could not instantiate the task's invokable class.")); } @Test public void testExecutionFailsInInvoke() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(InvokableWithExceptionInInvoke.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - task.run(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertNotNull(task.getFailureCause()); assertNotNull(task.getFailureCause().getMessage()); assertTrue(task.getFailureCause().getMessage().contains("test")); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.FAILED, task, new Exception("test")); } @Test public void testFailWithWrappedException() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(FailingInvokableWithChainedException.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - task.run(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); final Throwable cause = task.getFailureCause(); assertTrue(cause instanceof IOException); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.FAILED, task, new IOException("test")); } @Test public void testCancelDuringInvoke() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(InvokableBlockingInInvoke.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - // run the task asynchronous task.startTaskThread(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // wait till the task is in invoke awaitLatch.await(); @@ -366,26 +361,22 @@ public void testCancelDuringInvoke() throws Exception { assertEquals(ExecutionState.CANCELED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertNull(task.getFailureCause()); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.CANCELED, task, null); } @Test public void testFailExternallyDuringInvoke() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(InvokableBlockingInInvoke.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - // run the task asynchronous task.startTaskThread(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // wait till the task is in invoke awaitLatch.await(); @@ -396,51 +387,43 @@ public void testFailExternallyDuringInvoke() throws Exception { assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertTrue(task.getFailureCause().getMessage().contains("test")); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.FAILED, task, new Exception("test")); } @Test public void testCanceledAfterExecutionFailedInInvoke() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(InvokableWithExceptionInInvoke.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - task.run(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // this should not overwrite the failure state task.cancelExecution(); assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertTrue(task.getFailureCause().getMessage().contains("test")); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.FAILED, task, new Exception("test")); } @Test public void testExecutionFailsAfterCanceling() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(InvokableWithExceptionOnTrigger.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - // run the task asynchronous task.startTaskThread(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // wait till the task is in invoke awaitLatch.await(); @@ -456,26 +439,22 @@ public void testExecutionFailsAfterCanceling() throws Exception { assertEquals(ExecutionState.CANCELED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertNull(task.getFailureCause()); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.CANCELED, task, null); } @Test public void testExecutionFailsAfterTaskMarkedFailed() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final QueuedNoOpTaskManagerActions taskManagerActions = new QueuedNoOpTaskManagerActions(); final Task task = new TaskBuilder() .setInvokable(InvokableWithExceptionOnTrigger.class) .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - // run the task asynchronous task.startTaskThread(); - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // wait till the task is in invoke awaitLatch.await(); @@ -490,6 +469,9 @@ public void testExecutionFailsAfterTaskMarkedFailed() throws Exception { assertEquals(ExecutionState.FAILED, task.getExecutionState()); assertTrue(task.isCanceledOrFailed()); assertTrue(task.getFailureCause().getMessage().contains("external")); + + taskManagerActions.validateListenerMessage(ExecutionState.RUNNING, task, null); + taskManagerActions.validateListenerMessage(ExecutionState.FAILED, task, new Exception("external")); } @@ -728,11 +710,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { */ @Test public void testWatchDogInterruptsTask() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); - - // guard no fatal error - doThrow(new RuntimeException("Unexpected FatalError message")). - when(taskManagerActions).notifyFatalError(anyString(), any(Throwable.class)); + final TaskManagerActions taskManagerActions = new ProhibitFatalErrorTaskManagerActions(); final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL.key(), 5); @@ -759,11 +737,7 @@ public void testWatchDogInterruptsTask() throws Exception { */ @Test public void testInterruptibleSharedLockInInvokeAndCancel() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); - - // guard no fatal error - doThrow(new RuntimeException("Unexpected FatalError message")). - when(taskManagerActions).notifyFatalError(anyString(), any(Throwable.class)); + final TaskManagerActions taskManagerActions = new ProhibitFatalErrorTaskManagerActions(); final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5); @@ -789,7 +763,8 @@ public void testInterruptibleSharedLockInInvokeAndCancel() throws Exception { */ @Test public void testFatalErrorAfterUnInterruptibleInvoke() throws Exception { - final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final AwaitFatalErrorTaskManagerActions taskManagerActions = + new AwaitFatalErrorTaskManagerActions(); final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5); @@ -801,26 +776,21 @@ public void testFatalErrorAfterUnInterruptibleInvoke() throws Exception { .setTaskManagerActions(taskManagerActions) .build(); - final TaskExecutionState state = new TaskExecutionState( - task.getJobID(), - task.getExecutionId(), - ExecutionState.RUNNING); - - task.startTaskThread(); - - verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); + try { + task.startTaskThread(); - awaitLatch.await(); + awaitLatch.await(); - task.cancelExecution(); + task.cancelExecution(); - verify(taskManagerActions, timeout(TIMEOUT)).notifyFatalError( - anyString(), any(Throwable.class)); - - // Interrupt again to clean up Thread - triggerLatch.trigger(); - task.getExecutingThread().interrupt(); - task.getExecutingThread().join(); + // wait for the notification of notifyFatalError + taskManagerActions.latch.await(); + } finally { + // Interrupt again to clean up Thread + triggerLatch.trigger(); + task.getExecutingThread().interrupt(); + task.getExecutingThread().join(); + } } /** @@ -859,10 +829,71 @@ public void testTaskConfig() throws Exception { task.getExecutingThread().join(); } + // ------------------------------------------------------------------------ + // customized TaskManagerActions + // ------------------------------------------------------------------------ + + /** + * Customized TaskManagerActions that queues all calls of updateTaskExecutionState + */ + private class QueuedNoOpTaskManagerActions extends NoOpTaskManagerActions { + private final BlockingQueue queue = new LinkedBlockingDeque<>(); + + @Override + public void updateTaskExecutionState(TaskExecutionState taskExecutionState) { + queue.offer(taskExecutionState); + } + + private void validateListenerMessage(ExecutionState state, Task task, Throwable error) { + try { + // we may have to wait for a bit to give the actors time to receive the message + // and put it into the queue + final TaskExecutionState taskState = queue.take(); + assertNotNull("There is no additional listener message", state); + + assertEquals(task.getJobID(), taskState.getJobID()); + assertEquals(task.getExecutionId(), taskState.getID()); + assertEquals(state, taskState.getExecutionState()); + + final Throwable t = taskState.getError(getClass().getClassLoader()); + if (error == null) { + assertNull(t); + } else { + assertEquals(error.toString(), t.toString()); + } + } catch (InterruptedException e) { + fail("interrupted"); + } + } + } + + /** + * Customized TaskManagerActions that ensures no call of notifyFatalError + */ + private class ProhibitFatalErrorTaskManagerActions extends NoOpTaskManagerActions { + @Override + public void notifyFatalError(String message, Throwable cause) { + throw new RuntimeException("Unexpected FatalError notification"); + } + } + + /** + * Customized TaskManagerActions that waits for a call of notifyFatalError + */ + private class AwaitFatalErrorTaskManagerActions extends NoOpTaskManagerActions { + private final OneShotLatch latch = new OneShotLatch(); + + @Override + public void notifyFatalError(String message, Throwable cause) { + latch.trigger(); + } + } + // ------------------------------------------------------------------------ // helper functions // ------------------------------------------------------------------------ + private void setInputGate(Task task, SingleInputGate inputGate) { try { Field f = Task.class.getDeclaredField("inputGates");