Skip to content

Commit

Permalink
[hotfix][tests] Refactor unit tests in RemoteInputChannelTest to avoi…
Browse files Browse the repository at this point in the history
…d mock way
  • Loading branch information
zhijiangW committed May 12, 2020
1 parent bcf7e28 commit a267019
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.LocalConnectionManager;
import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.TaskEventPublisher;
import org.apache.flink.runtime.io.network.TestingConnectionManager;
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
Expand All @@ -43,7 +43,7 @@ public class InputChannelBuilder {
private ConnectionID connectionID = STUB_CONNECTION_ID;
private ResultPartitionManager partitionManager = new ResultPartitionManager();
private TaskEventPublisher taskEventPublisher = new TaskEventDispatcher();
private ConnectionManager connectionManager = new LocalConnectionManager();
private ConnectionManager connectionManager = new TestingConnectionManager();
private int initialBackoff = 0;
private int maxBackoff = 0;
private InputChannelMetrics metrics = InputChannelTestUtils.newUnregisteredInputChannelMetrics();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -204,10 +203,9 @@ private void testConcurrentReleaseAndSomething(

@Test(expected = IllegalStateException.class)
public void testRetriggerWithoutPartitionRequest() throws Exception {
PartitionRequestClient connClient = mock(PartitionRequestClient.class);
SingleInputGate inputGate = createSingleInputGate(1);

RemoteInputChannel ch = createRemoteInputChannel(inputGate, connClient, 500, 3000);
RemoteInputChannel ch = createRemoteInputChannel(inputGate, 500, 3000);

ch.retriggerSubpartitionRequest(0);
}
Expand All @@ -218,20 +216,21 @@ public void testPartitionRequestExponentialBackoff() throws Exception {
int[] expectedDelays = {500, 1000, 2000, 3000};

// Setup
PartitionRequestClient connClient = mock(PartitionRequestClient.class);
SingleInputGate inputGate = createSingleInputGate(1);

RemoteInputChannel ch = createRemoteInputChannel(inputGate, connClient, 500, 3000);
ResultPartitionID partitionId = new ResultPartitionID();
TestVerifyPartitionRequestClient client = new TestVerifyPartitionRequestClient();
ConnectionManager connectionManager = new TestVerifyConnectionManager(client);
RemoteInputChannel ch = createRemoteInputChannel(inputGate, connectionManager, partitionId, 500, 3000);

// Initial request
ch.requestSubpartition(0);
verify(connClient).requestSubpartition(eq(ch.partitionId), eq(0), eq(ch), eq(0));
client.verifyResult(partitionId, 0, 0);

// Request subpartition and verify that the actual requests are delayed.
for (int expected : expectedDelays) {
ch.retriggerSubpartitionRequest(0);

verify(connClient).requestSubpartition(eq(ch.partitionId), eq(0), eq(ch), eq(expected));
client.verifyResult(partitionId, 0, expected);
}

// Exception after backoff is greater than the maximum backoff.
Expand All @@ -247,18 +246,19 @@ public void testPartitionRequestExponentialBackoff() throws Exception {
@Test
public void testPartitionRequestSingleBackoff() throws Exception {
// Setup
PartitionRequestClient connClient = mock(PartitionRequestClient.class);
SingleInputGate inputGate = createSingleInputGate(1);

RemoteInputChannel ch = createRemoteInputChannel(inputGate, connClient, 500, 500);
ResultPartitionID partitionId = new ResultPartitionID();
TestVerifyPartitionRequestClient client = new TestVerifyPartitionRequestClient();
ConnectionManager connectionManager = new TestVerifyConnectionManager(client);
RemoteInputChannel ch = createRemoteInputChannel(inputGate, connectionManager, partitionId, 500, 500);

// No delay for first request
ch.requestSubpartition(0);
verify(connClient).requestSubpartition(eq(ch.partitionId), eq(0), eq(ch), eq(0));
client.verifyResult(partitionId, 0, 0);

// Initial delay for second request
ch.retriggerSubpartitionRequest(0);
verify(connClient).requestSubpartition(eq(ch.partitionId), eq(0), eq(ch), eq(500));
client.verifyResult(partitionId, 0, 500);

// Exception after backoff is greater than the maximum backoff.
try {
Expand All @@ -273,14 +273,15 @@ public void testPartitionRequestSingleBackoff() throws Exception {
@Test
public void testPartitionRequestNoBackoff() throws Exception {
// Setup
PartitionRequestClient connClient = mock(PartitionRequestClient.class);
SingleInputGate inputGate = createSingleInputGate(1);

RemoteInputChannel ch = createRemoteInputChannel(inputGate, connClient, 0, 0);
ResultPartitionID partitionId = new ResultPartitionID();
TestVerifyPartitionRequestClient client = new TestVerifyPartitionRequestClient();
ConnectionManager connectionManager = new TestVerifyConnectionManager(client);
RemoteInputChannel ch = createRemoteInputChannel(inputGate, connectionManager, partitionId, 0, 0);

// No delay for first request
ch.requestSubpartition(0);
verify(connClient).requestSubpartition(eq(ch.partitionId), eq(0), eq(ch), eq(0));
client.verifyResult(partitionId, 0, 0);

// Exception, because backoff is disabled.
try {
Expand Down Expand Up @@ -1053,24 +1054,28 @@ public void testUnblockReleasedChannel() throws Exception {

// ---------------------------------------------------------------------------------------------

private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws IOException, InterruptedException {
return createRemoteInputChannel(inputGate, mock(PartitionRequestClient.class), 0, 0);
private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) {
return createRemoteInputChannel(inputGate, 0, 0);
}

private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate, int initialBackoff, int maxBackoff) {
return InputChannelBuilder.newBuilder()
.setInitialBackoff(initialBackoff)
.setMaxBackoff(maxBackoff)
.buildRemoteChannel(inputGate);
}

private RemoteInputChannel createRemoteInputChannel(
SingleInputGate inputGate,
PartitionRequestClient partitionRequestClient,
ConnectionManager connectionManager,
ResultPartitionID partitionId,
int initialBackoff,
int maxBackoff) throws IOException, InterruptedException {

final ConnectionManager connectionManager = mock(ConnectionManager.class);
when(connectionManager.createPartitionRequestClient(any(ConnectionID.class)))
.thenReturn(partitionRequestClient);

int maxBackoff) {
return InputChannelBuilder.newBuilder()
.setConnectionManager(connectionManager)
.setInitialBackoff(initialBackoff)
.setMaxBackoff(maxBackoff)
.setPartitionId(partitionId)
.setConnectionManager(connectionManager)
.buildRemoteChannel(inputGate);
}

Expand Down Expand Up @@ -1294,4 +1299,36 @@ boolean isInvoked() {
return isInvoked;
}
}

private static final class TestVerifyConnectionManager extends TestingConnectionManager {
private final PartitionRequestClient client;

TestVerifyConnectionManager(TestingPartitionRequestClient client) {
this.client = checkNotNull(client);
}

@Override
public PartitionRequestClient createPartitionRequestClient(ConnectionID connectionId) {
return client;
}
}

private static final class TestVerifyPartitionRequestClient extends TestingPartitionRequestClient {
private ResultPartitionID partitionId;
private int subpartitionIndex;
private int delayMs;

@Override
public void requestSubpartition(ResultPartitionID partitionId, int subpartitionIndex, RemoteInputChannel channel, int delayMs) {
this.partitionId = partitionId;
this.subpartitionIndex = subpartitionIndex;
this.delayMs = delayMs;
}

void verifyResult(ResultPartitionID expectedId, int expectedSubpartitionIndex, int expectedDelayMs) {
assertEquals(expectedId, partitionId);
assertEquals(expectedSubpartitionIndex, subpartitionIndex);
assertEquals(expectedDelayMs, delayMs);
}
}
}

0 comments on commit a267019

Please sign in to comment.