Skip to content

Commit

Permalink
[FLINK-6094] [table] Add checks for hashCode/equals and little code c…
Browse files Browse the repository at this point in the history
…leanup
  • Loading branch information
twalthr committed Jan 9, 2018
1 parent 9623b25 commit 49c6d10
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ class DataStreamGroupWindowAggregate(

def getWindowProperties: Seq[NamedWindowProperty] = namedProperties

def getWindowAlias: String = window.aliasAttribute.toString

override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamGroupWindowAggregate(
window,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class DataStreamJoin(
schema: RowSchema,
ruleDescription: String)
extends BiRel(cluster, traitSet, leftNode, rightNode)
with CommonJoin
with DataStreamRel {
with CommonJoin
with DataStreamRel {

override def deriveRowType(): RelDataType = schema.relDataType

Expand Down Expand Up @@ -123,8 +123,8 @@ class DataStreamJoin(
} else {
throw TableException(
"Equality join predicate on incompatible types.\n" +
s"\tLeft: ${left},\n" +
s"\tRight: ${right},\n" +
s"\tLeft: $left,\n" +
s"\tRight: $right,\n" +
s"\tCondition: (${joinConditionToString(schema.relDataType,
joinCondition, getExpressionString)})"
)
Expand All @@ -138,8 +138,9 @@ class DataStreamJoin(

val (connectOperator, nullCheck) = joinType match {
case JoinRelType.INNER => (leftDataStream.connect(rightDataStream), false)
case _ => throw TableException(s"An Unsupported JoinType [ $joinType ]. Currently only " +
s"non-window inner joins with at least one equality predicate are supported")
case _ =>
throw TableException(s"Unsupported join type '$joinType'. Currently only " +
s"non-window inner joins with at least one equality predicate are supported")
}

val generator = new FunctionCodeGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.{RelNode, RelVisitor}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
import org.apache.calcite.sql.SqlKind
import org.apache.flink.table.expressions.ProctimeAttribute
import org.apache.flink.table.plan.nodes.datastream._

import _root_.scala.collection.JavaConverters._
import _root_.scala.collection.JavaConversions._
import scala.collection.mutable

Expand Down Expand Up @@ -66,15 +66,15 @@ object UpdatingPlanChecker {
// belong to the same group, i.e., pk1. Here we use the lexicographic smallest attribute as
// the common group id. A node can have keys if it generates the keys by itself or it
// forwards keys from its input(s).
def visit(node: RelNode): Option[List[(String, String)]] = {
def visit(node: RelNode): Option[Seq[(String, String)]] = {
node match {
case c: DataStreamCalc =>
val inputKeys = visit(node.getInput(0))
// check if input has keys
if (inputKeys.isDefined) {
// track keys forward
val inNames = c.getInput.getRowType.getFieldNames
val inOutNames = c.getProgram.getNamedProjects.asScala
val inOutNames = c.getProgram.getNamedProjects
.map(p => {
c.getProgram.expandLocalRef(p.left) match {
// output field is forwarded input field
Expand Down Expand Up @@ -102,7 +102,8 @@ object UpdatingPlanChecker {

val inputKeysMap = inputKeys.get.toMap
val inOutGroups = inputKeysAndOutput
.map(e => (inputKeysMap(e._1), e._2)).sorted.reverse.toMap
.map(e => (inputKeysMap(e._1), e._2))
.toMap

// get output keys
val outputKeys = inputKeysAndOutput
Expand All @@ -111,7 +112,7 @@ object UpdatingPlanChecker {
// check if all keys have been preserved
if (outputKeys.map(_._2).distinct.length == inputKeys.get.map(_._2).distinct.length) {
// all key have been preserved (but possibly renamed)
Some(outputKeys.toList)
Some(outputKeys)
} else {
// some (or all) keys have been removed. Keys are no longer unique and removed
None
Expand All @@ -125,26 +126,27 @@ object UpdatingPlanChecker {
visit(node.getInput(0))
case a: DataStreamGroupAggregate =>
// get grouping keys
val groupKeys = a.getRowType.getFieldNames.asScala.take(a.getGroupings.length)
Some(groupKeys.map(e => (e, e)).toList)
val groupKeys = a.getRowType.getFieldNames.take(a.getGroupings.length)
Some(groupKeys.map(e => (e, e)))
case w: DataStreamGroupWindowAggregate =>
// get grouping keys
val groupKeys =
w.getRowType.getFieldNames.asScala.take(w.getGroupings.length).toArray
// get window start and end time
val windowStartEnd = w.getWindowProperties.map(_.name)
w.getRowType.getFieldNames.take(w.getGroupings.length).toArray
// proctime is not a valid key
val windowProperties = w.getWindowProperties
.filter(!_.property.isInstanceOf[ProctimeAttribute])
.map(_.name)
// we have only a unique key if at least one window property is selected
if (windowStartEnd.nonEmpty) {
val smallestAttribute = windowStartEnd.min
Some((groupKeys.map(e => (e, e)) ++ windowStartEnd.map((_, smallestAttribute))).toList)
if (windowProperties.nonEmpty) {
Some(groupKeys.map(e => (e, e)) ++ windowProperties.map(e => (e, e)))
} else {
None
}

case j: DataStreamJoin =>
val joinType = j.getJoinType
joinType match {
case JoinRelType.INNER => {
case JoinRelType.INNER =>
// get key(s) for inner join
val lInKeys = visit(j.getLeft)
val rInKeys = visit(j.getRight)
Expand All @@ -170,18 +172,17 @@ object UpdatingPlanChecker {
.map(rInNames.get(_))
.map(rInNamesToJoinNamesMap(_))

val inKeys: List[(String, String)] = lInKeys.get ++ rInKeys.get
val inKeys: Seq[(String, String)] = lInKeys.get ++ rInKeys.get
.map(e => (rInNamesToJoinNamesMap(e._1), rInNamesToJoinNamesMap(e._2)))

getOutputKeysForInnerJoin(
joinNames,
inKeys,
lJoinKeys.zip(rJoinKeys).toList
lJoinKeys.zip(rJoinKeys)
)
}
}
case _ => throw new UnsupportedOperationException(
s"An Unsupported JoinType [ $joinType ]")
case _ =>
throw new UnsupportedOperationException(s"Unsupported join type '$joinType'")
}
case _: DataStreamRel =>
// anything else does not forward keys, so we can stop
Expand All @@ -199,9 +200,9 @@ object UpdatingPlanChecker {
*/
def getOutputKeysForInnerJoin(
inNames: Seq[String],
inKeys: List[(String, String)],
joinKeys: List[(String, String)])
: Option[List[(String, String)]] = {
inKeys: Seq[(String, String)],
joinKeys: Seq[(String, String)])
: Option[Seq[(String, String)]] = {

val nameToGroups = mutable.HashMap.empty[String,String]

Expand All @@ -210,7 +211,7 @@ object UpdatingPlanChecker {
val ga: String = findGroup(nameA)
val gb: String = findGroup(nameB)
if (!ga.equals(gb)) {
if(ga.compare(gb) < 0) {
if (ga.compare(gb) < 0) {
nameToGroups += (gb -> ga)
} else {
nameToGroups += (ga -> gb)
Expand Down Expand Up @@ -242,14 +243,13 @@ object UpdatingPlanChecker {
// merge groups
joinKeys.foreach(e => merge(e._1, e._2))
// make sure all name point to the group name directly
inNames.foreach(findGroup(_))
inNames.foreach(findGroup)

val outputGroups = inKeys.map(e => nameToGroups(e._1)).distinct
Some(
inNames
.filter(e => outputGroups.contains(nameToGroups(e)))
.map(e => (e, nameToGroups(e)))
.toList
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.ResultTypeQueryable
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.table.typeutils.TypeCheckUtils.validateEqualsHashCode
import org.apache.flink.types.Row

/**
Expand All @@ -33,6 +34,9 @@ class CRowKeySelector(
extends KeySelector[CRow, Row]
with ResultTypeQueryable[Row] {

// check if type implements proper equals/hashCode
validateEqualsHashCode("grouping", returnType)

override def getKey(value: CRow): Row = {
Row.project(value.row, keyFields)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@ package org.apache.flink.table.runtime.join
import org.apache.flink.api.common.functions.FlatJoinFunction
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.table.api.{StreamQueryConfig, Types}
import org.apache.flink.table.codegen.Compiler
import org.apache.flink.table.runtime.CRowWrappingMultiOutputCollector
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.table.typeutils.TypeCheckUtils.validateEqualsHashCode
import org.apache.flink.table.util.Logging
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
import org.apache.flink.table.codegen.Compiler
import org.apache.flink.table.util.Logging


/**
Expand All @@ -53,8 +55,12 @@ class NonWindowInnerJoin(
genJoinFuncCode: String,
queryConfig: StreamQueryConfig)
extends CoProcessFunction[CRow, CRow, CRow]
with Compiler[FlatJoinFunction[Row, Row, Row]]
with Logging {
with Compiler[FlatJoinFunction[Row, Row, Row]]
with Logging {

// check if input types implement proper equals/hashCode
validateEqualsHashCode("join", leftType)
validateEqualsHashCode("join", rightType)

// state to hold left stream element
private var leftState: MapState[Row, JTuple2[Int, Long]] = _
Expand Down Expand Up @@ -116,7 +122,7 @@ class NonWindowInnerJoin(
ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
out: Collector[CRow]): Unit = {

processElement(valueC, ctx, out, leftTimer, leftState, rightState, true)
processElement(valueC, ctx, out, leftTimer, leftState, rightState, isLeft = true)
}

/**
Expand All @@ -132,7 +138,7 @@ class NonWindowInnerJoin(
ctx: CoProcessFunction[CRow, CRow, CRow]#Context,
out: Collector[CRow]): Unit = {

processElement(valueC, ctx, out, rightTimer, rightState, leftState, false)
processElement(valueC, ctx, out, rightTimer, rightState, leftState, isLeft = false)
}


Expand Down Expand Up @@ -168,7 +174,6 @@ class NonWindowInnerJoin(
}
}


def getNewExpiredTime(
curProcessTime: Long,
oldExpiredTime: Long): Long = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,52 +109,56 @@ object TypeCheckUtils {

/**
* Checks whether a type implements own hashCode() and equals() methods for storing an instance
* in Flink's state.
* in Flink's state or performing a keyBy operation.
*
* @param name name of the operation
* @param t type information to be validated
*/
def validateStateType(t: TypeInformation[_]): Unit = t match {
def validateEqualsHashCode(name: String, t: TypeInformation[_]): Unit = t match {

// make sure that a POJO class is a valid state type
case pt: PojoTypeInfo[_] =>
// we don't check the types recursively to give a chance of wrapping
// proper hashCode/equals methods around an immutable type
validateStateType(pt.getClass)
validateEqualsHashCode(name, pt.getClass)
// recursively check composite types
case ct: CompositeType[_] =>
validateStateType(t.getTypeClass)
validateEqualsHashCode(name, t.getTypeClass)
// we check recursively for entering Flink types such as tuples and rows
for (i <- 0 until ct.getArity) {
val subtype = ct.getTypeAt(i)
validateStateType(subtype)
validateEqualsHashCode(name, subtype)
}
// check other type information only based on the type class
case _: TypeInformation[_] =>
validateStateType(t.getTypeClass)
validateEqualsHashCode(name, t.getTypeClass)
}

/**
* Checks whether a class implements own hashCode() and equals() methods for storing an instance
* in Flink's state.
* in Flink's state or performing a keyBy operation.
*
* @param name name of the operation
* @param c class to be validated
*/
def validateStateType(c: Class[_]): Unit = {
def validateEqualsHashCode(name: String, c: Class[_]): Unit = {

// skip primitives
if (!c.isPrimitive) {
// check the component type of arrays
if (c.isArray) {
validateStateType(c.getComponentType)
validateEqualsHashCode(name, c.getComponentType)
}
// check type for methods
else {
if (c.getMethod("hashCode").getDeclaringClass eq classOf[Object]) {
throw new ValidationException(
s"Type '${c.getCanonicalName}' cannot be used in a stateful operation because it " +
s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " +
s"does not implement a proper hashCode() method.")
}
if (c.getMethod("equals", classOf[Object]).getDeclaringClass eq classOf[Object]) {
throw new ValidationException(
s"Type '${c.getCanonicalName}' cannot be used in a stateful operation because it " +
s"Type '${c.getCanonicalName}' cannot be used in a $name operation because it " +
s"does not implement a proper equals() method.")
}
}
Expand Down
Loading

0 comments on commit 49c6d10

Please sign in to comment.