Skip to content

Commit

Permalink
[FLINK-1862] [apis] Add support for non-serializable types for collec…
Browse files Browse the repository at this point in the history
…t() by switching from Java serialization to Flink serialization
  • Loading branch information
StephanEwen committed Apr 10, 2015
1 parent 3246255 commit 211d0bd
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@

package org.apache.flink.api.common.accumulators;

import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.InputViewDataInputStreamWrapper;
import org.apache.flink.core.memory.OutputViewDataOutputStreamWrapper;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -37,19 +43,24 @@ public class SerializedListAccumulator<T> implements Accumulator<T, ArrayList<by
private static final long serialVersionUID = 1L;

private ArrayList<byte[]> localValue = new ArrayList<byte[]>();


@Override
public void add(T value) {
if (value == null) {
throw new NullPointerException("Value to accumulate must nor be null");
}

throw new UnsupportedOperationException();
}

public void add(T value, TypeSerializer<T> serializer) throws IOException {
try {
byte[] byteArray = InstantiationUtil.serializeObject(value);
localValue.add(byteArray);
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
OutputViewDataOutputStreamWrapper out =
new OutputViewDataOutputStreamWrapper(new DataOutputStream(outStream));

serializer.serialize(value, out);
localValue.add(outStream.toByteArray());
}
catch (IOException e) {
throw new RuntimeException("Serialization of accumulated value failed", e);
throw new IOException("Failed to serialize value '" + value + '\'', e);
}
}

Expand All @@ -58,21 +69,6 @@ public ArrayList<byte[]> getLocalValue() {
return localValue;
}

public ArrayList<T> deserializeLocalValue(ClassLoader classLoader) {
try {
ArrayList<T> arrList = new ArrayList<T>(localValue.size());
for (byte[] byteArr : localValue) {
@SuppressWarnings("unchecked")
T item = (T) InstantiationUtil.deserializeObject(byteArr, classLoader);
arrList.add(item);
}
return arrList;
}
catch (Exception e) {
throw new RuntimeException("Cannot deserialize accumulator list element", e);
}
}

@Override
public void resetLocal() {
localValue.clear();
Expand All @@ -91,12 +87,15 @@ public SerializedListAccumulator<T> clone() {
}

@SuppressWarnings("unchecked")
public static <T> List<T> deserializeList(ArrayList<byte[]> data, ClassLoader loader)
public static <T> List<T> deserializeList(ArrayList<byte[]> data, TypeSerializer<T> serializer)
throws IOException, ClassNotFoundException
{
List<T> result = new ArrayList<T>(data.size());
for (byte[] bytes : data) {
result.add((T) InstantiationUtil.deserializeObject(bytes, loader));
ByteArrayInputStream inStream = new ByteArrayInputStream(bytes);
InputViewDataInputStreamWrapper in = new InputViewDataInputStreamWrapper(new DataInputStream(inStream));
T val = serializer.deserialize(in);
result.add(val);
}
return result;
}
Expand Down
18 changes: 5 additions & 13 deletions flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint;
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.FirstReducer;
import org.apache.flink.api.java.functions.FormattingMapper;
Expand Down Expand Up @@ -411,24 +412,15 @@ public long count() throws Exception {
* @see org.apache.flink.api.java.Utils.CollectHelper
*/
public List<T> collect() throws Exception {
// validate that our type is actually serializable
Class<?> typeClass = getType().getTypeClass();
ClassLoader cl = typeClass.getClassLoader() == null ? ClassLoader.getSystemClassLoader()
: typeClass.getClassLoader();

if (!java.io.Serializable.class.isAssignableFrom(typeClass)) {
throw new UnsupportedOperationException("collect() can only be used with serializable data types. "
+ "The DataSet type '" + typeClass.getName() + "' does not implement java.io.Serializable.");
}

final String id = new AbstractID().toString();

this.flatMap(new Utils.CollectHelper<T>(id)).output(new DiscardingOutputFormat<T>());
final TypeSerializer<T> serializer = getType().createSerializer(getExecutionEnvironment().getConfig());

this.flatMap(new Utils.CollectHelper<T>(id, serializer)).output(new DiscardingOutputFormat<T>());
JobExecutionResult res = getExecutionEnvironment().execute();

ArrayList<byte[]> accResult = res.getAccumulatorResult(id);
try {
return SerializedListAccumulator.deserializeList(accResult, cl);
return SerializedListAccumulator.deserializeList(accResult, serializer);
}
catch (ClassNotFoundException e) {
throw new RuntimeException("Cannot find type class of collected data type.", e);
Expand Down
12 changes: 8 additions & 4 deletions flink-java/src/main/java/org/apache/flink/api/java/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.accumulators.SerializedListAccumulator;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.GenericTypeInfo;

import java.lang.reflect.Field;
Expand Down Expand Up @@ -97,21 +98,24 @@ public static class CollectHelper<T> extends RichFlatMapFunction<T, T> {
private static final long serialVersionUID = 1L;

private final String id;
private final SerializedListAccumulator<T> accumulator;
private final TypeSerializer<T> serializer;

private SerializedListAccumulator<T> accumulator;

public CollectHelper(String id) {
public CollectHelper(String id, TypeSerializer<T> serializer) {
this.id = id;
this.accumulator = new SerializedListAccumulator<T>();
this.serializer = serializer;
}

@Override
public void open(Configuration parameters) throws Exception {
this.accumulator = new SerializedListAccumulator<T>();
getRuntimeContext().addAccumulator(id, accumulator);
}

@Override
public void flatMap(T value, Collector<T> out) throws Exception {
accumulator.add(value);
accumulator.add(value, serializer);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.io.{FileOutputFormat, OutputFormat}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.Utils.CountHelper
import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.java.functions.{FirstReducer, KeySelector}
Expand Down Expand Up @@ -537,24 +538,18 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
*/
@throws(classOf[Exception])
def collect(): Seq[T] = {
val typeClass: Class[_] = getType().getTypeClass()
val cl: ClassLoader = if (typeClass.getClassLoader == null) ClassLoader.getSystemClassLoader
else typeClass.getClassLoader

if (typeClass != null && !classOf[java.io.Serializable].isAssignableFrom(typeClass)) {
throw new UnsupportedOperationException(
"collect() can only be used with serializable data types. " +
"The DataSet type '" + typeClass.getName + "' does not implement java.io.Serializable.")
}

val id = new AbstractID().toString
javaSet.flatMap(new Utils.CollectHelper[T](id)).output(new DiscardingOutputFormat[T])
val serializer = getType().createSerializer(getExecutionEnvironment.getConfig)

javaSet.flatMap(new Utils.CollectHelper[T](id, serializer))
.output(new DiscardingOutputFormat[T])

val res = getExecutionEnvironment.execute()

val accResult: java.util.ArrayList[Array[Byte]] = res.getAccumulatorResult(id)

try {
SerializedListAccumulator.deserializeList(accResult, cl).asScala
SerializedListAccumulator.deserializeList(accResult, serializer).asScala
}
catch {
case e: ClassNotFoundException => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;

import java.io.Serializable;
import java.util.Collection;

/**
Expand Down Expand Up @@ -107,7 +106,7 @@ public static void main(String[] args) throws Exception {
/**
* A simple two-dimensional point.
*/
public static class Point implements Serializable {
public static class Point {

public double x, y;

Expand Down

0 comments on commit 211d0bd

Please sign in to comment.