Skip to content

Commit

Permalink
[hotfix] Migrate a few location preference tests to JUnit5
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhurk committed Jan 30, 2023
1 parent 143464d commit 00c2ebd
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,75 +19,64 @@
package org.apache.flink.runtime.scheduler;

import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.util.TestLogger;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import java.util.Collection;

import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createRandomExecutionVertexId;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;
import static org.assertj.core.api.Assertions.assertThat;

/** Tests for {@link AvailableInputsLocationsRetriever}. */
public class AvailableInputsLocationsRetrieverTest extends TestLogger {
class AvailableInputsLocationsRetrieverTest {
private static final ExecutionVertexID EV1 = createRandomExecutionVertexId();
private static final ExecutionVertexID EV2 = createRandomExecutionVertexId();

@Test
public void testNoInputLocation() {
void testNoInputLocation() {
TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever();
InputsLocationsRetriever availableInputsLocationsRetriever =
new AvailableInputsLocationsRetriever(originalLocationRetriever);
assertThat(
availableInputsLocationsRetriever.getTaskManagerLocation(EV1).isPresent(),
is(false));
assertThat(availableInputsLocationsRetriever.getTaskManagerLocation(EV1)).isNotPresent();
}

@Test
public void testNoInputLocationIfNotDone() {
void testNoInputLocationIfNotDone() {
TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever();
originalLocationRetriever.markScheduled(EV1);
InputsLocationsRetriever availableInputsLocationsRetriever =
new AvailableInputsLocationsRetriever(originalLocationRetriever);
assertThat(
availableInputsLocationsRetriever.getTaskManagerLocation(EV1).isPresent(),
is(false));
assertThat(availableInputsLocationsRetriever.getTaskManagerLocation(EV1)).isNotPresent();
}

@Test
public void testNoInputLocationIfFailed() {
void testNoInputLocationIfFailed() {
TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever();
originalLocationRetriever.failTaskManagerLocation(EV1, new Throwable());
InputsLocationsRetriever availableInputsLocationsRetriever =
new AvailableInputsLocationsRetriever(originalLocationRetriever);
assertThat(
availableInputsLocationsRetriever.getTaskManagerLocation(EV1).isPresent(),
is(false));
assertThat(availableInputsLocationsRetriever.getTaskManagerLocation(EV1)).isNotPresent();
}

@Test
public void testInputLocationIfDone() {
void testInputLocationIfDone() {
TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever();
originalLocationRetriever.assignTaskManagerLocation(EV1);
InputsLocationsRetriever availableInputsLocationsRetriever =
new AvailableInputsLocationsRetriever(originalLocationRetriever);
assertThat(
availableInputsLocationsRetriever.getTaskManagerLocation(EV1).isPresent(),
is(true));
assertThat(availableInputsLocationsRetriever.getTaskManagerLocation(EV1)).isPresent();
}

@Test
public void testConsumedResultPartitionsProducers() {
void testConsumedResultPartitionsProducers() {
TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever();
InputsLocationsRetriever availableInputsLocationsRetriever =
new AvailableInputsLocationsRetriever(originalLocationRetriever);
Collection<Collection<ExecutionVertexID>> producers =
availableInputsLocationsRetriever.getConsumedResultPartitionsProducers(EV2);
assertThat(producers.size(), is(1));
assertThat(producers).hasSize(1);
Collection<ExecutionVertexID> resultProducers = producers.iterator().next();
assertThat(resultProducers, contains(EV1));
assertThat(resultProducers).containsExactly(EV1);
}

private static TestingInputsLocationsRetriever getOriginalLocationRetriever() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.taskmanager.LocalTaskManagerLocation;
import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
import org.apache.flink.util.TestLogger;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -34,16 +33,13 @@
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.Assert.assertThat;
import static org.assertj.core.api.Assertions.assertThat;

/** Tests {@link DefaultPreferredLocationsRetriever}. */
public class DefaultPreferredLocationsRetrieverTest extends TestLogger {
class DefaultPreferredLocationsRetrieverTest {

@Test
public void testStateLocationsWillBeReturnedIfExist() {
void testStateLocationsWillBeReturnedIfExist() {
final TaskManagerLocation stateLocation = new LocalTaskManagerLocation();

final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
Expand All @@ -65,11 +61,11 @@ public void testStateLocationsWillBeReturnedIfExist() {
final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations =
locationsRetriever.getPreferredLocations(consumerId, Collections.emptySet());

assertThat(preferredLocations.getNow(null), contains(stateLocation));
assertThat(preferredLocations.getNow(null)).containsExactly(stateLocation);
}

@Test
public void testInputLocationsIgnoresEdgeOfTooManyLocations() {
void testInputLocationsIgnoresEdgeOfTooManyLocations() {
final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
new TestingInputsLocationsRetriever.Builder();

Expand Down Expand Up @@ -98,11 +94,11 @@ public void testInputLocationsIgnoresEdgeOfTooManyLocations() {
final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations =
locationsRetriever.getPreferredLocations(consumerId, Collections.emptySet());

assertThat(preferredLocations.getNow(null), hasSize(0));
assertThat(preferredLocations.getNow(null)).isEmpty();
}

@Test
public void testInputLocationsChoosesInputOfFewerLocations() {
void testInputLocationsChoosesInputOfFewerLocations() {
final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
new TestingInputsLocationsRetriever.Builder();

Expand Down Expand Up @@ -150,12 +146,12 @@ public void testInputLocationsChoosesInputOfFewerLocations() {
final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations =
locationsRetriever.getPreferredLocations(consumerId, Collections.emptySet());

assertThat(
preferredLocations.getNow(null), containsInAnyOrder(expectedLocations.toArray()));
assertThat(preferredLocations.getNow(null))
.containsExactlyInAnyOrderElementsOf(expectedLocations);
}

@Test
public void testInputLocationsIgnoresExcludedProducers() {
void testInputLocationsIgnoresExcludedProducers() {
final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
new TestingInputsLocationsRetriever.Builder();

Expand Down Expand Up @@ -186,10 +182,10 @@ public void testInputLocationsIgnoresExcludedProducers() {
locationsRetriever.getPreferredLocations(
consumerId, Collections.singleton(producerId1));

assertThat(preferredLocations.getNow(null), hasSize(1));
assertThat(preferredLocations.getNow(null)).hasSize(1);

final TaskManagerLocation producerLocation2 =
inputsLocationsRetriever.getTaskManagerLocation(producerId2).get().getNow(null);
assertThat(preferredLocations.getNow(null), contains(producerLocation2));
assertThat(preferredLocations.getNow(null)).containsExactly(producerLocation2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,26 @@

import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
import org.apache.flink.util.TestLogger;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import java.util.Collection;
import java.util.Collections;
import java.util.Optional;

import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createRandomExecutionVertexId;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertThat;
import static org.assertj.core.api.Assertions.assertThat;

/** Tests for {@link DefaultSyncPreferredLocationsRetriever}. */
public class DefaultSyncPreferredLocationsRetrieverTest extends TestLogger {
class DefaultSyncPreferredLocationsRetrieverTest {
private static final ExecutionVertexID EV1 = createRandomExecutionVertexId();
private static final ExecutionVertexID EV2 = createRandomExecutionVertexId();
private static final ExecutionVertexID EV3 = createRandomExecutionVertexId();
private static final ExecutionVertexID EV4 = createRandomExecutionVertexId();
private static final ExecutionVertexID EV5 = createRandomExecutionVertexId();

@Test
public void testAvailableInputLocationRetrieval() {
void testAvailableInputLocationRetrieval() {
TestingInputsLocationsRetriever originalLocationRetriever =
new TestingInputsLocationsRetriever.Builder()
.connectConsumerToProducer(EV5, EV1)
Expand All @@ -64,6 +62,6 @@ public void testAvailableInputLocationRetrieval() {
TaskManagerLocation expectedLocation =
originalLocationRetriever.getTaskManagerLocation(EV1).get().join();

assertThat(preferredLocations, contains(expectedLocation));
assertThat(preferredLocations).containsExactly(expectedLocation);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,30 @@
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorResource;
import org.apache.flink.util.TestLogger;
import org.apache.flink.testutils.executor.TestExecutorExtension;

import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.util.Collection;
import java.util.Collections;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Tests for {@link ExecutionGraphToInputsLocationsRetrieverAdapter}. */
public class ExecutionGraphToInputsLocationsRetrieverAdapterTest extends TestLogger {
class ExecutionGraphToInputsLocationsRetrieverAdapterTest {

@ClassRule
public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
TestingUtils.defaultExecutorResource();
@RegisterExtension
static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION =
TestingUtils.defaultExecutorExtension();

/** Tests that can get the producers of consumed result partitions. */
@Test
public void testGetConsumedResultPartitionsProducers() throws Exception {
void testGetConsumedResultPartitionsProducers() throws Exception {
final JobVertex producer1 = ExecutionGraphTestUtils.createNoOpVertex(1);
final JobVertex producer2 = ExecutionGraphTestUtils.createNoOpVertex(1);
final JobVertex consumer = ExecutionGraphTestUtils.createNoOpVertex(1);
Expand All @@ -72,7 +65,7 @@ public void testGetConsumedResultPartitionsProducers() throws Exception {

final ExecutionGraph eg =
ExecutionGraphTestUtils.createExecutionGraph(
EXECUTOR_RESOURCE.getExecutor(), producer1, producer2, consumer);
EXECUTOR_EXTENSION.getExecutor(), producer1, producer2, consumer);
final ExecutionGraphToInputsLocationsRetrieverAdapter inputsLocationsRetriever =
new ExecutionGraphToInputsLocationsRetrieverAdapter(eg);

Expand All @@ -87,41 +80,43 @@ public void testGetConsumedResultPartitionsProducers() throws Exception {
Collection<Collection<ExecutionVertexID>> producersOfConsumer =
inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfConsumer);

assertThat(producersOfProducer1, is(empty()));
assertThat(producersOfProducer2, is(empty()));
assertThat(producersOfConsumer, hasSize(2));
assertThat(producersOfConsumer, hasItem(Collections.singletonList(evIdOfProducer1)));
assertThat(producersOfConsumer, hasItem(Collections.singletonList(evIdOfProducer2)));
assertThat(producersOfProducer1).isEmpty();
assertThat(producersOfProducer2).isEmpty();
assertThat(producersOfConsumer).hasSize(2);
assertThat(producersOfConsumer)
.containsExactlyInAnyOrder(
Collections.singletonList(evIdOfProducer1),
Collections.singletonList(evIdOfProducer2));
}

/** Tests that it will get empty task manager location if vertex is not scheduled. */
@Test
public void testGetEmptyTaskManagerLocationIfVertexNotScheduled() throws Exception {
void testGetEmptyTaskManagerLocationIfVertexNotScheduled() throws Exception {
final JobVertex jobVertex = ExecutionGraphTestUtils.createNoOpVertex(1);

final ExecutionGraph eg =
ExecutionGraphTestUtils.createExecutionGraph(
EXECUTOR_RESOURCE.getExecutor(), jobVertex);
EXECUTOR_EXTENSION.getExecutor(), jobVertex);
final ExecutionGraphToInputsLocationsRetrieverAdapter inputsLocationsRetriever =
new ExecutionGraphToInputsLocationsRetrieverAdapter(eg);

ExecutionVertexID executionVertexId = new ExecutionVertexID(jobVertex.getID(), 0);
Optional<CompletableFuture<TaskManagerLocation>> taskManagerLocation =
inputsLocationsRetriever.getTaskManagerLocation(executionVertexId);

assertFalse(taskManagerLocation.isPresent());
assertThat(taskManagerLocation).isNotPresent();
}

/** Tests that it can get the task manager location in an Execution. */
@Test
public void testGetTaskManagerLocationWhenScheduled() throws Exception {
void testGetTaskManagerLocationWhenScheduled() throws Exception {
final JobVertex jobVertex = ExecutionGraphTestUtils.createNoOpVertex(1);

final TestingLogicalSlot testingLogicalSlot =
new TestingLogicalSlotBuilder().createTestingLogicalSlot();
final ExecutionGraph eg =
ExecutionGraphTestUtils.createExecutionGraph(
EXECUTOR_RESOURCE.getExecutor(), jobVertex);
EXECUTOR_EXTENSION.getExecutor(), jobVertex);
final ExecutionGraphToInputsLocationsRetrieverAdapter inputsLocationsRetriever =
new ExecutionGraphToInputsLocationsRetrieverAdapter(eg);

Expand All @@ -133,34 +128,34 @@ public void testGetTaskManagerLocationWhenScheduled() throws Exception {
Optional<CompletableFuture<TaskManagerLocation>> taskManagerLocationOptional =
inputsLocationsRetriever.getTaskManagerLocation(executionVertexId);

assertTrue(taskManagerLocationOptional.isPresent());
assertThat(taskManagerLocationOptional).isPresent();

final CompletableFuture<TaskManagerLocation> taskManagerLocationFuture =
taskManagerLocationOptional.get();
assertThat(
taskManagerLocationFuture.get(), is(testingLogicalSlot.getTaskManagerLocation()));
assertThat(taskManagerLocationFuture.get())
.isEqualTo(testingLogicalSlot.getTaskManagerLocation());
}

/**
* Tests that it will throw exception when getting the task manager location of a non existing
* execution.
*/
@Test
public void testGetNonExistingExecutionVertexWillThrowException() throws Exception {
void testGetNonExistingExecutionVertexWillThrowException() throws Exception {
final JobVertex jobVertex = ExecutionGraphTestUtils.createNoOpVertex(1);

final ExecutionGraph eg =
ExecutionGraphTestUtils.createExecutionGraph(
EXECUTOR_RESOURCE.getExecutor(), jobVertex);
EXECUTOR_EXTENSION.getExecutor(), jobVertex);
final ExecutionGraphToInputsLocationsRetrieverAdapter inputsLocationsRetriever =
new ExecutionGraphToInputsLocationsRetrieverAdapter(eg);

ExecutionVertexID invalidExecutionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
try {
inputsLocationsRetriever.getTaskManagerLocation(invalidExecutionVertexId);
fail("Should throw exception if execution vertex doesn't exist!");
} catch (IllegalStateException expected) {
// expect this exception
}
assertThatThrownBy(
() ->
inputsLocationsRetriever.getTaskManagerLocation(
invalidExecutionVertexId),
"Should throw exception if execution vertex doesn't exist!")
.isInstanceOf(IllegalStateException.class);
}
}

0 comments on commit 00c2ebd

Please sign in to comment.