Skip to content

Commit

Permalink
[FLINK-1167] Handle unions at the root of the iteration
Browse files Browse the repository at this point in the history
This closes apache#160
  • Loading branch information
StephanEwen authored and rmetzger committed Oct 18, 2014
1 parent 867e3a5 commit 259f10c
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void setNextPartialSolution(OptimizerNode nextPartialSolution, OptimizerN
// check if the root of the step function has the same DOP as the iteration
// or if the steo function has any operator at all
if (nextPartialSolution.getDegreeOfParallelism() != getDegreeOfParallelism() ||
nextPartialSolution == partialSolution)
nextPartialSolution == partialSolution || nextPartialSolution instanceof BinaryUnionNode)
{
// add a no-op to the root to express the re-partitioning
NoOpNode noop = new NoOpNode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public void setNextPartialSolution(OptimizerNode solutionSetDelta, OptimizerNode

// there needs to be at least one node in the workset path, so
// if the next workset is equal to the workset, we need to inject a no-op node
if (nextWorkset == worksetNode) {
if (nextWorkset == worksetNode || nextWorkset instanceof BinaryUnionNode) {
NoOpNode noop = new NoOpNode();
noop.setDegreeOfParallelism(getDegreeOfParallelism());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
* limitations under the License.
*/


package org.apache.flink.compiler.plandump;

import java.io.File;
Expand All @@ -26,7 +25,6 @@
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand All @@ -47,7 +45,6 @@
import org.apache.flink.compiler.dataproperties.LocalProperties;
import org.apache.flink.compiler.plan.BulkIterationPlanNode;
import org.apache.flink.compiler.plan.Channel;
import org.apache.flink.compiler.plan.NAryUnionPlanNode;
import org.apache.flink.compiler.plan.OptimizedPlan;
import org.apache.flink.compiler.plan.PlanNode;
import org.apache.flink.compiler.plan.SingleInputPlanNode;
Expand Down Expand Up @@ -265,121 +262,104 @@ private boolean visit(DumpableNode<?> node, PrintWriter writer, boolean first) {
if (inConns != null && inConns.hasNext()) {
// start predecessor list
writer.print(",\n\t\t\"predecessors\": [");
int connNum = 0;
int inputNum = 0;

while (inConns.hasNext()) {
final DumpableConnection<?> conn = inConns.next();

final Collection<DumpableConnection<?>> inConnsForInput;
if (conn.getSource() instanceof NAryUnionPlanNode) {
inConnsForInput = new ArrayList<DumpableConnection<?>>();
final DumpableConnection<?> inConn = inConns.next();
final DumpableNode<?> source = inConn.getSource();
writer.print(inputNum == 0 ? "\n" : ",\n");
if (inputNum == 0) {
child1name += child1name.length() > 0 ? ", " : "";
child1name += source.getOptimizerNode().getPactContract().getName();
} else if (inputNum == 1) {
child2name += child2name.length() > 0 ? ", " : "";
child2name = source.getOptimizerNode().getPactContract().getName();
}

// output predecessor id
writer.print("\t\t\t{\"id\": " + this.nodeIds.get(source));

// output connection side
if (inConns.hasNext() || inputNum > 0) {
writer.print(", \"side\": \"" + (inputNum == 0 ? "first" : "second") + "\"");
}
// output shipping strategy and channel type
final Channel channel = (inConn instanceof Channel) ? (Channel) inConn : null;
final ShipStrategyType shipType = channel != null ? channel.getShipStrategy() :
((PactConnection) inConn).getShipStrategy();

for (DumpableConnection<?> inputOfUnion : conn.getSource().getDumpableInputs()) {
inConnsForInput.add(inputOfUnion);
String shipStrategy = null;
if (shipType != null) {
switch (shipType) {
case NONE:
// nothing
break;
case FORWARD:
shipStrategy = "Forward";
break;
case BROADCAST:
shipStrategy = "Broadcast";
break;
case PARTITION_HASH:
shipStrategy = "Hash Partition";
break;
case PARTITION_RANGE:
shipStrategy = "Range Partition";
break;
case PARTITION_RANDOM:
shipStrategy = "Redistribute";
break;
case PARTITION_FORCED_REBALANCE:
shipStrategy = "Rebalance";
break;
default:
throw new CompilerException("Unknown ship strategy '" + inConn.getShipStrategy().name()
+ "' in JSON generator.");
}
}
else {
inConnsForInput = Collections.<DumpableConnection<?>>singleton(conn);

if (channel != null && channel.getShipStrategyKeys() != null && channel.getShipStrategyKeys().size() > 0) {
shipStrategy += " on " + (channel.getShipStrategySortOrder() == null ?
channel.getShipStrategyKeys().toString() :
Utils.createOrdering(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder()).toString());
}

if (shipStrategy != null) {
writer.print(", \"ship_strategy\": \"" + shipStrategy + "\"");
}

for (DumpableConnection<?> inConn : inConnsForInput) {
final DumpableNode<?> source = inConn.getSource();
writer.print(connNum == 0 ? "\n" : ",\n");
if (connNum == 0) {
child1name += child1name.length() > 0 ? ", " : "";
child1name += source.getOptimizerNode().getPactContract().getName();
} else if (connNum == 1) {
child2name += child2name.length() > 0 ? ", " : "";
child2name = source.getOptimizerNode().getPactContract().getName();
}

// output predecessor id
writer.print("\t\t\t{\"id\": " + this.nodeIds.get(source));

// output connection side
if (inConns.hasNext() || inputNum > 0) {
writer.print(", \"side\": \"" + (inputNum == 0 ? "first" : "second") + "\"");
}
// output shipping strategy and channel type
final Channel channel = (inConn instanceof Channel) ? (Channel) inConn : null;
final ShipStrategyType shipType = channel != null ? channel.getShipStrategy() :
((PactConnection) inConn).getShipStrategy();

String shipStrategy = null;
if (shipType != null) {
switch (shipType) {
case NONE:
// nothing
break;
case FORWARD:
shipStrategy = "Forward";
break;
case BROADCAST:
shipStrategy = "Broadcast";
break;
case PARTITION_HASH:
shipStrategy = "Hash Partition";
break;
case PARTITION_RANGE:
shipStrategy = "Range Partition";
break;
case PARTITION_RANDOM:
shipStrategy = "Redistribute";
break;
case PARTITION_FORCED_REBALANCE:
shipStrategy = "Rebalance";
break;
default:
throw new CompilerException("Unknown ship strategy '" + conn.getShipStrategy().name()
+ "' in JSON generator.");
}
if (channel != null) {
String localStrategy = null;
switch (channel.getLocalStrategy()) {
case NONE:
break;
case SORT:
localStrategy = "Sort";
break;
case COMBININGSORT:
localStrategy = "Sort (combining)";
break;
default:
throw new CompilerException("Unknown local strategy " + channel.getLocalStrategy().name());
}

if (channel != null && channel.getShipStrategyKeys() != null && channel.getShipStrategyKeys().size() > 0) {
shipStrategy += " on " + (channel.getShipStrategySortOrder() == null ?
channel.getShipStrategyKeys().toString() :
Utils.createOrdering(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder()).toString());
}

if (shipStrategy != null) {
writer.print(", \"ship_strategy\": \"" + shipStrategy + "\"");
if (channel != null && channel.getLocalStrategyKeys() != null && channel.getLocalStrategyKeys().size() > 0) {
localStrategy += " on " + (channel.getLocalStrategySortOrder() == null ?
channel.getLocalStrategyKeys().toString() :
Utils.createOrdering(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder()).toString());
}

if (channel != null) {
String localStrategy = null;
switch (channel.getLocalStrategy()) {
case NONE:
break;
case SORT:
localStrategy = "Sort";
break;
case COMBININGSORT:
localStrategy = "Sort (combining)";
break;
default:
throw new CompilerException("Unknown local strategy " + channel.getLocalStrategy().name());
}

if (channel != null && channel.getLocalStrategyKeys() != null && channel.getLocalStrategyKeys().size() > 0) {
localStrategy += " on " + (channel.getLocalStrategySortOrder() == null ?
channel.getLocalStrategyKeys().toString() :
Utils.createOrdering(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder()).toString());
}

if (localStrategy != null) {
writer.print(", \"local_strategy\": \"" + localStrategy + "\"");
}

if (channel != null && channel.getTempMode() != TempMode.NONE) {
String tempMode = channel.getTempMode().toString();
writer.print(", \"temp_mode\": \"" + tempMode + "\"");
}
if (localStrategy != null) {
writer.print(", \"local_strategy\": \"" + localStrategy + "\"");
}

writer.print('}');
connNum++;
if (channel != null && channel.getTempMode() != TempMode.NONE) {
String tempMode = channel.getTempMode().toString();
writer.print(", \"temp_mode\": \"" + tempMode + "\"");
}
}

writer.print('}');
inputNum++;
}
// finish predecessors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.compiler.CompilerTestBase;
import org.apache.flink.compiler.plan.OptimizedPlan;
import org.apache.flink.compiler.plandump.PlanJSONDumpGenerator;
import org.apache.flink.compiler.plantranslate.NepheleJobGraphGenerator;
import org.apache.flink.compiler.testfunctions.IdentityMapper;
import org.junit.Test;

@SuppressWarnings("serial")
Expand All @@ -57,7 +57,7 @@ public void testIdentityIteration() {
}

@Test
public void testIdentityWorksetIteration() {
public void testEmptyWorksetIteration() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(43);
Expand All @@ -76,7 +76,63 @@ public void testIdentityWorksetIteration() {
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);

System.out.println(new PlanJSONDumpGenerator().getOptimizerPlanAsJSON(op));
new NepheleJobGraphGenerator().compileJobGraph(op);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}

@Test
public void testIterationWithUnionRoot() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(43);

IterativeDataSet<Long> iteration = env.generateSequence(-4, 1000).iterate(100);

iteration.closeWith(
iteration.map(new IdentityMapper<Long>()).union(iteration.map(new IdentityMapper<Long>())))
.print();

Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);

new NepheleJobGraphGenerator().compileJobGraph(op);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}

@Test
public void testWorksetIterationWithUnionRoot() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(43);

DataSet<Tuple2<Long, Long>> input = env.generateSequence(1, 20)
.map(new MapFunction<Long, Tuple2<Long, Long>>() {
@Override
public Tuple2<Long, Long> map(Long value){ return null; }
});


DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iter = input.iterateDelta(input, 100, 0);
iter.closeWith(
iter.getWorkset().map(new IdentityMapper<Tuple2<Long,Long>>())
.union(
iter.getWorkset().map(new IdentityMapper<Tuple2<Long,Long>>()))
, iter.getWorkset().map(new IdentityMapper<Tuple2<Long,Long>>())
.union(
iter.getWorkset().map(new IdentityMapper<Tuple2<Long,Long>>()))
)
.print();

Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);

new NepheleJobGraphGenerator().compileJobGraph(op);
}
Expand Down

0 comments on commit 259f10c

Please sign in to comment.