Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-18726][table-planner-blink] Support INSERT INTO specific colum… #14977

Merged
merged 2 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,8 @@ private DynamicTableSource createDynamicTableSource(
Thread.currentThread().getContextClassLoader(),
schemaTable.isTemporary());
}

public CatalogTable getCatalogTable() {
return catalogTable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.flink.table.planner.plan.schema.{GenericRelDataType, _}
import org.apache.flink.table.runtime.types.{LogicalTypeDataTypeConverter, PlannerTypeUtils}
import org.apache.flink.table.types.logical._
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo
import org.apache.flink.table.utils.TableSchemaUtils
import org.apache.flink.types.Nothing
import org.apache.flink.util.Preconditions.checkArgument

Expand Down Expand Up @@ -234,6 +235,16 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem)
buildStructType(fields.map(_.getName), fields.map(_.getType), StructKind.FULLY_QUALIFIED)
}

/**
* Creates a struct type with the physical columns using FlinkTypeFactory
*
* @param tableSchema schema to convert to Calcite's specific one
* @return a struct type with the input fieldNames, input fieldTypes.
*/
def buildPhysicalRelNodeRowType(tableSchema: TableSchema): RelDataType = {
buildRelNodeRowType(TableSchemaUtils.getPhysicalSchema(tableSchema))
}

/**
* Creates a struct type with the input fieldNames and input fieldTypes using FlinkTypeFactory.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ package org.apache.flink.table.planner.calcite
import org.apache.flink.sql.parser.SqlProperty
import org.apache.flink.sql.parser.dml.RichSqlInsert
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.planner.calcite.PreValidateReWriter.appendPartitionProjects
import org.apache.flink.table.planner.calcite.PreValidateReWriter.appendPartitionAndNullsProjects
import org.apache.flink.table.planner.plan.schema.{CatalogSourceTable, FlinkPreparingTableBase, LegacyCatalogSourceTable}

import org.apache.calcite.plan.RelOptTable
import org.apache.calcite.prepare.CalciteCatalogReader
Expand All @@ -47,16 +48,17 @@ class PreValidateReWriter(
val typeFactory: RelDataTypeFactory) extends SqlBasicVisitor[Unit] {
override def visit(call: SqlCall): Unit = {
call match {
case r: RichSqlInsert if r.getStaticPartitions.nonEmpty => r.getSource match {
case r: RichSqlInsert
if r.getStaticPartitions.nonEmpty || r.getTargetColumnList != null => r.getSource match {
case select: SqlSelect =>
appendPartitionProjects(r, validator, typeFactory, select, r.getStaticPartitions)
appendPartitionAndNullsProjects(r, validator, typeFactory, select, r.getStaticPartitions)
case values: SqlCall if values.getKind == SqlKind.VALUES =>
val newSource = appendPartitionProjects(r, validator, typeFactory, values,
val newSource = appendPartitionAndNullsProjects(r, validator, typeFactory, values,
r.getStaticPartitions)
r.setOperand(2, newSource)
case source =>
throw new ValidationException(
s"INSERT INTO <table> PARTITION statement only support "
s"INSERT INTO <table> PARTITION [(COLUMN LIST)] statement only support "
+ s"SELECT and VALUES clause for now, '$source' is not supported yet.")
}
case _ =>
Expand All @@ -67,8 +69,8 @@ class PreValidateReWriter(
object PreValidateReWriter {
//~ Tools ------------------------------------------------------------------
/**
* Append the static partitions to the data source projection list. The columns are appended to
* the corresponding positions.
* Append the static partitions and unspecified columns to the data source projection list.
* The columns are appended to the corresponding positions.
*
* <p>If we have a table A with schema (&lt;a&gt;, &lt;b&gt;, &lt;c&gt) whose
* partition columns are (&lt;a&gt;, &lt;c&gt;), and got a query
Expand All @@ -83,13 +85,25 @@ object PreValidateReWriter {
* </pre></blockquote>
* Where the "tpe1" and "tpe2" are data types of column a and c of target table A.
*
* <p>If we have a table A with schema (&lt;a&gt;, &lt;b&gt;, &lt;c&gt), and got a query
* <blockquote><pre>
* insert into A (a, b)
* select a, b from B
* </pre></blockquote>
* The query would be rewritten to:
* <blockquote><pre>
* insert into A
* select a, b, cast(null as tpeC) from B
* </pre></blockquote>
* Where the "tpeC" is data type of column c for target table A.
*
* @param sqlInsert RichSqlInsert instance
* @param validator Validator
* @param typeFactory type factory
* @param source Source to rewrite
* @param partitions Static partition statements
*/
def appendPartitionProjects(sqlInsert: RichSqlInsert,
def appendPartitionAndNullsProjects(sqlInsert: RichSqlInsert,
validator: FlinkCalciteSqlValidator,
typeFactory: RelDataTypeFactory,
source: SqlCall,
Expand All @@ -103,8 +117,7 @@ object PreValidateReWriter {
// just skip to let other validation error throw.
return source
}
val targetRowType = createTargetRowType(typeFactory,
calciteCatalogReader, table, sqlInsert.getTargetColumnList)
val targetRowType = createTargetRowType(typeFactory, table)
// validate partition fields first.
val assignedFields = new util.LinkedHashMap[Integer, SqlNode]
val relOptTable = table match {
Expand All @@ -121,26 +134,83 @@ object PreValidateReWriter {
assignedFields.put(targetField.getIndex,
maybeCast(value, value.createSqlType(typeFactory), targetField.getType, typeFactory))
}

// validate partial insert columns.

// the columnList may reorder fields (compare with fields of sink)
val targetPosition = new util.ArrayList[Int]()

if (sqlInsert.getTargetColumnList != null) {
val targetFields = new util.HashSet[Integer]
val targetColumns =
sqlInsert
.getTargetColumnList
.getList
.map(id => {
val targetField = SqlValidatorUtil.getTargetField(
targetRowType, typeFactory, id.asInstanceOf[SqlIdentifier],
calciteCatalogReader, relOptTable)
validateField(targetFields.add, id.asInstanceOf[SqlIdentifier], targetField)
targetField
})

val partitionColumns =
partitions
.getList
.map(property =>
SqlValidatorUtil.getTargetField(
targetRowType, typeFactory, property.asInstanceOf[SqlProperty].getKey,
calciteCatalogReader, relOptTable))

for (targetField <- targetRowType.getFieldList) {
if (!partitionColumns.contains(targetField)) {
if (!targetColumns.contains(targetField)) {
// padding null
val id = new SqlIdentifier(targetField.getName, SqlParserPos.ZERO)
if (!targetField.getType.isNullable) {
throw newValidationError(id, RESOURCE.columnNotNullable(targetField.getName))
}
validateField(idx => !assignedFields.contains(idx), id, targetField)
assignedFields.put(targetField.getIndex,
maybeCast(
SqlLiteral.createNull(SqlParserPos.ZERO),
typeFactory.createUnknownType(),
targetField.getType,
typeFactory))
} else {
// handle reorder
targetPosition.add(targetColumns.indexOf(targetField))
}
}
}
}

source match {
case select: SqlSelect =>
rewriteSelect(validator, select, targetRowType, assignedFields)
rewriteSelect(validator, select, targetRowType, assignedFields, targetPosition)
case values: SqlCall if values.getKind == SqlKind.VALUES =>
rewriteValues(values, targetRowType, assignedFields)
rewriteValues(values, targetRowType, assignedFields, targetPosition)
}
}

private def rewriteSelect(
validator: FlinkCalciteSqlValidator,
select: SqlSelect,
targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode]): SqlCall = {
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int]): SqlCall = {
// Expands the select list first in case there is a star(*).
// Validates the select first to register the where scope.
validator.validate(select)
val sourceList = validator.expandStar(select.getSelectList, select, false).getList

val fixedNodes = new util.ArrayList[SqlNode]
val currentNodes = new util.ArrayList[SqlNode](sourceList)
val currentNodes =
if (targetPosition.isEmpty) {
new util.ArrayList[SqlNode](sourceList)
} else {
reorder(new util.ArrayList[SqlNode](sourceList), targetPosition)
}
0 until targetRowType.getFieldList.length foreach {
idx =>
if (assignedFields.containsKey(idx)) {
Expand All @@ -161,7 +231,8 @@ object PreValidateReWriter {
private def rewriteValues(
values: SqlCall,
targetRowType: RelDataType,
assignedFields: util.LinkedHashMap[Integer, SqlNode]): SqlCall = {
assignedFields: util.LinkedHashMap[Integer, SqlNode],
targetPosition: util.List[Int]): SqlCall = {
val fixedNodes = new util.ArrayList[SqlNode]
0 until values.getOperandList.size() foreach {
valueIdx =>
Expand All @@ -171,7 +242,12 @@ object PreValidateReWriter {
} else {
Collections.singletonList(value)
}
val currentNodes = new util.ArrayList[SqlNode](valueAsList)
val currentNodes =
if (targetPosition.isEmpty) {
new util.ArrayList[SqlNode](valueAsList)
} else {
reorder(new util.ArrayList[SqlNode](valueAsList), targetPosition)
}
val fieldNodes = new util.ArrayList[SqlNode]
0 until targetRowType.getFieldList.length foreach {
fieldIdx =>
Expand All @@ -191,41 +267,40 @@ object PreValidateReWriter {
SqlStdOperatorTable.VALUES.createCall(values.getParserPosition, fixedNodes)
}

private def reorder(
sourceList: util.ArrayList[SqlNode],
targetPosition: util.List[Int]): util.ArrayList[SqlNode] = {
val targetList = new Array[SqlNode](sourceList.size())
0 until sourceList.size() foreach {
idx => targetList(targetPosition.get(idx)) = sourceList.get(idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here should be idx => targetList(idx) = sourceList.get(targetPosition.get(idx)) What do you think? @leonardBang @wuchong

Copy link
Member

@wuchong wuchong Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any tests to prove this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Test
  def testPartialInsertWithComplexReorder(): Unit = {
    tEnv.executeSql(
      s"""
         |CREATE TABLE testSink (
         |  `a` INT,
         |  `c` STRING,
         |  `c1` STRING,
         |  `c2` STRING,
         |  `c3` BIGINT,
         |  `d` INT,
         |  `e` DOUBLE
         |)
         |WITH (
         |  'connector' = 'values',
         |  'sink-insert-only' = 'false'
         |)
         |""".stripMargin)

    val t = env.fromCollection(tupleData2).toTable(tEnv, 'x, 'y)
    tEnv.createTemporaryView("MyTable", t)

    tEnv.executeSql(
      s"""
         |INSERT INTO testSink (a,c2,e,c,c1,c3,d)
         |SELECT 1,'c2',sum(y),'c','c1',33333,12 FROM MyTable GROUP BY x
         |""".stripMargin).await()
    val expected = List(
      "1,c,c1,c2,33333,12,0.1",
      "1,c,c1,c2,33333,12,0.4",
      "1,c,c1,c2,33333,12,1.0",
      "1,c,c1,c2,33333,12,2.2",
      "1,c,c1,c2,33333,12,3.9")
    val result = TestValuesTableFactory.getResults("testSink")
    assertEquals(expected.sorted, result.sorted)
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I think we should use plan testing instead of itcase.

}
new util.ArrayList[SqlNode](targetList.toList)
}
/**
* Derives a row-type for INSERT and UPDATE operations.
* Derives a physical row-type for INSERT and UPDATE operations.
*
* <p>This code snippet is almost inspired by
* [[org.apache.calcite.sql.validate.SqlValidatorImpl#createTargetRowType]].
* It is the best that the logic can be merged into Apache Calcite,
* but this needs time.
*
* @param typeFactory TypeFactory
* @param catalogReader CalciteCatalogReader
* @param table Target table for INSERT/UPDATE
* @param targetColumnList List of target columns, or null if not specified
* @return Rowtype
*/
private def createTargetRowType(
typeFactory: RelDataTypeFactory,
catalogReader: CalciteCatalogReader,
table: SqlValidatorTable,
targetColumnList: SqlNodeList): RelDataType = {
val rowType = table.getRowType
if (targetColumnList == null) return rowType
val fields = new util.ArrayList[util.Map.Entry[String, RelDataType]]
val assignedFields = new util.HashSet[Integer]
val relOptTable = table match {
case t: RelOptTable => t
case _ => null
}
for (node <- targetColumnList) {
val id = node.asInstanceOf[SqlIdentifier]
val targetField = SqlValidatorUtil.getTargetField(rowType,
typeFactory, id, catalogReader, relOptTable)
validateField(assignedFields.add, id, targetField)
fields.add(targetField)
table: SqlValidatorTable): RelDataType = {
table.unwrap(classOf[FlinkPreparingTableBase]) match {
case t: CatalogSourceTable =>
val schema = t.getCatalogTable.getSchema
typeFactory.asInstanceOf[FlinkTypeFactory].buildPhysicalRelNodeRowType(schema)
case t: LegacyCatalogSourceTable[_] =>
val schema = t.catalogTable.getSchema
typeFactory.asInstanceOf[FlinkTypeFactory].buildPhysicalRelNodeRowType(schema)
case _ =>
table.getRowType
}
typeFactory.createStructType(fields)
}

/** Check whether the field is valid. **/
Expand Down
Loading