Skip to content

Commit

Permalink
[FLINK-5906] [table] Add support to register UDAGGs for Table API and…
Browse files Browse the repository at this point in the history
… SQL.

This closes apache#3809.
  • Loading branch information
shaoxuan-wang authored and fhueske committed May 4, 2017
1 parent d6435e8 commit 981dea4
Show file tree
Hide file tree
Showing 38 changed files with 1,617 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkRelBuilder, FlinkT
import org.apache.flink.table.catalog.{ExternalCatalog, ExternalCatalogSchema}
import org.apache.flink.table.codegen.{CodeGenerator, ExpressionReducer}
import org.apache.flink.table.expressions.{Alias, Expression, UnresolvedFieldReference}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createScalarSqlFunction, createTableSqlFunctions}
import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, AggregateFunction}
import org.apache.flink.table.plan.cost.DataSetCostFactory
import org.apache.flink.table.plan.logical.{CatalogNode, LogicalRelNode}
import org.apache.flink.table.plan.rules.FlinkRuleSets
Expand Down Expand Up @@ -351,6 +351,27 @@ abstract class TableEnvironment(val config: TableConfig) {
functionCatalog.registerSqlFunctions(sqlFunctions)
}

/**
* Registers an [[AggregateFunction]] under a unique name. Replaces already existing
* user-defined functions under this name.
*/
private[flink] def registerAggregateFunctionInternal[T: TypeInformation, ACC](
name: String, function: AggregateFunction[T, ACC]): Unit = {
// check if class not Scala object
checkNotSingleton(function.getClass)
// check if class could be instantiated
checkForInstantiation(function.getClass)

val typeInfo: TypeInformation[_] = implicitly[TypeInformation[T]]

// register in Table API
functionCatalog.registerFunction(name, function.getClass)

// register in SQL API
val sqlFunctions = createAggregateSqlFunction(name, function, typeInfo, typeFactory)
functionCatalog.registerSqlFunction(sqlFunctions)
}

/**
* Registers a [[Table]] under a unique name in the TableEnvironment's catalog.
* Registered tables can be referenced in SQL queries.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.table.expressions.ExpressionParser
import org.apache.flink.table.api._
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.table.functions.{AggregateFunction, TableFunction}

/**
* The [[TableEnvironment]] for a Java batch [[DataSet]]
Expand Down Expand Up @@ -178,4 +178,24 @@ class BatchTableEnvironment(

registerTableFunctionInternal[T](name, tf)
}

/**
* Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
* Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
* @param f The AggregateFunction to register.
* @tparam T The type of the output value.
* @tparam ACC The type of aggregate accumulator.
*/
def registerFunction[T, ACC](
name: String,
f: AggregateFunction[T, ACC])
: Unit = {
implicit val typeInfo: TypeInformation[T] = TypeExtractor
.createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0)
.asInstanceOf[TypeInformation[T]]

registerAggregateFunctionInternal[T, ACC](name, f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.flink.table.api.java
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.api._
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
import org.apache.flink.table.expressions.ExpressionParser
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
Expand Down Expand Up @@ -180,4 +180,24 @@ class StreamTableEnvironment(

registerTableFunctionInternal[T](name, tf)
}

/**
* Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
* Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
* @param f The AggregateFunction to register.
* @tparam T The type of the output value.
* @tparam ACC The type of aggregate accumulator.
*/
def registerFunction[T, ACC](
name: String,
f: AggregateFunction[T, ACC])
: Unit = {
implicit val typeInfo: TypeInformation[T] = TypeExtractor
.createTypeInfo(f, classOf[AggregateFunction[T, ACC]], f.getClass, 0)
.asInstanceOf[TypeInformation[T]]

registerAggregateFunctionInternal[T, ACC](name, f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.table.functions.{AggregateFunction, TableFunction}

import _root_.scala.reflect.ClassTag

Expand Down Expand Up @@ -151,4 +151,20 @@ class BatchTableEnvironment(
def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
registerTableFunctionInternal(name, tf)
}

/**
* Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
* Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
* @param f The AggregateFunction to register.
* @tparam T The type of the output value.
* @tparam ACC The type of aggregate accumulator.
*/
def registerFunction[T: TypeInformation, ACC](
name: String,
f: AggregateFunction[T, ACC])
: Unit = {
registerAggregateFunctionInternal[T, ACC](name, f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.flink.table.api.scala

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api.{TableEnvironment, Table, TableConfig}
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.table.functions.{AggregateFunction, TableFunction}
import org.apache.flink.table.expressions.Expression
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.scala.asScalaStream
Expand Down Expand Up @@ -152,4 +152,20 @@ class StreamTableEnvironment(
def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
registerTableFunctionInternal(name, tf)
}

/**
* Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
* Registered functions can be referenced in Table API and SQL queries.
*
* @param name The name under which the function is registered.
* @param f The AggregateFunction to register.
* @tparam T The type of the output value.
* @tparam ACC The type of aggregate accumulator.
*/
def registerFunction[T: TypeInformation, ACC](
name: String,
f: AggregateFunction[T, ACC])
: Unit = {
registerAggregateFunctionInternal[T, ACC](name, f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.flink.table.api.{TableException, CurrentRow, CurrentRange, Unb
import org.apache.flink.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval}
import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.AggregateFunction

import scala.language.implicitConversions

Expand Down Expand Up @@ -773,6 +774,8 @@ trait ImplicitExpressionConversions {
implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression =
Literal(sqlTimestamp)
implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
implicit def userDefinedAggFunctionConstructor[T: TypeInformation, ACC]
(udagg: AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg)
}

// ------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ class Table(
*/
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
select(fieldExprs: _*)
//get the correct expression for AggFunctionCall
val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, tableEnv))
select(withResolvedAggFunctionCall: _*)
}

/**
Expand All @@ -167,7 +169,7 @@ class Table(
def as(fields: Expression*): Table = {

logicalPlan match {
case functionCall: LogicalTableFunctionCall if functionCall.child == null => {
case functionCall: LogicalTableFunctionCall if functionCall.child == null =>
// If the logical plan is a TableFunctionCall, we replace its field names to avoid special
// cases during the validation.
if (fields.length != functionCall.output.length) {
Expand All @@ -181,15 +183,14 @@ class Table(
}
new Table(
tableEnv,
new LogicalTableFunctionCall(
LogicalTableFunctionCall(
functionCall.functionName,
functionCall.tableFunction,
functionCall.parameters,
functionCall.resultType,
fields.map(_.asInstanceOf[UnresolvedFieldReference].name).toArray,
functionCall.child)
)
}
case _ =>
// prepend an AliasNode
new Table(tableEnv, AliasNode(fields, logicalPlan).validate(tableEnv))
Expand Down Expand Up @@ -908,7 +909,9 @@ class GroupedTable(
*/
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
select(fieldExprs: _*)
//get the correct expression for AggFunctionCall
val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
select(withResolvedAggFunctionCall: _*)
}
}

Expand Down Expand Up @@ -983,7 +986,9 @@ class OverWindowedTable(

def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
select(fieldExprs: _*)
//get the correct expression for AggFunctionCall
val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
select(withResolvedAggFunctionCall: _*)
}
}

Expand Down Expand Up @@ -1043,7 +1048,9 @@ class WindowGroupedTable(
*/
def select(fields: String): Table = {
val fieldExprs = ExpressionParser.parseExpressionList(fields)
select(fieldExprs: _*)
//get the correct expression for AggFunctionCall
val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, table.tableEnv))
select(withResolvedAggFunctionCall: _*)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.flink.table.codegen

import java.lang.reflect.ParameterizedType
import java.lang.{Iterable => JIterable}
import java.math.{BigDecimal => JBigDecimal}

import org.apache.calcite.avatica.util.DateTimeUtils
Expand Down Expand Up @@ -45,6 +47,7 @@ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.runtime.TableFunctionCollector
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.types.Row
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod, signatureToString}

import scala.collection.JavaConversions._
import scala.collection.mutable
Expand Down Expand Up @@ -258,6 +261,9 @@ class CodeGenerator(
* @param constantFlags An optional parameter to define where to set constant boolean flags in
* the output row.
* @param outputArity The number of fields in the output row.
* @param needRetract a flag to indicate if the aggregate needs the retract method
* @param needMerge a flag to indicate if the aggregate needs the merge method
* @param needReset a flag to indicate if the aggregate needs the resetAccumulator method
*
* @return A GeneratedAggregationsFunction
*/
Expand All @@ -274,27 +280,89 @@ class CodeGenerator(
constantFlags: Option[Array[(Int, Boolean)]],
outputArity: Int,
needRetract: Boolean,
needMerge: Boolean)
needMerge: Boolean,
needReset: Boolean)
: GeneratedAggregationsFunction = {

// get unique function name
val funcName = newName(name)
// register UDAGGs
val aggs = aggregates.map(a => generator.addReusableFunction(a))
// get java types of accumulators
val accTypes = aggregates.map { a =>
a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
val accTypeClasses = aggregates.map { a =>
a.getClass.getMethod("createAccumulator").getReturnType
}
val accTypes = accTypeClasses.map(_.getCanonicalName)

// get java types of input fields
val javaTypes = inputType.getFieldList
.map(f => FlinkTypeFactory.toTypeInfo(f.getType))
.map(t => t.getTypeClass.getCanonicalName)
// get java classes of input fields
val javaClasses = inputType.getFieldList
.map(f => FlinkTypeFactory.toTypeInfo(f.getType).getTypeClass)
// get parameter lists for aggregation functions
val parameters = aggFields.map {inFields =>
val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)"
val parameters = aggFields.map { inFields =>
val fields = for (f <- inFields) yield
s"(${javaClasses(f).getCanonicalName}) input.getField($f)"
fields.mkString(", ")
}
val methodSignaturesList = aggFields.map {
inFields => for (f <- inFields) yield javaClasses(f)
}

// check and validate the needed methods
aggregates.zipWithIndex.map {
case (a, i) => {
getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
.getOrElse(
throw new CodeGenException(
s"No matching accumulate method found for AggregateFunction " +
s"'${a.getClass.getCanonicalName}'" +
s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
)

if (needRetract) {
getUserDefinedMethod(a, "retract", Array(accTypeClasses(i)) ++ methodSignaturesList(i))
.getOrElse(
throw new CodeGenException(
s"No matching retract method found for AggregateFunction " +
s"'${a.getClass.getCanonicalName}'" +
s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
)
}

if (needMerge) {
val methods =
getUserDefinedMethod(a, "merge", Array(accTypeClasses(i), classOf[JIterable[Any]]))
.getOrElse(
throw new CodeGenException(
s"No matching merge method found for AggregateFunction " +
s"${a.getClass.getCanonicalName}'.")
)

var iterableTypeClass = methods.getGenericParameterTypes.apply(1)
.asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0)
// further extract iterableTypeClass if the accumulator has generic type
iterableTypeClass match {
case impl: ParameterizedType => iterableTypeClass = impl.getRawType
case _ =>
}

if (iterableTypeClass != accTypeClasses(i)) {
throw new CodeGenException(
s"merge method in AggregateFunction ${a.getClass.getCanonicalName} does not have " +
s"the correct Iterable type. Actually: ${iterableTypeClass.toString}. " +
s"Expected: ${accTypeClasses(i).toString}")
}
}

if (needReset) {
getUserDefinedMethod(a, "resetAccumulator", Array(accTypeClasses(i)))
.getOrElse(
throw new CodeGenException(
s"No matching resetAccumulator method found for " +
s"aggregate ${a.getClass.getCanonicalName}'.")
)
}
}
}

def genSetAggregationResults: String = {

Expand Down Expand Up @@ -529,9 +597,14 @@ class CodeGenerator(
| ((${accTypes(i)}) accs.getField($i)));""".stripMargin
}.mkString("\n")

j"""$sig {
|$reset
| }""".stripMargin
if (needReset) {
j"""$sig {
|$reset
| }""".stripMargin
} else {
j"""$sig {
| }""".stripMargin
}
}

var funcCode =
Expand Down
Loading

0 comments on commit 981dea4

Please sign in to comment.