Skip to content

Commit

Permalink
[FLINK-6226] [table] Add tests for UDFs with Byte, Short, and Float a…
Browse files Browse the repository at this point in the history
…rguments.
  • Loading branch information
fhueske committed Nov 2, 2017
1 parent 37df826 commit 6e118d1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
"Func1(f0)",
"43")

testAllApis(
Func1('f11),
"Func1(f11)",
"Func1(f11)",
"4")

testAllApis(
Func1('f12),
"Func1(f12)",
"Func1(f12)",
"4")

testAllApis(
Func1('f13),
"Func1(f13)",
"Func1(f13)",
"4.0")

testAllApis(
Func2('f0, 'f1, 'f3),
"Func2(f0, f1, f3)",
Expand Down Expand Up @@ -360,7 +378,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
// ----------------------------------------------------------------------------------------------

override def testData: Any = {
val testData = new Row(11)
val testData = new Row(14)
testData.setField(0, 42)
testData.setField(1, "Test")
testData.setField(2, null)
Expand All @@ -372,6 +390,9 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
testData.setField(8, 1000L)
testData.setField(9, Seq("Hello", "World"))
testData.setField(10, Array[Integer](1, 2, null))
testData.setField(11, 3.toByte)
testData.setField(12, 3.toShort)
testData.setField(13, 3.toFloat)
testData
}

Expand All @@ -387,7 +408,10 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
Types.INTERVAL_MONTHS,
Types.INTERVAL_MILLIS,
TypeInformation.of(classOf[Seq[String]]),
BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO
BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO,
Types.BYTE,
Types.SHORT,
Types.FLOAT
).asInstanceOf[TypeInformation[Any]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ object Func1 extends ScalarFunction {
def eval(index: Integer): Integer = {
index + 1
}

def eval(b: Byte): Byte = (b + 1).toByte

def eval(s: Short): Short = (s + 1).toShort

def eval(f: Float): Float = f + 1
}

object Func2 extends ScalarFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}

import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.api.{TableEnvironment, TableException, Types, ValidationException}
import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0
import org.apache.flink.table.api.scala._
import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2}
Expand Down Expand Up @@ -230,6 +230,27 @@ class CorrelateITCase(
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testByteShortFloatArguments(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env, config)
val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
val tFunc = new TableFunc4

val result = in
.select('a.cast(Types.BYTE) as 'a, 'a.cast(Types.SHORT) as 'b, 'b.cast(Types.FLOAT) as 'c)
.join(tFunc('a, 'b, 'c) as ('a2, 'b2, 'c2))
.toDataSet[Row]

val results = result.collect()
val expected = Seq(
"1,1,1.0,Byte=1,Short=1,Float=1.0",
"2,2,2.0,Byte=2,Short=2,Float=2.0",
"3,3,2.0,Byte=3,Short=3,Float=2.0",
"4,4,3.0,Byte=4,Short=4,Float=3.0").mkString("\n")
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testUserDefinedTableFunctionWithParameter(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.lang.Boolean
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.tuple.Tuple3
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.api.{Types, ValidationException}
import org.apache.flink.table.functions.{FunctionContext, TableFunction}
import org.apache.flink.types.Row
import org.junit.Assert
Expand Down Expand Up @@ -109,6 +109,16 @@ class TableFunc3(data: String, conf: Map[String, String]) extends TableFunction[
}
}

class TableFunc4 extends TableFunction[Row] {
def eval(b: Byte, s: Short, f: Float): Unit = {
collect(Row.of("Byte=" + b, "Short=" + s, "Float=" + f))
}

override def getResultType: TypeInformation[Row] = {
new RowTypeInfo(Types.STRING, Types.STRING, Types.STRING)
}
}

class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
def eval(user: String) {
if (user.contains("#")) {
Expand Down

0 comments on commit 6e118d1

Please sign in to comment.