From 4414008551b2843d98b3caddbb171fa1934e1f40 Mon Sep 17 00:00:00 2001 From: twalthr Date: Fri, 23 Sep 2016 16:44:42 +0200 Subject: [PATCH] [FLINK-4554] [table] Add support for array types This closes #2919. --- docs/dev/table_api.md | 158 +++++++- .../flink/api/scala/table/expressionDsl.scala | 50 ++- .../flink/api/table/FlinkTypeFactory.scala | 22 +- .../api/table/codegen/CodeGenUtils.scala | 11 +- .../api/table/codegen/CodeGenerator.scala | 37 ++ .../api/table/codegen/ExpressionReducer.scala | 4 +- .../table/codegen/calls/ScalarOperators.scala | 198 +++++++++- .../table/expressions/ExpressionParser.scala | 9 +- .../table/expressions/ExpressionUtils.scala | 61 ++- .../flink/api/table/expressions/array.scala | 146 +++++++ .../api/table/expressions/comparison.scala | 3 - .../api/table/plan/ProjectionTranslator.scala | 8 +- .../table/plan/schema/ArrayRelDataType.scala | 53 +++ .../api/table/typeutils/TypeCheckUtils.scala | 14 +- .../api/table/validate/FunctionCatalog.scala | 12 +- .../src/test/resources/log4j-test.properties | 2 +- .../api/table/expressions/ArrayTypeTest.scala | 359 ++++++++++++++++++ .../table/expressions/SqlExpressionTest.scala | 14 +- 18 files changed, 1122 insertions(+), 39 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index 6cf0dee905779..2b42ab2827bab 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -1470,7 +1470,14 @@ The Table API is built on top of Flink's DataSet and DataStream API. Internally, | `Types.INTERVAL_MONTHS`| `INTERVAL YEAR TO MONTH` | `java.lang.Integer` | | `Types.INTERVAL_MILLIS`| `INTERVAL DAY TO SECOND(3)` | `java.lang.Long` | -Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and arrays can be fields of a row. Generic types and arrays are treated as a black box within Table API and SQL yet. Composite types, however, are fully supported types where fields of a composite type can be accessed using the `.get()` operator in Table API and dot operator (e.g. `MyTable.pojoColumn.myField`) in SQL. Composite types can also be flattened using `.flatten()` in Table API or `MyTable.pojoColumn.*` in SQL. + +Advanced types such as generic types, composite types (e.g. POJOs or Tuples), and array types (object or primitive arrays) can be fields of a row. + +Generic types are treated as a black box within Table API and SQL yet. + +Composite types, however, are fully supported types where fields of a composite type can be accessed using the `.get()` operator in Table API and dot operator (e.g. `MyTable.pojoColumn.myField`) in SQL. Composite types can also be flattened using `.flatten()` in Table API or `MyTable.pojoColumn.*` in SQL. + +Array types can be accessed using the `myArray.at(1)` operator in Table API and `myArray[1]` operator in SQL. Array literals can be created using `array(1, 2, 3)` in Table API and `ARRAY[1, 2, 3]` in SQL. {% top %} @@ -2038,6 +2045,50 @@ COMPOSITE.get(INT) + + + {% highlight java %} +ARRAY.at(INT) +{% endhighlight %} + + +

Returns the element at a particular position in an array. The index starts at 1.

+ + + + + + {% highlight java %} +array(ANY [, ANY ]*) +{% endhighlight %} + + +

Creates an array from a list of values. The array will be an array of objects (not primitives).

+ + + + + + {% highlight java %} +ARRAY.cardinality() +{% endhighlight %} + + +

Returns the number of elements of an array.

+ + + + + + {% highlight scala %} +ARRAY.element() +{% endhighlight %} + + +

Returns the sole element of an array with a single element. Returns null if the array is empty. Throws an exception if the array has more than one element.

+ + + @@ -2599,6 +2650,50 @@ COMPOSITE.get(INT) + + + {% highlight scala %} +ARRAY.at(INT) +{% endhighlight %} + + +

Returns the element at a particular position in an array. The index starts at 1.

+ + + + + + {% highlight scala %} +array(ANY [, ANY ]*) +{% endhighlight %} + + +

Creates an array from a list of values. The array will be an array of objects (not primitives).

+ + + + + + {% highlight scala %} +ARRAY.cardinality() +{% endhighlight %} + + +

Returns the number of elements of an array.

+ + + + + + {% highlight scala %} +ARRAY.element() +{% endhighlight %} + + +

Returns the sole element of an array with a single element. Returns null if the array is empty. Throws an exception if the array has more than one element.

+ + + @@ -3368,8 +3463,6 @@ CAST(value AS type) - - + + + + {% highlight text %} +array ‘[’ index ‘]’ +{% endhighlight %} + + +

Returns the element at a particular position in an array. The index starts at 1.

+ + + + + + {% highlight text %} +ARRAY ‘[’ value [, value ]* ‘]’ +{% endhighlight %} + + +

Creates an array from a list of values.

+ + + ---> @@ -3657,6 +3774,39 @@ tableName.compositeType.*
+ + + + + + + + + + + + + + + + + + + +
Array functionsDescription
+ {% highlight text %} +CARDINALITY(ARRAY) +{% endhighlight %} + +

Returns the number of elements of an array.

+
+ {% highlight text %} +ELEMENT(ARRAY) +{% endhighlight %} + +

Returns the sole element of an array with a single element. Returns null if the array is empty. Throws an exception if the array has more than one element.

+
+ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala index 175ce2ee851de..823458a4f2f25 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala @@ -21,9 +21,10 @@ import java.sql.{Date, Time, Timestamp} import org.apache.calcite.avatica.util.DateTimeUtils._ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} -import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval} +import org.apache.flink.api.table.expressions.ExpressionUtils.{convertArray, toMilliInterval, toMonthInterval, toRowInterval} import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.api.table.expressions._ +import java.math.{BigDecimal => JBigDecimal} import scala.language.implicitConversions @@ -461,6 +462,29 @@ trait ImplicitExpressionOperations { * into a flat representation where every subtype is a separate field. */ def flatten() = Flattening(expr) + + /** + * Accesses the element of an array based on an index (starting at 1). + * + * @param index position of the element (starting at 1) + * @return value of the element + */ + def at(index: Expression) = ArrayElementAt(expr, index) + + /** + * Returns the number of elements of an array. + * + * @return number of elements + */ + def cardinality() = ArrayCardinality(expr) + + /** + * Returns the sole element of an array with a single element. Returns null if the array is + * empty. Throws an exception if the array has more than one element. + * + * @return the first and only element of an array with a single element + */ + def element() = ArrayElement(expr) } /** @@ -540,18 +564,24 @@ trait ImplicitExpressionConversions { implicit def float2Literal(d: Float): Expression = Literal(d) implicit def string2Literal(str: String): Expression = Literal(str) implicit def boolean2Literal(bool: Boolean): Expression = Literal(bool) - implicit def javaDec2Literal(javaDec: java.math.BigDecimal): Expression = Literal(javaDec) - implicit def scalaDec2Literal(scalaDec: scala.math.BigDecimal): Expression = + implicit def javaDec2Literal(javaDec: JBigDecimal): Expression = Literal(javaDec) + implicit def scalaDec2Literal(scalaDec: BigDecimal): Expression = Literal(scalaDec.bigDecimal) implicit def sqlDate2Literal(sqlDate: Date): Expression = Literal(sqlDate) implicit def sqlTime2Literal(sqlTime: Time): Expression = Literal(sqlTime) - implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp) + implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = + Literal(sqlTimestamp) + implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array) } // ------------------------------------------------------------------------------------------------ // Expressions with no parameters // ------------------------------------------------------------------------------------------------ +// we disable the object checker here as it checks for capital letters of objects +// but we want that objects look like functions in certain cases e.g. array(1, 2, 3) +// scalastyle:off object.name + /** * Returns the current SQL date in UTC time zone. */ @@ -645,5 +675,17 @@ object temporalOverlaps { } } +/** + * Creates an array of literals. The array will be an array of objects (not primitives). + */ +object array { + /** + * Creates an array of literals. The array will be an array of objects (not primitives). + */ + def apply(head: Expression, tail: Expression*): Expression = { + ArrayConstructor(head +: tail.toSeq) + } +} +// scalastyle:on object.name diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala index bb11576766fca..8dcd66085a8b0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala @@ -26,11 +26,12 @@ import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.parser.SqlParserPos import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo import org.apache.flink.api.java.typeutils.ValueTypeInfo._ import org.apache.flink.api.table.FlinkTypeFactory.typeInfoToSqlTypeName -import org.apache.flink.api.table.plan.schema.{CompositeRelDataType, GenericRelDataType} +import org.apache.flink.api.table.plan.schema.{ArrayRelDataType, CompositeRelDataType, GenericRelDataType} import org.apache.flink.api.table.typeutils.TimeIntervalTypeInfo import org.apache.flink.api.table.typeutils.TypeCheckUtils.isSimple @@ -102,11 +103,22 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } } + override def createArrayType(elementType: RelDataType, maxCardinality: Long): RelDataType = + new ArrayRelDataType( + ObjectArrayTypeInfo.getInfoFor(FlinkTypeFactory.toTypeInfo(elementType)), + elementType, + true) + private def createAdvancedType(typeInfo: TypeInformation[_]): RelDataType = typeInfo match { case ct: CompositeType[_] => new CompositeRelDataType(ct, this) - // TODO add specific RelDataTypes for PrimitiveArrayTypeInfo, ObjectArrayTypeInfo + case pa: PrimitiveArrayTypeInfo[_] => + new ArrayRelDataType(pa, createTypeFromTypeInfo(pa.getComponentType), false) + + case oa: ObjectArrayTypeInfo[_, _] => + new ArrayRelDataType(oa, createTypeFromTypeInfo(oa.getComponentInfo), true) + case ti: TypeInformation[_] => new GenericRelDataType(typeInfo, getTypeSystem.asInstanceOf[FlinkTypeSystem]) @@ -190,6 +202,10 @@ object FlinkTypeFactory { // ROW and CURSOR for UDTF case, whose type info will never be used, just a placeholder case ROW | CURSOR => new NothingTypeInfo + case ARRAY if relDataType.isInstanceOf[ArrayRelDataType] => + val arrayRelDataType = relDataType.asInstanceOf[ArrayRelDataType] + arrayRelDataType.typeInfo + case _@t => throw TableException(s"Type is not supported: $t") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala index b78012c831989..4092a24c3556b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenUtils.scala @@ -155,7 +155,6 @@ object CodeGenUtils { def enumValueOf[T <: Enum[T]](cls: Class[_], stringValue: String): Enum[_] = Enum.valueOf(cls.asInstanceOf[Class[T]], stringValue).asInstanceOf[Enum[_]] - // ---------------------------------------------------------------------------------------------- def requireNumeric(genExpr: GeneratedExpression) = @@ -189,6 +188,16 @@ object CodeGenUtils { throw new CodeGenException("Interval expression type expected.") } + def requireArray(genExpr: GeneratedExpression) = + if (!TypeCheckUtils.isArray(genExpr.resultType)) { + throw new CodeGenException("Array expression type expected.") + } + + def requireInteger(genExpr: GeneratedExpression) = + if (!TypeCheckUtils.isInteger(genExpr.resultType)) { + throw new CodeGenException("Integer expression type expected.") + } + // ---------------------------------------------------------------------------------------------- def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index f7d68634edee1..7caad126acf0f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -976,6 +976,27 @@ class CodeGenerator( requireString(left) generateArithmeticOperator("+", nullCheck, resultType, left, right) + // arrays + case ARRAY_VALUE_CONSTRUCTOR => + generateArray(this, resultType, operands) + + case ITEM => + val array = operands.head + val index = operands(1) + requireArray(array) + requireInteger(index) + generateArrayElementAt(this, array, index) + + case CARDINALITY => + val array = operands.head + requireArray(array) + generateArrayCardinality(nullCheck, array) + + case ELEMENT => + val array = operands.head + requireArray(array) + generateArrayElement(this, array) + // advanced scalar functions case sqlOperator: SqlOperator => val callGen = FunctionGenerator.getCallGenerator( @@ -1393,6 +1414,22 @@ class CodeGenerator( fieldTerm } + /** + * Adds a reusable array to the member area of the generated [[Function]]. + */ + def addReusableArray(clazz: Class[_], size: Int): String = { + val fieldTerm = newName("array") + val classQualifier = clazz.getCanonicalName // works also for int[] etc. + val initArray = classQualifier.replaceFirst("\\[", s"[$size") + val fieldArray = + s""" + |transient $classQualifier $fieldTerm = + | new $initArray; + |""".stripMargin + reusableMemberStatements.add(fieldArray) + fieldTerm + } + /** * Adds a reusable timestamp to the beginning of the SAM of the generated [[Function]]. */ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala index 74756ef2d73be..731452f5e5718 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionReducer.scala @@ -63,7 +63,7 @@ class ExpressionReducer(config: TableConfig) ) // we don't support object literals yet, we skip those constant expressions - case (SqlTypeName.ANY, _) | (SqlTypeName.ROW, _) => None + case (SqlTypeName.ANY, _) | (SqlTypeName.ROW, _) | (SqlTypeName.ARRAY, _) => None case (_, e) => Some(e) } @@ -101,7 +101,7 @@ class ExpressionReducer(config: TableConfig) val unreduced = constExprs.get(i) unreduced.getType.getSqlTypeName match { // we insert the original expression for object literals - case SqlTypeName.ANY | SqlTypeName.ROW => + case SqlTypeName.ANY | SqlTypeName.ROW | SqlTypeName.ARRAY => reducedValues.add(unreduced) case _ => val literal = rexBuilder.makeLiteral( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala index 75c0149a0f87a..330e2fe021a9a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala @@ -21,9 +21,10 @@ import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} import org.apache.calcite.util.BuiltInMethod import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo import org.apache.flink.api.table.codegen.CodeGenUtils._ -import org.apache.flink.api.table.codegen.{CodeGenException, GeneratedExpression} +import org.apache.flink.api.table.codegen.{CodeGenerator, CodeGenException, GeneratedExpression} import org.apache.flink.api.table.typeutils.TimeIntervalTypeInfo import org.apache.flink.api.table.typeutils.TypeCheckUtils._ @@ -91,6 +92,12 @@ object ScalarOperators { else if (isTemporal(left.resultType) && left.resultType == right.resultType) { generateComparison("==", nullCheck, left, right) } + // array types + else if (isArray(left.resultType) && left.resultType == right.resultType) { + generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { + (leftTerm, rightTerm) => s"java.util.Arrays.equals($leftTerm, $rightTerm)" + } + } // comparable types of same type else if (isComparable(left.resultType) && left.resultType == right.resultType) { generateComparison("==", nullCheck, left, right) @@ -125,6 +132,12 @@ object ScalarOperators { else if (isTemporal(left.resultType) && left.resultType == right.resultType) { generateComparison("!=", nullCheck, left, right) } + // array types + else if (isArray(left.resultType) && left.resultType == right.resultType) { + generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) { + (leftTerm, rightTerm) => s"!java.util.Arrays.equals($leftTerm, $rightTerm)" + } + } // comparable types else if (isComparable(left.resultType) && left.resultType == right.resultType) { generateComparison("!=", nullCheck, left, right) @@ -428,7 +441,7 @@ object ScalarOperators { // Date/Time/Timestamp -> String case (dtt: SqlTimeTypeInfo[_], STRING_TYPE_INFO) => generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { - (operandTerm) => s"""${internalToTimePointCode(dtt, operandTerm)}.toString()""" + (operandTerm) => s"${internalToTimePointCode(dtt, operandTerm)}.toString()" } // Interval Months -> String @@ -447,6 +460,18 @@ object ScalarOperators { (operandTerm) => s"$method($operandTerm, $timeUnitRange, 3)" // milli second precision } + // Object array -> String + case (_:ObjectArrayTypeInfo[_, _], STRING_TYPE_INFO) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"java.util.Arrays.deepToString($operandTerm)" + } + + // Primitive array -> String + case (_:PrimitiveArrayTypeInfo[_], STRING_TYPE_INFO) => + generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { + (operandTerm) => s"java.util.Arrays.toString($operandTerm)" + } + // * (not Date/Time/Timestamp) -> String case (_, STRING_TYPE_INFO) => generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) { @@ -701,6 +726,173 @@ object ScalarOperators { generateUnaryArithmeticOperator(operator, nullCheck, operand.resultType, operand) } + def generateArray( + codeGenerator: CodeGenerator, + resultType: TypeInformation[_], + elements: Seq[GeneratedExpression]) + : GeneratedExpression = { + val arrayTerm = codeGenerator.addReusableArray(resultType.getTypeClass, elements.size) + + val boxedElements: Seq[GeneratedExpression] = resultType match { + + case oati: ObjectArrayTypeInfo[_, _] => + // we box the elements to also represent null values + val boxedTypeTerm = boxedTypeTermForTypeInfo(oati.getComponentInfo) + + elements.map { e => + val boxedExpr = codeGenerator.generateOutputFieldBoxing(e) + val exprOrNull: String = if (codeGenerator.nullCheck) { + s"${boxedExpr.nullTerm} ? null : ($boxedTypeTerm) ${boxedExpr.resultTerm}" + } else { + boxedExpr.resultTerm + } + boxedExpr.copy(resultTerm = exprOrNull) + } + + // no boxing necessary + case _: PrimitiveArrayTypeInfo[_] => elements + } + + val code = boxedElements + .zipWithIndex + .map { case (element, idx) => + s""" + |${element.code} + |$arrayTerm[$idx] = ${element.resultTerm}; + |""".stripMargin + } + .mkString("\n") + + GeneratedExpression(arrayTerm, GeneratedExpression.NEVER_NULL, code, resultType) + } + + def generateArrayElementAt( + codeGenerator: CodeGenerator, + array: GeneratedExpression, + index: GeneratedExpression) + : GeneratedExpression = { + + val resultTerm = newName("result") + + array.resultType match { + + // unbox object array types + case oati: ObjectArrayTypeInfo[_, _] => + // get boxed array element + val resultTypeTerm = boxedTypeTermForTypeInfo(oati.getComponentInfo) + + val arrayAccessCode = if (codeGenerator.nullCheck) { + s""" + |${array.code} + |${index.code} + |$resultTypeTerm $resultTerm = (${array.nullTerm} || ${index.nullTerm}) ? + | null : ${array.resultTerm}[${index.resultTerm} - 1]; + |""".stripMargin + } else { + s""" + |${array.code} + |${index.code} + |$resultTypeTerm $resultTerm = ${array.resultTerm}[${index.resultTerm} - 1]; + |""".stripMargin + } + + // generate unbox code + val unboxing = codeGenerator.generateInputFieldUnboxing(oati.getComponentInfo, resultTerm) + + unboxing.copy(code = + s""" + |$arrayAccessCode + |${unboxing.code} + |""".stripMargin + ) + + // no unboxing necessary + case pati: PrimitiveArrayTypeInfo[_] => + generateOperatorIfNotNull(codeGenerator.nullCheck, pati.getComponentType, array, index) { + (leftTerm, rightTerm) => s"$leftTerm[$rightTerm - 1]" + } + } + } + + def generateArrayElement( + codeGenerator: CodeGenerator, + array: GeneratedExpression) + : GeneratedExpression = { + + val nullTerm = newName("isNull") + val resultTerm = newName("result") + val resultType = array.resultType match { + case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo + case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType + } + val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) + val defaultValue = primitiveDefaultValue(resultType) + + val arrayLengthCode = if (codeGenerator.nullCheck) { + s"${array.nullTerm} ? 0 : ${array.resultTerm}.length" + } else { + s"${array.resultTerm}.length" + } + + val arrayAccessCode = array.resultType match { + case oati: ObjectArrayTypeInfo[_, _] => + // generate unboxing code + val unboxing = codeGenerator.generateInputFieldUnboxing( + oati.getComponentInfo, + s"${array.resultTerm}[0]") + + s""" + |${array.code} + |${if (codeGenerator.nullCheck) s"boolean $nullTerm;" else "" } + |$resultTypeTerm $resultTerm; + |switch ($arrayLengthCode) { + | case 0: + | ${if (codeGenerator.nullCheck) s"$nullTerm = true;" else "" } + | $resultTerm = $defaultValue; + | break; + | case 1: + | ${unboxing.code} + | ${if (codeGenerator.nullCheck) s"$nullTerm = ${unboxing.nullTerm};" else "" } + | $resultTerm = ${unboxing.resultTerm}; + | break; + | default: + | throw new RuntimeException("Array has more than one element."); + |} + |""".stripMargin + + case pati: PrimitiveArrayTypeInfo[_] => + s""" + |${array.code} + |${if (codeGenerator.nullCheck) s"boolean $nullTerm;" else "" } + |$resultTypeTerm $resultTerm; + |switch ($arrayLengthCode) { + | case 0: + | ${if (codeGenerator.nullCheck) s"$nullTerm = true;" else "" } + | $resultTerm = $defaultValue; + | break; + | case 1: + | ${if (codeGenerator.nullCheck) s"$nullTerm = false;" else "" } + | $resultTerm = ${array.resultTerm}[0]; + | break; + | default: + | throw new RuntimeException("Array has more than one element."); + |} + |""".stripMargin + } + + GeneratedExpression(resultTerm, nullTerm, arrayAccessCode, resultType) + } + + def generateArrayCardinality( + nullCheck: Boolean, + array: GeneratedExpression) + : GeneratedExpression = { + + generateUnaryOperatorIfNotNull(nullCheck, INT_TYPE_INFO, array) { + (operandTerm) => s"${array.resultTerm}.length" + } + } + // ---------------------------------------------------------------------------------------------- private def generateUnaryOperatorIfNotNull( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala index a926717b87c30..c960a79a7a0c8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala @@ -48,6 +48,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { // Keyword + lazy val ARRAY: Keyword = Keyword("Array") lazy val AS: Keyword = Keyword("as") lazy val COUNT: Keyword = Keyword("count") lazy val AVG: Keyword = Keyword("avg") @@ -88,7 +89,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val FLATTEN: Keyword = Keyword("flatten") def functionIdent: ExpressionParser.Parser[String] = - not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ + not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ not(SUM) ~ not(START) ~ not(END)~ not(CAST) ~ not(NULL) ~ not(IF) ~> super.ident @@ -298,6 +299,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { // prefix operators + lazy val prefixArray: PackratParser[Expression] = + ARRAY ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { elements => ArrayConstructor(elements) } + lazy val prefixSum: PackratParser[Expression] = SUM ~ "(" ~> expression <~ ")" ^^ { e => Sum(e) } @@ -372,7 +376,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { FLATTEN ~ "(" ~> composite <~ ")" ^^ { e => Flattening(e) } lazy val prefixed: PackratParser[Expression] = - prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg | prefixStart | prefixEnd | + prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg | + prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening | prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala index c071c5926b6e3..86575347add3c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionUtils.scala @@ -18,13 +18,16 @@ package org.apache.flink.api.table.expressions -import java.math.BigDecimal +import java.lang.{Boolean => JBoolean, Byte => JByte, Short => JShort, Integer => JInteger, Long => JLong, Float => JFloat, Double => JDouble} +import java.math.{BigDecimal => JBigDecimal} +import java.sql.{Date, Time, Timestamp} import org.apache.calcite.avatica.util.TimeUnit import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex.{RexBuilder, RexNode} import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.table.ValidationException import org.apache.flink.api.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} object ExpressionUtils { @@ -54,6 +57,48 @@ object ExpressionUtils { throw new IllegalArgumentException("Invalid value for row interval literal.") } + private[flink] def convertArray(array: Array[_]): Expression = { + def createArray(): Expression = { + ArrayConstructor(array.map(Literal(_))) + } + + array match { + // primitives + case _: Array[Boolean] => createArray() + case _: Array[Byte] => createArray() + case _: Array[Short] => createArray() + case _: Array[Int] => createArray() + case _: Array[Long] => createArray() + case _: Array[Float] => createArray() + case _: Array[Double] => createArray() + + // boxed types + case _: Array[JBoolean] => createArray() + case _: Array[JByte] => createArray() + case _: Array[JShort] => createArray() + case _: Array[JInteger] => createArray() + case _: Array[JLong] => createArray() + case _: Array[JFloat] => createArray() + case _: Array[JDouble] => createArray() + + // others + case _: Array[String] => createArray() + case _: Array[JBigDecimal] => createArray() + case _: Array[Date] => createArray() + case _: Array[Time] => createArray() + case _: Array[Timestamp] => createArray() + case bda: Array[BigDecimal] => ArrayConstructor(bda.map { bd => Literal(bd.bigDecimal) }) + + case _ => + // nested + if (array.length > 0 && array.head.isInstanceOf[Array[_]]) { + ArrayConstructor(array.map { na => convertArray(na.asInstanceOf[Array[_]]) }) + } else { + throw ValidationException("Unsupported array type.") + } + } + } + // ---------------------------------------------------------------------------------------------- // RexNode conversion functions (see org.apache.calcite.sql2rel.StandardConvertletTable) // ---------------------------------------------------------------------------------------------- @@ -61,7 +106,7 @@ object ExpressionUtils { /** * Copy of [[org.apache.calcite.sql2rel.StandardConvertletTable#getFactor()]]. */ - private[flink] def getFactor(unit: TimeUnit): BigDecimal = unit match { + private[flink] def getFactor(unit: TimeUnit): JBigDecimal = unit match { case TimeUnit.DAY => java.math.BigDecimal.ONE case TimeUnit.HOUR => TimeUnit.DAY.multiplier case TimeUnit.MINUTE => TimeUnit.HOUR.multiplier @@ -78,20 +123,20 @@ object ExpressionUtils { rexBuilder: RexBuilder, resType: RelDataType, res: RexNode, - value: BigDecimal) + value: JBigDecimal) : RexNode = { - if (value == BigDecimal.ONE) return res + if (value == JBigDecimal.ONE) return res rexBuilder.makeCall(SqlStdOperatorTable.MOD, res, rexBuilder.makeExactLiteral(value, resType)) } /** * Copy of [[org.apache.calcite.sql2rel.StandardConvertletTable#divide()]]. */ - private[flink] def divide(rexBuilder: RexBuilder, res: RexNode, value: BigDecimal): RexNode = { - if (value == BigDecimal.ONE) return res - if (value.compareTo(BigDecimal.ONE) < 0 && value.signum == 1) { + private[flink] def divide(rexBuilder: RexBuilder, res: RexNode, value: JBigDecimal): RexNode = { + if (value == JBigDecimal.ONE) return res + if (value.compareTo(JBigDecimal.ONE) < 0 && value.signum == 1) { try { - val reciprocal = BigDecimal.ONE.divide(value, BigDecimal.ROUND_UNNECESSARY) + val reciprocal = JBigDecimal.ONE.divide(value, JBigDecimal.ROUND_UNNECESSARY) return rexBuilder.makeCall( SqlStdOperatorTable.MULTIPLY, res, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala new file mode 100644 index 0000000000000..78084de084201 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/array.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.expressions + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.BasicTypeInfo.INT_TYPE_INFO +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo +import org.apache.flink.api.table.FlinkRelBuilder +import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} + +import scala.collection.JavaConverters._ + +case class ArrayConstructor(elements: Seq[Expression]) extends Expression { + + override private[flink] def children: Seq[Expression] = elements + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + val relDataType = relBuilder + .asInstanceOf[FlinkRelBuilder] + .getTypeFactory + .createTypeFromTypeInfo(resultType) + val values = elements.map(_.toRexNode).toList.asJava + relBuilder + .getRexBuilder + .makeCall(relDataType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, values) + } + + override def toString = s"array(${elements.mkString(", ")})" + + override private[flink] def resultType = ObjectArrayTypeInfo.getInfoFor(elements.head.resultType) + + override private[flink] def validateInput(): ValidationResult = { + if (elements.isEmpty) { + return ValidationFailure("Empty arrays are not supported yet.") + } + val elementType = elements.head.resultType + if (!elements.forall(_.resultType == elementType)) { + ValidationFailure("Not all elements of the array have the same type.") + } else { + ValidationSuccess + } + } +} + +case class ArrayElementAt(array: Expression, index: Expression) extends Expression { + + override private[flink] def children: Seq[Expression] = Seq(array, index) + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder + .getRexBuilder + .makeCall(SqlStdOperatorTable.ITEM, array.toRexNode, index.toRexNode) + } + + override def toString = s"($array).at($index)" + + override private[flink] def resultType = array.resultType match { + case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo + case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType + } + + override private[flink] def validateInput(): ValidationResult = { + array.resultType match { + case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => + if (index.resultType == INT_TYPE_INFO) { + // check for common user mistake + index match { + case Literal(value: Int, INT_TYPE_INFO) if value < 1 => + ValidationFailure( + s"Array element access needs an index starting at 1 but was $value.") + case _ => ValidationSuccess + } + } else { + ValidationFailure( + s"Array element access needs an integer index but was '${index.resultType}'.") + } + case other@_ => ValidationFailure(s"Array expected but was '$other'.") + } + } +} + +case class ArrayCardinality(array: Expression) extends Expression { + + override private[flink] def children: Seq[Expression] = Seq(array) + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder + .getRexBuilder + .makeCall(SqlStdOperatorTable.CARDINALITY, array.toRexNode) + } + + override def toString = s"($array).cardinality()" + + override private[flink] def resultType = BasicTypeInfo.INT_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + array.resultType match { + case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => ValidationSuccess + case other@_ => ValidationFailure(s"Array expected but was '$other'.") + } + } +} + +case class ArrayElement(array: Expression) extends Expression { + + override private[flink] def children: Seq[Expression] = Seq(array) + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder + .getRexBuilder + .makeCall(SqlStdOperatorTable.ELEMENT, array.toRexNode) + } + + override def toString = s"($array).element()" + + override private[flink] def resultType = array.resultType match { + case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo + case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType + } + + override private[flink] def validateInput(): ValidationResult = { + array.resultType match { + case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => ValidationSuccess + case other@_ => ValidationFailure(s"Array expected but was '$other'.") + } + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala index d5244d0dd0afa..5a150f864126f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala @@ -36,7 +36,6 @@ abstract class BinaryComparison extends BinaryExpression { override private[flink] def resultType = BOOLEAN_TYPE_INFO - // TODO: tighten this rule once we implemented type coercion rules during validation override private[flink] def validateInput(): ValidationResult = (left.resultType, right.resultType) match { case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess @@ -56,7 +55,6 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison override private[flink] def validateInput(): ValidationResult = (left.resultType, right.resultType) match { case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess - // TODO widen this rule once we support custom objects as types (FLINK-3916) case (lType, rType) if lType == rType => ValidationSuccess case (lType, rType) => ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType") @@ -71,7 +69,6 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari override private[flink] def validateInput(): ValidationResult = (left.resultType, right.resultType) match { case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess - // TODO widen this rule once we support custom objects as types (FLINK-3916) case (lType, rType) if lType == rType => ValidationSuccess case (lType, rType) => ValidationFailure(s"Inequality predicate on incompatible types: $lType and $rType") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala index c093f1aa207e1..22b77b47a32b5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala @@ -143,7 +143,13 @@ object ProjectionTranslator { case sfc @ ScalarFunctionCall(clazz, args) => val newArgs: Seq[Expression] = args .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames)) - sfc.makeCopy(Array(clazz,newArgs)) + sfc.makeCopy(Array(clazz, newArgs)) + + // array constructor + case c @ ArrayConstructor(args) => + val newArgs = c.elements + .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames)) + c.makeCopy(Array(newArgs)) // General expression case e: Expression => diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala new file mode 100644 index 0000000000000..92fcb83000900 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/ArrayRelDataType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.schema + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql.`type`.ArraySqlType +import org.apache.flink.api.common.typeinfo.TypeInformation + +/** + * Flink distinguishes between primitive arrays (int[], double[], ...) and + * object arrays (Integer[], MyPojo[], ...). This custom type supports both cases. + */ +class ArrayRelDataType( + val typeInfo: TypeInformation[_], + elementType: RelDataType, + isNullable: Boolean) + extends ArraySqlType( + elementType, + isNullable) { + + override def toString = s"ARRAY($typeInfo)" + + def canEqual(other: Any): Boolean = other.isInstanceOf[ArrayRelDataType] + + override def equals(other: Any): Boolean = other match { + case that: ArrayRelDataType => + super.equals(that) && + (that canEqual this) && + typeInfo == that.typeInfo + case _ => false + } + + override def hashCode(): Int = { + typeInfo.hashCode() + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala index aa8614bd4c0e4..e30e2733a1fa3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/TypeCheckUtils.scala @@ -17,8 +17,9 @@ */ package org.apache.flink.api.table.typeutils -import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, STRING_TYPE_INFO} -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, NumericTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{BIG_DEC_TYPE_INFO, BOOLEAN_TYPE_INFO, INT_TYPE_INFO, STRING_TYPE_INFO} +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo import org.apache.flink.api.table.validate._ object TypeCheckUtils { @@ -61,8 +62,15 @@ object TypeCheckUtils { def isDecimal(dataType: TypeInformation[_]): Boolean = dataType == BIG_DEC_TYPE_INFO + def isInteger(dataType: TypeInformation[_]): Boolean = dataType == INT_TYPE_INFO + + def isArray(dataType: TypeInformation[_]): Boolean = dataType match { + case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => true + case _ => false + } + def isComparable(dataType: TypeInformation[_]): Boolean = - classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass) + classOf[Comparable[_]].isAssignableFrom(dataType.getTypeClass) && !isArray(dataType) def assertNumericExpr( dataType: TypeInformation[_], diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala index dc68b89a88472..8e409cc892e61 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala @@ -185,7 +185,12 @@ object FunctionCatalog { "localTime" -> classOf[LocalTime], "localTimestamp" -> classOf[LocalTimestamp], "quarter" -> classOf[Quarter], - "temporalOverlaps" -> classOf[TemporalOverlaps] + "temporalOverlaps" -> classOf[TemporalOverlaps], + + // array + "cardinality" -> classOf[ArrayCardinality], + "at" -> classOf[ArrayElementAt], + "element" -> classOf[ArrayElement] // TODO implement function overloading here // "floor" -> classOf[TemporalFloor] @@ -258,6 +263,11 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.MIN, SqlStdOperatorTable.MAX, SqlStdOperatorTable.AVG, + // ARRAY OPERATORS + SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, + SqlStdOperatorTable.ITEM, + SqlStdOperatorTable.CARDINALITY, + SqlStdOperatorTable.ELEMENT, // SPECIAL OPERATORS SqlStdOperatorTable.ROW, SqlStdOperatorTable.OVERLAPS, diff --git a/flink-libraries/flink-table/src/test/resources/log4j-test.properties b/flink-libraries/flink-table/src/test/resources/log4j-test.properties index f713aa89b6ad3..4c74d85d7c625 100644 --- a/flink-libraries/flink-table/src/test/resources/log4j-test.properties +++ b/flink-libraries/flink-table/src/test/resources/log4j-test.properties @@ -18,7 +18,7 @@ # Set root logger level to OFF to not flood build logs # set manually to INFO for debugging purposes -log4j.rootLogger=INFO, testlogger +log4j.rootLogger=OFF, testlogger # A1 is set to be a ConsoleAppender. log4j.appender.testlogger=org.apache.log4j.ConsoleAppender diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala new file mode 100644 index 0000000000000..034ce0ba191e5 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/ArrayTypeTest.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.expressions + +import java.sql.Date + +import org.apache.flink.api.common.typeinfo.{PrimitiveArrayTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.expressions.utils.ExpressionTestBase +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.{Row, Types, ValidationException} +import org.junit.Test + +class ArrayTypeTest extends ExpressionTestBase { + + @Test(expected = classOf[ValidationException]) + def testObviousInvalidIndexTableApi(): Unit = { + testTableApi('f2.at(0), "FAIL", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testEmptyArraySql(): Unit = { + testSqlApi("ARRAY[]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testEmptyArrayTableApi(): Unit = { + testTableApi("FAIL", "array()", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testNullArraySql(): Unit = { + testSqlApi("ARRAY[NULL]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testDifferentTypesArraySql(): Unit = { + testSqlApi("ARRAY[1, TRUE]", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testDifferentTypesArrayTableApi(): Unit = { + testTableApi("FAIL", "array(1, true)", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testUnsupportedComparison(): Unit = { + testAllApis( + 'f2 <= 'f5.at(1), + "f2 <= f5.at(1)", + "f2 <= f5[1]", + "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testElementNonArray(): Unit = { + testTableApi( + 'f0.element(), + "FAIL", + "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testElementNonArraySql(): Unit = { + testSqlApi( + "ELEMENT(f0)", + "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testCardinalityOnNonArray(): Unit = { + testTableApi('f0.cardinality(), "FAIL", "FAIL") + } + + @Test(expected = classOf[ValidationException]) + def testCardinalityOnNonArraySql(): Unit = { + testSqlApi("CARDINALITY(f0)", "FAIL") + } + + @Test + def testArrayLiterals(): Unit = { + // primitive literals + testAllApis(array(1, 2, 3), "array(1, 2, 3)", "ARRAY[1, 2, 3]", "[1, 2, 3]") + + testAllApis( + array(true, true, true), + "array(true, true, true)", + "ARRAY[TRUE, TRUE, TRUE]", + "[true, true, true]") + + // object literals + testTableApi(array(BigDecimal(1), BigDecimal(1)), "array(1p, 1p)", "[1, 1]") + + testAllApis( + array(array(array(1), array(1))), + "array(array(array(1), array(1)))", + "ARRAY[ARRAY[ARRAY[1], ARRAY[1]]]", + "[[[1], [1]]]") + + testAllApis( + array(1 + 1, 3 * 3), + "array(1 + 1, 3 * 3)", + "ARRAY[1 + 1, 3 * 3]", + "[2, 9]") + + testAllApis( + array(Null(Types.INT), 1), + "array(Null(INT), 1)", + "ARRAY[NULLIF(1,1), 1]", + "[null, 1]") + + testAllApis( + array(array(Null(Types.INT), 1)), + "array(array(Null(INT), 1))", + "ARRAY[ARRAY[NULLIF(1,1), 1]]", + "[[null, 1]]") + + // implicit conversion + testTableApi( + Array(1, 2, 3), + "array(1, 2, 3)", + "[1, 2, 3]") + + testTableApi( + Array[Integer](1, 2, 3), + "array(1, 2, 3)", + "[1, 2, 3]") + + testAllApis( + Array(Date.valueOf("1985-04-11")), + "array('1985-04-11'.toDate)", + "ARRAY[DATE '1985-04-11']", + "[1985-04-11]") + + testAllApis( + Array(BigDecimal(2.0002), BigDecimal(2.0003)), + "Array(2.0002p, 2.0003p)", + "ARRAY[CAST(2.0002 AS DECIMAL), CAST(2.0003 AS DECIMAL)]", + "[2.0002, 2.0003]") + + testAllApis( + Array(Array(x = true)), + "Array(Array(true))", + "ARRAY[ARRAY[TRUE]]", + "[[true]]") + + testAllApis( + Array(Array(1, 2, 3), Array(3, 2, 1)), + "Array(Array(1, 2, 3), Array(3, 2, 1))", + "ARRAY[ARRAY[1, 2, 3], ARRAY[3, 2, 1]]", + "[[1, 2, 3], [3, 2, 1]]") + } + + @Test + def testArrayField(): Unit = { + testAllApis( + array('f0, 'f1), + "array(f0, f1)", + "ARRAY[f0, f1]", + "[null, 42]") + + testAllApis( + array('f0, 'f1), + "array(f0, f1)", + "ARRAY[f0, f1]", + "[null, 42]") + + testAllApis( + 'f2, + "f2", + "f2", + "[1, 2, 3]") + + testAllApis( + 'f3, + "f3", + "f3", + "[1984-03-12, 1984-02-10]") + + testAllApis( + 'f5, + "f5", + "f5", + "[[1, 2, 3], null]") + + testAllApis( + 'f6, + "f6", + "f6", + "[1, null, null, 4]") + + testAllApis( + 'f2, + "f2", + "f2", + "[1, 2, 3]") + + testAllApis( + 'f2.at(1), + "f2.at(1)", + "f2[1]", + "1") + + testAllApis( + 'f3.at(1), + "f3.at(1)", + "f3[1]", + "1984-03-12") + + testAllApis( + 'f3.at(2), + "f3.at(2)", + "f3[2]", + "1984-02-10") + + testAllApis( + 'f5.at(1).at(2), + "f5.at(1).at(2)", + "f5[1][2]", + "2") + + testAllApis( + 'f5.at(2).at(2), + "f5.at(2).at(2)", + "f5[2][2]", + "null") + + testAllApis( + 'f4.at(2).at(2), + "f4.at(2).at(2)", + "f4[2][2]", + "null") + } + + @Test + def testArrayOperations(): Unit = { + // cardinality + testAllApis( + 'f2.cardinality(), + "f2.cardinality()", + "CARDINALITY(f2)", + "3") + + testAllApis( + 'f4.cardinality(), + "f4.cardinality()", + "CARDINALITY(f4)", + "null") + + // element + testAllApis( + 'f9.element(), + "f9.element()", + "ELEMENT(f9)", + "1") + + testAllApis( + 'f8.element(), + "f8.element()", + "ELEMENT(f8)", + "4.0") + + testAllApis( + 'f10.element(), + "f10.element()", + "ELEMENT(f10)", + "null") + + testAllApis( + 'f4.element(), + "f4.element()", + "ELEMENT(f4)", + "null") + + // comparison + testAllApis( + 'f2 === 'f5.at(1), + "f2 === f5.at(1)", + "f2 = f5[1]", + "true") + + testAllApis( + 'f6 === array(1, 2, 3), + "f6 === array(1, 2, 3)", + "f6 = ARRAY[1, 2, 3]", + "false") + + testAllApis( + 'f2 !== 'f5.at(1), + "f2 !== f5.at(1)", + "f2 <> f5[1]", + "false") + + testAllApis( + 'f2 === 'f7, + "f2 === f7", + "f2 = f7", + "false") + + testAllApis( + 'f2 !== 'f7, + "f2 !== f7", + "f2 <> f7", + "true") + } + + // ---------------------------------------------------------------------------------------------- + + case class MyCaseClass(string: String, int: Int) + + override def testData: Any = { + val testData = new Row(11) + testData.setField(0, null) + testData.setField(1, 42) + testData.setField(2, Array(1, 2, 3)) + testData.setField(3, Array(Date.valueOf("1984-03-12"), Date.valueOf("1984-02-10"))) + testData.setField(4, null) + testData.setField(5, Array(Array(1, 2, 3), null)) + testData.setField(6, Array[Integer](1, null, null, 4)) + testData.setField(7, Array(1, 2, 3, 4)) + testData.setField(8, Array(4.0)) + testData.setField(9, Array[Integer](1)) + testData.setField(10, Array[Integer]()) + testData + } + + override def typeInfo: TypeInformation[Any] = { + new RowTypeInfo(Seq( + Types.INT, + Types.INT, + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + ObjectArrayTypeInfo.getInfoFor(Types.DATE), + ObjectArrayTypeInfo.getInfoFor(ObjectArrayTypeInfo.getInfoFor(Types.INT)), + ObjectArrayTypeInfo.getInfoFor(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO), + ObjectArrayTypeInfo.getInfoFor(Types.INT), + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, + ObjectArrayTypeInfo.getInfoFor(Types.INT), + ObjectArrayTypeInfo.getInfoFor(Types.INT) + )).asInstanceOf[TypeInformation[Any]] + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala index b892cfb321712..52dc848a93d23 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/SqlExpressionTest.scala @@ -135,11 +135,13 @@ class SqlExpressionTest extends ExpressionTestBase { testSqlApi("CAST(2 AS DOUBLE)", "2.0") } - @Ignore // TODO we need a special code path that flattens ROW types @Test def testValueConstructorFunctions(): Unit = { - testSqlApi("ROW('hello world', 12)", "hello world") // test base only returns field 0 - testSqlApi("('hello world', 12)", "hello world") // test base only returns field 0 + // TODO we need a special code path that flattens ROW types + // testSqlApi("ROW('hello world', 12)", "hello world") // test base only returns field 0 + // testSqlApi("('hello world', 12)", "hello world") // test base only returns field 0 + testSqlApi("ARRAY[TRUE, FALSE][2]", "false") + testSqlApi("ARRAY[TRUE, TRUE]", "[true, true]") } @Test @@ -155,6 +157,12 @@ class SqlExpressionTest extends ExpressionTestBase { testSqlApi("QUARTER(DATE '2016-04-12')", "2") } + @Test + def testArrayFunctions(): Unit = { + testSqlApi("CARDINALITY(ARRAY[TRUE, TRUE, FALSE])", "3") + testSqlApi("ELEMENT(ARRAY['HELLO WORLD'])", "HELLO WORLD") + } + override def testData: Any = new Row(0) override def typeInfo: TypeInformation[Any] =