Skip to content

Commit

Permalink
[FLINK-15487][table] Update code generation for new type inference
Browse files Browse the repository at this point in the history
This updates the code generation for the new type inference and thus
completes FLINK-15487. Scalar function work with the types supported
by the planner. Tests added in this PR only test basic behavior. We
will need more tests per data type. But this is a follow up issue.

This closes apache#10960.
  • Loading branch information
twalthr committed Feb 5, 2020
1 parent 75248b4 commit 5e6e851
Show file tree
Hide file tree
Showing 10 changed files with 525 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.flink.table.dataformat.{BinaryStringUtil, Decimal, _}
import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.runtime.dataview.StateDataViewStore
import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction}
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter
import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getInternalClassForType
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
Expand All @@ -36,11 +35,12 @@ import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical._
import org.apache.flink.types.Row

import java.lang.reflect.Method
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort}
import java.util.concurrent.atomic.AtomicInteger

import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField}

object CodeGenUtils {

// ------------------------------- DEFAULT TERMS ------------------------------------------
Expand Down Expand Up @@ -116,7 +116,28 @@ object CodeGenUtils {
/**
* Retrieve the canonical name of a class type.
*/
def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName
def className[T](implicit m: Manifest[T]): String = {
val name = m.runtimeClass.getCanonicalName
if (name == null) {
throw new CodeGenException(
s"Class '${m.runtimeClass.getName}' does not have a canonical name. " +
s"Make sure it is statically accessible.")
}
name
}

/**
* Returns a term for representing the given class in Java code.
*/
def typeTerm(clazz: Class[_]): String = {
val name = clazz.getCanonicalName
if (name == null) {
throw new CodeGenException(
s"Class '${clazz.getName}' does not have a canonical name. " +
s"Make sure it is statically accessible.")
}
name
}

// when casting we first need to unbox Primitives, for example,
// float a = 1.0f;
Expand Down Expand Up @@ -167,18 +188,6 @@ object CodeGenUtils {
case RAW => className[BinaryGeneric[_]]
}

/**
* Gets the boxed type term from external type info.
* We only use TypeInformation to store external type info.
*/
def boxedTypeTermForExternalType(t: DataType): String = {
if (t.getConversionClass == null) {
ClassLogicalTypeConverter.getDefaultExternalClassForType(t.getLogicalType).getCanonicalName
} else {
t.getConversionClass.getCanonicalName
}
}

/**
* Gets the default value for a primitive type, and null for generic types
*/
Expand Down Expand Up @@ -682,50 +691,82 @@ object CodeGenUtils {
genToInternal(ctx, t)(term)

def genToInternal(ctx: CodeGeneratorContext, t: DataType): String => String = {
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t))
if (isConverterIdentity(t)) {
term => s"($iTerm) $term"
term => s"$term"
} else {
val eTerm = boxedTypeTermForExternalType(t)
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t))
val eTerm = typeTerm(t.getConversionClass)
val converter = ctx.addReusableObject(
DataFormatConverters.getConverterForDataType(t),
"converter")
term => s"($iTerm) $converter.toInternal(($eTerm) $term)"
}
}

/**
* Generates code for converting the given external source data type to the internal data format.
*
* Use this function for converting at the edges of the API.
*/
def genToInternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
term: String): String = {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term"
sourceDataType: DataType,
externalTerm: String)
: GeneratedExpression = {
val sourceType = sourceDataType.getLogicalType
val sourceClass = sourceDataType.getConversionClass
// convert external source type to internal format
val internalResultTerm = if (isInternalClass(sourceDataType)) {
s"$externalTerm"
} else {
genToInternal(ctx, t, term)
genToInternal(ctx, sourceDataType, externalTerm)
}
// extract null term from result term
if (sourceClass.isPrimitive) {
generateNonNullField(sourceType, internalResultTerm)
} else {
generateInputFieldUnboxing(ctx, sourceType, externalTerm, internalResultTerm)
}
}

def genToExternal(ctx: CodeGeneratorContext, t: DataType, term: String): String = {
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t))
if (isConverterIdentity(t)) {
s"($iTerm) $term"
def genToExternal(
ctx: CodeGeneratorContext,
targetType: DataType,
internalTerm: String): String = {
if (isConverterIdentity(targetType)) {
s"$internalTerm"
} else {
val eTerm = boxedTypeTermForExternalType(t)
val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(targetType))
val eTerm = typeTerm(targetType.getConversionClass)
val converter = ctx.addReusableObject(
DataFormatConverters.getConverterForDataType(t),
DataFormatConverters.getConverterForDataType(targetType),
"converter")
s"($eTerm) $converter.toExternal(($iTerm) $term)"
s"($eTerm) $converter.toExternal(($iTerm) $internalTerm)"
}
}

/**
* Generates code for converting the internal data format to the given external target data type.
*
* Use this function for converting at the edges of the API.
*/
def genToExternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
term: String): String = {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term"
targetDataType: DataType,
internalExpr: GeneratedExpression)
: String = {
val targetType = fromDataTypeToLogicalType(targetDataType)
// convert internal format to target type
val externalResultTerm = if (isInternalClass(targetDataType)) {
s"(${boxedTypeTermForType(targetType)}) ${internalExpr.resultTerm}"
} else {
genToExternal(ctx, targetDataType, internalExpr.resultTerm)
}
// merge null term into the result term
if (targetDataType.getConversionClass.isPrimitive) {
externalResultTerm
} else {
genToExternal(ctx, t, term)
s"${internalExpr.nullTerm} ? null : ($externalResultTerm)"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, req
import org.apache.flink.table.planner.codegen.GenerateUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens._
import org.apache.flink.table.planner.codegen.calls.{FunctionGenerator, ScalarFunctionCallGen, StringCallGen, TableFunctionCallGen}
import org.apache.flink.table.planner.codegen.calls.{BridgingSqlFunctionCallGen, FunctionGenerator, ScalarFunctionCallGen, StringCallGen, TableFunctionCallGen}
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._
import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction
import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction}
Expand All @@ -41,6 +41,8 @@ import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName}
import org.apache.calcite.util.TimestampString
import org.apache.flink.table.functions.{ScalarFunction, UserDefinedFunction}
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction

import scala.collection.JavaConversions._

Expand Down Expand Up @@ -482,7 +484,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
case (o@_, _) => o.accept(this)
}

generateCallExpression(ctx, call.getOperator, operands, resultType)
generateCallExpression(ctx, call, operands, resultType)
}

override def visitOver(over: RexOver): GeneratedExpression =
Expand All @@ -498,10 +500,10 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)

private def generateCallExpression(
ctx: CodeGeneratorContext,
operator: SqlOperator,
call: RexCall,
operands: Seq[GeneratedExpression],
resultType: LogicalType): GeneratedExpression = {
operator match {
call.getOperator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
val left = operands.head
Expand Down Expand Up @@ -780,21 +782,25 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
tsf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)

case bf: BridgingSqlFunction if bf.getDefinition.isInstanceOf[ScalarFunction] =>
new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)

// advanced scalar functions
case sqlOperator: SqlOperator =>
StringCallGen.generateCallExpression(ctx, operator, operands, resultType).getOrElse {
FunctionGenerator
.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}
StringCallGen.generateCallExpression(ctx, call.getOperator, operands, resultType)
.getOrElse {
FunctionGenerator
.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}

// unknown or invalid
case call@_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,37 +570,40 @@ object GenerateUtils {
* Wrapper types can autoboxed to their corresponding primitive type (Integer -> int).
*
* @param ctx code generator context which maintains various code statements.
* @param fieldType type of field
* @param fieldTerm expression term of field to be unboxed
* @param inputType type of field
* @param inputTerm expression term of field to be unboxed
* @param inputUnboxingTerm unboxing/conversion term
* @return internal unboxed field representation
*/
def generateInputFieldUnboxing(
ctx: CodeGeneratorContext,
fieldType: LogicalType,
fieldTerm: String): GeneratedExpression = {
inputType: LogicalType,
inputTerm: String,
inputUnboxingTerm: String)
: GeneratedExpression = {

val resultTypeTerm = primitiveTypeTermForType(fieldType)
val defaultValue = primitiveDefaultValue(fieldType)
val resultTypeTerm = primitiveTypeTermForType(inputType)
val defaultValue = primitiveDefaultValue(inputType)

val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables(
(resultTypeTerm, "result"),
("boolean", "isNull"))

val wrappedCode = if (ctx.nullCheck) {
s"""
|$nullTerm = $fieldTerm == null;
|$nullTerm = $inputTerm == null;
|$resultTerm = $defaultValue;
|if (!$nullTerm) {
| $resultTerm = $fieldTerm;
| $resultTerm = $inputUnboxingTerm;
|}
|""".stripMargin.trim
} else {
s"""
|$resultTerm = $fieldTerm;
|$resultTerm = $inputUnboxingTerm;
|""".stripMargin.trim
}

GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType)
GeneratedExpression(resultTerm, nullTerm, wrappedCode, inputType)
}

/**
Expand Down Expand Up @@ -659,7 +662,7 @@ object GenerateUtils {
case _ =>
val fieldTypeTerm = boxedTypeTermForType(inputType)
val inputCode = s"($fieldTypeTerm) $inputTerm"
generateInputFieldUnboxing(ctx, inputType, inputCode)
generateInputFieldUnboxing(ctx, inputType, inputCode, inputCode)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ object LookupJoinCodeGenerator {
.map { e =>
val dataType = fromLogicalTypeToDataType(e.resultType)
val bType = if (isExternalArgs) {
boxedTypeTermForExternalType(dataType)
typeTerm(dataType.getConversionClass)
} else {
boxedTypeTermForType(e.resultType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ImperativeAggCodeGen(
} else {
boxedTypeTermForType(fromDataTypeToLogicalType(externalAccType))
}
val accTypeExternalTerm: String = boxedTypeTermForExternalType(externalAccType)
val accTypeExternalTerm: String = typeTerm(externalAccType.getConversionClass)

val argTypes: Array[LogicalType] = {
val types = inputTypes ++ constantExprs.map(_.resultType)
Expand Down Expand Up @@ -250,7 +250,7 @@ class ImperativeAggCodeGen(

def getValue(generator: ExprCodeGenerator): GeneratedExpression = {
val valueExternalTerm = newName("value_external")
val valueExternalTypeTerm = boxedTypeTermForExternalType(externalResultType)
val valueExternalTypeTerm = typeTerm(externalResultType.getConversionClass)
val valueInternalTerm = newName("value_internal")
val valueInternalTypeTerm = boxedTypeTermForType(internalResultType)
val nullTerm = newName("valueIsNull")
Expand All @@ -277,8 +277,7 @@ class ImperativeAggCodeGen(
if (f >= inputTypes.length) {
// index to constant
val expr = constantExprs(f - inputTypes.length)
s"${expr.nullTerm} ? null : ${
genToExternal(ctx, externalInputTypes(index), expr.resultTerm)}"
genToExternalIfNeeded(ctx, externalInputTypes(index), expr)
} else {
// index to input field
val inputRef = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) {
Expand All @@ -297,8 +296,7 @@ class ImperativeAggCodeGen(
var inputExpr = generator.generateExpression(inputRef.accept(rexNodeGen))
if (inputFieldCopy) inputExpr = inputExpr.deepCopy(ctx)
codes += inputExpr.code
val term = s"${genToExternal(ctx, externalInputTypes(index), inputExpr.resultTerm)}"
s"${inputExpr.nullTerm} ? null : $term"
genToExternalIfNeeded(ctx, externalInputTypes(index), inputExpr)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ object AggCodeGenHelper {
val singleIterableClass = classOf[SingleElementIterator[_]].getCanonicalName

val externalAccT = getAccumulatorTypeOfAggregateFunction(agg)
val javaField = boxedTypeTermForExternalType(externalAccT)
val javaField = typeTerm(externalAccT.getConversionClass)
val tmpAcc = newName("tmpAcc")
s"""
|final $singleIterableClass accIt$aggIndex = new $singleIterableClass();
Expand Down Expand Up @@ -625,11 +625,10 @@ object AggCodeGenHelper {
agg, externalAccType, inputExprs.map(_.resultType))
val parameters = inputExprs.zipWithIndex.map {
case (expr, i) =>
s"${expr.nullTerm} ? null : " +
s"${ genToExternal(ctx, externalUDITypes(i), expr.resultTerm)}"
genToExternalIfNeeded(ctx, externalUDITypes(i), expr)
}

val javaTerm = boxedTypeTermForExternalType(externalAccType)
val javaTerm = typeTerm(externalAccType.getConversionClass)
val tmpAcc = newName("tmpAcc")
val innerCode =
s"""
Expand Down
Loading

0 comments on commit 5e6e851

Please sign in to comment.