Skip to content

Commit

Permalink
[FLINK-3226] Casting support for arithmetic operators
Browse files Browse the repository at this point in the history
  • Loading branch information
twalthr authored and vasia committed Mar 18, 2016
1 parent 7a46bfa commit f05f8fb
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,52 @@
package org.apache.flink.api.table.codegen

import org.apache.flink.api.common.typeinfo.BasicTypeInfo.BOOLEAN_TYPE_INFO
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, TypeInformation}
import org.apache.flink.api.table.codegen.CodeGenUtils._

object OperatorCodeGen {

def generateArithmeticOperator(
def generateArithmeticOperator(
operator: String,
nullCheck: Boolean,
resultType: TypeInformation[_],
left: GeneratedExpression,
right: GeneratedExpression)
: GeneratedExpression = {
generateOperatorIfNotNull(nullCheck, resultType, left, right) {
// String arithmetic // TODO rework
if (isString(left)) {
generateOperatorIfNotNull(nullCheck, resultType, left, right) {
(leftTerm, rightTerm) => s"$leftTerm $operator $rightTerm"
}
}
// Numeric arithmetic
else if (isNumeric(left) && isNumeric(right)) {
val leftType = left.resultType.asInstanceOf[NumericTypeInfo[_]]
val rightType = right.resultType.asInstanceOf[NumericTypeInfo[_]]
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)

generateOperatorIfNotNull(nullCheck, resultType, left, right) {
(leftTerm, rightTerm) =>
// no casting required
if (leftType == resultType && rightType == resultType) {
s"$leftTerm $operator $rightTerm"
}
// left needs casting
else if (leftType != resultType && rightType == resultType) {
s"(($resultTypeTerm) $leftTerm) $operator $rightTerm"
}
// right needs casting
else if (leftType == resultType && rightType != resultType) {
s"$leftTerm $operator (($resultTypeTerm) $rightTerm)"
}
// both sides need casting
else {
s"(($resultTypeTerm) $leftTerm) $operator (($resultTypeTerm) $rightTerm)"
}
}
}
else {
throw new CodeGenException("Unsupported arithmetic operation.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,77 +18,76 @@

package org.apache.flink.api.java.table.test;

import java.util.List;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.table.TableEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.table.Table;
import org.apache.flink.api.java.tuple.Tuple6;
import org.apache.flink.api.java.tuple.Tuple7;
import org.apache.flink.api.java.tuple.Tuple8;
import org.apache.flink.api.table.Row;
import org.apache.flink.api.table.Table;
import org.apache.flink.api.table.codegen.CodeGenException;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.table.TableEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.tuple.Tuple7;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.junit.Ignore;
import org.apache.flink.api.table.test.TableProgramsTestBase;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import scala.NotImplementedError;

import java.util.List;

@RunWith(Parameterized.class)
public class CastingITCase extends MultipleProgramsTestBase {
public class CastingITCase extends TableProgramsTestBase {

public CastingITCase(TestExecutionMode mode){
super(mode);
public CastingITCase(TestExecutionMode mode, TableConfigMode configMode){
super(mode, configMode);
}

@Ignore
@Test(expected = NotImplementedError.class)
@Test
public void testNumericAutocastInArithmetic() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
TableEnvironment tableEnv = getJavaTableEnvironment();

DataSource<Tuple7<Byte, Short, Integer, Long, Float, Double, String>> input =
env.fromElements(new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, "Hello"));
DataSource<Tuple8<Byte, Short, Integer, Long, Float, Double, Long, Double>> input =
env.fromElements(new Tuple8<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, 1L, 1001.1));

Table table =
tableEnv.fromDataSet(input);

Table result = table.select("f0 + 1, f1 +" +
" 1, f2 + 1L, f3 + 1.0f, f4 + 1.0d, f5 + 1");
" 1, f2 + 1L, f3 + 1.0f, f4 + 1.0d, f5 + 1, f6 + 1.0d, f7 + f0");

DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "2,2,2,2.0,2.0,2.0";
String expected = "2,2,2,2.0,2.0,2.0,2.0,1002.1";
compareResultAsText(results, expected);
}

@Test
public void testNumericAutocastInComparison() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
TableEnvironment tableEnv = getJavaTableEnvironment();

DataSource<Tuple7<Byte, Short, Integer, Long, Float, Double, String>> input =
DataSource<Tuple6<Byte, Short, Integer, Long, Float, Double>> input =
env.fromElements(
new Tuple7<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d, "Hello"),
new Tuple7<>((byte) 2, (short) 2, 2, 2L, 2.0f, 2.0d, "Hello"));
new Tuple6<>((byte) 1, (short) 1, 1, 1L, 1.0f, 1.0d),
new Tuple6<>((byte) 2, (short) 2, 2, 2L, 2.0f, 2.0d));

Table table =
tableEnv.fromDataSet(input, "a,b,c,d,e,f,g");
tableEnv.fromDataSet(input, "a,b,c,d,e,f");

Table result = table
.filter("a > 1 && b > 1 && c > 1L && d > 1.0f && e > 1.0d && f > 1");

DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
List<Row> results = ds.collect();
String expected = "2,2,2,2,2.0,2.0,Hello";
String expected = "2,2,2,2,2.0,2.0";
compareResultAsText(results, expected);
}

// TODO support advanced String operations

@Test(expected = CodeGenException.class)
public void testCastFromString() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,33 @@
package org.apache.flink.api.scala.table.test

import java.util.Date
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.table._
import org.apache.flink.api.table.Row
import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase}
import org.apache.flink.api.table.codegen.CodeGenException
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

import scala.collection.JavaConverters._
import org.apache.flink.api.table.codegen.CodeGenException

@RunWith(classOf[Parameterized])
class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {

@Ignore // String autocasting not yet supported
@Test
def testAutoCastToString(): Unit = {

val env = ExecutionEnvironment.getExecutionEnvironment
val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, new Date(0))).toTable
.select('_1 + "b", '_2 + "s", '_3 + "i", '_4 + "L", '_5 + "f", '_6 + "d", '_7 + "Date")

val expected = "1b,1s,1i,1L,1.0f,1.0d,1970-01-01 00:00:00.000Date"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Ignore // gives different types of exceptions for cluster and collection modes
@Test(expected = classOf[NotImplementedError])
def testNumericAutoCastInArithmetic(): Unit = {

// don't test everything, just some common cast directions

val env = ExecutionEnvironment.getExecutionEnvironment
val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d)).toTable
.select('_1 + 1, '_2 + 1, '_3 + 1L, '_4 + 1.0f, '_5 + 1.0d, '_6 + 1)
val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, 1L, 1001.1)).toTable
.select('_1 + 1, '_2 + 1, '_3 + 1L, '_4 + 1.0f, '_5 + 1.0d, '_6 + 1, '_7 + 1.0d, '_8 + '_1)

val expected = "2,2,2,2.0,2.0,2.0"
val expected = "2,2,2,2.0,2.0,2.0,2.0,1002.1"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
Expand All @@ -78,6 +66,21 @@ class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

// TODO support advanced String operations

@Ignore
@Test
def testAutoCastToString(): Unit = {

val env = ExecutionEnvironment.getExecutionEnvironment
val t = env.fromElements((1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, new Date(0))).toTable
.select('_1 + "b", '_2 + "s", '_3 + "i", '_4 + "L", '_5 + "f", '_6 + "d", '_7 + "Date")

val expected = "1b,1s,1i,1L,1.0f,1.0d,1970-01-01 00:00:00.000Date"
val results = t.toDataSet[Row].collect()
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test(expected = classOf[CodeGenException])
def testCastFromString: Unit = {

Expand Down

0 comments on commit f05f8fb

Please sign in to comment.