Skip to content

Commit

Permalink
[FLINK-1110] Add Collection-Based execution for Reduce Operators
Browse files Browse the repository at this point in the history
Also fix some bugs resulting from moving stuff between packages
  • Loading branch information
aljoscha authored and StephanEwen committed Oct 3, 2014
1 parent 114af5a commit fd3f5c2
Show file tree
Hide file tree
Showing 10 changed files with 692 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@
package org.apache.flink.api.common.operators.base;


import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FlatCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

/**
* @see org.apache.flink.api.common.functions.GroupReduceFunction
Expand Down Expand Up @@ -56,7 +68,7 @@ public GroupReduceOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation
public GroupReduceOperatorBase(UserCodeWrapper<FT> udf, UnaryOperatorInformation<IN, OUT> operatorInfo, String name) {
super(udf, operatorInfo, name);
}

public GroupReduceOperatorBase(FT udf, UnaryOperatorInformation<IN, OUT> operatorInfo, String name) {
super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, name);
}
Expand Down Expand Up @@ -115,4 +127,49 @@ public boolean isCombinable() {
return this.combinable;
}

// --------------------------------------------------------------------------------------------

@Override
protected List<OUT> executeOnCollections(List<IN> inputData, RuntimeContext ctx)
throws Exception {
GroupReduceFunction<IN, OUT> function = this.userFunction.getUserCodeObject();

UnaryOperatorInformation<IN, OUT> operatorInfo = getOperatorInfo();
TypeInformation<IN> inputType = operatorInfo.getInputType();

if (!(inputType instanceof CompositeType)) {
throw new InvalidProgramException("Input type of groupReduce operation must be" +
" composite type.");
}

int[] inputColumns = getKeyColumns(0);
boolean[] inputOrderings = new boolean[inputColumns.length];
final TypeComparator<IN> inputComparator =
((CompositeType<IN>) inputType).createComparator(inputColumns, inputOrderings);

FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);


ArrayList<OUT> result = new ArrayList<OUT>(inputData.size());
ListCollector<OUT> collector = new ListCollector<OUT>(result);

inputData.sort( new Comparator<IN>() {
@Override
public int compare(IN o1, IN o2) {
return - inputComparator.compare(o1, o2);
}
});
ListKeyGroupedIterator<IN> keyedIterator =
new ListKeyGroupedIterator<IN>(inputData, inputComparator);

while (keyedIterator.nextKey()) {
function.reduce(keyedIterator.getValues(), collector);
}

FunctionUtils.closeFunction(function);

return result;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@

package org.apache.flink.api.common.operators.base;

import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.TypeComparable;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


/**
Expand Down Expand Up @@ -88,7 +100,7 @@ public ReduceOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation<T, T
public ReduceOperatorBase(UserCodeWrapper<FT> udf, UnaryOperatorInformation<T, T> operatorInfo, String name) {
super(udf, operatorInfo, name);
}

/**
* Creates a non-grouped reduce data flow operator (all-reduce).
*
Expand All @@ -110,4 +122,61 @@ public ReduceOperatorBase(FT udf, UnaryOperatorInformation<T, T> operatorInfo, S
public ReduceOperatorBase(Class<? extends FT> udf, UnaryOperatorInformation<T, T> operatorInfo, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, name);
}

// --------------------------------------------------------------------------------------------

@SuppressWarnings("unchecked")
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext ctx)
throws Exception {
ReduceFunction<T> function = this.userFunction.getUserCodeObject();

UnaryOperatorInformation<T, T> operatorInfo = getOperatorInfo();
TypeInformation<T> inputType = operatorInfo.getInputType();

if (!(inputType instanceof CompositeType)) {
throw new InvalidProgramException("Input type of groupReduce operation must be" +
" composite type.");
}

FunctionUtils.setFunctionRuntimeContext(function, ctx);
FunctionUtils.openFunction(function, this.parameters);

int[] inputColumns = getKeyColumns(0);
if (inputColumns.length > 0) {
boolean[] inputOrderings = new boolean[inputColumns.length];
TypeComparator<T> inputComparator = ((CompositeType<T>) inputType).createComparator(inputColumns, inputOrderings);

Map<TypeComparable<T>, T> aggregateMap = new HashMap<TypeComparable<T>, T>(inputData.size() / 10);

for (T next : inputData) {
TypeComparable<T> wrapper = new TypeComparable<T>(next, inputComparator);
T existing = aggregateMap.get(wrapper);
T result;
if (existing != null) {
result = function.reduce(existing, next);
} else {
result = next;
}
aggregateMap.put(wrapper, result);
}

List<T> result = new ArrayList<T>(aggregateMap.values().size());
result.addAll(aggregateMap.values());

FunctionUtils.closeFunction(function);
return result;
} else {
T aggregate = inputData.get(0);
for (int i = 1; i < inputData.size(); i++) {
aggregate = function.reduce(aggregate, inputData.get(i));
}
List<T> result = new ArrayList<T>(1);
result.add(aggregate);

FunctionUtils.setFunctionRuntimeContext(function, ctx);
return result;
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators.util;

import org.apache.flink.api.common.typeutils.TypeComparator;

import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/**
* The KeyValueIterator returns a key and all values that belong to the key (share the same key).
*
*/
public final class ListKeyGroupedIterator<E> {

private final List<E> input;

// private final TypeSerializer<E> serializer;

private final TypeComparator<E> comparator;

private ValuesIterator valuesIterator;

private int currentPosition = 0;

private E lookahead = null;

private boolean done;

/**
* Initializes the ListKeyGroupedIterator..
*
* @param input The list with the input elements.
* @param comparator The comparator for the data type iterated over.
*/
public ListKeyGroupedIterator(List<E> input, TypeComparator<E> comparator)
{
if (input == null || comparator == null) {
throw new NullPointerException();
}

this.input = input;
// this.serializer = serializer;
this.comparator = comparator;
}

/**
* Moves the iterator to the next key. This method may skip any values that have not yet been returned by the
* iterator created by the {@link #getValues()} method. Hence, if called multiple times it "removes" key groups.
*
* @return true, if the input iterator has an other group of records with the same key.
*/
public boolean nextKey() throws IOException {

if (lookahead != null) {
// common case: whole value-iterator was consumed and a new key group is available.
this.comparator.setReference(this.lookahead);
this.valuesIterator.next = this.lookahead;
this.lookahead = null;
return true;
}

// first element, empty/done, or the values iterator was not entirely consumed
if (this.done) {
return false;
}

if (this.valuesIterator != null) {
// values was not entirely consumed. move to the next key
// Required if user code / reduce() method did not read the whole value iterator.
E next;
while (true) {
if ((next = this.input.get(currentPosition++)) != null) {
if (!this.comparator.equalToReference(next)) {
// the keys do not match, so we have a new group. store the current key
this.comparator.setReference(next);
this.valuesIterator.next = next;
return true;
}
}
else {
// input exhausted
this.valuesIterator.next = null;
this.valuesIterator = null;
this.done = true;
return false;
}
}
}
else {
// first element
// get the next element
E first = input.get(currentPosition++);
if (first != null) {
this.comparator.setReference(first);
this.valuesIterator = new ValuesIterator(first);
return true;
}
else {
// empty input, set everything null
this.done = true;
return false;
}
}
}

private E advanceToNext() {
if (currentPosition < input.size()) {
E next = input.get(currentPosition++);
if (comparator.equalToReference(next)) {
// same key
return next;
} else {
// moved to the next key, no more values here
lookahead = next;
return null;
}
}
else {
// backing iterator is consumed
this.done = true;
return null;
}
}

/**
* Returns an iterator over all values that belong to the current key. The iterator is initially <code>null</code>
* (before the first call to {@link #nextKey()} and after all keys are consumed. In general, this method returns
* always a non-null value, if a previous call to {@link #nextKey()} return <code>true</code>.
*
* @return Iterator over all values that belong to the current key.
*/
public ValuesIterator getValues() {
return this.valuesIterator;
}

// --------------------------------------------------------------------------------------------

public final class ValuesIterator implements Iterator<E>, Iterable<E> {

private E next;

private ValuesIterator(E first) {
this.next = first;
}

@Override
public boolean hasNext() {
return next != null;
}

@Override
public E next() {
if (this.next != null) {
E current = this.next;
this.next = ListKeyGroupedIterator.this.advanceToNext();
return current;
} else {
throw new NoSuchElementException();
}
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}

@Override
public Iterator<E> iterator() {
return this;
}
}
}
Loading

0 comments on commit fd3f5c2

Please sign in to comment.