diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java index 0da624f1e7f94..887da47ed0c53 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java @@ -1246,7 +1246,7 @@ private boolean checkAndConfigurePersistentIntermediateResult(PlanNode node) { predecessorVertex != null, "Bug: Chained task has not been assigned its containing vertex when connecting."); - predecessorVertex.createAndAddResultDataSet( + predecessorVertex.getOrCreateResultDataSet( // use specified intermediateDataSetID new IntermediateDataSetID( ((BlockingShuffleOutputFormat) userCodeObject).getIntermediateDataSetId()), @@ -1326,7 +1326,7 @@ private DistributionPattern connectJobVertices( JobEdge edge = targetVertex.connectNewDataSetAsInput( - sourceVertex, distributionPattern, resultType); + sourceVertex, distributionPattern, resultType, isBroadcast); // -------------- configure the source task's ship strategy strategies in task config // -------------- @@ -1403,7 +1403,6 @@ private DistributionPattern connectJobVertices( channel.getTempMode() == TempMode.NONE ? null : channel.getTempMode().toString(); edge.setShipStrategyName(shipStrategy); - edge.setBroadcast(isBroadcast); edge.setForward(channel.getShipStrategy() == ShipStrategyType.FORWARD); edge.setPreProcessingOperationName(localStrategy); edge.setOperatorLevelCachingDescription(caching); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java index a2c45ed0c475d..ba10dd59de844 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java @@ -133,7 +133,9 @@ private List createInputGateDeploymentDescriptors IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); SubpartitionIndexRange consumedSubpartitionRange = computeConsumedSubpartitionRange( - resultPartition, executionId.getSubtaskIndex()); + consumedPartitionGroup.getNumConsumers(), + resultPartition, + executionId.getSubtaskIndex()); IntermediateDataSetID resultId = consumedIntermediateResult.getId(); ResultPartitionType partitionType = consumedIntermediateResult.getResultType(); @@ -164,8 +166,9 @@ private List createInputGateDeploymentDescriptors } public static SubpartitionIndexRange computeConsumedSubpartitionRange( - IntermediateResultPartition resultPartition, int consumerSubtaskIndex) { - int numConsumers = resultPartition.getConsumerVertexGroup().size(); + int numConsumers, + IntermediateResultPartition resultPartition, + int consumerSubtaskIndex) { int consumerIndex = consumerSubtaskIndex % numConsumers; IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); int numSubpartitions = resultPartition.getNumberOfSubpartitions(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java index 2f160f146054e..ce1f18c550441 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java @@ -1399,6 +1399,7 @@ private void releasePartitionGroups( final List releasablePartitionGroups) { if (releasablePartitionGroups.size() > 0) { + final List releasablePartitionIds = new ArrayList<>(); // Remove the cache of ShuffleDescriptors when ConsumedPartitionGroups are released for (ConsumedPartitionGroup releasablePartitionGroup : releasablePartitionGroups) { @@ -1406,15 +1407,17 @@ private void releasePartitionGroups( checkNotNull( intermediateResults.get( releasablePartitionGroup.getIntermediateDataSetID())); + for (IntermediateResultPartitionID partitionId : releasablePartitionGroup) { + IntermediateResultPartition partition = + totalResult.getPartitionById(partitionId); + partition.markPartitionGroupReleasable(releasablePartitionGroup); + if (partition.canBeReleased()) { + releasablePartitionIds.add(createResultPartitionId(partitionId)); + } + } totalResult.clearCachedInformationForPartitionGroup(releasablePartitionGroup); } - final List releasablePartitionIds = - releasablePartitionGroups.stream() - .flatMap(IterableUtils::toStream) - .map(this::createResultPartitionId) - .collect(Collectors.toList()); - partitionTracker.stopTrackingAndReleasePartitions(releasablePartitionIds); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java index 8efc25a913b03..1a437e160305c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java @@ -30,12 +30,11 @@ import java.util.Map; import static org.apache.flink.util.Preconditions.checkNotNull; -import static org.apache.flink.util.Preconditions.checkState; /** Class that manages all the connections between tasks. */ public class EdgeManager { - private final Map partitionConsumers = + private final Map> partitionConsumers = new HashMap<>(); private final Map> vertexConsumedPartitions = @@ -50,9 +49,9 @@ public void connectPartitionWithConsumerVertexGroup( checkNotNull(consumerVertexGroup); - checkState( - partitionConsumers.putIfAbsent(resultPartitionId, consumerVertexGroup) == null, - "Currently one partition can have at most one consumer group"); + List groups = + getConsumerVertexGroupsForPartitionInternal(resultPartitionId); + groups.add(consumerVertexGroup); } public void connectVertexWithConsumedPartitionGroup( @@ -66,14 +65,20 @@ public void connectVertexWithConsumedPartitionGroup( consumedPartitions.add(consumedPartitionGroup); } + private List getConsumerVertexGroupsForPartitionInternal( + IntermediateResultPartitionID resultPartitionId) { + return partitionConsumers.computeIfAbsent(resultPartitionId, id -> new ArrayList<>()); + } + private List getConsumedPartitionGroupsForVertexInternal( ExecutionVertexID executionVertexId) { return vertexConsumedPartitions.computeIfAbsent(executionVertexId, id -> new ArrayList<>()); } - public ConsumerVertexGroup getConsumerVertexGroupForPartition( + public List getConsumerVertexGroupsForPartition( IntermediateResultPartitionID resultPartitionId) { - return partitionConsumers.get(resultPartitionId); + return Collections.unmodifiableList( + getConsumerVertexGroupsForPartitionInternal(resultPartitionId)); } public List getConsumedPartitionGroupsForVertex( @@ -100,4 +105,9 @@ public List getConsumedPartitionGroupsById( return Collections.unmodifiableList( getConsumedPartitionGroupsByIdInternal(resultPartitionId)); } + + public int getNumberOfConsumedPartitionGroupsById( + IntermediateResultPartitionID resultPartitionId) { + return getConsumedPartitionGroupsByIdInternal(resultPartitionId).size(); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java index 3aa6e59de0d81..8ac55b7b7a260 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java @@ -89,7 +89,7 @@ private static void connectAllToAll( .collect(Collectors.toList()); ConsumedPartitionGroup consumedPartitionGroup = createAndRegisterConsumedPartitionGroupToEdgeManager( - consumedPartitions, intermediateResult); + taskVertices.length, consumedPartitions, intermediateResult); for (ExecutionVertex ev : taskVertices) { ev.addConsumedPartitionGroup(consumedPartitionGroup); } @@ -122,7 +122,9 @@ private static void connectPointwise( ConsumedPartitionGroup consumedPartitionGroup = createAndRegisterConsumedPartitionGroupToEdgeManager( - partition.getPartitionId(), intermediateResult); + consumerVertexGroup.size(), + partition.getPartitionId(), + intermediateResult); executionVertex.addConsumedPartitionGroup(consumedPartitionGroup); } } else if (sourceCount > targetCount) { @@ -147,20 +149,19 @@ private static void connectPointwise( ConsumedPartitionGroup consumedPartitionGroup = createAndRegisterConsumedPartitionGroupToEdgeManager( - consumedPartitions, intermediateResult); + consumerVertexGroup.size(), consumedPartitions, intermediateResult); executionVertex.addConsumedPartitionGroup(consumedPartitionGroup); } } else { for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) { + int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount; + int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount; IntermediateResultPartition partition = intermediateResult.getPartitions()[partitionNum]; ConsumedPartitionGroup consumedPartitionGroup = createAndRegisterConsumedPartitionGroupToEdgeManager( - partition.getPartitionId(), intermediateResult); - - int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount; - int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount; + end - start, partition.getPartitionId(), intermediateResult); List consumers = new ArrayList<>(end - start); @@ -179,21 +180,23 @@ private static void connectPointwise( } private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager( + int numConsumers, IntermediateResultPartitionID consumedPartitionId, IntermediateResult intermediateResult) { ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromSinglePartition( - consumedPartitionId, intermediateResult.getResultType()); + numConsumers, consumedPartitionId, intermediateResult.getResultType()); registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult); return consumedPartitionGroup; } private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager( + int numConsumers, List consumedPartitions, IntermediateResult intermediateResult) { ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromMultiplePartitions( - consumedPartitions, intermediateResult.getResultType()); + numConsumers, consumedPartitions, intermediateResult.getResultType()); registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult); return consumedPartitionGroup; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index 657effb66f6ec..f6de9059869eb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -65,6 +65,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -497,10 +498,7 @@ public CompletableFuture registerProducedPartitions(TaskManagerLocation lo } private static int getPartitionMaxParallelism(IntermediateResultPartition partition) { - return partition - .getIntermediateResult() - .getConsumerExecutionJobVertex() - .getMaxParallelism(); + return partition.getIntermediateResult().getConsumersMaxParallelism(); } /** @@ -718,31 +716,40 @@ public CompletableFuture suspend() { } private void updatePartitionConsumers(final IntermediateResultPartition partition) { - final Optional consumerVertexGroup = - partition.getConsumerVertexGroupOptional(); - if (!consumerVertexGroup.isPresent()) { + final List consumerVertexGroups = partition.getConsumerVertexGroups(); + if (consumerVertexGroups.isEmpty()) { return; } - for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) { - final ExecutionVertex consumerVertex = - vertex.getExecutionGraphAccessor().getExecutionVertexOrThrow(consumerVertexId); - final Execution consumer = consumerVertex.getCurrentExecutionAttempt(); - final ExecutionState consumerState = consumer.getState(); - - // ---------------------------------------------------------------- - // Consumer is recovering or running => send update message now - // Consumer is deploying => cache the partition info which would be - // sent after switching to running - // ---------------------------------------------------------------- - if (consumerState == DEPLOYING - || consumerState == RUNNING - || consumerState == INITIALIZING) { - final PartitionInfo partitionInfo = createPartitionInfo(partition); - - if (consumerState == DEPLOYING) { - consumerVertex.cachePartitionInfo(partitionInfo); - } else { - consumer.sendUpdatePartitionInfoRpcCall(Collections.singleton(partitionInfo)); + final Set updatedVertices = new HashSet<>(); + for (ConsumerVertexGroup consumerVertexGroup : consumerVertexGroups) { + for (ExecutionVertexID consumerVertexId : consumerVertexGroup) { + if (updatedVertices.contains(consumerVertexId)) { + continue; + } + + final ExecutionVertex consumerVertex = + vertex.getExecutionGraphAccessor() + .getExecutionVertexOrThrow(consumerVertexId); + final Execution consumer = consumerVertex.getCurrentExecutionAttempt(); + final ExecutionState consumerState = consumer.getState(); + + // ---------------------------------------------------------------- + // Consumer is recovering or running => send update message now + // Consumer is deploying => cache the partition info which would be + // sent after switching to running + // ---------------------------------------------------------------- + if (consumerState == DEPLOYING + || consumerState == RUNNING + || consumerState == INITIALIZING) { + final PartitionInfo partitionInfo = createPartitionInfo(partition); + updatedVertices.add(consumerVertexId); + + if (consumerState == DEPLOYING) { + consumerVertex.cachePartitionInfo(partitionInfo); + } else { + consumer.sendUpdatePartitionInfoRpcCall( + Collections.singleton(partitionInfo)); + } } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java index 4b666b5742be7..fd8be042f6f05 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java @@ -32,12 +32,15 @@ import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; public class IntermediateResult { @@ -68,6 +71,9 @@ public class IntermediateResult { private final Map> shuffleDescriptorCache; + /** All consumer job vertex ids of this dataset. */ + private final List consumerVertices = new ArrayList<>(); + public IntermediateResult( IntermediateDataSet intermediateDataSet, ExecutionJobVertex producer, @@ -95,6 +101,10 @@ public IntermediateResult( this.resultType = checkNotNull(resultType); this.shuffleDescriptorCache = new HashMap<>(); + + intermediateDataSet + .getConsumers() + .forEach(jobEdge -> consumerVertices.add(jobEdge.getTarget().getID())); } public void setPartition(int partitionNumber, IntermediateResultPartition partition) { @@ -124,6 +134,10 @@ public IntermediateResultPartition[] getPartitions() { return partitions; } + public List getConsumerVertices() { + return consumerVertices; + } + /** * Returns the partition with the given ID. * @@ -162,20 +176,60 @@ int getNumParallelProducers() { return numParallelProducers; } - ExecutionJobVertex getConsumerExecutionJobVertex() { - final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer()); - final JobVertexID consumerJobVertexId = consumer.getTarget().getID(); - return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId)); + /** + * Currently, this method is only used to compute the maximum number of consumers. For dynamic + * graph, it should be called before adaptively deciding the downstream consumer parallelism. + */ + int getConsumersParallelism() { + List consumers = intermediateDataSet.getConsumers(); + checkState(!consumers.isEmpty()); + + InternalExecutionGraphAccessor graph = getProducer().getGraph(); + int consumersParallelism = + graph.getJobVertex(consumers.get(0).getTarget().getID()).getParallelism(); + if (consumers.size() == 1) { + return consumersParallelism; + } + + // sanity check, all consumer vertices must have the same parallelism: + // 1. for vertices that are not assigned a parallelism initially (for example, dynamic + // graph), the parallelisms will all be -1 (parallelism not decided yet) + // 2. for vertices that are initially assigned a parallelism, the parallelisms must be the + // same, which is guaranteed at compilation phase + for (JobVertexID jobVertexID : consumerVertices) { + checkState( + consumersParallelism == graph.getJobVertex(jobVertexID).getParallelism(), + "Consumers must have the same parallelism."); + } + return consumersParallelism; + } + + int getConsumersMaxParallelism() { + List consumers = intermediateDataSet.getConsumers(); + checkState(!consumers.isEmpty()); + + InternalExecutionGraphAccessor graph = getProducer().getGraph(); + int consumersMaxParallelism = + graph.getJobVertex(consumers.get(0).getTarget().getID()).getMaxParallelism(); + if (consumers.size() == 1) { + return consumersMaxParallelism; + } + + // sanity check, all consumer vertices must have the same max parallelism + for (JobVertexID jobVertexID : consumerVertices) { + checkState( + consumersMaxParallelism == graph.getJobVertex(jobVertexID).getMaxParallelism(), + "Consumers must have the same max parallelism."); + } + return consumersMaxParallelism; } public DistributionPattern getConsumingDistributionPattern() { - final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer()); - return consumer.getDistributionPattern(); + return intermediateDataSet.getDistributionPattern(); } public boolean isBroadcast() { - final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer()); - return consumer.isBroadcast(); + return intermediateDataSet.isBroadcast(); } public int getConnectionIndex() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java index 5aef1b07471fb..9b9c176a3d965 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java @@ -21,11 +21,13 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; +import java.util.HashSet; import java.util.List; -import java.util.Optional; +import java.util.Set; import static org.apache.flink.util.Preconditions.checkState; @@ -47,6 +49,12 @@ public class IntermediateResultPartition { /** Whether this partition has produced some data. */ private boolean hasDataProduced = false; + /** + * Releasable {@link ConsumedPartitionGroup}s for this result partition. This result partition + * can be released if all {@link ConsumedPartitionGroup}s are releasable. + */ + private final Set releasablePartitionGroups = new HashSet<>(); + public IntermediateResultPartition( IntermediateResult totalResult, ExecutionVertex producer, @@ -58,6 +66,25 @@ public IntermediateResultPartition( this.edgeManager = edgeManager; } + public void markPartitionGroupReleasable(ConsumedPartitionGroup partitionGroup) { + releasablePartitionGroups.add(partitionGroup); + } + + public boolean canBeReleased() { + if (releasablePartitionGroups.size() + != edgeManager.getNumberOfConsumedPartitionGroupsById(partitionId)) { + return false; + } + for (JobVertexID jobVertexId : totalResult.getConsumerVertices()) { + // for dynamic graph, if any consumer vertex is still not initialized, this result + // partition can not be released + if (!producer.getExecutionGraphAccessor().getJobVertex(jobVertexId).isInitialized()) { + return false; + } + } + return true; + } + public ExecutionVertex getProducer() { return producer; } @@ -78,15 +105,8 @@ public ResultPartitionType getResultType() { return totalResult.getResultType(); } - public ConsumerVertexGroup getConsumerVertexGroup() { - Optional consumerVertexGroup = getConsumerVertexGroupOptional(); - checkState(consumerVertexGroup.isPresent()); - return consumerVertexGroup.get(); - } - - public Optional getConsumerVertexGroupOptional() { - return Optional.ofNullable( - getEdgeManager().getConsumerVertexGroupForPartition(partitionId)); + public List getConsumerVertexGroups() { + return getEdgeManager().getConsumerVertexGroupsForPartition(partitionId); } public List getConsumedPartitionGroups() { @@ -106,12 +126,13 @@ public int getNumberOfSubpartitions() { private int computeNumberOfSubpartitions() { if (!getProducer().getExecutionGraphAccessor().isDynamic()) { - ConsumerVertexGroup consumerVertexGroup = getConsumerVertexGroup(); - checkState(consumerVertexGroup.size() > 0); + List consumerVertexGroups = getConsumerVertexGroups(); + checkState(!consumerVertexGroups.isEmpty()); // The produced data is partitioned among a number of subpartitions, one for each - // consuming sub task. - return consumerVertexGroup.size(); + // consuming sub task. All vertex groups must have the same number of consumers + // for non-dynamic graph. + return consumerVertexGroups.get(0).size(); } else { if (totalResult.isBroadcast()) { // for dynamic graph and broadcast result, we only produced one subpartition, @@ -124,18 +145,16 @@ private int computeNumberOfSubpartitions() { } private int computeNumberOfMaxPossiblePartitionConsumers() { - final ExecutionJobVertex consumerJobVertex = - getIntermediateResult().getConsumerExecutionJobVertex(); final DistributionPattern distributionPattern = getIntermediateResult().getConsumingDistributionPattern(); // decide the max possible consumer job vertex parallelism - int maxConsumerJobVertexParallelism = consumerJobVertex.getParallelism(); + int maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersParallelism(); if (maxConsumerJobVertexParallelism <= 0) { + maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersMaxParallelism(); checkState( - consumerJobVertex.getMaxParallelism() > 0, + maxConsumerJobVertexParallelism > 0, "Neither the parallelism nor the max parallelism of a job vertex is set"); - maxConsumerJobVertexParallelism = consumerJobVertex.getMaxParallelism(); } // compute number of subpartitions according to the distribution pattern @@ -163,6 +182,7 @@ void resetForNewExecution() { consumedPartitionGroup.partitionUnfinished(); } } + releasablePartitionGroups.clear(); hasDataProduced = false; for (ConsumedPartitionGroup consumedPartitionGroup : getConsumedPartitionGroups()) { totalResult.clearCachedInformationForPartitionGroup(consumedPartitionGroup); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java index 6e7181ffd035e..39d7fe7254742 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java @@ -230,12 +230,12 @@ private Iterable getConsumerVerticesToVisit( for (SchedulingExecutionVertex vertex : regionToRestart.getVertices()) { for (SchedulingResultPartition producedPartition : vertex.getProducedResults()) { - final Optional consumerVertexGroup = - producedPartition.getConsumerVertexGroup(); - if (consumerVertexGroup.isPresent() - && !visitedConsumerVertexGroups.contains(consumerVertexGroup.get())) { - visitedConsumerVertexGroups.add(consumerVertexGroup.get()); - consumerVertexGroupsToVisit.add(consumerVertexGroup.get()); + for (ConsumerVertexGroup consumerVertexGroup : + producedPartition.getConsumerVertexGroups()) { + if (!visitedConsumerVertexGroups.contains(consumerVertexGroup)) { + visitedConsumerVertexGroups.add(consumerVertexGroup); + consumerVertexGroupsToVisit.add(consumerVertexGroup); + } } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java index de353be7ba89e..6428284cd49d5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java @@ -32,7 +32,6 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -116,23 +115,20 @@ private static List> buildOutEdgesDesc( if (producedResult.getResultType().mustBePipelinedConsumed()) { continue; } - final Optional consumerVertexGroup = - producedResult.getConsumerVertexGroup(); - if (!consumerVertexGroup.isPresent()) { - continue; - } - - for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) { - SchedulingExecutionVertex consumerVertex = - executionVertexRetriever.apply(consumerVertexId); - // Skip the ConsumerVertexGroup if its vertices are outside current - // regions and cannot be merged - if (!vertexToRegion.containsKey(consumerVertex)) { - break; - } - if (!currentRegion.contains(consumerVertex)) { - currentRegionOutEdges.add( - regionIndices.get(vertexToRegion.get(consumerVertex))); + for (ConsumerVertexGroup consumerVertexGroup : + producedResult.getConsumerVertexGroups()) { + for (ExecutionVertexID consumerVertexId : consumerVertexGroup) { + SchedulingExecutionVertex consumerVertex = + executionVertexRetriever.apply(consumerVertexId); + // Skip the ConsumerVertexGroup if its vertices are outside current + // regions and cannot be merged + if (!vertexToRegion.containsKey(consumerVertex)) { + break; + } + if (!currentRegion.contains(consumerVertex)) { + currentRegionOutEdges.add( + regionIndices.get(vertexToRegion.get(consumerVertex))); + } } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java index aa45dd4b56fec..2d55cf4c6398a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java @@ -56,6 +56,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; @@ -93,7 +94,7 @@ public class NettyShuffleEnvironment private final FileChannelManager fileChannelManager; - private final Map inputGatesById; + private final Map> inputGatesById; private final ResultPartitionFactory resultPartitionFactory; @@ -169,7 +170,7 @@ public NettyShuffleEnvironmentConfiguration getConfiguration() { } @VisibleForTesting - public Optional getInputGate(InputGateID id) { + public Optional> getInputGate(InputGateID id) { return Optional.ofNullable(inputGatesById.get(id)); } @@ -260,8 +261,24 @@ public List createInputGates( InputGateID id = new InputGateID( igdd.getConsumedResultId(), ownerContext.getExecutionAttemptID()); - inputGatesById.put(id, inputGate); - inputGate.getCloseFuture().thenRun(() -> inputGatesById.remove(id)); + Set inputGateSet = + inputGatesById.computeIfAbsent( + id, ignored -> ConcurrentHashMap.newKeySet()); + inputGateSet.add(inputGate); + inputGatesById.put(id, inputGateSet); + inputGate + .getCloseFuture() + .thenRun( + () -> + inputGatesById.computeIfPresent( + id, + (key, value) -> { + value.remove(inputGate); + if (value.isEmpty()) { + return null; + } + return value; + })); inputGates[gateIndex] = inputGate; } @@ -297,17 +314,20 @@ public boolean updatePartitionInfo(ExecutionAttemptID consumerID, PartitionInfo IntermediateDataSetID intermediateResultPartitionID = partitionInfo.getIntermediateDataSetID(); InputGateID id = new InputGateID(intermediateResultPartitionID, consumerID); - SingleInputGate inputGate = inputGatesById.get(id); - if (inputGate == null) { + Set inputGates = inputGatesById.get(id); + if (inputGates == null || inputGates.isEmpty()) { return false; } + ShuffleDescriptor shuffleDescriptor = partitionInfo.getShuffleDescriptor(); checkArgument( shuffleDescriptor instanceof NettyShuffleDescriptor, "Tried to update unknown channel with unknown ShuffleDescriptor %s.", shuffleDescriptor.getClass().getName()); - inputGate.updateInputChannel( - taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor); + for (SingleInputGate inputGate : inputGates) { + inputGate.updateInputChannel( + taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor); + } return true; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java index d6b3abc3b52ed..2aad26139969f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java @@ -20,7 +20,8 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; -import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -39,11 +40,16 @@ public class IntermediateDataSet implements java.io.Serializable { private final JobVertex producer; // the operation that produced this data set - @Nullable private JobEdge consumer; + // All consumers must have the same partitioner and parallelism + private final List consumers = new ArrayList<>(); // The type of partition to use at runtime private final ResultPartitionType resultType; + private DistributionPattern distributionPattern; + + private boolean isBroadcast; + // -------------------------------------------------------------------------------------------- public IntermediateDataSet( @@ -63,9 +69,16 @@ public JobVertex getProducer() { return producer; } - @Nullable - public JobEdge getConsumer() { - return consumer; + public List getConsumers() { + return this.consumers; + } + + public boolean isBroadcast() { + return isBroadcast; + } + + public DistributionPattern getDistributionPattern() { + return distributionPattern; } public ResultPartitionType getResultType() { @@ -75,10 +88,19 @@ public ResultPartitionType getResultType() { // -------------------------------------------------------------------------------------------- public void addConsumer(JobEdge edge) { - checkState( - this.consumer == null, - "Currently one IntermediateDataSet can have at most one consumer."); - this.consumer = edge; + // sanity check + checkState(id.equals(edge.getSourceId()), "Incompatible dataset id."); + + if (consumers.isEmpty()) { + distributionPattern = edge.getDistributionPattern(); + isBroadcast = edge.isBroadcast(); + } else { + checkState( + distributionPattern == edge.getDistributionPattern(), + "Incompatible distribution pattern."); + checkState(isBroadcast == edge.isBroadcast(), "Incompatible broadcast type."); + } + consumers.add(edge); } // -------------------------------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java index 9772ff4dbb146..4649303c14476 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java @@ -44,10 +44,7 @@ public class JobEdge implements java.io.Serializable { private SubtaskStateMapper upstreamSubtaskStateMapper = SubtaskStateMapper.ROUND_ROBIN; /** The data set at the source of the edge, may be null if the edge is not yet connected. */ - private IntermediateDataSet source; - - /** The id of the source intermediate data set. */ - private IntermediateDataSetID sourceId; + private final IntermediateDataSet source; /** * Optional name for the data shipping strategy (forward, partition hash, rebalance, ...), to be @@ -55,7 +52,7 @@ public class JobEdge implements java.io.Serializable { */ private String shipStrategyName; - private boolean isBroadcast; + private final boolean isBroadcast; private boolean isForward; @@ -74,36 +71,20 @@ public class JobEdge implements java.io.Serializable { * @param source The data set that is at the source of this edge. * @param target The operation that is at the target of this edge. * @param distributionPattern The pattern that defines how the connection behaves in parallel. + * @param isBroadcast Whether the source broadcasts data to the target. */ public JobEdge( - IntermediateDataSet source, JobVertex target, DistributionPattern distributionPattern) { + IntermediateDataSet source, + JobVertex target, + DistributionPattern distributionPattern, + boolean isBroadcast) { if (source == null || target == null || distributionPattern == null) { throw new NullPointerException(); } this.target = target; this.distributionPattern = distributionPattern; this.source = source; - this.sourceId = source.getId(); - } - - /** - * Constructs a new job edge that refers to an intermediate result via the Id, rather than - * directly through the intermediate data set structure. - * - * @param sourceId The id of the data set that is at the source of this edge. - * @param target The operation that is at the target of this edge. - * @param distributionPattern The pattern that defines how the connection behaves in parallel. - */ - public JobEdge( - IntermediateDataSetID sourceId, - JobVertex target, - DistributionPattern distributionPattern) { - if (sourceId == null || target == null || distributionPattern == null) { - throw new NullPointerException(); - } - this.target = target; - this.distributionPattern = distributionPattern; - this.sourceId = sourceId; + this.isBroadcast = isBroadcast; } /** @@ -140,11 +121,7 @@ public DistributionPattern getDistributionPattern() { * @return The ID of the consumed data set. */ public IntermediateDataSetID getSourceId() { - return sourceId; - } - - public boolean isIdReference() { - return this.source == null; + return source.getId(); } // -------------------------------------------------------------------------------------------- @@ -173,11 +150,6 @@ public boolean isBroadcast() { return isBroadcast; } - /** Sets whether the edge is broadcast edge. */ - public void setBroadcast(boolean broadcast) { - isBroadcast = broadcast; - } - /** Gets whether the edge is forward edge. */ public boolean isForward() { return isForward; @@ -268,6 +240,6 @@ public void setOperatorLevelCachingDescription(String operatorLevelCachingDescri @Override public String toString() { - return String.format("%s --> %s [%s]", sourceId, target, distributionPattern.name()); + return String.format("%s --> %s [%s]", source.getId(), target, distributionPattern.name()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java index e8baef4d8e7aa..bb821adda7b04 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java @@ -458,10 +458,9 @@ public List getVerticesSortedTopologicallyFromSources() private void addNodesThatHaveNoNewPredecessors( JobVertex start, List target, Set remaining) { - // forward traverse over all produced data sets + // forward traverse over all produced data sets and all their consumers for (IntermediateDataSet dataSet : start.getProducedDataSets()) { - JobEdge edge = dataSet.getConsumer(); - if (edge != null) { + for (JobEdge edge : dataSet.getConsumers()) { // a vertex can be added, if it has no predecessors that are still in the // 'remaining' set diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java index 36d1a1e548791..717bd6d437613 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java @@ -36,7 +36,9 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -72,8 +74,8 @@ public class JobVertex implements java.io.Serializable { */ private final List operatorIDs; - /** List of produced data sets, one per writer. */ - private final ArrayList results = new ArrayList<>(); + /** Produced data sets, one per writer. */ + private final Map results = new LinkedHashMap<>(); /** List of edges with incoming data. One per Reader. */ private final ArrayList inputs = new ArrayList<>(); @@ -374,7 +376,7 @@ public void setInputSplitSource(InputSplitSource inputSplitSource) { } public List getProducedDataSets() { - return this.results; + return new ArrayList<>(results.values()); } public List getInputs() { @@ -481,30 +483,37 @@ public void updateCoLocationGroup(CoLocationGroupImpl group) { } // -------------------------------------------------------------------------------------------- - public IntermediateDataSet createAndAddResultDataSet( + public IntermediateDataSet getOrCreateResultDataSet( IntermediateDataSetID id, ResultPartitionType partitionType) { - - IntermediateDataSet result = new IntermediateDataSet(id, partitionType, this); - this.results.add(result); - return result; + return this.results.computeIfAbsent( + id, key -> new IntermediateDataSet(id, partitionType, this)); } public JobEdge connectNewDataSetAsInput( JobVertex input, DistributionPattern distPattern, ResultPartitionType partitionType) { + return connectNewDataSetAsInput(input, distPattern, partitionType, false); + } + + public JobEdge connectNewDataSetAsInput( + JobVertex input, + DistributionPattern distPattern, + ResultPartitionType partitionType, + boolean isBroadcast) { return connectNewDataSetAsInput( - input, distPattern, partitionType, new IntermediateDataSetID()); + input, distPattern, partitionType, new IntermediateDataSetID(), isBroadcast); } public JobEdge connectNewDataSetAsInput( JobVertex input, DistributionPattern distPattern, ResultPartitionType partitionType, - IntermediateDataSetID intermediateDataSetId) { + IntermediateDataSetID intermediateDataSetId, + boolean isBroadcast) { IntermediateDataSet dataSet = - input.createAndAddResultDataSet(intermediateDataSetId, partitionType); + input.getOrCreateResultDataSet(intermediateDataSetId, partitionType); - JobEdge edge = new JobEdge(dataSet, this, distPattern); + JobEdge edge = new JobEdge(dataSet, this, distPattern, isBroadcast); this.inputs.add(edge); dataSet.addConsumer(edge); return edge; @@ -525,13 +534,7 @@ public boolean isOutputVertex() { } public boolean hasNoConnectedInputs() { - for (JobEdge edge : inputs) { - if (!edge.isIdReference()) { - return false; - } - } - - return true; + return inputs.isEmpty(); } public void markContainsSources() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java index 0fac885dd3f53..1506152fa3864 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java @@ -45,7 +45,6 @@ import java.util.Map; import java.util.function.Function; -import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** @@ -143,15 +142,22 @@ private static Map getMaxSubpartitionNums( Map ret = new HashMap<>(); List producedDataSets = ejv.getJobVertex().getProducedDataSets(); - for (int i = 0; i < producedDataSets.size(); i++) { - IntermediateDataSet producedDataSet = producedDataSets.get(i); - JobEdge outputEdge = checkNotNull(producedDataSet.getConsumer()); - ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID()); - int maxNum = - EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex( - ejv.getParallelism(), - consumerJobVertex.getParallelism(), - outputEdge.getDistributionPattern()); + checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic graph."); + for (IntermediateDataSet producedDataSet : producedDataSets) { + int maxNum = 0; + List outputEdges = producedDataSet.getConsumers(); + + if (!outputEdges.isEmpty()) { + // for non-dynamic graph, the consumer vertices' parallelisms and distribution + // patterns must be the same + JobEdge outputEdge = outputEdges.get(0); + ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID()); + maxNum = + EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex( + ejv.getParallelism(), + consumerJobVertex.getParallelism(), + outputEdge.getDistributionPattern()); + } ret.put(producedDataSet.getId(), maxNum); } @@ -177,7 +183,9 @@ static Map getMaxInputChannelNumsForDynamicGraph ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst())); SubpartitionIndexRange subpartitionIndexRange = TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange( - resultPartition, vertex.getParallelSubtaskIndex()); + partitionGroup.getNumConsumers(), + resultPartition, + vertex.getParallelSubtaskIndex()); ret.merge( partitionGroup.getIntermediateDataSetID(), diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java index 486ae28b264bd..f2fd573e61302 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java @@ -262,7 +262,7 @@ private void generateNewExecutionVerticesAndResultPartitions( List producedPartitions = generateProducedSchedulingResultPartition( vertex.getProducedPartitions(), - edgeManager::getConsumerVertexGroupForPartition); + edgeManager::getConsumerVertexGroupsForPartition); producedPartitions.forEach( partition -> resultPartitionsById.put(partition.getId(), partition)); @@ -285,8 +285,8 @@ private void generateNewExecutionVerticesAndResultPartitions( private static List generateProducedSchedulingResultPartition( Map producedIntermediatePartitions, - Function - partitionConsumerVertexGroupRetriever) { + Function> + partitionConsumerVertexGroupsRetriever) { List producedSchedulingPartitions = new ArrayList<>(producedIntermediatePartitions.size()); @@ -305,8 +305,8 @@ private static List generateProducedSchedulingResultPart ? ResultPartitionState.CONSUMABLE : ResultPartitionState.CREATED, () -> - partitionConsumerVertexGroupRetriever.apply( - irp.getPartitionId()), + partitionConsumerVertexGroupsRetriever + .apply(irp.getPartitionId()), irp::getConsumedPartitionGroups))); return producedSchedulingPartitions; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java index b46c54bd58e08..2ae54af430477 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java @@ -27,7 +27,6 @@ import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; import java.util.List; -import java.util.Optional; import java.util.function.Supplier; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -45,7 +44,7 @@ class DefaultResultPartition implements SchedulingResultPartition { private DefaultExecutionVertex producer; - private final Supplier consumerVertexGroupSupplier; + private final Supplier> consumerVertexGroupsSupplier; private final Supplier> consumerPartitionGroupSupplier; @@ -54,13 +53,13 @@ class DefaultResultPartition implements SchedulingResultPartition { IntermediateDataSetID intermediateDataSetId, ResultPartitionType partitionType, Supplier resultPartitionStateSupplier, - Supplier consumerVertexGroupSupplier, + Supplier> consumerVertexGroupsSupplier, Supplier> consumerPartitionGroupSupplier) { this.resultPartitionId = checkNotNull(partitionId); this.intermediateDataSetId = checkNotNull(intermediateDataSetId); this.partitionType = checkNotNull(partitionType); this.resultPartitionStateSupplier = checkNotNull(resultPartitionStateSupplier); - this.consumerVertexGroupSupplier = checkNotNull(consumerVertexGroupSupplier); + this.consumerVertexGroupsSupplier = checkNotNull(consumerVertexGroupsSupplier); this.consumerPartitionGroupSupplier = checkNotNull(consumerPartitionGroupSupplier); } @@ -90,8 +89,8 @@ public DefaultExecutionVertex getProducer() { } @Override - public Optional getConsumerVertexGroup() { - return Optional.ofNullable(consumerVertexGroupSupplier.get()); + public List getConsumerVertexGroups() { + return checkNotNull(consumerVertexGroupsSupplier.get()); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java index 6e4672c48e99b..2f5be93ca514e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java @@ -42,12 +42,17 @@ public class ConsumedPartitionGroup implements Iterable resultPartitions, ResultPartitionType resultPartitionType) { checkArgument( resultPartitions.size() > 0, "The size of result partitions in the ConsumedPartitionGroup should be larger than 0."); + this.numConsumers = numConsumers; this.intermediateDataSetID = resultPartitions.get(0).getIntermediateDataSetID(); this.resultPartitionType = Preconditions.checkNotNull(resultPartitionType); @@ -63,16 +68,18 @@ private ConsumedPartitionGroup( } public static ConsumedPartitionGroup fromMultiplePartitions( + int numConsumers, List resultPartitions, ResultPartitionType resultPartitionType) { - return new ConsumedPartitionGroup(resultPartitions, resultPartitionType); + return new ConsumedPartitionGroup(numConsumers, resultPartitions, resultPartitionType); } public static ConsumedPartitionGroup fromSinglePartition( + int numConsumers, IntermediateResultPartitionID resultPartition, ResultPartitionType resultPartitionType) { return new ConsumedPartitionGroup( - Collections.singletonList(resultPartition), resultPartitionType); + numConsumers, Collections.singletonList(resultPartition), resultPartitionType); } @Override @@ -88,6 +95,14 @@ public boolean isEmpty() { return resultPartitions.isEmpty(); } + /** + * In dynamic graph cases, the number of consumers of ConsumedPartitionGroup can be different + * even if they contain the same IntermediateResultPartition. + */ + public int getNumConsumers() { + return numConsumers; + } + public IntermediateResultPartitionID getFirst() { return iterator().next(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java index b52150c96cd15..86cafebac6b70 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java @@ -24,7 +24,6 @@ import org.apache.flink.runtime.topology.Result; import java.util.List; -import java.util.Optional; /** Representation of {@link IntermediateResultPartition}. */ public interface SchedulingResultPartition @@ -49,11 +48,11 @@ public interface SchedulingResultPartition ResultPartitionState getState(); /** - * Gets the {@link ConsumerVertexGroup}. + * Gets the {@link ConsumerVertexGroup}s. * - * @return {@link ConsumerVertexGroup} if consumers exists, otherwise {@link Optional#empty()}. + * @return list of {@link ConsumerVertexGroup}s */ - Optional getConsumerVertexGroup(); + List getConsumerVertexGroups(); /** * Gets the {@link ConsumedPartitionGroup}s this partition belongs to. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java index 4f7d01e4e8f90..4d217a9eb9793 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java @@ -24,12 +24,12 @@ import org.apache.flink.runtime.scheduler.SchedulingTopologyListener; import org.apache.flink.util.IterableUtils; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -85,11 +85,9 @@ public void onExecutionStateChange( Set consumerVertices = IterableUtils.toStream(executionVertex.getProducedResults()) - .map(SchedulingResultPartition::getConsumerVertexGroup) - .filter(Optional::isPresent) - .flatMap( - consumerVertexGroup -> - IterableUtils.toStream(consumerVertexGroup.get())) + .map(SchedulingResultPartition::getConsumerVertexGroups) + .flatMap(Collection::stream) + .flatMap(IterableUtils::toStream) .collect(Collectors.toSet()); maybeScheduleVertices(consumerVertices); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java new file mode 100644 index 0000000000000..af1c543c509a8 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.blob.TestingBlobWriter; +import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; +import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter; +import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService; +import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.scheduler.SchedulerBase; +import org.apache.flink.testutils.TestingUtils; +import org.apache.flink.testutils.executor.TestExecutorExtension; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; + +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex; +import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createSchedulerAndDeploy; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests that blocking result partitions are properly released. */ +class BlockingResultPartitionReleaseTest { + + @RegisterExtension + static final TestExecutorExtension EXECUTOR_RESOURCE = + TestingUtils.defaultExecutorExtension(); + + private ScheduledExecutorService scheduledExecutorService; + private ComponentMainThreadExecutor mainThreadExecutor; + private ManuallyTriggeredScheduledExecutorService ioExecutor; + + @BeforeEach + void setup() { + scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); + mainThreadExecutor = + ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor( + scheduledExecutorService); + ioExecutor = new ManuallyTriggeredScheduledExecutorService(); + } + + @AfterEach + void teardown() { + if (scheduledExecutorService != null) { + scheduledExecutorService.shutdownNow(); + } + } + + @Test + void testMultipleConsumersForAdaptiveBatchScheduler() throws Exception { + testResultPartitionConsumedByMultiConsumers(true); + } + + @Test + void testMultipleConsumersForDefaultScheduler() throws Exception { + testResultPartitionConsumedByMultiConsumers(false); + } + + private void testResultPartitionConsumedByMultiConsumers(boolean isAdaptive) throws Exception { + int parallelism = 2; + JobID jobId = new JobID(); + JobVertex producer = ExecutionGraphTestUtils.createNoOpVertex("producer", parallelism); + JobVertex consumer1 = ExecutionGraphTestUtils.createNoOpVertex("consumer1", parallelism); + JobVertex consumer2 = ExecutionGraphTestUtils.createNoOpVertex("consumer2", parallelism); + + TestingPartitionTracker partitionTracker = new TestingPartitionTracker(); + SchedulerBase scheduler = + createSchedulerAndDeploy( + isAdaptive, + jobId, + producer, + new JobVertex[] {consumer1, consumer2}, + DistributionPattern.ALL_TO_ALL, + new TestingBlobWriter(Integer.MAX_VALUE), + mainThreadExecutor, + ioExecutor, + partitionTracker, + EXECUTOR_RESOURCE.getExecutor()); + ExecutionGraph executionGraph = scheduler.getExecutionGraph(); + + assertThat(partitionTracker.releasedPartitions).isEmpty(); + + CompletableFuture.runAsync( + () -> finishJobVertex(executionGraph, consumer1.getID()), + mainThreadExecutor) + .join(); + ioExecutor.triggerAll(); + + assertThat(partitionTracker.releasedPartitions).isEmpty(); + + CompletableFuture.runAsync( + () -> finishJobVertex(executionGraph, consumer2.getID()), + mainThreadExecutor) + .join(); + ioExecutor.triggerAll(); + + assertThat(partitionTracker.releasedPartitions.size()).isEqualTo(parallelism); + for (int i = 0; i < parallelism; ++i) { + ExecutionJobVertex ejv = checkNotNull(executionGraph.getJobVertex(producer.getID())); + assertThat( + partitionTracker.releasedPartitions.stream() + .map(ResultPartitionID::getPartitionId)) + .containsExactlyInAnyOrder( + Arrays.stream(ejv.getProducedDataSets()[0].getPartitions()) + .map(IntermediateResultPartition::getPartitionId) + .toArray(IntermediateResultPartitionID[]::new)); + } + } + + private static class TestingPartitionTracker extends NoOpJobMasterPartitionTracker { + + private final List releasedPartitions = new ArrayList<>(); + + @Override + public void stopTrackingAndReleasePartitions( + Collection resultPartitionIds, boolean releaseOnShuffleMaster) { + releasedPartitions.addAll(checkNotNull(resultPartitionIds)); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java index 9245fa8f96688..588a91c6fd7d7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java @@ -24,6 +24,7 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; @@ -48,6 +49,7 @@ import java.util.Set; import java.util.concurrent.ScheduledExecutorService; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -261,6 +263,45 @@ void testSetupInputSplits() throws Exception { assertThat(eg.getAllVertices().get(v5.getID()).getSplitAssigner()).isEqualTo(assigner2); } + @Test + void testMultiConsumersForOneIntermediateResult() throws Exception { + JobVertex v1 = new JobVertex("vertex1"); + JobVertex v2 = new JobVertex("vertex2"); + JobVertex v3 = new JobVertex("vertex3"); + + IntermediateDataSetID dataSetId = new IntermediateDataSetID(); + v2.connectNewDataSetAsInput( + v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false); + v3.connectNewDataSetAsInput( + v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false); + + List vertices = new ArrayList<>(Arrays.asList(v1, v2, v3)); + ExecutionGraph eg = createDefaultExecutionGraph(vertices); + eg.attachJobGraph(vertices); + + ExecutionJobVertex ejv1 = checkNotNull(eg.getJobVertex(v1.getID())); + assertThat(ejv1.getProducedDataSets()).hasSize(1); + assertThat(ejv1.getProducedDataSets()[0].getId()).isEqualTo(dataSetId); + + ExecutionJobVertex ejv2 = checkNotNull(eg.getJobVertex(v2.getID())); + assertThat(ejv2.getInputs()).hasSize(1); + assertThat(ejv2.getInputs().get(0).getId()).isEqualTo(dataSetId); + + ExecutionJobVertex ejv3 = checkNotNull(eg.getJobVertex(v3.getID())); + assertThat(ejv3.getInputs()).hasSize(1); + assertThat(ejv3.getInputs().get(0).getId()).isEqualTo(dataSetId); + + List partitionGroups1 = + ejv2.getTaskVertices()[0].getAllConsumedPartitionGroups(); + assertThat(partitionGroups1).hasSize(1); + assertThat(partitionGroups1.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId); + + List partitionGroups2 = + ejv3.getTaskVertices()[0].getAllConsumedPartitionGroups(); + assertThat(partitionGroups2).hasSize(1); + assertThat(partitionGroups2.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId); + } + @Test void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception { JobVertex v1 = new JobVertex("source"); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java index e3b603a8790dd..28b44595a26b6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java @@ -85,7 +85,7 @@ private void testGetMaxNumEdgesToTarget( IntermediateResultPartition partition = ev.getProducedPartitions().values().iterator().next(); - ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroup(); + ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroups().get(0); int actual = consumerVertexGroup.size(); if (actual > actualMaxForUpstream) { actualMaxForUpstream = actual; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java index 81165da61933d..ad94749e7e84f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java @@ -34,9 +34,15 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; +import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL; +import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link EdgeManager}. */ @@ -50,23 +56,7 @@ class EdgeManagerTest { void testGetConsumedPartitionGroup() throws Exception { JobVertex v1 = new JobVertex("source"); JobVertex v2 = new JobVertex("sink"); - - v1.setParallelism(2); - v2.setParallelism(2); - - v1.setInvokableClass(NoOpInvokable.class); - v2.setInvokableClass(NoOpInvokable.class); - - v2.connectNewDataSetAsInput( - v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); - - JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(v1, v2); - SchedulerBase scheduler = - SchedulerTestingUtils.createScheduler( - jobGraph, - ComponentMainThreadExecutorServiceAdapter.forMainThread(), - EXECUTOR_RESOURCE.getExecutor()); - ExecutionGraph eg = scheduler.getExecutionGraph(); + ExecutionGraph eg = buildExecutionGraph(v1, v2, 2, 2, ALL_TO_ALL); ConsumedPartitionGroup groupRetrievedByDownstreamVertex = Objects.requireNonNull(eg.getJobVertex(v2.getID())) @@ -86,9 +76,7 @@ void testGetConsumedPartitionGroup() throws Exception { .isEqualTo(groupRetrievedByDownstreamVertex); ConsumedPartitionGroup groupRetrievedByScheduledResultPartition = - scheduler - .getExecutionGraph() - .getSchedulingTopology() + eg.getSchedulingTopology() .getResultPartition(consumedPartition.getPartitionId()) .getConsumedPartitionGroups() .get(0); @@ -96,4 +84,64 @@ void testGetConsumedPartitionGroup() throws Exception { assertThat(groupRetrievedByScheduledResultPartition) .isEqualTo(groupRetrievedByDownstreamVertex); } + + @Test + public void testCalculateNumberOfConsumers() throws Exception { + testCalculateNumberOfConsumers(5, 2, ALL_TO_ALL, new int[] {2, 2}); + testCalculateNumberOfConsumers(5, 2, POINTWISE, new int[] {1, 1}); + testCalculateNumberOfConsumers(2, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5}); + testCalculateNumberOfConsumers(2, 5, POINTWISE, new int[] {3, 3, 3, 2, 2}); + testCalculateNumberOfConsumers(5, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5}); + testCalculateNumberOfConsumers(5, 5, POINTWISE, new int[] {1, 1, 1, 1, 1}); + } + + private void testCalculateNumberOfConsumers( + int producerParallelism, + int consumerParallelism, + DistributionPattern distributionPattern, + int[] expectedConsumers) + throws Exception { + JobVertex producer = new JobVertex("producer"); + JobVertex consumer = new JobVertex("consumer"); + ExecutionGraph eg = + buildExecutionGraph( + producer, + consumer, + producerParallelism, + consumerParallelism, + distributionPattern); + List partitionGroups = + Arrays.stream(checkNotNull(eg.getJobVertex(consumer.getID())).getTaskVertices()) + .flatMap(ev -> ev.getAllConsumedPartitionGroups().stream()) + .collect(Collectors.toList()); + int index = 0; + for (ConsumedPartitionGroup partitionGroup : partitionGroups) { + assertThat(partitionGroup.getNumConsumers()).isEqualTo(expectedConsumers[index++]); + } + } + + private ExecutionGraph buildExecutionGraph( + JobVertex producer, + JobVertex consumer, + int producerParallelism, + int consumerParallelism, + DistributionPattern distributionPattern) + throws Exception { + producer.setParallelism(producerParallelism); + consumer.setParallelism(consumerParallelism); + + producer.setInvokableClass(NoOpInvokable.class); + consumer.setInvokableClass(NoOpInvokable.class); + + consumer.connectNewDataSetAsInput( + producer, distributionPattern, ResultPartitionType.BLOCKING); + + JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(producer, consumer); + SchedulerBase scheduler = + SchedulerTestingUtils.createScheduler( + jobGraph, + ComponentMainThreadExecutorServiceAdapter.forMainThread(), + EXECUTOR_RESOURCE.getExecutor()); + return scheduler.getExecutionGraph(); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java index ff2baf807e178..a9e7df9470d2b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.scheduler.SchedulerBase; import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.runtime.testutils.DirectScheduledExecutorService; @@ -43,6 +44,7 @@ import java.lang.reflect.Field; import java.time.Duration; import java.util.List; +import java.util.Objects; import java.util.Random; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeoutException; @@ -249,6 +251,23 @@ public static void completeCancellingForAllVertices(ExecutionGraph eg) { } } + public static void finishJobVertex(ExecutionGraph executionGraph, JobVertexID jobVertexId) { + for (ExecutionVertex vertex : + Objects.requireNonNull(executionGraph.getJobVertex(jobVertexId)) + .getTaskVertices()) { + finishExecutionVertex(executionGraph, vertex); + } + } + + public static void finishExecutionVertex( + ExecutionGraph executionGraph, ExecutionVertex executionVertex) { + executionGraph.updateState( + new TaskExecutionStateTransition( + new TaskExecutionState( + executionVertex.getCurrentExecutionAttempt().getAttemptId(), + ExecutionState.FINISHED))); + } + /** * Takes all vertices in the given ExecutionGraph and switches their current execution to * FINISHED. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java index d847e9ede68df..dda1659675178 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java @@ -170,7 +170,7 @@ public static ExecutionJobVertex createDynamicExecutionJobVertex( int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception { JobVertex jobVertex = new JobVertex("testVertex"); jobVertex.setInvokableClass(AbstractInvokable.class); - jobVertex.createAndAddResultDataSet( + jobVertex.getOrCreateResultDataSet( new IntermediateDataSetID(), ResultPartitionType.BLOCKING); if (maxParallelism > 0) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java index 5ff5d9f201cd1..aa906caf6749b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java @@ -22,6 +22,7 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobGraphBuilder; import org.apache.flink.runtime.jobgraph.JobGraphTestUtils; @@ -151,6 +152,32 @@ void testBlockingPartitionResetting() throws Exception { assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse(); } + @Test + void testReleasePartitionGroups() throws Exception { + IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2); + + IntermediateResultPartition partition1 = result.getPartitions()[0]; + IntermediateResultPartition partition2 = result.getPartitions()[1]; + assertThat(partition1.canBeReleased()).isFalse(); + assertThat(partition2.canBeReleased()).isFalse(); + + List consumedPartitionGroup1 = + partition1.getConsumedPartitionGroups(); + List consumedPartitionGroup2 = + partition2.getConsumedPartitionGroups(); + assertThat(consumedPartitionGroup1).isEqualTo(consumedPartitionGroup2); + + assertThat(consumedPartitionGroup1).hasSize(2); + partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(0)); + assertThat(partition1.canBeReleased()).isFalse(); + + partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(1)); + assertThat(partition1.canBeReleased()).isTrue(); + + result.resetForNewExecution(); + assertThat(partition1.canBeReleased()).isFalse(); + } + @Test void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception { testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, false, Arrays.asList(7, 7)); @@ -245,11 +272,24 @@ public static ExecutionGraph createExecutionGraph( v2.setMaxParallelism(consumerMaxParallelism); } - v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING); + final JobVertex v3 = new JobVertex("v3"); + v3.setInvokableClass(NoOpInvokable.class); + if (consumerParallelism > 0) { + v3.setParallelism(consumerParallelism); + } + if (consumerMaxParallelism > 0) { + v3.setMaxParallelism(consumerMaxParallelism); + } + + IntermediateDataSetID dataSetId = new IntermediateDataSetID(); + v2.connectNewDataSetAsInput( + v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false); + v3.connectNewDataSetAsInput( + v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false); final JobGraph jobGraph = JobGraphBuilder.newBatchJobGraphBuilder() - .addJobVertices(Arrays.asList(v1, v2)) + .addJobVertices(Arrays.asList(v1, v2, v3)) .build(); final Configuration configuration = new Configuration(); @@ -287,15 +327,23 @@ private static IntermediateResult createResult( source.setInvokableClass(NoOpInvokable.class); source.setParallelism(parallelism); - JobVertex sink = new JobVertex("v2"); - sink.setInvokableClass(NoOpInvokable.class); - sink.setParallelism(parallelism); + JobVertex sink1 = new JobVertex("v2"); + sink1.setInvokableClass(NoOpInvokable.class); + sink1.setParallelism(parallelism); + + JobVertex sink2 = new JobVertex("v3"); + sink2.setInvokableClass(NoOpInvokable.class); + sink2.setParallelism(parallelism); - sink.connectNewDataSetAsInput(source, DistributionPattern.ALL_TO_ALL, resultPartitionType); + IntermediateDataSetID dataSetId = new IntermediateDataSetID(); + sink1.connectNewDataSetAsInput( + source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false); + sink2.connectNewDataSetAsInput( + source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false); ScheduledExecutorService executorService = new DirectScheduledExecutorService(); - JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink); + JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink1, sink2); SchedulerBase scheduler = new DefaultSchedulerBuilder( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java index 7c00da22a856f..4bf976c76b0bf 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.executiongraph; import org.apache.flink.api.common.JobID; -import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.blob.BlobWriter; import org.apache.flink.runtime.blob.TestingBlobWriter; import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; @@ -27,21 +26,14 @@ import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy; +import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; -import org.apache.flink.runtime.jobgraph.JobGraph; -import org.apache.flink.runtime.jobgraph.JobGraphBuilder; import org.apache.flink.runtime.jobgraph.JobVertex; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.jobmaster.LogicalSlot; -import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder; -import org.apache.flink.runtime.scheduler.DefaultScheduler; -import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder; +import org.apache.flink.runtime.scheduler.SchedulerBase; +import org.apache.flink.runtime.scheduler.SchedulerTestingUtils; import org.apache.flink.runtime.shuffle.ShuffleDescriptor; -import org.apache.flink.runtime.taskmanager.TaskExecutionState; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; @@ -50,17 +42,16 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeoutException; import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors; +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishExecutionVertex; +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex; import static org.assertj.core.api.Assertions.assertThat; /** @@ -119,7 +110,7 @@ private void testRemoveCacheForAllToAllEdgeAfterFinished( final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM); final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM); - final DefaultScheduler scheduler = + final SchedulerBase scheduler = createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter); final ExecutionGraph executionGraph = scheduler.getExecutionGraph(); @@ -132,8 +123,7 @@ private void testRemoveCacheForAllToAllEdgeAfterFinished( // For the all-to-all edge, we transition all downstream tasks to finished CompletableFuture.runAsync( - () -> transitionTasksToFinished(executionGraph, v2.getID()), - mainThreadExecutor) + () -> finishJobVertex(executionGraph, v2.getID()), mainThreadExecutor) .join(); ioExecutor.triggerAll(); @@ -164,7 +154,7 @@ private void testRemoveCacheForAllToAllEdgeAfterFailover( final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM); final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM); - final DefaultScheduler scheduler = + final SchedulerBase scheduler = createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter); final ExecutionGraph executionGraph = scheduler.getExecutionGraph(); @@ -206,7 +196,7 @@ private void testRemoveCacheForPointwiseEdgeAfterFinished( final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM); final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM); - final DefaultScheduler scheduler = + final SchedulerBase scheduler = createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter); final ExecutionGraph executionGraph = scheduler.getExecutionGraph(); @@ -222,7 +212,7 @@ private void testRemoveCacheForPointwiseEdgeAfterFinished( Objects.requireNonNull(executionGraph.getJobVertex(v2.getID())) .getTaskVertices()[0]; CompletableFuture.runAsync( - () -> transitionTaskToFinished(executionGraph, ev21), mainThreadExecutor) + () -> finishExecutionVertex(executionGraph, ev21), mainThreadExecutor) .join(); ioExecutor.triggerAll(); @@ -263,7 +253,7 @@ private void testRemoveCacheForPointwiseEdgeAfterFailover( final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM); final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM); - final DefaultScheduler scheduler = + final SchedulerBase scheduler = createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter); final ExecutionGraph executionGraph = scheduler.getExecutionGraph(); @@ -292,42 +282,27 @@ private void testRemoveCacheForPointwiseEdgeAfterFailover( assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter); } - private DefaultScheduler createSchedulerAndDeploy( + private SchedulerBase createSchedulerAndDeploy( JobID jobId, JobVertex v1, JobVertex v2, DistributionPattern distributionPattern, BlobWriter blobWriter) throws Exception { - - v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING); - - final List ordered = new ArrayList<>(Arrays.asList(v1, v2)); - final DefaultScheduler scheduler = - createScheduler(jobId, ordered, blobWriter, mainThreadExecutor, ioExecutor); - final ExecutionGraph executionGraph = scheduler.getExecutionGraph(); - final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder(); - - CompletableFuture.runAsync( - () -> { - try { - // Deploy upstream source vertices - deployTasks(executionGraph, v1.getID(), slotBuilder); - // Transition upstream vertices into FINISHED - transitionTasksToFinished(executionGraph, v1.getID()); - // Deploy downstream sink vertices - deployTasks(executionGraph, v2.getID(), slotBuilder); - } catch (Exception e) { - throw new RuntimeException("Exceptions shouldn't happen here.", e); - } - }, - mainThreadExecutor) - .join(); - - return scheduler; + return SchedulerTestingUtils.createSchedulerAndDeploy( + false, + jobId, + v1, + new JobVertex[] {v2}, + distributionPattern, + blobWriter, + mainThreadExecutor, + ioExecutor, + NoOpJobMasterPartitionTracker.INSTANCE, + EXECUTOR_RESOURCE.getExecutor()); } - private void triggerGlobalFailoverAndComplete(DefaultScheduler scheduler, JobVertex upstream) + private void triggerGlobalFailoverAndComplete(SchedulerBase scheduler, JobVertex upstream) throws TimeoutException { final Throwable t = new Exception(); @@ -378,66 +353,6 @@ private void triggerExceptionAndComplete( // ============== Utils ============== - private static DefaultScheduler createScheduler( - final JobID jobId, - final List jobVertices, - final BlobWriter blobWriter, - final ComponentMainThreadExecutor mainThreadExecutor, - final ScheduledExecutorService ioExecutor) - throws Exception { - final JobGraph jobGraph = - JobGraphBuilder.newBatchJobGraphBuilder() - .setJobId(jobId) - .addJobVertices(jobVertices) - .build(); - - return new DefaultSchedulerBuilder( - jobGraph, mainThreadExecutor, EXECUTOR_RESOURCE.getExecutor()) - .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0)) - .setBlobWriter(blobWriter) - .setIoExecutor(ioExecutor) - .build(); - } - - private static void deployTasks( - ExecutionGraph executionGraph, - JobVertexID jobVertexID, - TestingLogicalSlotBuilder slotBuilder) - throws JobException, ExecutionException, InterruptedException { - - for (ExecutionVertex vertex : - Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID)) - .getTaskVertices()) { - LogicalSlot slot = slotBuilder.createTestingLogicalSlot(); - - Execution execution = vertex.getCurrentExecutionAttempt(); - execution.registerProducedPartitions(slot.getTaskManagerLocation()).get(); - execution.transitionState(ExecutionState.SCHEDULED); - - vertex.tryAssignResource(slot); - vertex.deploy(); - } - } - - private static void transitionTasksToFinished( - ExecutionGraph executionGraph, JobVertexID jobVertexID) { - - for (ExecutionVertex vertex : - Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID)) - .getTaskVertices()) { - transitionTaskToFinished(executionGraph, vertex); - } - } - - private static void transitionTaskToFinished( - ExecutionGraph executionGraph, ExecutionVertex executionVertex) { - executionGraph.updateState( - new TaskExecutionStateTransition( - new TaskExecutionState( - executionVertex.getCurrentExecutionAttempt().getAttemptId(), - ExecutionState.FINISHED))); - } - private static MaybeOffloaded getConsumedCachedShuffleDescriptor( ExecutionGraph executionGraph, JobVertex vertex) { return getConsumedCachedShuffleDescriptor(executionGraph, vertex, 0); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java index 087a29613e148..130212bcf059a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java @@ -28,8 +28,9 @@ import java.util.List; /** No-op implementation of {@link JobMasterPartitionTracker}. */ -public enum NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker { - INSTANCE; +public class NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker { + public static final NoOpJobMasterPartitionTracker INSTANCE = + new NoOpJobMasterPartitionTracker(); public static final PartitionTrackerFactory FACTORY = lookup -> INSTANCE; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java index 66e097b34b15d..cf59753999f55 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.net.URL; import java.net.URLClassLoader; +import java.util.List; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -41,6 +42,45 @@ @SuppressWarnings("serial") class JobTaskVertexTest { + @Test + void testMultipleConsumersVertices() { + JobVertex producer = new JobVertex("producer"); + JobVertex consumer1 = new JobVertex("consumer1"); + JobVertex consumer2 = new JobVertex("consumer2"); + + IntermediateDataSetID dataSetId = new IntermediateDataSetID(); + consumer1.connectNewDataSetAsInput( + producer, + DistributionPattern.ALL_TO_ALL, + ResultPartitionType.BLOCKING, + dataSetId, + false); + consumer2.connectNewDataSetAsInput( + producer, + DistributionPattern.ALL_TO_ALL, + ResultPartitionType.BLOCKING, + dataSetId, + false); + + JobVertex consumer3 = new JobVertex("consumer3"); + consumer3.connectNewDataSetAsInput( + producer, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); + + assertThat(producer.getProducedDataSets()).hasSize(2); + + IntermediateDataSet dataSet = producer.getProducedDataSets().get(0); + assertThat(dataSet.getId()).isEqualTo(dataSetId); + + List consumers1 = dataSet.getConsumers(); + assertThat(consumers1).hasSize(2); + assertThat(consumers1.get(0).getTarget().getID()).isEqualTo(consumer1.getID()); + assertThat(consumers1.get(1).getTarget().getID()).isEqualTo(consumer2.getID()); + + List consumers2 = producer.getProducedDataSets().get(1).getConsumers(); + assertThat(consumers2).hasSize(1); + assertThat(consumers2.get(0).getTarget().getID()).isEqualTo(consumer3.getID()); + } + @Test void testConnectDirectly() { JobVertex source = new JobVertex("source"); @@ -59,7 +99,8 @@ void testConnectDirectly() { assertThat(source.getProducedDataSets().get(0)) .isEqualTo(target.getInputs().get(0).getSource()); - assertThat(source.getProducedDataSets().get(0).getConsumer().getTarget()).isEqualTo(target); + assertThat(source.getProducedDataSets().get(0).getConsumers().get(0).getTarget()) + .isEqualTo(target); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java index 7f9488fe90169..a0ce9132e3f39 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java @@ -206,7 +206,8 @@ private JobGraph createFirstJobGraph( sender, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING_PERSISTENT, - intermediateDataSetID); + intermediateDataSetID, + false); return new JobGraph(null, "First Job", sender, receiver); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java index eb09a7e139c5a..b3954d2ea075d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java @@ -20,6 +20,8 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; +import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.blob.BlobWriter; import org.apache.flink.runtime.checkpoint.CheckpointCoordinator; import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; @@ -28,12 +30,24 @@ import org.apache.flink.runtime.checkpoint.PendingCheckpoint; import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy; +import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobGraphBuilder; +import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration; import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; +import org.apache.flink.runtime.jobmaster.LogicalSlot; +import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder; import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlotProvider; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; @@ -46,12 +60,18 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @@ -297,4 +317,121 @@ public static ExecutionAttemptID getAttemptId( allocationTimeout, new LocalInputPreferredSlotSharingStrategy.Factory()); } + + public static SchedulerBase createSchedulerAndDeploy( + boolean isAdaptive, + JobID jobId, + JobVertex producer, + JobVertex[] consumers, + DistributionPattern distributionPattern, + BlobWriter blobWriter, + ComponentMainThreadExecutor mainThreadExecutor, + ScheduledExecutorService ioExecutor, + JobMasterPartitionTracker partitionTracker, + ScheduledExecutorService scheduledExecutor) + throws Exception { + final List vertices = new ArrayList<>(Collections.singletonList(producer)); + IntermediateDataSetID dataSetId = new IntermediateDataSetID(); + for (JobVertex consumer : consumers) { + consumer.connectNewDataSetAsInput( + producer, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false); + vertices.add(consumer); + } + + final SchedulerBase scheduler = + createScheduler( + isAdaptive, + jobId, + vertices, + blobWriter, + mainThreadExecutor, + ioExecutor, + partitionTracker, + scheduledExecutor); + final ExecutionGraph executionGraph = scheduler.getExecutionGraph(); + final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder(); + + CompletableFuture.runAsync( + () -> { + try { + if (isAdaptive) { + initializeExecutionJobVertex(producer.getID(), executionGraph); + } + // Deploy upstream source vertices + deployTasks(executionGraph, producer.getID(), slotBuilder); + // Transition upstream vertices into FINISHED + finishJobVertex(executionGraph, producer.getID()); + // Deploy downstream sink vertices + for (JobVertex consumer : consumers) { + if (isAdaptive) { + initializeExecutionJobVertex( + consumer.getID(), executionGraph); + } + deployTasks(executionGraph, consumer.getID(), slotBuilder); + } + } catch (Exception e) { + throw new RuntimeException("Exceptions shouldn't happen here.", e); + } + }, + mainThreadExecutor) + .join(); + return scheduler; + } + + private static void initializeExecutionJobVertex( + JobVertexID jobVertex, ExecutionGraph executionGraph) { + try { + executionGraph.initializeJobVertex( + executionGraph.getJobVertex(jobVertex), System.currentTimeMillis()); + executionGraph.notifyNewlyInitializedJobVertices( + Collections.singletonList(executionGraph.getJobVertex(jobVertex))); + } catch (JobException exception) { + throw new RuntimeException(exception); + } + } + + private static DefaultScheduler createScheduler( + boolean isAdaptive, + JobID jobId, + List jobVertices, + BlobWriter blobWriter, + ComponentMainThreadExecutor mainThreadExecutor, + ScheduledExecutorService ioExecutor, + JobMasterPartitionTracker partitionTracker, + ScheduledExecutorService scheduledExecutor) + throws Exception { + final JobGraph jobGraph = + JobGraphBuilder.newBatchJobGraphBuilder() + .setJobId(jobId) + .addJobVertices(jobVertices) + .build(); + + final DefaultSchedulerBuilder builder = + new DefaultSchedulerBuilder(jobGraph, mainThreadExecutor, scheduledExecutor) + .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0)) + .setBlobWriter(blobWriter) + .setIoExecutor(ioExecutor) + .setPartitionTracker(partitionTracker); + return isAdaptive ? builder.buildAdaptiveBatchJobScheduler() : builder.build(); + } + + private static void deployTasks( + ExecutionGraph executionGraph, + JobVertexID jobVertexID, + TestingLogicalSlotBuilder slotBuilder) + throws JobException, ExecutionException, InterruptedException { + + for (ExecutionVertex vertex : + Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID)) + .getTaskVertices()) { + LogicalSlot slot = slotBuilder.createTestingLogicalSlot(); + + Execution execution = vertex.getCurrentExecutionAttempt(); + execution.registerProducedPartitions(slot.getTaskManagerLocation()).get(); + execution.transitionState(ExecutionState.SCHEDULED); + + vertex.tryAssignResource(slot); + vertex.deploy(); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java index 4e1d9cb76a4ae..f22c3b23dd7a9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java @@ -52,7 +52,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -314,17 +313,23 @@ private static void assertPartitionsEquals( assertPartitionEquals(originalPartition, adaptedPartition); - ConsumerVertexGroup consumerVertexGroup = originalPartition.getConsumerVertexGroup(); - Optional adaptedConsumers = - adaptedPartition.getConsumerVertexGroup(); - assertThat(adaptedConsumers).isPresent(); - for (ExecutionVertexID originalId : consumerVertexGroup) { + List originalConsumerIds = new ArrayList<>(); + for (ConsumerVertexGroup consumerVertexGroup : + originalPartition.getConsumerVertexGroups()) { + for (ExecutionVertexID executionVertexId : consumerVertexGroup) { + originalConsumerIds.add(executionVertexId); + } + } + List adaptedConsumers = adaptedPartition.getConsumerVertexGroups(); + assertThat(adaptedConsumers).isNotEmpty(); + for (ExecutionVertexID originalId : originalConsumerIds) { // it is sufficient to verify that some vertex exists with the correct ID here, // since deep equality is verified later in the main loop // this DOES rely on an implicit assumption that the vertices objects returned by // the topology are // identical to those stored in the partition - assertThat(adaptedConsumers.get()).contains(originalId); + assertThat(adaptedConsumers.stream().flatMap(IterableUtils::toStream)) + .contains(originalId); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java index fc2df69746410..d6fb9ace1a385 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java @@ -81,6 +81,7 @@ void setUp() throws Exception { List consumedPartitionGroups = Collections.singletonList( ConsumedPartitionGroup.fromSinglePartition( + 1, intermediateResultPartitionId, schedulingResultPartition.getResultType())); Map resultPartitionById = diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java index 6d0024626c18a..9f8c58400e30b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java @@ -28,7 +28,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Supplier; @@ -47,8 +50,8 @@ class DefaultResultPartitionTest { private DefaultResultPartition resultPartition; - private final Map consumerVertexGroups = - new HashMap<>(); + private final Map> + consumerVertexGroups = new HashMap<>(); @BeforeEach void setUp() { @@ -58,7 +61,9 @@ void setUp() { intermediateResultId, BLOCKING, resultPartitionState, - () -> consumerVertexGroups.get(resultPartitionId), + () -> + consumerVertexGroups.computeIfAbsent( + resultPartitionId, ignored -> new ArrayList<>()), () -> { throw new UnsupportedOperationException(); }); @@ -75,14 +80,15 @@ void testGetPartitionState() { @Test void testGetConsumerVertexGroup() { - assertThat(resultPartition.getConsumerVertexGroup()).isNotPresent(); + assertThat(resultPartition.getConsumerVertexGroups()).isEmpty(); // test update consumers ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0); consumerVertexGroups.put( - resultPartition.getId(), ConsumerVertexGroup.fromSingleVertex(executionVertexId)); - assertThat(resultPartition.getConsumerVertexGroup()).isPresent(); - assertThat(resultPartition.getConsumerVertexGroup().get()).contains(executionVertexId); + resultPartition.getId(), + Collections.singletonList(ConsumerVertexGroup.fromSingleVertex(executionVertexId))); + assertThat(resultPartition.getConsumerVertexGroups()).isNotEmpty(); + assertThat(resultPartition.getConsumerVertexGroups().get(0)).contains(executionVertexId); } /** A test {@link ResultPartitionState} supplier. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java index 3c9bed61354fe..a6e87f9bfc2f8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java @@ -166,7 +166,7 @@ public JobGraph createJobGraph(boolean withForwardEdge) { sink.connectNewDataSetAsInput( source2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); if (withForwardEdge) { - source1.getProducedDataSets().get(0).getConsumer().setForward(true); + source1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); } return new JobGraph(new JobID(), "test job", source1, source2, sink); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java index b45ba3f5ebec7..e9b0da1509de8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java @@ -97,14 +97,14 @@ private void testThreeVerticesConnectSequentially( v2.connectNewDataSetAsInput( v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); if (isForward1) { - v1.getProducedDataSets().get(0).getConsumer().setForward(true); + v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); } v3.connectNewDataSetAsInput( v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); if (isForward2) { - v2.getProducedDataSets().get(0).getConsumer().setForward(true); + v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); } Set groups = computeForwardGroups(v1, v2, v3); @@ -135,10 +135,10 @@ void testTwoInputsMergesIntoOne() throws Exception { v3.connectNewDataSetAsInput( v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); - v1.getProducedDataSets().get(0).getConsumer().setForward(true); + v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); v3.connectNewDataSetAsInput( v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); - v2.getProducedDataSets().get(0).getConsumer().setForward(true); + v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); v4.connectNewDataSetAsInput( v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); @@ -174,8 +174,8 @@ void testOneInputSplitsIntoTwo() throws Exception { v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); v4.connectNewDataSetAsInput( v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); - v2.getProducedDataSets().get(0).getConsumer().setForward(true); - v2.getProducedDataSets().get(1).getConsumer().setForward(true); + v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); + v2.getProducedDataSets().get(1).getConsumers().get(0).setForward(true); Set groups = computeForwardGroups(v1, v2, v3, v4); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java index 4621680d3604d..655f725874c71 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java @@ -93,7 +93,9 @@ public List getConsumedPartitionGroups() { void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) { final ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromSinglePartition( - consumedPartition.getId(), consumedPartition.getResultType()); + consumedPartition.getNumConsumers(), + consumedPartition.getId(), + consumedPartition.getResultType()); consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup); if (consumedPartition.getState() == ResultPartitionState.CONSUMABLE) { @@ -155,7 +157,8 @@ public Builder withConsumedPartitionGroups( partitionIds.add(partitionId); } this.consumedPartitionGroups.add( - ConsumedPartitionGroup.fromMultiplePartitions(partitionIds, resultType)); + ConsumedPartitionGroup.fromMultiplePartitions( + partitionGroup.getNumConsumers(), partitionIds, resultType)); } return this; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java index 9454268cb5fb8..6274eecb285c0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java @@ -28,7 +28,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -64,6 +63,10 @@ private TestingSchedulingResultPartition( this.consumedPartitionGroups = new ArrayList<>(); } + public int getNumConsumers() { + return consumerVertexGroup == null ? 1 : consumerVertexGroup.size(); + } + @Override public IntermediateResultPartitionID getId() { return intermediateResultPartitionID; @@ -90,8 +93,8 @@ public TestingSchedulingExecutionVertex getProducer() { } @Override - public Optional getConsumerVertexGroup() { - return Optional.of(consumerVertexGroup); + public List getConsumerVertexGroups() { + return Collections.singletonList(consumerVertexGroup); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java index 82463a1a2cb8d..95596369e3399 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java @@ -350,6 +350,7 @@ protected List connect() { ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromMultiplePartitions( + consumers.size(), resultPartitions.stream() .map(TestingSchedulingResultPartition::getId) .collect(Collectors.toList()), diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index be72cacff5c7a..4cf9ed7c9908d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -1079,19 +1079,20 @@ private void connect(Integer headOfChain, StreamEdge edge) { headVertex, DistributionPattern.POINTWISE, resultPartitionType, - intermediateDataSetID); + intermediateDataSetID, + partitioner.isBroadcast()); } else { jobEdge = downStreamVertex.connectNewDataSetAsInput( headVertex, DistributionPattern.ALL_TO_ALL, resultPartitionType, - intermediateDataSetID); + intermediateDataSetID, + partitioner.isBroadcast()); } // set strategy name so that web interface can show it. jobEdge.setShipStrategyName(partitioner.toString()); - jobEdge.setBroadcast(partitioner.isBroadcast()); jobEdge.setForward(partitioner instanceof ForwardPartitioner); jobEdge.setDownstreamSubtaskStateMapper(partitioner.getDownstreamSubtaskStateMapper()); jobEdge.setUpstreamSubtaskStateMapper(partitioner.getUpstreamSubtaskStateMapper()); diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java index ab50814edd258..137e5040383d2 100644 --- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java +++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java @@ -55,24 +55,26 @@ public static void checkDataFlow(TestJobWithDescription testJob, boolean withDra for (JobVertex upstream : testJob.jobGraph.getVertices()) { for (IntermediateDataSet produced : upstream.getProducedDataSets()) { - JobEdge edge = produced.getConsumer(); - Optional upstreamIDOptional = getTrackedOperatorID(upstream, true, testJob); - Optional downstreamIDOptional = - getTrackedOperatorID(edge.getTarget(), false, testJob); - if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) { - final String upstreamID = upstreamIDOptional.get(); - final String downstreamID = downstreamIDOptional.get(); - if (testJob.sources.contains(upstreamID)) { - // TODO: if we add tests for FLIP-27 sources we might need to adjust - // this condition - LOG.debug( - "Legacy sources do not have the finish() method and thus do not" - + " emit FinishEvent"); + for (JobEdge edge : produced.getConsumers()) { + Optional upstreamIDOptional = + getTrackedOperatorID(upstream, true, testJob); + Optional downstreamIDOptional = + getTrackedOperatorID(edge.getTarget(), false, testJob); + if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) { + final String upstreamID = upstreamIDOptional.get(); + final String downstreamID = downstreamIDOptional.get(); + if (testJob.sources.contains(upstreamID)) { + // TODO: if we add tests for FLIP-27 sources we might need to adjust + // this condition + LOG.debug( + "Legacy sources do not have the finish() method and thus do not" + + " emit FinishEvent"); + } else { + checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain); + } } else { - checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain); + LOG.debug("Ignoring edge (untracked operator): {}", edge); } - } else { - LOG.debug("Ignoring edge (untracked operator): {}", edge); } } }