diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala index ac5c96d609348..962e58c20f81a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala @@ -26,7 +26,6 @@ import org.apache.flink.table.dataformat.{BinaryStringUtil, Decimal, _} import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.runtime.dataview.StateDataViewStore import org.apache.flink.table.runtime.generated.{AggsHandleFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction} -import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getInternalClassForType import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable @@ -36,11 +35,12 @@ import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical._ import org.apache.flink.types.Row - import java.lang.reflect.Method import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort} import java.util.concurrent.atomic.AtomicInteger +import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField} + object CodeGenUtils { // ------------------------------- DEFAULT TERMS ------------------------------------------ @@ -116,7 +116,28 @@ object CodeGenUtils { /** * Retrieve the canonical name of a class type. */ - def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName + def className[T](implicit m: Manifest[T]): String = { + val name = m.runtimeClass.getCanonicalName + if (name == null) { + throw new CodeGenException( + s"Class '${m.runtimeClass.getName}' does not have a canonical name. " + + s"Make sure it is statically accessible.") + } + name + } + + /** + * Returns a term for representing the given class in Java code. + */ + def typeTerm(clazz: Class[_]): String = { + val name = clazz.getCanonicalName + if (name == null) { + throw new CodeGenException( + s"Class '${clazz.getName}' does not have a canonical name. " + + s"Make sure it is statically accessible.") + } + name + } // when casting we first need to unbox Primitives, for example, // float a = 1.0f; @@ -167,18 +188,6 @@ object CodeGenUtils { case RAW => className[BinaryGeneric[_]] } - /** - * Gets the boxed type term from external type info. - * We only use TypeInformation to store external type info. - */ - def boxedTypeTermForExternalType(t: DataType): String = { - if (t.getConversionClass == null) { - ClassLogicalTypeConverter.getDefaultExternalClassForType(t.getLogicalType).getCanonicalName - } else { - t.getConversionClass.getCanonicalName - } - } - /** * Gets the default value for a primitive type, and null for generic types */ @@ -682,11 +691,11 @@ object CodeGenUtils { genToInternal(ctx, t)(term) def genToInternal(ctx: CodeGeneratorContext, t: DataType): String => String = { - val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t)) if (isConverterIdentity(t)) { - term => s"($iTerm) $term" + term => s"$term" } else { - val eTerm = boxedTypeTermForExternalType(t) + val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t)) + val eTerm = typeTerm(t.getConversionClass) val converter = ctx.addReusableObject( DataFormatConverters.getConverterForDataType(t), "converter") @@ -694,38 +703,70 @@ object CodeGenUtils { } } + /** + * Generates code for converting the given external source data type to the internal data format. + * + * Use this function for converting at the edges of the API. + */ def genToInternalIfNeeded( ctx: CodeGeneratorContext, - t: DataType, - term: String): String = { - if (isInternalClass(t)) { - s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term" + sourceDataType: DataType, + externalTerm: String) + : GeneratedExpression = { + val sourceType = sourceDataType.getLogicalType + val sourceClass = sourceDataType.getConversionClass + // convert external source type to internal format + val internalResultTerm = if (isInternalClass(sourceDataType)) { + s"$externalTerm" } else { - genToInternal(ctx, t, term) + genToInternal(ctx, sourceDataType, externalTerm) + } + // extract null term from result term + if (sourceClass.isPrimitive) { + generateNonNullField(sourceType, internalResultTerm) + } else { + generateInputFieldUnboxing(ctx, sourceType, externalTerm, internalResultTerm) } } - def genToExternal(ctx: CodeGeneratorContext, t: DataType, term: String): String = { - val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(t)) - if (isConverterIdentity(t)) { - s"($iTerm) $term" + def genToExternal( + ctx: CodeGeneratorContext, + targetType: DataType, + internalTerm: String): String = { + if (isConverterIdentity(targetType)) { + s"$internalTerm" } else { - val eTerm = boxedTypeTermForExternalType(t) + val iTerm = boxedTypeTermForType(fromDataTypeToLogicalType(targetType)) + val eTerm = typeTerm(targetType.getConversionClass) val converter = ctx.addReusableObject( - DataFormatConverters.getConverterForDataType(t), + DataFormatConverters.getConverterForDataType(targetType), "converter") - s"($eTerm) $converter.toExternal(($iTerm) $term)" + s"($eTerm) $converter.toExternal(($iTerm) $internalTerm)" } } + /** + * Generates code for converting the internal data format to the given external target data type. + * + * Use this function for converting at the edges of the API. + */ def genToExternalIfNeeded( ctx: CodeGeneratorContext, - t: DataType, - term: String): String = { - if (isInternalClass(t)) { - s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term" + targetDataType: DataType, + internalExpr: GeneratedExpression) + : String = { + val targetType = fromDataTypeToLogicalType(targetDataType) + // convert internal format to target type + val externalResultTerm = if (isInternalClass(targetDataType)) { + s"(${boxedTypeTermForType(targetType)}) ${internalExpr.resultTerm}" + } else { + genToExternal(ctx, targetDataType, internalExpr.resultTerm) + } + // merge null term into the result term + if (targetDataType.getConversionClass.isPrimitive) { + externalResultTerm } else { - genToExternal(ctx, t, term) + s"${internalExpr.nullTerm} ? null : ($externalResultTerm)" } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index 1c60918df1089..eec9b58b14198 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, req import org.apache.flink.table.planner.codegen.GenerateUtils._ import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens._ -import org.apache.flink.table.planner.codegen.calls.{FunctionGenerator, ScalarFunctionCallGen, StringCallGen, TableFunctionCallGen} +import org.apache.flink.table.planner.codegen.calls.{BridgingSqlFunctionCallGen, FunctionGenerator, ScalarFunctionCallGen, StringCallGen, TableFunctionCallGen} import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._ import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction} @@ -41,6 +41,8 @@ import org.apache.calcite.rex._ import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.`type`.{ReturnTypes, SqlTypeName} import org.apache.calcite.util.TimestampString +import org.apache.flink.table.functions.{ScalarFunction, UserDefinedFunction} +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction import scala.collection.JavaConversions._ @@ -482,7 +484,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) case (o@_, _) => o.accept(this) } - generateCallExpression(ctx, call.getOperator, operands, resultType) + generateCallExpression(ctx, call, operands, resultType) } override def visitOver(over: RexOver): GeneratedExpression = @@ -498,10 +500,10 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) private def generateCallExpression( ctx: CodeGeneratorContext, - operator: SqlOperator, + call: RexCall, operands: Seq[GeneratedExpression], resultType: LogicalType): GeneratedExpression = { - operator match { + call.getOperator match { // arithmetic case PLUS if isNumeric(resultType) => val left = operands.head @@ -780,21 +782,25 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) tsf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray)) .generate(ctx, operands, resultType) + case bf: BridgingSqlFunction if bf.getDefinition.isInstanceOf[ScalarFunction] => + new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType) + // advanced scalar functions case sqlOperator: SqlOperator => - StringCallGen.generateCallExpression(ctx, operator, operands, resultType).getOrElse { - FunctionGenerator - .getCallGenerator( - sqlOperator, - operands.map(expr => expr.resultType), - resultType) - .getOrElse( - throw new CodeGenException(s"Unsupported call: " + - s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" + - s"If you think this function should be supported, " + - s"you can create an issue and start a discussion for it.")) - .generate(ctx, operands, resultType) - } + StringCallGen.generateCallExpression(ctx, call.getOperator, operands, resultType) + .getOrElse { + FunctionGenerator + .getCallGenerator( + sqlOperator, + operands.map(expr => expr.resultType), + resultType) + .getOrElse( + throw new CodeGenException(s"Unsupported call: " + + s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" + + s"If you think this function should be supported, " + + s"you can create an issue and start a discussion for it.")) + .generate(ctx, operands, resultType) + } // unknown or invalid case call@_ => diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala index fb32f58c5b850..2855eef380955 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala @@ -570,17 +570,20 @@ object GenerateUtils { * Wrapper types can autoboxed to their corresponding primitive type (Integer -> int). * * @param ctx code generator context which maintains various code statements. - * @param fieldType type of field - * @param fieldTerm expression term of field to be unboxed + * @param inputType type of field + * @param inputTerm expression term of field to be unboxed + * @param inputUnboxingTerm unboxing/conversion term * @return internal unboxed field representation */ def generateInputFieldUnboxing( ctx: CodeGeneratorContext, - fieldType: LogicalType, - fieldTerm: String): GeneratedExpression = { + inputType: LogicalType, + inputTerm: String, + inputUnboxingTerm: String) + : GeneratedExpression = { - val resultTypeTerm = primitiveTypeTermForType(fieldType) - val defaultValue = primitiveDefaultValue(fieldType) + val resultTypeTerm = primitiveTypeTermForType(inputType) + val defaultValue = primitiveDefaultValue(inputType) val Seq(resultTerm, nullTerm) = ctx.addReusableLocalVariables( (resultTypeTerm, "result"), @@ -588,19 +591,19 @@ object GenerateUtils { val wrappedCode = if (ctx.nullCheck) { s""" - |$nullTerm = $fieldTerm == null; + |$nullTerm = $inputTerm == null; |$resultTerm = $defaultValue; |if (!$nullTerm) { - | $resultTerm = $fieldTerm; + | $resultTerm = $inputUnboxingTerm; |} |""".stripMargin.trim } else { s""" - |$resultTerm = $fieldTerm; + |$resultTerm = $inputUnboxingTerm; |""".stripMargin.trim } - GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType) + GeneratedExpression(resultTerm, nullTerm, wrappedCode, inputType) } /** @@ -659,7 +662,7 @@ object GenerateUtils { case _ => val fieldTypeTerm = boxedTypeTermForType(inputType) val inputCode = s"($fieldTypeTerm) $inputTerm" - generateInputFieldUnboxing(ctx, inputType, inputCode) + generateInputFieldUnboxing(ctx, inputType, inputCode, inputCode) } /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala index 9a9bf2e475b4e..d1414670665c4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala @@ -191,7 +191,7 @@ object LookupJoinCodeGenerator { .map { e => val dataType = fromLogicalTypeToDataType(e.resultType) val bType = if (isExternalArgs) { - boxedTypeTermForExternalType(dataType) + typeTerm(dataType.getConversionClass) } else { boxedTypeTermForType(e.resultType) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala index 457310c9e064f..6604ec5ac80ef 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala @@ -103,7 +103,7 @@ class ImperativeAggCodeGen( } else { boxedTypeTermForType(fromDataTypeToLogicalType(externalAccType)) } - val accTypeExternalTerm: String = boxedTypeTermForExternalType(externalAccType) + val accTypeExternalTerm: String = typeTerm(externalAccType.getConversionClass) val argTypes: Array[LogicalType] = { val types = inputTypes ++ constantExprs.map(_.resultType) @@ -250,7 +250,7 @@ class ImperativeAggCodeGen( def getValue(generator: ExprCodeGenerator): GeneratedExpression = { val valueExternalTerm = newName("value_external") - val valueExternalTypeTerm = boxedTypeTermForExternalType(externalResultType) + val valueExternalTypeTerm = typeTerm(externalResultType.getConversionClass) val valueInternalTerm = newName("value_internal") val valueInternalTypeTerm = boxedTypeTermForType(internalResultType) val nullTerm = newName("valueIsNull") @@ -277,8 +277,7 @@ class ImperativeAggCodeGen( if (f >= inputTypes.length) { // index to constant val expr = constantExprs(f - inputTypes.length) - s"${expr.nullTerm} ? null : ${ - genToExternal(ctx, externalInputTypes(index), expr.resultTerm)}" + genToExternalIfNeeded(ctx, externalInputTypes(index), expr) } else { // index to input field val inputRef = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) { @@ -297,8 +296,7 @@ class ImperativeAggCodeGen( var inputExpr = generator.generateExpression(inputRef.accept(rexNodeGen)) if (inputFieldCopy) inputExpr = inputExpr.deepCopy(ctx) codes += inputExpr.code - val term = s"${genToExternal(ctx, externalInputTypes(index), inputExpr.resultTerm)}" - s"${inputExpr.nullTerm} ? null : $term" + genToExternalIfNeeded(ctx, externalInputTypes(index), inputExpr) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala index 389cd7eff8681..95dd1bdd9ce8f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala @@ -554,7 +554,7 @@ object AggCodeGenHelper { val singleIterableClass = classOf[SingleElementIterator[_]].getCanonicalName val externalAccT = getAccumulatorTypeOfAggregateFunction(agg) - val javaField = boxedTypeTermForExternalType(externalAccT) + val javaField = typeTerm(externalAccT.getConversionClass) val tmpAcc = newName("tmpAcc") s""" |final $singleIterableClass accIt$aggIndex = new $singleIterableClass(); @@ -625,11 +625,10 @@ object AggCodeGenHelper { agg, externalAccType, inputExprs.map(_.resultType)) val parameters = inputExprs.zipWithIndex.map { case (expr, i) => - s"${expr.nullTerm} ? null : " + - s"${ genToExternal(ctx, externalUDITypes(i), expr.resultTerm)}" + genToExternalIfNeeded(ctx, externalUDITypes(i), expr) } - val javaTerm = boxedTypeTermForExternalType(externalAccType) + val javaTerm = typeTerm(externalAccType.getConversionClass) val tmpAcc = newName("tmpAcc") val innerCode = s""" diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala new file mode 100644 index 0000000000000..4c1cd35279d23 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.codegen.calls +import java.lang.reflect.Method +import java.util.Collections + +import org.apache.calcite.rex.{RexCall, RexCallBinding} +import org.apache.flink.table.functions.UserDefinedFunctionHelper.SCALAR_EVAL +import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction} +import org.apache.flink.table.planner.codegen.CodeGenUtils.{genToExternalIfNeeded, genToInternalIfNeeded, typeTerm} +import org.apache.flink.table.planner.codegen.{CodeGenException, CodeGeneratorContext, GeneratedExpression} +import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction +import org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala +import org.apache.flink.table.types.DataType +import org.apache.flink.table.types.extraction.utils.ExtractionUtils +import org.apache.flink.table.types.extraction.utils.ExtractionUtils.{createMethodSignatureString, isAssignable, isMethodInvokable, primitiveToWrapper} +import org.apache.flink.table.types.inference.TypeInferenceUtil +import org.apache.flink.table.types.logical.LogicalType + +/** + * Generates a call to a user-defined [[ScalarFunction]] or [[TableFunction]] (future work). + */ +class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator { + + private val function: BridgingSqlFunction = call.getOperator.asInstanceOf[BridgingSqlFunction] + private val udf: UserDefinedFunction = function.getDefinition.asInstanceOf[UserDefinedFunction] + + override def generate( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + returnType: LogicalType) + : GeneratedExpression = { + + val inference = function.getTypeInference + + // we could have implemented a dedicated code generation context but the closer we are to + // Calcite the more consistent is the type inference during the data type enrichment + val callContext = new OperatorBindingCallContext( + function.getDataTypeFactory, + udf, + RexCallBinding.create( + function.getTypeFactory, + call, + Collections.emptyList())) + + // enrich argument types with conversion class + val adaptedCallContext = TypeInferenceUtil.adaptArguments( + inference, + callContext, + null) + val enrichedArgumentDataTypes = toScala(adaptedCallContext.getArgumentDataTypes) + verifyArgumentTypes(operands.map(_.resultType), enrichedArgumentDataTypes) + + // enrich output types with conversion class + val enrichedOutputDataType = TypeInferenceUtil.inferOutputType( + adaptedCallContext, + inference.getOutputTypeStrategy) + verifyOutputType(returnType, enrichedOutputDataType) + + // find runtime method and generate call + verifyImplementation(enrichedArgumentDataTypes, enrichedOutputDataType) + generateFunctionCall(ctx, operands, enrichedArgumentDataTypes, enrichedOutputDataType) + } + + private def generateFunctionCall( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + argumentDataTypes: Seq[DataType], + outputDataType: DataType) + : GeneratedExpression = { + + val functionTerm = ctx.addReusableFunction(udf) + + // operand conversion + val externalOperands = prepareExternalOperands(ctx, operands, argumentDataTypes) + val externalOperandTerms = externalOperands.map(_.resultTerm).mkString(", ") + + // result conversion + val externalResultClass = outputDataType.getConversionClass + val externalResultTypeTerm = typeTerm(externalResultClass) + // Janino does not fully support the JVM spec: + // boolean b = (boolean) f(); where f returns Object + // This is not supported and we need to box manually. + val externalResultClassBoxed = primitiveToWrapper(externalResultClass) + val externalResultCasting = if (externalResultClass == externalResultClassBoxed) { + s"($externalResultTypeTerm)" + } else { + s"($externalResultTypeTerm) (${typeTerm(externalResultClassBoxed)})" + } + val externalResultTerm = ctx.addReusableLocalVariable(externalResultTypeTerm, "externalResult") + val internalExpr = genToInternalIfNeeded(ctx, outputDataType, externalResultTerm) + + // function call + internalExpr.copy(code = + s""" + |${externalOperands.map(_.code).mkString("\n")} + |$externalResultTerm = $externalResultCasting $functionTerm + | .$SCALAR_EVAL($externalOperandTerms); + |${internalExpr.code} + |""".stripMargin) + } + + private def prepareExternalOperands( + ctx: CodeGeneratorContext, + operands: Seq[GeneratedExpression], + argumentDataTypes: Seq[DataType]) + : Seq[GeneratedExpression] = { + operands + .zip(argumentDataTypes) + .map { case (operand, dataType) => + operand.copy(resultTerm = genToExternalIfNeeded(ctx, dataType, operand)) + } + } + + private def verifyArgumentTypes( + operandTypes: Seq[LogicalType], + enrichedDataTypes: Seq[DataType]) + : Unit = { + val enrichedTypes = enrichedDataTypes.map(_.getLogicalType) + operandTypes.zip(enrichedTypes).foreach { case (operandType, enrichedType) => + // check that the logical type has not changed during the enrichment + // a nullability mismatch is acceptable if the enriched type can handle it + if (operandType != enrichedType && operandType.copy(true) != enrichedType) { + throw new CodeGenException( + s"Mismatch of function's argument data type '$enrichedType' and actual " + + s"argument type '$operandType'.") + } + } + // the data type class can only partially verify the conversion class, + // now is the time for the final check + enrichedDataTypes.foreach(dataType => { + if (!dataType.getLogicalType.supportsOutputConversion(dataType.getConversionClass)) { + throw new CodeGenException( + s"Data type '$dataType' does not support an output conversion " + + s"to class '${dataType.getConversionClass}'.") + } + }) + } + + private def verifyOutputType( + outputType: LogicalType, + enrichedDataType: DataType) + : Unit = { + val enrichedType = enrichedDataType.getLogicalType + // check that the logical type has not changed during the enrichment + // a nullability mismatch is acceptable if the output type can handle it + if (outputType != enrichedType && outputType != enrichedType.copy(true)) { + throw new CodeGenException( + s"Mismatch of expected output data type '$outputType' and function's " + + s"output type '$enrichedType'.") + } + // the data type class can only partially verify the conversion class, + // now is the time for the final check + if (!enrichedType.supportsInputConversion(enrichedDataType.getConversionClass)) { + throw new CodeGenException( + s"Data type '$enrichedDataType' does not support an input conversion " + + s"to class '${enrichedDataType.getConversionClass}'.") + } + } + + private def verifyImplementation( + argumentDataTypes: Seq[DataType], + outputDataType: DataType) + : Unit = { + val methods = toScala(ExtractionUtils.collectMethods(udf.getClass, SCALAR_EVAL)) + val argumentClasses = argumentDataTypes.map(_.getConversionClass).toArray + val outputClass = outputDataType.getConversionClass + // verifies regular JVM calling semantics + def methodMatches(method: Method): Boolean = { + isMethodInvokable(method, argumentClasses: _*) && + isAssignable(outputClass, method.getReturnType, true) + } + if (!methods.exists(methodMatches)) { + throw new CodeGenException( + s"Could not find an implementation method in class '${typeTerm(udf.getClass)}' for " + + s"function '$function' that matches the following signature: \n" + + s"${createMethodSignatureString(SCALAR_EVAL, argumentClasses, outputClass)}") + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarFunctionCallGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarFunctionCallGen.scala index 4001b40dd63d0..3ef085f5c5aa3 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarFunctionCallGen.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarFunctionCallGen.scala @@ -30,6 +30,7 @@ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType +import org.apache.flink.table.types.DataType import org.apache.flink.table.types.extraction.utils.ExtractionUtils import org.apache.flink.table.types.logical.LogicalType import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType @@ -107,7 +108,7 @@ class ScalarFunctionCallGen(scalarFunction: ScalarFunction) extends CallGenerato val resultUnboxing = if (resultClass.isPrimitive) { GenerateUtils.generateNonNullField(returnType, resultTerm) } else { - GenerateUtils.generateInputFieldUnboxing(ctx, returnType, resultTerm) + GenerateUtils.generateInputFieldUnboxing(ctx, returnType, resultTerm, resultTerm) } resultUnboxing.copy(code = s""" @@ -126,6 +127,17 @@ class ScalarFunctionCallGen(scalarFunction: ScalarFunction) extends CallGenerato prepareFunctionArgs(ctx, operands, paramClasses, func.getParameterTypes(paramClasses)) } + def genToInternalIfNeeded( + ctx: CodeGeneratorContext, + t: DataType, + term: String): String = { + if (isInternalClass(t)) { + s"(${boxedTypeTermForType(LogicalTypeDataTypeConverter.fromDataTypeToLogicalType(t))}) $term" + } else { + genToInternal(ctx, t, term) + } + } + } object ScalarFunctionCallGen { @@ -161,10 +173,8 @@ object ScalarFunctionCallGen { } else { signatureTypes(i) } - val externalResultTerm = genToExternalIfNeeded( - ctx, signatureType, operandExpr.resultTerm) - val exprOrNull = s"${operandExpr.nullTerm} ? null : ($externalResultTerm)" - operandExpr.copy(resultTerm = exprOrNull) + val externalResultTerm = genToExternalIfNeeded(ctx, signatureType, operandExpr) + operandExpr.copy(resultTerm = externalResultTerm) } } } diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java index 7aa167bc5cb5c..d02cdda08e200 100644 --- a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java @@ -18,27 +18,37 @@ package org.apache.flink.table.planner.runtime.stream.sql; +import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.InputGroup; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.catalog.Catalog; import org.apache.flink.table.catalog.CatalogFunction; +import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.catalog.ObjectPath; import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.table.planner.codegen.CodeGenException; import org.apache.flink.table.planner.factories.utils.TestCollectionTableFactory; import org.apache.flink.table.planner.runtime.utils.StreamingTestBase; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.inference.TypeStrategies; +import org.apache.flink.table.utils.EncodingUtils; import org.apache.flink.types.Row; import org.junit.Test; -import java.util.ArrayList; +import java.math.BigDecimal; import java.util.Arrays; import java.util.List; +import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage; /** * Tests for catalog and system in stream table environment. @@ -400,7 +410,7 @@ private void testUserDefinedCatalogFunction(String createFunctionDDL) throws Exc ); TestCollectionTableFactory.reset(); - TestCollectionTableFactory.initData(sourceData, new ArrayList<>(), -1); + TestCollectionTableFactory.initData(sourceData); String sourceDDL = "create table t1(a int, b varchar, c int) with ('connector' = 'COLLECTION')"; String sinkDDL = "create table t2(a int, b varchar, c int) with ('connector' = 'COLLECTION')"; @@ -421,4 +431,178 @@ private void testUserDefinedCatalogFunction(String createFunctionDDL) throws Exc tEnv().sqlUpdate("drop table t1"); tEnv().sqlUpdate("drop table t2"); } + + @Test + public void testPrimitiveScalarFunction() throws Exception { + final List sourceData = Arrays.asList( + Row.of(1, 1L, "-"), + Row.of(2, 2L, "--"), + Row.of(3, 3L, "---") + ); + + final List sinkData = Arrays.asList( + Row.of(1, 3L, "-"), + Row.of(2, 6L, "--"), + Row.of(3, 9L, "---") + ); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + + tEnv().sqlUpdate("CREATE TABLE TestTable(a INT NOT NULL, b BIGINT NOT NULL, c STRING) WITH ('connector' = 'COLLECTION')"); + + tEnv().createTemporarySystemFunction("PrimitiveScalarFunction", PrimitiveScalarFunction.class); + tEnv().sqlUpdate("INSERT INTO TestTable SELECT a, PrimitiveScalarFunction(a, b, c), c FROM TestTable"); + tEnv().execute("Test Job"); + + assertThat(TestCollectionTableFactory.getResult(), equalTo(sinkData)); + } + + @Test + public void testComplexScalarFunction() throws Exception { + final List sourceData = Arrays.asList( + Row.of(1, new byte[]{1, 2, 3}), + Row.of(2, new byte[]{2, 3, 4}), + Row.of(3, new byte[]{3, 4, 5}), + Row.of(null, null) + ); + + final List sinkData = Arrays.asList( + Row.of(1, "1+2012-12-12 12:12:12.123456789", "[1, 2, 3]+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "[1, 2, 3]"), + Row.of(2, "2+2012-12-12 12:12:12.123456789", "[2, 3, 4]+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "[2, 3, 4]"), + Row.of(3, "3+2012-12-12 12:12:12.123456789", "[3, 4, 5]+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "[3, 4, 5]"), + Row.of(null, "null+2012-12-12 12:12:12.123456789", "null+2012-12-12 12:12:12.123456789", new BigDecimal("123.40"), "null") + ); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + + tEnv().sqlUpdate( + "CREATE TABLE SourceTable(i INT, b BYTES) " + + "WITH ('connector' = 'COLLECTION')"); + tEnv().sqlUpdate( + "CREATE TABLE SinkTable(i INT, s1 STRING, s2 STRING, d DECIMAL(5, 2), s3 STRING) " + + "WITH ('connector' = 'COLLECTION')"); + + tEnv().createTemporarySystemFunction("ComplexScalarFunction", ComplexScalarFunction.class); + tEnv().sqlUpdate( + "INSERT INTO SinkTable " + + "SELECT " + + " i, " + + " ComplexScalarFunction(i, TIMESTAMP '2012-12-12 12:12:12.123456789'), " + + " ComplexScalarFunction(b, TIMESTAMP '2012-12-12 12:12:12.123456789')," + + " ComplexScalarFunction(), " + + " ComplexScalarFunction(b) " + + "FROM SourceTable"); + tEnv().execute("Test Job"); + + assertThat(TestCollectionTableFactory.getResult(), equalTo(sinkData)); + } + + @Test + public void testCustomScalarFunction() throws Exception { + final List sourceData = Arrays.asList( + Row.of(1), + Row.of(2), + Row.of(3), + Row.of((Integer) null) + ); + + final List sinkData = Arrays.asList( + Row.of(1, 1, 5), + Row.of(2, 2, 5), + Row.of(3, 3, 5), + Row.of(null, null, 5) + ); + + TestCollectionTableFactory.reset(); + TestCollectionTableFactory.initData(sourceData); + + tEnv().sqlUpdate("CREATE TABLE SourceTable(i INT) WITH ('connector' = 'COLLECTION')"); + tEnv().sqlUpdate("CREATE TABLE SinkTable(i1 INT, i2 INT, i3 INT) WITH ('connector' = 'COLLECTION')"); + + tEnv().createTemporarySystemFunction("CustomScalarFunction", CustomScalarFunction.class); + tEnv().sqlUpdate( + "INSERT INTO SinkTable " + + "SELECT " + + " i, " + + " CustomScalarFunction(i), " + + " CustomScalarFunction(CAST(NULL AS INT), 5, i, i) " + + "FROM SourceTable"); + tEnv().execute("Test Job"); + + assertThat(TestCollectionTableFactory.getResult(), equalTo(sinkData)); + } + + @Test + public void testInvalidCustomScalarFunction() { + tEnv().sqlUpdate("CREATE TABLE SinkTable(s STRING) WITH ('connector' = 'COLLECTION')"); + + tEnv().createTemporarySystemFunction("CustomScalarFunction", CustomScalarFunction.class); + try { + tEnv().sqlUpdate( + "INSERT INTO SinkTable " + + "SELECT CustomScalarFunction('test')"); + fail(); + } catch (CodeGenException e) { + assertThat( + e, + hasMessage( + equalTo( + "Could not find an implementation method in class '" + CustomScalarFunction.class.getCanonicalName() + + "' for function 'CustomScalarFunction' that matches the following signature: \n" + + "java.lang.String eval(java.lang.String)"))); + } + } + + // -------------------------------------------------------------------------------------------- + // Test functions + // -------------------------------------------------------------------------------------------- + + /** + * Function that takes and returns primitives. + */ + public static class PrimitiveScalarFunction extends ScalarFunction { + public long eval(int i, long l, String s) { + return i + l + s.length(); + } + } + + /** + * Function that is overloaded and takes use of annotations. + */ + public static class ComplexScalarFunction extends ScalarFunction { + public String eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object o, java.sql.Timestamp t) { + return EncodingUtils.objectToString(o) + "+" + t.toString(); + } + + public @DataTypeHint("DECIMAL(5, 2)") BigDecimal eval() { + return new BigDecimal("123.4"); // 1 digit is missing + } + + public String eval(byte[] bytes) { + return Arrays.toString(bytes); + } + } + + /** + * Function that has a custom type inference that is broader than the actual implementation. + */ + public static class CustomScalarFunction extends ScalarFunction { + public Integer eval(Integer... args) { + for (Integer o : args) { + if (o != null) { + return o; + } + } + return null; + } + + @Override + public TypeInference getTypeInference(DataTypeFactory typeFactory) { + return TypeInference.newBuilder() + .outputTypeStrategy(TypeStrategies.argument(0)) + .build(); + } + } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/factories/utils/TestCollectionTableFactory.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/factories/utils/TestCollectionTableFactory.scala index db1a39ad43b96..01c8fb8631286 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/factories/utils/TestCollectionTableFactory.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/factories/utils/TestCollectionTableFactory.scala @@ -97,6 +97,10 @@ object TestCollectionTableFactory { val RESULT = new JLinkedList[Row]() private var emitIntervalMS = -1L + def initData(sourceData: JList[Row]): Unit ={ + initData(sourceData, List(), -1L) + } + def initData(sourceData: JList[Row], dimData: JList[Row] = List(), emitInterval: Long = -1L): Unit ={ @@ -112,6 +116,8 @@ object TestCollectionTableFactory { emitIntervalMS = -1L } + def getResult: util.List[Row] = RESULT + def getCollectionSource(props: JMap[String, String]): CollectionTableSource = { val properties = new DescriptorProperties() properties.putProperties(props)