Skip to content

Commit

Permalink
[hotfix][yarn][test] Avoid using Mockito for AMRMClientAsync in YarnR…
Browse files Browse the repository at this point in the history
…esourceManagerTest.
  • Loading branch information
xintongsong authored and tillrohrmann committed Apr 25, 2020
1 parent 8b2376a commit 16a84fd
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.yarn;

import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.TriConsumer;
import org.apache.flink.util.function.TriFunction;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.yarn.client.api.async.impl.AMRMClientAsyncImpl;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;

/**
* A Yarn {@link AMRMClientAsync} implementation for testing.
*/
public class TestingYarnAMRMClientAsync extends AMRMClientAsyncImpl<AMRMClient.ContainerRequest> {

private volatile Function<Tuple4<Priority, String, Resource, CallbackHandler>, List<? extends Collection<AMRMClient.ContainerRequest>>>
getMatchingRequestsFunction = ignored -> Collections.emptyList();
private volatile BiConsumer<AMRMClient.ContainerRequest, CallbackHandler> addContainerRequestConsumer = (ignored1, ignored2) -> {};
private volatile BiConsumer<AMRMClient.ContainerRequest, CallbackHandler> removeContainerRequestConsumer = (ignored1, ignored2) -> {};
private volatile BiConsumer<ContainerId, CallbackHandler> releaseAssignedContainerConsumer = (ignored1, ignored2) -> {};
private volatile Consumer<Integer> setHeartbeatIntervalConsumer = (ignored) -> {};
private volatile TriFunction<String, Integer, String, RegisterApplicationMasterResponse> registerApplicationMasterFunction =
(ignored1, ignored2, ignored3) -> RegisterApplicationMasterResponse.newInstance(
Resource.newInstance(0, 0),
Resource.newInstance(Integer.MAX_VALUE, Integer.MAX_VALUE),
Collections.emptyMap(),
null,
Collections.emptyList(),
null,
Collections.emptyList());
private volatile TriConsumer<FinalApplicationStatus, String, String> unregisterApplicationMasterConsumer = (ignored1, ignored2, ignored3) -> {};

TestingYarnAMRMClientAsync(CallbackHandler callbackHandler) {
super(0, callbackHandler);
}

@Override
public List<? extends Collection<AMRMClient.ContainerRequest>> getMatchingRequests(Priority priority, String resourceName, Resource capability) {
return getMatchingRequestsFunction.apply(Tuple4.of(priority, resourceName, capability, handler));
}

@Override
public void addContainerRequest(AMRMClient.ContainerRequest req) {
addContainerRequestConsumer.accept(req, handler);
}

@Override
public void removeContainerRequest(AMRMClient.ContainerRequest req) {
removeContainerRequestConsumer.accept(req, handler);
}

@Override
public void releaseAssignedContainer(ContainerId containerId) {
releaseAssignedContainerConsumer.accept(containerId, handler);
}

@Override
public void setHeartbeatInterval(int interval) {
setHeartbeatIntervalConsumer.accept(interval);
}

@Override
public RegisterApplicationMasterResponse registerApplicationMaster(String appHostName, int appHostPort, String appTrackingUrl) {
return registerApplicationMasterFunction.apply(appHostName, appHostPort, appTrackingUrl);
}

@Override
public void unregisterApplicationMaster(FinalApplicationStatus appStatus, String appMessage, String appTrackingUrl) {
unregisterApplicationMasterConsumer.accept(appStatus, appMessage, appTrackingUrl);
}

void setGetMatchingRequestsFunction(
Function<Tuple4<Priority, String, Resource, CallbackHandler>, List<? extends Collection<AMRMClient.ContainerRequest>>>
getMatchingRequestsFunction) {
this.getMatchingRequestsFunction = Preconditions.checkNotNull(getMatchingRequestsFunction);
}

void setAddContainerRequestConsumer(
BiConsumer<AMRMClient.ContainerRequest, CallbackHandler> addContainerRequestConsumer) {
this.addContainerRequestConsumer = Preconditions.checkNotNull(addContainerRequestConsumer);
}

void setRemoveContainerRequestConsumer(
BiConsumer<AMRMClient.ContainerRequest, CallbackHandler> removeContainerRequestConsumer) {
this.removeContainerRequestConsumer = Preconditions.checkNotNull(removeContainerRequestConsumer);
}

void setReleaseAssignedContainerConsumer(
BiConsumer<ContainerId, CallbackHandler> releaseAssignedContainerConsumer) {
this.releaseAssignedContainerConsumer = Preconditions.checkNotNull(releaseAssignedContainerConsumer);
}

void setSetHeartbeatIntervalConsumer(
Consumer<Integer> setHeartbeatIntervalConsumer) {
this.setHeartbeatIntervalConsumer = setHeartbeatIntervalConsumer;
}

void setRegisterApplicationMasterFunction(
TriFunction<String, Integer, String, RegisterApplicationMasterResponse> registerApplicationMasterFunction) {
this.registerApplicationMasterFunction = registerApplicationMasterFunction;
}

void setUnregisterApplicationMasterConsumer(
TriConsumer<FinalApplicationStatus, String, String> unregisterApplicationMasterConsumer) {
this.unregisterApplicationMasterConsumer = unregisterApplicationMasterConsumer;
}

// ------------------------------------------------------------------------
// Override lifecycle methods to avoid actually starting the service
// ------------------------------------------------------------------------

@Override
protected void serviceInit(Configuration conf) throws Exception {
// noop
}

@Override
protected void serviceStart() throws Exception {
// noop
}

@Override
protected void serviceStop() throws Exception {
// noop
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
Expand All @@ -92,13 +91,16 @@
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.apache.flink.configuration.GlobalConfiguration.FLINK_CONF_FILENAME;
import static org.apache.flink.yarn.YarnConfigKeys.ENV_APP_ID;
Expand All @@ -115,13 +117,10 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

/**
Expand Down Expand Up @@ -179,7 +178,7 @@ public void teardown() throws Exception {
}

static class TestingYarnResourceManager extends YarnResourceManager {
AMRMClientAsync<AMRMClient.ContainerRequest> mockResourceManagerClient;
final TestingYarnAMRMClientAsync testingYarnAMRMClientAsync;
NMClientAsync mockNMClient;

TestingYarnResourceManager(
Expand All @@ -194,7 +193,6 @@ static class TestingYarnResourceManager extends YarnResourceManager {
ClusterInformation clusterInformation,
FatalErrorHandler fatalErrorHandler,
@Nullable String webInterfaceUrl,
AMRMClientAsync<AMRMClient.ContainerRequest> mockResourceManagerClient,
NMClientAsync mockNMClient,
ResourceManagerMetricGroup resourceManagerMetricGroup) {
super(
Expand All @@ -212,7 +210,7 @@ static class TestingYarnResourceManager extends YarnResourceManager {
webInterfaceUrl,
resourceManagerMetricGroup);
this.mockNMClient = mockNMClient;
this.mockResourceManagerClient = mockResourceManagerClient;
this.testingYarnAMRMClientAsync = new TestingYarnAMRMClientAsync(this);
}

<T> CompletableFuture<T> runInMainThread(Callable<T> callable) {
Expand All @@ -228,7 +226,7 @@ protected AMRMClientAsync<AMRMClient.ContainerRequest> createAndStartResourceMan
YarnConfiguration yarnConfiguration,
int yarnHeartbeatIntervalMillis,
@Nullable String webInterfaceUrl) {
return mockResourceManagerClient;
return testingYarnAMRMClientAsync;
}

@Override
Expand Down Expand Up @@ -257,8 +255,7 @@ class Context {

public NMClientAsync mockNMClient = mock(NMClientAsync.class);

@SuppressWarnings("unchecked")
public AMRMClientAsync<AMRMClient.ContainerRequest> mockResourceManagerClient = mock(AMRMClientAsync.class);
final TestingYarnAMRMClientAsync testingYarnAMRMClientAsync;

/**
* Create mock RM dependencies.
Expand Down Expand Up @@ -289,9 +286,10 @@ class Context {
new ClusterInformation("localhost", 1234),
testingFatalErrorHandler,
null,
mockResourceManagerClient,
mockNMClient,
UnregisteredMetricGroups.createUnregisteredResourceManagerMetricGroup());

testingYarnAMRMClientAsync = resourceManager.testingYarnAMRMClientAsync;
}

/**
Expand Down Expand Up @@ -324,12 +322,11 @@ void runTest(RunnableWithException testMethod) throws Exception {
}

void verifyContainerHasBeenStarted(Container testingContainer) {
verify(mockResourceManagerClient, VERIFICATION_TIMEOUT).removeContainerRequest(any(AMRMClient.ContainerRequest.class));
verify(mockNMClient, VERIFICATION_TIMEOUT).startContainerAsync(eq(testingContainer), any(ContainerLaunchContext.class));
}

void verifyContainerHasBeenRequested() {
verify(mockResourceManagerClient, VERIFICATION_TIMEOUT).addContainerRequest(any(AMRMClient.ContainerRequest.class));
void verifyFutureCompleted(CompletableFuture future) throws Exception {
future.get(TIMEOUT.toMilliseconds(), TimeUnit.MILLISECONDS);
}

Container createTestingContainer() {
Expand Down Expand Up @@ -365,18 +362,26 @@ public void testShutdownRequestCausesFatalError() throws Exception {
@Test
public void testStopWorker() throws Exception {
new Context() {{
final CompletableFuture<Void> addContainerRequestFuture = new CompletableFuture<>();
final CompletableFuture<Void> removeContainerRequestFuture = new CompletableFuture<>();
final CompletableFuture<Void> releaseAssignedContainerFuture = new CompletableFuture<>();

testingYarnAMRMClientAsync.setGetMatchingRequestsFunction(ignored ->
Collections.singletonList(Collections.singletonList(resourceManager.getContainerRequest())));
testingYarnAMRMClientAsync.setAddContainerRequestConsumer((ignored1, ignored2) -> addContainerRequestFuture.complete(null));
testingYarnAMRMClientAsync.setRemoveContainerRequestConsumer((ignored1, ignored2) -> removeContainerRequestFuture.complete(null));
testingYarnAMRMClientAsync.setReleaseAssignedContainerConsumer((ignored1, ignored2) -> releaseAssignedContainerFuture.complete(null));

runTest(() -> {
// Request slot from SlotManager.
registerSlotRequest(resourceManager, rmServices, resourceProfile1, taskHost);

// Callback from YARN when container is allocated.
Container testingContainer = createTestingContainer();

doReturn(Collections.singletonList(Collections.singletonList(resourceManager.getContainerRequest())))
.when(mockResourceManagerClient).getMatchingRequests(any(Priority.class), anyString(), any(Resource.class));

resourceManager.onContainersAllocated(ImmutableList.of(testingContainer));
verifyContainerHasBeenRequested();
verifyFutureCompleted(addContainerRequestFuture);
verifyFutureCompleted(removeContainerRequestFuture);
verifyContainerHasBeenStarted(testingContainer);

// Remote task executor registers with YarnResourceManager.
Expand Down Expand Up @@ -432,7 +437,7 @@ public void testStopWorker() throws Exception {
unregisterAndReleaseFuture.get();

verify(mockNMClient).stopContainerAsync(any(ContainerId.class), any(NodeId.class));
verify(mockResourceManagerClient).releaseAssignedContainer(any(ContainerId.class));
verifyFutureCompleted(releaseAssignedContainerFuture);
});

// It's now safe to access the SlotManager state since the ResourceManager has been stopped.
Expand Down Expand Up @@ -464,51 +469,74 @@ public void testDeleteApplicationFiles() throws Exception {
@Test
public void testOnContainerCompleted() throws Exception {
new Context() {{
final List<CompletableFuture<Void>> addContainerRequestFutures = new ArrayList<>();
addContainerRequestFutures.add(new CompletableFuture<>());
addContainerRequestFutures.add(new CompletableFuture<>());
addContainerRequestFutures.add(new CompletableFuture<>());
final AtomicInteger addContainerRequestFuturesNumCompleted = new AtomicInteger(0);
final CompletableFuture<Void> removeContainerRequestFuture = new CompletableFuture<>();

testingYarnAMRMClientAsync.setGetMatchingRequestsFunction(ignored ->
Collections.singletonList(Collections.singletonList(resourceManager.getContainerRequest())));
testingYarnAMRMClientAsync.setAddContainerRequestConsumer((ignored1, ignored2) ->
addContainerRequestFutures.get(addContainerRequestFuturesNumCompleted.getAndIncrement()).complete(null));
testingYarnAMRMClientAsync.setRemoveContainerRequestConsumer((ignored1, ignored2) -> removeContainerRequestFuture.complete(null));

runTest(() -> {
registerSlotRequest(resourceManager, rmServices, resourceProfile1, taskHost);

// Callback from YARN when container is allocated.
Container testingContainer = createTestingContainer();

doReturn(Collections.singletonList(Collections.singletonList(resourceManager.getContainerRequest())))
.when(mockResourceManagerClient).getMatchingRequests(any(Priority.class), anyString(), any(Resource.class));

resourceManager.onContainersAllocated(ImmutableList.of(testingContainer));
verifyContainerHasBeenRequested();
verifyFutureCompleted(addContainerRequestFutures.get(0));
verifyFutureCompleted(removeContainerRequestFuture);
verifyContainerHasBeenStarted(testingContainer);

// Callback from YARN when container is Completed, pending request can not be fulfilled by pending
// containers, need to request new container.
ContainerStatus testingContainerStatus = createTestingContainerStatus(testingContainer.getId());

resourceManager.onContainersCompleted(ImmutableList.of(testingContainerStatus));
verify(mockResourceManagerClient, VERIFICATION_TIMEOUT.times(2)).addContainerRequest(any(AMRMClient.ContainerRequest.class));
verifyFutureCompleted(addContainerRequestFutures.get(1));

// Callback from YARN when container is Completed happened before global fail, pending request
// slot is already fulfilled by pending containers, no need to request new container.
resourceManager.onContainersCompleted(ImmutableList.of(testingContainerStatus));
verify(mockResourceManagerClient, times(2)).addContainerRequest(any(AMRMClient.ContainerRequest.class));
assertFalse(addContainerRequestFutures.get(2).isDone());
});
}};
}

@Test
public void testOnStartContainerError() throws Exception {
new Context() {{
final List<CompletableFuture<Void>> addContainerRequestFutures = new ArrayList<>();
addContainerRequestFutures.add(new CompletableFuture<>());
addContainerRequestFutures.add(new CompletableFuture<>());
final AtomicInteger addContainerRequestFuturesNumCompleted = new AtomicInteger(0);
final CompletableFuture<Void> removeContainerRequestFuture = new CompletableFuture<>();
final CompletableFuture<Void> releaseAssignedContainerFuture = new CompletableFuture<>();

testingYarnAMRMClientAsync.setGetMatchingRequestsFunction(ignored ->
Collections.singletonList(Collections.singletonList(resourceManager.getContainerRequest())));
testingYarnAMRMClientAsync.setAddContainerRequestConsumer((ignored1, ignored2) ->
addContainerRequestFutures.get(addContainerRequestFuturesNumCompleted.getAndIncrement()).complete(null));
testingYarnAMRMClientAsync.setRemoveContainerRequestConsumer((ignored1, ignored2) -> removeContainerRequestFuture.complete(null));
testingYarnAMRMClientAsync.setReleaseAssignedContainerConsumer((ignored1, ignored2) -> releaseAssignedContainerFuture.complete(null));

runTest(() -> {
registerSlotRequest(resourceManager, rmServices, resourceProfile1, taskHost);
Container testingContainer = createTestingContainer();

doReturn(Collections.singletonList(Collections.singletonList(resourceManager.getContainerRequest())))
.when(mockResourceManagerClient).getMatchingRequests(any(Priority.class), anyString(), any(Resource.class));

resourceManager.onContainersAllocated(ImmutableList.of(testingContainer));
verifyContainerHasBeenRequested();
verifyFutureCompleted(addContainerRequestFutures.get(0));
verifyFutureCompleted(removeContainerRequestFuture);
verifyContainerHasBeenStarted(testingContainer);

resourceManager.onStartContainerError(testingContainer.getId(), new Exception("start error"));
verify(mockResourceManagerClient, VERIFICATION_TIMEOUT).releaseAssignedContainer(testingContainer.getId());
verify(mockResourceManagerClient, VERIFICATION_TIMEOUT.times(2)).addContainerRequest(any(AMRMClient.ContainerRequest.class));
verifyFutureCompleted(releaseAssignedContainerFuture);
verifyFutureCompleted(addContainerRequestFutures.get(1));
});
}};
}
Expand Down

0 comments on commit 16a84fd

Please sign in to comment.