Skip to content

Commit

Permalink
[FLINK-12140][table-planner-blink] Support e2e sort merge join operat…
Browse files Browse the repository at this point in the history
…or in batch mode (apache#8127)
  • Loading branch information
JingsongLi authored and KurtYoung committed Apr 12, 2019
1 parent a2f59a1 commit c2e3bdc
Show file tree
Hide file tree
Showing 25 changed files with 2,164 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,4 @@ object FlinkPlannerImpl {
* the default field collation if not specified, Consistent with CALCITE.
*/
val defaultCollationDirection: RelFieldCollation.Direction = RelFieldCollation.Direction.ASCENDING

/** Returns the default null direction if not specified. */
def getNullDefaultOrders(ascendings: Array[Boolean]): Array[Boolean] = {
ascendings.map { asc =>
FlinkPlannerImpl.defaultNullCollation.last(!asc)
}
}

/** Returns the default null direction if not specified. */
def getNullDefaultOrder(ascending: Boolean): Boolean = {
FlinkPlannerImpl.defaultNullCollation.last(!ascending)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.flink.streaming.api.functions.async.{AsyncFunction, RichAsyncF
import org.apache.flink.table.`type`.InternalType
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.generated.GeneratedFunction
import org.apache.flink.table.generated.{GeneratedFunction, GeneratedJoinCondition, JoinCondition}

/**
* A code generator for generating Flink [[org.apache.flink.api.common.functions.Function]]s.
Expand Down Expand Up @@ -159,4 +159,50 @@ object FunctionCodeGenerator {

new GeneratedFunction(funcName, funcCode, ctx.references.toArray)
}

/**
* Generates a [[JoinCondition]] that can be passed to Java compiler.
*
* @param ctx The context of the code generator
* @param name Class name of the Function. Not must be unique but has to be a valid Java class
* identifier.
* @param bodyCode code contents of the SAM (Single Abstract Method).
* @param input1Term the first input term
* @param input2Term the second input term.
* @return instance of GeneratedJoinCondition
*/
def generateJoinCondition(
ctx: CodeGeneratorContext,
name: String,
bodyCode: String,
input1Term: String = CodeGenUtils.DEFAULT_INPUT1_TERM,
input2Term: String = CodeGenUtils.DEFAULT_INPUT2_TERM): GeneratedJoinCondition = {
val funcName = newName(name)

val baseClass = classOf[JoinCondition]

val funcCode =
j"""
public class $funcName implements ${baseClass.getCanonicalName} {

${ctx.reuseMemberCode()}

public $funcName(Object[] references) throws Exception {
${ctx.reuseInitCode()}
}

${ctx.reuseConstructorCode(funcName)}

@Override
public boolean apply($BASE_ROW $input1Term, $BASE_ROW $input2Term) throws Exception {
${ctx.reusePerRecordCode()}
${ctx.reuseLocalVariableCode()}
${ctx.reuseInputUnboxingCode()}
$bodyCode
}
}
""".stripMargin

new GeneratedJoinCondition(funcName, funcCode, ctx.references.toArray)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ package org.apache.flink.table.codegen
import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.typeinfo.{AtomicType => AtomicTypeInfo}
import org.apache.flink.table.`type`._
import org.apache.flink.table.calcite.FlinkPlannerImpl
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.calls.CurrentTimePointCallGen
import org.apache.flink.table.dataformat._
import org.apache.flink.table.plan.util.SortUtil
import org.apache.flink.table.typeutils.TypeCheckUtils.{isReference, isTemporal}

import org.apache.calcite.avatica.util.ByteString
import org.apache.commons.lang3.StringEscapeUtils

Expand Down Expand Up @@ -653,7 +654,7 @@ object GenerateUtils {
val compareFunc = newName("compareArray")
val compareCode = generateArrayCompare(
ctx,
FlinkPlannerImpl.getNullDefaultOrder(true), at, "a", "b")
SortUtil.getNullDefaultOrder(true), at, "a", "b")
val funcCode: String =
s"""
public int $compareFunc($BINARY_ARRAY a, $BINARY_ARRAY b) {
Expand All @@ -670,7 +671,7 @@ object GenerateUtils {
rowType.getFieldTypes.indices.toArray,
rowType.getFieldTypes,
orders,
FlinkPlannerImpl.getNullDefaultOrders(orders),
SortUtil.getNullDefaultOrders(orders),
"a",
"b")
val compareFunc = newName("compareRow")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.metrics.Gauge
import org.apache.flink.table.`type`.{InternalType, RowType}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.calcite.FlinkPlannerImpl
import org.apache.flink.table.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull}
import org.apache.flink.table.codegen.agg.batch.AggCodeGenHelper.buildAggregateArgsMapping
import org.apache.flink.table.codegen.{CodeGenUtils, CodeGeneratorContext, ExprCodeGenerator, GenerateUtils, GeneratedExpression, OperatorCodeGenerator, SortCodeGenerator}
Expand All @@ -31,6 +30,7 @@ import org.apache.flink.table.expressions.{CallExpression, Expression, Expressio
import org.apache.flink.table.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction}
import org.apache.flink.table.generated.{NormalizedKeyComputer, RecordComparator}
import org.apache.flink.table.plan.util.SortUtil
import org.apache.flink.table.runtime.aggregate.{BytesHashMap, BytesHashMapSpillMemorySegmentPool}
import org.apache.flink.table.runtime.sort.BufferedKVExternalSorter
import org.apache.flink.table.typeutils.BinaryRowSerializer
Expand Down Expand Up @@ -835,7 +835,7 @@ object HashAggCodeGenHelper {
val keyFieldTypes = aggMapKeyType.getFieldTypes
val keys = keyFieldTypes.indices.toArray
val orders = keys.map((_) => true)
val nullsIsLast = FlinkPlannerImpl.getNullDefaultOrders(orders)
val nullsIsLast = SortUtil.getNullDefaultOrders(orders)

val sortCodeGenerator = new SortCodeGenerator(
ctx.tableConfig, keys, keyFieldTypes, orders, nullsIsLast)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,26 @@

package org.apache.flink.table.plan.nodes.physical.batch

import org.apache.flink.runtime.operators.DamBehavior
import org.apache.flink.streaming.api.transformations.{PartitionTransformation, StreamTransformation}
import org.apache.flink.streaming.runtime.partitioner.{BroadcastPartitioner, GlobalPartitioner, RebalancePartitioner}
import org.apache.flink.table.`type`.RowType
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.{CodeGeneratorContext, HashCodeGenerator}
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.plan.nodes.common.CommonPhysicalExchange
import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode}
import org.apache.flink.table.runtime.BinaryHashPartitioner
import org.apache.flink.table.typeutils.BaseRowTypeInfo

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.{RelDistribution, RelNode}

import java.util

import scala.collection.JavaConversions._

/**
* This RelNode represents a change of partitioning of the input elements.
*
Expand Down Expand Up @@ -77,7 +92,14 @@ class BatchExecExchange(
inputRel: RelNode,
relDistribution: RelDistribution)
extends CommonPhysicalExchange(cluster, traitSet, inputRel, relDistribution)
with BatchPhysicalRel {
with BatchPhysicalRel
with BatchExecNode[BaseRow]{

// TODO reuse PartitionTransformation
// currently, an Exchange' input transformation will be reused if it is reusable,
// and different PartitionTransformation objects will be created which have same input.
// cache input transformation to reuse
private var reusedInput: Option[StreamTransformation[BaseRow]] = None

override def copy(
traitSet: RelTraitSet,
Expand All @@ -86,5 +108,77 @@ class BatchExecExchange(
new BatchExecExchange(cluster, traitSet, newInput, relDistribution)
}

override def getDamBehavior: DamBehavior = {
distribution.getType match {
case RelDistribution.Type.RANGE_DISTRIBUTED => DamBehavior.FULL_DAM
case _ => DamBehavior.PIPELINED
}
}

override def getInputNodes: util.List[ExecNode[BatchTableEnvironment, _]] =
getInputs.map(_.asInstanceOf[ExecNode[BatchTableEnvironment, _]])

override def translateToPlanInternal(
tableEnv: BatchTableEnvironment): StreamTransformation[BaseRow] = {
val input = reusedInput match {
case Some(transformation) => transformation
case None =>
val input = getInputNodes.get(0).translateToPlan(tableEnv)
.asInstanceOf[StreamTransformation[BaseRow]]
reusedInput = Some(input)
input
}

val inputType = input.getOutputType.asInstanceOf[BaseRowTypeInfo]
val outputRowType = FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo

relDistribution.getType match {
case RelDistribution.Type.ANY =>
val transformation = new PartitionTransformation(
input,
null)
transformation.setOutputType(outputRowType)
transformation

case RelDistribution.Type.SINGLETON =>
val transformation = new PartitionTransformation(
input,
new GlobalPartitioner[BaseRow])
transformation.setOutputType(outputRowType)
transformation

case RelDistribution.Type.RANDOM_DISTRIBUTED =>
val transformation = new PartitionTransformation(
input,
new RebalancePartitioner[BaseRow])
transformation.setOutputType(outputRowType)
transformation

case RelDistribution.Type.BROADCAST_DISTRIBUTED =>
val transformation = new PartitionTransformation(
input,
new BroadcastPartitioner[BaseRow])
transformation.setOutputType(outputRowType)
transformation

case RelDistribution.Type.HASH_DISTRIBUTED =>
// TODO Eliminate duplicate keys
val keys = relDistribution.getKeys
val partitioner = new BinaryHashPartitioner(
HashCodeGenerator.generateRowHash(
CodeGeneratorContext(tableEnv.config),
new RowType(inputType.getInternalTypes: _*),
"HashPartitioner",
keys.map(_.intValue()).toArray))
val transformation = new PartitionTransformation(
input,
partitioner)
transformation.setOutputType(outputRowType)
transformation
case _ =>
throw new UnsupportedOperationException(
s"not support RelDistribution: ${relDistribution.getType} now!")
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
*/
package org.apache.flink.table.plan.nodes.physical.batch

import org.apache.flink.table.`type`.RowType
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.codegen.{CodeGeneratorContext, ExprCodeGenerator, FunctionCodeGenerator}
import org.apache.flink.table.generated.GeneratedJoinCondition
import org.apache.flink.table.plan.nodes.common.CommonPhysicalJoin

import org.apache.calcite.rel.core.Join
Expand All @@ -26,4 +30,30 @@ import org.apache.calcite.rel.core.Join
*/
trait BatchExecJoinBase extends CommonPhysicalJoin with BatchPhysicalRel {

private[flink] def generateCondition(
config: TableConfig,
leftType: RowType,
rightType: RowType): GeneratedJoinCondition = {
val ctx = CodeGeneratorContext(config)
val exprGenerator = new ExprCodeGenerator(ctx, false)
.bindInput(leftType)
.bindSecondInput(rightType)

val body = if (joinInfo.isEqui) {
// only equality condition
"return true;"
} else {
val nonEquiPredicates = joinInfo.getRemaining(getCluster.getRexBuilder)
val condition = exprGenerator.generateExpression(nonEquiPredicates)
s"""
|${condition.code}
|return ${condition.resultTerm};
|""".stripMargin
}

FunctionCodeGenerator.generateJoinCondition(
ctx,
"JoinConditionFunction",
body)
}
}
Loading

0 comments on commit c2e3bdc

Please sign in to comment.