diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala index 8fc91d4a849a5..4f75fad4d7c9d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala @@ -21,8 +21,9 @@ package org.apache.flink.table.planner.calcite import org.apache.flink.sql.parser.SqlProperty import org.apache.flink.sql.parser.dml.RichSqlInsert import org.apache.flink.table.api.ValidationException -import org.apache.flink.table.planner.calcite.PreValidateReWriter.appendPartitionAndNullsProjects +import org.apache.flink.table.planner.calcite.PreValidateReWriter.{appendPartitionAndNullsProjects, notSupported} import org.apache.flink.table.planner.plan.schema.{CatalogSourceTable, FlinkPreparingTableBase, LegacyCatalogSourceTable} +import org.apache.flink.util.Preconditions.checkArgument import org.apache.calcite.plan.RelOptTable import org.apache.calcite.prepare.CalciteCatalogReader @@ -33,7 +34,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.parser.SqlParserPos import org.apache.calcite.sql.util.SqlBasicVisitor import org.apache.calcite.sql.validate.{SqlValidatorException, SqlValidatorTable, SqlValidatorUtil} -import org.apache.calcite.sql.{SqlCall, SqlIdentifier, SqlKind, SqlLiteral, SqlNode, SqlNodeList, SqlSelect, SqlUtil} +import org.apache.calcite.sql.{SqlCall, SqlIdentifier, SqlKind, SqlLiteral, SqlNode, SqlNodeList, SqlOrderBy, SqlSelect, SqlUtil} import org.apache.calcite.util.Static.RESOURCE import java.util @@ -50,16 +51,11 @@ class PreValidateReWriter( call match { case r: RichSqlInsert if r.getStaticPartitions.nonEmpty || r.getTargetColumnList != null => r.getSource match { - case select: SqlSelect => - appendPartitionAndNullsProjects(r, validator, typeFactory, select, r.getStaticPartitions) - case values: SqlCall if values.getKind == SqlKind.VALUES => - val newSource = appendPartitionAndNullsProjects(r, validator, typeFactory, values, - r.getStaticPartitions) + case call: SqlCall => + val newSource = appendPartitionAndNullsProjects( + r, validator, typeFactory, call, r.getStaticPartitions) r.setOperand(2, newSource) - case source => - throw new ValidationException( - s"INSERT INTO PARTITION [(COLUMN LIST)] statement only support " - + s"SELECT and VALUES clause for now, '$source' is not supported yet.") + case source => throw new ValidationException(notSupported(source)) } case _ => } @@ -67,7 +63,14 @@ class PreValidateReWriter( } object PreValidateReWriter { + //~ Tools ------------------------------------------------------------------ + + private def notSupported(source: SqlNode): String = { + s"INSERT INTO
PARTITION [(COLUMN LIST)] statement only support " + + s"SELECT, VALUES, SET_QUERY AND ORDER BY clause for now, '$source' is not supported yet." + } + /** * Append the static partitions and unspecified columns to the data source projection list. * The columns are appended to the corresponding positions. @@ -108,7 +111,6 @@ object PreValidateReWriter { typeFactory: RelDataTypeFactory, source: SqlCall, partitions: SqlNodeList): SqlCall = { - assert(source.getKind == SqlKind.SELECT || source.getKind == SqlKind.VALUES) val calciteCatalogReader = validator.getCatalogReader.unwrap(classOf[CalciteCatalogReader]) val names = sqlInsert.getTargetTable.asInstanceOf[SqlIdentifier].names val table = calciteCatalogReader.getTable(names) @@ -185,11 +187,49 @@ object PreValidateReWriter { } } - source match { - case select: SqlSelect => - rewriteSelect(validator, select, targetRowType, assignedFields, targetPosition) - case values: SqlCall if values.getKind == SqlKind.VALUES => - rewriteValues(values, targetRowType, assignedFields, targetPosition) + rewriteSqlCall(validator, source, targetRowType, assignedFields, targetPosition) + } + + private def rewriteSqlCall( + validator: FlinkCalciteSqlValidator, + call: SqlCall, + targetRowType: RelDataType, + assignedFields: util.LinkedHashMap[Integer, SqlNode], + targetPosition: util.List[Int]): SqlCall = { + + def rewrite(node: SqlNode): SqlCall = { + checkArgument(node.isInstanceOf[SqlCall], node) + rewriteSqlCall( + validator, + node.asInstanceOf[SqlCall], + targetRowType, + assignedFields, + targetPosition) + } + + call.getKind match { + case SqlKind.SELECT => + rewriteSelect( + validator, call.asInstanceOf[SqlSelect], targetRowType, assignedFields, targetPosition) + case SqlKind.VALUES => + rewriteValues(call, targetRowType, assignedFields, targetPosition) + case kind if SqlKind.SET_QUERY.contains(kind) => + call.getOperandList.zipWithIndex.foreach { + case (operand, index) => call.setOperand(index, rewrite(operand)) + } + call + case SqlKind.ORDER_BY => + val operands = call.getOperandList + new SqlOrderBy( + call.getParserPosition, + rewrite(operands.get(0)), + operands.get(1).asInstanceOf[SqlNodeList], + operands.get(2), + operands.get(3)) + // Not support: + // case SqlKind.WITH => + // case SqlKind.EXPLICIT_TABLE => + case _ => throw new ValidationException(notSupported(call)) } } diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml index b900cb65d118a..aa9aa501e26c0 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml @@ -35,6 +35,51 @@ Sink(table=[default_catalog.default_database.sink], fields=[a, b, c, d, e, f, g] +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) +- Exchange(distribution=[hash[a, b, c, d, e]]) +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) +]]> + + + + + + + + + + + + + + + + + + + + + + @@ -62,48 +107,432 @@ Sink(table=[default_catalog.default_database.partitioned_sink], fields=[a, c, d, ]]> - + - + (sum_vcol_marker, 0)]) + +- GroupAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, SUM_RETRACT(vcol_marker) AS sum_vcol_marker]) + +- Exchange(distribution=[hash[a, c, d, e, f, g]]) + +- Union(all=[true], union=[a, c, d, e, f, g, vcol_marker]) + :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, 1:BIGINT AS vcol_marker]) + : +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + : +- Exchange(distribution=[hash[a, b, c, d, e]]) + : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) + +- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, -1:BIGINT AS vcol_marker]) + +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + +- Exchange(distribution=[hash[a, b, c, d, e]]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) ]]> - + - + (sum_vcol_marker, 0)]) + +- HashAggregate(isMerge=[true], groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Final_SUM(sum$0) AS sum_vcol_marker]) + +- Exchange(distribution=[hash[a, c, d, e, f, g]]) + +- LocalHashAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Partial_SUM(vcol_marker) AS sum$0]) + +- Union(all=[true], union=[a, c, d, e, f, g, vcol_marker]) + :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, 1:BIGINT AS vcol_marker]) + : +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + : +- Exchange(distribution=[hash[a, b, c, d, e]]) + : +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) + +- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, -1:BIGINT AS vcol_marker]) + +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + +- Exchange(distribution=[hash[a, b, c, d, e]]) + +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) +]]> + + + + + + + + + + + (vcol_left_cnt, vcol_right_cnt), vcol_right_cnt, vcol_left_cnt) AS $f0, a, c, d, e, f, g], where=[AND(>=(vcol_left_cnt, 1), >=(vcol_right_cnt, 1))]) + +- GroupAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, COUNT_RETRACT(vcol_left_marker) AS vcol_left_cnt, COUNT_RETRACT(vcol_right_marker) AS vcol_right_cnt]) + +- Exchange(distribution=[hash[a, c, d, e, f, g]]) + +- Union(all=[true], union=[a, c, d, e, f, g, vcol_left_marker, vcol_right_marker]) + :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, true AS vcol_left_marker, null:BOOLEAN AS vcol_right_marker]) + : +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + : +- Exchange(distribution=[hash[a, b, c, d, e]]) + : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) + +- Calc(select=[a, c, d, e, CAST(456:BIGINT) AS f, CAST(789) AS g, null:BOOLEAN AS vcol_left_marker, true AS vcol_right_marker]) + +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + +- Exchange(distribution=[hash[a, b, c, d, e]]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) +]]> + + + + + + + + + + + (vcol_left_cnt, vcol_right_cnt), vcol_right_cnt, vcol_left_cnt) AS $f0, a, c, d, e, f, g], where=[AND(>=(vcol_left_cnt, 1), >=(vcol_right_cnt, 1))]) + +- HashAggregate(isMerge=[true], groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Final_COUNT(count$0) AS vcol_left_cnt, Final_COUNT(count$1) AS vcol_right_cnt]) + +- Exchange(distribution=[hash[a, c, d, e, f, g]]) + +- LocalHashAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Partial_COUNT(vcol_left_marker) AS count$0, Partial_COUNT(vcol_right_marker) AS count$1]) + +- Union(all=[true], union=[a, c, d, e, f, g, vcol_left_marker, vcol_right_marker]) + :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, true AS vcol_left_marker, null:BOOLEAN AS vcol_right_marker]) + : +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + : +- Exchange(distribution=[hash[a, b, c, d, e]]) + : +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) + +- Calc(select=[a, c, d, e, CAST(456:BIGINT) AS f, CAST(789) AS g, null:BOOLEAN AS vcol_left_marker, true AS vcol_right_marker]) + +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + +- Exchange(distribution=[hash[a, b, c, d, e]]) + +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala index fb6eac13e2cd0..b16ce64919cc0 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala @@ -73,6 +73,49 @@ class PartialInsertTest(isBatch: Boolean) extends TableTestBase { util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " + "SELECT e,a,456,123,c,d FROM MyTable GROUP BY a,b,c,d,e") } + + @Test + def testPartialInsertWithUnion(): Unit = { + testPartialInsertWithSetOperator("UNION") + } + + @Test + def testPartialInsertWithUnionAll(): Unit = { + testPartialInsertWithSetOperator("UNION ALL") + } + + @Test + def testPartialInsertWithIntersectAll(): Unit = { + testPartialInsertWithSetOperator("INTERSECT ALL") + } + + @Test + def testPartialInsertWithExceptAll(): Unit = { + testPartialInsertWithSetOperator("EXCEPT ALL") + } + + private def testPartialInsertWithSetOperator(operator: String): Unit = { + util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " + + "SELECT e,a,456,123,c,d FROM MyTable GROUP BY a,b,c,d,e " + + operator + " " + + "SELECT e,a,789,456,c,d FROM MyTable GROUP BY a,b,c,d,e ") + } + + @Test + def testPartialInsertWithUnionAllNested(): Unit = { + util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " + + "SELECT e,a,456,123,c,d FROM MyTable GROUP BY a,b,c,d,e " + + "UNION ALL " + + "SELECT e,a,789,456,c,d FROM MyTable GROUP BY a,b,c,d,e " + + "UNION ALL " + + "SELECT e,a,123,456,c,d FROM MyTable GROUP BY a,b,c,d,e ") + } + + @Test + def testPartialInsertWithOrderBy(): Unit = { + util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " + + "SELECT e,a,456,123,c,d FROM MyTable ORDER BY a,e,c,d") + } } object PartialInsertTest {