From 623a5696bc328a9a55bf5de67ad0070a985c96ee Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Wed, 22 Mar 2017 15:20:51 +0200 Subject: [PATCH] [BEAM-1074] Set default-partitioner in SourceRDD.Unbounded --- .../spark/SparkNativePipelineVisitor.java | 1 - .../beam/runners/spark/io/SourceDStream.java | 52 ++++++++++++++----- .../beam/runners/spark/io/SourceRDD.java | 19 +++++-- .../spark/io/SparkUnboundedSource.java | 15 +++--- 4 files changed, 63 insertions(+), 24 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java index c2784a240f2c0..c2d38d7a2e006 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java @@ -92,7 +92,6 @@ public boolean apply(NativeTransform debugTransform) { @Override > void doVisitTransform(TransformHierarchy.Node node) { - super.doVisitTransform(node); @SuppressWarnings("unchecked") TransformT transform = (TransformT) node.getTransform(); @SuppressWarnings("unchecked") diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java index 8a0763b7052b3..3f2c10a428888 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java @@ -28,6 +28,7 @@ import org.apache.spark.rdd.RDD; import org.apache.spark.streaming.StreamingContext; import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.dstream.InputDStream; import org.apache.spark.streaming.scheduler.RateController; import org.apache.spark.streaming.scheduler.RateController$; @@ -36,7 +37,6 @@ import org.joda.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import scala.Tuple2; @@ -60,6 +60,9 @@ class SourceDStream private final UnboundedSource unboundedSource; private final SparkRuntimeContext runtimeContext; private final Duration boundReadDuration; + // Number of partitions for the DStream is final and remains the same throughout the entire + // lifetime of the pipeline, including when resuming from checkpoint. + private final int numPartitions; // the initial parallelism, set by Spark's backend, will be determined once when the job starts. // in case of resuming/recovering from checkpoint, the DStream will be reconstructed and this // property should not be reset. @@ -67,40 +70,55 @@ class SourceDStream // the bound on max records is optional. // in case it is set explicitly via PipelineOptions, it takes precedence // otherwise it could be activated via RateController. - private Long boundMaxRecords = null; + private final long boundMaxRecords; SourceDStream( StreamingContext ssc, UnboundedSource unboundedSource, - SparkRuntimeContext runtimeContext) { - + SparkRuntimeContext runtimeContext, + Long boundMaxRecords) { super(ssc, JavaSparkContext$.MODULE$., CheckpointMarkT>>fakeClassTag()); this.unboundedSource = unboundedSource; this.runtimeContext = runtimeContext; + SparkPipelineOptions options = runtimeContext.getPipelineOptions().as( SparkPipelineOptions.class); + this.boundReadDuration = boundReadDuration(options.getReadTimePercentage(), options.getMinReadTimeMillis()); // set initial parallelism once. this.initialParallelism = ssc().sc().defaultParallelism(); checkArgument(this.initialParallelism > 0, "Number of partitions must be greater than zero."); - } - public void setMaxRecordsPerBatch(long maxRecordsPerBatch) { - boundMaxRecords = maxRecordsPerBatch; + this.boundMaxRecords = boundMaxRecords > 0 ? boundMaxRecords : rateControlledMaxRecords(); + + try { + this.numPartitions = + createMicrobatchSource() + .splitIntoBundles(initialParallelism, options) + .size(); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public scala.Option, CheckpointMarkT>>> compute(Time validTime) { - long maxNumRecords = boundMaxRecords != null ? boundMaxRecords : rateControlledMaxRecords(); - MicrobatchSource microbatchSource = new MicrobatchSource<>( - unboundedSource, boundReadDuration, initialParallelism, maxNumRecords, -1, - id()); - RDD, CheckpointMarkT>> rdd = new SourceRDD.Unbounded<>( - ssc().sc(), runtimeContext, microbatchSource); + RDD, CheckpointMarkT>> rdd = + new SourceRDD.Unbounded<>( + ssc().sc(), + runtimeContext, + createMicrobatchSource(), + numPartitions); return scala.Option.apply(rdd); } + + private MicrobatchSource createMicrobatchSource() { + return new MicrobatchSource<>(unboundedSource, boundReadDuration, initialParallelism, + boundMaxRecords, -1, id()); + } + @Override public void start() { } @@ -112,6 +130,14 @@ public String name() { return "Beam UnboundedSource [" + id() + "]"; } + /** + * Number of partitions is exposed so clients of {@link SourceDStream} can use this to set + * appropriate partitioning for operations such as {@link JavaPairDStream#mapWithState}. + */ + int getNumPartitions() { + return numPartitions; + } + //---- Bound by time. // return the largest between the proportional read time (%batchDuration dedicated for read) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java index cf37b3a6526a6..1a3537fe63564 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java @@ -30,15 +30,17 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.util.WindowedValue; import org.apache.spark.Dependency; +import org.apache.spark.HashPartitioner; import org.apache.spark.InterruptibleIterator; import org.apache.spark.Partition; +import org.apache.spark.Partitioner; import org.apache.spark.SparkContext; import org.apache.spark.TaskContext; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.rdd.RDD; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import scala.Option; /** @@ -213,8 +215,10 @@ public Source getSource() { */ public static class Unbounded extends RDD, CheckpointMarkT>> { + private final MicrobatchSource microbatchSource; private final SparkRuntimeContext runtimeContext; + private final Partitioner partitioner; // to satisfy Scala API. private static final scala.collection.immutable.List> NIL = @@ -222,12 +226,14 @@ public static class Unbounded>emptyList()).toList(); public Unbounded(SparkContext sc, - SparkRuntimeContext runtimeContext, - MicrobatchSource microbatchSource) { + SparkRuntimeContext runtimeContext, + MicrobatchSource microbatchSource, + int initialNumPartitions) { super(sc, NIL, JavaSparkContext$.MODULE$., CheckpointMarkT>>fakeClassTag()); this.runtimeContext = runtimeContext; this.microbatchSource = microbatchSource; + this.partitioner = new HashPartitioner(initialNumPartitions); } @Override @@ -246,6 +252,13 @@ public Partition[] getPartitions() { } } + @Override + public Option partitioner() { + // setting the partitioner helps to "keep" the same partitioner in the following + // mapWithState read for Read.Unbounded, preventing a post-mapWithState shuffle. + return scala.Some.apply(partitioner); + } + @Override public scala.collection.Iterator, CheckpointMarkT>> compute(Partition split, TaskContext context) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java index e5bbaf185e052..6c047ac596ebd 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java @@ -77,11 +77,9 @@ public static UnboundedDataset re SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class); Long maxRecordsPerBatch = options.getMaxRecordsPerBatch(); - SourceDStream sourceDStream = new SourceDStream<>(jssc.ssc(), source, rc); - // if max records per batch was set by the user. - if (maxRecordsPerBatch > 0) { - sourceDStream.setMaxRecordsPerBatch(maxRecordsPerBatch); - } + SourceDStream sourceDStream = + new SourceDStream<>(jssc.ssc(), source, rc, maxRecordsPerBatch); + JavaPairInputDStream, CheckpointMarkT> inputDStream = JavaPairInputDStream$.MODULE$.fromInputDStream(sourceDStream, JavaSparkContext$.MODULE$.>fakeClassTag(), @@ -89,8 +87,11 @@ public static UnboundedDataset re // call mapWithState to read from a checkpointable sources. JavaMapWithStateDStream, CheckpointMarkT, Tuple2, - Tuple2, Metadata>> mapWithStateDStream = inputDStream.mapWithState( - StateSpec.function(StateSpecFunctions.mapSourceFunction(rc))); + Tuple2, Metadata>> mapWithStateDStream = + inputDStream.mapWithState( + StateSpec + .function(StateSpecFunctions.mapSourceFunction(rc)) + .numPartitions(sourceDStream.getNumPartitions())); // set checkpoint duration for read stream, if set. checkpointStream(mapWithStateDStream, options);