Skip to content

Commit

Permalink
[FLINK-2663] [gelly] Updated Gelly library methods to use generic key…
Browse files Browse the repository at this point in the history
… types

This squashes the following commits:

[gelly] Added missing Javadocs to GSA classes

[FLINK-2663] [gelly] Updated Gelly library methods to also use generic vertex/edge values where possible

This closes apache#1152
  • Loading branch information
vasia committed Oct 1, 2015
1 parent bbd9735 commit 9f71107
Show file tree
Hide file tree
Showing 24 changed files with 242 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public Long map(Long value) throws Exception {
}, env);

DataSet<Vertex<Long, Long>> verticesWithMinIds = graph
.run(new GSAConnectedComponents(maxIterations)).getVertices();
.run(new GSAConnectedComponents<Long, NullValue>(maxIterations));

// emit result
if (fileOutput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ public Tuple2<String, Long> map(Tuple2<Long, String> tuple2) throws Exception {
public Long map(Tuple2<Long, Long> value) {
return value.f1;
}
}).run(new LabelPropagation<String>(maxIterations))
.getVertices();
}).run(new LabelPropagation<String, NullValue>(maxIterations));

if (fileOutput) {
verticesWithCommunity.writeAsCsv(communitiesOutputPath, "\n", "\t");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
import java.io.Serializable;
import java.util.Collection;

/**
* The base class for the third and last step of a {@link GatherSumApplyIteration}.
*
* @param <K> the vertex ID type
* @param <VV> the vertex value type
* @param <M> the input type (produced by the Sum phase)
*/
@SuppressWarnings("serial")
public abstract class ApplyFunction<K, VV, M> implements Serializable {

Expand All @@ -51,6 +58,14 @@ void setNumberOfVertices(long numberOfVertices) {

//---------------------------------------------------------------------------------------------

/**
* This method is invoked once per superstep, after the {@link SumFunction}
* in a {@link GatherSumApplyIteration}.
* It updates the Vertex values.
*
* @param newValue the value computed during the current superstep.
* @param currentValue the current Vertex value.
*/
public abstract void apply(M newValue, VV currentValue);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
import java.io.Serializable;
import java.util.Collection;

/**
* The base class for the first step of a {@link GatherSumApplyIteration}.
*
* @param <VV> the vertex value type
* @param <EV> the edge value type
* @param <M> the output type
*/
@SuppressWarnings("serial")
public abstract class GatherFunction<VV, EV, M> implements Serializable {

Expand All @@ -49,6 +56,15 @@ void setNumberOfVertices(long numberOfVertices) {

//---------------------------------------------------------------------------------------------

/**
* This method is invoked once per superstep, for each {@link Neighbor} of each Vertex
* in the beginning of each superstep in a {@link GatherSumApplyIteration}.
* It needs to produce a partial value, which will be combined with other partial value
* in the next phase of the iteration.
*
* @param neighbor the input Neighbor. It provides access to the source Vertex and the Edge objects.
* @return a partial result to be combined in the Sum phase.
*/
public abstract M gather(Neighbor<VV, EV> neighbor);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
import java.io.Serializable;
import java.util.Collection;

/**
* The base class for the second step of a {@link GatherSumApplyIteration}.
*
* @param <VV> the vertex value type
* @param <EV> the edge value type
* @param <M> the output type
*/
@SuppressWarnings("serial")
public abstract class SumFunction<VV, EV, M> implements Serializable {

Expand All @@ -48,7 +55,16 @@ void setNumberOfVertices(long numberOfVertices) {
}

//---------------------------------------------------------------------------------------------

/**
* This method is invoked once per superstep, after the {@link GatherFunction}
* in a {@link GatherSumApplyIteration}.
* It combines the partial values produced by {@link GatherFunction#gather(Neighbor)}
* in pairs, until a single value has been computed.
*
* @param arg0 the first partial value.
* @param arg1 the second partial value.
* @return the combined value.
*/
public abstract M sum(M arg0, M arg1);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package org.apache.flink.graph.library;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
Expand All @@ -34,18 +36,21 @@
/**
* Community Detection Algorithm.
*
* Initially, each vertex is assigned a tuple formed of its own id along with a score equal to 1.0, as value.
* This implementation expects Long Vertex values and labels. The Vertex values of the input Graph provide the initial label assignments.
*
* Initially, each vertex is assigned a tuple formed of its own initial value along with a score equal to 1.0.
* The vertices propagate their labels and max scores in iterations, each time adopting the label with the
* highest score from the list of received messages. The chosen label is afterwards re-scored using the fraction
* delta/the superstep number. Delta is passed as a parameter and has 0.5 as a default value.
*
* The algorithm converges when vertices no longer update their value or when the maximum number of iterations
* is reached.
*
* @param <K> the Vertex ID type
*
* @see <a href="https://arxiv.org/pdf/0808.2633.pdf">article explaining the algorithm in detail</a>
*/
public class CommunityDetection implements
GraphAlgorithm<Long, Long, Double, Graph<Long, Long, Double>> {
public class CommunityDetection<K> implements GraphAlgorithm<K, Long, Double, Graph<K, Long, Double>> {

private Integer maxIterations;

Expand All @@ -58,20 +63,22 @@ public CommunityDetection(Integer maxIterations, Double delta) {
}

@Override
public Graph<Long, Long, Double> run(Graph<Long, Long, Double> graph) {
public Graph<K, Long, Double> run(Graph<K, Long, Double> graph) {

Graph<Long, Long, Double> undirectedGraph = graph.getUndirected();
DataSet<Vertex<K, Tuple2<Long, Double>>> initializedVertices = graph.getVertices()
.map(new AddScoreToVertexValuesMapper<K>());

Graph<Long, Tuple2<Long, Double>, Double> graphWithScoredVertices = undirectedGraph
.mapVertices(new AddScoreToVertexValuesMapper());
Graph<K, Tuple2<Long, Double>, Double> graphWithScoredVertices =
Graph.fromDataSet(initializedVertices, graph.getEdges(), graph.getContext()).getUndirected();

return graphWithScoredVertices.runVertexCentricIteration(new VertexLabelUpdater(delta),
new LabelMessenger(), maxIterations)
.mapVertices(new RemoveScoreFromVertexValuesMapper());
return graphWithScoredVertices.runVertexCentricIteration(new VertexLabelUpdater<K>(delta),
new LabelMessenger<K>(), maxIterations)
.mapVertices(new RemoveScoreFromVertexValuesMapper<K>());
}

@SuppressWarnings("serial")
public static final class VertexLabelUpdater extends VertexUpdateFunction<Long, Tuple2<Long, Double>, Tuple2<Long, Double>> {
public static final class VertexLabelUpdater<K> extends VertexUpdateFunction<
K, Tuple2<Long, Double>, Tuple2<Long, Double>> {

private Double delta;

Expand All @@ -80,7 +87,7 @@ public VertexLabelUpdater(Double delta) {
}

@Override
public void updateVertex(Vertex<Long, Tuple2<Long, Double>> vertex,
public void updateVertex(Vertex<K, Tuple2<Long, Double>> vertex,
MessageIterator<Tuple2<Long, Double>> inMessages) throws Exception {

// we would like these two maps to be ordered
Expand Down Expand Up @@ -140,34 +147,36 @@ public void updateVertex(Vertex<Long, Tuple2<Long, Double>> vertex,
}

@SuppressWarnings("serial")
public static final class LabelMessenger extends MessagingFunction<Long, Tuple2<Long, Double>,
public static final class LabelMessenger<K> extends MessagingFunction<K, Tuple2<Long, Double>,
Tuple2<Long, Double>, Double> {

@Override
public void sendMessages(Vertex<Long, Tuple2<Long, Double>> vertex) throws Exception {
public void sendMessages(Vertex<K, Tuple2<Long, Double>> vertex) throws Exception {

for(Edge<Long, Double> edge : getEdges()) {
for(Edge<K, Double> edge : getEdges()) {
sendMessageTo(edge.getTarget(), new Tuple2<Long, Double>(vertex.getValue().f0,
vertex.getValue().f1 * edge.getValue()));
}

}
}

@SuppressWarnings("serial")
public static final class AddScoreToVertexValuesMapper implements MapFunction<Vertex<Long, Long>, Tuple2<Long, Double>> {
@ForwardedFields("f0")
public static final class AddScoreToVertexValuesMapper<K> implements MapFunction<
Vertex<K, Long>, Vertex<K, Tuple2<Long, Double>>> {

@Override
public Tuple2<Long, Double> map(Vertex<Long, Long> vertex) throws Exception {
return new Tuple2<Long, Double>(vertex.getValue(), 1.0);
public Vertex<K, Tuple2<Long, Double>> map(Vertex<K, Long> vertex) {
return new Vertex<K, Tuple2<Long, Double>>(
vertex.getId(), new Tuple2<Long, Double>(vertex.getValue(), 1.0));
}
}

@SuppressWarnings("serial")
public static final class RemoveScoreFromVertexValuesMapper implements MapFunction<Vertex<Long, Tuple2<Long, Double>>, Long> {
public static final class RemoveScoreFromVertexValuesMapper<K> implements MapFunction<
Vertex<K, Tuple2<Long, Double>>, Long> {

@Override
public Long map(Vertex<Long, Tuple2<Long, Double>> vertex) throws Exception {
public Long map(Vertex<K, Tuple2<Long, Double>> vertex) throws Exception {
return vertex.getValue().f0;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,32 @@

package org.apache.flink.graph.library;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.MessageIterator;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.graph.spargel.VertexUpdateFunction;
import org.apache.flink.graph.utils.NullValueEdgeMapper;
import org.apache.flink.types.NullValue;

/**
* A vertex-centric implementation of the Connected components algorithm.
* A vertex-centric implementation of the Connected Components algorithm.
*
* Initially, each vertex will have its own ID as a value(is its own component). The vertices propagate their
* current component ID in iterations, each time adopting a new value from the received neighbor IDs,
* This implementation assumes that the vertices of the input Graph are initialized with unique, Long component IDs.
* The vertices propagate their current component ID in iterations, each time adopting a new value from the received neighbor IDs,
* provided that the value is less than the current minimum.
*
* The algorithm converges when vertices no longer update their value or when the maximum number of iterations
* is reached.
*
* The result is a DataSet of vertices, where the vertex value corresponds to the assigned component ID.
*
* @see {@link org.apache.flink.graph.library.GSAConnectedComponents}
*/
@SuppressWarnings("serial")
public class ConnectedComponents implements
GraphAlgorithm<Long, Long, NullValue, Graph<Long, Long, NullValue>> {
public class ConnectedComponents<K, EV> implements GraphAlgorithm<K, Long, EV, DataSet<Vertex<K, Long>>> {

private Integer maxIterations;

Expand All @@ -47,21 +52,24 @@ public ConnectedComponents(Integer maxIterations) {
}

@Override
public Graph<Long, Long, NullValue> run(Graph<Long, Long, NullValue> graph) throws Exception {
public DataSet<Vertex<K, Long>> run(Graph<K, Long, EV> graph) throws Exception {

Graph<Long, Long, NullValue> undirectedGraph = graph.getUndirected();
Graph<K, Long, NullValue> undirectedGraph = graph.mapEdges(new NullValueEdgeMapper<K, EV>())
.getUndirected();

// initialize vertex values and run the Vertex Centric Iteration
return undirectedGraph.runVertexCentricIteration(new CCUpdater(), new CCMessenger(), maxIterations);
return undirectedGraph.runVertexCentricIteration(
new CCUpdater<K>(), new CCMessenger<K>(), maxIterations)
.getVertices();
}

/**
* Updates the value of a vertex by picking the minimum neighbor ID out of all the incoming messages.
*/
public static final class CCUpdater extends VertexUpdateFunction<Long, Long, Long> {
public static final class CCUpdater<K> extends VertexUpdateFunction<K, Long, Long> {

@Override
public void updateVertex(Vertex<Long, Long> vertex, MessageIterator<Long> messages) throws Exception {
public void updateVertex(Vertex<K, Long> vertex, MessageIterator<Long> messages) throws Exception {
long min = Long.MAX_VALUE;

for (long msg : messages) {
Expand All @@ -78,10 +86,10 @@ public void updateVertex(Vertex<Long, Long> vertex, MessageIterator<Long> messag
/**
* Distributes the minimum ID associated with a given vertex among all the target vertices.
*/
public static final class CCMessenger extends MessagingFunction<Long, Long, Long, NullValue> {
public static final class CCMessenger<K> extends MessagingFunction<K, Long, Long, NullValue> {

@Override
public void sendMessages(Vertex<Long, Long> vertex) throws Exception {
public void sendMessages(Vertex<K, Long> vertex) throws Exception {
// send current minimum to neighbors
sendMessageToAllNeighbors(vertex.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@

package org.apache.flink.graph.library;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.GraphAlgorithm;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.graph.gsa.Neighbor;
import org.apache.flink.graph.utils.NullValueEdgeMapper;
import org.apache.flink.types.NullValue;

/**
* This is an implementation of the Connected Components algorithm, using a gather-sum-apply iteration.
* This implementation assumes that the vertices of the input Graph are initialized with unique, Long component IDs.
* The result is a DataSet of vertices, where the vertex value corresponds to the assigned component ID.
*
* @see {@link org.apache.flink.graph.library.ConnectedComponents}
*/
public class GSAConnectedComponents implements
GraphAlgorithm<Long, Long, NullValue, Graph<Long, Long, NullValue>> {
public class GSAConnectedComponents<K, EV> implements GraphAlgorithm<K, Long, EV, DataSet<Vertex<K, Long>>> {

private Integer maxIterations;

Expand All @@ -39,13 +45,15 @@ public GSAConnectedComponents(Integer maxIterations) {
}

@Override
public Graph<Long, Long, NullValue> run(Graph<Long, Long, NullValue> graph) throws Exception {
public DataSet<Vertex<K, Long>> run(Graph<K, Long, EV> graph) throws Exception {

Graph<Long, Long, NullValue> undirectedGraph = graph.getUndirected();
Graph<K, Long, NullValue> undirectedGraph = graph.mapEdges(new NullValueEdgeMapper<K, EV>())
.getUndirected();

// initialize vertex values and run the Vertex Centric Iteration
return undirectedGraph.runGatherSumApplyIteration(new GatherNeighborIds(), new SelectMinId(), new UpdateComponentId(),
maxIterations);
return undirectedGraph.runGatherSumApplyIteration(
new GatherNeighborIds(), new SelectMinId(), new UpdateComponentId<K>(),
maxIterations).getVertices();
}

// --------------------------------------------------------------------------------------------
Expand All @@ -69,7 +77,7 @@ public Long sum(Long newValue, Long currentValue) {
};

@SuppressWarnings("serial")
private static final class UpdateComponentId extends ApplyFunction<Long, Long, Long> {
private static final class UpdateComponentId<K> extends ApplyFunction<K, Long, Long> {

public void apply(Long summedValue, Long origValue) {
if (summedValue < origValue) {
Expand Down
Loading

0 comments on commit 9f71107

Please sign in to comment.