Skip to content

Commit

Permalink
[FLINK-16589][table-planner-blink] Split code for AggsHandlerCodeGene…
Browse files Browse the repository at this point in the history
…rator

This closes 11512
  • Loading branch information
libenchao committed Jun 18, 2020
1 parent b59ba8b commit 99fca58
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.flink.table.planner.codegen

import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.functions.{Function, RuntimeContext}
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.table.api.TableConfig
Expand Down Expand Up @@ -109,9 +108,9 @@ class CodeGeneratorContext(val tableConfig: TableConfig) {
private var currentMethodNameForLocalVariables = "DEFAULT"

/**
* Flag that indicates whether the generated code is split into several methods.
* Flag map that indicates whether the generated code for method is split into several methods.
*/
private var isCodeSplit = false
private val isCodeSplitMap = mutable.Map[String, Boolean]()

// map of local variable statements. It will be placed in method if method code not excess
// max code length, otherwise will be placed in member area of the class. The statements
Expand Down Expand Up @@ -149,11 +148,12 @@ class CodeGeneratorContext(val tableConfig: TableConfig) {
}

/**
* Set the flag [[isCodeSplit]] to be true, which indicates the generated code is split into
* several methods.
* Set the flag [[isCodeSplitMap]] to be true for methodName, which indicates
* the generated code is split into several methods.
* @param methodName the method which will be split.
*/
def setCodeSplit(): Unit = {
isCodeSplit = true
def setCodeSplit(methodName: String = currentMethodNameForLocalVariables): Unit = {
isCodeSplitMap(methodName) = true
}

/**
Expand Down Expand Up @@ -210,10 +210,14 @@ class CodeGeneratorContext(val tableConfig: TableConfig) {
*/
def reuseMemberCode(): String = {
val result = reusableMemberStatements.mkString("\n")
if (isCodeSplit) {
if (isCodeSplitMap.nonEmpty) {
val localVariableAsMember = reusableLocalVariableStatements.map(
statements => statements._2.map("private " + _).mkString("\n")
).mkString("\n")
statements => if (isCodeSplitMap.getOrElse(statements._1, false)) {
statements._2.map("private " + _).mkString("\n")
} else {
""
}
).filter(_.length > 0).mkString("\n")
result + "\n" + localVariableAsMember
} else {
result
Expand All @@ -224,8 +228,8 @@ class CodeGeneratorContext(val tableConfig: TableConfig) {
* @return code block of statements that will be placed in the member area of the class
* if generated code is split or in local variables of method
*/
def reuseLocalVariableCode(methodName: String = null): String = {
if (isCodeSplit) {
def reuseLocalVariableCode(methodName: String = currentMethodNameForLocalVariables): String = {
if (isCodeSplitMap.getOrElse(methodName, false)) {
GeneratedExpression.NO_CODE
} else if (methodName == null) {
reusableLocalVariableStatements(currentMethodNameForLocalVariables).mkString("\n")
Expand Down Expand Up @@ -375,8 +379,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) {
clazz: Class[_],
outRecordTerm: String,
outRecordWriterTerm: Option[String] = None): Unit = {
val statement = generateRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm)
reusableMemberStatements.add(statement)
generateRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm, this)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,12 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
outRowWriter: Option[String] = Some(DEFAULT_OUT_RECORD_WRITER_TERM),
reusedOutRow: Boolean = true,
outRowAlreadyExists: Boolean = false,
allowSplit: Boolean = false): GeneratedExpression = {
allowSplit: Boolean = false,
methodName: String = null): GeneratedExpression = {
val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap
generateResultExpression(fieldExprs, fieldExprIdxToOutputRowPosMap, returnType,
returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists, allowSplit)
returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists,
allowSplit, methodName)
}

/**
Expand All @@ -257,7 +259,8 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
outRowWriter: Option[String],
reusedOutRow: Boolean,
outRowAlreadyExists: Boolean,
allowSplit: Boolean)
allowSplit: Boolean,
methodName: String)
: GeneratedExpression = {
// initial type check
if (returnType.getFieldCount != fieldExprs.length) {
Expand Down Expand Up @@ -298,7 +301,11 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
val maxCodeLength = ctx.tableConfig.getMaxGeneratedCodeLength
val setFieldsCode = if (allowSplit && totalLen > maxCodeLength) {
// do the split.
ctx.setCodeSplit()
if (methodName != null) {
ctx.setCodeSplit(methodName)
} else {
ctx.setCodeSplit()
}
setFieldsCodes.map(project => {
val methodName = newName("split")
val method =
Expand All @@ -315,9 +322,8 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
}

val outRowInitCode = if (!outRowAlreadyExists) {
val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter)
val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter, ctx)
if (reusedOutRow) {
ctx.addReusableMember(initCode)
NO_CODE
} else {
initCode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,23 @@ object GenerateUtils {
// --------------------------- General Generate Utils ----------------------------------

/**
* Generates a record declaration statement. The record can be any type of RowData or
* other types.
* Generates a record declaration statement, and add it to reusable member. The record
* can be any type of RowData or other types.
*
* @param t the record type
* @param clazz the specified class of the type (only used when RowType)
* @param recordTerm the record term to be declared
* @param recordWriterTerm the record writer term (only used when BinaryRowData type)
* @return the record declaration statement
* @param ctx the code generator context
* @return the record initialization statement
*/
@tailrec
def generateRecordStatement(
t: LogicalType,
clazz: Class[_],
recordTerm: String,
recordWriterTerm: Option[String] = None)
recordWriterTerm: Option[String] = None,
ctx: CodeGeneratorContext)
: String = t.getTypeRoot match {
// ordered by type root definition
case ROW | STRUCTURED_TYPE if clazz == classOf[BinaryRowData] =>
Expand All @@ -231,26 +233,33 @@ object GenerateUtils {
)
val binaryRowWriter = className[BinaryRowWriter]
val typeTerm = clazz.getCanonicalName
ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});")
ctx.addReusableMember(
s"$binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm);")
s"""
|final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});
|final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm);
|$recordTerm = new $typeTerm(${getFieldCount(t)});
|$writerTerm = new $binaryRowWriter($recordTerm);
|""".stripMargin.trim
case ROW | STRUCTURED_TYPE if clazz == classOf[GenericRowData] ||
clazz == classOf[BoxedWrapperRowData] =>
val typeTerm = clazz.getCanonicalName
s"final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});"
ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});")
s"$recordTerm = new $typeTerm(${getFieldCount(t)});"
case ROW | STRUCTURED_TYPE if clazz == classOf[JoinedRowData] =>
val typeTerm = clazz.getCanonicalName
s"final $typeTerm $recordTerm = new $typeTerm();"
ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm();")
s"$recordTerm = new $typeTerm();"
case DISTINCT_TYPE =>
generateRecordStatement(
t.asInstanceOf[DistinctType].getSourceType,
clazz,
recordTerm,
recordWriterTerm)
recordWriterTerm,
ctx)
case _ =>
val typeTerm = boxedTypeTermForType(t)
s"final $typeTerm $recordTerm = new $typeTerm();"
ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm();")
s"$recordTerm = new $typeTerm();"
}

def generateNullLiteral(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ object ProjectionCodeGenerator {

val outRowInitCode = {
val initCode = generateRecordStatement(
outType, outClass, outRecordTerm, Some(outRecordWriterTerm))
outType, outClass, outRecordTerm, Some(outRecordWriterTerm), ctx)
if (reusedOutRecord) {
ctx.addReusableMember(initCode)
NO_CODE
} else {
initCode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ class AggsHandlerCodeGenerator(
public final class $functionName
implements $NAMESPACE_AGGS_HANDLER_FUNCTION<$namespaceClassName> {

private $namespaceClassName $NAMESPACE_TERM;
${ctx.reuseMemberCode()}

public $functionName(Object[] references) throws Exception {
Expand All @@ -608,14 +609,14 @@ class AggsHandlerCodeGenerator(

@Override
public void merge(Object ns, $ROW_DATA $MERGED_ACC_TERM) throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$mergeCode
}

@Override
public void setAccumulators(Object ns, $ROW_DATA $ACC_TERM)
throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$setAccumulatorsCode
}

Expand All @@ -631,13 +632,13 @@ class AggsHandlerCodeGenerator(

@Override
public $ROW_DATA getValue(Object ns) throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$getValueCode
}

@Override
public void cleanup(Object ns) throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
${ctx.reuseCleanupCode()}
}

Expand Down Expand Up @@ -684,6 +685,7 @@ class AggsHandlerCodeGenerator(
public final class $functionName
implements ${className[NamespaceTableAggsHandleFunction[_]]}<$namespaceClassName> {

private $namespaceClassName $NAMESPACE_TERM;
${ctx.reuseMemberCode()}
private $CONVERT_COLLECTOR_TYPE_TERM $MEMBER_COLLECTOR_TERM;

Expand All @@ -709,14 +711,14 @@ class AggsHandlerCodeGenerator(

@Override
public void merge(Object ns, $ROW_DATA $MERGED_ACC_TERM) throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$mergeCode
}

@Override
public void setAccumulators(Object ns, $ROW_DATA $ACC_TERM)
throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$setAccumulatorsCode
}

Expand All @@ -735,13 +737,13 @@ class AggsHandlerCodeGenerator(
$COLLECTOR<$ROW_DATA> $COLLECTOR_TERM) throws Exception {

$MEMBER_COLLECTOR_TERM.$COLLECTOR_TERM = $COLLECTOR_TERM;
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$emitValueCode
}

@Override
public void cleanup(Object ns) throws Exception {
$namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns;
$NAMESPACE_TERM = ($namespaceClassName) ns;
${ctx.reuseCleanupCode()}
}

Expand Down Expand Up @@ -806,7 +808,9 @@ class AggsHandlerCodeGenerator(
accTypeInfo,
classOf[GenericRowData],
outRow = accTerm,
reusedOutRow = false)
reusedOutRow = false,
allowSplit = true,
methodName = methodName)

s"""
|${ctx.reuseLocalVariableCode(methodName)}
Expand All @@ -829,7 +833,9 @@ class AggsHandlerCodeGenerator(
accTypeInfo,
classOf[GenericRowData],
outRow = accTerm,
reusedOutRow = false)
reusedOutRow = false,
allowSplit = true,
methodName = methodName)

s"""
|${ctx.reuseLocalVariableCode(methodName)}
Expand All @@ -845,7 +851,8 @@ class AggsHandlerCodeGenerator(
// bind input1 as accumulators
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(accTypeInfo, inputTerm = ACC_TERM)
val body = aggBufferCodeGens.map(_.setAccumulator(exprGenerator)).mkString("\n")
val body = splitExpressionsIfNecessary(
aggBufferCodeGens.map(_.setAccumulator(exprGenerator)), methodName)

s"""
|${ctx.reuseLocalVariableCode(methodName)}
Expand All @@ -859,7 +866,8 @@ class AggsHandlerCodeGenerator(
ctx.startNewLocalVariableStatement(methodName)

val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
val body = aggBufferCodeGens.map(_.resetAccumulator(exprGenerator)).mkString("\n")
val body = splitExpressionsIfNecessary(aggBufferCodeGens.map(_.resetAccumulator(exprGenerator)),
methodName)

s"""
|${ctx.reuseLocalVariableCode(methodName)}
Expand All @@ -878,7 +886,8 @@ class AggsHandlerCodeGenerator(
// bind input1 as inputRow
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(inputType, inputTerm = ACCUMULATE_INPUT_TERM)
val body = aggActionCodeGens.map(_.accumulate(exprGenerator)).mkString("\n")
val body = splitExpressionsIfNecessary(
aggActionCodeGens.map(_.accumulate(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(ACCUMULATE_INPUT_TERM)}
Expand All @@ -901,7 +910,8 @@ class AggsHandlerCodeGenerator(
// bind input1 as inputRow
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(inputType, inputTerm = RETRACT_INPUT_TERM)
val body = aggActionCodeGens.map(_.retract(exprGenerator)).mkString("\n")
val body = splitExpressionsIfNecessary(
aggActionCodeGens.map(_.retract(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(RETRACT_INPUT_TERM)}
Expand Down Expand Up @@ -935,7 +945,8 @@ class AggsHandlerCodeGenerator(
// bind input1 as otherAcc
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(mergedAccType, inputTerm = MERGED_ACC_TERM)
val body = aggActionCodeGens.map(_.merge(exprGenerator)).mkString("\n")
val body = splitExpressionsIfNecessary(
aggActionCodeGens.map(_.merge(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(MERGED_ACC_TERM)}
Expand All @@ -947,6 +958,27 @@ class AggsHandlerCodeGenerator(
}
}

private def splitExpressionsIfNecessary(exprs: Array[String], methodName: String): String = {
val totalLen = exprs.map(_.length).sum
val maxCodeLength = ctx.tableConfig.getMaxGeneratedCodeLength
if (totalLen > maxCodeLength) {
ctx.setCodeSplit(methodName)
exprs.map(expr => {
val splitMethodName = newName("split_" + methodName)
val method =
s"""
|private void $splitMethodName() throws Exception {
| $expr
|}
|""".stripMargin
ctx.addReusableMember(method)
s"$splitMethodName();"
}).mkString("\n")
} else {
exprs.mkString("\n")
}
}

private def getWindowExpressions(
windowProperties: Seq[PlannerWindowProperty]): Seq[GeneratedExpression] = {
windowProperties.map {
Expand Down Expand Up @@ -1006,7 +1038,9 @@ class AggsHandlerCodeGenerator(
valueType,
classOf[GenericRowData],
outRow = aggValueTerm,
reusedOutRow = false)
reusedOutRow = false,
allowSplit = true,
methodName = methodName)

s"""
|${ctx.reuseLocalVariableCode(methodName)}
Expand Down
Loading

0 comments on commit 99fca58

Please sign in to comment.