Skip to content

Commit

Permalink
[FLINK-6648] [gelly] Transforms for Gelly examples
Browse files Browse the repository at this point in the history
Replaces GeneratedGraph class (which was extended by inputs) with the
GraphKeyTypeTransform which can also transform the algorithm result to
obtain consistent hash codes. This allows for the removal of the case
statements in the driver checksum tests.

Float and double are now supported types.

This closes apache#4304
  • Loading branch information
greghogan committed Jul 26, 2017
1 parent 3e12673 commit 8695a21
Show file tree
Hide file tree
Showing 53 changed files with 2,461 additions and 1,349 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
import org.apache.flink.graph.drivers.parameter.Parameterized;
import org.apache.flink.graph.drivers.parameter.ParameterizedBase;
import org.apache.flink.graph.drivers.parameter.StringParameter;
import org.apache.flink.graph.drivers.transform.Transform;
import org.apache.flink.graph.drivers.transform.Transformable;
import org.apache.flink.runtime.util.EnvironmentInformation;
import org.apache.flink.util.InstantiationUtil;

Expand All @@ -64,6 +66,7 @@
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -127,6 +130,9 @@ public class Runner
private final StringParameter jobDetailsPath = new StringParameter(this, "__job_details_path")
.setDefaultValue(null);

private StringParameter jobName = new StringParameter(this, "__job_name")
.setDefaultValue(null);

/**
* Create an algorithm runner from the given arguments.
*
Expand Down Expand Up @@ -225,6 +231,21 @@ private static String getAlgorithmUsage(String algorithmName) {
.toString();
}

/**
* Configure a runtime component. Catch {@link RuntimeException} and
* re-throw with a Flink internal exception which is processed by
* CliFrontend for display to the user.
*
* @param parameterized the component to be configured
*/
private void parameterize(Parameterized parameterized) {
try {
parameterized.configure(parameters);
} catch (RuntimeException ex) {
throw new ProgramParametrizationException(ex.getMessage());
}
}

public void run() throws Exception {
// Set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Expand All @@ -235,6 +256,7 @@ public void run() throws Exception {
config.disableForceKryo();

config.setGlobalJobParameters(parameters);
parameterize(this);

// configure local parameters and throw proper exception on error
try {
Expand Down Expand Up @@ -275,11 +297,7 @@ public void run() throws Exception {
throw new ProgramParametrizationException("No input given");
}

try {
algorithm.configure(parameters);
} catch (RuntimeException ex) {
throw new ProgramParametrizationException(ex.getMessage());
}
parameterize(algorithm);

String inputName = parameters.get(INPUT);
Input input = inputFactory.get(inputName);
Expand All @@ -288,11 +306,7 @@ public void run() throws Exception {
throw new ProgramParametrizationException("Unknown input type: " + inputName);
}

try {
input.configure(parameters);
} catch (RuntimeException ex) {
throw new ProgramParametrizationException(ex.getMessage());
}
parameterize(input);

// output and usage
if (!parameters.has(OUTPUT)) {
Expand All @@ -306,10 +320,29 @@ public void run() throws Exception {
throw new ProgramParametrizationException("Unknown output type: " + outputName);
}

try {
output.configure(parameters);
} catch (RuntimeException ex) {
throw new ProgramParametrizationException(ex.getMessage());
parameterize(output);

// ----------------------------------------------------------------------------------------
// Create list of input and algorithm transforms
// ----------------------------------------------------------------------------------------

List<Transform> transforms = new ArrayList<>();

if (input instanceof Transformable) {
transforms.addAll(((Transformable) input).getTransformers());
}

if (algorithm instanceof Transformable) {
transforms.addAll(((Transformable) algorithm).getTransformers());
}

for (Transform transform : transforms) {
parameterize(transform);
}

// unused parameters
if (parameters.getUnrequestedParameters().size() > 0) {
throw new ProgramParametrizationException("Unrequested parameters: " + parameters.getUnrequestedParameters());
}

// ----------------------------------------------------------------------------------------
Expand All @@ -319,20 +352,55 @@ public void run() throws Exception {
// Create input
Graph graph = input.create(env);

// Transform input
for (Transform transform : transforms) {
graph = (Graph) transform.transformInput(graph);
}

// Run algorithm
DataSet results = algorithm.plan(graph);

// Output
String executionName = input.getIdentity() + " ⇨ " + algorithmName + " ⇨ " + output.getName();
String executionName = jobName.getValue() != null ? jobName.getValue() + ": " : "";

System.out.println();
executionName += input.getIdentity() + " ⇨ " + algorithmName + " ⇨ " + output.getName();

if (transforms.size() > 0) {
// append identifiers to job name
StringBuffer buffer = new StringBuffer(executionName).append(" [");

for (Transform transform : transforms) {
buffer.append(transform.getIdentity());
}

executionName = buffer.append("]").toString();
}

if (output == null) {
throw new ProgramParametrizationException("Unknown output type: " + outputName);
}

try {
output.configure(parameters);
} catch (RuntimeException ex) {
throw new ProgramParametrizationException(ex.getMessage());
}

if (results == null) {
env.execute(executionName);
} else {
output.write(executionName, System.out, results);
// Transform output if algorithm returned result DataSet
if (transforms.size() > 0) {
Collections.reverse(transforms);
for (Transform transform : transforms) {
results = (DataSet) transform.transformResult(results);
}
}

output.write(executionName.toString(), System.out, results);
}

System.out.println();
algorithm.printAnalytics(System.out);

if (jobDetailsPath.getValue() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.api.java.DataSet;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.drivers.parameter.BooleanParameter;
import org.apache.flink.graph.drivers.parameter.DoubleParameter;
import org.apache.flink.types.CopyableValue;

Expand All @@ -40,6 +41,8 @@ public class AdamicAdar<K extends CopyableValue<K>, VV, EV>
.setDefaultValue(0.0)
.setMinimumValue(0.0, true);

private BooleanParameter mirrorResults = new BooleanParameter(this, "mirror_results");

@Override
public String getShortDescription() {
return "similarity score weighted by centerpoint degree";
Expand All @@ -62,6 +65,7 @@ public DataSet plan(Graph<K, VV, EV> graph) throws Exception {
.run(new org.apache.flink.graph.library.similarity.AdamicAdar<K, VV, EV>()
.setMinimumRatio(minRatio.getValue().floatValue())
.setMinimumScore(minScore.getValue().floatValue())
.setMirrorResults(mirrorResults.getValue())
.setParallelism(parallelism.getValue().intValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ public abstract class DriverBase<K, VV, EV>
protected LongParameter parallelism = new LongParameter(this, "__parallelism")
.setDefaultValue(PARALLELISM_DEFAULT);

@Override
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public void printAnalytics(PrintStream out) {
// analytics are optionally executed by drivers overriding this method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public class TriangleListing<K extends Comparable<K> & CopyableValue<K>, VV, EV>

private BooleanParameter computeTriadicCensus = new BooleanParameter(this, "triadic_census");

private BooleanParameter permuteResults = new BooleanParameter(this, "permute_results");

private GraphAnalytic<K, VV, EV, ? extends PrintableResult> triadicCensus;

@Override
Expand Down Expand Up @@ -86,6 +88,7 @@ public DataSet plan(Graph<K, VV, EV> graph) throws Exception {
@SuppressWarnings("unchecked")
DataSet<PrintableResult> directedResult = (DataSet<PrintableResult>) (DataSet<?>) graph
.run(new org.apache.flink.graph.library.clustering.directed.TriangleListing<K, VV, EV>()
.setPermuteResults(permuteResults.getValue())
.setSortTriangleVertices(sortTriangleVertices.getValue())
.setParallelism(parallelism));
return directedResult;
Expand All @@ -100,6 +103,7 @@ public DataSet plan(Graph<K, VV, EV> graph) throws Exception {
@SuppressWarnings("unchecked")
DataSet<PrintableResult> undirectedResult = (DataSet<PrintableResult>) (DataSet<?>) graph
.run(new org.apache.flink.graph.library.clustering.undirected.TriangleListing<K, VV, EV>()
.setPermuteResults(permuteResults.getValue())
.setSortTriangleVertices(sortTriangleVertices.getValue())
.setParallelism(parallelism));
return undirectedResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void configure(ParameterTool parameterTool) throws ProgramParametrization

@Override
public String getIdentity() {
return getTypeName() + " " + getName() + " (" + offsetRanges + ")";
return getName() + " (" + offsetRanges + ")";
}

@Override
Expand All @@ -105,7 +105,7 @@ protected long vertexCount() {
}

@Override
public Graph<LongValue, NullValue, NullValue> generate(ExecutionEnvironment env) {
public Graph<LongValue, NullValue, NullValue> create(ExecutionEnvironment env) {
org.apache.flink.graph.generator.CirculantGraph graph = new org.apache.flink.graph.generator.CirculantGraph(env,
vertexCount.getValue());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
* Generate a {@link org.apache.flink.graph.generator.CompleteGraph}.
*/
public class CompleteGraph
extends GeneratedGraph<LongValue> {
extends GeneratedGraph {

private LongParameter vertexCount = new LongParameter(this, "vertex_count")
.setMinimumValue(MINIMUM_VERTEX_COUNT);

@Override
public String getIdentity() {
return getTypeName() + " " + getName() + " (" + vertexCount.getValue() + ")";
return getName() + " (" + vertexCount.getValue() + ")";
}

@Override
Expand All @@ -46,7 +46,7 @@ protected long vertexCount() {
}

@Override
protected Graph<LongValue, NullValue, NullValue> generate(ExecutionEnvironment env) throws Exception {
public Graph<LongValue, NullValue, NullValue> create(ExecutionEnvironment env) throws Exception {
return new org.apache.flink.graph.generator.CompleteGraph(env, vertexCount.getValue())
.setParallelism(parallelism.getValue().intValue())
.generate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
* Generate a {@link org.apache.flink.graph.generator.CycleGraph}.
*/
public class CycleGraph
extends GeneratedGraph<LongValue> {
extends GeneratedGraph {

private LongParameter vertexCount = new LongParameter(this, "vertex_count")
.setMinimumValue(MINIMUM_VERTEX_COUNT);

@Override
public String getIdentity() {
return getTypeName() + " " + getName() + " (" + vertexCount + ")";
return getName() + " (" + vertexCount + ")";
}

@Override
Expand All @@ -46,7 +46,7 @@ protected long vertexCount() {
}

@Override
public Graph<LongValue, NullValue, NullValue> generate(ExecutionEnvironment env) {
public Graph<LongValue, NullValue, NullValue> create(ExecutionEnvironment env) {
return new org.apache.flink.graph.generator.CycleGraph(env, vertexCount.getValue())
.setParallelism(parallelism.getValue().intValue())
.generate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class EchoGraph

@Override
public String getIdentity() {
return getTypeName() + " " + getName() + " (" + vertexCount.getValue() + ":" + vertexDegree.getValue() + ")";
return getName() + " (" + vertexCount.getValue() + ":" + vertexDegree.getValue() + ")";
}

@Override
Expand All @@ -50,7 +50,7 @@ protected long vertexCount() {
}

@Override
protected Graph<LongValue, NullValue, NullValue> generate(ExecutionEnvironment env) throws Exception {
public Graph<LongValue, NullValue, NullValue> create(ExecutionEnvironment env) throws Exception {
return new org.apache.flink.graph.generator.EchoGraph(env, vertexCount.getValue(), vertexDegree.getValue())
.setParallelism(parallelism.getValue().intValue())
.generate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
* Generate an {@link org.apache.flink.graph.generator.EmptyGraph}.
*/
public class EmptyGraph
extends GeneratedGraph<LongValue> {
extends GeneratedGraph {

private LongParameter vertexCount = new LongParameter(this, "vertex_count")
.setMinimumValue(MINIMUM_VERTEX_COUNT);

@Override
public String getIdentity() {
return getTypeName() + " " + getName() + " (" + vertexCount + ")";
return getName() + " (" + vertexCount + ")";
}

@Override
Expand All @@ -46,7 +46,7 @@ protected long vertexCount() {
}

@Override
public Graph<LongValue, NullValue, NullValue> generate(ExecutionEnvironment env) {
public Graph<LongValue, NullValue, NullValue> create(ExecutionEnvironment env) {
return new org.apache.flink.graph.generator.EmptyGraph(env, vertexCount.getValue())
.setParallelism(parallelism.getValue().intValue())
.generate();
Expand Down
Loading

0 comments on commit 8695a21

Please sign in to comment.