Skip to content

Commit

Permalink
[FLINK-12405] [DataSet] Introduce a way to generate BLOCKING_PERSISTE…
Browse files Browse the repository at this point in the history
…NT ResultPartition through DataSet API
  • Loading branch information
Xpray authored and StephanEwen committed Jun 15, 2019
1 parent 75cfae2 commit d3ec142
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.io;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Preconditions;

import java.io.IOException;

/**
* This is the inner OutputFormat used for specify the BLOCKING_PERSISTENT result partition type of coming edge.
* @param <T>
*/
@Internal
public final class BlockingShuffleOutputFormat<T> implements OutputFormat<T> {

private final AbstractID intermediateDataSetId;

private BlockingShuffleOutputFormat(AbstractID intermediateDataSetId) {
this.intermediateDataSetId = intermediateDataSetId;
}

public static <T> BlockingShuffleOutputFormat<T> createOutputFormat(AbstractID intermediateDataSetId) {
return new BlockingShuffleOutputFormat<>(Preconditions.checkNotNull(intermediateDataSetId, "intermediateDataSetId is null"));
}

@Override
public void configure(Configuration parameters) {}

@Override
public void open(int taskNumber, int numTasks) throws IOException {}

@Override
public void writeRecord(T record) throws IOException {}

@Override
public void close() throws IOException {}

public AbstractID getIntermediateDataSetId() {
return intermediateDataSetId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.api.java.io.BlockingShuffleOutputFormat;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.AlgorithmOptions;
import org.apache.flink.configuration.ConfigConstants;
Expand Down Expand Up @@ -61,6 +62,7 @@
import org.apache.flink.runtime.iterative.task.IterationTailTask;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.InputFormatVertex;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
Expand Down Expand Up @@ -480,6 +482,37 @@ public void postVisit(PlanNode node) {
if (node instanceof SourcePlanNode || node instanceof NAryUnionPlanNode || node instanceof SolutionSetPlanNode) {
return;
}

// if this is a blocking shuffle vertex, we add one IntermediateDataSetID to its predecessor and return
if (node instanceof SinkPlanNode) {
Object userCodeObject = node.getProgramOperator().getUserCodeWrapper().getUserCodeObject();
if (userCodeObject instanceof BlockingShuffleOutputFormat) {
Iterable<Channel> inputIterable = node.getInputs();
if (inputIterable == null || inputIterable.iterator() == null ||
!inputIterable.iterator().hasNext()) {
throw new IllegalStateException("SinkPlanNode must have a input.");
}
PlanNode precedentNode = inputIterable.iterator().next().getSource();
JobVertex precedentVertex;
if (vertices.containsKey(precedentNode)) {
precedentVertex = vertices.get(precedentNode);
} else {
precedentVertex = chainedTasks.get(precedentNode).getContainingVertex();
}
if (precedentVertex == null) {
throw new IllegalStateException("Bug: Chained task has not been assigned its containing vertex when connecting.");
}
precedentVertex.createAndAddResultDataSet(
// use specified intermediateDataSetID
new IntermediateDataSetID(((BlockingShuffleOutputFormat) userCodeObject).getIntermediateDataSetId()),
ResultPartitionType.BLOCKING_PERSISTENT
);

// remove this node so the OutputFormatVertex will not shown in the final JobGraph.
vertices.remove(node);
return;
}
}

// check if we have an iteration. in that case, translate the step function now
if (node instanceof IterationPlanNode) {
Expand Down Expand Up @@ -1252,7 +1285,7 @@ private DistributionPattern connectJobVertices(Channel channel, int inputNumber,
edge.setShipStrategyName(shipStrategy);
edge.setPreProcessingOperationName(localStrategy);
edge.setOperatorLevelCachingDescription(caching);

return distributionPattern;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.api.common.operators.ResourceSpec;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.io.BlockingShuffleOutputFormat;
import org.apache.flink.api.java.operators.DataSink;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.IterativeDataSet;
Expand All @@ -35,9 +38,18 @@
import org.apache.flink.configuration.Configuration;
import org.apache.flink.optimizer.Optimizer;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.InputFormatVertex;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;

import org.apache.flink.runtime.jobgraph.OutputFormatVertex;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.util.AbstractID;

import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
Expand Down Expand Up @@ -265,6 +277,60 @@ public void testArtifactCompression() throws IOException {
assertState(nonExecutableDirEntry, false, true);
}

@Test
public void testGeneratingJobGraphWithUnconsumedResultPartition() {

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

DataSet<Tuple2<Long, Long>> input = env.fromElements(new Tuple2<>(1L, 2L))
.setParallelism(1);

DataSet ds = input.map((MapFunction<Tuple2<Long, Long>, Object>) value -> new Tuple2<>(value.f0 + 1, value.f1))
.setParallelism(3);

AbstractID intermediateDataSetID = new AbstractID();

// this output branch will be excluded.
ds.output(BlockingShuffleOutputFormat.createOutputFormat(intermediateDataSetID))
.setParallelism(1);

// this is the normal output branch.
ds.output(new DiscardingOutputFormat())
.setParallelism(1);

Plan plan = env.createProgramPlan();
Optimizer pc = new Optimizer(new Configuration());
OptimizedPlan op = pc.compile(plan);

JobGraphGenerator jgg = new JobGraphGenerator();
JobGraph jobGraph = jgg.compileJobGraph(op);

Assert.assertEquals(3, jobGraph.getVerticesSortedTopologicallyFromSources().size());

JobVertex inputVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(0);
JobVertex mapVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(1);
JobVertex outputVertex = jobGraph.getVerticesSortedTopologicallyFromSources().get(2);

Assert.assertThat(inputVertex, Matchers.instanceOf(InputFormatVertex.class));
Assert.assertThat(mapVertex, Matchers.instanceOf(JobVertex.class));
Assert.assertThat(outputVertex, Matchers.instanceOf(OutputFormatVertex.class));

TaskConfig cfg = new TaskConfig(outputVertex.getConfiguration());
UserCodeWrapper<OutputFormat<?>> wrapper = cfg.getStubWrapper(this.getClass().getClassLoader());
OutputFormat<?> outputFormat = wrapper.getUserCodeObject(OutputFormat.class, this.getClass().getClassLoader());

// the only OutputFormatVertex is DiscardingOutputFormat
Assert.assertThat(outputFormat, Matchers.instanceOf(DiscardingOutputFormat.class));

// there are 2 output result with one of them is ResultPartitionType.BLOCKING_PERSISTENT
Assert.assertEquals(2, mapVertex.getProducedDataSets().size());

Assert.assertTrue(mapVertex.getProducedDataSets().stream()
.anyMatch(dataSet -> dataSet.getId().equals(new IntermediateDataSetID(intermediateDataSetID)) &&
dataSet.getResultType() == ResultPartitionType.BLOCKING_PERSISTENT));

}

private static void assertState(DistributedCache.DistributedCacheEntry entry, boolean isExecutable, boolean isZipped) throws IOException {
assertNotNull(entry);
assertEquals(isExecutable, entry.isExecutable);
Expand Down

0 comments on commit d3ec142

Please sign in to comment.