diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java index 296df132b89b7..5cd3bb5438d0c 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java @@ -617,7 +617,7 @@ public IterativeDataStream iterate(long maxWaitTimeMillis) { */ public SingleOutputStreamOperator fold(R initialValue, FoldFunction folder) { TypeInformation outType = TypeExtractor.getFoldReturnTypes(clean(folder), getType(), - Utils.getCallLocationName(), false); + Utils.getCallLocationName(), true); return transform("Fold", outType, new StreamFold(clean(folder), initialValue, outType)); diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/JSONGenerator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/JSONGenerator.java index 4ae2f97d27547..bc20fff126776 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/JSONGenerator.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/JSONGenerator.java @@ -73,7 +73,7 @@ private void visit(JSONArray jsonArray, List toVisit, node.put(PREDECESSORS, inputs); for (StreamEdge inEdge : vertex.getInEdges()) { - int inputID = inEdge.getSourceID(); + int inputID = inEdge.getSourceId(); Integer mappedID = (edgeRemapings.keySet().contains(inputID)) ? edgeRemapings .get(inputID) : inputID; @@ -85,7 +85,7 @@ private void visit(JSONArray jsonArray, List toVisit, } else { Integer iterationHead = -1; for (StreamEdge inEdge : vertex.getInEdges()) { - int operator = inEdge.getSourceID(); + int operator = inEdge.getSourceId(); if (streamGraph.vertexIDtoLoop.containsKey(operator)) { iterationHead = operator; @@ -127,7 +127,7 @@ private void visitIteration(JSONArray jsonArray, List toVisit, int head obj.put(PREDECESSORS, inEdges); for (StreamEdge inEdge : vertex.getInEdges()) { - int inputID = inEdge.getSourceID(); + int inputID = inEdge.getSourceId(); if (edgeRemapings.keySet().contains(inputID)) { decorateEdge(inEdges, vertexID, inputID, inputID); diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java index 3b000003c86aa..329b4dd304a52 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java @@ -422,7 +422,7 @@ public String toString() { builder.append("\nOutput names: " + getNonChainedOutputs(cl)); builder.append("\nPartitioning:"); for (StreamEdge output : getNonChainedOutputs(cl)) { - int outputname = output.getTargetID(); + int outputname = output.getTargetId(); builder.append("\n\t" + outputname + ": " + output.getPartitioner()); } diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java index d34b21a4b1426..293f5e06b9961 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamEdge.java @@ -68,12 +68,12 @@ public StreamNode getTargetVertex() { return targetVertex; } - public int getSourceID() { - return sourceVertex.getID(); + public int getSourceId() { + return sourceVertex.getId(); } - public int getTargetID() { - return targetVertex.getID(); + public int getTargetId() { + return targetVertex.getId(); } public int getTypeNumber() { diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java index 95cbc2317d663..ade1c6e4d9be0 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java @@ -232,7 +232,7 @@ public void addIterationTail(Integer sinkID, Integer iterationTail, Integer iter getStreamNode(sinkID).setOperatorName("IterationTail-" + iterationTail); iteration.getSource().setParallelism(iteration.getSink().getParallelism()); - setBufferTimeout(iteration.getSource().getID(), getStreamNode(iterationTail) + setBufferTimeout(iteration.getSource().getId(), getStreamNode(iterationTail) .getBufferTimeout()); if (LOG.isDebugEnabled()) { @@ -257,8 +257,8 @@ public void addEdge(Integer upStreamVertexID, Integer downStreamVertexID, StreamEdge edge = new StreamEdge(getStreamNode(upStreamVertexID), getStreamNode(downStreamVertexID), typeNumber, outputNames, partitionerObject); - getStreamNode(edge.getSourceID()).addOutEdge(edge); - getStreamNode(edge.getTargetID()).addInEdge(edge); + getStreamNode(edge.getSourceId()).addOutEdge(edge); + getStreamNode(edge.getTargetId()).addInEdge(edge); } public void addOutputSelector(Integer vertexID, OutputSelector outputSelector) { @@ -335,7 +335,7 @@ public StreamEdge getStreamEdge(int sourceId, int targetId) { while (outIterator.hasNext()) { StreamEdge edge = outIterator.next(); - if (edge.getTargetID() == targetId) { + if (edge.getTargetId() == targetId) { return edge; } } @@ -354,7 +354,7 @@ public Collection getStreamNodes() { public Set>> getOperators() { Set>> operatorSet = new HashSet>>(); for (StreamNode vertex : streamNodes.values()) { - operatorSet.add(new Tuple2>(vertex.getID(), vertex + operatorSet.add(new Tuple2>(vertex.getId(), vertex .getOperator())); } return operatorSet; @@ -389,7 +389,7 @@ protected void removeVertex(StreamNode toRemove) { for (StreamEdge edge : edgesToRemove) { removeEdge(edge); } - streamNodes.remove(toRemove.getID()); + streamNodes.remove(toRemove.getId()); } /** @@ -462,7 +462,7 @@ public static enum ResourceStrategy { * Object for representing loops in streaming programs. * */ - protected static class StreamLoop { + public static class StreamLoop { private Integer loopID; diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java index ddc71cbc2e53c..ccca2f1b821b8 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamNode.java @@ -42,7 +42,7 @@ public class StreamNode implements Serializable { transient private StreamExecutionEnvironment env; - private Integer ID; + private Integer id; private Integer parallelism = null; private Long bufferTimeout = null; private String operatorName; @@ -62,11 +62,11 @@ public class StreamNode implements Serializable { private InputFormat inputFormat; - public StreamNode(StreamExecutionEnvironment env, Integer ID, StreamOperator operator, + public StreamNode(StreamExecutionEnvironment env, Integer id, StreamOperator operator, String operatorName, List> outputSelector, Class jobVertexClass) { this.env = env; - this.ID = ID; + this.id = id; this.operatorName = operatorName; this.operator = operator; this.outputSelectors = outputSelector; @@ -75,16 +75,16 @@ public StreamNode(StreamExecutionEnvironment env, Integer ID, StreamOperator } public void addInEdge(StreamEdge inEdge) { - if (inEdge.getTargetID() != getID()) { - throw new IllegalArgumentException("Destination ID doesn't match the StreamNode ID"); + if (inEdge.getTargetId() != getId()) { + throw new IllegalArgumentException("Destination id doesn't match the StreamNode id"); } else { inEdges.add(inEdge); } } public void addOutEdge(StreamEdge outEdge) { - if (outEdge.getSourceID() != getID()) { - throw new IllegalArgumentException("Source ID doesn't match the StreamNode ID"); + if (outEdge.getSourceId() != getId()) { + throw new IllegalArgumentException("Source id doesn't match the StreamNode id"); } else { outEdges.add(outEdge); } @@ -102,7 +102,7 @@ public List getOutEdgeIndices() { List outEdgeIndices = new ArrayList(); for (StreamEdge edge : outEdges) { - outEdgeIndices.add(edge.getTargetID()); + outEdgeIndices.add(edge.getTargetId()); } return outEdgeIndices; @@ -112,14 +112,14 @@ public List getInEdgeIndices() { List inEdgeIndices = new ArrayList(); for (StreamEdge edge : inEdges) { - inEdgeIndices.add(edge.getSourceID()); + inEdgeIndices.add(edge.getSourceId()); } return inEdgeIndices; } - public Integer getID() { - return ID; + public Integer getId() { + return id; } public int getParallelism() { @@ -216,7 +216,7 @@ public void isolateSlot() { @Override public String toString() { - return operatorName + ID; + return operatorName + id; } } diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 1670a479b1a37..9e12a682f8da3 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -109,7 +109,7 @@ private void setPhysicalEdges() { Map> physicalInEdgesInOrder = new HashMap>(); for (StreamEdge edge : physicalEdgesInOrder) { - int target = edge.getTargetID(); + int target = edge.getTargetId(); List inEdges = physicalInEdgesInOrder.get(target); @@ -154,12 +154,12 @@ private List createChain(Integer startNode, Integer current) { } for (StreamEdge chainable : chainableOutputs) { - transitiveOutEdges.addAll(createChain(startNode, chainable.getTargetID())); + transitiveOutEdges.addAll(createChain(startNode, chainable.getTargetId())); } for (StreamEdge nonChainable : nonChainableOutputs) { transitiveOutEdges.add(nonChainable); - createChain(nonChainable.getTargetID(), nonChainable.getTargetID()); + createChain(nonChainable.getTargetId(), nonChainable.getTargetId()); } chainedNames.put(current, createChainedName(current, chainableOutputs)); @@ -203,14 +203,14 @@ private String createChainedName(Integer vertexID, List chainedOutpu if (chainedOutputs.size() > 1) { List outputChainedNames = new ArrayList(); for (StreamEdge chainable : chainedOutputs) { - outputChainedNames.add(chainedNames.get(chainable.getTargetID())); + outputChainedNames.add(chainedNames.get(chainable.getTargetId())); } String returnOperatorName = operatorName + " -> (" + StringUtils.join(outputChainedNames, ", ") + ")"; return returnOperatorName; } else if (chainedOutputs.size() == 1) { String returnOperatorName = operatorName + " -> " - + chainedNames.get(chainedOutputs.get(0).getTargetID()); + + chainedNames.get(chainedOutputs.get(0).getTargetId()); return returnOperatorName; } else { return operatorName; @@ -281,8 +281,8 @@ private void setVertexConfig(Integer vertexID, StreamConfig config, allOutputs.addAll(nonChainableOutputs); for (StreamEdge output : allOutputs) { - config.setSelectedNames(output.getTargetID(), - streamGraph.getStreamEdge(vertexID, output.getTargetID()).getSelectedNames()); + config.setSelectedNames(output.getTargetId(), + streamGraph.getStreamEdge(vertexID, output.getTargetId()).getSelectedNames()); } vertexConfigs.put(vertexID, config); @@ -292,7 +292,7 @@ private void connect(Integer headOfChain, StreamEdge edge) { physicalEdgesInOrder.add(edge); - Integer downStreamvertexID = edge.getTargetID(); + Integer downStreamvertexID = edge.getTargetId(); AbstractJobVertex headVertex = jobVertices.get(headOfChain); AbstractJobVertex downStreamVertex = jobVertices.get(downStreamvertexID); @@ -358,8 +358,8 @@ private void setSlotSharing() { for (StreamLoop loop : streamGraph.getStreamLoops()) { CoLocationGroup ccg = new CoLocationGroup(); - AbstractJobVertex tail = jobVertices.get(loop.getSink().getID()); - AbstractJobVertex head = jobVertices.get(loop.getSource().getID()); + AbstractJobVertex tail = jobVertices.get(loop.getSink().getId()); + AbstractJobVertex head = jobVertices.get(loop.getSource().getId()); ccg.addVertex(head); ccg.addVertex(tail); tail.updateCoLocationGroup(ccg); diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/WindowingOptimizer.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/WindowingOptimizer.java index 92043e77f64d1..dce7684c8eda3 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/WindowingOptimizer.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/graph/WindowingOptimizer.java @@ -52,9 +52,9 @@ private static void removeMergeBeforeFlatten(StreamGraph streamGraph) { } } - for (Integer flattenerID : flatteners) { + for (Integer flattenerId : flatteners) { // Flatteners should have exactly one input - StreamNode input = streamGraph.getStreamNode(flattenerID).getInEdges().get(0) + StreamNode input = streamGraph.getStreamNode(flattenerId).getInEdges().get(0) .getSourceVertex(); // Check whether the flatten is applied after a merge @@ -64,18 +64,18 @@ private static void removeMergeBeforeFlatten(StreamGraph streamGraph) { StreamNode mergeInput = input.getInEdges().get(0).getSourceVertex(); // We connect the merge input to the flattener directly - streamGraph.addEdge(mergeInput.getID(), flattenerID, + streamGraph.addEdge(mergeInput.getId(), flattenerId, new RebalancePartitioner(true), 0, new ArrayList()); // If the merger is only connected to the flattener we delete it // completely, otherwise we only remove the edge if (input.getOutEdges().size() > 1) { - streamGraph.removeEdge(streamGraph.getStreamEdge(input.getID(), flattenerID)); + streamGraph.removeEdge(streamGraph.getStreamEdge(input.getId(), flattenerId)); } else { streamGraph.removeVertex(input); } - streamGraph.setParallelism(flattenerID, mergeInput.getParallelism()); + streamGraph.setParallelism(flattenerId, mergeInput.getParallelism()); } } @@ -137,14 +137,14 @@ private static void setDiscretizerReuse(StreamGraph streamGraph) { if (matchList.size() > 1) { StreamNode first = matchList.get(0); for (int i = 1; i < matchList.size(); i++) { - replaceDiscretizer(streamGraph, matchList.get(i).getID(), first.getID()); + replaceDiscretizer(streamGraph, matchList.get(i).getId(), first.getId()); } } } } private static void replaceDiscretizer(StreamGraph streamGraph, Integer toReplaceID, - Integer replaceWithID) { + Integer replaceWithId) { // Convert to array to create a copy List outEdges = new ArrayList(streamGraph .getStreamNode(toReplaceID).getOutEdges()); @@ -155,7 +155,7 @@ private static void replaceDiscretizer(StreamGraph streamGraph, Integer toReplac for (int i = 0; i < numOutputs; i++) { StreamEdge outEdge = outEdges.get(i); - streamGraph.addEdge(replaceWithID, outEdge.getTargetID(), outEdge.getPartitioner(), 0, + streamGraph.addEdge(replaceWithId, outEdge.getTargetId(), outEdge.getPartitioner(), 0, new ArrayList()); } diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java index 20aeb898ba64d..2094d31efd90e 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java @@ -78,8 +78,8 @@ public OutputHandler(StreamTask vertex) { for (StreamEdge outEdge : outEdgesInOrder) { StreamOutput streamOutput = createStreamOutput( outEdge, - outEdge.getTargetID(), - chainedConfigs.get(outEdge.getSourceID()), + outEdge.getTargetId(), + chainedConfigs.get(outEdge.getSourceId()), outEdgesInOrder.indexOf(outEdge)); outputMap.put(outEdge, streamOutput); } @@ -134,7 +134,7 @@ private Output createChainedCollector(StreamConfig chainedTaskConfig) { // Create collectors for the chained outputs for (StreamEdge outputEdge : chainedTaskConfig.getChainedOutputs(cl)) { - Integer output = outputEdge.getTargetID(); + Integer output = outputEdge.getTargetId(); Collector outCollector = createChainedCollector(chainedConfigs.get(output)); diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java index c245f76e0ce98..58f7ebae12322 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java @@ -22,18 +22,26 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import java.util.List; + +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.FoldFunction; +import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.datastream.ConnectedDataStream; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSink; import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.datastream.GroupedDataStream; +import org.apache.flink.streaming.api.datastream.IterativeDataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.datastream.SplitDataStream; import org.apache.flink.streaming.api.datastream.WindowedDataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.WindowMapFunction; @@ -42,8 +50,17 @@ import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamGraph; +import org.apache.flink.streaming.api.graph.StreamGraph.StreamLoop; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.windowing.helper.Count; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.partitioner.FieldsPartitioner; +import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; +import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.util.TestStreamEnvironment; import org.apache.flink.util.Collector; import org.junit.Test; @@ -329,10 +346,167 @@ public CustomPOJO fold(CustomPOJO accumulator, String value) throws Exception { assertEquals(TypeExtractor.getForClass(CustomPOJO.class), flatten.getType()); } + @Test + public void operatorTest() { + StreamExecutionEnvironment env = new TestStreamEnvironment(PARALLELISM, MEMORYSIZE); + + StreamGraph streamGraph = env.getStreamGraph(); + + DataStreamSource src = env.generateSequence(0, 0); + + MapFunction mapFunction = new MapFunction() { + @Override + public Integer map(Long value) throws Exception { + return null; + } + }; + DataStream map = src.map(mapFunction); + assertEquals(mapFunction, getFunctionForDataStream(map)); + + + FlatMapFunction flatMapFunction = new FlatMapFunction() { + @Override + public void flatMap(Long value, Collector out) throws Exception { + } + }; + DataStream flatMap = src.flatMap(flatMapFunction); + assertEquals(flatMapFunction, getFunctionForDataStream(flatMap)); + + FilterFunction filterFunction = new FilterFunction() { + @Override + public boolean filter(Integer value) throws Exception { + return false; + } + }; + + DataStream unionFilter = map + .union(flatMap) + .filter(filterFunction); + + assertEquals(filterFunction, getFunctionForDataStream(unionFilter)); + + try { + streamGraph.getStreamEdge(map.getId(), unionFilter.getId()); + } catch (RuntimeException e) { + fail(e.getMessage()); + } + + try { + streamGraph.getStreamEdge(flatMap.getId(), unionFilter.getId()); + } catch (RuntimeException e) { + fail(e.getMessage()); + } + + OutputSelector outputSelector = new OutputSelector() { + @Override + public Iterable select(Integer value) { + return null; + } + }; + + SplitDataStream split = unionFilter.split(outputSelector); + List> outputSelectors = streamGraph.getStreamNode(split.getId()).getOutputSelectors(); + assertEquals(1, outputSelectors.size()); + assertEquals(outputSelector, outputSelectors.get(0)); + + DataStream select = split.select("a"); + DataStreamSink sink = select.print(); + + StreamEdge splitEdge = streamGraph.getStreamEdge(select.getId(), sink.getId()); + assertEquals("a", splitEdge.getSelectedNames().get(0)); + + FoldFunction foldFunction = new FoldFunction() { + @Override + public String fold(String accumulator, Integer value) throws Exception { + return null; + } + }; + DataStream fold = map.fold("", foldFunction); + assertEquals(foldFunction, getFunctionForDataStream(fold)); + + ConnectedDataStream connect = fold.connect(flatMap); + CoMapFunction coMapper = new CoMapFunction() { + @Override + public String map1(String value) { + return null; + } + + @Override + public String map2(Integer value) { + return null; + } + }; + DataStream coMap = connect.map(coMapper); + assertEquals(coMapper, getFunctionForDataStream(coMap)); + + try { + streamGraph.getStreamEdge(fold.getId(), coMap.getId()); + } catch (RuntimeException e) { + fail(e.getMessage()); + } + + try { + streamGraph.getStreamEdge(flatMap.getId(), coMap.getId()); + } catch (RuntimeException e) { + fail(e.getMessage()); + } + } + + @Test + public void testChannelSelectors() { + StreamExecutionEnvironment env = new TestStreamEnvironment(PARALLELISM, MEMORYSIZE); + + StreamGraph streamGraph = env.getStreamGraph(); + + DataStreamSource src = env.generateSequence(0, 0); + + DataStream broadcast = src.broadcast(); + DataStreamSink broadcastSink = broadcast.print(); + StreamPartitioner broadcastPartitioner = + streamGraph.getStreamEdge(broadcast.getId(), broadcastSink.getId()).getPartitioner(); + assertTrue(broadcastPartitioner instanceof BroadcastPartitioner); + + DataStream shuffle = src.shuffle(); + DataStreamSink shuffleSink = shuffle.print(); + StreamPartitioner shufflePartitioner = + streamGraph.getStreamEdge(shuffle.getId(), shuffleSink.getId()).getPartitioner(); + assertTrue(shufflePartitioner instanceof ShufflePartitioner); + + DataStream forward = src.forward(); + DataStreamSink forwardSink = forward.print(); + StreamPartitioner forwardPartitioner = + streamGraph.getStreamEdge(forward.getId(), forwardSink.getId()).getPartitioner(); + assertTrue(forwardPartitioner instanceof RebalancePartitioner); + + DataStream rebalance = src.rebalance(); + DataStreamSink rebalanceSink = rebalance.print(); + StreamPartitioner rebalancePartitioner = + streamGraph.getStreamEdge(rebalance.getId(), rebalanceSink.getId()).getPartitioner(); + assertTrue(rebalancePartitioner instanceof RebalancePartitioner); + + DataStream global = src.global(); + DataStreamSink globalSink = global.print(); + StreamPartitioner globalPartitioner = + streamGraph.getStreamEdge(global.getId(), globalSink.getId()).getPartitioner(); + assertTrue(globalPartitioner instanceof GlobalPartitioner); + } + ///////////////////////////////////////////////////////////// // Utilities ///////////////////////////////////////////////////////////// + private static StreamOperator getOperatorForDataStream(DataStream dataStream) { + StreamExecutionEnvironment env = dataStream.getExecutionEnvironment(); + StreamGraph streamGraph = env.getStreamGraph(); + return streamGraph.getStreamNode(dataStream.getId()).getOperator(); + } + + private static Function getFunctionForDataStream(DataStream dataStream) { + AbstractUdfStreamOperator operator = + (AbstractUdfStreamOperator) getOperatorForDataStream(dataStream); + return operator.getUserFunction(); + } + private static Integer createDownStreamId(DataStream dataStream) { return dataStream.print().getId(); } diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java index 2f5f30d314ada..c4a3b6973977c 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java @@ -100,14 +100,13 @@ public void testSources() { StreamExecutionEnvironment env = new TestStreamEnvironment(PARALLELISM, MEMORYSIZE); SourceFunction srcFun = new SourceFunction() { + @Override - public boolean reachedEnd() throws Exception { - return false; + public void run(SourceContext ctx) throws Exception { } @Override - public Integer next() throws Exception { - return null; + public void cancel() { } }; DataStreamSource src1 = env.addSource(srcFun); @@ -117,7 +116,6 @@ public Integer next() throws Exception { DataStreamSource src2 = env.generateSequence(0, 2); assertTrue(getFunctionForDataSource(src2) instanceof FromIteratorFunction); - checkIfSameElements(list, getFunctionForDataSource(src2)); DataStreamSource src3 = env.fromElements(0L, 1L, 2L); assertTrue(getFunctionForDataSource(src3) instanceof FromElementsFunction); @@ -133,23 +131,6 @@ public Integer next() throws Exception { // Utilities ///////////////////////////////////////////////////////////// - private static void checkIfSameElements(Collection collection, SourceFunction sourceFunction) { - for (T elem : collection) { - try { - assertEquals(elem, sourceFunction.next()); - } catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - try { - assertTrue(sourceFunction.reachedEnd()); - } catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } private static StreamOperator getOperatorForDataStream(DataStream dataStream) { StreamExecutionEnvironment env = dataStream.getExecutionEnvironment(); diff --git a/flink-staging/flink-streaming/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala b/flink-staging/flink-streaming/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala index 2f80967b4762a..5e348eb6d2f71 100644 --- a/flink-staging/flink-streaming/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala +++ b/flink-staging/flink-streaming/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala @@ -18,12 +18,18 @@ package org.apache.flink.streaming.api.scala +import java.lang + +import org.apache.flink.api.common.functions._ import org.apache.flink.api.java.typeutils.TypeExtractor -import org.apache.flink.streaming.api.graph.{StreamEdge, StreamGraph} +import org.apache.flink.streaming.api.collector.selector.OutputSelector +import org.apache.flink.streaming.api.functions.co.CoMapFunction +import org.apache.flink.streaming.api.graph.{StreamEdge, StreamGraph, StreamNode} +import org.apache.flink.streaming.api.operators.{AbstractUdfStreamOperator, StreamOperator} import org.apache.flink.streaming.api.windowing.helper.Count -import org.apache.flink.streaming.runtime.partitioner.FieldsPartitioner +import org.apache.flink.streaming.runtime.partitioner._ import org.apache.flink.util.Collector -import org.junit.Assert._ +import org.junit.Assert.fail import org.junit.Test class DataStreamTest { @@ -53,11 +59,11 @@ class DataStreamTest { ).name("testCoFlatMap") assert("testCoFlatMap" == connected.getName) - val fu: ((Long, Long) => Long) = + val func: ((Long, Long) => Long) = (x: Long, y: Long) => 0L val windowed = connected.window(Count.of(10)) - .foldWindow(0L, fu) + .foldWindow(0L, func) windowed.name("testWindowFold") assert("testWindowFold" == windowed.getName) @@ -93,10 +99,10 @@ class DataStreamTest { val group3 = src1.groupBy("_1") val group4 = src1.groupBy(x => x._1) - assert(isPartitioned(graph.getStreamEdge(group1.getId, createDownStreamId(group1)))); - assert(isPartitioned(graph.getStreamEdge(group2.getId, createDownStreamId(group2)))); - assert(isPartitioned(graph.getStreamEdge(group3.getId, createDownStreamId(group3)))); - assert(isPartitioned(graph.getStreamEdge(group4.getId, createDownStreamId(group4)))); + assert(isPartitioned(graph.getStreamEdge(group1.getId, createDownStreamId(group1)))) + assert(isPartitioned(graph.getStreamEdge(group2.getId, createDownStreamId(group2)))) + assert(isPartitioned(graph.getStreamEdge(group3.getId, createDownStreamId(group3)))) + assert(isPartitioned(graph.getStreamEdge(group4.getId, createDownStreamId(group4)))) //Testing ConnectedDataStream grouping val connectedGroup1: ConnectedDataStream[_, _] = connected.groupBy(0, 0) @@ -238,33 +244,193 @@ class DataStreamTest { @Test def testTypeInfo { - val env: StreamExecutionEnvironment = StreamExecutionEnvironment - .createLocalEnvironment(parallelism) + val env = StreamExecutionEnvironment.createLocalEnvironment(parallelism) val src1: DataStream[Long] = env.generateSequence(0, 0) - assertEquals(TypeExtractor.getForClass(classOf[Long]), src1.getType) + assert(TypeExtractor.getForClass(classOf[Long]) == src1.getType) val map: DataStream[(Integer, String)] = src1.map(x => null) - assertEquals(classOf[scala.Tuple2[Integer, String]], map.getType.getTypeClass) + assert(classOf[scala.Tuple2[Integer, String]] == map.getType.getTypeClass) val window: WindowedDataStream[String] = map .window(Count.of(5)) .mapWindow((x: Iterable[(Integer, String)], y: Collector[String]) => {}) - assertEquals(TypeExtractor.getForClass(classOf[String]), window.getType) + assert(TypeExtractor.getForClass(classOf[String]) == window.getType) val flatten: DataStream[Int] = window .foldWindow(0, (accumulator: Int, value: String) => 0 ).flatten - assertEquals(TypeExtractor.getForClass(classOf[Int]), flatten.getType) + assert(TypeExtractor.getForClass(classOf[Int]) == flatten.getType) // TODO check for custom case class } + @Test def operatorTest { + val env = StreamExecutionEnvironment.createLocalEnvironment(parallelism) + + val streamGraph = env.getStreamGraph + + val src = env.generateSequence(0, 0) + + val mapFunction = new MapFunction[Long, Int] { + override def map(value: Long): Int = 0 + }; + val map = src.map(mapFunction) + assert(mapFunction == getFunctionForDataStream(map)) + assert(getFunctionForDataStream(map.map(x => 0)).isInstanceOf[MapFunction[Int, Int]]) + + + val flatMapFunction = new FlatMapFunction[Long, Int] { + override def flatMap(value: Long, out: Collector[Int]): Unit = {} + } + val flatMap = src.flatMap(flatMapFunction) + assert(flatMapFunction == getFunctionForDataStream(flatMap)) + assert( + getFunctionForDataStream(flatMap + .flatMap((x: Int, out: Collector[Int]) => {})) + .isInstanceOf[FlatMapFunction[Int, Int]]) + + val filterFunction = new FilterFunction[Int] { + override def filter(value: Int): Boolean = false + } + + val unionFilter = map.union(flatMap).filter(filterFunction) + assert(filterFunction == getFunctionForDataStream(unionFilter)) + assert( + getFunctionForDataStream(map + .filter((x: Int) => true)) + .isInstanceOf[FilterFunction[Int]]) + + try { + streamGraph.getStreamEdge(map.getId, unionFilter.getId) + } + catch { + case e => { + fail(e.getMessage) + } + } + + try { + streamGraph.getStreamEdge(flatMap.getId, unionFilter.getId) + } + catch { + case e => { + fail(e.getMessage) + } + } + + val outputSelector = new OutputSelector[Int] { + override def select(value: Int): lang.Iterable[String] = null + } + + val split = unionFilter.split(outputSelector) + val outputSelectors = streamGraph.getStreamNode(split.getId).getOutputSelectors + assert(1 == outputSelectors.size) + assert(outputSelector == outputSelectors.get(0)) + + unionFilter.split(x => List("a")) + val moreOutputSelectors = streamGraph.getStreamNode(split.getId).getOutputSelectors + assert(2 == moreOutputSelectors.size) + + val select = split.select("a") + val sink = select.print + val splitEdge = streamGraph.getStreamEdge(select.getId, sink.getId) + assert("a" == splitEdge.getSelectedNames.get(0)) + + val foldFunction = new FoldFunction[Int, String] { + override def fold(accumulator: String, value: Int): String = "" + } + val fold = map.fold("", foldFunction) + assert(foldFunction == getFunctionForDataStream(fold)) + assert( + getFunctionForDataStream(map + .fold("", (x: String, y: Int) => "")) + .isInstanceOf[FoldFunction[Int, String]]) + + val connect = fold.connect(flatMap) + + val coMapFunction = + new CoMapFunction[String, Int, String] { + override def map1(value: String): String = "" + + override def map2(value: Int): String = "" + } + val coMap = connect.map(coMapFunction) + assert(coMapFunction == getFunctionForDataStream(coMap)) + + try { + streamGraph.getStreamEdge(fold.getId, coMap.getId) + } + catch { + case e => { + fail(e.getMessage) + } + } + try { + streamGraph.getStreamEdge(flatMap.getId, coMap.getId) + } + catch { + case e => { + fail(e.getMessage) + } + } + } + + @Test + def testChannelSelectors { + val env = StreamExecutionEnvironment.createLocalEnvironment(parallelism) + + val streamGraph = env.getStreamGraph + val src = env.generateSequence(0, 0) + + val broadcast = src.broadcast + val broadcastSink = broadcast.print + val broadcastPartitioner = streamGraph + .getStreamEdge(broadcast.getId, broadcastSink.getId).getPartitioner + assert(broadcastPartitioner.isInstanceOf[BroadcastPartitioner[_]]) + + val shuffle: DataStream[Long] = src.shuffle + val shuffleSink = shuffle.print + val shufflePartitioner = streamGraph + .getStreamEdge(shuffle.getId, shuffleSink.getId).getPartitioner + assert(shufflePartitioner.isInstanceOf[ShufflePartitioner[_]]) + + val forward: DataStream[Long] = src.forward + val forwardSink = forward.print + val forwardPartitioner = streamGraph + .getStreamEdge(forward.getId, forwardSink.getId).getPartitioner + assert(forwardPartitioner.isInstanceOf[RebalancePartitioner[_]]) + + val rebalance: DataStream[Long] = src.rebalance + val rebalanceSink = rebalance.print + val rebalancePartitioner = streamGraph + .getStreamEdge(rebalance.getId, rebalanceSink.getId).getPartitioner + assert(rebalancePartitioner.isInstanceOf[RebalancePartitioner[_]]) + + val global: DataStream[Long] = src.global + val globalSink = global.print + val globalPartitioner = streamGraph + .getStreamEdge(global.getId, globalSink.getId).getPartitioner + assert(globalPartitioner.isInstanceOf[GlobalPartitioner[_]]) + } + ///////////////////////////////////////////////////////////// // Utilities ///////////////////////////////////////////////////////////// + private def getFunctionForDataStream(dataStream: DataStream[_]): Function = { + val operator = getOperatorForDataStream(dataStream) + .asInstanceOf[AbstractUdfStreamOperator[_, _]] + return operator.getUserFunction.asInstanceOf[Function] + } + + private def getOperatorForDataStream(dataStream: DataStream[_]): StreamOperator[_] = { + val env = dataStream.getJavaStream.getExecutionEnvironment + val streamGraph: StreamGraph = env.getStreamGraph + streamGraph.getStreamNode(dataStream.getId).getOperator + } + private def isPartitioned(edge: StreamEdge): Boolean = { return edge.getPartitioner.isInstanceOf[FieldsPartitioner[_]] }