Skip to content

Commit

Permalink
[FLINK-3226] Translate logical joins to physical
Browse files Browse the repository at this point in the history
This closes apache#1632
  • Loading branch information
vasia committed Mar 18, 2016
1 parent 670ec6a commit f89d303
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import org.apache.flink.api.table.codegen.Indenter.toISC
import org.apache.flink.api.table.codegen.OperatorCodeGen._
import org.apache.flink.api.table.plan.TypeConverter.sqlTypeToTypeInfo
import org.apache.flink.api.table.typeinfo.RowTypeInfo

import scala.collection.JavaConversions._
import scala.collection.mutable
import org.apache.flink.api.common.functions.FlatJoinFunction

/**
* A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s.
Expand Down Expand Up @@ -148,16 +148,25 @@ class CodeGenerator(
if (clazz == classOf[FlatMapFunction[_,_]]) {
val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
(s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)",
s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
}

// MapFunction
else if (clazz == classOf[MapFunction[_,_]]) {
val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
("Object map(Object _in1)",
s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")
List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
}

// FlatJoinFunction
else if (clazz == classOf[FlatJoinFunction[_,_,_]]) {
val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1)
val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse(
throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
(s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)",
List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;",
s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
}
else {
// TODO more functions
throw new CodeGenException("Unsupported Function.")
Expand All @@ -175,7 +184,7 @@ class CodeGenerator(

@Override
public ${samHeader._1} {
${samHeader._2}
${samHeader._2.mkString("\n")}
${reuseInputUnboxingCode()}
$bodyCode
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.{TableConfig, Row}
import org.apache.flink.api.common.functions.FlatJoinFunction
import org.apache.flink.api.table.plan.TypeConverter._
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.java.tuple.Tuple2
import org.apache.flink.api.table.typeinfo.RowTypeInfo
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import org.apache.flink.api.table.plan.TypeConverter

/**
* Flink RelNode which matches along with JoinOperator and its related operations.
Expand All @@ -42,7 +50,8 @@ class DataSetJoin(
joinKeysRight: Array[Int],
joinType: JoinType,
joinHint: JoinHint,
func: JoinFunction[Row, Row, Row])
func: (TableConfig, TypeInformation[Any], TypeInformation[Any], TypeInformation[Any]) =>
FlatJoinFunction[Any, Any, Any])
extends BiRel(cluster, traitSet, left, right)
with DataSetRel {

Expand Down Expand Up @@ -71,6 +80,19 @@ class DataSetJoin(
override def translateToPlan(
config: TableConfig,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
???

val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(config)
val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(config)

val returnType = determineReturnType(
getRowType,
expectedType,
config.getNullCheck,
config.getEfficientTypeUsage)

val joinFun = func.apply(config, leftDataSet.getType, rightDataSet.getType, returnType)
leftDataSet.join(rightDataSet).where(joinKeysLeft: _*).equalTo(joinKeysRight: _*)
.`with`(joinFun).asInstanceOf[DataSet[Any]]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, DataSetJoin}
import org.apache.flink.api.table.plan.nodes.logical.{FlinkJoin, FlinkConvention}
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import org.apache.flink.api.table.plan.TypeConverter._
import org.apache.flink.api.table.runtime.FlatJoinRunner
import org.apache.flink.api.table.TableConfig
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.codegen.CodeGenerator
import org.apache.flink.api.common.functions.FlatJoinFunction
import org.apache.calcite.rel.core.JoinInfo
import org.apache.flink.api.table.TableException

class DataSetJoinRule
extends ConverterRule(
Expand All @@ -39,18 +49,82 @@ class DataSetJoinRule
val convLeft: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE)
val convRight: RelNode = RelOptRule.convert(join.getInput(1), DataSetConvention.INSTANCE)

new DataSetJoin(
rel.getCluster,
traitSet,
convLeft,
convRight,
rel.getRowType,
join.toString,
Array[Int](),
Array[Int](),
JoinType.INNER,
null,
null)
// get the equality keys
val joinInfo = join.analyzeCondition
val keyPairs = joinInfo.pairs

if (keyPairs.isEmpty) { // if no equality keys => not supported
throw new TableException("Joins should have at least one equality condition")
}
else { // at least one equality expression => generate a join function
val conditionType = join.getCondition.getType
val func = getJoinFunction(join, joinInfo)
val leftKeys = ArrayBuffer.empty[Int]
val rightKeys = ArrayBuffer.empty[Int]

keyPairs.foreach(pair => {
leftKeys.add(pair.source)
rightKeys.add(pair.target)}
)

new DataSetJoin(
rel.getCluster,
traitSet,
convLeft,
convRight,
rel.getRowType,
join.toString,
leftKeys.toArray,
rightKeys.toArray,
JoinType.INNER,
null,
func)
}
}

def getJoinFunction(join: FlinkJoin, joinInfo: JoinInfo):
((TableConfig, TypeInformation[Any], TypeInformation[Any], TypeInformation[Any]) =>
FlatJoinFunction[Any, Any, Any]) = {

val func = (
config: TableConfig,
leftInputType: TypeInformation[Any],
rightInputType: TypeInformation[Any],
returnType: TypeInformation[Any]) => {

val generator = new CodeGenerator(config, leftInputType, Some(rightInputType))
val conversion = generator.generateConverterResultExpression(returnType)
var body = ""

if (joinInfo.isEqui) {
// only equality condition
body = s"""
|${conversion.code}
|${generator.collectorTerm}.collect(${conversion.resultTerm});
|""".stripMargin
}
else {
val condition = generator.generateExpression(join.getCondition)
body = s"""
|${condition.code}
|if (${condition.resultTerm}) {
| ${conversion.code}
| ${generator.collectorTerm}.collect(${conversion.resultTerm});
|}
|""".stripMargin
}
val genFunction = generator.generateFunction(
description,
classOf[FlatJoinFunction[Any, Any, Any]],
body,
returnType)

new FlatJoinRunner[Any, Any, Any](
genFunction.name,
genFunction.code,
genFunction.returnType)
}
func
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.runtime

import org.apache.flink.api.common.functions.{FlatJoinFunction, RichFlatJoinFunction}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.configuration.Configuration
import org.apache.flink.util.Collector
import org.slf4j.LoggerFactory

class FlatJoinRunner[IN1, IN2, OUT](
name: String,
code: String,
@transient returnType: TypeInformation[OUT])
extends RichFlatJoinFunction[IN1, IN2, OUT]
with ResultTypeQueryable[OUT]
with FunctionCompiler[FlatJoinFunction[IN1, IN2, OUT]] {

val LOG = LoggerFactory.getLogger(this.getClass)

private var function: FlatJoinFunction[IN1, IN2, OUT] = null

override def open(parameters: Configuration): Unit = {
LOG.debug(s"Compiling FlatJoinFunction: $name \n\n Code:\n$code")
val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
LOG.debug("Instantiating FlatJoinFunction.")
function = clazz.newInstance()
}

override def join(first: IN1, second: IN2, out: Collector[OUT]): Unit =
function.join(first, second, out)

override def getProducedType: TypeInformation[OUT] = returnType
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@

import org.apache.flink.api.table.Row;
import org.apache.flink.api.table.Table;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.table.TableEnvironment;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import scala.NotImplementedError;

import java.util.List;

Expand All @@ -43,7 +42,7 @@ public JoinITCase(TestExecutionMode mode) {
super(mode);
}

@Test(expected = NotImplementedError.class)
@Test
public void testJoin() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand All @@ -62,7 +61,7 @@ public void testJoin() throws Exception {
compareResultAsText(results, expected);
}

@Test(expected = NotImplementedError.class)
@Test
public void testJoinWithFilter() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand All @@ -81,7 +80,27 @@ public void testJoinWithFilter() throws Exception {
compareResultAsText(results, expected);
}

@Test(expected = NotImplementedError.class)
@Test
public void testJoinWithJoinFilter() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.get3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "d, e, f, g, h");

Table result = in1.join(in2).where("b === e && a < 6 && h < b").select("c, g");

DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "Hello world, how are you?,Hallo Welt wie\n" +
"I am fine.,Hallo Welt wie\n";
compareResultAsText(results, expected);
}

@Test
public void testJoinWithMultipleKeys() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand Down Expand Up @@ -120,9 +139,7 @@ public void testJoinNonExistingKey() throws Exception {
compareResultAsText(results, expected);
}

// Calcite does not eagerly check the compatibility of compared types
@Ignore
@Test(expected = IllegalArgumentException.class)
@Test(expected = InvalidProgramException.class)
public void testJoinWithNonMatchingKeyTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand Down Expand Up @@ -162,7 +179,7 @@ public void testJoinWithAmbiguousFields() throws Exception {
compareResultAsText(results, expected);
}

@Test(expected = NotImplementedError.class)
@Test
public void testJoinWithAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand Down
Loading

0 comments on commit f89d303

Please sign in to comment.