Skip to content

Commit

Permalink
[FLINK-29767] Let consumerVertexGroup also know the type of result pa…
Browse files Browse the repository at this point in the history
…rtition.
  • Loading branch information
reswqa authored and xintongsong committed Dec 15, 2022
1 parent e739b38 commit 022f7ad
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ private static void connectAllToAll(
.map(ExecutionVertex::getID)
.collect(Collectors.toList());
ConsumerVertexGroup consumerVertexGroup =
ConsumerVertexGroup.fromMultipleVertices(consumerVertices);
ConsumerVertexGroup.fromMultipleVertices(
consumerVertices, intermediateResult.getResultType());
for (IntermediateResultPartition partition : intermediateResult.getPartitions()) {
partition.addConsumers(consumerVertexGroup);
}
Expand All @@ -117,7 +118,8 @@ private static void connectPointwise(
IntermediateResultPartition partition = intermediateResult.getPartitions()[i];

ConsumerVertexGroup consumerVertexGroup =
ConsumerVertexGroup.fromSingleVertex(executionVertex.getID());
ConsumerVertexGroup.fromSingleVertex(
executionVertex.getID(), intermediateResult.getResultType());
partition.addConsumers(consumerVertexGroup);

ConsumedPartitionGroup consumedPartitionGroup =
Expand All @@ -132,7 +134,8 @@ private static void connectPointwise(

ExecutionVertex executionVertex = taskVertices[index];
ConsumerVertexGroup consumerVertexGroup =
ConsumerVertexGroup.fromSingleVertex(executionVertex.getID());
ConsumerVertexGroup.fromSingleVertex(
executionVertex.getID(), intermediateResult.getResultType());

int start = index * sourceCount / targetCount;
int end = (index + 1) * sourceCount / targetCount;
Expand Down Expand Up @@ -173,7 +176,8 @@ private static void connectPointwise(
}

ConsumerVertexGroup consumerVertexGroup =
ConsumerVertexGroup.fromMultipleVertices(consumers);
ConsumerVertexGroup.fromMultipleVertices(
consumers, intermediateResult.getResultType());
partition.addConsumers(consumerVertexGroup);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.flink.runtime.scheduler.strategy;

import org.apache.flink.runtime.io.network.partition.ResultPartitionType;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
Expand All @@ -26,16 +28,26 @@
public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> {
private final List<ExecutionVertexID> vertices;

private ConsumerVertexGroup(List<ExecutionVertexID> vertices) {
private final ResultPartitionType resultPartitionType;

private ConsumerVertexGroup(
List<ExecutionVertexID> vertices, ResultPartitionType resultPartitionType) {
this.vertices = vertices;
this.resultPartitionType = resultPartitionType;
}

public static ConsumerVertexGroup fromMultipleVertices(
List<ExecutionVertexID> vertices, ResultPartitionType resultPartitionType) {
return new ConsumerVertexGroup(vertices, resultPartitionType);
}

public static ConsumerVertexGroup fromMultipleVertices(List<ExecutionVertexID> vertices) {
return new ConsumerVertexGroup(vertices);
public static ConsumerVertexGroup fromSingleVertex(
ExecutionVertexID vertex, ResultPartitionType resultPartitionType) {
return new ConsumerVertexGroup(Collections.singletonList(vertex), resultPartitionType);
}

public static ConsumerVertexGroup fromSingleVertex(ExecutionVertexID vertex) {
return new ConsumerVertexGroup(Collections.singletonList(vertex));
public ResultPartitionType getResultPartitionType() {
return resultPartitionType;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ void testGetConsumerVertexGroup() {
ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
consumerVertexGroups.put(
resultPartition.getId(),
Collections.singletonList(ConsumerVertexGroup.fromSingleVertex(executionVertexId)));
Collections.singletonList(
ConsumerVertexGroup.fromSingleVertex(
executionVertexId, resultPartition.getResultType())));
assertThat(resultPartition.getConsumerVertexGroups()).isNotEmpty();
assertThat(resultPartition.getConsumerVertexGroups().get(0)).contains(executionVertexId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,17 @@ public List<ConsumedPartitionGroup> getConsumedPartitionGroups() {
return Collections.unmodifiableList(consumedPartitionGroups);
}

void addConsumerGroup(Collection<TestingSchedulingExecutionVertex> consumerVertices) {
void addConsumerGroup(
Collection<TestingSchedulingExecutionVertex> consumerVertices,
ResultPartitionType resultPartitionType) {
checkState(this.consumerVertexGroup == null);

final ConsumerVertexGroup consumerVertexGroup =
ConsumerVertexGroup.fromMultipleVertices(
consumerVertices.stream()
.map(TestingSchedulingExecutionVertex::getId)
.collect(Collectors.toList()));
.collect(Collectors.toList()),
resultPartitionType);

this.consumerVertexGroup = consumerVertexGroup;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ public TestingSchedulingTopology connect(
.withResultPartitionType(resultPartitionType)
.build();

resultPartition.addConsumerGroup(Collections.singleton(consumer));
resultPartition.addConsumerGroup(
Collections.singleton(consumer), resultPartition.getResultType());
resultPartition.setProducer(producer);

producer.addProducedPartition(resultPartition);
Expand Down Expand Up @@ -304,7 +305,8 @@ protected List<TestingSchedulingResultPartition> connect() {
resultPartition.setProducer(producer);
producer.addProducedPartition(resultPartition);
consumer.addConsumedPartition(resultPartition);
resultPartition.addConsumerGroup(Collections.singleton(consumer));
resultPartition.addConsumerGroup(
Collections.singleton(consumer), resultPartitionType);
resultPartitions.add(resultPartition);
}

Expand Down Expand Up @@ -344,7 +346,7 @@ protected List<TestingSchedulingResultPartition> connect() {
resultPartition.setProducer(producer);
producer.addProducedPartition(resultPartition);

resultPartition.addConsumerGroup(consumers);
resultPartition.addConsumerGroup(consumers, resultPartitionType);
resultPartitions.add(resultPartition);
}

Expand Down

0 comments on commit 022f7ad

Please sign in to comment.