Skip to content

Commit

Permalink
[BEAM-649] Analyse DAG to determine if RDD/DStream has to be cached o…
Browse files Browse the repository at this point in the history
…r not
  • Loading branch information
jbonofre committed Mar 23, 2017
1 parent 9ac1ffc commit daa10dd
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir;
import org.apache.beam.runners.spark.translation.streaming.SparkRunnerStreamingContextFactory;
import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.WatermarksListener;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.Read;
Expand Down Expand Up @@ -90,6 +91,7 @@
public final class SparkRunner extends PipelineRunner<SparkPipelineResult> {

private static final Logger LOG = LoggerFactory.getLogger(SparkRunner.class);

/**
* Options used in this pipeline runner.
*/
Expand Down Expand Up @@ -143,10 +145,14 @@ public SparkPipelineResult run(final Pipeline pipeline) {

final SparkPipelineResult result;
final Future<?> startPipeline;

final SparkPipelineTranslator translator;

final ExecutorService executorService = Executors.newSingleThreadExecutor();

MetricsEnvironment.setMetricsSupported(true);

// visit the pipeline to determine the translation mode
detectTranslationMode(pipeline);

if (mOptions.isStreaming()) {
Expand All @@ -157,6 +163,11 @@ public SparkPipelineResult run(final Pipeline pipeline) {
JavaStreamingContext.getOrCreate(checkpointDir.getSparkCheckpointDir().toString(),
contextFactory);

// update cache candidates
translator = new StreamingTransformTranslator.Translator(
new TransformTranslator.Translator());
updateCacheCandidates(pipeline, translator, contextFactory.getEvaluationContext());

// Checkpoint aggregator/metrics values
jssc.addStreamingListener(
new JavaStreamingListenerWrapper(
Expand Down Expand Up @@ -191,17 +202,21 @@ public void run() {

result = new SparkPipelineResult.StreamingMode(startPipeline, jssc);
} else {
// create the evaluation context
final JavaSparkContext jsc = SparkContextFactory.getSparkContext(mOptions);
final EvaluationContext evaluationContext = new EvaluationContext(jsc, pipeline);
translator = new TransformTranslator.Translator();

// update the cache candidates
updateCacheCandidates(pipeline, translator, evaluationContext);

initAccumulators(mOptions, jsc);

startPipeline = executorService.submit(new Runnable() {

@Override
public void run() {
pipeline.traverseTopologically(new Evaluator(new TransformTranslator.Translator(),
evaluationContext));
pipeline.traverseTopologically(new Evaluator(translator, evaluationContext));
evaluationContext.computeOutputs();
LOG.info("Batch pipeline execution complete.");
}
Expand Down Expand Up @@ -240,9 +255,7 @@ public static void initAccumulators(SparkPipelineOptions opts, JavaSparkContext
}

/**
* Detect the translation mode for the pipeline and change options in case streaming
* translation is needed.
* @param pipeline
* Visit the pipeline to determine the translation mode (batch/streaming).
*/
private void detectTranslationMode(Pipeline pipeline) {
TranslationModeDetector detector = new TranslationModeDetector();
Expand All @@ -253,6 +266,17 @@ private void detectTranslationMode(Pipeline pipeline) {
}
}

/**
* Evaluator that update/populate the cache candidates.
*/
private void updateCacheCandidates(
Pipeline pipeline,
SparkPipelineTranslator translator,
EvaluationContext evaluationContext) {
CacheVisitor cacheVisitor = new CacheVisitor(translator, evaluationContext);
pipeline.traverseTopologically(cacheVisitor);
}

/**
* The translation mode of the Beam Pipeline.
*/
Expand Down Expand Up @@ -297,6 +321,36 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) {
}
}

/**
* Traverses the pipeline to populate the candidates for caching.
*/
static class CacheVisitor extends Evaluator {

protected CacheVisitor(
SparkPipelineTranslator translator,
EvaluationContext evaluationContext) {
super(translator, evaluationContext);
}

@Override
public void doVisitTransform(TransformHierarchy.Node node) {
// we populate cache candidates by updating the map with inputs of each node.
// The goal is to detect the PCollections accessed more than one time, and so enable cache
// on the underlying RDDs or DStreams.

for (TaggedPValue input : node.getInputs()) {
PValue value = input.getValue();
if (value instanceof PCollection) {
long count = 1L;
if (ctxt.getCacheCandidates().get(value) != null) {
count = ctxt.getCacheCandidates().get(value) + 1;
}
ctxt.getCacheCandidates().put((PCollection) value, count);
}
}
}
}

/**
* Evaluator on the pipeline.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ public WindowedValue<T> apply(byte[] bytes) {

@Override
public void cache(String storageLevel) {
rdd.persist(StorageLevel.fromString(storageLevel));
// populate the rdd if needed
getRDD().persist(StorageLevel.fromString(storageLevel));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import static com.google.common.base.Preconditions.checkArgument;

import com.google.common.collect.Iterables;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.AppliedPTransform;
Expand All @@ -52,10 +54,10 @@ public class EvaluationContext {
private final Map<PValue, Dataset> datasets = new LinkedHashMap<>();
private final Map<PValue, Dataset> pcollections = new LinkedHashMap<>();
private final Set<Dataset> leaves = new LinkedHashSet<>();
private final Set<PValue> multiReads = new LinkedHashSet<>();
private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
private AppliedPTransform<?, ?, ?> currentTransform;
private final SparkPCollectionView pviews = new SparkPCollectionView();
private final Map<PCollection, Long> cacheCandidates = new HashMap<>();

public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
this.jsc = jsc;
Expand Down Expand Up @@ -116,6 +118,15 @@ public List<TaggedPValue> getOutputs(PTransform<?, ?> transform) {
return currentTransform.getOutputs();
}

private boolean shouldCache(PValue pvalue) {
if ((pvalue instanceof PCollection)
&& cacheCandidates.containsKey(pvalue)
&& cacheCandidates.get(pvalue) > 1) {
return true;
}
return false;
}

public void putDataset(PTransform<?, ? extends PValue> transform, Dataset dataset) {
putDataset(getOutput(transform), dataset);
}
Expand All @@ -126,13 +137,30 @@ public void putDataset(PValue pvalue, Dataset dataset) {
} catch (IllegalStateException e) {
// name not set, ignore
}
if (shouldCache(pvalue)) {
dataset.cache(storageLevel());
}
datasets.put(pvalue, dataset);
leaves.add(dataset);
}

<T> void putBoundedDatasetFromValues(
PTransform<?, ? extends PValue> transform, Iterable<T> values, Coder<T> coder) {
datasets.put(getOutput(transform), new BoundedDataset<>(values, jsc, coder));
PValue output = getOutput(transform);
if (shouldCache(output)) {
// eagerly create the RDD, as it will be reused.
Iterable<WindowedValue<T>> elems = Iterables.transform(values,
WindowingHelpers.<T>windowValueFunction());
WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder =
WindowedValue.getValueOnlyCoder(coder);
JavaRDD<WindowedValue<T>> rdd =
getSparkContext().parallelize(CoderHelpers.toByteArrays(elems, windowCoder))
.map(CoderHelpers.fromByteFunction(windowCoder));
putDataset(transform, new BoundedDataset<>(rdd));
} else {
// create a BoundedDataset that would create a RDD on demand
datasets.put(getOutput(transform), new BoundedDataset<>(values, jsc, coder));
}
}

public Dataset borrowDataset(PTransform<? extends PValue, ?> transform) {
Expand All @@ -142,12 +170,6 @@ public Dataset borrowDataset(PTransform<? extends PValue, ?> transform) {
public Dataset borrowDataset(PValue pvalue) {
Dataset dataset = datasets.get(pvalue);
leaves.remove(dataset);
if (multiReads.contains(pvalue)) {
// Ensure the RDD is marked as cached
dataset.cache(storageLevel());
} else {
multiReads.add(pvalue);
}
return dataset;
}

Expand All @@ -157,8 +179,6 @@ public Dataset borrowDataset(PValue pvalue) {
*/
public void computeOutputs() {
for (Dataset dataset : leaves) {
// cache so that any subsequent get() is cheap.
dataset.cache(storageLevel());
dataset.action(); // force computation.
}
}
Expand All @@ -185,18 +205,6 @@ public <T> T get(PValue value) {
throw new IllegalStateException("Cannot resolve un-known PObject: " + value);
}

/**
* Retrieves an iterable of results associated with the PCollection passed in.
*
* @param pcollection Collection we wish to translate.
* @param <T> Type of elements contained in collection.
* @return Natively types result associated with collection.
*/
<T> Iterable<T> get(PCollection<T> pcollection) {
Iterable<WindowedValue<T>> windowedValues = getWindowedValues(pcollection);
return Iterables.transform(windowedValues, WindowingHelpers.<T>unwindowValueFunction());
}

/**
* Retrun the current views creates in the pipepline.
*
Expand All @@ -220,6 +228,15 @@ public void putPView(
pviews.putPView(view, value, coder);
}

/**
* Get the map of cache candidates hold by the evaluation context.
*
* @return The current {@link Map} of cache candidates.
*/
public Map<PCollection, Long> getCacheCandidates() {
return this.cacheCandidates;
}

<T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
@SuppressWarnings("unchecked")
BoundedDataset<T> boundedDataset = (BoundedDataset<T>) datasets.get(pcollection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public JavaStreamingContext create() {
return jssc;
}

public EvaluationContext getEvaluationContext() {
return this.ctxt;
}

private void checkpoint(JavaStreamingContext jssc) {
Path rootCheckpointPath = checkpointDir.getRootCheckpointDir();
Path sparkCheckpointPath = checkpointDir.getSparkCheckpointDir();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark;

import static org.junit.Assert.assertEquals;

import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.Rule;
import org.junit.Test;

/**
* This test checks how the cache candidates map is populated by the runner when evaluating the
* pipeline.
*/
public class CacheTest {

@Rule
public final transient PipelineRule pipelineRule = PipelineRule.batch();

@Test
public void cacheCandidatesUpdaterTest() throws Exception {
Pipeline pipeline = pipelineRule.createPipeline();
PCollection<String> pCollection = pipeline.apply(Create.of("foo", "bar"));
// first read
pCollection.apply(Count.<String>globally());
// second read
// as we access the same PCollection two times, the Spark runner does optimization and so
// will cache the RDD representing this PCollection
pCollection.apply(Count.<String>globally());

JavaSparkContext jsc = SparkContextFactory.getSparkContext(pipelineRule.getOptions());
EvaluationContext ctxt = new EvaluationContext(jsc, pipeline);
SparkRunner.CacheVisitor cacheVisitor =
new SparkRunner.CacheVisitor(new TransformTranslator.Translator(), ctxt);
pipeline.traverseTopologically(cacheVisitor);
assertEquals(2L, (long) ctxt.getCacheCandidates().get(pCollection));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ public class StorageLevelTest {
@Test
public void test() throws Exception {
pipelineRule.getOptions().setStorageLevel("DISK_ONLY");
Pipeline p = pipelineRule.createPipeline();
Pipeline pipeline = pipelineRule.createPipeline();

PCollection<String> pCollection = p.apply(Create.of("foo"));
PCollection<String> pCollection = pipeline.apply(Create.of("foo"));

// by default, the Spark runner doesn't cache the RDD if it accessed only one time.
// So, to "force" the caching of the RDD, we have to call the RDD at least two time.
Expand All @@ -50,7 +50,7 @@ public void test() throws Exception {

PAssert.thatSingleton(output).isEqualTo("Disk Serialized 1x Replicated");

p.run();
pipeline.run();
}

}

0 comments on commit daa10dd

Please sign in to comment.