Skip to content

Commit

Permalink
[FLINK-3748] [table] Add CASE function to Table API
Browse files Browse the repository at this point in the history
This closes apache#1893.
  • Loading branch information
twalthr committed Apr 19, 2016
1 parent 85fcfc4 commit 4be297e
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ trait ImplicitExpressionOperations {

def as(name: Symbol) = Naming(expr, name.name)

/**
* Conditional operator that decides which of two other expressions should be evaluated
* based on a evaluated boolean condition.
*
* e.g. (42 > 5).eval("A", "B") leads to "A"
*
* @param ifTrue expression to be evaluated if condition holds
* @param ifFalse expression to be evaluated if condition does not hold
*/
def eval(ifTrue: Expression, ifFalse: Expression) = {
Eval(expr, ifTrue, ifFalse)
}

// scalar functions

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,9 @@ class CodeGenerator(
requireBoolean(operand)
generateNot(nullCheck, operand)

case CASE =>
generateIfElse(nullCheck, operands, resultType)

// casting
case CAST =>
val operand = operands.head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,65 @@ object ScalarOperators {
}
}

def generateIfElse(
nullCheck: Boolean,
operands: Seq[GeneratedExpression],
resultType: TypeInformation[_],
i: Int = 0)
: GeneratedExpression = {
// else part
if (i == operands.size - 1) {
generateCast(nullCheck, operands(i), resultType)
}
else {
// check that the condition is boolean
// we do not check for null instead we use the default value
// thus null is false
requireBoolean(operands(i))
val condition = operands(i)
val trueAction = generateCast(nullCheck, operands(i + 1), resultType)
val falseAction = generateIfElse(nullCheck, operands, resultType, i + 2)

val resultTerm = newName("result")
val nullTerm = newName("isNull")
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)

val operatorCode = if (nullCheck) {
s"""
|${condition.code}
|$resultTypeTerm $resultTerm;
|boolean $nullTerm;
|if (${condition.resultTerm}) {
| ${trueAction.code}
| $resultTerm = ${trueAction.resultTerm};
| $nullTerm = ${trueAction.nullTerm};
|}
|else {
| ${falseAction.code}
| $resultTerm = ${falseAction.resultTerm};
| $nullTerm = ${falseAction.nullTerm};
|}
|""".stripMargin
}
else {
s"""
|${condition.code}
|$resultTypeTerm $resultTerm;
|if (${condition.resultTerm}) {
| ${trueAction.code}
| $resultTerm = ${trueAction.resultTerm};
|}
|else {
| ${falseAction.code}
| $resultTerm = ${falseAction.resultTerm};
|}
|""".stripMargin
}

GeneratedExpression(resultTerm, nullTerm, operatorCode, resultType)
}
}

// ----------------------------------------------------------------------------------------------

private def generateUnaryOperatorIfNotNull(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
case e ~ _ ~ target ~ _ => Naming(e, target.name)
}

lazy val eval: PackratParser[Expression] = atom ~
".eval(" ~ expression ~ "," ~ expression ~ ")" ^^ {
case condition ~ _ ~ ifTrue ~ _ ~ ifFalse ~ _ => Eval(condition, ifTrue, ifFalse)
}

// general function calls

lazy val functionCall = ident ~ "(" ~ rep1sep(expression, ",") ~ ")" ^^ {
Expand Down Expand Up @@ -200,7 +205,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {

lazy val suffix =
isNull | isNotNull |
sum | min | max | count | avg | cast | nullLiteral |
sum | min | max | count | avg | cast | nullLiteral | eval |
specialFunctionCalls | functionCall | functionCallWithoutArgs |
specialSuffixFunctionCalls | suffixFunctionCall | suffixFunctionCallWithoutArgs |
atom
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
*/
package org.apache.flink.api.table.expressions

import scala.collection.JavaConversions._

import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
Expand All @@ -35,7 +33,7 @@ case class Call(functionName: String, args: Expression*) extends Expression {
override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(
BuiltInFunctionNames.toSqlOperator(functionName),
args.map(_.toRexNode))
args.map(_.toRexNode): _*)
}

override def toString = s"\\$functionName(${args.mkString(", ")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.flink.api.table.expressions

import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder

abstract class BinaryPredicate extends BinaryExpression { self: Product => }
Expand Down Expand Up @@ -54,3 +55,23 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
relBuilder.or(left.toRexNode, right.toRexNode)
}
}

case class Eval(
condition: Expression,
ifTrue: Expression,
ifFalse: Expression)
extends Expression {
def children = Seq(condition, ifTrue, ifFalse)

override def toString = s"($condition)? $ifTrue : $ifFalse"

override val name = Expression.freshName("if-" + condition.name +
"-then-" + ifTrue.name + "-else-" + ifFalse.name)

override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
val c = condition.toRexNode
val t = ifTrue.toRexNode
val f = ifFalse.toRexNode
relBuilder.call(SqlStdOperatorTable.CASE, c, t, f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ object RexNodeTranslator {
val l = extractAggCalls(b.left, tableEnv)
val r = extractAggCalls(b.right, tableEnv)
(b.makeCopy(List(l._1, r._1)), l._2 ::: r._2)
case e: Eval =>
val c = extractAggCalls(e.condition, tableEnv)
val t = extractAggCalls(e.ifTrue, tableEnv)
val f = extractAggCalls(e.ifFalse, tableEnv)
(e.makeCopy(List(c._1, t._1, f._1)), c._2 ::: t._2 ::: f._2)

// Scalar functions
case c@Call(name, args@_*) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,46 @@ public void testNullLiteral() throws Exception {
}
}

@Test
public void testEval() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());

DataSource<Tuple2<Integer, Boolean>> input =
env.fromElements(new Tuple2<>(5, true));

Table table =
tableEnv.fromDataSet(input, "a, b");

Table result = table.select(
"(b && true).eval('true', 'false')," +
"false.eval('true', 'false')," +
"true.eval(true.eval(true.eval(10, 4), 4), 4)");

DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "true,false,10";
compareResultAsText(results, expected);
}

@Test(expected = IllegalArgumentException.class)
public void testEvalInvalidTypes() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());

DataSource<Tuple2<Integer, Boolean>> input =
env.fromElements(new Tuple2<>(5, true));

Table table =
tableEnv.fromDataSet(input, "a, b");

Table result = table.select("(b && true).eval(5, 'false')");

DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "true,false,3,10";
compareResultAsText(results, expected);
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,55 @@ class ExpressionsITCase(
}
}

@Test
def testCase(): Unit = {

val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val sqlQuery = "SELECT " +
"CASE 11 WHEN 1 THEN 'a' ELSE 'b' END," +
"CASE 2 WHEN 1 THEN 'a' ELSE 'b' END," +
"CASE 1 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 THEN '3' " +
" ELSE 'none of the above' END" +
" FROM MyTable"

val ds = env.fromElements((1, 0))
tEnv.registerDataSet("MyTable", ds, 'a, 'b)

val result = tEnv.sql(sqlQuery)

val expected = "b,b,1 or 2"
val results = result.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testCaseWithNull(): Unit = {
if (!config.getNullCheck) {
return
}

val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val sqlQuery = "SELECT " +
"CASE WHEN 'a'='a' THEN 1 END," +
"CASE 2 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END," +
"CASE a WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END," +
"CASE b WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END," +
"CASE 42 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END," +
"CASE 1 WHEN 1 THEN true WHEN 2 THEN false ELSE NULL END" +
" FROM MyTable"

val ds = env.fromElements((1, 0))
tEnv.registerDataSet("MyTable", ds, 'a, 'b)

val result = tEnv.sql(sqlQuery)

val expected = "1,bcd,11,null,null,true"
val results = result.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,35 @@ class ExpressionsITCase(
}
}

@Test
def testEval(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val t = env.fromElements((5, true)).toTable(tEnv, 'a, 'b)
.select(
('b && true).eval("true", "false"),
false.eval("true", "false"),
true.eval(true.eval(true.eval(10, 4), 4), 4))

val expected = "true,false,10"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test(expected = classOf[IllegalArgumentException])
def testEvalInvalidTypes(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)

val t = env.fromElements((5, true)).toTable(tEnv, 'a, 'b)
.select(('b && true).eval(5, "false"))

val expected = "true,false,3,10"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

// Date literals not yet supported
@Ignore
@Test
Expand Down

0 comments on commit 4be297e

Please sign in to comment.