Skip to content

Commit

Permalink
Merge branch 'FLINK-4086'
Browse files Browse the repository at this point in the history
  • Loading branch information
twalthr committed Jul 4, 2016
2 parents e34fea5 + 429b844 commit 18995c8
Show file tree
Hide file tree
Showing 14 changed files with 202 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,32 @@ abstract class Expression extends TreeNode[Expression] {
* Returns the [[TypeInformation]] for evaluating this expression.
* It is sometimes not available until the expression is valid.
*/
def resultType: TypeInformation[_]
private[flink] def resultType: TypeInformation[_]

/**
* One pass validation of the expression tree in post order.
*/
lazy val valid: Boolean = childrenValid && validateInput().isSuccess
private[flink] lazy val valid: Boolean = childrenValid && validateInput().isSuccess

def childrenValid: Boolean = children.forall(_.valid)
private[flink] def childrenValid: Boolean = children.forall(_.valid)

/**
* Check input data types, inputs number or other properties specified by this expression.
* Return `ValidationSuccess` if it pass the check,
* or `ValidationFailure` with supplement message explaining the error.
* Note: we should only call this method until `childrenValid == true`
*/
def validateInput(): ExprValidationResult = ValidationSuccess
private[flink] def validateInput(): ExprValidationResult = ValidationSuccess

/**
* Convert Expression to its counterpart in Calcite, i.e. RexNode
*/
def toRexNode(implicit relBuilder: RelBuilder): RexNode =
private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
throw new UnsupportedOperationException(
s"${this.getClass.getName} cannot be transformed to RexNode"
)

def checkEquals(other: Expression): Boolean = {
private[flink] def checkEquals(other: Expression): Boolean = {
if (this.getClass != other.getClass) {
false
} else {
Expand All @@ -73,16 +73,16 @@ abstract class Expression extends TreeNode[Expression] {
}

abstract class BinaryExpression extends Expression {
def left: Expression
def right: Expression
def children = Seq(left, right)
private[flink] def left: Expression
private[flink] def right: Expression
private[flink] def children = Seq(left, right)
}

abstract class UnaryExpression extends Expression {
def child: Expression
def children = Seq(child)
private[flink] def child: Expression
private[flink] def children = Seq(child)
}

abstract class LeafExpression extends Expression {
val children = Nil
private[flink] val children = Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ trait InputTypeSpec extends Expression {
* def expectedTypes: Seq[TypeInformation[_]] = DOUBLE_TYPE_INFO :: DOUBLE_TYPE_INFO :: Nil
* }}}
*/
def expectedTypes: Seq[TypeInformation[_]]
private[flink] def expectedTypes: Seq[TypeInformation[_]]

override def validateInput(): ExprValidationResult = {
override private[flink] def validateInput(): ExprValidationResult = {
val typeMismatches = mutable.ArrayBuffer.empty[String]
children.zip(expectedTypes).zipWithIndex.foreach { case ((e, tpe), i) =>
if (e.resultType != tpe) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,69 +29,73 @@ abstract sealed class Aggregation extends UnaryExpression {

override def toString = s"Aggregate($child)"

override def toRexNode(implicit relBuilder: RelBuilder): RexNode =
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
throw new UnsupportedOperationException("Aggregate cannot be transformed to RexNode")

/**
* Convert Aggregate to its counterpart in Calcite, i.e. AggCall
*/
def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall
private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall
}

case class Sum(child: Expression) extends Aggregation {
override def toString = s"sum($child)"

override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(SqlStdOperatorTable.SUM, false, null, name, child.toRexNode)
}

override def resultType = child.resultType
override private[flink] def resultType = child.resultType

override def validateInput = TypeCheckUtils.assertNumericExpr(child.resultType, "sum")
override private[flink] def validateInput =
TypeCheckUtils.assertNumericExpr(child.resultType, "sum")
}

case class Min(child: Expression) extends Aggregation {
override def toString = s"min($child)"

override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(SqlStdOperatorTable.MIN, false, null, name, child.toRexNode)
}

override def resultType = child.resultType
override private[flink] def resultType = child.resultType

override def validateInput = TypeCheckUtils.assertOrderableExpr(child.resultType, "min")
override private[flink] def validateInput =
TypeCheckUtils.assertOrderableExpr(child.resultType, "min")
}

case class Max(child: Expression) extends Aggregation {
override def toString = s"max($child)"

override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(SqlStdOperatorTable.MAX, false, null, name, child.toRexNode)
}

override def resultType = child.resultType
override private[flink] def resultType = child.resultType

override def validateInput = TypeCheckUtils.assertOrderableExpr(child.resultType, "max")
override private[flink] def validateInput =
TypeCheckUtils.assertOrderableExpr(child.resultType, "max")
}

case class Count(child: Expression) extends Aggregation {
override def toString = s"count($child)"

override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(SqlStdOperatorTable.COUNT, false, null, name, child.toRexNode)
}

override def resultType = BasicTypeInfo.LONG_TYPE_INFO
override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO
}

case class Avg(child: Expression) extends Aggregation {
override def toString = s"avg($child)"

override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
relBuilder.aggregateCall(SqlStdOperatorTable.AVG, false, null, name, child.toRexNode)
}

override def resultType = child.resultType
override private[flink] def resultType = child.resultType

override def validateInput = TypeCheckUtils.assertNumericExpr(child.resultType, "avg")
override private[flink] def validateInput =
TypeCheckUtils.assertNumericExpr(child.resultType, "avg")
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ import org.apache.flink.api.table.validate._
import scala.collection.JavaConversions._

abstract class BinaryArithmetic extends BinaryExpression {
def sqlOperator: SqlOperator
private[flink] def sqlOperator: SqlOperator

override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(sqlOperator, children.map(_.toRexNode))
}

override def resultType: TypeInformation[_] =
override private[flink] def resultType: TypeInformation[_] =
TypeCoercion.widerTypeOf(left.resultType, right.resultType) match {
case Some(t) => t
case None =>
throw new RuntimeException("This should never happen.")
}

// TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = {
override private[flink] def validateInput(): ExprValidationResult = {
if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) {
ValidationFailure(s"$this requires both operands Numeric, got " +
s"${left.resultType} and ${right.resultType}")
Expand All @@ -56,9 +56,9 @@ abstract class BinaryArithmetic extends BinaryExpression {
case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left + $right)"

val sqlOperator = SqlStdOperatorTable.PLUS
private[flink] val sqlOperator = SqlStdOperatorTable.PLUS

override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
if(isString(left.resultType)) {
val castedRight = Cast(right, BasicTypeInfo.STRING_TYPE_INFO)
relBuilder.call(SqlStdOperatorTable.PLUS, left.toRexNode, castedRight.toRexNode)
Expand All @@ -73,7 +73,7 @@ case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
}

// TODO: tighten this rule once we implemented type coercion rules during validation
override def validateInput(): ExprValidationResult = {
override private[flink] def validateInput(): ExprValidationResult = {
if (isString(left.resultType) || isString(right.resultType)) {
ValidationSuccess
} else if (!isNumeric(left.resultType) || !isNumeric(right.resultType)) {
Expand All @@ -88,36 +88,36 @@ case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
case class UnaryMinus(child: Expression) extends UnaryExpression {
override def toString = s"-($child)"

override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(SqlStdOperatorTable.UNARY_MINUS, child.toRexNode)
}

override def resultType = child.resultType
override private[flink] def resultType = child.resultType

override def validateInput(): ExprValidationResult =
override private[flink] def validateInput(): ExprValidationResult =
TypeCheckUtils.assertNumericExpr(child.resultType, "unary minus")
}

case class Minus(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left - $right)"

val sqlOperator = SqlStdOperatorTable.MINUS
private[flink] val sqlOperator = SqlStdOperatorTable.MINUS
}

case class Div(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left / $right)"

val sqlOperator = SqlStdOperatorTable.DIVIDE
private[flink] val sqlOperator = SqlStdOperatorTable.DIVIDE
}

case class Mul(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left * $right)"

val sqlOperator = SqlStdOperatorTable.MULTIPLY
private[flink] val sqlOperator = SqlStdOperatorTable.MULTIPLY
}

case class Mod(left: Expression, right: Expression) extends BinaryArithmetic {
override def toString = s"($left % $right)"

val sqlOperator = SqlStdOperatorTable.MOD
private[flink] val sqlOperator = SqlStdOperatorTable.MOD
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ import org.apache.flink.api.table.validate.{ExprValidationResult, ValidationFail
*/
case class Call(functionName: String, args: Seq[Expression]) extends Expression {

override def children: Seq[Expression] = args
override private[flink] def children: Seq[Expression] = args

override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
throw new UnresolvedException(s"trying to convert UnresolvedFunction $functionName to RexNode")
}

override def toString = s"\\$functionName(${args.mkString(", ")})"

override def resultType =
override private[flink] def resultType =
throw new UnresolvedException(s"calling resultType on UnresolvedFunction $functionName")

override def validateInput(): ExprValidationResult =
override private[flink] def validateInput(): ExprValidationResult =
ValidationFailure(s"Unresolved function call: $functionName")
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ case class Cast(child: Expression, resultType: TypeInformation[_]) extends Unary

override def toString = s"$child.cast($resultType)"

override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.cast(child.toRexNode, TypeConverter.typeInfoToSqlType(resultType))
}

override def makeCopy(anyRefs: Array[AnyRef]): this.type = {
override private[flink] def makeCopy(anyRefs: Array[AnyRef]): this.type = {
val child: Expression = anyRefs.head.asInstanceOf[Expression]
copy(child, resultType).asInstanceOf[this.type]
}

override def validateInput(): ExprValidationResult = {
override private[flink] def validateInput(): ExprValidationResult = {
if (TypeCoercion.canCast(child.resultType, resultType)) {
ValidationSuccess
} else {
Expand Down
Loading

0 comments on commit 18995c8

Please sign in to comment.