Skip to content

Commit

Permalink
[FLINK-1110] Started implementing the JoinOperatorBase.
Browse files Browse the repository at this point in the history
Implemented JoinOperatorBase and test cases.
  • Loading branch information
tillrohrmann authored and StephanEwen committed Oct 3, 2014
1 parent 77ac6c0 commit 471f340
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package org.apache.flink.api.common.operators.base;

import java.util.List;

import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.Ordering;
Expand Down Expand Up @@ -152,4 +155,10 @@ public boolean isCombinableSecond() {
public void setCombinableSecond(boolean combinableSecond) {
this.combinableSecond = combinableSecond;
}

@Override
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext) throws Exception {
// TODO Auto-generated method stub
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
package org.apache.flink.api.common.operators.base;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.apache.flink.api.common.aggregators.AggregatorRegistry;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.IterationOperator;
Expand Down Expand Up @@ -329,4 +331,10 @@ public UserCodeWrapper<?> getUserCodeWrapper() {
return null;
}
}

@Override
protected List<ST> executeOnCollections(List<ST> inputData1, List<WT> inputData2, RuntimeContext runtimeContext) throws Exception {
// TODO Auto-generated method stub
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,26 @@
package org.apache.flink.api.common.operators.base;

import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* @see org.apache.flink.api.common.functions.FlatJoinFunction
Expand All @@ -34,12 +49,91 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
public JoinOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(udf, operatorInfo, keyPositions1, keyPositions2, name);
}

public JoinOperatorBase(FT udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name);
}

public JoinOperatorBase(Class<? extends FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name);
}

@SuppressWarnings("unchecked")
@Override
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext) throws Exception {
FlatJoinFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject();

FunctionUtils.setFunctionRuntimeContext(function, runtimeContext);
FunctionUtils.openFunction(function, this.parameters);

TypeInformation<IN1> leftInformation = getOperatorInfo().getFirstInputType();
TypeInformation<IN2> rightInformation = getOperatorInfo().getSecondInputType();

TypeComparator<IN1> leftComparator;
TypeComparator<IN2> rightComparator;

if(leftInformation instanceof AtomicType){
leftComparator = ((AtomicType<IN1>) leftInformation).createComparator(true);
}else if(leftInformation instanceof CompositeType){
int[] keyPositions = getKeyColumns(0);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);

leftComparator = ((CompositeType<IN1>) leftInformation).createComparator(keyPositions, orders);
}else{
throw new RuntimeException("Type information for left input of type " + leftInformation.getClass()
.getCanonicalName() + " is not supported. Could not generate a comparator.");
}

if(rightInformation instanceof AtomicType){
rightComparator = ((AtomicType<IN2>) rightInformation).createComparator(true);
}else if(rightInformation instanceof CompositeType){
int[] keyPositions = getKeyColumns(1);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);

rightComparator = ((CompositeType<IN2>) rightInformation).createComparator(keyPositions, orders);
}else{
throw new RuntimeException("Type information for right input of type " + rightInformation.getClass()
.getCanonicalName() + " is not supported. Could not generate a comparator.");
}

TypePairComparator<IN1, IN2> pairComparator = new GenericPairComparator<IN1, IN2>(leftComparator,
rightComparator);

List<OUT> result = new ArrayList<OUT>();
ListCollector<OUT> collector = new ListCollector<OUT>(result);

Map<Integer, List<IN2>> probeTable = new HashMap<Integer, List<IN2>>();

//Build probe table
for(IN2 element: inputData2){
List<IN2> list = probeTable.get(rightComparator.hash(element));
if(list == null){
list = new ArrayList<IN2>();
probeTable.put(rightComparator.hash(element), list);
}

list.add(element);
}

//Probing
for(IN1 left: inputData1){
List<IN2> matchingHashes = probeTable.get(leftComparator.hash(left));

pairComparator.setReference(left);

if(matchingHashes != null){
for(IN2 right: matchingHashes){
if(pairComparator.equalToReference(right)){
function.join(left, right, collector);
}
}
}
}

FunctionUtils.closeFunction(function);

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils.runtime;
package org.apache.flink.api.common.typeutils;

import java.io.Serializable;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.operators.base;

import static org.junit.Assert.*;

import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.junit.Test;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

@SuppressWarnings("serial")
public class JoinOperatorBaseTest implements Serializable {

@Test
public void testJoinPlain(){
final FlatJoinFunction<String, String, Integer> joiner = new FlatJoinFunction<String, String, Integer>() {

@Override
public void join(String first, String second, Collector<Integer> out) throws Exception {
out.collect(first.length());
out.collect(second.length());
}
};

@SuppressWarnings({ "rawtypes", "unchecked" })
JoinOperatorBase<String, String, Integer,
FlatJoinFunction<String, String,Integer> > base = new JoinOperatorBase(joiner,
new BinaryOperatorInformation(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO), new int[0], new int[0], "TestJoiner");

List<String> inputData1 = new ArrayList<String>(Arrays.asList("foo", "bar", "foobar"));
List<String> inputData2 = new ArrayList<String>(Arrays.asList("foobar", "foo"));
List<Integer> expected = new ArrayList<Integer>(Arrays.asList(3, 3, 6 ,6));

try {
List<Integer> result = base.executeOnCollections(inputData1, inputData2, null);

assertEquals(expected, result);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}

@Test
public void testJoinRich(){
final AtomicBoolean opened = new AtomicBoolean(false);
final AtomicBoolean closed = new AtomicBoolean(false);
final String taskName = "Test rich join function";

final RichFlatJoinFunction<String, String, Integer> joiner = new RichFlatJoinFunction<String, String, Integer>() {
@Override
public void open(Configuration parameters) throws Exception {
opened.compareAndSet(false, true);
assertEquals(0, getRuntimeContext().getIndexOfThisSubtask());
assertEquals(1, getRuntimeContext().getNumberOfParallelSubtasks());
}

@Override
public void close() throws Exception{
closed.compareAndSet(false, true);
}

@Override
public void join(String first, String second, Collector<Integer> out) throws Exception {
out.collect(first.length());
out.collect(second.length());
}
};

JoinOperatorBase<String, String, Integer,
RichFlatJoinFunction<String, String, Integer>> base = new JoinOperatorBase<String, String, Integer,
RichFlatJoinFunction<String, String, Integer>>(joiner, new BinaryOperatorInformation<String, String,
Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO), new int[0], new int[0], taskName);

final List<String> inputData1 = new ArrayList<String>(Arrays.asList("foo", "bar", "foobar"));
final List<String> inputData2 = new ArrayList<String>(Arrays.asList("foobar", "foo"));
final List<Integer> expected = new ArrayList<Integer>(Arrays.asList(3, 3, 6, 6));


try {
List<Integer> result = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName,
1, 0));

assertEquals(expected, result);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}

assertTrue(opened.get());
assertTrue(closed.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.api.java.typeutils.runtime;

import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
Expand Down
Loading

0 comments on commit 471f340

Please sign in to comment.