Skip to content

Commit

Permalink
[FLINK-2191] Fix inconsistent use of closure cleaner in Scala Streaming
Browse files Browse the repository at this point in the history
The closure cleaner still cannot be disabled for the Timestamp extractor
in Time and for the delta function in Delta (windowing helpers).

Closes apache#813
  • Loading branch information
aljoscha authored and mbalassi committed Jun 10, 2015
1 parent 4fe2e18 commit e2304c4
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.io.CsvOutputFormat;
Expand Down Expand Up @@ -238,12 +237,8 @@ protected void fillInType(TypeInformation<OUT> typeInfo) {
this.typeInfo = typeInfo;
}

public <F> F clean(F f) {
if (getExecutionEnvironment().getConfig().isClosureCleanerEnabled()) {
ClosureCleaner.clean(f, true);
}
ClosureCleaner.ensureSerializable(f);
return f;
protected <F> F clean(F f) {
return getExecutionEnvironment().clean(f);
}

public StreamExecutionEnvironment getExecutionEnvironment() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ public <OUT> DataStreamSource<OUT> addSource(SourceFunction<OUT> function, Strin

boolean isParallel = function instanceof ParallelSourceFunction;

ClosureCleaner.clean(function, true);
clean(function);
StreamOperator<OUT> sourceOperator = new StreamSource<OUT>(function);

return new DataStreamSource<OUT>(this, sourceName, typeInfo, sourceOperator,
Expand Down Expand Up @@ -1169,4 +1169,16 @@ private static <OUT> void checkCollection(Collection<OUT> elements, Class<OUT> v
}
}

/**
* Returns a "closure-cleaned" version of the given function. Cleans only if closure cleaning
* is not disabled in the {@link org.apache.flink.api.common.ExecutionConfig}
*/
public <F> F clean(F f) {
if (getConfig().isClosureCleanerEnabled()) {
ClosureCleaner.clean(f, true);
}
ClosureCleaner.ensureSerializable(f);
return f;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ package org.apache.flink.streaming.api.scala
import java.util
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.scala.ClosureCleaner
import org.apache.flink.streaming.api.datastream.{ConnectedDataStream => JavaCStream, DataStream => JavaStream}
import org.apache.flink.streaming.api.functions.co.{CoFlatMapFunction, CoMapFunction, CoReduceFunction, CoWindowFunction}
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment.clean
import org.apache.flink.util.Collector
import scala.collection.JavaConversions.asScalaBuffer
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.apache.flink.streaming.api.operators.co.CoStreamFlatMap
import org.apache.flink.streaming.api.operators.co.CoStreamMap
import org.apache.flink.streaming.api.operators.co.CoStreamReduce

class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {

Expand All @@ -49,9 +46,11 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
if (fun1 == null || fun2 == null) {
throw new NullPointerException("Map function must not be null.")
}
val cleanFun1 = clean(fun1)
val cleanFun2 = clean(fun2)
val comapper = new CoMapFunction[IN1, IN2, R] {
def map1(in1: IN1): R = clean(fun1)(in1)
def map2(in2: IN2): R = clean(fun2)(in2)
def map1(in1: IN1): R = cleanFun1(in1)
def map2(in2: IN2): R = cleanFun2(in2)
}

map(comapper)
Expand Down Expand Up @@ -121,9 +120,11 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
if (fun1 == null || fun2 == null) {
throw new NullPointerException("FlatMap functions must not be null.")
}
val cleanFun1 = clean(fun1)
val cleanFun2 = clean(fun2)
val flatMapper = new CoFlatMapFunction[IN1, IN2, R] {
def flatMap1(value: IN1, out: Collector[R]): Unit = clean(fun1)(value, out)
def flatMap2(value: IN2, out: Collector[R]): Unit = clean(fun2)(value, out)
def flatMap1(value: IN1, out: Collector[R]): Unit = cleanFun1(value, out)
def flatMap2(value: IN2, out: Collector[R]): Unit = cleanFun2(value, out)
}
flatMap(flatMapper)
}
Expand All @@ -143,9 +144,9 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
if (fun1 == null || fun2 == null) {
throw new NullPointerException("FlatMap functions must not be null.")
}
val cleanFun1 = clean(fun1)
val cleanFun2 = clean(fun2)
val flatMapper = new CoFlatMapFunction[IN1, IN2, R] {
val cleanFun1 = clean(fun1)
val cleanFun2 = clean(fun2)
def flatMap1(value: IN1, out: Collector[R]) = { cleanFun1(value) foreach out.collect }
def flatMap2(value: IN2, out: Collector[R]) = { cleanFun2(value) foreach out.collect }
}
Expand Down Expand Up @@ -238,11 +239,13 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
def groupBy[K: TypeInformation, L: TypeInformation](fun1: IN1 => K, fun2: IN2 => L):
ConnectedDataStream[IN1, IN2] = {

val cleanFun1 = clean(fun1)
val cleanFun2 = clean(fun2)
val keyExtractor1 = new KeySelector[IN1, K] {
def getKey(in: IN1) = clean(fun1)(in)
def getKey(in: IN1) = cleanFun1(in)
}
val keyExtractor2 = new KeySelector[IN2, L] {
def getKey(in: IN2) = clean(fun2)(in)
def getKey(in: IN2) = cleanFun2(in)
}

javaStream.groupBy(keyExtractor1, keyExtractor2)
Expand Down Expand Up @@ -324,11 +327,14 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
def partitionByHash[K: TypeInformation, L: TypeInformation](fun1: IN1 => K, fun2: IN2 => L):
ConnectedDataStream[IN1, IN2] = {

val cleanFun1 = clean(fun1)
val cleanFun2 = clean(fun2)

val keyExtractor1 = new KeySelector[IN1, K] {
def getKey(in: IN1) = clean(fun1)(in)
def getKey(in: IN1) = cleanFun1(in)
}
val keyExtractor2 = new KeySelector[IN2, L] {
def getKey(in: IN2) = clean(fun2)(in)
def getKey(in: IN2) = cleanFun2(in)
}

javaStream.partitionByHash(keyExtractor1, keyExtractor2)
Expand Down Expand Up @@ -378,11 +384,16 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
throw new NullPointerException("Reduce functions must not be null.")
}

val cleanReducer1 = clean(reducer1)
val cleanReducer2 = clean(reducer2)
val cleanMapper1 = clean(mapper1)
val cleanMapper2 = clean(mapper2)

val reducer = new CoReduceFunction[IN1, IN2, R] {
def reduce1(value1: IN1, value2: IN1): IN1 = clean(reducer1)(value1, value2)
def map2(value: IN2): R = clean(mapper2)(value)
def reduce2(value1: IN2, value2: IN2): IN2 = clean(reducer2)(value1, value2)
def map1(value: IN1): R = clean(mapper1)(value)
def reduce1(value1: IN1, value2: IN1): IN1 = cleanReducer1(value1, value2)
def reduce2(value1: IN2, value2: IN2): IN2 = cleanReducer2(value1, value2)
def map1(value: IN1): R = cleanMapper1(value)
def map2(value: IN2): R = cleanMapper2(value)
}
reduce(reducer)
}
Expand Down Expand Up @@ -442,9 +453,11 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
throw new NullPointerException("CoWindow function must no be null")
}

val cleanCoWindower = clean(coWindower)

val coWindowFun = new CoWindowFunction[IN1, IN2, R] {
def coWindow(first: util.List[IN1], second: util.List[IN2],
out: Collector[R]): Unit = clean(coWindower)(first, second, out)
out: Collector[R]): Unit = cleanCoWindower(first.asScala, second.asScala, out)
}

windowReduce(coWindowFun, windowSize, slideInterval)
Expand Down Expand Up @@ -486,4 +499,12 @@ class ConnectedDataStream[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) {
javaStream.getType2
}

/**
* Returns a "closure-cleaned" version of the given function. Cleans only if closure cleaning
* is not disabled in the {@link org.apache.flink.api.common.ExecutionConfig}
*/
private[flink] def clean[F <: AnyRef](f: F): F = {
new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.streaming.api.scala

import org.apache.flink.api.common.io.OutputFormat
import org.apache.flink.api.scala.ClosureCleaner
import org.apache.flink.api.scala.operators.ScalaCsvOutputFormat
import org.apache.flink.core.fs.{FileSystem, Path}

Expand All @@ -34,9 +35,8 @@ import org.apache.flink.streaming.api.collector.selector.OutputSelector
import org.apache.flink.streaming.api.datastream.{DataStream => JavaStream, DataStreamSink, GroupedDataStream, SingleOutputStreamOperator}
import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction.AggregationType
import org.apache.flink.streaming.api.functions.aggregation.SumFunction
import org.apache.flink.streaming.api.functions.sink.{FileSinkFunctionByMillis, SinkFunction}
import org.apache.flink.streaming.api.functions.sink.SinkFunction
import org.apache.flink.streaming.api.operators.{StreamGroupedReduce, StreamReduce}
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment.clean
import org.apache.flink.streaming.api.windowing.helper.WindowingHelper
import org.apache.flink.streaming.api.windowing.policy.{EvictionPolicy, TriggerPolicy}
import org.apache.flink.streaming.util.serialization.SerializationSchema
Expand Down Expand Up @@ -225,8 +225,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
*/
def groupBy[K: TypeInformation](fun: T => K): DataStream[T] = {

val cleanFun = clean(fun)
val keyExtractor = new KeySelector[T, K] {
val cleanFun = clean(fun)
def getKey(in: T) = cleanFun(in)
}
javaStream.groupBy(keyExtractor)
Expand All @@ -251,8 +251,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
*/
def partitionByHash[K: TypeInformation](fun: T => K): DataStream[T] = {

val cleanFun = clean(fun)
val keyExtractor = new KeySelector[T, K] {
val cleanFun = clean(fun)
def getKey(in: T) = cleanFun(in)
}
javaStream.partitionByHash(keyExtractor)
Expand Down Expand Up @@ -472,8 +472,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("Map function must not be null.")
}
val cleanFun = clean(fun)
val mapper = new MapFunction[T, R] {
val cleanFun = clean(fun)
def map(in: T): R = cleanFun(in)
}

Expand Down Expand Up @@ -513,8 +513,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("FlatMap function must not be null.")
}
val cleanFun = clean(fun)
val flatMapper = new FlatMapFunction[T, R] {
val cleanFun = clean(fun)
def flatMap(in: T, out: Collector[R]) { cleanFun(in, out) }
}
flatMap(flatMapper)
Expand All @@ -528,8 +528,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("FlatMap function must not be null.")
}
val cleanFun = clean(fun)
val flatMapper = new FlatMapFunction[T, R] {
val cleanFun = clean(fun)
def flatMap(in: T, out: Collector[R]) { cleanFun(in) foreach out.collect }
}
flatMap(flatMapper)
Expand All @@ -555,8 +555,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("Reduce function must not be null.")
}
val cleanFun = clean(fun)
val reducer = new ReduceFunction[T] {
val cleanFun = clean(fun)
def reduce(v1: T, v2: T) = { cleanFun(v1, v2) }
}
reduce(reducer)
Expand Down Expand Up @@ -584,9 +584,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("Fold function must not be null.")
}
val cleanFun = clean(fun)
val folder = new FoldFunction[T,R] {
val cleanFun = clean(fun)

def fold(acc: R, v: T) = {
cleanFun(acc, v)
}
Expand All @@ -611,8 +610,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("Filter function must not be null.")
}
val cleanFun = clean(fun)
val filter = new FilterFunction[T] {
val cleanFun = clean(fun)
def filter(in: T) = cleanFun(in)
}
this.filter(filter)
Expand Down Expand Up @@ -665,8 +664,8 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("OutputSelector must not be null.")
}
val cleanFun = clean(fun)
val selector = new OutputSelector[T] {
val cleanFun = clean(fun)
def select(in: T): java.lang.Iterable[String] = {
cleanFun(in).toIterable.asJava
}
Expand Down Expand Up @@ -786,11 +785,19 @@ class DataStream[T](javaStream: JavaStream[T]) {
if (fun == null) {
throw new NullPointerException("Sink function must not be null.")
}
val cleanFun = clean(fun)
val sinkFunction = new SinkFunction[T] {
val cleanFun = clean(fun)
def invoke(in: T) = cleanFun(in)
}
this.addSink(sinkFunction)
}

/**
* Returns a "closure-cleaned" version of the given function. Cleans only if closure cleaning
* is not disabled in the {@link org.apache.flink.api.common.ExecutionConfig}
*/
private[flink] def clean[F <: AnyRef](f: F): F = {
new StreamExecutionEnvironment(javaStream.getExecutionEnvironment).scalaClean(f)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.flink.streaming.api.datastream.temporal.TemporalWindow
import org.apache.flink.streaming.api.datastream.{DataStream => JavaStream, SingleOutputStreamOperator}
import org.apache.flink.streaming.api.functions.co.CrossWindowFunction
import org.apache.flink.streaming.api.operators.co.CoStreamWindow
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment.clean

import scala.reflect.ClassTag

Expand Down Expand Up @@ -82,8 +81,12 @@ object StreamCrossOperator {
*/
def apply[R: TypeInformation: ClassTag](fun: (I1, I2) => R): DataStream[R] = {

val cleanCrossWindowFunction = clean(getCrossWindowFunction(op, fun))
val operator = new CoStreamWindow[I1, I2, R](
clean(getCrossWindowFunction(op, fun)), op.windowSize, op.slideInterval, op.timeStamp1,
cleanCrossWindowFunction,
op.windowSize,
op.slideInterval,
op.timeStamp1,
op.timeStamp2)

javaStream.getExecutionEnvironment().getStreamGraph().setOperator(javaStream.getId(),
Expand All @@ -110,9 +113,8 @@ object StreamCrossOperator {
CrossWindowFunction[I1, I2, R] = {
require(crossFunction != null, "Join function must not be null.")

val cleanFun = op.input1.clean(crossFunction)
val crossFun = new CrossFunction[I1, I2, R] {
val cleanFun = op.input1.clean(crossFunction)

override def cross(first: I1, second: I2): R = {
cleanFun(first, second)
}
Expand Down
Loading

0 comments on commit e2304c4

Please sign in to comment.