Skip to content

Commit

Permalink
Added API for broadcast variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
aalexandrov authored and StephanEwen committed Feb 13, 2014
1 parent 5300f4b commit 4c93530
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
**********************************************************************************************************************/
package eu.stratosphere.api.common.functions;

import java.util.Collection;
import java.util.HashMap;

import eu.stratosphere.api.common.accumulators.Accumulator;
import eu.stratosphere.api.common.accumulators.DoubleCounter;
import eu.stratosphere.api.common.accumulators.Histogram;
import eu.stratosphere.api.common.accumulators.IntCounter;
import eu.stratosphere.api.common.accumulators.LongCounter;
import eu.stratosphere.types.Record;

/**
*
Expand All @@ -30,6 +32,8 @@ public interface RuntimeContext {
int getNumberOfParallelSubtasks();

int getIndexOfThisSubtask();

// --------------------------------------------------------------------------------------------

/**
* Add this accumulator. Throws an exception if the counter is already
Expand All @@ -56,18 +60,23 @@ public interface RuntimeContext {
HashMap<String, Accumulator<?, ?>> getAllAccumulators();

/**
* Convenience function to create a counter object for integers. This
* creates an accumulator object for double values internally.
*
* @param name
* @return
* Convenience function to create a counter object for integers.
*/
IntCounter getIntCounter(String name);

/**
* Convenience function to create a counter object for longs.
*/
LongCounter getLongCounter(String name);

/**
* Convenience function to create a counter object for doubles.
*/
DoubleCounter getDoubleCounter(String name);

/**
* Convenience function to create a counter object for histograms.
*/
Histogram getHistogram(String name);

// /**
Expand Down Expand Up @@ -97,5 +106,18 @@ public interface RuntimeContext {
// */
// <T> SimpleAccumulator<T> getSimpleAccumulator(String name,
// Class<? extends SimpleAccumulator<T>> accumulatorClass);

// --------------------------------------------------------------------------------------------

/**
* Sets the value of the broadcast variable identified by the given
* {@code name}.
*/
void setBroadcastVariable(String name, Collection<?> value);

/**
* Returns the result bound to the broadcast variable identified by the
* given {@code name}.
*/
<RT> Collection<RT> getBroadcastVariable(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

package eu.stratosphere.api.common.operators;

import java.util.HashMap;
import java.util.Map;

import eu.stratosphere.api.common.functions.Function;
import eu.stratosphere.api.common.operators.util.UserCodeWrapper;

/**
* Abstract superclass for all contracts that represent actual Pacts.
* Abstract superclass for all contracts that represent actual operators.
*/
public abstract class AbstractUdfOperator<T extends Function> extends Operator {

Expand All @@ -26,29 +29,34 @@ public abstract class AbstractUdfOperator<T extends Function> extends Operator {
*/
protected final UserCodeWrapper<T> stub;

/**
* The extra inputs which parameterize the user function.
*/
protected final Map<String, Operator> broadcastInputs = new HashMap<String, Operator>();

// --------------------------------------------------------------------------------------------

/**
* Creates a new abstract Pact with the given name wrapping the given user function.
* Creates a new abstract operator with the given name wrapping the given user function.
*
* @param stub The object containing the user function.
* @param name The given name for the Pact, used in plans, logs and progress messages.
* @param name The given name for the operator, used in plans, logs and progress messages.
*/
protected AbstractUdfOperator(UserCodeWrapper<T> stub, String name) {
super(name);
this.stub = stub;
}

// --------------------------------------------------------------------------------------------

/**
* Gets the stub that is wrapped by this contract. The stub is the actual implementation of the
* user code.
*
* This throws an exception if the pact does not contain an object but a class for the user
* code.
*
* @return The object with the user function for this Pact.
* @return The object with the user function for this operator.
*
* @see eu.stratosphere.api.common.operators.Operator#getUserCodeWrapper()
*/
Expand All @@ -59,10 +67,53 @@ public UserCodeWrapper<T> getUserCodeWrapper() {

// --------------------------------------------------------------------------------------------

// TODO: add delegates for the parameter input setters to the Operator builders

/**
* Returns the input, or null, if none is set.
*
* @return The broadcast input root operator.
*/
public Map<String, Operator> getBroadcastInputs() {
return this.broadcastInputs;
}

/**
* Binds the result produced by a plan rooted at {@code root} to a variable
* used by the UDF wrapped in this operator.
*
* @param root The root of the plan producing this input.
*/
public void setBroadcastVariable(String name, Operator root) {
if (name == null) {
throw new IllegalArgumentException("The broadcast input name may not be null.");
}
if (root == null) {
throw new IllegalArgumentException("The broadcast input root operator may not be null.");
}

this.broadcastInputs.put(name, root);
}

/**
* Clears all previous broadcast inputs and binds the given inputs as
* broadcast variables of this operator.
*
* @param inputs The <name, root> pairs to be set as broadcast inputs.
*/
public void setBroadcastVariables(Map<String, Operator> roots) {
this.broadcastInputs.clear();
for (Map.Entry<String, Operator> e: roots.entrySet()) {
setBroadcastVariable(e.getKey(), e.getValue());
}
}

// --------------------------------------------------------------------------------------------

/**
* Gets the number of inputs for this Pact.
* Gets the number of inputs for this operator.
*
* @return The number of inputs for this Pact.
* @return The number of inputs for this operator.
*/
public abstract int getNumberOfInputs();

Expand All @@ -84,7 +135,7 @@ public UserCodeWrapper<T> getUserCodeWrapper() {
*/
protected static final <U> Class<U>[] asArray(Class<U> clazz) {
@SuppressWarnings("unchecked")
Class<U>[] array = (Class<U>[]) new Class[] { clazz };
Class<U>[] array = new Class[] { clazz };
return array;
}

Expand All @@ -96,7 +147,7 @@ protected static final <U> Class<U>[] asArray(Class<U> clazz) {
*/
protected static final <U> Class<U>[] emptyClassArray() {
@SuppressWarnings("unchecked")
Class<U>[] array = (Class<U>[]) new Class[0];
Class<U>[] array = new Class[0];
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ public void accept(Visitor<Operator> visitor) {
for (Operator c : this.input2) {
c.accept(visitor);
}
for (Operator c : this.broadcastInputs.values()) {
c.accept(visitor);
}
visitor.postVisit(this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ public void accept(Visitor<Operator> visitor) {
for (Operator c : this.input) {
c.accept(visitor);
}
for (Operator c : this.broadcastInputs.values()) {
c.accept(visitor);
}
visitor.postVisit(this);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public Record(Value value) {
}

/**
* Creates a new record containing exactly to fields, which are the given values.
* Creates a new record containing exactly two fields, which are the given values.
*
* @param val1 The value for the first field.
* @param val2 The value for the second field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import eu.stratosphere.api.common.operators.Operator;
import eu.stratosphere.api.common.operators.Ordering;
Expand All @@ -27,7 +29,6 @@
import eu.stratosphere.api.common.operators.util.UserCodeObjectWrapper;
import eu.stratosphere.api.common.operators.util.UserCodeWrapper;
import eu.stratosphere.api.java.record.functions.CoGroupFunction;
import eu.stratosphere.api.java.record.functions.JoinFunction;
import eu.stratosphere.types.Key;

/**
Expand Down Expand Up @@ -92,6 +93,7 @@ protected CoGroupOperator(Builder builder) {
this.keyTypes = builder.getKeyClassesArray();
setFirstInputs(builder.inputs1);
setSecondInputs(builder.inputs2);
setBroadcastVariables(builder.broadcastInputs);
setGroupOrderForInputOne(builder.secondaryOrder1);
setGroupOrderForInputTwo(builder.secondaryOrder2);
}
Expand Down Expand Up @@ -209,6 +211,7 @@ public static class Builder {
/* The optional parameters */
private List<Operator> inputs1;
private List<Operator> inputs2;
private Map<String, Operator> broadcastInputs;
private Ordering secondaryOrder1 = null;
private Ordering secondaryOrder2 = null;
private String name = DEFAULT_NAME;
Expand All @@ -233,10 +236,9 @@ protected Builder(UserCodeWrapper<CoGroupFunction> udf, Class<? extends Key> key
this.keyColumns2.add(keyColumn2);
this.inputs1 = new ArrayList<Operator>();
this.inputs2 = new ArrayList<Operator>();
this.broadcastInputs = new HashMap<String, Operator>();
}



/**
* Creates a Builder with the provided {@link JoinFunction} implementation. This method is intended
* for special case sub-types only.
Expand All @@ -250,6 +252,7 @@ protected Builder(UserCodeWrapper<CoGroupFunction> udf) {
this.keyColumns2 = new ArrayList<Integer>();
this.inputs1 = new ArrayList<Operator>();
this.inputs2 = new ArrayList<Operator>();
this.broadcastInputs = new HashMap<String, Operator>();
}

private int[] getKeyColumnsArray1() {
Expand Down Expand Up @@ -352,6 +355,24 @@ public Builder inputs2(List<Operator> inputs) {
return this;
}

/**
* Binds the result produced by a plan rooted at {@code root} to a
* variable used by the UDF wrapped in this operator.
*/
public Builder setBroadcastVariable(String name, Operator input) {
this.broadcastInputs.put(name, input);
return this;
}

/**
* Binds multiple broadcast variables.
*/
public Builder setBroadcastVariables(Map<String, Operator> inputs) {
this.broadcastInputs.clear();
this.broadcastInputs.putAll(inputs);
return this;
}

/**
* Sets the name of this contract.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
package eu.stratosphere.api.java.record.operators;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import eu.stratosphere.api.common.operators.Operator;
import eu.stratosphere.api.common.operators.base.CrossOperatorBase;
Expand Down Expand Up @@ -62,6 +64,7 @@ protected CrossOperator(Builder builder) {
super(builder.udf, builder.name);
setFirstInputs(builder.inputs1);
setSecondInputs(builder.inputs2);
setBroadcastVariables(builder.broadcastInputs);
}


Expand All @@ -83,6 +86,7 @@ public static class Builder {
/* The optional parameters */
private List<Operator> inputs1;
private List<Operator> inputs2;
private Map<String, Operator> broadcastInputs;
private String name = DEFAULT_NAME;

/**
Expand All @@ -94,6 +98,7 @@ protected Builder(UserCodeWrapper<CrossFunction> udf) {
this.udf = udf;
this.inputs1 = new ArrayList<Operator>();
this.inputs2 = new ArrayList<Operator>();
this.broadcastInputs = new HashMap<String, Operator>();
}

/**
Expand Down Expand Up @@ -142,6 +147,24 @@ public Builder inputs2(List<Operator> inputs) {
return this;
}

/**
* Binds the result produced by a plan rooted at {@code root} to a
* variable used by the UDF wrapped in this operator.
*/
public Builder setBroadcastVariable(String name, Operator input) {
this.broadcastInputs.put(name, input);
return this;
}

/**
* Binds multiple broadcast variables.
*/
public Builder setBroadcastVariables(Map<String, Operator> inputs) {
this.broadcastInputs.clear();
this.broadcastInputs.putAll(inputs);
return this;
}

/**
* Sets the name of this contract.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private Builder(UserCodeWrapper<CrossFunction> udf) {
*
* @return The created contract
*/
@Override
public CrossWithLargeOperator build() {
return new CrossWithLargeOperator(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private Builder(UserCodeWrapper<CrossFunction> udf) {
*
* @return The created contract
*/
@Override
public CrossWithSmallOperator build() {
return new CrossWithSmallOperator(this);
}
Expand Down
Loading

0 comments on commit 4c93530

Please sign in to comment.