Skip to content

Commit

Permalink
[FLINK-28663][runtime] Allow multiple downstream consumer job vertice…
Browse files Browse the repository at this point in the history
…s sharing the same intermediate dataset at scheduler side

This closes apache#20350.
  • Loading branch information
wsry committed Aug 8, 2022
1 parent b3be6bb commit 7240536
Show file tree
Hide file tree
Showing 43 changed files with 958 additions and 404 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ private boolean checkAndConfigurePersistentIntermediateResult(PlanNode node) {
predecessorVertex != null,
"Bug: Chained task has not been assigned its containing vertex when connecting.");

predecessorVertex.createAndAddResultDataSet(
predecessorVertex.getOrCreateResultDataSet(
// use specified intermediateDataSetID
new IntermediateDataSetID(
((BlockingShuffleOutputFormat) userCodeObject).getIntermediateDataSetId()),
Expand Down Expand Up @@ -1326,7 +1326,7 @@ private DistributionPattern connectJobVertices(

JobEdge edge =
targetVertex.connectNewDataSetAsInput(
sourceVertex, distributionPattern, resultType);
sourceVertex, distributionPattern, resultType, isBroadcast);

// -------------- configure the source task's ship strategy strategies in task config
// --------------
Expand Down Expand Up @@ -1403,7 +1403,6 @@ private DistributionPattern connectJobVertices(
channel.getTempMode() == TempMode.NONE ? null : channel.getTempMode().toString();

edge.setShipStrategyName(shipStrategy);
edge.setBroadcast(isBroadcast);
edge.setForward(channel.getShipStrategy() == ShipStrategyType.FORWARD);
edge.setPreProcessingOperationName(localStrategy);
edge.setOperatorLevelCachingDescription(caching);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
SubpartitionIndexRange consumedSubpartitionRange =
computeConsumedSubpartitionRange(
resultPartition, executionId.getSubtaskIndex());
consumedPartitionGroup.getNumConsumers(),
resultPartition,
executionId.getSubtaskIndex());

IntermediateDataSetID resultId = consumedIntermediateResult.getId();
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
Expand Down Expand Up @@ -164,8 +166,9 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
}

public static SubpartitionIndexRange computeConsumedSubpartitionRange(
IntermediateResultPartition resultPartition, int consumerSubtaskIndex) {
int numConsumers = resultPartition.getConsumerVertexGroup().size();
int numConsumers,
IntermediateResultPartition resultPartition,
int consumerSubtaskIndex) {
int consumerIndex = consumerSubtaskIndex % numConsumers;
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
int numSubpartitions = resultPartition.getNumberOfSubpartitions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1399,22 +1399,25 @@ private void releasePartitionGroups(
final List<ConsumedPartitionGroup> releasablePartitionGroups) {

if (releasablePartitionGroups.size() > 0) {
final List<ResultPartitionID> releasablePartitionIds = new ArrayList<>();

// Remove the cache of ShuffleDescriptors when ConsumedPartitionGroups are released
for (ConsumedPartitionGroup releasablePartitionGroup : releasablePartitionGroups) {
IntermediateResult totalResult =
checkNotNull(
intermediateResults.get(
releasablePartitionGroup.getIntermediateDataSetID()));
for (IntermediateResultPartitionID partitionId : releasablePartitionGroup) {
IntermediateResultPartition partition =
totalResult.getPartitionById(partitionId);
partition.markPartitionGroupReleasable(releasablePartitionGroup);
if (partition.canBeReleased()) {
releasablePartitionIds.add(createResultPartitionId(partitionId));
}
}
totalResult.clearCachedInformationForPartitionGroup(releasablePartitionGroup);
}

final List<ResultPartitionID> releasablePartitionIds =
releasablePartitionGroups.stream()
.flatMap(IterableUtils::toStream)
.map(this::createResultPartitionId)
.collect(Collectors.toList());

partitionTracker.stopTrackingAndReleasePartitions(releasablePartitionIds);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@
import java.util.Map;

import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;

/** Class that manages all the connections between tasks. */
public class EdgeManager {

private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> partitionConsumers =
private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>> partitionConsumers =
new HashMap<>();

private final Map<ExecutionVertexID, List<ConsumedPartitionGroup>> vertexConsumedPartitions =
Expand All @@ -50,9 +49,9 @@ public void connectPartitionWithConsumerVertexGroup(

checkNotNull(consumerVertexGroup);

checkState(
partitionConsumers.putIfAbsent(resultPartitionId, consumerVertexGroup) == null,
"Currently one partition can have at most one consumer group");
List<ConsumerVertexGroup> groups =
getConsumerVertexGroupsForPartitionInternal(resultPartitionId);
groups.add(consumerVertexGroup);
}

public void connectVertexWithConsumedPartitionGroup(
Expand All @@ -66,14 +65,20 @@ public void connectVertexWithConsumedPartitionGroup(
consumedPartitions.add(consumedPartitionGroup);
}

private List<ConsumerVertexGroup> getConsumerVertexGroupsForPartitionInternal(
IntermediateResultPartitionID resultPartitionId) {
return partitionConsumers.computeIfAbsent(resultPartitionId, id -> new ArrayList<>());
}

private List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertexInternal(
ExecutionVertexID executionVertexId) {
return vertexConsumedPartitions.computeIfAbsent(executionVertexId, id -> new ArrayList<>());
}

public ConsumerVertexGroup getConsumerVertexGroupForPartition(
public List<ConsumerVertexGroup> getConsumerVertexGroupsForPartition(
IntermediateResultPartitionID resultPartitionId) {
return partitionConsumers.get(resultPartitionId);
return Collections.unmodifiableList(
getConsumerVertexGroupsForPartitionInternal(resultPartitionId));
}

public List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertex(
Expand All @@ -100,4 +105,9 @@ public List<ConsumedPartitionGroup> getConsumedPartitionGroupsById(
return Collections.unmodifiableList(
getConsumedPartitionGroupsByIdInternal(resultPartitionId));
}

public int getNumberOfConsumedPartitionGroupsById(
IntermediateResultPartitionID resultPartitionId) {
return getConsumedPartitionGroupsByIdInternal(resultPartitionId).size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private static void connectAllToAll(
.collect(Collectors.toList());
ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
consumedPartitions, intermediateResult);
taskVertices.length, consumedPartitions, intermediateResult);
for (ExecutionVertex ev : taskVertices) {
ev.addConsumedPartitionGroup(consumedPartitionGroup);
}
Expand Down Expand Up @@ -122,7 +122,9 @@ private static void connectPointwise(

ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
partition.getPartitionId(), intermediateResult);
consumerVertexGroup.size(),
partition.getPartitionId(),
intermediateResult);
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
}
} else if (sourceCount > targetCount) {
Expand All @@ -147,20 +149,19 @@ private static void connectPointwise(

ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
consumedPartitions, intermediateResult);
consumerVertexGroup.size(), consumedPartitions, intermediateResult);
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
}
} else {
for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {
int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;

IntermediateResultPartition partition =
intermediateResult.getPartitions()[partitionNum];
ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
partition.getPartitionId(), intermediateResult);

int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
end - start, partition.getPartitionId(), intermediateResult);

List<ExecutionVertexID> consumers = new ArrayList<>(end - start);

Expand All @@ -179,21 +180,23 @@ private static void connectPointwise(
}

private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
int numConsumers,
IntermediateResultPartitionID consumedPartitionId,
IntermediateResult intermediateResult) {
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromSinglePartition(
consumedPartitionId, intermediateResult.getResultType());
numConsumers, consumedPartitionId, intermediateResult.getResultType());
registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
return consumedPartitionGroup;
}

private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
int numConsumers,
List<IntermediateResultPartitionID> consumedPartitions,
IntermediateResult intermediateResult) {
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromMultiplePartitions(
consumedPartitions, intermediateResult.getResultType());
numConsumers, consumedPartitions, intermediateResult.getResultType());
registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
return consumedPartitionGroup;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -497,10 +498,7 @@ public CompletableFuture<Void> registerProducedPartitions(TaskManagerLocation lo
}

private static int getPartitionMaxParallelism(IntermediateResultPartition partition) {
return partition
.getIntermediateResult()
.getConsumerExecutionJobVertex()
.getMaxParallelism();
return partition.getIntermediateResult().getConsumersMaxParallelism();
}

/**
Expand Down Expand Up @@ -718,31 +716,40 @@ public CompletableFuture<?> suspend() {
}

private void updatePartitionConsumers(final IntermediateResultPartition partition) {
final Optional<ConsumerVertexGroup> consumerVertexGroup =
partition.getConsumerVertexGroupOptional();
if (!consumerVertexGroup.isPresent()) {
final List<ConsumerVertexGroup> consumerVertexGroups = partition.getConsumerVertexGroups();
if (consumerVertexGroups.isEmpty()) {
return;
}
for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
final ExecutionVertex consumerVertex =
vertex.getExecutionGraphAccessor().getExecutionVertexOrThrow(consumerVertexId);
final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
final ExecutionState consumerState = consumer.getState();

// ----------------------------------------------------------------
// Consumer is recovering or running => send update message now
// Consumer is deploying => cache the partition info which would be
// sent after switching to running
// ----------------------------------------------------------------
if (consumerState == DEPLOYING
|| consumerState == RUNNING
|| consumerState == INITIALIZING) {
final PartitionInfo partitionInfo = createPartitionInfo(partition);

if (consumerState == DEPLOYING) {
consumerVertex.cachePartitionInfo(partitionInfo);
} else {
consumer.sendUpdatePartitionInfoRpcCall(Collections.singleton(partitionInfo));
final Set<ExecutionVertexID> updatedVertices = new HashSet<>();
for (ConsumerVertexGroup consumerVertexGroup : consumerVertexGroups) {
for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
if (updatedVertices.contains(consumerVertexId)) {
continue;
}

final ExecutionVertex consumerVertex =
vertex.getExecutionGraphAccessor()
.getExecutionVertexOrThrow(consumerVertexId);
final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
final ExecutionState consumerState = consumer.getState();

// ----------------------------------------------------------------
// Consumer is recovering or running => send update message now
// Consumer is deploying => cache the partition info which would be
// sent after switching to running
// ----------------------------------------------------------------
if (consumerState == DEPLOYING
|| consumerState == RUNNING
|| consumerState == INITIALIZING) {
final PartitionInfo partitionInfo = createPartitionInfo(partition);
updatedVertices.add(consumerVertexId);

if (consumerState == DEPLOYING) {
consumerVertex.cachePartitionInfo(partitionInfo);
} else {
consumer.sendUpdatePartitionInfoRpcCall(
Collections.singleton(partitionInfo));
}
}
}
}
Expand Down
Loading

0 comments on commit 7240536

Please sign in to comment.