Skip to content

Commit

Permalink
[FLINK-1535] [tests] Add custom acumulators and custom type collect()…
Browse files Browse the repository at this point in the history
… to classloading tests

Consolidate class loading tests into one cluster execution to validate repeated different class loaders
  • Loading branch information
StephanEwen committed Apr 10, 2015
1 parent 6774054 commit 922db79
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 227 deletions.
22 changes: 2 additions & 20 deletions flink-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ under the License.
<configuration>
<archive>
<manifest>
<mainClass>org.apache.flink.test.util.testjar.KMeansForTest</mainClass>
<mainClass>org.apache.flink.test.classloading.jar.KMeansForTest</mainClass>
</manifest>
</archive>
<finalName>kmeans</finalName>
Expand Down Expand Up @@ -397,24 +397,6 @@ under the License.
<artifactId>maven-clean-plugin</artifactId>
<version>2.5</version><!--$NO-MVN-MAN-VER$-->
<executions>
<execution>
<id>remove-kmeans-test-dependencies</id>
<phase>process-test-classes</phase>
<goals>
<goal>clean</goal>
</goals>
<configuration>
<excludeDefaultDirectories>true</excludeDefaultDirectories>
<filesets>
<fileset>
<directory>${project.build.testOutputDirectory}</directory>
<includes>
<include>**/testjar/*.class</include>
</includes>
</fileset>
</filesets>
</configuration>
</execution>
<execution>
<id>remove-classloading-test-dependencies</id>
<phase>process-test-classes</phase>
Expand Down Expand Up @@ -463,7 +445,7 @@ under the License.
</goals>
</pluginExecutionFilter>
<action>
<ignore></ignore>
<ignore/>
</action>
</pluginExecution>
<pluginExecution>
Expand Down
3 changes: 2 additions & 1 deletion flink-tests/src/test/assembly/test-kmeans-assembly.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ under the License.
<outputDirectory>/</outputDirectory>
<!--modify/add include to match your package(s) -->
<includes>
<include>org/apache/flink/test/util/testjar/**</include>
<include>org/apache/flink/test/classloading/jar/KMeansForTest.class</include>
<include>org/apache/flink/test/classloading/jar/KMeansForTest$*.class</include>
</includes>
</fileSet>
</fileSets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@
import org.apache.flink.client.program.PackagedProgram;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.testdata.KMeansData;
import org.apache.flink.test.util.ForkableFlinkMiniCluster;

import org.junit.Assert;
import org.junit.Test;

public class InputSplitClassLoaderITCase {
public class ClassLoaderITCase {

private static final String JAR_FILE = "target/customsplit-test-jar.jar";
private static final String INPUT_SPLITS_PROG_JAR_FILE = "target/customsplit-test-jar.jar";

private static final String STREAMING_PROG_JAR_FILE = "target/streamingclassloader-test-jar.jar";

private static final String KMEANS_JAR_PATH = "target/kmeans-test-jar.jar";

@Test
public void testJobWithCustomInputFormat() {
Expand All @@ -40,12 +45,32 @@ public void testJobWithCustomInputFormat() {
config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 2);

ForkableFlinkMiniCluster testCluster = new ForkableFlinkMiniCluster(config, false);

try {
int port = testCluster.getJobManagerRPCPort();

PackagedProgram prog = new PackagedProgram(new File(JAR_FILE),
new String[] { JAR_FILE, "localhost", String.valueOf(port) } );
prog.invokeInteractiveModeForExecution();
PackagedProgram inputSplitTestProg = new PackagedProgram(new File(INPUT_SPLITS_PROG_JAR_FILE),
new String[] { INPUT_SPLITS_PROG_JAR_FILE,
"localhost",
String.valueOf(port),
"4" // parallelism
} );
inputSplitTestProg.invokeInteractiveModeForExecution();

PackagedProgram streamingProg = new PackagedProgram(new File(STREAMING_PROG_JAR_FILE),
new String[] { STREAMING_PROG_JAR_FILE, "localhost", String.valueOf(port) } );
streamingProg.invokeInteractiveModeForExecution();

PackagedProgram kMeansProg = new PackagedProgram(new File(KMEANS_JAR_PATH),
new String[] { KMEANS_JAR_PATH,
"localhost",
String.valueOf(port),
"4", // parallelism
KMeansData.DATAPOINTS,
KMeansData.INITIAL_CENTERS,
"25"
} );
kMeansProg.invokeInteractiveModeForExecution();
}
finally {
testCluster.shutdown();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ public static void main(String[] args) throws Exception {
final String jarFile = args[0];
final String host = args[1];
final int port = Integer.parseInt(args[2]);
final int parallelism = Integer.parseInt(args[3]);

ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(host, port, jarFile);
env.setParallelism(parallelism);

DataSet<Integer> data = env.createInput(new CustomInputFormat());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,65 @@
* limitations under the License.
*/

package org.apache.flink.test.util.testjar;
package org.apache.flink.test.classloading.jar;

import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.SimpleAccumulator;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.localDistributed.PackagedProgramEndToEndITCase;

import java.io.Serializable;
import java.util.Collection;

/**
* This class belongs to the {@link PackagedProgramEndToEndITCase} test.
* This class belongs to the {@link org.apache.flink.test.classloading.ClassLoaderITCase} test.
*
* <p> It's removed by Maven from classpath, so other tests must not depend on it.
* It tests dynamic class loading for:
* <ul>
* <li>Custom Functions</li>
* <li>Custom Data Types</li>
* <li>Custom Accumulators</li>
* <li>Custom Types in collect()</li>
* </ul>
*
* <p>
* It's removed by Maven from classpath, so other tests must not depend on it.
*/
@SuppressWarnings("serial")
public class KMeansForTest implements Program {
public class KMeansForTest {

// *************************************************************************
// PROGRAM
// *************************************************************************

@Override
public Plan getPlan(String... args) {
if (args.length < 4) {
public static void main(String[] args) throws Exception {
if (args.length < 7) {
throw new IllegalArgumentException("Missing parameters");
}

final String pointsPath = args[0];
final String centersPath = args[1];
final String outputPath = args[2];
final int numIterations = Integer.parseInt(args[3]);
final String jarFile = args[0];
final String host = args[1];
final int port = Integer.parseInt(args[2]);

final int parallelism = Integer.parseInt(args[3]);

final String pointsData = args[4];
final String centersData = args[5];
final int numIterations = Integer.parseInt(args[6]);

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(host, port, jarFile);
env.setParallelism(parallelism);

// get input data
DataSet<Point> points = env.readCsvFile(pointsPath)
.fieldDelimiter("|")
.includeFields(true, true)
.types(Double.class, Double.class)
DataSet<Point> points = env.fromElements(pointsData.split("\n"))
.map(new TuplePointConverter());

DataSet<Centroid> centroids = env.readCsvFile(centersPath)
.fieldDelimiter("|")
.includeFields(true, true, true)
.types(Integer.class, Double.class, Double.class)
DataSet<Centroid> centroids = env.fromElements(centersData.split("\n"))
.map(new TupleCentroidConverter());

// set number of bulk iterations for KMeans algorithm
Expand All @@ -79,24 +83,21 @@ public Plan getPlan(String... args) {
DataSet<Centroid> newCentroids = points
// compute closest centroid for each point
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// count and sum point coordinates for each centroid

// count and sum point coordinates for each centroid (test pojo return type)
.map(new CountAppender())
// !test if key expressions are working!

// !test if key expressions are working!
.groupBy("field0").reduce(new CentroidAccumulator())
// compute new centroids from point counts and coordinate sums

// compute new centroids from point counts and coordinate sums
.map(new CentroidAverager());

// feed new centroids back into next iteration
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);

DataSet<Tuple2<Integer, Point>> clusteredPoints = points
// assign points to final clusters
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");

// emit result
clusteredPoints.writeAsCsv(outputPath, "\n", " ");

return env.createProgramPlan();
// test that custom data type collects are working
finalCentroids.collect();
}

// *************************************************************************
Expand Down Expand Up @@ -173,31 +174,37 @@ public String toString() {
// *************************************************************************

/** Converts a Tuple2<Double,Double> into a Point. */
public static final class TuplePointConverter extends RichMapFunction<Tuple2<Double, Double>, Point> {
public static final class TuplePointConverter extends RichMapFunction<String, Point> {

@Override
public Point map(Tuple2<Double, Double> t) throws Exception {
return new Point(t.f0, t.f1);
public Point map(String str) {
String[] fields = str.split("\\|");
return new Point(Double.parseDouble(fields[1]), Double.parseDouble(fields[2]));
}
}

/** Converts a Tuple3<Integer, Double,Double> into a Centroid. */
public static final class TupleCentroidConverter extends RichMapFunction<Tuple3<Integer, Double, Double>, Centroid> {
public static final class TupleCentroidConverter extends RichMapFunction<String, Centroid> {

@Override
public Centroid map(Tuple3<Integer, Double, Double> t) throws Exception {
return new Centroid(t.f0, t.f1, t.f2);
public Centroid map(String str) {
String[] fields = str.split("\\|");
return new Centroid(Integer.parseInt(fields[0]), Double.parseDouble(fields[1]), Double.parseDouble(fields[2]));
}
}

/** Determines the closest cluster center for a data point. */
public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {

private Collection<Centroid> centroids;
private CustomAccumulator acc;

/** Reads the centroid values from a broadcast variable into a collection. */
@Override
public void open(Configuration parameters) throws Exception {
this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
this.acc = new CustomAccumulator();
getRuntimeContext().addAccumulator("myAcc", this.acc);
}

@Override
Expand All @@ -219,6 +226,7 @@ public Tuple2<Integer, Point> map(Point p) throws Exception {
}

// emit a new record with the center id and the data point.
acc.add(1L);
return new Tuple2<Integer, Point>(closestCentroidId, p);
}
}
Expand Down Expand Up @@ -264,4 +272,36 @@ public Centroid map(DummyTuple3IntPointLong value) {
return new Centroid(value.field0, value.field1.div(value.field2));
}
}

public static class CustomAccumulator implements SimpleAccumulator<Long> {

private long value;

@Override
public void add(Long value) {
this.value += value;
}

@Override
public Long getLocalValue() {
return this.value;
}

@Override
public void resetLocal() {
this.value = 0L;
}

@Override
public void merge(Accumulator<Long, Long> other) {
this.value += other.getLocalValue();
}

@Override
public Accumulator<Long, Long> clone() {
CustomAccumulator acc = new CustomAccumulator();
acc.value = this.value;
return acc;
}
}
}
Loading

0 comments on commit 922db79

Please sign in to comment.