Skip to content

Commit

Permalink
[FLINK-5280] [table] Refactor TableSource interface.
Browse files Browse the repository at this point in the history
This closes apache#3039.
  • Loading branch information
mushketyk authored and fhueske committed Jan 10, 2017
1 parent d4d7cc3 commit 38ded2b
Show file tree
Hide file tree
Showing 22 changed files with 259 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
package org.apache.flink.streaming.connectors.kafka;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.types.Row;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.sources.StreamTableSource;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.util.serialization.DeserializationSchema;
import org.apache.flink.table.sources.StreamTableSource;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

import java.util.Properties;
Expand Down Expand Up @@ -111,24 +111,9 @@ public DataStream<Row> getDataStream(StreamExecutionEnvironment env) {
return kafkaSource;
}

@Override
public int getNumberOfFields() {
return fieldNames.length;
}

@Override
public String[] getFieldsNames() {
return fieldNames;
}

@Override
public TypeInformation<?>[] getFieldTypes() {
return fieldTypes;
}

@Override
public TypeInformation<Row> getReturnType() {
return new RowTypeInfo(fieldTypes);
return new RowTypeInfo(fieldTypes, fieldNames);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ abstract class BatchTableEnvironment(
* @return The [[DataSet]] that corresponds to the translated [[Table]].
*/
protected def translate[A](logicalPlan: RelNode)(implicit tpe: TypeInformation[A]): DataSet[A] = {
validateType(tpe)
TableEnvironment.validateType(tpe)

logicalPlan match {
case node: DataSetRel =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ abstract class StreamTableEnvironment(
protected def translate[A]
(logicalPlan: RelNode)(implicit tpe: TypeInformation[A]): DataStream[A] = {

validateType(tpe)
TableEnvironment.validateType(tpe)

logicalPlan match {
case node: DataStreamRel =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

package org.apache.flink.table.api

import _root_.java.util.concurrent.atomic.AtomicInteger
import _root_.java.lang.reflect.Modifier
import _root_.java.util.concurrent.atomic.AtomicInteger

import org.apache.calcite.config.Lex
import org.apache.calcite.jdbc.CalciteSchema
Expand All @@ -32,7 +32,8 @@ import org.apache.calcite.sql.parser.SqlParser
import org.apache.calcite.sql.util.ChainedSqlOperatorTable
import org.apache.calcite.tools.{FrameworkConfig, Frameworks, RuleSet, RuleSets}
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo, TupleTypeInfo}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo}
import org.apache.flink.api.java.{ExecutionEnvironment => JavaBatchExecEnv}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.scala.{ExecutionEnvironment => ScalaBatchExecEnv}
Expand All @@ -48,6 +49,7 @@ import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.plan.cost.DataSetCostFactory
import org.apache.flink.table.plan.schema.RelTable
import org.apache.flink.table.sinks.TableSink
import org.apache.flink.table.sources.{DefinedFieldNames, TableSource}
import org.apache.flink.table.validate.FunctionCatalog

import _root_.scala.collection.JavaConverters._
Expand Down Expand Up @@ -336,48 +338,16 @@ abstract class TableEnvironment(val config: TableConfig) {
frameworkConfig
}

protected def validateType(typeInfo: TypeInformation[_]): Unit = {
val clazz = typeInfo.getTypeClass
if ((clazz.isMemberClass && !Modifier.isStatic(clazz.getModifiers)) ||
!Modifier.isPublic(clazz.getModifiers) ||
clazz.getCanonicalName == null) {
throw TableException(s"Class '$clazz' described in type information '$typeInfo' must be " +
s"static and globally accessible.")
}
}

/**
* Returns field names and field positions for a given [[TypeInformation]].
*
* Field names are automatically extracted for
* [[org.apache.flink.api.common.typeutils.CompositeType]].
* The method fails if inputType is not a
* [[org.apache.flink.api.common.typeutils.CompositeType]].
*
* @param inputType The TypeInformation extract the field names and positions from.
* @tparam A The type of the TypeInformation.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
protected[flink] def getFieldInfo[A](inputType: TypeInformation[A]):
(Array[String], Array[Int]) =
{
validateType(inputType)

val fieldNames: Array[String] = inputType match {
case t: TupleTypeInfo[A] => t.getFieldNames
case c: CaseClassTypeInfo[A] => c.getFieldNames
case p: PojoTypeInfo[A] => p.getFieldNames
case r: RowTypeInfo => r.getFieldNames
case tpe =>
throw new TableException(s"Type $tpe lacks explicit field naming")
}
val fieldIndexes = fieldNames.indices.toArray

if (fieldNames.contains("*")) {
throw new TableException("Field name can not be '*'.")
}

(fieldNames, fieldIndexes)
(Array[String], Array[Int]) = {
(TableEnvironment.getFieldNames(inputType), TableEnvironment.getFieldIndices(inputType))
}

/**
Expand All @@ -393,7 +363,7 @@ abstract class TableEnvironment(val config: TableConfig) {
inputType: TypeInformation[A],
exprs: Array[Expression]): (Array[String], Array[Int]) = {

validateType(inputType)
TableEnvironment.validateType(inputType)

val indexedNames: Array[(Int, String)] = inputType match {
case a: AtomicType[A] =>
Expand Down Expand Up @@ -554,4 +524,95 @@ object TableEnvironment {

new ScalaStreamTableEnv(executionEnvironment, tableConfig)
}

/**
* Returns field names for a given [[TypeInformation]].
*
* @param inputType The TypeInformation extract the field names.
* @tparam A The type of the TypeInformation.
* @return An array holding the field names
*/
def getFieldNames[A](inputType: TypeInformation[A]): Array[String] = {
validateType(inputType)

val fieldNames: Array[String] = inputType match {
case t: CompositeType[_] => t.getFieldNames
case a: AtomicType[_] => Array("f0")
case tpe =>
throw new TableException(s"Currently only CompositeType and AtomicType are supported. " +
s"Type $tpe lacks explicit field naming")
}

if (fieldNames.contains("*")) {
throw new TableException("Field name can not be '*'.")
}

fieldNames
}

/**
* Validate if class represented by the typeInfo is static and globally accessible
* @param typeInfo type to check
* @throws TableException if type does not meet these criteria
*/
def validateType(typeInfo: TypeInformation[_]): Unit = {
val clazz = typeInfo.getTypeClass
if ((clazz.isMemberClass && !Modifier.isStatic(clazz.getModifiers)) ||
!Modifier.isPublic(clazz.getModifiers) ||
clazz.getCanonicalName == null) {
throw TableException(s"Class '$clazz' described in type information '$typeInfo' must be " +
s"static and globally accessible.")
}
}

/**
* Returns field indexes for a given [[TypeInformation]].
*
* @param inputType The TypeInformation extract the field positions from.
* @return An array holding the field positions
*/
def getFieldIndices(inputType: TypeInformation[_]): Array[Int] = {
getFieldNames(inputType).indices.toArray
}

/**
* Returns field types for a given [[TypeInformation]].
*
* @param inputType The TypeInformation to extract field types from.
* @return An array holding the field types.
*/
def getFieldTypes(inputType: TypeInformation[_]): Array[TypeInformation[_]] = {
validateType(inputType)

inputType match {
case t: CompositeType[_] => 0.until(t.getArity).map(t.getTypeAt(_)).toArray
case a: AtomicType[_] => Array(a.asInstanceOf[TypeInformation[_]])
case tpe =>
throw new TableException(s"Currently only CompositeType and AtomicType are supported.")
}
}

/**
* Returns field names for a given [[TableSource]].
*
* @param tableSource The TableSource to extract field names from.
* @tparam A The type of the TableSource.
* @return An array holding the field names.
*/
def getFieldNames[A](tableSource: TableSource[A]): Array[String] = tableSource match {
case d: DefinedFieldNames => d.getFieldNames
case _ => TableEnvironment.getFieldNames(tableSource.getReturnType)
}

/**
* Returns field indices for a given [[TableSource]].
*
* @param tableSource The TableSource to extract field indices from.
* @tparam A The type of the TableSource.
* @return An array holding the field indices.
*/
def getFieldIndices[A](tableSource: TableSource[A]): Array[Int] = tableSource match {
case d: DefinedFieldNames => d.getFieldIndices
case _ => TableEnvironment.getFieldIndices(tableSource.getReturnType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
Expand Down Expand Up @@ -268,23 +268,9 @@ object UserDefinedFunctionUtils {
def getFieldInfo(inputType: TypeInformation[_])
: (Array[String], Array[Int], Array[TypeInformation[_]]) = {

val fieldNames: Array[String] = inputType match {
case t: CompositeType[_] => t.getFieldNames
case a: AtomicType[_] => Array("f0")
case tpe =>
throw new TableException(s"Currently only CompositeType and AtomicType are supported. " +
s"Type $tpe lacks explicit field naming")
}
val fieldIndexes = fieldNames.indices.toArray
val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i =>
inputType match {
case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]]
case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]]
case tpe =>
throw new TableException(s"Currently only CompositeType and AtomicType are supported.")
}
}
(fieldNames, fieldIndexes, fieldTypes)
(TableEnvironment.getFieldNames(inputType),
TableEnvironment.getFieldIndices(inputType),
TableEnvironment.getFieldTypes(inputType))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.table.plan.nodes

import org.apache.calcite.rel.`type`.RelDataType
import java.util

import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.flink.api.common.functions.MapFunction
Expand Down Expand Up @@ -103,10 +105,12 @@ trait FlinkRel {

}


private[flink] def estimateRowSize(rowType: RelDataType): Double = {
val fieldList = rowType.getFieldList

rowType.getFieldList.map(_.getType.getSqlTypeName).foldLeft(0) { (s, t) =>
t match {
fieldList.map(_.getType.getSqlTypeName).zipWithIndex.foldLeft(0) { (s, t) =>
t._1 match {
case SqlTypeName.TINYINT => s + 1
case SqlTypeName.SMALLINT => s + 2
case SqlTypeName.INTEGER => s + 4
Expand All @@ -120,6 +124,7 @@ trait FlinkRel {
case typeName if SqlTypeName.YEAR_INTERVAL_TYPES.contains(typeName) => s + 8
case typeName if SqlTypeName.DAY_INTERVAL_TYPES.contains(typeName) => s + 4
case SqlTypeName.TIME | SqlTypeName.TIMESTAMP | SqlTypeName.DATE => s + 12
case SqlTypeName.ROW => s + estimateRowSize(fieldList.get(t._2).getType()).asInstanceOf[Int]
case _ => throw TableException(s"Unsupported data type encountered: $t")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.api.{BatchTableEnvironment, TableEnvironment}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.schema.TableSourceTable
import org.apache.flink.table.sources.BatchTableSource
Expand All @@ -38,7 +38,9 @@ class BatchTableSourceScan(

override def deriveRowType() = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes)
flinkTypeFactory.buildRowDataType(
TableEnvironment.getFieldNames(tableSource),
TableEnvironment.getFieldTypes(tableSource.getReturnType))
}

override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
Expand All @@ -57,7 +59,7 @@ class BatchTableSourceScan(

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.item("fields", tableSource.getFieldsNames.mkString(", "))
.item("fields", TableEnvironment.getFieldNames(tableSource).mkString(", "))
}

override def translateToPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.schema.TableSourceTable
import org.apache.flink.table.sources.StreamTableSource
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.StreamTableEnvironment
import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment}

/** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */
class StreamTableSourceScan(
Expand All @@ -38,7 +38,9 @@ class StreamTableSourceScan(

override def deriveRowType() = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes)
flinkTypeFactory.buildRowDataType(
TableEnvironment.getFieldNames(tableSource),
TableEnvironment.getFieldTypes(tableSource.getReturnType))
}

override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
Expand All @@ -57,7 +59,7 @@ class StreamTableSourceScan(

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.item("fields", tableSource.getFieldsNames.mkString(", "))
.item("fields", TableEnvironment.getFieldNames(tableSource).mkString(", "))
}

override def translateToPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class BatchTableSourceScanRule
/** Rule must only match if TableScan targets a [[BatchTableSource]] */
override def matches(call: RelOptRuleCall): Boolean = {
val scan: TableScan = call.rel(0).asInstanceOf[TableScan]
val dataSetTable = scan.getTable.unwrap(classOf[TableSourceTable])
val dataSetTable = scan.getTable.unwrap(classOf[TableSourceTable[_]])
dataSetTable match {
case tst: TableSourceTable =>
case tst: TableSourceTable[_] =>
tst.tableSource match {
case _: BatchTableSource[_] =>
true
Expand All @@ -57,7 +57,7 @@ class BatchTableSourceScanRule
val scan: TableScan = rel.asInstanceOf[TableScan]
val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)

val tableSource = scan.getTable.unwrap(classOf[TableSourceTable]).tableSource
val tableSource = scan.getTable.unwrap(classOf[TableSourceTable[_]]).tableSource
.asInstanceOf[BatchTableSource[_]]
new BatchTableSourceScan(
rel.getCluster,
Expand Down
Loading

0 comments on commit 38ded2b

Please sign in to comment.