Skip to content

Commit

Permalink
Merge broadcast variable runtime.
Browse files Browse the repository at this point in the history
Extend runtime for iterative algorithms.
Add iterative kmeans test for runtime code.
  • Loading branch information
StephanEwen committed Feb 13, 2014
1 parent 4353134 commit 90846d7
Show file tree
Hide file tree
Showing 28 changed files with 772 additions and 427 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
public abstract class AbstractFunction implements Function, Serializable {

private static final long serialVersionUID = 1L;

// --------------------------------------------------------------------------------------------
// Runtime context access
// --------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
import eu.stratosphere.api.common.operators.BulkIteration;
import eu.stratosphere.api.common.operators.FileDataSink;
import eu.stratosphere.api.common.operators.FileDataSource;
import eu.stratosphere.api.java.record.io.CsvInputFormat;
import eu.stratosphere.api.java.record.operators.MapOperator;
import eu.stratosphere.api.java.record.operators.ReduceOperator;
import eu.stratosphere.example.java.record.kmeans.KMeansSingleStep.PointBuilder;
import eu.stratosphere.example.java.record.kmeans.KMeansSingleStep.RecomputeClusterCenter;
import eu.stratosphere.example.java.record.kmeans.KMeansSingleStep.SelectNearestCenter;
import eu.stratosphere.example.java.record.kmeans.udfs.PointInFormat;
import eu.stratosphere.example.java.record.kmeans.udfs.PointOutFormat;
import eu.stratosphere.types.DoubleValue;
import eu.stratosphere.types.IntValue;


public class KMeansIterative implements Program, ProgramDescription {

private static final long serialVersionUID = 1L;

@Override
public Plan getPlan(String... args) {
Expand All @@ -42,10 +45,16 @@ public Plan getPlan(String... args) {
int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 2);

// create DataSourceContract for data point input
FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points");
@SuppressWarnings("unchecked")
FileDataSource pointsSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), dataPointInput, "Data Points");

// create DataSourceContract for cluster center input
FileDataSource clusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers");
@SuppressWarnings("unchecked")
FileDataSource clustersSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), clusterInput, "Centers");

MapOperator dataPoints = MapOperator.builder(new PointBuilder()).name("Build data points").input(pointsSource).build();

MapOperator clusterPoints = MapOperator.builder(new PointBuilder()).name("Build cluster points").input(clustersSource).build();

BulkIteration iter = new BulkIteration("k-means loop");
iter.setInput(clusterPoints);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
import eu.stratosphere.api.common.operators.FileDataSource;
import eu.stratosphere.api.java.record.functions.MapFunction;
import eu.stratosphere.api.java.record.functions.ReduceFunction;
import eu.stratosphere.api.java.record.io.CsvInputFormat;
import eu.stratosphere.api.java.record.io.DelimitedOutputFormat;
import eu.stratosphere.api.java.record.operators.MapOperator;
import eu.stratosphere.api.java.record.operators.ReduceOperator;
import eu.stratosphere.configuration.Configuration;
import eu.stratosphere.example.java.record.kmeans.udfs.PointInFormat;
import eu.stratosphere.example.java.record.kmeans.udfs.PointOutFormat;
import eu.stratosphere.types.DoubleValue;
import eu.stratosphere.types.IntValue;
import eu.stratosphere.types.Record;
import eu.stratosphere.types.Value;
Expand All @@ -43,6 +44,7 @@

public class KMeansSingleStep implements Program, ProgramDescription {

private static final long serialVersionUID = 1L;

@Override
public Plan getPlan(String... args) {
Expand All @@ -53,10 +55,16 @@ public Plan getPlan(String... args) {
String output = (args.length > 3 ? args[3] : "");

// create DataSourceContract for data point input
FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points");
@SuppressWarnings("unchecked")
FileDataSource pointsSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), dataPointInput, "Data Points");

// create DataSourceContract for cluster center input
FileDataSource clusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers");
@SuppressWarnings("unchecked")
FileDataSource clustersSource = new FileDataSource(new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class), clusterInput, "Centers");

MapOperator dataPoints = MapOperator.builder(new PointBuilder()).name("Build data points").input(pointsSource).build();

MapOperator clusterPoints = MapOperator.builder(new PointBuilder()).name("Build cluster points").input(clustersSource).build();

// create CrossOperator for distance computation
MapOperator findNearestClusterCenters = MapOperator.builder(new SelectNearestCenter())
Expand Down Expand Up @@ -175,6 +183,7 @@ public static final class SelectNearestCenter extends MapFunction implements Ser
public void open(Configuration parameters) throws Exception {
Collection<Record> clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers");

centers.clear();
for (Record r : clusterCenters) {
centers.add(new PointWithId(r.getField(0, IntValue.class).getValue(), r.getField(1, Point.class)));
}
Expand All @@ -190,7 +199,7 @@ public void open(Configuration parameters) throws Exception {
*/
@Override
public void map(Record dataPointRecord, Collector<Record> out) {
Point p = dataPointRecord.getField(0, Point.class);
Point p = dataPointRecord.getField(1, Point.class);

double nearestDistance = Double.MAX_VALUE;
int centerId = -1;
Expand All @@ -202,6 +211,7 @@ public void map(Record dataPointRecord, Collector<Record> out) {

// update nearest cluster if necessary
if (distance < nearestDistance) {
nearestDistance = distance;
centerId = center.id;
}
}
Expand Down Expand Up @@ -257,4 +267,40 @@ private final Record sumPointsAndCount(Iterator<Record> dataPoints) {
return next;
}
}

public static final class PointBuilder extends MapFunction {

private static final long serialVersionUID = 1L;

@Override
public void map(Record record, Collector<Record> out) throws Exception {
double x = record.getField(1, DoubleValue.class).getValue();
double y = record.getField(2, DoubleValue.class).getValue();
double z = record.getField(3, DoubleValue.class).getValue();

record.setField(1, new Point(x, y, z));
out.collect(record);
}
}

public static final class PointOutFormat extends DelimitedOutputFormat {

private static final long serialVersionUID = 1L;

private static final String format = "%d|%.1f|%.1f|%.1f|";

@Override
public int serializeRecord(Record rec, byte[] target) throws Exception {
int id = rec.getField(0, IntValue.class).getValue();
Point p = rec.getField(1, Point.class);

byte[] bytes = String.format(format, id, p.x, p.y, p.z).getBytes();
if (bytes.length > target.length) {
return -bytes.length;
} else {
System.arraycopy(bytes, 0, target, 0, bytes.length);
return bytes.length;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@
import org.apache.commons.logging.LogFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* The base class for all tasks able to participate in an iteration.
Expand All @@ -54,30 +51,28 @@ public abstract class AbstractIterativePactTask<S extends Function, OT> extends
implements Terminable
{
private static final Log log = LogFactory.getLog(AbstractIterativePactTask.class);


private final AtomicBoolean terminationRequested = new AtomicBoolean(false);

private RuntimeAggregatorRegistry iterationAggregators;

private List<Integer> iterativeInputs = new ArrayList<Integer>();

private String brokerKey;
protected LongSumAggregator worksetAggregator;

private int superstepNum = 1;
protected BlockingBackChannel worksetBackChannel;

protected boolean isWorksetIteration;

protected boolean isWorksetUpdate;

protected boolean isSolutionSetUpdate;


protected LongSumAggregator worksetAggregator;
private RuntimeAggregatorRegistry iterationAggregators;

protected BlockingBackChannel worksetBackChannel;
private String brokerKey;

private int superstepNum = 1;

private volatile boolean terminationRequested;

// --------------------------------------------------------------------------------------------
// Wrapping methods to supplement behavior of the regular Pact Task
// Main life cycle methods that implement the iterative behavior
// --------------------------------------------------------------------------------------------

@Override
Expand All @@ -93,8 +88,6 @@ protected void initialize() throws Exception {
excludeFromReset(i);
}
}
// initialize the repeatable driver
resDriver.initialize();
}

TaskConfig config = getLastTasksConfig();
Expand All @@ -118,9 +111,20 @@ protected void initialize() throws Exception {

@Override
public void run() throws Exception {
if (!inFirstIteration()) {
if (inFirstIteration()) {
if (this.driver instanceof ResettablePactDriver) {
// initialize the repeatable driver
((ResettablePactDriver<?, ?>) this.driver).initialize();
}
} else {
reinstantiateDriver();
resetAllInputs();

// re-read the iterative broadcast variables
for (int i : this.iterativeBroadcastInputs) {
final String name = getTaskConfig().getBroadcastInputName(i);
readAndSetBroadcastInput(i, name, this.runtimeUdfContext);
}
}

// call the parent to execute the superstep
Expand All @@ -138,35 +142,14 @@ protected void closeLocalStrategiesAndCaches() {
try {
resDriver.teardown();
} catch (Throwable t) {
log.error("Error shutting down a resettable driver.", t);
log.error("Error while shutting down an iterative operator.", t);
}
}
}
}

@Override
protected MutableObjectIterator<?> createInputIterator(int i, MutableReader<?> inputReader, TypeSerializer<?> serializer) {

final MutableObjectIterator<?> inIter = super.createInputIterator(i, inputReader, serializer);
final int numberOfEventsUntilInterrupt = getTaskConfig().getNumberOfEventsUntilInterruptInIterativeGate(i);

if (numberOfEventsUntilInterrupt < 0) {
throw new IllegalArgumentException();
}
else if (numberOfEventsUntilInterrupt > 0) {
inputReader.setIterative(numberOfEventsUntilInterrupt);
this.iterativeInputs.add(i);

if (log.isDebugEnabled()) {
log.debug(formatLogString("Input [" + i + "] reads in supersteps with [" +
+ numberOfEventsUntilInterrupt + "] event(s) till next superstep."));
}
}
return inIter;
}

@Override
public RuntimeUDFContext getRuntimeContext(String taskName) {
public RuntimeUDFContext createRuntimeContext(String taskName) {
Environment env = getEnvironment();
return new IterativeRuntimeUdfContext(taskName, env.getCurrentNumberOfSubtasks(), env.getIndexInSubtaskGroup());
}
Expand Down Expand Up @@ -223,7 +206,7 @@ public RuntimeAggregatorRegistry getIterationAggregators() {

protected void checkForTerminationAndResetEndOfSuperstepState() throws IOException {
// sanity check that there is at least one iterative input reader
if (this.iterativeInputs.isEmpty())
if (this.iterativeInputs.length == 0 && this.iterativeBroadcastInputs.length == 0)
throw new IllegalStateException();

// check whether this step ended due to end-of-superstep, or proper close
Expand Down Expand Up @@ -262,6 +245,23 @@ protected void checkForTerminationAndResetEndOfSuperstepState() throws IOExcepti
}
}
}

for (int inputNum : this.iterativeBroadcastInputs) {
MutableReader<?> reader = this.broadcastInputReaders[inputNum];

if (reader.isInputClosed()) {
anyClosed = true;
}
else {
// sanity check that the BC input is at the end of teh superstep
if (!reader.hasReachedEndOfSuperstep()) {
throw new IllegalStateException("An iterative broadcast input has not been fully consumed.");
}

allClosed = false;
reader.startNextSuperstep();
}
}

// sanity check whether we saw the same state (end-of-superstep or termination) on all inputs
if (allClosed != anyClosed) {
Expand All @@ -275,12 +275,12 @@ protected void checkForTerminationAndResetEndOfSuperstepState() throws IOExcepti

@Override
public boolean terminationRequested() {
return this.terminationRequested.get();
return this.terminationRequested;
}

@Override
public void requestTermination() {
this.terminationRequested.set(true);
this.terminationRequested = true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,10 @@ public void run() throws Exception {
solutionSet = initHashTable();

// read the initial solution set
@SuppressWarnings("unchecked")
MutableObjectIterator<X> solutionSetInput = (MutableObjectIterator<X>) createInputIterator(
initialSolutionSetInput, inputReaders[initialSolutionSetInput], solutionTypeSerializer);
// @SuppressWarnings("unchecked")
// MutableObjectIterator<X> solutionSetInput = (MutableObjectIterator<X>) createInputIterator(
// initialSolutionSetInput, inputReaders[initialSolutionSetInput], solutionTypeSerializer);
MutableObjectIterator<X> solutionSetInput = getInput(initialSolutionSetInput);
readInitialSolutionSet(solutionSet, solutionSetInput);

SolutionSetBroker.instance().handIn(brokerKey, solutionSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
protected volatile MutableHashTable<?, ?> hashJoin;


protected abstract int getBuildSideIndex();
private final int buildSideIndex;

protected abstract int getProbeSideIndex();
private final int probeSideIndex;

protected AbstractCachedBuildSideMatchDriver(int buildSideIndex, int probeSideIndex) {
this.buildSideIndex = buildSideIndex;
this.probeSideIndex = probeSideIndex;
}

// --------------------------------------------------------------------------------------------

Expand All @@ -45,7 +50,7 @@ public boolean isInputResettable(int inputNum) {
if (inputNum < 0 || inputNum > 1) {
throw new IndexOutOfBoundsException();
}
return inputNum == getBuildSideIndex();
return inputNum == buildSideIndex;
}

@Override
Expand All @@ -66,12 +71,12 @@ public void initialize() throws Exception {
List<MemorySegment> memSegments = this.taskContext.getMemoryManager().allocatePages(
this.taskContext.getOwningNepheleTask(), numMemoryPages);

if (getBuildSideIndex() == 0 && getProbeSideIndex() == 1) {
if (buildSideIndex == 0 && probeSideIndex == 1) {
MutableHashTable<IT1, IT2> hashJoin = new MutableHashTable<IT1, IT2>(serializer1, serializer2, comparator1, comparator2,
pairComparatorFactory.createComparator21(comparator1, comparator2), memSegments, this.taskContext.getIOManager());
this.hashJoin = hashJoin;
hashJoin.open(input1, EmptyMutableObjectIterator.<IT2>get());
} else if (getBuildSideIndex() == 1 && getProbeSideIndex() == 0) {
} else if (buildSideIndex == 1 && probeSideIndex == 0) {
MutableHashTable<IT2, IT1> hashJoin = new MutableHashTable<IT2, IT1>(serializer2, serializer1, comparator2, comparator1,
pairComparatorFactory.createComparator12(comparator1, comparator2), memSegments, this.taskContext.getIOManager());
this.hashJoin = hashJoin;
Expand All @@ -92,7 +97,7 @@ public void run() throws Exception {
final GenericJoiner<IT1, IT2, OT> matchStub = this.taskContext.getStub();
final Collector<OT> collector = this.taskContext.getOutputCollector();

if (getBuildSideIndex() == 0) {
if (buildSideIndex == 0) {
final TypeSerializer<IT1> buildSideSerializer = taskContext.<IT1> getInputSerializer(0);
final TypeSerializer<IT2> probeSideSerializer = taskContext.<IT2> getInputSerializer(1);

Expand All @@ -117,7 +122,7 @@ public void run() throws Exception {
matchStub.join(buildSideRecordFirst, probeSideRecord, collector);
}
}
} else if (getBuildSideIndex() == 1) {
} else if (buildSideIndex == 1) {
final TypeSerializer<IT2> buildSideSerializer = taskContext.<IT2> getInputSerializer(1);
final TypeSerializer<IT1> probeSideSerializer = taskContext.<IT1> getInputSerializer(0);

Expand Down
Loading

0 comments on commit 90846d7

Please sign in to comment.