Skip to content

Commit

Permalink
[FLINK-10958] [table] Add UDF's eval method parameters support subcla…
Browse files Browse the repository at this point in the history
…ss matching.

This closes apache#7152
  • Loading branch information
dianfu authored and sunjincheng121 committed Nov 26, 2018
1 parent 6e62ca2 commit d9c7f97
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,17 @@ object UserDefinedFunctionUtils {
// match parameters of signature to actual parameters
methodSignature.length == signatures.length &&
signatures.zipWithIndex.forall { case (clazz, i) =>
parameterTypeEquals(methodSignature(i), clazz)
parameterTypeApplicable(methodSignature(i), clazz)
}
case cur if cur.isVarArgs =>
val signatures = cur.getParameterTypes
methodSignature.zipWithIndex.forall {
// non-varargs
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(clazz, signatures(i))
parameterTypeApplicable(clazz, signatures(i))
// varargs
case (clazz, i) if i >= signatures.length - 1 =>
parameterTypeEquals(clazz, signatures.last.getComponentType)
parameterTypeApplicable(clazz, signatures.last.getComponentType)
} || (methodSignature.isEmpty && signatures.length == 1) // empty varargs
}

Expand All @@ -171,26 +171,57 @@ object UserDefinedFunctionUtils {
fixedMethodsCount > 0 && !cur.isVarArgs ||
fixedMethodsCount == 0 && cur.isVarArgs
}
val maximallySpecific = if (found.length > 1) {
implicit val methodOrdering = new scala.Ordering[Method] {
override def compare(x: Method, y: Method): Int = {
def specificThan(left: Method, right: Method) = {
// left parameter type is more specific than right parameter type
left.getParameterTypes.zip(right.getParameterTypes).forall {
case (leftParameterType, rightParameterType) =>
parameterTypeApplicable(leftParameterType, rightParameterType)
} &&
// non-equal
left.getParameterTypes.zip(right.getParameterTypes).exists {
case (leftParameterType, rightParameterType) =>
!parameterTypeEquals(leftParameterType, rightParameterType)
}
}

if (specificThan(x, y)) {
1
} else if (specificThan(y, x)) {
-1
} else {
0
}
}
}

val max = found.max
found.filter(methodOrdering.compare(max, _) == 0)
} else {
found
}

// check if there is a Scala varargs annotation
if (found.isEmpty &&
if (maximallySpecific.isEmpty &&
methods.exists { method =>
val signatures = method.getParameterTypes
signatures.zipWithIndex.forall {
case (clazz, i) if i < signatures.length - 1 =>
parameterTypeEquals(methodSignature(i), clazz)
parameterTypeApplicable(methodSignature(i), clazz)
case (clazz, i) if i == signatures.length - 1 =>
clazz.getName.equals("scala.collection.Seq")
}
}) {
throw new ValidationException(
s"Scala-style variable arguments in '$methodName' methods are not supported. Please " +
s"add a @scala.annotation.varargs annotation.")
} else if (found.length > 1) {
} else if (maximallySpecific.length > 1) {
throw new ValidationException(
s"Found multiple '$methodName' methods which match the signature.")
}
found.headOption
maximallySpecific.headOption
}

/**
Expand Down Expand Up @@ -719,19 +750,22 @@ object UserDefinedFunctionUtils {
* Compares parameter candidate classes with expected classes. If true, the parameters match.
* Candidate can be null (acts as a wildcard).
*/
private def parameterTypeApplicable(candidate: Class[_], expected: Class[_]): Boolean =
parameterTypeEquals(candidate, expected) ||
((expected != null && expected.isAssignableFrom(candidate)) ||
expected.isPrimitive && Primitives.wrap(expected).isAssignableFrom(candidate))

private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
candidate == null ||
candidate == expected ||
expected == classOf[Object] ||
expected.isPrimitive && Primitives.wrap(expected) == candidate ||
// time types
candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) ||
// arrays
(candidate.isArray && expected.isArray &&
(candidate.getComponentType == expected.getComponentType ||
expected.getComponentType == classOf[Object]))
(candidate.getComponentType == expected.getComponentType))

/**
* Creates a [[LogicalTableFunctionCall]] by parsing a String expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,26 @@ class Func20 extends ScalarFunction {
}
}

object Func21 extends ScalarFunction {
def eval(p: People): String = {
p.name
}

def eval(p: Student): String = {
"student#" + p.name
}
}

object Func22 extends ScalarFunction {
def eval(a: Array[People]): String = {
a.head.name
}

def eval(a: Array[Student]): String = {
"student#" + a.head.name
}
}

class SplitUDF(deterministic: Boolean) extends ScalarFunction {
def eval(x: String, sep: String, index: Int): String = {
val splits = StringUtils.splitByWholeSeparator(x, sep)
Expand All @@ -321,3 +341,9 @@ class SplitUDF(deterministic: Boolean) extends ScalarFunction {
}
override def isDeterministic: Boolean = deterministic
}

class People(val name: String)

class Student(name: String) extends People(name)

class GraduatedStudent(name: String) extends Student(name)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.table.api.TableEnvironment
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.Literal
import org.apache.flink.table.expressions.utils.{Func13, RichFunc1, RichFunc2, SplitUDF}
import org.apache.flink.table.expressions.utils._
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, UserDefinedFunctionTestUtils}
import org.apache.flink.test.util.AbstractTestBase
import org.apache.flink.types.Row
Expand Down Expand Up @@ -350,4 +350,62 @@ class CalcITCase extends AbstractTestBase {
"{9=Comment#3}")
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

@Test
def testOverload(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)

StreamITCase.testResults = mutable.MutableList()

val testData = new mutable.MutableList[GraduatedStudent]
testData.+=(new GraduatedStudent("Jack#22"))
testData.+=(new GraduatedStudent("John#19"))
testData.+=(new GraduatedStudent("Anna#44"))
testData.+=(new GraduatedStudent("nosharp"))

val t = env.fromCollection(testData).toTable(tEnv).as('a)

val result = t.select(Func21('a))

result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = mutable.MutableList(
"student#Jack#22",
"student#John#19",
"student#Anna#44",
"student#nosharp"
)
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}

@Test
def testOverloadWithArray(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env)

StreamITCase.testResults = mutable.MutableList()

val testData = new mutable.MutableList[Array[GraduatedStudent]]
testData.+=(Array(new GraduatedStudent("Jack#22")))
testData.+=(Array(new GraduatedStudent("John#19")))
testData.+=(Array(new GraduatedStudent("Anna#44")))
testData.+=(Array(new GraduatedStudent("nosharp")))

val t = env.fromCollection(testData).toTable(tEnv).as('a)

val result = t.select(Func22('a))

result.addSink(new StreamITCase.StringSink[Row])
env.execute()

val expected = mutable.MutableList(
"student#Jack#22",
"student#John#19",
"student#Anna#44",
"student#nosharp"
)
assertEquals(expected.sorted, StreamITCase.testResults.sorted)
}
}

0 comments on commit d9c7f97

Please sign in to comment.