From 237e3a2df4b027995491b67adbbe77f8f3648704 Mon Sep 17 00:00:00 2001 From: godfreyhe Date: Tue, 21 May 2019 20:30:48 +0800 Subject: [PATCH 01/92] [FLINK-12575][table-planner-blink] Introduce planner rules to remove redundant shuffle and collation This closes #8499 i#refactor satisfyTraits method --- .../nodes/physical/FlinkPhysicalRel.scala | 4 +- .../nodes/physical/batch/BatchExecCalc.scala | 63 +- .../physical/batch/BatchExecCorrelate.scala | 83 +- .../batch/BatchExecGroupAggregateBase.scala | 11 +- .../batch/BatchExecHashAggregate.scala | 58 +- .../physical/batch/BatchExecHashJoin.scala | 33 + .../physical/batch/BatchExecJoinBase.scala | 156 +- .../batch/BatchExecLocalHashAggregate.scala | 40 +- .../batch/BatchExecLocalSortAggregate.scala | 48 +- .../batch/BatchExecNestedLoopJoin.scala | 6 + .../batch/BatchExecOverAggregate.scala | 105 +- .../nodes/physical/batch/BatchExecRank.scala | 125 +- .../batch/BatchExecSortAggregate.scala | 69 +- .../batch/BatchExecSortMergeJoin.scala | 54 +- .../nodes/physical/batch/BatchExecUnion.scala | 39 +- .../table/plan/rules/FlinkBatchRuleSets.scala | 25 +- .../plan/rules/FlinkStreamRuleSets.scala | 18 +- .../logical/WindowGroupReorderRule.scala | 137 ++ .../physical/FlinkExpandConversionRule.scala | 18 +- .../RemoveRedundantLocalHashAggRule.scala | 60 + .../batch/RemoveRedundantLocalRankRule.scala | 60 + .../RemoveRedundantLocalSortAggRule.scala | 110 ++ .../plan/batch/sql/DagOptimizationTest.xml | 45 +- .../plan/batch/sql/DeadlockBreakupTest.xml | 58 +- .../plan/batch/sql/RemoveCollationTest.xml | 708 +++++++++ .../plan/batch/sql/RemoveShuffleTest.xml | 1264 +++++++++++++++++ .../table/plan/batch/sql/SubplanReuseTest.xml | 63 +- .../plan/batch/sql/agg/GroupingSetsTest.xml | 18 +- .../plan/batch/sql/agg/OverAggregateTest.xml | 271 ++-- .../join/BroadcastHashSemiAntiJoinTest.xml | 13 +- .../batch/sql/join/NestedLoopJoinTest.xml | 4 +- .../sql/join/NestedLoopSemiAntiJoinTest.xml | 13 +- .../plan/batch/sql/join/SemiAntiJoinTest.xml | 197 ++- .../batch/sql/join/ShuffledHashJoinTest.xml | 8 +- .../sql/join/ShuffledHashSemiAntiJoinTest.xml | 140 +- .../plan/batch/sql/join/SingleRowJoinTest.xml | 43 +- .../plan/batch/sql/join/SortMergeJoinTest.xml | 8 +- .../sql/join/SortMergeSemiAntiJoinTest.xml | 140 +- .../logical/WindowGroupReorderRuleTest.xml | 280 ++++ .../RemoveRedundantLocalHashAggRuleTest.xml | 84 ++ .../RemoveRedundantLocalRankRuleTest.xml | 114 ++ .../RemoveRedundantLocalSortAggRuleTest.xml | 86 ++ .../plan/batch/sql/RemoveCollationTest.scala | 384 +++++ .../plan/batch/sql/RemoveShuffleTest.scala | 547 +++++++ .../batch/sql/agg/OverAggregateTest.scala | 51 +- .../logical/WindowGroupReorderRuleTest.scala | 179 +++ .../RemoveRedundantLocalHashAggRuleTest.scala | 71 + .../RemoveRedundantLocalRankRuleTest.scala | 72 + .../RemoveRedundantLocalSortAggRuleTest.scala | 66 + 49 files changed, 5693 insertions(+), 556 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/FlinkPhysicalRel.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/FlinkPhysicalRel.scala index cd85e7ba347617..cc49d0ad37a72e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/FlinkPhysicalRel.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/FlinkPhysicalRel.scala @@ -35,8 +35,8 @@ trait FlinkPhysicalRel extends FlinkRelNode { * * @param requiredTraitSet required traits * @return A converted node which satisfy required traits by inputs node of current node. - * Returns null if required traits cannot be pushed down into inputs. + * Returns None if required traits cannot be satisfied. */ - def satisfyTraitsByInput(requiredTraitSet: RelTraitSet): RelNode = null + def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = None } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCalc.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCalc.scala index ace866d2839417..1e184dc217f2e9 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCalc.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCalc.scala @@ -24,6 +24,7 @@ import org.apache.flink.table.api.{BatchTableEnvironment, TableConfigOptions} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.{CalcCodeGenerator, CodeGeneratorContext} import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef, TraitUtil} import org.apache.flink.table.plan.nodes.common.CommonCalc import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.plan.util.RelExplainUtil @@ -32,7 +33,9 @@ import org.apache.calcite.plan._ import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Calc -import org.apache.calcite.rex.RexProgram +import org.apache.calcite.rex.{RexCall, RexInputRef, RexProgram} +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.util.mapping.{Mapping, MappingType, Mappings} import java.util @@ -57,6 +60,64 @@ class BatchExecCalc( new BatchExecCalc(cluster, traitSet, child, program, outputRowType) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + // Does not push broadcast distribution trait down into Calc. + if (requiredDistribution.getType == RelDistribution.Type.BROADCAST_DISTRIBUTED) { + return None + } + val projects = calcProgram.getProjectList.map(calcProgram.expandLocalRef) + + def getProjectMapping: Mapping = { + val mapping = Mappings.create(MappingType.INVERSE_FUNCTION, + getInput.getRowType.getFieldCount, projects.size) + projects.zipWithIndex.foreach { + case (project, index) => + project match { + case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index) + case call: RexCall if call.getKind == SqlKind.AS => + call.getOperands.head match { + case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index) + case _ => // ignore + } + case _ => // ignore + } + } + mapping.inverse() + } + + val mapping = getProjectMapping + val appliedDistribution = requiredDistribution.apply(mapping) + // If both distribution and collation can be satisfied, satisfy both. If only distribution + // can be satisfied, only satisfy distribution. There is no possibility to only satisfy + // collation here except for there is no distribution requirement. + if ((!requiredDistribution.isTop) && (appliedDistribution eq FlinkRelDistribution.ANY)) { + return None + } + + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val appliedCollation = TraitUtil.apply(requiredCollation, mapping) + val canCollationPushedDown = !appliedCollation.getFieldCollations.isEmpty + // If required traits only contains collation requirements, but collation keys are not columns + // from input, then no need to satisfy required traits. + if ((appliedDistribution eq FlinkRelDistribution.ANY) && !canCollationPushedDown) { + return None + } + + var inputRequiredTraits = getInput.getTraitSet + var providedTraits = getTraitSet + if (!appliedDistribution.isTop) { + inputRequiredTraits = inputRequiredTraits.replace(appliedDistribution) + providedTraits = providedTraits.replace(requiredDistribution) + } + if (canCollationPushedDown) { + inputRequiredTraits = inputRequiredTraits.replace(appliedCollation) + providedTraits = providedTraits.replace(requiredCollation) + } + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + Some(copy(providedTraits, Seq(newInput))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior = DamBehavior.PIPELINED diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCorrelate.scala index 7097e82c1a44f4..16f0cb52e7d83d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCorrelate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecCorrelate.scala @@ -23,16 +23,18 @@ import org.apache.flink.table.api.{BatchTableEnvironment, TableConfigOptions} import org.apache.flink.table.codegen.{CodeGeneratorContext, CorrelateCodeGenerator} import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.utils.TableSqlFunction +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef, TraitUtil} import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan import org.apache.flink.table.plan.util.RelExplainUtil -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Correlate -import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.calcite.rex.{RexCall, RexNode, RexProgram} -import org.apache.calcite.sql.SemiJoinType +import org.apache.calcite.rel.{RelCollationTraitDef, RelDistribution, RelFieldCollation, RelNode, RelWriter, SingleRel} +import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexProgram} +import org.apache.calcite.sql.{SemiJoinType, SqlKind} +import org.apache.calcite.util.mapping.{Mapping, MappingType, Mappings} import java.util @@ -94,6 +96,79 @@ class BatchExecCorrelate( .itemIf("condition", condition.orNull, condition.isDefined) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + // Correlate could not provide broadcast distribution + if (requiredDistribution.getType == RelDistribution.Type.BROADCAST_DISTRIBUTED) { + return None + } + + def getOutputInputMapping: Mapping = { + val inputFieldCnt = getInput.getRowType.getFieldCount + projectProgram match { + case Some(program) => + val projects = program.getProjectList.map(program.expandLocalRef) + val mapping = Mappings.create(MappingType.INVERSE_FUNCTION, inputFieldCnt, projects.size) + projects.zipWithIndex.foreach { + case (project, index) => + project match { + case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index) + case call: RexCall if call.getKind == SqlKind.AS => + call.getOperands.head match { + case inputRef: RexInputRef => mapping.set(inputRef.getIndex, index) + case _ => // ignore + } + case _ => // ignore + } + } + mapping.inverse() + case _ => + val mapping = Mappings.create(MappingType.FUNCTION, inputFieldCnt, inputFieldCnt) + (0 until inputFieldCnt).foreach { + index => mapping.set(index, index) + } + mapping + } + } + + val mapping = getOutputInputMapping + val appliedDistribution = requiredDistribution.apply(mapping) + // If both distribution and collation can be satisfied, satisfy both. If only distribution + // can be satisfied, only satisfy distribution. There is no possibility to only satisfy + // collation here except for there is no distribution requirement. + if ((!requiredDistribution.isTop) && (appliedDistribution eq FlinkRelDistribution.ANY)) { + return None + } + + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val appliedCollation = TraitUtil.apply(requiredCollation, mapping) + // the required collation can be satisfied if field collations are not empty + // and the direction of each field collation is non-STRICTLY + val canSatisfyCollation = appliedCollation.getFieldCollations.nonEmpty && + !appliedCollation.getFieldCollations.exists { c => + (c.getDirection eq RelFieldCollation.Direction.STRICTLY_ASCENDING) || + (c.getDirection eq RelFieldCollation.Direction.STRICTLY_DESCENDING) + } + // If required traits only contains collation requirements, but collation keys are not columns + // from input, then no need to satisfy required traits. + if ((appliedDistribution eq FlinkRelDistribution.ANY) && !canSatisfyCollation) { + return None + } + + var inputRequiredTraits = getInput.getTraitSet + var providedTraits = getTraitSet + if (!appliedDistribution.isTop) { + inputRequiredTraits = inputRequiredTraits.replace(appliedDistribution) + providedTraits = providedTraits.replace(requiredDistribution) + } + if (canSatisfyCollation) { + inputRequiredTraits = inputRequiredTraits.replace(appliedCollation) + providedTraits = providedTraits.replace(requiredCollation) + } + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + Some(copy(providedTraits, Seq(newInput))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = DamBehavior.PIPELINED diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala index 60481a66edcab0..fc3456010f780f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala @@ -18,9 +18,9 @@ package org.apache.flink.table.plan.nodes.physical.batch -import org.apache.flink.table.api.TableException +import org.apache.flink.table.api.{AggPhaseEnforcer, PlannerConfigOptions, TableException} import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.plan.util.{FlinkRelOptUtil, RelExplainUtil} import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType @@ -84,4 +84,11 @@ abstract class BatchExecGroupAggregateBase( isFinal) } + protected def isEnforceTwoStageAgg: Boolean = { + val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(this) + val aggConfig = tableConfig.getConf.getString( + PlannerConfigOptions.SQL_OPTIMIZER_AGG_PHASE_ENFORCER) + AggPhaseEnforcer.TWO_PHASE.toString.equalsIgnoreCase(aggConfig) + } + } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala index 19bc869f47db61..2827987790cfb9 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala @@ -19,19 +19,24 @@ package org.apache.flink.table.plan.nodes.physical.batch import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.StreamTransformation -import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} +import org.apache.flink.table.plan.util.{FlinkRelOptUtil, RelExplainUtil} -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.{ImmutableIntList, Util} import java.util +import scala.collection.JavaConversions._ + /** * Batch physical RelNode for (global) hash-based aggregate operator. * @@ -94,6 +99,53 @@ class BatchExecHashAggregate( isGlobal = true)) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val canSatisfy = requiredDistribution.getType match { + case SINGLETON => grouping.length == 0 + case HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + val groupKeysList = ImmutableIntList.of(grouping.indices.toArray: _*) + if (requiredDistribution.requireStrict) { + shuffleKeys == groupKeysList + } else if (Util.startsWith(shuffleKeys, groupKeysList)) { + // If required distribution is not strict, Hash[a] can satisfy Hash[a, b]. + // so return true if shuffleKeys(Hash[a, b]) start with groupKeys(Hash[a]) + true + } else { + // If partialKey is enabled, try to use partial key to satisfy the required distribution + val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(this) + val partialKeyEnabled = tableConfig.getConf.getBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED) + partialKeyEnabled && groupKeysList.containsAll(shuffleKeys) + } + case _ => false + } + if (!canSatisfy) { + return None + } + + val inputRequiredDistribution = requiredDistribution.getType match { + case SINGLETON => requiredDistribution + case HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + val groupKeysList = ImmutableIntList.of(grouping.indices.toArray: _*) + if (requiredDistribution.requireStrict) { + FlinkRelDistribution.hash(grouping, requireStrict = true) + } else if (Util.startsWith(shuffleKeys, groupKeysList)) { + // Hash[a] can satisfy Hash[a, b] + FlinkRelDistribution.hash(grouping, requireStrict = false) + } else { + // use partial key to satisfy the required distribution + FlinkRelDistribution.hash(shuffleKeys.map(grouping(_)).toArray, requireStrict = false) + } + } + + val newInput = RelOptRule.convert(getInput, inputRequiredDistribution) + val newProvidedTraitSet = getTraitSet.replace(requiredDistribution) + Some(copy(newProvidedTraitSet, Seq(newInput))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior = DamBehavior.FULL_DAM diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala index 579a7065cb371d..2205c0fd76494f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashJoin.scala @@ -21,7 +21,9 @@ import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.StreamTransformation import org.apache.flink.table.api.{BatchTableEnvironment, TableException} import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} +import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.exec.ExecNode import org.apache.flink.table.plan.util.{FlinkRelMdUtil, JoinUtil} import org.apache.flink.table.runtime.join.HashJoinType @@ -127,6 +129,37 @@ class BatchExecHashJoin( } } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + if (!isBroadcast) { + satisfyTraitsOnNonBroadcastHashJoin(requiredTraitSet) + } else { + satisfyTraitsOnBroadcastJoin(requiredTraitSet, leftIsBuild) + } + } + + private def satisfyTraitsOnNonBroadcastHashJoin( + requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val (canSatisfyDistribution, leftRequiredDistribution, rightRequiredDistribution) = + satisfyHashDistributionOnNonBroadcastJoin(requiredDistribution) + if (!canSatisfyDistribution) { + return None + } + + val toRestrictHashDistributionByKeys = (distribution: FlinkRelDistribution) => + getCluster.getPlanner + .emptyTraitSet + .replace(FlinkConventions.BATCH_PHYSICAL) + .replace(distribution) + val leftRequiredTraits = toRestrictHashDistributionByKeys(leftRequiredDistribution) + val rightRequiredTraits = toRestrictHashDistributionByKeys(rightRequiredDistribution) + val newLeft = RelOptRule.convert(getLeft, leftRequiredTraits) + val newRight = RelOptRule.convert(getRight, rightRequiredTraits) + val providedTraits = getTraitSet.replace(requiredDistribution) + // HashJoin can not satisfy collation. + Some(copy(providedTraits, Seq(newLeft, newRight))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala index f22a64accb6cbc..ff7a9c3467bce4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecJoinBase.scala @@ -22,13 +22,19 @@ import org.apache.flink.table.api.TableConfig import org.apache.flink.table.codegen.{CodeGeneratorContext, ExprCodeGenerator, FunctionCodeGenerator} import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.generated.GeneratedJoinCondition +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.plan.nodes.common.CommonPhysicalJoin import org.apache.flink.table.plan.nodes.exec.BatchExecNode -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.RelNode +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, RANGE_DISTRIBUTED} import org.apache.calcite.rel.core.{Join, JoinRelType} +import org.apache.calcite.rel.{RelCollations, RelNode} import org.apache.calcite.rex.RexNode +import org.apache.calcite.util.ImmutableIntList + +import scala.collection.JavaConversions._ +import scala.collection.mutable /** * Batch physical RelNode for [[Join]] @@ -70,4 +76,150 @@ abstract class BatchExecJoinBase( "JoinConditionFunction", body) } + + /** + * Try to satisfy hash distribution on Non-BroadcastJoin (including SortMergeJoin and + * Non-Broadcast HashJoin). + * + * @param requiredDistribution distribution requirement + * @return a Tuple including 3 element. + * The first element is a flag which indicates whether the requirement can be satisfied. + * The second element is the distribution requirement of left child if the requirement + * can be push down into join. + * The third element is the distribution requirement of right child if the requirement + * can be push down into join. + */ + def satisfyHashDistributionOnNonBroadcastJoin( + requiredDistribution: FlinkRelDistribution + ): (Boolean, FlinkRelDistribution, FlinkRelDistribution) = { + // Only Non-broadcast HashJoin could provide HashDistribution + if (requiredDistribution.getType != HASH_DISTRIBUTED) { + return (false, null, null) + } + // Full outer join cannot provide Hash distribute because it will generate null for left/right + // side if there is no match row. + if (joinType == JoinRelType.FULL) { + return (false, null, null) + } + + val leftKeys = joinInfo.leftKeys + val rightKeys = joinInfo.rightKeys + val leftKeysToRightKeys = leftKeys.zip(rightKeys).toMap + val rightKeysToLeftKeys = rightKeys.zip(leftKeys).toMap + val leftFieldCnt = getLeft.getRowType.getFieldCount + val requiredShuffleKeys = requiredDistribution.getKeys + val requiredLeftShuffleKeys = mutable.ArrayBuffer[Int]() + val requiredRightShuffleKeys = mutable.ArrayBuffer[Int]() + requiredShuffleKeys.foreach { key => + if (key < leftFieldCnt && joinType != JoinRelType.RIGHT) { + leftKeysToRightKeys.get(key) match { + case Some(rk) => + requiredLeftShuffleKeys += key + requiredRightShuffleKeys += rk + case None if requiredDistribution.requireStrict => + // Cannot satisfy required hash distribution due to required distribution is restrict + // however the key is not found in right + return (false, null, null) + case _ => // do nothing + } + } else if (key >= leftFieldCnt && + (joinType == JoinRelType.RIGHT || joinType == JoinRelType.INNER)) { + val keysOnRightChild = key - leftFieldCnt + rightKeysToLeftKeys.get(keysOnRightChild) match { + case Some(lk) => + requiredLeftShuffleKeys += lk + requiredRightShuffleKeys += keysOnRightChild + case None if requiredDistribution.requireStrict => + // Cannot satisfy required hash distribution due to required distribution is restrict + // however the key is not found in left + return (false, null, null) + case _ => // do nothing + } + } else { + // cannot satisfy required hash distribution if requirement shuffle keys are not come from + // left side when Join is LOJ or are not come from right side when Join is ROJ. + return (false, null, null) + } + } + if (requiredLeftShuffleKeys.isEmpty) { + // the join can not satisfy the required hash distribution + // due to the required input shuffle keys are empty + return (false, null, null) + } + + val (leftShuffleKeys, rightShuffleKeys) = if (joinType == JoinRelType.INNER && + !requiredDistribution.requireStrict) { + (requiredLeftShuffleKeys.distinct, requiredRightShuffleKeys.distinct) + } else { + (requiredLeftShuffleKeys, requiredRightShuffleKeys) + } + (true, + FlinkRelDistribution.hash(ImmutableIntList.of(leftShuffleKeys: _*), requireStrict = true), + FlinkRelDistribution.hash(ImmutableIntList.of(rightShuffleKeys: _*), requireStrict = true)) + } + + /** + * Try to satisfy the given required traits on BroadcastJoin (including Broadcast-HashJoin and + * NestedLoopJoin). + * + * @param requiredTraitSet requirement traitSets + * @return Equivalent Join which satisfies required traitSet, return null if + * requirement cannot be satisfied. + */ + protected def satisfyTraitsOnBroadcastJoin( + requiredTraitSet: RelTraitSet, + leftIsBroadcast: Boolean): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val keys = requiredDistribution.getKeys + val left = getLeft + val right = getRight + val leftFieldCnt = left.getRowType.getFieldCount + val canSatisfy = requiredDistribution.getType match { + case HASH_DISTRIBUTED | RANGE_DISTRIBUTED => + // required distribution can be satisfied only if distribution keys all from + // non-broadcast side of BroadcastJoin + if (leftIsBroadcast) { + // all distribution keys must come from right child + keys.forall(_ >= leftFieldCnt) + } else { + // all distribution keys must come from left child + keys.forall(_ < leftFieldCnt) + } + // SINGLETON, BROADCAST_DISTRIBUTED, ANY, RANDOM_DISTRIBUTED, ROUND_ROBIN_DISTRIBUTED + // distribution cannot be pushed down. + case _ => false + } + if (!canSatisfy) { + return None + } + + val keysInProbeSide = if (leftIsBroadcast) { + ImmutableIntList.of(keys.map(_ - leftFieldCnt): _ *) + } else { + keys + } + + val inputRequiredDistribution = requiredDistribution.getType match { + case HASH_DISTRIBUTED => + FlinkRelDistribution.hash(keysInProbeSide, requiredDistribution.requireStrict) + case RANGE_DISTRIBUTED => + FlinkRelDistribution.range(keysInProbeSide) + } + // remove collation traits from input traits and provided traits + val (newLeft, newRight) = if (leftIsBroadcast) { + val rightRequiredTraitSet = right.getTraitSet + .replace(inputRequiredDistribution) + .replace(RelCollations.EMPTY) + val newRight = RelOptRule.convert(right, rightRequiredTraitSet) + (left, newRight) + } else { + val leftRequiredTraitSet = left.getTraitSet + .replace(inputRequiredDistribution) + .replace(RelCollations.EMPTY) + val newLeft = RelOptRule.convert(left, leftRequiredTraitSet) + (newLeft, right) + } + val providedTraitSet = requiredTraitSet.replace(RelCollations.EMPTY) + Some(copy(providedTraitSet, Seq(newLeft, newRight))) + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala index 9bf440ce201ad7..b475f4653cbb75 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala @@ -19,21 +19,24 @@ package org.apache.flink.table.plan.nodes.physical.batch import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.StreamTransformation -import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig} +import org.apache.flink.table.api.TableConfig import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.plan.nodes.exec.ExecNode +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.plan.util.RelExplainUtil -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelDistribution.Type import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.ImmutableIntList import java.util import scala.collection.JavaConversions._ +import scala.collection.mutable /** * Batch physical RelNode for local hash-based aggregate operator. @@ -92,6 +95,37 @@ class BatchExecLocalHashAggregate( isGlobal = false)) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + // Does not to try to satisfy requirement by localAgg's input if enforce to use two-stage agg. + if (isEnforceTwoStageAgg) { + return None + } + + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val canSatisfy = requiredDistribution.getType match { + case Type.HASH_DISTRIBUTED | Type.RANGE_DISTRIBUTED => + val groupCount = grouping.length + // Cannot satisfy distribution if keys are not group keys of agg + requiredDistribution.getKeys.forall(_ < groupCount) + case _ => false + } + if (!canSatisfy) { + return None + } + + val keys = requiredDistribution.getKeys.map(grouping(_)) + val inputRequiredDistributionKeys = ImmutableIntList.of(keys: _*) + val inputRequiredDistribution = requiredDistribution.getType match { + case Type.HASH_DISTRIBUTED => + FlinkRelDistribution.hash(inputRequiredDistributionKeys, requiredDistribution.requireStrict) + case Type.RANGE_DISTRIBUTED => FlinkRelDistribution.range(inputRequiredDistributionKeys) + } + val inputRequiredTraits = input.getTraitSet.replace(inputRequiredDistribution) + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + val providedTraits = getTraitSet.replace(requiredDistribution) + Some(copy(providedTraits, Seq(newInput))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala index bac96fd18efcb4..e62c850aee9399 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala @@ -22,16 +22,22 @@ import org.apache.flink.streaming.api.transformations.StreamTransformation import org.apache.flink.table.api.TableConfig import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} +import org.apache.flink.table.plan.util.{FlinkRelOptUtil, RelExplainUtil} -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelDistribution.Type import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.ImmutableIntList import java.util +import scala.collection.JavaConversions._ +import scala.collection.mutable + /** * Batch physical RelNode for local sort-based aggregate operator. * @@ -90,6 +96,44 @@ class BatchExecLocalSortAggregate( isGlobal = false)) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + // Does not to try to satisfy requirement by localAgg's input if enforce to use two-stage agg. + if (isEnforceTwoStageAgg) { + return None + } + val groupCount = grouping.length + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val canSatisfy = requiredDistribution.getType match { + case Type.HASH_DISTRIBUTED | Type.RANGE_DISTRIBUTED => + // Cannot satisfy distribution if keys are not group keys of agg + requiredDistribution.getKeys.forall(_ < groupCount) + case _ => false + } + if (!canSatisfy) { + return None + } + + val keys = requiredDistribution.getKeys.map(grouping(_)) + val inputRequiredDistributionKeys = ImmutableIntList.of(keys: _*) + val inputRequiredDistribution = requiredDistribution.getType match { + case Type.HASH_DISTRIBUTED => + FlinkRelDistribution.hash(inputRequiredDistributionKeys, requiredDistribution.requireStrict) + case Type.RANGE_DISTRIBUTED => + FlinkRelDistribution.range(inputRequiredDistributionKeys) + } + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val providedFieldCollations = (0 until groupCount).map(FlinkRelOptUtil.ofRelFieldCollation) + val providedCollation = RelCollations.of(providedFieldCollations) + val newProvidedTraits = if (providedCollation.satisfies(requiredCollation)) { + getTraitSet.replace(requiredDistribution).replace(requiredCollation) + } else { + getTraitSet.replace(requiredDistribution) + } + val inputRequiredTraits = getInput.getTraitSet.replace(inputRequiredDistribution) + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + Some(copy(newProvidedTraits, Seq(newInput))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala index a9919f8805ace7..1043c6506cf24e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecNestedLoopJoin.scala @@ -82,6 +82,7 @@ class BatchExecNestedLoopJoin( if (leftRowCnt == null || rightRowCnt == null) { return null } + val buildRel = if (leftIsBuild) getLeft else getRight val buildRows = mq.getRowCount(buildRel) val buildRowSize = mq.getAverageRowSize(buildRel) @@ -104,6 +105,11 @@ class BatchExecNestedLoopJoin( } } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + // Assume NestedLoopJoin always broadcast data from child which smaller. + satisfyTraitsOnBroadcastJoin(requiredTraitSet, leftIsBuild) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = DamBehavior.PIPELINED diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala index 14f50a8af253ef..1c07de73425b34 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecOverAggregate.scala @@ -22,7 +22,7 @@ import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} import org.apache.flink.table.CalcitePair import org.apache.flink.table.`type`.InternalTypes -import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, TableConfigOptions} +import org.apache.flink.table.api.{BatchTableEnvironment, PlannerConfigOptions, TableConfig, TableConfigOptions} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGeneratorContext import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator @@ -31,17 +31,19 @@ import org.apache.flink.table.codegen.sort.ComparatorCodeGenerator import org.apache.flink.table.dataformat.{BaseRow, BinaryRow} import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.generated.GeneratedRecordComparator +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.plan.nodes.physical.batch.OverWindowMode.OverWindowMode import org.apache.flink.table.plan.util.AggregateUtil.transformToBatchAggregateInfoList import org.apache.flink.table.plan.util.OverAggregateUtil.getLongBoundary -import org.apache.flink.table.plan.util.{OverAggregateUtil, RelExplainUtil} +import org.apache.flink.table.plan.util.{FlinkRelOptUtil, OverAggregateUtil, RelExplainUtil} import org.apache.flink.table.runtime.over.frame.OffsetOverFrame.CalcOffsetFunc import org.apache.flink.table.runtime.over.frame.{InsensitiveOverFrame, OffsetOverFrame, OverWindowFrame, RangeSlidingOverFrame, RangeUnboundedFollowingOverFrame, RangeUnboundedPrecedingOverFrame, RowSlidingOverFrame, RowUnboundedFollowingOverFrame, RowUnboundedPrecedingOverFrame, UnboundedOverWindowFrame} import org.apache.flink.table.runtime.over.{BufferDataOverWindowOperator, NonBufferOverWindowOperator} import org.apache.calcite.plan._ +import org.apache.calcite.rel.RelDistribution.Type._ import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Window.Group @@ -51,6 +53,7 @@ import org.apache.calcite.rex.RexWindowBound import org.apache.calcite.sql.SqlKind import org.apache.calcite.sql.fun.SqlLeadLagAggFunction import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.ImmutableIntList import java.util @@ -179,7 +182,7 @@ class BatchExecOverAggregate( yield new CalcitePair[AggregateCall, String](aggregateCalls.get(i), "windowAgg$" + i) } - private[flink] def splitOutOffsetOrInsensitiveGroup() + private def splitOutOffsetOrInsensitiveGroup() : Seq[(OverWindowMode, Window.Group, Seq[(AggregateCall, UserDefinedFunction)])] = { def compareTo(o1: Window.RexWinAggCall, o2: Window.RexWinAggCall): Boolean = { @@ -242,6 +245,102 @@ class BatchExecOverAggregate( windowGroupInfo } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + if (requiredDistribution.getType == ANY && requiredCollation.getFieldCollations.isEmpty) { + return None + } + + val selfProvidedTraitSet = inferProvidedTraitSet() + if (selfProvidedTraitSet.satisfies(requiredTraitSet)) { + // Current node can satisfy the requiredTraitSet,return the current node with ProvidedTraitSet + return Some(copy(selfProvidedTraitSet, Seq(getInput))) + } + + val inputFieldCnt = getInput.getRowType.getFieldCount + val canSatisfy = if (requiredDistribution.getType == ANY) { + true + } else { + if (!grouping.isEmpty) { + if (requiredDistribution.requireStrict) { + requiredDistribution.getKeys == ImmutableIntList.of(grouping: _*) + } else { + val isAllFieldsFromInput = requiredDistribution.getKeys.forall(_ < inputFieldCnt) + if (isAllFieldsFromInput) { + val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(this) + if (tableConfig.getConf.getBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED)) { + ImmutableIntList.of(grouping: _*).containsAll(requiredDistribution.getKeys) + } else { + requiredDistribution.getKeys == ImmutableIntList.of(grouping: _*) + } + } else { + // If requirement distribution keys are not all comes from input directly, + // cannot satisfy requirement distribution and collations. + false + } + } + } else { + requiredDistribution.getType == SINGLETON + } + } + // If OverAggregate can provide distribution, but it's traits cannot satisfy required + // distribution, cannot push down distribution and collation requirement (because later + // shuffle will destroy previous collation. + if (!canSatisfy) { + return None + } + + var inputRequiredTraits = getInput.getTraitSet + var providedTraits = selfProvidedTraitSet + val providedCollation = selfProvidedTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + if (!requiredDistribution.isTop) { + inputRequiredTraits = inputRequiredTraits.replace(requiredDistribution) + providedTraits = providedTraits.replace(requiredDistribution) + } + + if (providedCollation.satisfies(requiredCollation)) { + // the providedCollation can satisfy the requirement, + // so don't push down the sort into it's input. + } else if (providedCollation.getFieldCollations.isEmpty && + requiredCollation.getFieldCollations.nonEmpty) { + // If OverAgg cannot provide collation itself, try to push down collation requirements into + // it's input if collation fields all come from input node. + val canPushDownCollation = requiredCollation.getFieldCollations + .forall(_.getFieldIndex < inputFieldCnt) + if (canPushDownCollation) { + inputRequiredTraits = inputRequiredTraits.replace(requiredCollation) + providedTraits = providedTraits.replace(requiredCollation) + } + } else { + // Don't push down the sort into it's input, + // due to the provided collation will destroy the input's provided collation. + } + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + Some(copy(providedTraits, Seq(newInput))) + } + + private def inferProvidedTraitSet(): RelTraitSet = { + var selfProvidedTraitSet = getTraitSet + // provided distribution + val providedDistribution = if (grouping.nonEmpty) { + FlinkRelDistribution.hash(grouping.map(Integer.valueOf).toList, requireStrict = false) + } else { + FlinkRelDistribution.SINGLETON + } + selfProvidedTraitSet = selfProvidedTraitSet.replace(providedDistribution) + // provided collation + val firstGroup = windowGroupToAggCallToAggFunction.head._1 + if (OverAggregateUtil.needCollationTrait(logicWindow, firstGroup)) { + val collation = OverAggregateUtil.createCollation(firstGroup) + if (!collation.equals(RelCollations.EMPTY)) { + selfProvidedTraitSet = selfProvidedTraitSet.replace(collation) + } + } + selfProvidedTraitSet + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior = DamBehavior.PIPELINED diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala index 1363a4707ee67a..ce030e70c7cb48 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecRank.scala @@ -20,22 +20,25 @@ package org.apache.flink.table.plan.nodes.physical.batch import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} -import org.apache.flink.table.api.{BatchTableEnvironment, TableException} +import org.apache.flink.table.api.{BatchTableEnvironment, PlannerConfigOptions, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.sort.ComparatorCodeGenerator import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} import org.apache.flink.table.plan.nodes.calcite.Rank import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} -import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.plan.util.{FlinkRelOptUtil, RelExplainUtil} import org.apache.flink.table.runtime.rank.{ConstantRankRange, RankRange, RankType} import org.apache.flink.table.runtime.sort.RankOperator import org.apache.calcite.plan._ +import org.apache.calcite.rel.RelDistribution.Type +import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataTypeField import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.util.ImmutableBitSet +import org.apache.calcite.util.{ImmutableBitSet, ImmutableIntList, Util} import java.util @@ -112,6 +115,122 @@ class BatchExecRank( costFactory.makeCost(rowCount, cpuCost, 0, 0, memCost) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + if (isGlobal) { + satisfyTraitsOnGlobalRank(requiredTraitSet) + } else { + satisfyTraitsOnLocalRank(requiredTraitSet) + } + } + + private def satisfyTraitsOnGlobalRank(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val canSatisfy = requiredDistribution.getType match { + case SINGLETON => partitionKey.cardinality() == 0 + case HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + val partitionKeyList = ImmutableIntList.of(partitionKey.toArray: _*) + if (requiredDistribution.requireStrict) { + shuffleKeys == partitionKeyList + } else if (Util.startsWith(shuffleKeys, partitionKeyList)) { + // If required distribution is not strict, Hash[a] can satisfy Hash[a, b]. + // so return true if shuffleKeys(Hash[a, b]) start with partitionKeyList(Hash[a]) + true + } else { + // If partialKey is enabled, try to use partial key to satisfy the required distribution + val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(this) + val partialKeyEnabled = tableConfig.getConf.getBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED) + partialKeyEnabled && partitionKeyList.containsAll(shuffleKeys) + } + case _ => false + } + if (!canSatisfy) { + return None + } + + val inputRequiredDistribution = requiredDistribution.getType match { + case SINGLETON => requiredDistribution + case HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + val partitionKeyList = ImmutableIntList.of(partitionKey.toArray: _*) + if (requiredDistribution.requireStrict) { + FlinkRelDistribution.hash(partitionKeyList) + } else if (Util.startsWith(shuffleKeys, partitionKeyList)) { + // Hash[a] can satisfy Hash[a, b] + FlinkRelDistribution.hash(partitionKeyList, requireStrict = false) + } else { + // use partial key to satisfy the required distribution + FlinkRelDistribution.hash(shuffleKeys.map(partitionKeyList(_)), requireStrict = false) + } + } + + // sort by partition keys + orderby keys + val providedFieldCollations = partitionKey.toArray.map { + k => FlinkRelOptUtil.ofRelFieldCollation(k) + }.toList ++ orderKey.getFieldCollations + val providedCollation = RelCollations.of(providedFieldCollations) + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val newProvidedTraitSet = if (providedCollation.satisfies(requiredCollation)) { + getTraitSet.replace(requiredDistribution).replace(requiredCollation) + } else { + getTraitSet.replace(requiredDistribution) + } + val newInput = RelOptRule.convert(getInput, inputRequiredDistribution) + Some(copy(newProvidedTraitSet, Seq(newInput))) + } + + private def satisfyTraitsOnLocalRank(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + requiredDistribution.getType match { + case Type.SINGLETON => + val inputRequiredDistribution = requiredDistribution + // sort by orderby keys + val providedCollation = orderKey + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val newProvidedTraitSet = if (providedCollation.satisfies(requiredCollation)) { + getTraitSet.replace(requiredDistribution).replace(requiredCollation) + } else { + getTraitSet.replace(requiredDistribution) + } + + val inputRequiredTraits = getInput.getTraitSet.replace(inputRequiredDistribution) + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + Some(copy(newProvidedTraitSet, Seq(newInput))) + case Type.HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + if (outputRankNumber) { + // rank function column is the last one + val rankColumnIndex = getRowType.getFieldCount - 1 + if (!shuffleKeys.contains(rankColumnIndex)) { + // Cannot satisfy required distribution if some keys are not from input + return None + } + } + + val inputRequiredDistributionKeys = shuffleKeys + val inputRequiredDistribution = FlinkRelDistribution.hash( + inputRequiredDistributionKeys, requiredDistribution.requireStrict) + + // sort by partition keys + orderby keys + val providedFieldCollations = partitionKey.toArray.map { + k => FlinkRelOptUtil.ofRelFieldCollation(k) + }.toList ++ orderKey.getFieldCollations + val providedCollation = RelCollations.of(providedFieldCollations) + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val newProvidedTraitSet = if (providedCollation.satisfies(requiredCollation)) { + getTraitSet.replace(requiredDistribution).replace(requiredCollation) + } else { + getTraitSet.replace(requiredDistribution) + } + + val inputRequiredTraits = getInput.getTraitSet.replace(inputRequiredDistribution) + val newInput = RelOptRule.convert(getInput, inputRequiredTraits) + Some(copy(newProvidedTraitSet, Seq(newInput))) + case _ => None + } + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = DamBehavior.PIPELINED diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala index 208ad9ec1e574e..1fded572bf77b2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortAggregate.scala @@ -19,19 +19,24 @@ package org.apache.flink.table.plan.nodes.physical.batch import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.StreamTransformation -import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.plan.util.RelExplainUtil +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} +import org.apache.flink.table.plan.util.{FlinkRelOptUtil, RelExplainUtil} -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.tools.RelBuilder +import org.apache.calcite.util.{ImmutableIntList, Util} import java.util +import scala.collection.JavaConversions._ + /** * Batch physical RelNode for (global) sort-based aggregate operator. * @@ -95,6 +100,64 @@ class BatchExecSortAggregate( isGlobal = true)) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val canSatisfy = requiredDistribution.getType match { + case SINGLETON => grouping.length == 0 + case HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + val groupKeysList = ImmutableIntList.of(grouping.indices.toArray: _*) + if (requiredDistribution.requireStrict) { + shuffleKeys == groupKeysList + } else if (Util.startsWith(shuffleKeys, groupKeysList)) { + // If required distribution is not strict, Hash[a] can satisfy Hash[a, b]. + // so return true if shuffleKeys(Hash[a, b]) start with groupKeys(Hash[a]) + true + } else { + // If partialKey is enabled, try to use partial key to satisfy the required distribution + val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(this) + val partialKeyEnabled = tableConfig.getConf.getBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED) + partialKeyEnabled && groupKeysList.containsAll(shuffleKeys) + } + case _ => false + } + if (!canSatisfy) { + return None + } + + val inputRequiredDistribution = requiredDistribution.getType match { + case SINGLETON => requiredDistribution + case HASH_DISTRIBUTED => + val shuffleKeys = requiredDistribution.getKeys + val groupKeysList = ImmutableIntList.of(grouping.indices.toArray: _*) + if (requiredDistribution.requireStrict) { + FlinkRelDistribution.hash(grouping, requireStrict = true) + } else if (Util.startsWith(shuffleKeys, groupKeysList)) { + // Hash [a] can satisfy Hash[a, b] + FlinkRelDistribution.hash(grouping, requireStrict = false) + } else { + // use partial key to satisfy the required distribution + FlinkRelDistribution.hash(shuffleKeys.map(grouping(_)).toArray, requireStrict = false) + } + } + + val providedCollation = if (grouping.length == 0) { + RelCollations.EMPTY + } else { + val providedFieldCollations = grouping.map(FlinkRelOptUtil.ofRelFieldCollation).toList + RelCollations.of(providedFieldCollations) + } + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val newProvidedTraitSet = if (providedCollation.satisfies(requiredCollation)) { + getTraitSet.replace(requiredDistribution).replace(requiredCollation) + } else { + getTraitSet.replace(requiredDistribution) + } + val newInput = RelOptRule.convert(getInput, inputRequiredDistribution) + Some(copy(newProvidedTraitSet, Seq(newInput))) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala index fd24c169156a95..42d248b0e507cf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala @@ -26,16 +26,17 @@ import org.apache.flink.table.codegen.CodeGeneratorContext import org.apache.flink.table.codegen.ProjectionCodeGenerator.generateProjection import org.apache.flink.table.codegen.sort.SortCodeGenerator import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.`trait`.FlinkRelDistributionTraitDef import org.apache.flink.table.plan.cost.{FlinkCost, FlinkCostFactory} import org.apache.flink.table.plan.nodes.ExpressionFormat import org.apache.flink.table.plan.nodes.exec.ExecNode -import org.apache.flink.table.plan.util.{FlinkRelMdUtil, JoinUtil, SortUtil} +import org.apache.flink.table.plan.util.{FlinkRelMdUtil, FlinkRelOptUtil, JoinUtil, SortUtil} import org.apache.flink.table.runtime.join.{FlinkJoinType, SortMergeJoinOperator} import org.apache.calcite.plan._ import org.apache.calcite.rel.core._ import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.rel.{RelNode, RelWriter} +import org.apache.calcite.rel.{RelCollationTraitDef, RelNode, RelWriter} import org.apache.calcite.rex.RexNode import java.util @@ -142,6 +143,55 @@ class BatchExecSortMergeJoin( costFactory.makeCost(rowCount, cpuCost, 0, 0, sortMemCost) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val (canSatisfyDistribution, leftRequiredDistribution, rightRequiredDistribution) = + satisfyHashDistributionOnNonBroadcastJoin(requiredDistribution) + if (!canSatisfyDistribution) { + return None + } + + val requiredCollation = requiredTraitSet.getTrait(RelCollationTraitDef.INSTANCE) + val requiredFieldCollations = requiredCollation.getFieldCollations + val shuffleKeysSize = leftRequiredDistribution.getKeys.size + + val newLeft = RelOptRule.convert(getLeft, leftRequiredDistribution) + val newRight = RelOptRule.convert(getRight, rightRequiredDistribution) + + // SortMergeJoin can provide collation trait, check whether provided collation can satisfy + // required collations + val canProvideCollation = if (requiredCollation.getFieldCollations.isEmpty) { + false + } else if (requiredFieldCollations.size > shuffleKeysSize) { + // Sort by [a, b] can satisfy [a], but cannot satisfy [a, b, c] + false + } else { + val leftKeys = leftRequiredDistribution.getKeys + val leftFieldCnt = getLeft.getRowType.getFieldCount + val rightKeys = rightRequiredDistribution.getKeys.map(_ + leftFieldCnt) + requiredFieldCollations.zipWithIndex.forall { case (collation, index) => + val idxOfCollation = collation.getFieldIndex + // Full outer join is handled before, so does not need care about it + if (idxOfCollation < leftFieldCnt && joinType != JoinRelType.RIGHT) { + val fieldCollationOnLeftSortKey = FlinkRelOptUtil.ofRelFieldCollation(leftKeys.get(index)) + collation == fieldCollationOnLeftSortKey + } else if (idxOfCollation >= leftFieldCnt && + (joinType == JoinRelType.RIGHT || joinType == JoinRelType.INNER)) { + val fieldCollationOnRightSortKey = + FlinkRelOptUtil.ofRelFieldCollation(rightKeys.get(index)) + collation == fieldCollationOnRightSortKey + } else { + false + } + } + } + var newProvidedTraitSet = getTraitSet.replace(requiredDistribution) + if (canProvideCollation) { + newProvidedTraitSet = newProvidedTraitSet.replace(requiredCollation) + } + Some(copy(newProvidedTraitSet, Seq(newLeft, newRight))) + } + //~ ExecNode methods ----------------------------------------------------------- /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecUnion.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecUnion.scala index 9919d0bb7a0122..b0f6aa03fe4970 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecUnion.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecUnion.scala @@ -21,9 +21,11 @@ import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.{StreamTransformation, UnionTransformation} import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} +import org.apache.calcite.rel.RelDistribution.Type.{ANY, BROADCAST_DISTRIBUTED, HASH_DISTRIBUTED, RANDOM_DISTRIBUTED, RANGE_DISTRIBUTED, ROUND_ROBIN_DISTRIBUTED, SINGLETON} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{SetOp, Union} import org.apache.calcite.rel.{RelNode, RelWriter} @@ -58,6 +60,41 @@ class BatchExecUnion( .item("union", getRowType.getFieldNames.mkString(", ")) } + override def satisfyTraits(requiredTraitSet: RelTraitSet): Option[RelNode] = { + // union will destroy collation trait. So does not handle collation requirement. + val requiredDistribution = requiredTraitSet.getTrait(FlinkRelDistributionTraitDef.INSTANCE) + val canSatisfy = requiredDistribution.getType match { + case RANDOM_DISTRIBUTED | + ROUND_ROBIN_DISTRIBUTED | + BROADCAST_DISTRIBUTED | + HASH_DISTRIBUTED => true + // range distribution cannot be satisfied because partition's [lower, upper] of each union + // child may be different. + case RANGE_DISTRIBUTED => false + // singleton cannot cannot be satisfied because singleton exchange limits the parallelism of + // exchange output RelNode to 1. + // Push down Singleton into input of union will destroy the limitation. + case SINGLETON => false + // there is no need to satisfy Any distribution + case ANY => false + } + if (!canSatisfy) { + return None + } + + val inputRequiredDistribution = requiredDistribution.getType match { + case RANDOM_DISTRIBUTED | ROUND_ROBIN_DISTRIBUTED | BROADCAST_DISTRIBUTED => + requiredDistribution + case HASH_DISTRIBUTED => + // apply strict hash distribution of each child + // to avoid inconsistent of shuffle of each child + FlinkRelDistribution.hash(requiredDistribution.getKeys) + } + val newInputs = getInputs.map(RelOptRule.convert(_, inputRequiredDistribution)) + val providedTraitSet = getTraitSet.replace(inputRequiredDistribution) + Some(copy(providedTraitSet, newInputs)) + } + //~ ExecNode methods ----------------------------------------------------------- override def getDamBehavior: DamBehavior = DamBehavior.PIPELINED diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala index 2474dba8ecadda..d7e45b30794b5b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala @@ -192,7 +192,8 @@ object FlinkBatchRuleSets { val WINDOW_RULES: RuleSet = RuleSets.ofList( // slices a project into sections which contain window agg functions and sections which do not. ProjectToWindowRule.PROJECT, - // TODO add ExchangeWindowGroupRule + //adjust the sequence of window's groups. + WindowGroupReorderRule.INSTANCE, // Transform window to LogicalWindowAggregate WindowPropertiesRules.WINDOW_PROPERTIES_RULE, WindowPropertiesRules.WINDOW_PROPERTIES_HAVING_RULE @@ -311,28 +312,44 @@ object FlinkBatchRuleSets { */ val PHYSICAL_OPT_RULES: RuleSet = RuleSets.ofList( FlinkExpandConversionRule.BATCH_INSTANCE, + // source BatchExecBoundedStreamScanRule.INSTANCE, BatchExecScanTableSourceRule.INSTANCE, BatchExecIntermediateTableScanRule.INSTANCE, BatchExecValuesRule.INSTANCE, + // calc BatchExecCalcRule.INSTANCE, + // union BatchExecUnionRule.INSTANCE, + // sort BatchExecSortRule.INSTANCE, BatchExecLimitRule.INSTANCE, BatchExecSortLimitRule.INSTANCE, + // rank BatchExecRankRule.INSTANCE, + RemoveRedundantLocalRankRule.INSTANCE, + // expand BatchExecExpandRule.INSTANCE, + // group agg BatchExecHashAggRule.INSTANCE, BatchExecSortAggRule.INSTANCE, + RemoveRedundantLocalSortAggRule.WITHOUT_SORT, + RemoveRedundantLocalSortAggRule.WITH_SORT, + RemoveRedundantLocalHashAggRule.INSTANCE, + // over agg + BatchExecOverAggregateRule.INSTANCE, + // window agg + BatchExecWindowAggregateRule.INSTANCE, + // join BatchExecHashJoinRule.INSTANCE, BatchExecSortMergeJoinRule.INSTANCE, BatchExecNestedLoopJoinRule.INSTANCE, BatchExecSingleRowJoinRule.INSTANCE, - BatchExecCorrelateRule.INSTANCE, - BatchExecOverAggregateRule.INSTANCE, - BatchExecWindowAggregateRule.INSTANCE, BatchExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN, BatchExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN, + // correlate + BatchExecCorrelateRule.INSTANCE, + // sink BatchExecSinkRule.INSTANCE ) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala index c030929f069c6d..ce99a988d42088 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala @@ -293,27 +293,39 @@ object FlinkStreamRuleSets { */ val PHYSICAL_OPT_RULES: RuleSet = RuleSets.ofList( FlinkExpandConversionRule.STREAM_INSTANCE, + // source StreamExecDataStreamScanRule.INSTANCE, StreamExecTableSourceScanRule.INSTANCE, StreamExecIntermediateTableScanRule.INSTANCE, StreamExecValuesRule.INSTANCE, + // calc StreamExecCalcRule.INSTANCE, + // union StreamExecUnionRule.INSTANCE, + // sort StreamExecSortRule.INSTANCE, StreamExecLimitRule.INSTANCE, StreamExecSortLimitRule.INSTANCE, - StreamExecRankRule.INSTANCE, StreamExecTemporalSortRule.INSTANCE, + // rank + StreamExecRankRule.INSTANCE, StreamExecDeduplicateRule.RANK_INSTANCE, + // expand + StreamExecExpandRule.INSTANCE, + // group agg StreamExecGroupAggregateRule.INSTANCE, + // over agg StreamExecOverAggregateRule.INSTANCE, + // window agg StreamExecGroupWindowAggregateRule.INSTANCE, - StreamExecExpandRule.INSTANCE, + // join StreamExecJoinRule.INSTANCE, StreamExecWindowJoinRule.INSTANCE, - StreamExecCorrelateRule.INSTANCE, StreamExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN, StreamExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN, + // correlate + StreamExecCorrelateRule.INSTANCE, + // sink StreamExecSinkRule.INSTANCE ) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRule.scala new file mode 100644 index 00000000000000..c263371e27cf05 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRule.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.Window.Group +import org.apache.calcite.rel.logical.{LogicalProject, LogicalWindow} +import org.apache.calcite.rel.{RelCollation, RelNode} +import org.apache.calcite.rex.RexInputRef + +import java.util +import java.util.Comparator + +import scala.collection.JavaConversions._ + +/** + * Planner rule that makes the over window groups which have the same shuffle keys and order keys + * together. + */ +class WindowGroupReorderRule extends RelOptRule( + operand(classOf[LogicalWindow], + operand(classOf[RelNode], any)), + "ExchangeWindowGroupRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val window: LogicalWindow = call.rel(0) + window.groups.size() > 1 + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val window: LogicalWindow = call.rel(0) + val input: RelNode = call.rel(1) + val oldGroups: util.List[Group] = new util.ArrayList(window.groups) + val sequenceGroups: util.List[Group] = new util.ArrayList(window.groups) + + sequenceGroups.sort(new Comparator[Group] { + override def compare(o1: Group, o2: Group): Int = { + val keyComp = o1.keys.compareTo(o2.keys) + if (keyComp == 0) { + compareRelCollation(o1.orderKeys, o2.orderKeys) + } else { + keyComp + } + } + }) + + if (!sequenceGroups.equals(oldGroups) && !sequenceGroups.reverse.equals(oldGroups)) { + var offset = input.getRowType.getFieldCount + val aggTypeIndexes = oldGroups.map { group => + val aggCount = group.aggCalls.size() + val typeIndexes = (0 until aggCount).map(_ + offset).toArray + offset += aggCount + typeIndexes + } + + offset = input.getRowType.getFieldCount + val mapToOldTypeIndexes = (0 until offset).toArray ++ + sequenceGroups.flatMap { newGroup => + val aggCount = newGroup.aggCalls.size() + val oldIndex = oldGroups.indexOf(newGroup) + offset += aggCount + (0 until aggCount).map { + aggIndex => aggTypeIndexes(oldIndex)(aggIndex) + } + }.toArray[Int] + + val oldRowTypeFields = window.getRowType.getFieldList + val newFieldList = new util.ArrayList[util.Map.Entry[String, RelDataType]] + mapToOldTypeIndexes.foreach { index => + newFieldList.add(oldRowTypeFields.get(index)) + } + val intermediateRowType = window.getCluster.getTypeFactory.createStructType(newFieldList) + val newLogicalWindow = LogicalWindow.create( + window.getCluster.getPlanner.emptyTraitSet(), + input, + window.constants, + intermediateRowType, + sequenceGroups) + + val mapToNewTypeIndexes = mapToOldTypeIndexes.zipWithIndex.sortBy(_._1) + + val projects = mapToNewTypeIndexes.map { index => + new RexInputRef(index._2, newFieldList.get(index._2).getValue) + } + val project = LogicalProject.create( + newLogicalWindow, + projects.toList, + window.getRowType) + call.transformTo(project) + } + } + + private def compareRelCollation(o1: RelCollation, o2: RelCollation): Int = { + val comp = o1.compareTo(o2) + if (comp == 0) { + val collations1 = o1.getFieldCollations + val collations2 = o2.getFieldCollations + for (index <- 0 until collations1.length) { + val collation1 = collations1(index) + val collation2 = collations2(index) + val direction = collation1.direction.shortString.compareTo(collation2.direction.shortString) + if (direction == 0) { + val nullDirection = collation1.nullDirection.nullComparison.compare( + collation2.nullDirection.nullComparison) + if (nullDirection != 0) { + return nullDirection + } + } else { + return direction + } + } + } + comp + } +} + +object WindowGroupReorderRule { + val INSTANCE = new WindowGroupReorderRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/FlinkExpandConversionRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/FlinkExpandConversionRule.scala index 9b7d61a89f132a..97f3332ecd0ff1 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/FlinkExpandConversionRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/FlinkExpandConversionRule.scala @@ -81,14 +81,16 @@ class FlinkExpandConversionRule(flinkConvention: Convention) call: RelOptRuleCall): Unit = { node match { case batchRel: BatchPhysicalRel => - var otherChoice = batchRel.satisfyTraitsByInput(requiredTraits) - if (otherChoice != null) { - // It is possible only push down distribution instead of push down both distribution and - // collation. So it is necessary to check whether collation satisfy requirement. - val requiredCollation = requiredTraits.getTrait(RelCollationTraitDef.INSTANCE) - otherChoice = satisfyCollation(flinkConvention, otherChoice, requiredCollation) - checkSatisfyRequiredTrait(otherChoice, requiredTraits) - call.transformTo(otherChoice) + val otherChoice = batchRel.satisfyTraits(requiredTraits) + otherChoice match { + case Some(newRel) => + // It is possible only push down distribution instead of push down both distribution and + // collation. So it is necessary to check whether collation satisfy requirement. + val requiredCollation = requiredTraits.getTrait(RelCollationTraitDef.INSTANCE) + val finalRel = satisfyCollation(flinkConvention, newRel, requiredCollation) + checkSatisfyRequiredTrait(finalRel, requiredTraits) + call.transformTo(finalRel) + case _ => // do nothing } case _ => // ignore } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala new file mode 100644 index 00000000000000..379e5b4a302ca7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchExecLocalHashAggregate} + +import org.apache.calcite.plan.RelOptRule._ +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.RelNode + +/** + * There maybe exist a subTree like localHashAggregate -> globalHashAggregate which the middle + * shuffle is removed. The rule could remove redundant localHashAggregate node. + */ +class RemoveRedundantLocalHashAggRule extends RelOptRule( + operand(classOf[BatchExecHashAggregate], + operand(classOf[BatchExecLocalHashAggregate], + operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))), + "RemoveRedundantLocalHashAggRule") { + + override def onMatch(call: RelOptRuleCall): Unit = { + val globalAgg = call.rels(0).asInstanceOf[BatchExecHashAggregate] + val localAgg = call.rels(1).asInstanceOf[BatchExecLocalHashAggregate] + val inputOfLocalAgg = localAgg.getInput + val newGlobalAgg = new BatchExecHashAggregate( + globalAgg.getCluster, + call.builder(), + globalAgg.getTraitSet, + inputOfLocalAgg, + globalAgg.getRowType, + inputOfLocalAgg.getRowType, + inputOfLocalAgg.getRowType, + localAgg.getGrouping, + localAgg.getAuxGrouping, + globalAgg.getAggCallToAggFunction, + isMerge = false) + call.transformTo(newGlobalAgg) + } +} + +object RemoveRedundantLocalHashAggRule { + val INSTANCE = new RemoveRedundantLocalHashAggRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRule.scala new file mode 100644 index 00000000000000..304710db9e7764 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRule.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.physical.batch.BatchExecRank + +import org.apache.calcite.plan.RelOptRule._ +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.RelNode + +import scala.collection.JavaConversions._ + +/** + * Planner rule that matches a global [[BatchExecRank]] on a local [[BatchExecRank]], + * and merge them into a global [[BatchExecRank]]. + */ +class RemoveRedundantLocalRankRule extends RelOptRule( + operand(classOf[BatchExecRank], + operand(classOf[BatchExecRank], + operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))), + "RemoveRedundantLocalRankRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val globalRank: BatchExecRank = call.rel(0) + val localRank: BatchExecRank = call.rel(1) + globalRank.isGlobal && !localRank.isGlobal && + globalRank.rankType == localRank.rankType && + globalRank.partitionKey == localRank.partitionKey && + globalRank.orderKey == globalRank.orderKey && + globalRank.rankEnd == localRank.rankEnd + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val globalRank: BatchExecRank = call.rel(0) + val inputOfLocalRank: RelNode = call.rel(2) + val newGlobalRank = globalRank.copy(globalRank.getTraitSet, List(inputOfLocalRank)) + call.transformTo(newGlobalRank) + } +} + +object RemoveRedundantLocalRankRule { + val INSTANCE: RelOptRule = new RemoveRedundantLocalRankRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala new file mode 100644 index 00000000000000..4f719ebd8896ae --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchExecSort, BatchExecSortAggregate} + +import org.apache.calcite.plan.RelOptRule._ +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand} +import org.apache.calcite.rel.RelNode + +/** + * There maybe exist a subTree like localSortAggregate -> globalSortAggregate, or + * localSortAggregate -> sort -> globalSortAggregate which the middle shuffle is removed. + * The rule could remove redundant localSortAggregate node. + */ +abstract class RemoveRedundantLocalSortAggRule( + operand: RelOptRuleOperand, + ruleName: String) extends RelOptRule(operand, ruleName) { + + override def onMatch(call: RelOptRuleCall): Unit = { + val globalAgg = getOriginalGlobalAgg(call) + val localAgg = getOriginalLocalAgg(call) + val inputOfLocalAgg = getOriginalInputOfLocalAgg(call) + val newGlobalAgg = new BatchExecSortAggregate( + globalAgg.getCluster, + call.builder(), + globalAgg.getTraitSet, + inputOfLocalAgg, + globalAgg.getRowType, + inputOfLocalAgg.getRowType, + inputOfLocalAgg.getRowType, + localAgg.getGrouping, + localAgg.getAuxGrouping, + globalAgg.getAggCallToAggFunction, + isMerge = false) + call.transformTo(newGlobalAgg) + } + + private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchExecSortAggregate + + private[table] def getOriginalLocalAgg(call: RelOptRuleCall): BatchExecLocalSortAggregate + + private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode + +} + +class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSortAggRule( + operand(classOf[BatchExecSortAggregate], + operand(classOf[BatchExecLocalSortAggregate], + operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))), + "RemoveRedundantLocalSortAggWithoutSortRule") { + + override private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchExecSortAggregate = { + call.rels(0).asInstanceOf[BatchExecSortAggregate] + } + + override private[table] def getOriginalLocalAgg( + call: RelOptRuleCall): BatchExecLocalSortAggregate = { + call.rels(1).asInstanceOf[BatchExecLocalSortAggregate] + } + + override private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode = { + call.rels(2) + } + +} + +class RemoveRedundantLocalSortAggWithSortRule extends RemoveRedundantLocalSortAggRule( + operand(classOf[BatchExecSortAggregate], + operand(classOf[BatchExecSort], + operand(classOf[BatchExecLocalSortAggregate], + operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any)))), + "RemoveRedundantLocalSortAggWithSortRule") { + + override private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchExecSortAggregate = { + call.rels(0).asInstanceOf[BatchExecSortAggregate] + } + + override private[table] def getOriginalLocalAgg( + call: RelOptRuleCall): BatchExecLocalSortAggregate = { + call.rels(2).asInstanceOf[BatchExecLocalSortAggregate] + } + + override private[table] def getOriginalInputOfLocalAgg(call: RelOptRuleCall): RelNode = { + call.rels(3) + } + +} + +object RemoveRedundantLocalSortAggRule { + val WITHOUT_SORT = new RemoveRedundantLocalSortAggWithoutSortRule + val WITH_SORT = new RemoveRedundantLocalSortAggWithSortRule +} diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml index a58970bfd39d20..911bbd2f41c24e 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DagOptimizationTest.xml @@ -726,15 +726,14 @@ Sink(fields=[a1, b, c1]) :- Exchange(distribution=[hash[a3]], exchange_mode=[BATCH]) : +- Calc(select=[a AS a3, c AS c1], where=[AND(>=(a, 0), <(b, 5))]) : +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c], reuse_id=[1]) - +- Exchange(distribution=[hash[a1]]) - +- Calc(select=[a AS a1, b]) - +- HashJoin(joinType=[InnerJoin], where=[=(a, a2)], select=[a, b, a2], build=[right]) - :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) - : +- Calc(select=[a, b], where=[<=(a, 10)]) - : +- Reused(reference_id=[1]) - +- Exchange(distribution=[hash[a2]]) - +- Calc(select=[a AS a2], where=[AND(>=(a, 0), >=(b, 5))]) - +- Reused(reference_id=[1]) + +- Calc(select=[a AS a1, b]) + +- HashJoin(joinType=[InnerJoin], where=[=(a, a2)], select=[a, b, a2], build=[right]) + :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) + : +- Calc(select=[a, b], where=[<=(a, 10)]) + : +- Reused(reference_id=[1]) + +- Exchange(distribution=[hash[a2]]) + +- Calc(select=[a AS a2], where=[AND(>=(a, 0), >=(b, 5))]) + +- Reused(reference_id=[1]) ]]> @@ -784,21 +783,19 @@ Sink(fields=[a, b, c]) :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) : +- Calc(select=[a], where=[<=(a, 10)]) : +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c], reuse_id=[1]) - +- Exchange(distribution=[hash[a3]]) - +- Calc(select=[a3, b AS b1, c1]) - +- HashJoin(joinType=[InnerJoin], where=[=(a1, a3)], select=[a3, c1, a1, b], build=[right]) - :- Exchange(distribution=[hash[a3]], exchange_mode=[BATCH]) - : +- Calc(select=[a AS a3, c AS c1], where=[AND(>=(a, 0), <(b, 5))]) - : +- Reused(reference_id=[1]) - +- Exchange(distribution=[hash[a1]]) - +- Calc(select=[a AS a1, b]) - +- HashJoin(joinType=[InnerJoin], where=[=(a, a2)], select=[a, b, a2], build=[right]) - :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) - : +- Calc(select=[a, b], where=[<=(a, 10)]) - : +- Reused(reference_id=[1]) - +- Exchange(distribution=[hash[a2]]) - +- Calc(select=[a AS a2], where=[AND(>=(a, 0), >=(b, 5))]) - +- Reused(reference_id=[1]) + +- Calc(select=[a3, b AS b1, c1]) + +- HashJoin(joinType=[InnerJoin], where=[=(a1, a3)], select=[a3, c1, a1, b], build=[right]) + :- Exchange(distribution=[hash[a3]], exchange_mode=[BATCH]) + : +- Calc(select=[a AS a3, c AS c1], where=[AND(>=(a, 0), <(b, 5))]) + : +- Reused(reference_id=[1]) + +- Calc(select=[a AS a1, b]) + +- HashJoin(joinType=[InnerJoin], where=[=(a, a2)], select=[a, b, a2], build=[right]) + :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) + : +- Calc(select=[a, b], where=[<=(a, 10)]) + : +- Reused(reference_id=[1]) + +- Exchange(distribution=[hash[a2]]) + +- Calc(select=[a AS a2], where=[AND(>=(a, 0), >=(b, 5))]) + +- Reused(reference_id=[1]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml index f035f65cbbc552..69ec9d6436a93f 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/DeadlockBreakupTest.xml @@ -89,22 +89,21 @@ LogicalProject(a=[$0], b=[$1], c=[$2], a0=[$3], b0=[$4], c0=[$5], a1=[$6], b1=[$ HashJoin(joinType=[InnerJoin], where=[=(c, c1)], select=[a, b, c, a0, b0, c0, a1, b1, c1, a00, b00, c00], build=[right]) :- Exchange(distribution=[hash[c]], exchange_mode=[BATCH]) : +- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, c, a0, b0, c0], build=[left]) -: :- Exchange(distribution=[hash[a]]) -: : +- Calc(select=[a, b, c], where=[>(b, 10)]) -: : +- SortAggregate(isMerge=[true], groupBy=[a], select=[a, Final_SUM(sum$0) AS b, Final_MAX(max$1) AS c], reuse_id=[1]) -: : +- Sort(orderBy=[a ASC]) -: : +- Exchange(distribution=[hash[a]]) -: : +- LocalSortAggregate(groupBy=[a], select=[a, Partial_SUM(b) AS sum$0, Partial_MAX(c) AS max$1]) -: : +- Sort(orderBy=[a ASC]) -: : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -: +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH], reuse_id=[2]) +: :- Calc(select=[a, b, c], where=[>(b, 10)]) +: : +- SortAggregate(isMerge=[true], groupBy=[a], select=[a, Final_SUM(sum$0) AS b, Final_MAX(max$1) AS c], reuse_id=[1]) +: : +- Sort(orderBy=[a ASC]) +: : +- Exchange(distribution=[hash[a]]) +: : +- LocalSortAggregate(groupBy=[a], select=[a, Partial_SUM(b) AS sum$0, Partial_MAX(c) AS max$1]) +: : +- Sort(orderBy=[a ASC]) +: : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +: +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) : +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[c]]) +- HashJoin(joinType=[InnerJoin], where=[=(a, a0)], select=[a, b, c, a0, b0, c0], build=[left]) - :- Exchange(distribution=[hash[a]]) - : +- Calc(select=[a, b, c], where=[<(b, 5)]) - : +- Reused(reference_id=[1]) - +- Reused(reference_id=[2]) + :- Calc(select=[a, b, c], where=[<(b, 5)]) + : +- Reused(reference_id=[1]) + +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) + +- Reused(reference_id=[1]) ]]> @@ -141,17 +140,16 @@ Calc(select=[a, b]) :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) : +- Calc(select=[a], where=[=(b, 5:BIGINT)]) : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c], reuse_id=[1]) - +- Exchange(distribution=[hash[a]]) - +- HashJoin(joinType=[LeftSemiJoin], where=[=(a, a0)], select=[a, b], build=[left]) - :- Exchange(distribution=[hash[a]]) - : +- Calc(select=[a, b]) - : +- Limit(offset=[0], fetch=[10], global=[true]) - : +- Exchange(distribution=[single]) - : +- Limit(offset=[0], fetch=[10], global=[false]) - : +- Reused(reference_id=[1]) - +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) - +- Calc(select=[a], where=[>(b, 5)]) - +- Reused(reference_id=[1]) + +- HashJoin(joinType=[LeftSemiJoin], where=[=(a, a0)], select=[a, b], build=[left]) + :- Exchange(distribution=[hash[a]]) + : +- Calc(select=[a, b]) + : +- Limit(offset=[0], fetch=[10], global=[true]) + : +- Exchange(distribution=[single]) + : +- Limit(offset=[0], fetch=[10], global=[false]) + : +- Reused(reference_id=[1]) + +- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) + +- Calc(select=[a], where=[>(b, 5)]) + +- Reused(reference_id=[1]) ]]> @@ -298,13 +296,11 @@ HashJoin(joinType=[InnerJoin], where=[=(c, c0)], select=[a, b, c, a0, b0, c0], b :- Exchange(distribution=[hash[c]], exchange_mode=[BATCH]) : +- Calc(select=[w0$o0 AS a, b, c]) : +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MAX($2) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[b, c, $2, w0$o0]) -: +- Sort(orderBy=[b ASC], reuse_id=[1]) -: +- Exchange(distribution=[hash[b]]) -: +- Calc(select=[b, c, CASE(>(w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS $2]) -: +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1]) -: +- Sort(orderBy=[b ASC]) -: +- Exchange(distribution=[hash[b]]) -: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +: +- Calc(select=[b, c, CASE(>(w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS $2], reuse_id=[1]) +: +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1]) +: +- Sort(orderBy=[b ASC]) +: +- Exchange(distribution=[hash[b]]) +: +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[w0$o0 AS a, b, c]) +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MIN($2) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[b, c, $2, w0$o0]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.xml new file mode 100644 index 00000000000000..283177ea21f527 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.xml @@ -0,0 +1,708 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + (COUNT($1) OVER (PARTITION BY $0 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($1) OVER (PARTITION BY $0 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:BIGINT)):DOUBLE, COUNT($1) OVER (PARTITION BY $0 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], rn=[RANK() OVER (PARTITION BY $0 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalAggregate(group=[{0}], sum_b=[SUM($1)]) + +- LogicalProject(a=[$0], b=[$1]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w0$o2 AS rn]) ++- OverAggregate(partitionBy=[a], orderBy=[a ASC], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1, RANK(*) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, sum_b, w0$o0, w0$o1, w0$o2]) + +- SortAggregate(isMerge=[true], groupBy=[a], select=[a, Final_SUM(sum$0) AS sum_b]) + +- Sort(orderBy=[a ASC]) + +- Exchange(distribution=[hash[a]]) + +- LocalSortAggregate(groupBy=[a], select=[a, Partial_SUM(b) AS sum$0]) + +- Sort(orderBy=[a ASC]) + +- Calc(select=[a, b]) + +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.xml new file mode 100644 index 00000000000000..e23d8a044f1fab --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.xml @@ -0,0 +1,1264 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + (COUNT($2) OVER (PARTITION BY $0, $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 0), $SUM0($2) OVER (PARTITION BY $0, $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), null:BIGINT)):DOUBLE, COUNT($2) OVER (PARTITION BY $0, $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))], rn=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST, $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], c=[$1]) ++- LogicalAggregate(group=[{0, 1}], sum_b=[SUM($2)]) + +- LogicalProject(a=[$0], c=[$2], b=[$1]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c]) ++- OverAggregate(partitionBy=[c], orderBy=[a ASC, c ASC], window#0=[RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, sum_b, w0$o0, w0$o1, w1$o0]) + +- Sort(orderBy=[c ASC, a ASC]) + +- Exchange(distribution=[hash[c]]) + +- OverAggregate(partitionBy=[a, c], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[a, c, sum_b, w0$o0, w0$o1]) + +- Sort(orderBy=[a ASC, c ASC]) + +- HashAggregate(isMerge=[true], groupBy=[a, c], select=[a, c, Final_SUM(sum$0) AS sum_b]) + +- Exchange(distribution=[hash[a, c]]) + +- LocalHashAggregate(groupBy=[a, c], select=[a, c, Partial_SUM(b) AS sum$0]) + +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + + + + + + + + + (COUNT($1) OVER (PARTITION BY $0 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 0), $SUM0($1) OVER (PARTITION BY $0 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), null:BIGINT)):DOUBLE, COUNT($1) OVER (PARTITION BY $0 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))], rn=[RANK() OVER (PARTITION BY $0 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], c=[$0]) ++- LogicalAggregate(group=[{0}], sum_b=[SUM($1)]) + +- LogicalProject(c=[$2], b=[$1]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c]) ++- OverAggregate(partitionBy=[c], orderBy=[], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], window#1=[RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, sum_b, w0$o0, w0$o1, w1$o0]) + +- Sort(orderBy=[c ASC]) + +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_SUM(sum$0) AS sum_b]) + +- Exchange(distribution=[hash[c]]) + +- LocalHashAggregate(groupBy=[c], select=[c, Partial_SUM(b) AS sum$0]) + +- Calc(select=[c, b]) + +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + (COUNT($2) OVER (PARTITION BY $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 0), $SUM0($2) OVER (PARTITION BY $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), null:BIGINT)):DOUBLE, COUNT($2) OVER (PARTITION BY $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))], rn=[RANK() OVER (PARTITION BY $1 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], c=[$1]) ++- LogicalAggregate(group=[{0, 1}], sum_b=[SUM($2)]) + +- LogicalProject(a=[$0], c=[$2], b=[$1]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c]) ++- OverAggregate(partitionBy=[c], orderBy=[], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], window#1=[RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, sum_b, w0$o0, w0$o1, w1$o0]) + +- Calc(select=[c, sum_b]) + +- Sort(orderBy=[c ASC]) + +- HashAggregate(isMerge=[true], groupBy=[a, c], select=[a, c, Final_SUM(sum$0) AS sum_b]) + +- Exchange(distribution=[hash[c]]) + +- LocalHashAggregate(groupBy=[a, c], select=[a, c, Partial_SUM(b) AS sum$0]) + +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + (SELECT sum(b) * 0.1 FROM x)]]> + + + ($2, $SCALAR_QUERY({ +LogicalProject(EXPR$0=[*($0, 0.1:DECIMAL(2, 1))]) + LogicalAggregate(group=[{}], agg#0=[SUM($0)]) + LogicalProject(b=[$1]) + LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +}))]) + +- LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)], agg#1=[SUM($1)]) + +- LogicalProject(c=[$2], b=[$1]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($f2, $f0)], select=[EXPR$0, $f2, $f0], build=[right], singleRowJoin=[true]) + :- Calc(select=[EXPR$0, $f2]) + : +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_AVG(sum$0, count$1) AS EXPR$0, Final_SUM(sum$2) AS $f2]) + : +- Exchange(distribution=[hash[c]]) + : +- LocalHashAggregate(groupBy=[c], select=[c, Partial_AVG(b) AS (sum$0, count$1), Partial_SUM(b) AS sum$2]) + : +- Calc(select=[c, b]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Exchange(distribution=[broadcast]) + +- SortAggregate(isMerge=[false], select=[SINGLE_VALUE(EXPR$0) AS $f0]) + +- Calc(select=[*($f0, 0.1:DECIMAL(2, 1)) AS EXPR$0]) + +- SortAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) + +- Exchange(distribution=[single]) + +- LocalSortAggregate(select=[Partial_SUM(b) AS sum$0]) + +- Calc(select=[b]) + +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + 10 group by c +UNION ALL +SELECT count(d) as cnt, f FROM y WHERE e < 100 group by f) +SELECT r1.c, r1.cnt, r2.c, r2.cnt FROM r r1, r r2 WHERE r1.c = r2.c and r1.cnt < 10 + ]]> + + + ($1, 10)]) + : : +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + : +- LogicalProject(cnt=[$1], f=[$0]) + : +- LogicalAggregate(group=[{0}], cnt=[COUNT($1)]) + : +- LogicalProject(f=[$2], d=[$0]) + : +- LogicalFilter(condition=[<($1, 100)]) + : +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) + +- LogicalUnion(all=[true]) + :- LogicalProject(cnt=[$1], c=[$0]) + : +- LogicalAggregate(group=[{0}], cnt=[COUNT($1)]) + : +- LogicalProject(c=[$2], a=[$0]) + : +- LogicalFilter(condition=[>($1, 10)]) + : +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + +- LogicalProject(cnt=[$1], f=[$0]) + +- LogicalAggregate(group=[{0}], cnt=[COUNT($1)]) + +- LogicalProject(f=[$2], d=[$0]) + +- LogicalFilter(condition=[<($1, 100)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + (b, 10)]) + : : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : +- Calc(select=[cnt, f], where=[<(cnt, 10)]) + : +- HashAggregate(isMerge=[true], groupBy=[f], select=[f, Final_COUNT(count$0) AS cnt]) + : +- Exchange(distribution=[hash[f]]) + : +- LocalHashAggregate(groupBy=[f], select=[f, Partial_COUNT(d) AS count$0]) + : +- Calc(select=[f, d], where=[<(e, 100)]) + : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + +- Union(all=[true], union=[cnt, c]) + :- Calc(select=[cnt, c]) + : +- HashAggregate(isMerge=[true], groupBy=[c], select=[c, Final_COUNT(count$0) AS cnt]) + : +- Exchange(distribution=[hash[c]]) + : +- LocalHashAggregate(groupBy=[c], select=[c, Partial_COUNT(a) AS count$0]) + : +- Calc(select=[c, a], where=[>(b, 10)]) + : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- Calc(select=[cnt, f]) + +- HashAggregate(isMerge=[true], groupBy=[f], select=[f, Final_COUNT(count$0) AS cnt]) + +- Exchange(distribution=[hash[f]]) + +- LocalHashAggregate(groupBy=[f], select=[f, Partial_COUNT(d) AS count$0]) + +- Calc(select=[f, d], where=[<(e, 100)]) + +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) +]]> + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml index 302bc021ffe8cd..a1cc277f0b20d6 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml @@ -240,28 +240,26 @@ LogicalProject(c=[$0], e=[$1], avg_b=[$2], sum_b=[$3], psum=[$4], nsum=[$5], avg (w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c, e], where=[AND(<>(c, _UTF-16LE'':VARCHAR(65536) CHARACTER SET "UTF-16LE"), >(-(sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0)), 3))]) - : : +- OverAggregate(partitionBy=[c, e], orderBy=[], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], window#1=[RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, e, sum_b, w0$o0, w0$o1, w1$o0], reuse_id=[1]) - : : +- Sort(orderBy=[c ASC, e ASC], reuse_id=[2]) - : : +- Exchange(distribution=[hash[c, e]]) - : : +- HashAggregate(isMerge=[true], groupBy=[c, e], select=[c, e, Final_SUM(sum$0) AS sum_b]) - : : +- Exchange(distribution=[hash[c, e]]) - : : +- LocalHashAggregate(groupBy=[c, e], select=[c, e, Partial_SUM(b) AS sum$0]) - : : +- Calc(select=[c, e, b]) - : : +- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e], build=[right]) - : : :- Exchange(distribution=[hash[a]]) - : : : +- Calc(select=[a, b, c], where=[IS NOT NULL(c)]) - : : : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - : : +- Exchange(distribution=[hash[d]]) - : : +- Calc(select=[d, e], where=[>(e, 10)]) - : : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) - : +- Exchange(distribution=[hash[c, e, $f5]], exchange_mode=[BATCH]) - : +- Calc(select=[sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, c, e, +(w1$o0, 1) AS $f5]) - : +- Reused(reference_id=[1]) + :- Calc(select=[sum_b, avg_b, rn, c, e, sum_b0, avg_b0]) + : +- HashJoin(joinType=[InnerJoin], where=[AND(=(c, c0), =(e, e0), =(rn, $f5))], select=[sum_b, avg_b, rn, c, e, sum_b0, avg_b0, c0, e0, $f5], build=[left]) + : :- Exchange(distribution=[hash[c, e, rn]]) + : : +- Calc(select=[sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, w1$o0 AS rn, c, e], where=[AND(<>(c, _UTF-16LE'':VARCHAR(65536) CHARACTER SET "UTF-16LE"), >(-(sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0)), 3))]) + : : +- OverAggregate(partitionBy=[c, e], orderBy=[], window#0=[COUNT(sum_b) AS w0$o0, $SUM0(sum_b) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], window#1=[RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, e, sum_b, w0$o0, w0$o1, w1$o0], reuse_id=[1]) + : : +- Sort(orderBy=[c ASC, e ASC], reuse_id=[2]) + : : +- HashAggregate(isMerge=[true], groupBy=[c, e], select=[c, e, Final_SUM(sum$0) AS sum_b]) + : : +- Exchange(distribution=[hash[c, e]]) + : : +- LocalHashAggregate(groupBy=[c, e], select=[c, e, Partial_SUM(b) AS sum$0]) + : : +- Calc(select=[c, e, b]) + : : +- HashJoin(joinType=[InnerJoin], where=[=(a, d)], select=[a, b, c, d, e], build=[right]) + : : :- Exchange(distribution=[hash[a]]) + : : : +- Calc(select=[a, b, c], where=[IS NOT NULL(c)]) + : : : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : : +- Exchange(distribution=[hash[d]]) + : : +- Calc(select=[d, e], where=[>(e, 10)]) + : : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + : +- Exchange(distribution=[hash[c, e, $f5]], exchange_mode=[BATCH]) + : +- Calc(select=[sum_b, /(CAST(CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT)), w0$o0) AS avg_b, c, e, +(w1$o0, 1) AS $f5]) + : +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[c, e, $f5]], exchange_mode=[BATCH]) +- Calc(select=[sum_b, c, e, -(w0$o0, 1) AS $f5]) +- OverAggregate(partitionBy=[c, e], orderBy=[c ASC, e ASC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[c, e, sum_b, w0$o0]) @@ -651,8 +649,7 @@ HashJoin(joinType=[InnerJoin], where=[=(a, d0)], select=[a, b, c, d, e, f, a0, b : : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : +- Exchange(distribution=[hash[d]]) : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) -+- Exchange(distribution=[hash[d]]) - +- Reused(reference_id=[1]) ++- Reused(reference_id=[1]) ]]> @@ -1061,16 +1058,14 @@ LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], e=[$4], f=[$5], a0=[$6], b0=[$7], diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/GroupingSetsTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/GroupingSetsTest.xml index 07fc423e4ca3d3..7113b095454ab6 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/GroupingSetsTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/GroupingSetsTest.xml @@ -279,18 +279,16 @@ LogicalProject(deptno=[$0], gender=[$2], min_name=[$3]) ($f5, 2), AND(=(gender, _UTF-16LE'M':VARCHAR(65536) CHARACTER SET "UTF-16LE"), =(deptno, 10)))]) -+- SortAggregate(isMerge=[true], groupBy=[deptno, gender, deptno0, $e], select=[deptno, gender, deptno0, $e, Final_MIN(min$0) AS min_name, Final_COUNT(count1$1) AS $f5]) ++- SortAggregate(isMerge=[false], groupBy=[deptno, gender, deptno0, $e], select=[deptno, gender, deptno0, $e, MIN(ename) AS min_name, COUNT(*) AS $f5]) +- Sort(orderBy=[deptno ASC, gender ASC, deptno0 ASC, $e ASC]) +- Exchange(distribution=[hash[deptno, gender, deptno0, $e]]) - +- LocalSortAggregate(groupBy=[deptno, gender, deptno0, $e], select=[deptno, gender, deptno0, $e, Partial_MIN(ename) AS min$0, Partial_COUNT(*) AS count1$1]) - +- Sort(orderBy=[deptno ASC, gender ASC, deptno0 ASC, $e ASC]) - +- Expand(projects=[{ename=[$0], deptno=[$1], gender=[$2], deptno0=[$3], $e=[0]}, {ename=[$0], deptno=[$1], gender=[$2], deptno0=[null], $e=[1]}, {ename=[$0], deptno=[$1], gender=[null], deptno0=[$3], $e=[2]}, {ename=[$0], deptno=[$1], gender=[null], deptno0=[null], $e=[3]}, {ename=[$0], deptno=[null], gender=[$2], deptno0=[$3], $e=[4]}, {ename=[$0], deptno=[null], gender=[$2], deptno0=[null], $e=[5]}, {ename=[$0], deptno=[null], gender=[null], deptno0=[$3], $e=[6]}, {ename=[$0], deptno=[null], gender=[null], deptno0=[null], $e=[7]}], projects=[{ename, deptno, gender, deptno0, 0 AS $e}, {ename, deptno, gender, null AS deptno0, 1 AS $e}, {ename, deptno, null AS gender, deptno0, 2 AS $e}, {ename, deptno, null AS gender, null AS deptno0, 3 AS $e}, {ename, null AS deptno, gender, deptno0, 4 AS $e}, {ename, null AS deptno, gender, null AS deptno0, 5 AS $e}, {ename, null AS deptno, null AS gender, deptno0, 6 AS $e}, {ename, null AS deptno, null AS gender, null AS deptno0, 7 AS $e}]) - +- HashJoin(joinType=[InnerJoin], where=[=(deptno, deptno0)], select=[ename, deptno, gender, deptno0], build=[right]) - :- Exchange(distribution=[hash[deptno]]) - : +- TableSourceScan(table=[[emp, source: [TestTableSource(ename, deptno, gender)]]], fields=[ename, deptno, gender]) - +- Exchange(distribution=[hash[deptno]]) - +- Calc(select=[deptno]) - +- TableSourceScan(table=[[dept, source: [TestTableSource(deptno, dname)]]], fields=[deptno, dname]) + +- Expand(projects=[{ename=[$0], deptno=[$1], gender=[$2], deptno0=[$3], $e=[0]}, {ename=[$0], deptno=[$1], gender=[$2], deptno0=[null], $e=[1]}, {ename=[$0], deptno=[$1], gender=[null], deptno0=[$3], $e=[2]}, {ename=[$0], deptno=[$1], gender=[null], deptno0=[null], $e=[3]}, {ename=[$0], deptno=[null], gender=[$2], deptno0=[$3], $e=[4]}, {ename=[$0], deptno=[null], gender=[$2], deptno0=[null], $e=[5]}, {ename=[$0], deptno=[null], gender=[null], deptno0=[$3], $e=[6]}, {ename=[$0], deptno=[null], gender=[null], deptno0=[null], $e=[7]}], projects=[{ename, deptno, gender, deptno0, 0 AS $e}, {ename, deptno, gender, null AS deptno0, 1 AS $e}, {ename, deptno, null AS gender, deptno0, 2 AS $e}, {ename, deptno, null AS gender, null AS deptno0, 3 AS $e}, {ename, null AS deptno, gender, deptno0, 4 AS $e}, {ename, null AS deptno, gender, null AS deptno0, 5 AS $e}, {ename, null AS deptno, null AS gender, deptno0, 6 AS $e}, {ename, null AS deptno, null AS gender, null AS deptno0, 7 AS $e}]) + +- HashJoin(joinType=[InnerJoin], where=[=(deptno, deptno0)], select=[ename, deptno, gender, deptno0], build=[right]) + :- Exchange(distribution=[hash[deptno]]) + : +- TableSourceScan(table=[[emp, source: [TestTableSource(ename, deptno, gender)]]], fields=[ename, deptno, gender]) + +- Exchange(distribution=[hash[deptno]]) + +- Calc(select=[deptno]) + +- TableSourceScan(table=[[dept, source: [TestTableSource(deptno, dname)]]], fields=[deptno, dname]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.xml index 5a6196a8ef2055..c10ad212c4798b 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.xml @@ -16,7 +16,40 @@ See the License for the specific language governing permissions and limitations under the License. --> - + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w1$o0 AS EXPR$1, /(CAST(CASE(>(w2$o0, 0:BIGINT), w2$o1, null:INTEGER)), w2$o0) AS EXPR$2, w0$o2 AS EXPR$3, w2$o2 AS EXPR$4]) ++- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[COUNT(a) AS w2$o0, $SUM0(a) AS w2$o1, MIN(a) AS w2$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2, w1$o0, w2$o0, w2$o1, w2$o2]) + +- Sort(orderBy=[c ASC, a ASC]) + +- Exchange(distribution=[hash[c]]) + +- OverAggregate(partitionBy=[b], orderBy=[c ASC], window#0=[MAX(a) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2, w1$o0]) + +- Sort(orderBy=[b ASC, c ASC]) + +- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, RANK(*) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2]) + +- Sort(orderBy=[b ASC, a ASC]) + +- Exchange(distribution=[hash[b]]) + +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w1$o0 AS EXPR$1, w2$o0 AS EXPR$2, w0$o2 AS EXPR$3, /(CAST(CASE(>(w3$o0, 0:BIGINT), w3$o1, null:INTEGER)), w3$o0) AS EXPR$4]) -+- OverAggregate(orderBy=[b ASC], window#0=[COUNT(a) AS w3$o0, $SUM0(a) AS w3$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2, w1$o0, w2$o0, w3$o0, w3$o1]) - +- Sort(orderBy=[b ASC]) - +- Exchange(distribution=[single]) - +- OverAggregate(orderBy=[c ASC, a ASC], window#0=[MIN(a) AS w2$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2, w1$o0, w2$o0]) - +- Sort(orderBy=[c ASC, a ASC]) - +- Exchange(distribution=[single]) - +- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[MAX(a) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2, w1$o0]) ++- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[MAX(a) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o2, w1$o0, w0$o1, w3$o0, w3$o1, w2$o0, w0$o0]) + +- Sort(orderBy=[c ASC, a ASC]) + +- Exchange(distribution=[hash[c]]) + +- OverAggregate(partitionBy=[b], orderBy=[c ASC], window#0=[COUNT(a) AS w3$o0, $SUM0(a) AS w3$o1, RANK(*) AS w2$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o2, w1$o0, w0$o1, w3$o0, w3$o1, w2$o0]) + +- Sort(orderBy=[b ASC, c ASC]) + +- Exchange(distribution=[hash[b]]) + +- OverAggregate(orderBy=[c ASC, a ASC], window#0=[MIN(a) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o2, w1$o0, w0$o1]) +- Sort(orderBy=[c ASC, a ASC]) - +- Exchange(distribution=[hash[c]]) - +- OverAggregate(partitionBy=[b], orderBy=[c ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, RANK(*) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w0$o2]) - +- Sort(orderBy=[b ASC, c ASC]) - +- Exchange(distribution=[hash[b]]) - +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + +- OverAggregate(orderBy=[b ASC], window#0=[COUNT(a) AS w0$o2, $SUM0(a) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o2, w1$o0]) + +- Sort(orderBy=[b ASC]) + +- Exchange(distribution=[single]) + +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> @@ -102,14 +134,12 @@ LogicalProject(EXPR$0=[COUNT() OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RAN (w1$o0, 0:BIGINT), w1$o1, null:INTEGER) AS EXPR$1, w2$o0 AS EXPR$2, CASE(>(w3$o0, 0:BIGINT), w3$o1, null:INTEGER) AS EXPR$3, w4$o0 AS EXPR$4]) -+- OverAggregate(partitionBy=[c], orderBy=[c ASC], window#0=[COUNT(*) AS w4$o0 ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING], select=[a, c, w0$o0, w1$o0, w1$o1, w2$o0, w3$o0, w3$o1, w4$o0]) - +- Sort(orderBy=[c ASC]) - +- Exchange(distribution=[hash[c]]) - +- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[COUNT(*) AS w0$o0 RANG BETWEEN -1 PRECEDING AND 10 FOLLOWING], window#1=[COUNT(a) AS w1$o0, $SUM0(a) AS w1$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#2=[RANK(*) AS w2$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#3=[COUNT(a) AS w3$o0, $SUM0(a) AS w3$o1 RANG BETWEEN 1 PRECEDING AND 10 FOLLOWING], select=[a, c, w0$o0, w1$o0, w1$o1, w2$o0, w3$o0, w3$o1]) - +- Sort(orderBy=[c ASC, a ASC]) - +- Exchange(distribution=[hash[c]]) - +- Calc(select=[a, c]) - +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ++- OverAggregate(partitionBy=[c], orderBy=[c ASC], window#0=[COUNT(*) AS w4$o0 ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING], select=[a, c, w0$o0, w1$o0, w1$o1, w3$o1, w2$o0, w3$o0, w4$o0]) + +- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[COUNT(*) AS w0$o0 RANG BETWEEN -1 PRECEDING AND 10 FOLLOWING], window#1=[COUNT(a) AS w1$o0, $SUM0(a) AS w1$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#2=[COUNT(a) AS w3$o1, $SUM0(a) AS w2$o0 RANG BETWEEN 1 PRECEDING AND 10 FOLLOWING], window#3=[RANK(*) AS w3$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, w0$o0, w1$o0, w1$o1, w3$o1, w2$o0, w3$o0]) + +- Sort(orderBy=[c ASC, a ASC]) + +- Exchange(distribution=[hash[c]]) + +- Calc(select=[a, c]) + +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> @@ -185,36 +215,6 @@ Calc(select=[c, w0$o0 AS $1]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[a, c]) +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -]]> - - - - - - - - (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MIN($0) OVER (PARTITION BY $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$2=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) -+- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) -]]> - - - (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w1$o0 AS EXPR$1, w0$o2 AS EXPR$2]) -+- OverAggregate(partitionBy=[b], window#0=[MIN(a) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[a, b, w0$o0, w0$o1, w0$o2, w1$o0]) - +- Sort(orderBy=[b ASC]) - +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, MAX(a) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, w0$o0, w0$o1, w0$o2]) - +- Sort(orderBy=[b ASC, a ASC]) - +- Exchange(distribution=[hash[b]]) - +- Calc(select=[a, b]) - +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> @@ -417,33 +417,38 @@ Calc(select=[CASE(>(w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w0$o2 AS E ]]> - + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) +LogicalProject(EXPR$0=[CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $1 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) ]]> (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w0$o2 AS EXPR$1]) -+- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, MAX(a) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, w0$o0, w0$o1, w0$o2]) - +- Sort(orderBy=[b ASC, a ASC]) - +- Exchange(distribution=[hash[b]]) - +- Calc(select=[a, b]) - +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +Calc(select=[CASE(>(w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w1$o0 AS EXPR$1, /(CAST(CASE(>(w2$o0, 0:BIGINT), w2$o1, null:INTEGER)), w2$o0) AS EXPR$2, w0$o2 AS EXPR$3, w1$o1 AS EXPR$4]) ++- OverAggregate(partitionBy=[b], orderBy=[c ASC], window#0=[COUNT(a) AS w2$o0, $SUM0(a) AS w2$o1, RANK(*) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o1, w0$o2, w1$o1, w0$o0, w2$o0, w2$o1, w1$o0]) + +- Sort(orderBy=[b ASC, c ASC]) + +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MAX(a) AS w1$o1, MIN(a) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o1, w0$o2, w1$o1, w0$o0]) + +- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w0$o1, $SUM0(a) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o1, w0$o2]) + +- Sort(orderBy=[b ASC, a ASC]) + +- Exchange(distribution=[hash[b]]) + +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> - + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w1$o0 AS EXPR$1, /(CAST(CASE(>(w2$o0, 0:BIGINT), w2$o1, null:INTEGER)), w2$o0) AS EXPR$2, w3$o0 AS EXPR$3, w4$o0 AS EXPR$4]) -+- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MIN(a) AS w4$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w1$o0, w2$o0, w2$o1, w3$o0, w4$o0]) - +- Sort(orderBy=[b ASC]) - +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[a ASC, b ASC], window#0=[RANK(*) AS w3$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w1$o0, w2$o0, w2$o1, w3$o0]) - +- Sort(orderBy=[b ASC, a ASC]) - +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[a ASC, c ASC], window#0=[COUNT(a) AS w2$o0, $SUM0(a) AS w2$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w1$o0, w2$o0, w2$o1]) - +- Sort(orderBy=[b ASC, a ASC, c ASC]) - +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[MAX(a) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1, w1$o0]) - +- Sort(orderBy=[b ASC, a ASC]) - +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[c ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0, w0$o1]) - +- Sort(orderBy=[b ASC, c ASC]) - +- Exchange(distribution=[hash[b]]) - +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ++- OverAggregate(partitionBy=[b], orderBy=[c ASC], window#0=[COUNT(a) AS w1$o0, $SUM0(a) AS w3$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w2$o0, w0$o0, w2$o1, w4$o0, w0$o1, w1$o0, w3$o0]) + +- Sort(orderBy=[b ASC, c ASC]) + +- OverAggregate(partitionBy=[b], orderBy=[b ASC], window#0=[MIN(a) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w2$o0, w0$o0, w2$o1, w4$o0, w0$o1]) + +- OverAggregate(partitionBy=[b], orderBy=[a ASC, c ASC], window#0=[COUNT(a) AS w2$o1, $SUM0(a) AS w4$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w2$o0, w0$o0, w2$o1, w4$o0]) + +- Sort(orderBy=[b ASC, a ASC, c ASC]) + +- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[MAX(a) AS w2$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w2$o0, w0$o0]) + +- Sort(orderBy=[b ASC, a ASC]) + +- Exchange(distribution=[hash[b]]) + +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> - + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) ]]> (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w0$o2 AS EXPR$1]) ++- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, MAX(a) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, w0$o0, w0$o1, w0$o2]) + +- Sort(orderBy=[b ASC, a ASC]) +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[c ASC, a DESC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0]) - +- Sort(orderBy=[b ASC, c ASC, a DESC]) - +- Exchange(distribution=[hash[b]]) + +- Calc(select=[a, b]) + +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w0$o2 AS EXPR$1, /(CAST(CASE(>(w1$o0, 0:BIGINT), w1$o1, null:INTEGER)), w1$o0) AS EXPR$2, w0$o3 AS EXPR$3, w1$o2 AS EXPR$4]) ++- OverAggregate(partitionBy=[b], orderBy=[a DESC], window#0=[COUNT(a) AS w1$o0, $SUM0(a) AS w1$o1, MIN(a) AS w1$o2 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, w0$o0, w0$o1, w0$o2, w0$o3, w1$o0, w1$o1, w1$o2]) + +- Sort(orderBy=[b ASC, a DESC]) + +- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, MAX(a) AS w0$o2, RANK(*) AS w0$o3 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], window#1=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, w0$o0, w0$o1, w0$o2, w0$o3]) + +- Sort(orderBy=[b ASC, a ASC]) + +- Exchange(distribution=[hash[b]]) + +- Calc(select=[a, b]) +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> - + + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MIN($0) OVER (PARTITION BY $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$2=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + (w0$o0, 0:BIGINT), w0$o1, null:INTEGER) AS EXPR$0, w1$o0 AS EXPR$1, w0$o2 AS EXPR$2]) ++- OverAggregate(partitionBy=[b], orderBy=[a ASC], window#0=[COUNT(a) AS w1$o0, $SUM0(a) AS w0$o0, MAX(a) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, w0$o2, w1$o0, w0$o0, w0$o1]) +- Sort(orderBy=[b ASC, a ASC]) - +- Exchange(distribution=[hash[b]]) - +- OverAggregate(partitionBy=[b], orderBy=[a DESC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, c, w0$o0]) - +- Sort(orderBy=[b ASC, a DESC]) - +- Exchange(distribution=[hash[b]]) + +- OverAggregate(partitionBy=[b], window#0=[MIN(a) AS w0$o2 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[a, b, w0$o2]) + +- Sort(orderBy=[b ASC]) + +- Exchange(distribution=[hash[b]]) + +- Calc(select=[a, b]) +- TableSourceScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml index a2bb48731b8978..39ec889a551d1b 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml @@ -1054,13 +1054,12 @@ Calc(select=[b]) : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) +- Exchange(distribution=[broadcast]) +- SortAggregate(isMerge=[false], select=[SINGLE_VALUE(EXPR$0) AS $f0]) - +- Exchange(distribution=[single]) - +- Calc(select=[*(0.5:DECIMAL(2, 1), $f0) AS EXPR$0]) - +- HashAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) - +- Exchange(distribution=[single]) - +- LocalHashAggregate(select=[Partial_SUM(j) AS sum$0]) - +- Calc(select=[j], where=[<(i, 100)]) - +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) + +- Calc(select=[*(0.5:DECIMAL(2, 1), $f0) AS EXPR$0]) + +- SortAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) + +- Exchange(distribution=[single]) + +- LocalSortAggregate(select=[Partial_SUM(j) AS sum$0]) + +- Calc(select=[j], where=[<(i, 100)]) + +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml index aefff811e83a09..441d2f168b27ba 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopJoinTest.xml @@ -124,8 +124,8 @@ LogicalProject(d=[$0], e=[$1], f=[$2], g=[$3], h=[$4], a=[$6], b=[$7], c=[$8]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml index 117d8804f38582..8ad7125a2cf8dc 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml @@ -1000,12 +1000,11 @@ Calc(select=[a]) : :- Exchange(distribution=[hash[e]]) : : +- Calc(select=[d, e]) : : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) - : +- Exchange(distribution=[hash[j]]) - : +- HashAggregate(isMerge=[true], groupBy=[j], select=[j]) - : +- Exchange(distribution=[hash[j]]) - : +- LocalHashAggregate(groupBy=[j], select=[j]) - : +- Calc(select=[j]) - : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : +- HashAggregate(isMerge=[true], groupBy=[j], select=[j]) + : +- Exchange(distribution=[hash[j]]) + : +- LocalHashAggregate(groupBy=[j], select=[j]) + : +- Calc(select=[j]) + : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) +- Exchange(distribution=[hash[i, k]]) +- Calc(select=[i, k]) +- Reused(reference_id=[1]) @@ -1330,13 +1329,12 @@ Calc(select=[b]) : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) +- Exchange(distribution=[broadcast]) +- SortAggregate(isMerge=[false], select=[SINGLE_VALUE(EXPR$0) AS $f0]) - +- Exchange(distribution=[single]) - +- Calc(select=[*(0.5:DECIMAL(2, 1), $f0) AS EXPR$0]) - +- HashAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) - +- Exchange(distribution=[single]) - +- LocalHashAggregate(select=[Partial_SUM(j) AS sum$0]) - +- Calc(select=[j], where=[<(i, 100)]) - +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) + +- Calc(select=[*(0.5:DECIMAL(2, 1), $f0) AS EXPR$0]) + +- SortAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) + +- Exchange(distribution=[single]) + +- LocalSortAggregate(select=[Partial_SUM(j) AS sum$0]) + +- Calc(select=[j], where=[<(i, 100)]) + +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) ]]> @@ -1611,15 +1609,14 @@ Calc(select=[a]) :- Exchange(distribution=[hash[b]]) : +- Calc(select=[a, b]) : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - +- Exchange(distribution=[hash[e]]) - +- HashAggregate(isMerge=[true], groupBy=[e], select=[e]) - +- Exchange(distribution=[hash[e]]) - +- LocalHashAggregate(groupBy=[e], select=[e]) - +- Union(all=[true], union=[e]) - :- Calc(select=[e], where=[>(d, 10)]) - : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) - +- Calc(select=[CAST(i) AS i], where=[<(i, 100)]) - +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) + +- HashAggregate(isMerge=[true], groupBy=[e], select=[e]) + +- Exchange(distribution=[hash[e]]) + +- LocalHashAggregate(groupBy=[e], select=[e]) + +- Union(all=[true], union=[e]) + :- Calc(select=[e], where=[>(d, 10)]) + : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + +- Calc(select=[CAST(i) AS i], where=[<(i, 100)]) + +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) ]]> @@ -1933,19 +1930,18 @@ LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 100))]) =(CAST(c), 1:BIGINT)]) -: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -: : +- Exchange(distribution=[hash[j]]) -: : +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)]) -: : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) -: +- Exchange(distribution=[hash[i]]) -: +- Calc(select=[i], where=[<(j, 100)]) -: +- Reused(reference_id=[1]) +:- HashJoin(joinType=[LeftAntiJoin], where=[=(a, i)], select=[a, b, c], build=[right]) +: :- Exchange(distribution=[hash[a]]) +: : +- HashJoin(joinType=[LeftAntiJoin], where=[=(b, j)], select=[a, b, c], build=[right]) +: : :- Exchange(distribution=[hash[b]]) +: : : +- Calc(select=[a, b, c], where=[>=(CAST(c), 1:BIGINT)]) +: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +: : +- Exchange(distribution=[hash[j]]) +: : +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)]) +: : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) +: +- Exchange(distribution=[hash[i]]) +: +- Calc(select=[i], where=[<(j, 100)]) +: +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[d]]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> @@ -2213,12 +2209,11 @@ LogicalFilter(condition=[=($cor0.a, $0)]) HashJoin(joinType=[LeftAntiJoin], where=[=(a, c)], select=[a, b], build=[right]) :- Exchange(distribution=[hash[a]]) : +- TableSourceScan(table=[[leftT, source: [TestTableSource(a, b)]]], fields=[a, b]) -+- Exchange(distribution=[hash[c]]) - +- HashAggregate(isMerge=[true], groupBy=[c], select=[c]) - +- Exchange(distribution=[hash[c]]) - +- LocalHashAggregate(groupBy=[c], select=[c]) - +- Calc(select=[c]) - +- TableSourceScan(table=[[rightT, source: [TestTableSource(c, d)]]], fields=[c, d]) ++- HashAggregate(isMerge=[true], groupBy=[c], select=[c]) + +- Exchange(distribution=[hash[c]]) + +- LocalHashAggregate(groupBy=[c], select=[c]) + +- Calc(select=[c]) + +- TableSourceScan(table=[[rightT, source: [TestTableSource(c, d)]]], fields=[c, d]) ]]> @@ -2312,39 +2307,36 @@ Calc(select=[b]) :- Exchange(distribution=[hash[c]]) : +- Calc(select=[b, c, CASE(OR(=(c0, 0), AND(<>(c0, 0), IS NULL(i0), >=(ck, c0), IS NOT NULL(a))), 1, OR(=(c1, 0), AND(<>(c1, 0), IS NULL(i), >=(ck0, c1), IS NOT NULL(a))), 2, 3) AS $f3]) : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, EXPR$0)], select=[a, b, c, c0, ck, i0, c1, ck0, EXPR$0, i], build=[right]) - : :- Exchange(distribution=[hash[a]]) - : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck, i0, c1, ck0], build=[right], singleRowJoin=[true]) - : : :- Calc(select=[a, b, c, c0, ck, i0]) - : : : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, c0, ck, i, i0], build=[right]) - : : : :- Exchange(distribution=[hash[a]]) - : : : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck], build=[right], singleRowJoin=[true]) - : : : : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - : : : : +- Exchange(distribution=[broadcast]) - : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : : : +- Exchange(distribution=[single]) - : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) - : : : : +- Calc(select=[i]) - : : : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- LocalHashAggregate(groupBy=[i], select=[i]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- Reused(reference_id=[1]) - : : +- Exchange(distribution=[broadcast]) - : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[CAST(j) AS EXPR$0]) - : : +- Reused(reference_id=[1]) - : +- Exchange(distribution=[hash[EXPR$0]]) - : +- Calc(select=[EXPR$0, true AS i]) - : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) - : +- Exchange(distribution=[hash[EXPR$0]]) - : +- LocalHashAggregate(groupBy=[EXPR$0], select=[EXPR$0]) - : +- Calc(select=[CAST(j) AS EXPR$0, true AS i]) - : +- Reused(reference_id=[1]) + : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck, i0, c1, ck0], build=[right], singleRowJoin=[true]) + : : :- Calc(select=[a, b, c, c0, ck, i0]) + : : : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, c0, ck, i, i0], build=[right]) + : : : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck], build=[right], singleRowJoin=[true]) + : : : : :- Exchange(distribution=[hash[a]]) + : : : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : : : : +- Exchange(distribution=[broadcast]) + : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : : : +- Exchange(distribution=[single]) + : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) + : : : : +- Calc(select=[i]) + : : : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) + : : : +- Exchange(distribution=[hash[i]]) + : : : +- LocalHashAggregate(groupBy=[i], select=[i]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- Reused(reference_id=[1]) + : : +- Exchange(distribution=[broadcast]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) + : : +- Calc(select=[CAST(j) AS EXPR$0]) + : : +- Reused(reference_id=[1]) + : +- Calc(select=[EXPR$0, true AS i]) + : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) + : +- Exchange(distribution=[hash[EXPR$0]]) + : +- LocalHashAggregate(groupBy=[EXPR$0], select=[EXPR$0]) + : +- Calc(select=[CAST(j) AS EXPR$0, true AS i]) + : +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[f]]) +- Calc(select=[d, f]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) @@ -2557,36 +2549,33 @@ Calc(select=[b]) +- NestedLoopJoin(joinType=[LeftAntiJoin], where=[AND(OR(=(b, e), IS NULL(b), IS NULL(e)), OR(=($f3, d), IS NULL(d)))], select=[b, $f3], build=[right]) :- Calc(select=[b, CASE(OR(=(c0, 0), AND(<>(c0, 0), IS NULL(i0), >=(ck, c0), IS NOT NULL(a))), 1, OR(=(c, 0), AND(<>(c, 0), IS NULL(i), >=(ck0, c), IS NOT NULL(a))), 2, 3) AS $f3]) : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, j)], select=[a, b, c0, ck, i0, c, ck0, j, i], build=[right]) - : :- Exchange(distribution=[hash[a]]) - : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c0, ck, i0, c, ck0], build=[right], singleRowJoin=[true]) - : : :- Calc(select=[a, b, c AS c0, ck, i0]) - : : : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, ck, i, i0], build=[right]) - : : : :- Exchange(distribution=[hash[a]]) - : : : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, ck], build=[right], singleRowJoin=[true]) - : : : : :- Calc(select=[a, b]) - : : : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - : : : : +- Exchange(distribution=[broadcast]) - : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : : : +- Exchange(distribution=[single]) - : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) - : : : : +- TableSourceScan(table=[[t1, source: [TestTableSource(i)]]], fields=[i], reuse_id=[1]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- LocalHashAggregate(groupBy=[i], select=[i]) - : : : +- Reused(reference_id=[1]) - : : +- Exchange(distribution=[broadcast]) - : : +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(j) AS count$1]) - : : +- TableSourceScan(table=[[t2, source: [TestTableSource(j)]]], fields=[j], reuse_id=[2]) - : +- Exchange(distribution=[hash[j]]) - : +- Calc(select=[j, true AS i]) - : +- HashAggregate(isMerge=[true], groupBy=[j], select=[j]) - : +- Exchange(distribution=[hash[j]]) - : +- LocalHashAggregate(groupBy=[j], select=[j]) - : +- Reused(reference_id=[2]) + : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c0, ck, i0, c, ck0], build=[right], singleRowJoin=[true]) + : : :- Calc(select=[a, b, c AS c0, ck, i0]) + : : : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, ck, i, i0], build=[right]) + : : : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, ck], build=[right], singleRowJoin=[true]) + : : : : :- Exchange(distribution=[hash[a]]) + : : : : : +- Calc(select=[a, b]) + : : : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : : : : +- Exchange(distribution=[broadcast]) + : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : : : +- Exchange(distribution=[single]) + : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) + : : : : +- TableSourceScan(table=[[t1, source: [TestTableSource(i)]]], fields=[i], reuse_id=[1]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) + : : : +- Exchange(distribution=[hash[i]]) + : : : +- LocalHashAggregate(groupBy=[i], select=[i]) + : : : +- Reused(reference_id=[1]) + : : +- Exchange(distribution=[broadcast]) + : : +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : +- Exchange(distribution=[single]) + : : +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(j) AS count$1]) + : : +- TableSourceScan(table=[[t2, source: [TestTableSource(j)]]], fields=[j], reuse_id=[2]) + : +- Calc(select=[j, true AS i]) + : +- HashAggregate(isMerge=[true], groupBy=[j], select=[j]) + : +- Exchange(distribution=[hash[j]]) + : +- LocalHashAggregate(groupBy=[j], select=[j]) + : +- Reused(reference_id=[2]) +- Exchange(distribution=[broadcast]) +- Calc(select=[e, d]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml index c687cc57777036..e84cca4a9b40da 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashJoinTest.xml @@ -59,8 +59,8 @@ LogicalProject(d=[$0], e=[$1], f=[$2], g=[$3], h=[$4], a=[$6], b=[$7], c=[$8]) @@ -1324,15 +1322,14 @@ Calc(select=[a]) :- Exchange(distribution=[hash[b]]) : +- Calc(select=[a, b]) : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - +- Exchange(distribution=[hash[e]]) - +- HashAggregate(isMerge=[true], groupBy=[e], select=[e]) - +- Exchange(distribution=[hash[e]]) - +- LocalHashAggregate(groupBy=[e], select=[e]) - +- Union(all=[true], union=[e]) - :- Calc(select=[e], where=[>(d, 10)]) - : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) - +- Calc(select=[CAST(i) AS i], where=[<(i, 100)]) - +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) + +- HashAggregate(isMerge=[true], groupBy=[e], select=[e]) + +- Exchange(distribution=[hash[e]]) + +- LocalHashAggregate(groupBy=[e], select=[e]) + +- Union(all=[true], union=[e]) + :- Calc(select=[e], where=[>(d, 10)]) + : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + +- Calc(select=[CAST(i) AS i], where=[<(i, 100)]) + +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) ]]> @@ -1439,19 +1436,18 @@ LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 100))]) =(CAST(c), 1:BIGINT)]) -: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -: : +- Exchange(distribution=[hash[j]]) -: : +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)]) -: : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) -: +- Exchange(distribution=[hash[i]]) -: +- Calc(select=[i], where=[<(j, 100)]) -: +- Reused(reference_id=[1]) +:- HashJoin(joinType=[LeftAntiJoin], where=[=(a, i)], select=[a, b, c], build=[right]) +: :- Exchange(distribution=[hash[a]]) +: : +- HashJoin(joinType=[LeftAntiJoin], where=[=(b, j)], select=[a, b, c], build=[right]) +: : :- Exchange(distribution=[hash[b]]) +: : : +- Calc(select=[a, b, c], where=[>=(CAST(c), 1:BIGINT)]) +: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +: : +- Exchange(distribution=[hash[j]]) +: : +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)]) +: : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) +: +- Exchange(distribution=[hash[i]]) +: +- Calc(select=[i], where=[<(j, 100)]) +: +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[d]]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> @@ -1805,12 +1801,11 @@ LogicalFilter(condition=[=($cor0.a, $0)]) HashJoin(joinType=[LeftAntiJoin], where=[=(a, c)], select=[a, b], build=[right]) :- Exchange(distribution=[hash[a]]) : +- TableSourceScan(table=[[leftT, source: [TestTableSource(a, b)]]], fields=[a, b]) -+- Exchange(distribution=[hash[c]]) - +- HashAggregate(isMerge=[true], groupBy=[c], select=[c]) - +- Exchange(distribution=[hash[c]]) - +- LocalHashAggregate(groupBy=[c], select=[c]) - +- Calc(select=[c]) - +- TableSourceScan(table=[[rightT, source: [TestTableSource(c, d)]]], fields=[c, d]) ++- HashAggregate(isMerge=[true], groupBy=[c], select=[c]) + +- Exchange(distribution=[hash[c]]) + +- LocalHashAggregate(groupBy=[c], select=[c]) + +- Calc(select=[c]) + +- TableSourceScan(table=[[rightT, source: [TestTableSource(c, d)]]], fields=[c, d]) ]]> @@ -1904,39 +1899,36 @@ Calc(select=[b]) :- Exchange(distribution=[hash[c]]) : +- Calc(select=[b, c, CASE(OR(=(c0, 0), AND(<>(c0, 0), IS NULL(i0), >=(ck, c0), IS NOT NULL(a))), 1, OR(=(c1, 0), AND(<>(c1, 0), IS NULL(i), >=(ck0, c1), IS NOT NULL(a))), 2, 3) AS $f3]) : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, EXPR$0)], select=[a, b, c, c0, ck, i0, c1, ck0, EXPR$0, i], build=[right]) - : :- Exchange(distribution=[hash[a]]) - : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck, i0, c1, ck0], build=[right], singleRowJoin=[true]) - : : :- Calc(select=[a, b, c, c0, ck, i0]) - : : : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, c0, ck, i, i0], build=[right]) - : : : :- Exchange(distribution=[hash[a]]) - : : : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck], build=[right], singleRowJoin=[true]) - : : : : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - : : : : +- Exchange(distribution=[broadcast]) - : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : : : +- Exchange(distribution=[single]) - : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) - : : : : +- Calc(select=[i]) - : : : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- LocalHashAggregate(groupBy=[i], select=[i]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- Reused(reference_id=[1]) - : : +- Exchange(distribution=[broadcast]) - : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[CAST(j) AS EXPR$0]) - : : +- Reused(reference_id=[1]) - : +- Exchange(distribution=[hash[EXPR$0]]) - : +- Calc(select=[EXPR$0, true AS i]) - : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) - : +- Exchange(distribution=[hash[EXPR$0]]) - : +- LocalHashAggregate(groupBy=[EXPR$0], select=[EXPR$0]) - : +- Calc(select=[CAST(j) AS EXPR$0, true AS i]) - : +- Reused(reference_id=[1]) + : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck, i0, c1, ck0], build=[right], singleRowJoin=[true]) + : : :- Calc(select=[a, b, c, c0, ck, i0]) + : : : +- HashJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, c0, ck, i, i0], build=[right]) + : : : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck], build=[right], singleRowJoin=[true]) + : : : : :- Exchange(distribution=[hash[a]]) + : : : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : : : : +- Exchange(distribution=[broadcast]) + : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : : : +- Exchange(distribution=[single]) + : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) + : : : : +- Calc(select=[i]) + : : : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) + : : : +- Exchange(distribution=[hash[i]]) + : : : +- LocalHashAggregate(groupBy=[i], select=[i]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- Reused(reference_id=[1]) + : : +- Exchange(distribution=[broadcast]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) + : : +- Calc(select=[CAST(j) AS EXPR$0]) + : : +- Reused(reference_id=[1]) + : +- Calc(select=[EXPR$0, true AS i]) + : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) + : +- Exchange(distribution=[hash[EXPR$0]]) + : +- LocalHashAggregate(groupBy=[EXPR$0], select=[EXPR$0]) + : +- Calc(select=[CAST(j) AS EXPR$0, true AS i]) + : +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[f]]) +- Calc(select=[d, f]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.xml index 14e1cdcc310467..d1b60a4f380393 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SingleRowJoinTest.xml @@ -184,36 +184,6 @@ Calc(select=[a2]) +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0]) +- Calc(select=[0 AS $f0]) +- TableSourceScan(table=[[B, source: [TestTableSource(b1, b2)]]], fields=[b1, b2]) -]]> - - - - - - - - - - - @@ -274,13 +244,12 @@ Calc(select=[a2, EXPR$1]) : +- TableSourceScan(table=[[A, source: [TestTableSource(a1, a2)]]], fields=[a1, a2], reuse_id=[1]) +- Exchange(distribution=[broadcast]) +- SortAggregate(isMerge=[false], select=[SINGLE_VALUE(EXPR$0) AS $f0]) - +- Exchange(distribution=[single]) - +- Calc(select=[*($f0, 0.1:DECIMAL(2, 1)) AS EXPR$0]) - +- HashAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) - +- Exchange(distribution=[single]) - +- LocalHashAggregate(select=[Partial_SUM(a1) AS sum$0]) - +- Calc(select=[a1]) - +- Reused(reference_id=[1]) + +- Calc(select=[*($f0, 0.1:DECIMAL(2, 1)) AS EXPR$0]) + +- SortAggregate(isMerge=[true], select=[Final_SUM(sum$0) AS $f0]) + +- Exchange(distribution=[single]) + +- LocalSortAggregate(select=[Partial_SUM(a1) AS sum$0]) + +- Calc(select=[a1]) + +- Reused(reference_id=[1]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml index 51724cbb2f2b85..dadd6be1741835 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeJoinTest.xml @@ -59,8 +59,8 @@ LogicalProject(d=[$0], e=[$1], f=[$2], g=[$3], h=[$4], a=[$6], b=[$7], c=[$8]) @@ -1409,15 +1407,14 @@ Calc(select=[a]) :- Exchange(distribution=[hash[b]]) : +- Calc(select=[a, b]) : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - +- Exchange(distribution=[hash[e]]) - +- HashAggregate(isMerge=[true], groupBy=[e], select=[e]) - +- Exchange(distribution=[hash[e]]) - +- LocalHashAggregate(groupBy=[e], select=[e]) - +- Union(all=[true], union=[e]) - :- Calc(select=[e], where=[>(d, 10)]) - : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) - +- Calc(select=[CAST(i) AS i], where=[<(i, 100)]) - +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) + +- HashAggregate(isMerge=[true], groupBy=[e], select=[e]) + +- Exchange(distribution=[hash[e]]) + +- LocalHashAggregate(groupBy=[e], select=[e]) + +- Union(all=[true], union=[e]) + :- Calc(select=[e], where=[>(d, 10)]) + : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) + +- Calc(select=[CAST(i) AS i], where=[<(i, 100)]) + +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k]) ]]> @@ -1524,19 +1521,18 @@ LogicalFilter(condition=[AND(=($cor0.a, $0), <($1, 100))]) =(CAST(c), 1:BIGINT)]) -: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -: : +- Exchange(distribution=[hash[j]]) -: : +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)]) -: : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) -: +- Exchange(distribution=[hash[i]]) -: +- Calc(select=[i], where=[<(j, 100)]) -: +- Reused(reference_id=[1]) +:- SortMergeJoin(joinType=[LeftAntiJoin], where=[=(a, i)], select=[a, b, c]) +: :- Exchange(distribution=[hash[a]]) +: : +- SortMergeJoin(joinType=[LeftAntiJoin], where=[=(b, j)], select=[a, b, c]) +: : :- Exchange(distribution=[hash[b]]) +: : : +- Calc(select=[a, b, c], where=[>=(CAST(c), 1:BIGINT)]) +: : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +: : +- Exchange(distribution=[hash[j]]) +: : +- Calc(select=[j], where=[>(CAST(k), 50:BIGINT)]) +: : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) +: +- Exchange(distribution=[hash[i]]) +: +- Calc(select=[i], where=[<(j, 100)]) +: +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[d]]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> @@ -1890,12 +1886,11 @@ LogicalFilter(condition=[=($cor0.a, $0)]) SortMergeJoin(joinType=[LeftAntiJoin], where=[=(a, c)], select=[a, b]) :- Exchange(distribution=[hash[a]]) : +- TableSourceScan(table=[[leftT, source: [TestTableSource(a, b)]]], fields=[a, b]) -+- Exchange(distribution=[hash[c]]) - +- HashAggregate(isMerge=[true], groupBy=[c], select=[c]) - +- Exchange(distribution=[hash[c]]) - +- LocalHashAggregate(groupBy=[c], select=[c]) - +- Calc(select=[c]) - +- TableSourceScan(table=[[rightT, source: [TestTableSource(c, d)]]], fields=[c, d]) ++- HashAggregate(isMerge=[true], groupBy=[c], select=[c]) + +- Exchange(distribution=[hash[c]]) + +- LocalHashAggregate(groupBy=[c], select=[c]) + +- Calc(select=[c]) + +- TableSourceScan(table=[[rightT, source: [TestTableSource(c, d)]]], fields=[c, d]) ]]> @@ -1989,39 +1984,36 @@ Calc(select=[b]) :- Exchange(distribution=[hash[c]]) : +- Calc(select=[b, c, CASE(OR(=(c0, 0), AND(<>(c0, 0), IS NULL(i0), >=(ck, c0), IS NOT NULL(a))), 1, OR(=(c1, 0), AND(<>(c1, 0), IS NULL(i), >=(ck0, c1), IS NOT NULL(a))), 2, 3) AS $f3]) : +- SortMergeJoin(joinType=[LeftOuterJoin], where=[=(a, EXPR$0)], select=[a, b, c, c0, ck, i0, c1, ck0, EXPR$0, i]) - : :- Exchange(distribution=[hash[a]]) - : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck, i0, c1, ck0], build=[right], singleRowJoin=[true]) - : : :- Calc(select=[a, b, c, c0, ck, i0]) - : : : +- SortMergeJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, c0, ck, i, i0]) - : : : :- Exchange(distribution=[hash[a]]) - : : : : +- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck], build=[right], singleRowJoin=[true]) - : : : : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) - : : : : +- Exchange(distribution=[broadcast]) - : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : : : +- Exchange(distribution=[single]) - : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) - : : : : +- Calc(select=[i]) - : : : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) - : : : +- Exchange(distribution=[hash[i]]) - : : : +- LocalHashAggregate(groupBy=[i], select=[i]) - : : : +- Calc(select=[i, true AS i0]) - : : : +- Reused(reference_id=[1]) - : : +- Exchange(distribution=[broadcast]) - : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[CAST(j) AS EXPR$0]) - : : +- Reused(reference_id=[1]) - : +- Exchange(distribution=[hash[EXPR$0]]) - : +- Calc(select=[EXPR$0, true AS i]) - : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) - : +- Exchange(distribution=[hash[EXPR$0]]) - : +- LocalHashAggregate(groupBy=[EXPR$0], select=[EXPR$0]) - : +- Calc(select=[CAST(j) AS EXPR$0, true AS i]) - : +- Reused(reference_id=[1]) + : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck, i0, c1, ck0], build=[right], singleRowJoin=[true]) + : : :- Calc(select=[a, b, c, c0, ck, i0]) + : : : +- SortMergeJoin(joinType=[LeftOuterJoin], where=[=(a, i)], select=[a, b, c, c0, ck, i, i0]) + : : : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0, ck], build=[right], singleRowJoin=[true]) + : : : : :- Exchange(distribution=[hash[a]]) + : : : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) + : : : : +- Exchange(distribution=[broadcast]) + : : : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : : : +- Exchange(distribution=[single]) + : : : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(i) AS count$1]) + : : : : +- Calc(select=[i]) + : : : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- HashAggregate(isMerge=[true], groupBy=[i], select=[i]) + : : : +- Exchange(distribution=[hash[i]]) + : : : +- LocalHashAggregate(groupBy=[i], select=[i]) + : : : +- Calc(select=[i, true AS i0]) + : : : +- Reused(reference_id=[1]) + : : +- Exchange(distribution=[broadcast]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) + : : +- Calc(select=[CAST(j) AS EXPR$0]) + : : +- Reused(reference_id=[1]) + : +- Calc(select=[EXPR$0, true AS i]) + : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) + : +- Exchange(distribution=[hash[EXPR$0]]) + : +- LocalHashAggregate(groupBy=[EXPR$0], select=[EXPR$0]) + : +- Calc(select=[CAST(j) AS EXPR$0, true AS i]) + : +- Reused(reference_id=[1]) +- Exchange(distribution=[hash[f]]) +- Calc(select=[d, f]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.xml new file mode 100644 index 00000000000000..652470593ff718 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.xml @@ -0,0 +1,280 @@ + + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0), $4, null:INTEGER)], EXPR$1=[$6], EXPR$2=[/(CAST(CASE(>($7, 0), $8, null:INTEGER)):DOUBLE, $7)], EXPR$3=[$5], EXPR$4=[$9]) ++- LogicalWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), RANK()])], window#1=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [MAX($0)])], window#2=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), MIN($0)])]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[MIN($0) OVER (ORDER BY $2 NULLS FIRST, $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[/(CAST(CASE(>(COUNT($0) OVER (ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0), $4, null:INTEGER)], EXPR$1=[$6], EXPR$2=[$7], EXPR$3=[$5], EXPR$4=[/(CAST(CASE(>($8, 0), $9, null:INTEGER)):DOUBLE, $8)]) ++- LogicalProject(a=[$0], b=[$1], c=[$2], w0$o0=[$6], w0$o1=[$7], w0$o2=[$8], w1$o0=[$9], w2$o0=[$5], w3$o0=[$3], w3$o1=[$4]) + +- LogicalWindow(window#0=[window(partition {} order by [1 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0)])], window#1=[window(partition {} order by [2 ASC-nulls-first, 0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [MIN($0)])], window#2=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), RANK()])], window#3=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [MAX($0)])]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0), $4, null:INTEGER)], EXPR$1=[$5], EXPR$2=[/(CAST(CASE(>($7, 0), $8, null:INTEGER)):DOUBLE, $7)], EXPR$3=[$6], EXPR$4=[$9]) ++- LogicalWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), MAX($0), RANK()])], window#1=[window(partition {2} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), MIN($0)])]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$2=[RANK() OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST, $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$3=[CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING), null:INTEGER)], EXPR$4=[COUNT() OVER (PARTITION BY $2 ORDER BY $2 NULLS FIRST ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0), $4, null:INTEGER)], EXPR$2=[$5], EXPR$3=[CASE(>($6, 0), $7, null:INTEGER)], EXPR$4=[$8]) ++- LogicalProject(a=[$0], c=[$1], w0$o0=[$2], w1$o0=[$3], w1$o1=[$4], w2$o0=[$7], w3$o0=[$5], w3$o1=[$6], w4$o0=[$8]) + +- LogicalWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between $2 PRECEDING and $3 FOLLOWING aggs [COUNT()])], window#1=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0)])], window#2=[window(partition {1} order by [0 ASC-nulls-first] range between $4 PRECEDING and $3 FOLLOWING aggs [COUNT($0), $SUM0($0)])], window#3=[window(partition {1} order by [0 ASC-nulls-first, 1 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#4=[window(partition {1} order by [1 ASC-nulls-first] rows between $4 PRECEDING and $3 FOLLOWING aggs [COUNT()])]) + +- LogicalProject(a=[$0], c=[$2]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $1 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0), $4, null:INTEGER)], EXPR$1=[$6], EXPR$2=[/(CAST(CASE(>($8, 0), $9, null:INTEGER)):DOUBLE, $8)], EXPR$3=[$5], EXPR$4=[$7]) ++- LogicalProject(a=[$0], b=[$1], c=[$2], w0$o0=[$7], w0$o1=[$8], w0$o2=[$9], w1$o0=[$5], w1$o1=[$6], w2$o0=[$3], w2$o1=[$4]) + +- LogicalWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0)])], window#1=[window(partition {1} order by [1 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [MAX($0), MIN($0)])], window#2=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), RANK()])]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST, $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST, $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST, $2 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST, $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $1 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0), $4, null:INTEGER)], EXPR$1=[$5], EXPR$2=[/(CAST(CASE(>($6, 0), $7, null:INTEGER)):DOUBLE, $6)], EXPR$3=[$8], EXPR$4=[$9]) ++- LogicalProject(a=[$0], b=[$1], c=[$2], w0$o0=[$8], w0$o1=[$9], w1$o0=[$3], w2$o0=[$5], w2$o1=[$6], w3$o0=[$4], w4$o0=[$7]) + +- LogicalWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [MAX($0)])], window#1=[window(partition {1} order by [0 ASC-nulls-first, 1 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [RANK()])], window#2=[window(partition {1} order by [0 ASC-nulls-first, 2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0)])], window#3=[window(partition {1} order by [1 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [MIN($0)])], window#4=[window(partition {1} order by [2 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0)])]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$2=[/(CAST(CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)):DOUBLE, COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW))], EXPR$3=[RANK() OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)], EXPR$4=[MIN($0) OVER (PARTITION BY $1 ORDER BY $0 DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($2, 0), $3, null:INTEGER)], EXPR$1=[$4], EXPR$2=[/(CAST(CASE(>($6, 0), $7, null:INTEGER)):DOUBLE, $6)], EXPR$3=[$5], EXPR$4=[$8]) ++- LogicalWindow(window#0=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), MAX($0), RANK()])], window#1=[window(partition {1} order by [0 DESC-nulls-last] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), MIN($0)])]) + +- LogicalProject(a=[$0], b=[$1]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + + + + + + + + + + + + (COUNT($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), null:INTEGER)], EXPR$1=[MIN($0) OVER (PARTITION BY $1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$2=[MAX($0) OVER (PARTITION BY $1 ORDER BY $0 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) ++- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($2, 0), $3, null:INTEGER)], EXPR$1=[$5], EXPR$2=[$4]) ++- LogicalProject(a=[$0], b=[$1], w0$o0=[$3], w0$o1=[$4], w0$o2=[$5], w1$o0=[$2]) + +- LogicalWindow(window#0=[window(partition {1} order by [] range between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING aggs [MIN($0)])], window#1=[window(partition {1} order by [0 ASC-nulls-first] range between UNBOUNDED PRECEDING and CURRENT ROW aggs [COUNT($0), $SUM0($0), MAX($0)])]) + +- LogicalProject(a=[$0], b=[$1]) + +- LogicalTableScan(table=[[MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml new file mode 100644 index 00000000000000..2777bf1ae7b640 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.xml @@ -0,0 +1,84 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml new file mode 100644 index 00000000000000..68d7f219e826de --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.xml @@ -0,0 +1,114 @@ + + + + + + + + + + + + + + + + + = 2 + ]]> + + + =($1, 2))]) + +- LogicalProject(a=[$0], rk=[RANK() OVER (PARTITION BY $0 ORDER BY $1 NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)]) + +- LogicalAggregate(group=[{0}], agg#0=[SUM($1)]) + +- LogicalProject(a=[$0], b=[$1]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml new file mode 100644 index 00000000000000..6e791c8a685ce6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.xml @@ -0,0 +1,86 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala new file mode 100644 index 00000000000000..60902273fe6b30 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.batch.sql + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions, Types} +import org.apache.flink.table.plan.stats.TableStats +import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.StringSplit +import org.apache.flink.table.util.{TableFunc1, TableTestBase} + +import com.google.common.collect.ImmutableSet +import org.junit.{Before, Test} + +class RemoveCollationTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource("x", + Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), + Array("a", "b", "c"), + tableStats = Some(new TableStats(100L)) + ) + util.addTableSource("y", + Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), + Array("d", "e", "f"), + tableStats = Some(new TableStats(100L)) + ) + util.addTableSource("t1", + Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), + Array("a1", "b1", "c1"), + tableStats = Some(new TableStats(100L)) + ) + util.addTableSource("t2", + Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), + Array("d1", "e1", "f1"), + tableStats = Some(new TableStats(100L)) + ) + + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED, true) + } + + @Test + def testRemoveCollation_OverWindowAgg(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin,HashAgg") + val sqlQuery = + """ + | SELECT + | SUM(b) sum_b, + | AVG(SUM(b)) OVER (PARTITION BY a order by a) avg_b, + | RANK() OVER (PARTITION BY a ORDER BY a) rn + | FROM x + | GROUP BY a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Aggregate(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Aggregate_1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Sort(): Unit = { + util.tableEnv.getConfig.getConf.setBoolean(TableConfigOptions.SQL_EXEC_SORT_RANGE_ENABLED, true) + val sqlQuery = + """ + |WITH r AS (SELECT a, b, COUNT(c) AS cnt FROM x GROUP BY a, b) + |SELECT * FROM r ORDER BY a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Aggregate_3(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") + util.tableEnv.getConfig.getConf.setBoolean(TableConfigOptions.SQL_EXEC_SORT_RANGE_ENABLED, true) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x ORDER BY a, b) + |SELECT a, b, COUNT(c) AS cnt FROM r GROUP BY a, b + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Rank_1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") + val sqlQuery = + """ + |SELECT a, SUM(b) FROM ( + | SELECT * FROM ( + | SELECT a, b, RANK() OVER(PARTITION BY a ORDER BY b) rk FROM x) + | WHERE rk <= 10 + |) GROUP BY a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Rank_2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") + val sqlQuery = + """ + |SELECT a, b, MAX(c) FROM ( + | SELECT * FROM ( + | SELECT a, b, c, RANK() OVER(PARTITION BY a ORDER BY b) rk FROM x) + | WHERE rk <= 10 + |) GROUP BY a, b + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Rank_3(): Unit = { + // TODO remove local rank for single distribution input + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, c, RANK() OVER(PARTITION BY a ORDER BY b) rk FROM ( + | SELECT a, b, c FROM x ORDER BY a, b + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Rank_4(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, c, RANK() OVER(PARTITION BY a ORDER BY a) rk FROM ( + | SELECT a, COUNT(c) AS c FROM x GROUP BY a + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Rank_Singleton(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") + val sqlQuery = + """ + |SELECT COUNT(a), SUM(b) FROM ( + | SELECT * FROM ( + | SELECT a, b, RANK() OVER(ORDER BY b) rk FROM x) + | WHERE rk <= 10 + |) + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_MultipleSortMergeJoins1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + + val sql = + """ + |select * from + | x join y on a = d + | join t1 on a = a1 + | left outer join t2 on a = d1 + """.stripMargin + + util.verifyPlan(sql) + } + + @Test + def testRemoveCollation_MultipleSortMergeJoins_MultiJoinKeys1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + + val sql = + """ + |select * from + | x join y on a = d and b = e + | join t1 on a = a1 and b = b1 + | left outer join t2 on a = d1 and b = e1 + """.stripMargin + + util.verifyPlan(sql) + } + + @Test + def testRemoveCollation_MultipleSortMergeJoins2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + + val sql = + """ + |select * from + | x join y on a = d + | join t1 on d = a1 + | left outer join t2 on a1 = d1 + """.stripMargin + + util.verifyPlan(sql) + } + + @Test + def testRemoveCollation_MultipleSortMergeJoins_MultiJoinKeys2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + + val sql = + """ + |select * from + | x join y on a = d and b = e + | join t1 on d = a1 and e = b1 + | left outer join t2 on a1 = d1 and b1 = e1 + """.stripMargin + + util.verifyPlan(sql) + } + + @Test + def testRemoveCollation_MultipleSortMergeJoins3(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + util.addTableSource("tb1", + Array[TypeInformation[_]]( + Types.STRING, Types.STRING, Types.STRING, Types.STRING, Types.STRING), + Array("id", "key", "tb2_ids", "tb3_ids", "name"), + uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + ) + util.addTableSource("tb2", + Array[TypeInformation[_]](Types.STRING, Types.STRING), + Array("id", "name"), + uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + ) + util.addTableSource("tb3", + Array[TypeInformation[_]](Types.STRING, Types.STRING), + Array("id", "name"), + uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + ) + util.addTableSource("tb4", + Array[TypeInformation[_]](Types.STRING, Types.STRING), + Array("id", "name"), + uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + ) + util.addTableSource("tb5", + Array[TypeInformation[_]](Types.STRING, Types.STRING), + Array("id", "name"), + uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + ) + util.tableEnv.registerFunction("split", new StringSplit()) + + val sql = + """ + |with v1 as ( + | select id, tb2_id from tb1, LATERAL TABLE(split(tb2_ids)) AS T(tb2_id) + |), + |v2 as ( + | select id, tb3_id from tb1, LATERAL TABLE(split(tb3_ids)) AS T(tb3_id) + |), + | + |join_tb2 as ( + | select tb1_id, concat_agg(tb2_name, ',') as tb2_names + | from ( + | select v1.id as tb1_id, tb2.name as tb2_name + | from v1 left outer join tb2 on tb2_id = tb2.id + | ) group by tb1_id + |), + | + |join_tb3 as ( + | select tb1_id, concat_agg(tb3_name, ',') as tb3_names + | from ( + | select v2.id as tb1_id, tb3.name as tb3_name + | from v2 left outer join tb3 on tb3_id = tb3.id + | ) group by tb1_id + |) + | + |select + | tb1.id, + | tb1.tb2_ids, + | tb1.tb3_ids, + | tb1.name, + | tb2_names, + | tb3_names, + | tb4.name, + | tb5.name + | from tb1 + | left outer join join_tb2 on tb1.id = join_tb2.tb1_id + | left outer join join_tb3 on tb1.id = join_tb3.tb1_id + | left outer join tb4 on tb1.key = tb4.id + | left outer join tb5 on tb1.key = tb5.id + """.stripMargin + + util.verifyPlan(sql) + } + + @Test + def testRemoveCollation_Correlate1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin,HashAgg") + util.tableEnv.registerFunction("split", new TableFunc1) + val sqlQuery = + """ + |WITH r AS (SELECT f, count(f) as cnt FROM y GROUP BY f), + | v as (SELECT f1, f, cnt FROM r, LATERAL TABLE(split(f)) AS T(f1)) + |SELECT * FROM x, v WHERE c = f + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Correlate2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin,HashAgg") + util.tableEnv.registerFunction("split", new TableFunc1) + val sqlQuery = + """ + |WITH r AS (SELECT f, count(f) as cnt FROM y GROUP BY f), + | v as (SELECT f, f1 FROM r, LATERAL TABLE(split(f)) AS T(f1)) + |SELECT * FROM x, v WHERE c = f AND f LIKE '%llo%' + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveCollation_Correlate3(): Unit = { + // do not remove shuffle + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin,HashAgg") + util.tableEnv.registerFunction("split", new TableFunc1) + val sqlQuery = + """ + |WITH r AS (SELECT f, count(f) as cnt FROM y GROUP BY f), + | v as (SELECT f1 FROM r, LATERAL TABLE(split(f)) AS T(f1)) + |SELECT * FROM x, v WHERE c = f1 + """.stripMargin + util.verifyPlan(sqlQuery) + } + +} + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala new file mode 100644 index 00000000000000..29b14b9813c9da --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala @@ -0,0 +1,547 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.batch.sql + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions, Types} +import org.apache.flink.table.plan.stats.TableStats +import org.apache.flink.table.util.{TableFunc1, TableTestBase} + +import org.junit.{Before, Test} + +class RemoveShuffleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource("x", + Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), + Array("a", "b", "c"), + tableStats = Some(new TableStats(100L)) + ) + util.addTableSource("y", + Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), + Array("d", "e", "f"), + tableStats = Some(new TableStats(100L)) + ) + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_REUSE_SUB_PLAN_ENABLED, false) + } + + @Test + def testRemoveHashShuffle_OverWindowAgg(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin,SortAgg") + val sqlQuery = + """ + | SELECT + | SUM(b) sum_b, + | AVG(SUM(b)) OVER (PARTITION BY c) avg_b, + | RANK() OVER (PARTITION BY c ORDER BY c) rn, + | c + | FROM x + | GROUP BY c + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_MultiOverWindowAgg(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin,SortAgg") + val sqlQuery = + """ + | SELECT + | SUM(b) sum_b, + | AVG(SUM(b)) OVER (PARTITION BY a, c) avg_b, + | RANK() OVER (PARTITION BY c ORDER BY a, c) rn, + | c + | FROM x + | GROUP BY a, c + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_OverWindowAgg_PartialKey(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin,SortAgg") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, true) + // push down HashExchange[c] into HashAgg + val sqlQuery = + """ + | SELECT + | SUM(b) sum_b, + | AVG(SUM(b)) OVER (PARTITION BY c) avg_b, + | RANK() OVER (PARTITION BY c ORDER BY c) rn, + | c + | FROM x + | GROUP BY a, c + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Agg_PartialKey(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin,SortAgg") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, true) + // push down HashExchange[c] into HashAgg + val sqlQuery = + """ + | WITH r AS (SELECT a, c, count(b) as cnt FROM x GROUP BY a, c) + | SELECT count(cnt) FROM r group by c + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashAggregate(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashAggregate_1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a, d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashAggregate_2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortAggregate(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,HashAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortAggregate_1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,HashAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a, d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortAggregate_2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,HashAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortMergeJoin(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED, true) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortMergeJoin_LOJ(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED, true) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x left join (SELECT * FROM y WHERE e = 2) r on a = d) + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortMergeJoin_ROJ(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED, true) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x right join (SELECT * FROM y WHERE e = 2) r on a = d) + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_SortMergeJoin_FOJ(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin") + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x full join (SELECT * FROM y WHERE e = 2) r on a = d) + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashJoin(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_BroadcastHashJoin(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashJoin_LOJ(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x left join (SELECT * FROM y WHERE e = 2) r on a = d) + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashJoin_ROJ(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x right join (SELECT * FROM y WHERE e = 2) r on a = d) + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashJoin_FOJ(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x full join (SELECT * FROM y WHERE e = 2) r on a = d) + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_HashJoin_1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r1 AS (SELECT a, c, sum(b) FROM x group by a, c), + |r2 AS (SELECT a, c, sum(b) FROM x group by a, c) + |SELECT * FROM r1, r2 WHERE r1.a = r2.a and r1.c = r2.c + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_NestedLoopJoin(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,SortMergeJoin") + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT * FROM r r1, r r2 WHERE r1.a = r2.d + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Join_PartialKey(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, true) + val sqlQuery = + """ + |WITH r AS (SELECT d, count(f) as cnt FROM y GROUP BY d) + |SELECT * FROM x, r WHERE x.a = r.d AND x.b = r.cnt + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveSingleExchange_Agg(): Unit = { + val sqlQuery = "SELECT avg(b) FROM x GROUP BY c HAVING sum(b) > (SELECT sum(b) * 0.1 FROM x)" + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Union(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "NestedLoopJoin,SortMergeJoin,SortAgg") + val sqlQuery = + """ + |WITH r AS ( + |SELECT count(a) as cnt, c FROM x WHERE b > 10 group by c + |UNION ALL + |SELECT count(d) as cnt, f FROM y WHERE e < 100 group by f) + |SELECT r1.c, r1.cnt, r2.c, r2.cnt FROM r r1, r r2 WHERE r1.c = r2.c and r1.cnt < 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Rank(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, RANK() OVER(PARTITION BY a ORDER BY b) rk FROM ( + | SELECT a, SUM(b) AS b FROM x GROUP BY a + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Rank_PartialKey1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, true) + val sqlQuery = + """ + |SELECT a, SUM(b) FROM ( + | SELECT * FROM ( + | SELECT a, b, c, RANK() OVER(PARTITION BY a, c ORDER BY b) rk FROM x) + | WHERE rk <= 10 + |) GROUP BY a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Rank_PartialKey2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, false) + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, c, RANK() OVER(PARTITION BY a, c ORDER BY b) rk FROM ( + | SELECT a, SUM(b) AS b, COUNT(c) AS c FROM x GROUP BY a + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Rank_PartialKey3(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, true) + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, c, RANK() OVER(PARTITION BY a, c ORDER BY b) rk FROM ( + | SELECT a, SUM(b) AS b, COUNT(c) AS c FROM x GROUP BY a + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Rank_Singleton1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, RANK() OVER(ORDER BY b) rk FROM ( + | SELECT COUNT(a) AS a, SUM(b) AS b FROM x + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Rank_Singleton2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, RANK() OVER(PARTITION BY a ORDER BY b) rk FROM ( + | SELECT COUNT(a) AS a, SUM(b) AS b FROM x + | ) + |) WHERE rk <= 10 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Correlate1(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + util.tableEnv.registerFunction("split", new TableFunc1) + val sqlQuery = + """ + |WITH r AS (SELECT f, count(f) as cnt FROM y GROUP BY f), + | v as (SELECT f1, f, cnt FROM r, LATERAL TABLE(split(f)) AS T(f1)) + |SELECT * FROM x, v WHERE c = f + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Correlate2(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + util.tableEnv.registerFunction("split", new TableFunc1) + val sqlQuery = + """ + |WITH r AS (SELECT f, count(f) as cnt FROM y GROUP BY f), + | v as (SELECT f, f1 FROM r, LATERAL TABLE(split(f)) AS T(f1)) + |SELECT * FROM x, v WHERE c = f AND f LIKE '%llo%' + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveHashShuffle_Correlate3(): Unit = { + // do not remove shuffle + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + util.tableEnv.registerFunction("split", new TableFunc1) + val sqlQuery = + """ + |WITH r AS (SELECT f, count(f) as cnt FROM y GROUP BY f), + | v as (SELECT f1 FROM r, LATERAL TABLE(split(f)) AS T(f1)) + |SELECT * FROM x, v WHERE c = f1 + """.stripMargin + util.verifyPlan(sqlQuery) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala index d7d45596eec615..0a8dbbeea58a92 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/OverAggregateTest.scala @@ -61,7 +61,22 @@ class OverAggregateTest extends TableTestBase { } @Test - def testDiffPartitionKeysWithDiffOrderKeys(): Unit = { + def testDiffPartitionKeysWithDiffOrderKeys1(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY a), + | MAX(a) OVER (PARTITION BY b ORDER BY c), + | AVG(a) OVER (PARTITION BY c ORDER BY a), + | RANK() OVER (PARTITION BY b ORDER BY a), + | MIN(a) OVER (PARTITION BY c ORDER BY a) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testDiffPartitionKeysWithDiffOrderKeys2(): Unit = { val sqlQuery = """ |SELECT @@ -76,7 +91,22 @@ class OverAggregateTest extends TableTestBase { } @Test - def testSamePartitionKeysWithDiffOrderKeys(): Unit = { + def testSamePartitionKeysWithDiffOrderKeys1(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY c), + | MAX(a) OVER (PARTITION BY b ORDER BY b), + | AVG(a) OVER (PARTITION BY b ORDER BY a), + | RANK() OVER (PARTITION BY b ORDER BY c), + | MIN(a) OVER (PARTITION BY b ORDER BY b) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithDiffOrderKeys2(): Unit = { val sqlQuery = """ |SELECT @@ -116,7 +146,22 @@ class OverAggregateTest extends TableTestBase { } @Test - def testSamePartitionKeysWithSameOrderKeysDiffDirection(): Unit = { + def testSamePartitionKeysWithSameOrderKeysDiffDirection1(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY a ASC), + | MAX(a) OVER (PARTITION BY b ORDER BY a ASC), + | AVG(a) OVER (PARTITION BY b ORDER BY a DESC), + | RANK() OVER (PARTITION BY b ORDER BY a ASC), + | MIN(a) OVER (PARTITION BY b ORDER BY a DESC) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithSameOrderKeysDiffDirection2(): Unit = { val sqlQuery = """ |SELECT diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.scala new file mode 100644 index 00000000000000..723b0720756b57 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/WindowGroupReorderRuleTest.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.plan.optimize.program._ +import org.apache.flink.table.util.TableTestBase + +import org.junit.{Before, Test} + +/** + * Test for [[WindowGroupReorderRule]]. + */ +class WindowGroupReorderRuleTest extends TableTestBase { + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.buildBatchProgram(FlinkBatchProgram.LOGICAL) + util.addTableSource[(Int, Int, String)]("MyTable", 'a, 'b, 'c) + } + + @Test + def testSamePartitionKeysWithSameOrderKeysPrefix(): Unit = { + val sqlQuery = + """ + |SELECT a, + | RANK() OVER (PARTITION BY b ORDER BY c, a DESC), + | RANK() OVER (PARTITION BY b ORDER BY c, b) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithDiffOrderKeys1(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY c), + | MAX(a) OVER (PARTITION BY b ORDER BY b), + | AVG(a) OVER (PARTITION BY b ORDER BY a), + | RANK() OVER (PARTITION BY b ORDER BY c), + | MIN(a) OVER (PARTITION BY b ORDER BY b) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithDiffOrderKeys2(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY c), + | MAX(a) OVER (PARTITION BY b ORDER BY a), + | AVG(a) OVER (PARTITION BY b ORDER BY a, c), + | RANK() OVER (PARTITION BY b ORDER BY a, b), + | MIN(a) OVER (PARTITION BY b ORDER BY b) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithSameOrderKeysDiffDirection1(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY a ASC), + | MAX(a) OVER (PARTITION BY b ORDER BY a ASC), + | AVG(a) OVER (PARTITION BY b ORDER BY a DESC), + | RANK() OVER (PARTITION BY b ORDER BY a ASC), + | MIN(a) OVER (PARTITION BY b ORDER BY a DESC) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithSameOrderKeysDiffDirection2(): Unit = { + val sqlQuery = + """ + |SELECT + | RANK() OVER (PARTITION BY b ORDER BY a DESC), + | RANK() OVER (PARTITION BY b ORDER BY a ASC) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSamePartitionKeysWithSameOrderKeysWithEmptyOrder(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY a), + | MIN(a) OVER (PARTITION BY b), + | MAX(a) OVER (PARTITION BY b ORDER BY a) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testDiffPartitionKeysWithSameOrderKeys(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY a), + | MAX(a) OVER (PARTITION BY b ORDER BY a), + | AVG(a) OVER (PARTITION BY c ORDER BY a), + | RANK() OVER (PARTITION BY b ORDER BY a), + | MIN(a) OVER (PARTITION BY c ORDER BY a) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testDiffPartitionKeysWithDiffOrderKeys1(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY a), + | MAX(a) OVER (PARTITION BY b ORDER BY c), + | AVG(a) OVER (PARTITION BY c ORDER BY a), + | RANK() OVER (PARTITION BY b ORDER BY a), + | MIN(a) OVER (PARTITION BY c ORDER BY a) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testDiffPartitionKeysWithDiffOrderKeys2(): Unit = { + val sqlQuery = + """ + |SELECT + | SUM(a) OVER (PARTITION BY b ORDER BY c), + | MAX(a) OVER (PARTITION BY c ORDER BY a), + | MIN(a) OVER (ORDER BY c, a), + | RANK() OVER (PARTITION BY b ORDER BY c), + | AVG(a) OVER (ORDER BY b) + |FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testMultiOverWindowRangeType(): Unit = { + val sqlQuery = + """ + |SELECT + | COUNT(*) OVER (PARTITION BY c ORDER BY a RANGE BETWEEN -1 PRECEDING AND 10 FOLLOWING), + | SUM(a) OVER (PARTITION BY c ORDER BY a), + | RANK() OVER (PARTITION BY c ORDER BY a, c), + | SUM(a) OVER (PARTITION BY c ORDER BY a RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING), + | COUNT(*) OVER (PARTITION BY c ORDER BY c ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) + | FROM MyTable + """.stripMargin + util.verifyPlan(sqlQuery) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala new file mode 100644 index 00000000000000..acf9972e538ca5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalHashAggRuleTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions} +import org.apache.flink.table.util.TableTestBase + +import org.junit.{Before, Test} + +/** + * Test for [[RemoveRedundantLocalHashAggRule]]. + */ +class RemoveRedundantLocalHashAggRuleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource[(Int, Long, String)]("x", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("y", 'd, 'e, 'f) + } + + @Test + def testRemoveRedundantLocalHashAgg_ShuffleKeyFromJoin(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,SortAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT SUM(b) FROM r GROUP BY a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveRedundantLocalHashAgg_ShuffleKeyFromRank(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + util.tableEnv.getConfig.getConf.setBoolean( + PlannerConfigOptions.SQL_OPTIMIZER_SHUFFLE_PARTIAL_KEY_ENABLED, true) + val sqlQuery = + """ + |SELECT a, SUM(b) FROM ( + | SELECT * FROM ( + | SELECT a, b, c, RANK() OVER (PARTITION BY a, c ORDER BY b) rk FROM x) + | WHERE rk <= 10 + |) GROUP BY a + """.stripMargin + util.verifyPlan(sqlQuery) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.scala new file mode 100644 index 00000000000000..35e50ad3774fd5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalRankRuleTest.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.flink.api.scala._ +import org.apache.flink.table.util.TableTestBase + +import org.junit.{Before, Test} + +/** + * Tests for [[RemoveRedundantLocalRankRule]]. + */ +class RemoveRedundantLocalRankRuleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource[(Int, Long, String)]("x", 'a, 'b, 'c) + } + + @Test + def testSameRankRange(): Unit = { + val sqlQuery = + """ + |SELECT a FROM ( + | SELECT a, RANK() OVER(PARTITION BY a ORDER BY SUM(b)) rk FROM x GROUP BY a + |) WHERE rk <= 5 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testDiffRankRange(): Unit = { + val sqlQuery = + """ + |SELECT a FROM ( + | SELECT a, RANK() OVER(PARTITION BY a ORDER BY SUM(b)) rk FROM x GROUP BY a + |) WHERE rk <= 5 and rk >= 2 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testMultiRanks(): Unit = { + val sqlQuery = + """ + |SELECT * FROM ( + | SELECT a, b, rk, RANK() OVER(PARTITION BY a ORDER BY b) rk1 FROM ( + | SELECT a, b, RANK() OVER(PARTITION BY a ORDER BY b) rk FROM x + | ) WHERE rk <= 5 + |) WHERE rk1 <= 5 + """.stripMargin + util.verifyPlan(sqlQuery) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala new file mode 100644 index 00000000000000..36c05abe8d45e7 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/physical/batch/RemoveRedundantLocalSortAggRuleTest.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.physical.batch + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions} +import org.apache.flink.table.util.TableTestBase + +import org.junit.{Before, Test} + +/** + * Test for [[RemoveRedundantLocalSortAggRule]]. + */ +class RemoveRedundantLocalSortAggRuleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource[(Int, Long, String)]("x", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("y", 'd, 'e, 'f) + } + + @Test + def testRemoveRedundantLocalSortAggWithSort(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortMergeJoin,NestedLoopJoin,HashAgg") + // disable BroadcastHashJoin + util.tableEnv.getConfig.getConf.setLong( + PlannerConfigOptions.SQL_OPTIMIZER_HASH_JOIN_BROADCAST_THRESHOLD, -1) + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testRemoveRedundantLocalSortAggWithoutSort(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin,NestedLoopJoin,HashAgg") + val sqlQuery = + """ + |WITH r AS (SELECT * FROM x, y WHERE a = d AND c LIKE 'He%') + |SELECT sum(b) FROM r group by a + """.stripMargin + util.verifyPlan(sqlQuery) + } + +} From 6af47370be952e4dca3b18d71f5206269ac78fc8 Mon Sep 17 00:00:00 2001 From: "Abdul Qadeer (abqadeer)" Date: Tue, 9 Apr 2019 19:19:15 -0700 Subject: [PATCH 02/92] [FLINK-12167] Reset context classloader in run and getOptimizedPlan This closes #8154. --- .../flink/client/program/ClusterClient.java | 99 ++++++++++--------- 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java b/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java index 0da79974b0e332..0969ae570ad15b 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java +++ b/flink-clients/src/main/java/org/apache/flink/client/program/ClusterClient.java @@ -204,19 +204,24 @@ public static String getOptimizedPlanAsJson(Optimizer compiler, PackagedProgram public static FlinkPlan getOptimizedPlan(Optimizer compiler, PackagedProgram prog, int parallelism) throws CompilerException, ProgramInvocationException { - Thread.currentThread().setContextClassLoader(prog.getUserCodeClassLoader()); - if (prog.isUsingProgramEntryPoint()) { - return getOptimizedPlan(compiler, prog.getPlanWithJars(), parallelism); - } else if (prog.isUsingInteractiveMode()) { - // temporary hack to support the optimizer plan preview - OptimizerPlanEnvironment env = new OptimizerPlanEnvironment(compiler); - if (parallelism > 0) { - env.setParallelism(parallelism); - } + final ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + try { + Thread.currentThread().setContextClassLoader(prog.getUserCodeClassLoader()); + if (prog.isUsingProgramEntryPoint()) { + return getOptimizedPlan(compiler, prog.getPlanWithJars(), parallelism); + } else if (prog.isUsingInteractiveMode()) { + // temporary hack to support the optimizer plan preview + OptimizerPlanEnvironment env = new OptimizerPlanEnvironment(compiler); + if (parallelism > 0) { + env.setParallelism(parallelism); + } - return env.getOptimizedPlan(prog); - } else { - throw new RuntimeException("Couldn't determine program mode."); + return env.getOptimizedPlan(prog); + } else { + throw new RuntimeException("Couldn't determine program mode."); + } + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader); } } @@ -247,44 +252,48 @@ public static OptimizedPlan getOptimizedPlan(Optimizer compiler, Plan p, int par */ public JobSubmissionResult run(PackagedProgram prog, int parallelism) throws ProgramInvocationException, ProgramMissingJobException { - Thread.currentThread().setContextClassLoader(prog.getUserCodeClassLoader()); - if (prog.isUsingProgramEntryPoint()) { - - final JobWithJars jobWithJars = prog.getPlanWithJars(); - - return run(jobWithJars, parallelism, prog.getSavepointSettings()); - } - else if (prog.isUsingInteractiveMode()) { - log.info("Starting program in interactive mode (detached: {})", isDetached()); - - final List libraries = prog.getAllLibraries(); - - ContextEnvironmentFactory factory = new ContextEnvironmentFactory(this, libraries, - prog.getClasspaths(), prog.getUserCodeClassLoader(), parallelism, isDetached(), - prog.getSavepointSettings()); - ContextEnvironment.setAsContext(factory); - - try { - // invoke main method - prog.invokeInteractiveModeForExecution(); - if (lastJobExecutionResult == null && factory.getLastEnvCreated() == null) { - throw new ProgramMissingJobException("The program didn't contain a Flink job."); - } - if (isDetached()) { - // in detached mode, we execute the whole user code to extract the Flink job, afterwards we run it here - return ((DetachedEnvironment) factory.getLastEnvCreated()).finalizeExecute(); + final ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + try { + Thread.currentThread().setContextClassLoader(prog.getUserCodeClassLoader()); + if (prog.isUsingProgramEntryPoint()) { + final JobWithJars jobWithJars = prog.getPlanWithJars(); + return run(jobWithJars, parallelism, prog.getSavepointSettings()); + } + else if (prog.isUsingInteractiveMode()) { + log.info("Starting program in interactive mode (detached: {})", isDetached()); + + final List libraries = prog.getAllLibraries(); + + ContextEnvironmentFactory factory = new ContextEnvironmentFactory(this, libraries, + prog.getClasspaths(), prog.getUserCodeClassLoader(), parallelism, isDetached(), + prog.getSavepointSettings()); + ContextEnvironment.setAsContext(factory); + + try { + // invoke main method + prog.invokeInteractiveModeForExecution(); + if (lastJobExecutionResult == null && factory.getLastEnvCreated() == null) { + throw new ProgramMissingJobException("The program didn't contain a Flink job."); + } + if (isDetached()) { + // in detached mode, we execute the whole user code to extract the Flink job, afterwards we run it here + return ((DetachedEnvironment) factory.getLastEnvCreated()).finalizeExecute(); + } + else { + // in blocking mode, we execute all Flink jobs contained in the user code and then return here + return this.lastJobExecutionResult; + } } - else { - // in blocking mode, we execute all Flink jobs contained in the user code and then return here - return this.lastJobExecutionResult; + finally { + ContextEnvironment.unsetContext(); } } - finally { - ContextEnvironment.unsetContext(); + else { + throw new ProgramInvocationException("PackagedProgram does not have a valid invocation mode."); } } - else { - throw new ProgramInvocationException("PackagedProgram does not have a valid invocation mode."); + finally { + Thread.currentThread().setContextClassLoader(contextClassLoader); } } From 4f558e4f225195fefe61718ef09b989f395987ba Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 24 May 2019 15:58:13 +0800 Subject: [PATCH 03/92] [FLINK-12152] Make the vcore that Application Master used configurable for Flink on YARN This closes #8438. --- .../_includes/generated/yarn_config_configuration.html | 5 +++++ .../flink/yarn/AbstractYarnClusterDescriptor.java | 10 +++++++++- .../flink/yarn/configuration/YarnConfigOptions.java | 8 ++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/_includes/generated/yarn_config_configuration.html b/docs/_includes/generated/yarn_config_configuration.html index 40dfc09d3612b4..6615c325522f63 100644 --- a/docs/_includes/generated/yarn_config_configuration.html +++ b/docs/_includes/generated/yarn_config_configuration.html @@ -32,6 +32,11 @@ -1 The port where the application master RPC system is listening. + +
yarn.appmaster.vcores
+ 1 + The number of virtual cores (vcores) used by YARN application master. +
yarn.containers.vcores
-1 diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java b/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java index 0f244961dc2174..3135ecf0708023 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/AbstractYarnClusterDescriptor.java @@ -242,6 +242,14 @@ private void isReadyForDeployment(ClusterSpecification clusterSpecification) thr throw new YarnDeploymentException("Couldn't get cluster description, please check on the YarnConfiguration", e); } + int configuredAmVcores = flinkConfiguration.getInteger(YarnConfigOptions.APP_MASTER_VCORES); + if (configuredAmVcores > numYarnMaxVcores) { + throw new IllegalConfigurationException( + String.format("The number of requested virtual cores for application master %d" + + " exceeds the maximum number of virtual cores %d available in the Yarn Cluster.", + configuredAmVcores, numYarnMaxVcores)); + } + int configuredVcores = flinkConfiguration.getInteger(YarnConfigOptions.VCORES, clusterSpecification.getSlotsPerTaskManager()); // don't configure more than the maximum configured number of vcores if (configuredVcores > numYarnMaxVcores) { @@ -971,7 +979,7 @@ public ApplicationReport startAppMaster( // Set up resource type requirements for ApplicationMaster Resource capability = Records.newRecord(Resource.class); capability.setMemory(clusterSpecification.getMasterMemoryMB()); - capability.setVirtualCores(1); + capability.setVirtualCores(flinkConfiguration.getInteger(YarnConfigOptions.APP_MASTER_VCORES)); final String customApplicationName = customName != null ? customName : applicationName; diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/configuration/YarnConfigOptions.java b/flink-yarn/src/main/java/org/apache/flink/yarn/configuration/YarnConfigOptions.java index 0f46a572256b3e..c0b6cfebf4bf3e 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/configuration/YarnConfigOptions.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/configuration/YarnConfigOptions.java @@ -49,6 +49,14 @@ public class YarnConfigOptions { .defaultValue(-1) .withDescription("The port where the application master RPC system is listening."); + /** + * The vcores used by YARN application master. + */ + public static final ConfigOption APP_MASTER_VCORES = + key("yarn.appmaster.vcores") + .defaultValue(1) + .withDescription("The number of virtual cores (vcores) used by YARN application master."); + /** * Defines whether user-jars are included in the system class path for per-job-clusters as well as their positioning * in the path. They can be positioned at the beginning ("FIRST"), at the end ("LAST"), or be positioned based on From dac36482e2bc7b16d0c1e8dc8ae986d5cf6c1db6 Mon Sep 17 00:00:00 2001 From: Paul Lam Date: Wed, 12 Dec 2018 11:39:35 +0800 Subject: [PATCH 04/92] [FLINK-11126][YARN][security] Filter out AMRMToken in the TaskManager credentials This closes #7895. --- .../java/org/apache/flink/yarn/UtilsTest.java | 88 +++++++++++++++++++ .../yarn/YARNSessionFIFOSecuredITCase.java | 19 ++++ .../org/apache/flink/yarn/YarnTestBase.java | 54 +++++++++++- .../java/org/apache/flink/yarn/Utils.java | 13 ++- 4 files changed, 171 insertions(+), 3 deletions(-) diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/UtilsTest.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/UtilsTest.java index 1262096f5d2ae7..3a3144ee4b475c 100644 --- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/UtilsTest.java +++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/UtilsTest.java @@ -21,28 +21,50 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.ResourceManagerOptions; +import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.runtime.clusterframework.ContaineredTaskManagerParameters; import org.apache.flink.util.TestLogger; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.security.AMRMTokenIdentifier; import org.apache.log4j.AppenderSkeleton; import org.apache.log4j.Level; import org.apache.log4j.spi.LoggingEvent; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; import java.io.File; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + /** * Tests for various utilities. */ public class UtilsTest extends TestLogger { private static final Logger LOG = LoggerFactory.getLogger(UtilsTest.class); + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + @Test public void testUberjarLocator() { File dir = YarnTestBase.findFile("..", new YarnTestBase.RootDirFilenameFilter()); @@ -136,6 +158,72 @@ public void testGetEnvironmentVariablesErroneous() { Assert.assertEquals(0, res.size()); } + @Test + public void testCreateTaskExecutorCredentials() throws Exception { + File root = temporaryFolder.getRoot(); + File home = new File(root, "home"); + boolean created = home.mkdir(); + assertTrue(created); + + Configuration flinkConf = new Configuration(); + YarnConfiguration yarnConf = new YarnConfiguration(); + + Map env = new HashMap<>(); + env.put(YarnConfigKeys.ENV_APP_ID, "foo"); + env.put(YarnConfigKeys.ENV_CLIENT_HOME_DIR, home.getAbsolutePath()); + env.put(YarnConfigKeys.ENV_CLIENT_SHIP_FILES, ""); + env.put(YarnConfigKeys.ENV_FLINK_CLASSPATH, ""); + env.put(YarnConfigKeys.ENV_HADOOP_USER_NAME, "foo"); + env.put(YarnConfigKeys.FLINK_JAR_PATH, root.toURI().toString()); + env = Collections.unmodifiableMap(env); + + File credentialFile = temporaryFolder.newFile("container_tokens"); + final Text amRmTokenKind = AMRMTokenIdentifier.KIND_NAME; + final Text hdfsDelegationTokenKind = new Text("HDFS_DELEGATION_TOKEN"); + final Text service = new Text("test-service"); + Credentials amCredentials = new Credentials(); + amCredentials.addToken(amRmTokenKind, new Token<>(new byte[4], new byte[4], amRmTokenKind, service)); + amCredentials.addToken(hdfsDelegationTokenKind, new Token<>(new byte[4], new byte[4], + hdfsDelegationTokenKind, service)); + amCredentials.writeTokenStorageFile(new org.apache.hadoop.fs.Path(credentialFile.getAbsolutePath()), yarnConf); + + ContaineredTaskManagerParameters tmParams = new ContaineredTaskManagerParameters(64, + 64, 16, 1, new HashMap<>(1)); + Configuration taskManagerConf = new Configuration(); + + String workingDirectory = root.getAbsolutePath(); + Class taskManagerMainClass = YarnTaskExecutorRunner.class; + ContainerLaunchContext ctx; + + final Map originalEnv = System.getenv(); + try { + Map systemEnv = new HashMap<>(originalEnv); + systemEnv.put("HADOOP_TOKEN_FILE_LOCATION", credentialFile.getAbsolutePath()); + CommonTestUtils.setEnv(systemEnv); + ctx = Utils.createTaskExecutorContext(flinkConf, yarnConf, env, tmParams, + taskManagerConf, workingDirectory, taskManagerMainClass, LOG); + } finally { + CommonTestUtils.setEnv(originalEnv); + } + + Credentials credentials = new Credentials(); + try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(ctx.getTokens().array()))) { + credentials.readTokenStorageStream(dis); + } + Collection> tokens = credentials.getAllTokens(); + boolean hasHdfsDelegationToken = false; + boolean hasAmRmToken = false; + for (Token token : tokens) { + if (token.getKind().equals(amRmTokenKind)) { + hasAmRmToken = true; + } else if (token.getKind().equals(hdfsDelegationTokenKind)) { + hasHdfsDelegationToken = true; + } + } + assertTrue(hasHdfsDelegationToken); + assertFalse(hasAmRmToken); + } + // // --------------- Tools to test if a certain string has been logged with Log4j. ------------- // See : http://stackoverflow.com/questions/3717402/how-to-test-w-junit-that-warning-was-logged-w-log4j diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionFIFOSecuredITCase.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionFIFOSecuredITCase.java index d9a79b60eee434..54a5532a8b7601 100644 --- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionFIFOSecuredITCase.java +++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionFIFOSecuredITCase.java @@ -26,7 +26,10 @@ import org.apache.flink.test.util.SecureTestEnvironment; import org.apache.flink.test.util.TestingSecurityContext; +import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; + import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.security.AMRMTokenIdentifier; import org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceScheduler; import org.apache.hadoop.yarn.server.resourcemanager.scheduler.fifo.FifoScheduler; import org.hamcrest.Matchers; @@ -39,6 +42,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.concurrent.Callable; /** @@ -116,6 +120,21 @@ public void testDetachedMode() throws InterruptedException, IOException { "The JobManager and the TaskManager should both run with Kerberos.", jobManagerRunsWithKerberos && taskManagerRunsWithKerberos, Matchers.is(true)); + + final List amRMTokens = Lists.newArrayList(AMRMTokenIdentifier.KIND_NAME.toString()); + final String jobmanagerContainerId = getContainerIdByLogName("jobmanager.log"); + final String taskmanagerContainerId = getContainerIdByLogName("taskmanager.log"); + final boolean jobmanagerWithAmRmToken = verifyTokenKindInContainerCredentials(amRMTokens, jobmanagerContainerId); + final boolean taskmanagerWithAmRmToken = verifyTokenKindInContainerCredentials(amRMTokens, taskmanagerContainerId); + + Assert.assertThat( + "The JobManager should have AMRMToken.", + jobmanagerWithAmRmToken, + Matchers.is(true)); + Assert.assertThat( + "The TaskManager should not have AMRMToken.", + taskmanagerWithAmRmToken, + Matchers.is(false)); } /* For secure cluster testing, it is enough to run only one test and override below test methods diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java index 9aad148700a422..8af5e85e79eb72 100644 --- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java +++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java @@ -33,6 +33,9 @@ import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.fs.FileUtil; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.service.Service; import org.apache.hadoop.yarn.api.records.ApplicationReport; import org.apache.hadoop.yarn.api.records.ContainerId; @@ -73,6 +76,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -464,9 +468,8 @@ public boolean accept(File dir, String name) { } File f = new File(dir.getAbsolutePath() + "/" + name); LOG.info("Searching in {}", f.getAbsolutePath()); - try { + try (Scanner scanner = new Scanner(f)) { Set foundSet = new HashSet<>(mustHave.length); - Scanner scanner = new Scanner(f); while (scanner.hasNextLine()) { final String lineFromFile = scanner.nextLine(); for (String str : mustHave) { @@ -493,6 +496,53 @@ public boolean accept(File dir, String name) { } } + public static boolean verifyTokenKindInContainerCredentials(final Collection tokens, final String containerId) + throws IOException { + File cwd = new File("target/" + YARN_CONFIGURATION.get(TEST_CLUSTER_NAME_KEY)); + if (!cwd.exists() || !cwd.isDirectory()) { + return false; + } + + File containerTokens = findFile(cwd.getAbsolutePath(), new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.equals(containerId + ".tokens"); + } + }); + + if (containerTokens != null) { + LOG.info("Verifying tokens in {}", containerTokens.getAbsolutePath()); + + Credentials tmCredentials = Credentials.readTokenStorageFile(containerTokens, new Configuration()); + + Collection> userTokens = tmCredentials.getAllTokens(); + Set tokenKinds = new HashSet<>(4); + for (Token token : userTokens) { + tokenKinds.add(token.getKind().toString()); + } + + return tokenKinds.containsAll(tokens); + } else { + LOG.warn("Unable to find credential file for container {}", containerId); + return false; + } + } + + public static String getContainerIdByLogName(String logName) { + File cwd = new File("target/" + YARN_CONFIGURATION.get(TEST_CLUSTER_NAME_KEY)); + File containerLog = findFile(cwd.getAbsolutePath(), new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.equals(logName); + } + }); + if (containerLog != null) { + return containerLog.getParentFile().getName(); + } else { + throw new IllegalStateException("No container has log named " + logName); + } + } + public static void sleep(int time) { try { Thread.sleep(time); diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/Utils.java b/flink-yarn/src/main/java/org/apache/flink/yarn/Utils.java index 261bc971eaa5aa..eea6b9a88a53e8 100644 --- a/flink-yarn/src/main/java/org/apache/flink/yarn/Utils.java +++ b/flink-yarn/src/main/java/org/apache/flink/yarn/Utils.java @@ -45,6 +45,7 @@ import org.apache.hadoop.yarn.api.records.LocalResourceType; import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.security.AMRMTokenIdentifier; import org.apache.hadoop.yarn.util.ConverterUtils; import org.apache.hadoop.yarn.util.Records; import org.slf4j.Logger; @@ -567,7 +568,17 @@ static ContainerLaunchContext createTaskExecutorContext( new File(fileLocation), HadoopUtils.getHadoopConfiguration(flinkConfig)); - cred.writeTokenStorageToStream(dob); + // Filter out AMRMToken before setting the tokens to the TaskManager container context. + Credentials taskManagerCred = new Credentials(); + Collection> userTokens = cred.getAllTokens(); + for (Token token : userTokens) { + if (!token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) { + final Text id = new Text(token.getIdentifier()); + taskManagerCred.addToken(id, token); + } + } + + taskManagerCred.writeTokenStorageToStream(dob); ByteBuffer securityTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); ctx.setTokens(securityTokens); } catch (Throwable t) { From f0ca1cc02ee7330febca3ed5f7a9e30141ec8ac3 Mon Sep 17 00:00:00 2001 From: Bo WANG Date: Mon, 27 May 2019 17:25:47 +0800 Subject: [PATCH 05/92] [FLINK-12414][runtime] Implement SchedulingTopology adapter --- .../DefaultSchedulingExecutionVertex.java | 80 ++++++++ .../DefaultSchedulingResultPartition.java | 106 ++++++++++ ...utionGraphToSchedulingTopologyAdapter.java | 150 ++++++++++++++ .../strategy/SchedulingResultPartition.java | 2 +- .../DefaultSchedulingExecutionVertexTest.java | 112 +++++++++++ .../DefaultSchedulingResultPartitionTest.java | 102 ++++++++++ ...nGraphToSchedulingTopologyAdapterTest.java | 187 ++++++++++++++++++ 7 files changed, 738 insertions(+), 1 deletion(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartition.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java new file mode 100644 index 00000000000000..4b13d70c97d5b9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adapter; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Default implementation of {@link SchedulingExecutionVertex}. + */ +class DefaultSchedulingExecutionVertex implements SchedulingExecutionVertex { + + private final ExecutionVertexID executionVertexId; + + private final List consumedPartitions; + + private final List producedPartitions; + + private final Supplier stateSupplier; + + DefaultSchedulingExecutionVertex( + ExecutionVertexID executionVertexId, + List producedPartitions, + Supplier stateSupplier) { + this.executionVertexId = checkNotNull(executionVertexId); + this.consumedPartitions = new ArrayList<>(); + this.stateSupplier = checkNotNull(stateSupplier); + this.producedPartitions = checkNotNull(producedPartitions); + } + + @Override + public ExecutionVertexID getId() { + return executionVertexId; + } + + @Override + public ExecutionState getState() { + return stateSupplier.get(); + } + + @Override + public Collection getConsumedResultPartitions() { + return Collections.unmodifiableCollection(consumedPartitions); + } + + @Override + public Collection getProducedResultPartitions() { + return Collections.unmodifiableCollection(producedPartitions); + } + + void addConsumedPartition(X partition) { + consumedPartitions.add(partition); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartition.java new file mode 100644 index 00000000000000..45a80dcddb0f7f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartition.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adapter; + +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.DONE; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.EMPTY; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.PRODUCING; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Default implementation of {@link SchedulingResultPartition}. + */ +class DefaultSchedulingResultPartition implements SchedulingResultPartition { + + private final IntermediateResultPartitionID resultPartitionId; + + private final IntermediateDataSetID intermediateDataSetId; + + private final ResultPartitionType partitionType; + + private SchedulingExecutionVertex producer; + + private final List consumers; + + DefaultSchedulingResultPartition( + IntermediateResultPartitionID partitionId, + IntermediateDataSetID intermediateDataSetId, + ResultPartitionType partitionType) { + this.resultPartitionId = checkNotNull(partitionId); + this.intermediateDataSetId = checkNotNull(intermediateDataSetId); + this.partitionType = checkNotNull(partitionType); + this.consumers = new ArrayList<>(); + } + + @Override + public IntermediateResultPartitionID getId() { + return resultPartitionId; + } + + @Override + public IntermediateDataSetID getResultId() { + return intermediateDataSetId; + } + + @Override + public ResultPartitionType getPartitionType() { + return partitionType; + } + + @Override + public ResultPartitionState getState() { + switch (producer.getState()) { + case RUNNING: + return PRODUCING; + case FINISHED: + return DONE; + default: + return EMPTY; + } + } + + @Override + public SchedulingExecutionVertex getProducer() { + return producer; + } + + @Override + public Collection getConsumers() { + return Collections.unmodifiableCollection(consumers); + } + + void addConsumer(SchedulingExecutionVertex vertex) { + consumers.add(checkNotNull(vertex)); + } + + void setProducer(SchedulingExecutionVertex vertex) { + producer = checkNotNull(vertex); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java new file mode 100644 index 00000000000000..abf94697fbeee8 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adapter; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.ExecutionEdge; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; +import org.apache.flink.runtime.scheduler.strategy.SchedulingTopology; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Adapter of {@link ExecutionGraph} to {@link SchedulingTopology}. + */ +public class ExecutionGraphToSchedulingTopologyAdapter implements SchedulingTopology { + + private final Map executionVerticesById; + + private final List executionVerticesList; + + private final Map resultPartitionsById; + + public ExecutionGraphToSchedulingTopologyAdapter(ExecutionGraph graph) { + checkNotNull(graph, "execution graph can not be null"); + + this.executionVerticesById = new HashMap<>(); + this.executionVerticesList = new ArrayList<>(graph.getTotalNumberOfVertices()); + Map tmpResultPartitionsById = new HashMap<>(); + Map executionVertexMap = new HashMap<>(); + + for (ExecutionVertex vertex : graph.getAllExecutionVertices()) { + List producedPartitions = generateProducedSchedulingResultPartition(vertex.getProducedPartitions()); + + producedPartitions.forEach(partition -> tmpResultPartitionsById.put(partition.getId(), partition)); + + DefaultSchedulingExecutionVertex schedulingVertex = generateSchedulingExecutionVertex(vertex, producedPartitions); + this.executionVerticesById.put(schedulingVertex.getId(), schedulingVertex); + this.executionVerticesList.add(schedulingVertex); + executionVertexMap.put(vertex, schedulingVertex); + } + this.resultPartitionsById = tmpResultPartitionsById; + + connectVerticesToConsumedPartitions(executionVertexMap, tmpResultPartitionsById); + } + + @Override + public Iterable getVertices() { + return executionVerticesList; + } + + @Override + public Optional getVertex(ExecutionVertexID executionVertexId) { + return Optional.ofNullable(executionVerticesById.get(executionVertexId)); + } + + @Override + public Optional getResultPartition(IntermediateResultPartitionID intermediateResultPartitionId) { + return Optional.ofNullable(resultPartitionsById.get(intermediateResultPartitionId)); + } + + private static List generateProducedSchedulingResultPartition( + Map producedIntermediatePartitions) { + + List producedSchedulingPartitions = new ArrayList<>(producedIntermediatePartitions.size()); + + producedIntermediatePartitions.values().forEach( + irp -> producedSchedulingPartitions.add( + new DefaultSchedulingResultPartition( + irp.getPartitionId(), + irp.getIntermediateResult().getId(), + irp.getResultType()))); + + return producedSchedulingPartitions; + } + + private static DefaultSchedulingExecutionVertex generateSchedulingExecutionVertex( + ExecutionVertex vertex, + List producedPartitions) { + + DefaultSchedulingExecutionVertex schedulingVertex = new DefaultSchedulingExecutionVertex( + new ExecutionVertexID(vertex.getJobvertexId(), vertex.getParallelSubtaskIndex()), + producedPartitions, + new ExecutionStateSupplier(vertex)); + + producedPartitions.forEach(partition -> partition.setProducer(schedulingVertex)); + + return schedulingVertex; + } + + private static void connectVerticesToConsumedPartitions( + Map executionVertexMap, + Map resultPartitions) { + + for (Map.Entry mapEntry : executionVertexMap.entrySet()) { + final DefaultSchedulingExecutionVertex schedulingVertex = mapEntry.getValue(); + final ExecutionVertex executionVertex = mapEntry.getKey(); + + for (int index = 0; index < executionVertex.getNumberOfInputs(); index++) { + for (ExecutionEdge edge : executionVertex.getInputEdges(index)) { + DefaultSchedulingResultPartition partition = resultPartitions.get(edge.getSource().getPartitionId()); + schedulingVertex.addConsumedPartition(partition); + partition.addConsumer(schedulingVertex); + } + } + } + } + + private static class ExecutionStateSupplier implements Supplier { + + private final ExecutionVertex executionVertex; + + ExecutionStateSupplier(ExecutionVertex vertex) { + executionVertex = checkNotNull(vertex); + } + + @Override + public ExecutionState get() { + return executionVertex.getExecutionState(); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java index aefc5613d7ffeb..86ea8ba015a02a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java @@ -77,7 +77,7 @@ public interface SchedulingResultPartition { */ enum ResultPartitionState { /** - * Producer is not yet running. + * Producer is not yet running or in abnormal state. */ EMPTY, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java new file mode 100644 index 00000000000000..e9af7c6980c3ab --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adapter; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; +import org.apache.flink.util.TestLogger; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.function.Supplier; + +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link DefaultSchedulingExecutionVertex}. + */ +public class DefaultSchedulingExecutionVertexTest extends TestLogger { + + private final TestExecutionStateSupplier stateSupplier = new TestExecutionStateSupplier(); + + private DefaultSchedulingExecutionVertex producerVertex; + + private DefaultSchedulingExecutionVertex consumerVertex; + + private IntermediateResultPartitionID intermediateResultPartitionId; + + @Before + public void setUp() throws Exception { + + intermediateResultPartitionId = new IntermediateResultPartitionID(); + + DefaultSchedulingResultPartition schedulingResultPartition = new DefaultSchedulingResultPartition( + intermediateResultPartitionId, + new IntermediateDataSetID(), + BLOCKING); + producerVertex = new DefaultSchedulingExecutionVertex( + new ExecutionVertexID(new JobVertexID(), 0), + Collections.singletonList(schedulingResultPartition), + stateSupplier); + schedulingResultPartition.setProducer(producerVertex); + consumerVertex = new DefaultSchedulingExecutionVertex( + new ExecutionVertexID(new JobVertexID(), 0), + Collections.emptyList(), + stateSupplier); + consumerVertex.addConsumedPartition(schedulingResultPartition); + } + + @Test + public void testGetExecutionState() { + for (ExecutionState state : ExecutionState.values()) { + stateSupplier.setExecutionState(state); + assertEquals(state, producerVertex.getState()); + } + } + + @Test + public void testGetProducedResultPartitions() { + IntermediateResultPartitionID partitionIds1 = producerVertex + .getProducedResultPartitions().stream().findAny().map(SchedulingResultPartition::getId) + .orElseThrow(() -> new IllegalArgumentException("can not find result partition")); + assertEquals(partitionIds1, intermediateResultPartitionId); + } + + @Test + public void testGetConsumedResultPartitions() { + IntermediateResultPartitionID partitionIds1 = consumerVertex + .getConsumedResultPartitions().stream().findAny().map(SchedulingResultPartition::getId) + .orElseThrow(() -> new IllegalArgumentException("can not find result partition")); + assertEquals(partitionIds1, intermediateResultPartitionId); + } + + /** + * A simple implementation of {@link Supplier} for testing. + */ + static class TestExecutionStateSupplier implements Supplier { + + private ExecutionState executionState; + + void setExecutionState(ExecutionState state) { + executionState = state; + } + + @Override + public ExecutionState get() { + return executionState; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java new file mode 100644 index 00000000000000..d114b2ea28f07b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adapter; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; +import org.apache.flink.util.TestLogger; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.function.Supplier; + +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.DONE; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.EMPTY; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.PRODUCING; +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link DefaultSchedulingResultPartition}. + */ +public class DefaultSchedulingResultPartitionTest extends TestLogger { + + private static final TestExecutionStateSupplier stateProvider = new TestExecutionStateSupplier(); + + private final IntermediateResultPartitionID resultPartitionId = new IntermediateResultPartitionID(); + private final IntermediateDataSetID intermediateResultId = new IntermediateDataSetID(); + + private DefaultSchedulingResultPartition resultPartition; + + @Before + public void setUp() { + resultPartition = new DefaultSchedulingResultPartition( + resultPartitionId, + intermediateResultId, + BLOCKING); + + DefaultSchedulingExecutionVertex producerVertex = new DefaultSchedulingExecutionVertex( + new ExecutionVertexID(new JobVertexID(), 0), + Collections.singletonList(resultPartition), + stateProvider); + resultPartition.setProducer(producerVertex); + } + + @Test + public void testGetPartitionState() { + for (ExecutionState state : ExecutionState.values()) { + stateProvider.setExecutionState(state); + SchedulingResultPartition.ResultPartitionState partitionState = resultPartition.getState(); + switch (state) { + case RUNNING: + assertEquals(PRODUCING, partitionState); + break; + case FINISHED: + assertEquals(DONE, partitionState); + break; + default: + assertEquals(EMPTY, partitionState); + break; + } + } + } + + /** + * A simple implementation of {@link Supplier} for testing. + */ + private static class TestExecutionStateSupplier implements Supplier { + + private ExecutionState executionState; + + void setExecutionState(ExecutionState state) { + executionState = state; + } + + @Override + public ExecutionState get() { + return executionState; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java new file mode 100644 index 00000000000000..0861f13d64c23a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adapter; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.executiongraph.ExecutionEdge; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.executiongraph.TestRestartStrategy; +import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex; +import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition; +import org.apache.flink.runtime.scheduler.strategy.SchedulingTopology; +import org.apache.flink.util.TestLogger; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static junit.framework.TestCase.assertTrue; +import static org.apache.flink.api.common.InputDependencyConstraint.ALL; +import static org.apache.flink.api.common.InputDependencyConstraint.ANY; +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex; +import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createSimpleTestGraph; +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; +import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +/** + * Unit tests for {@link ExecutionGraphToSchedulingTopologyAdapter}. + */ +public class ExecutionGraphToSchedulingTopologyAdapterTest extends TestLogger { + + private final SimpleAckingTaskManagerGateway taskManagerGateway = new SimpleAckingTaskManagerGateway(); + + private final TestRestartStrategy triggeredRestartStrategy = TestRestartStrategy.manuallyTriggered(); + + private ExecutionGraph executionGraph; + + private ExecutionGraphToSchedulingTopologyAdapter adapter; + + @Before + public void setUp() throws Exception { + JobVertex[] jobVertices = new JobVertex[2]; + int parallelism = 3; + jobVertices[0] = createNoOpVertex(parallelism); + jobVertices[1] = createNoOpVertex(parallelism); + jobVertices[1].connectNewDataSetAsInput(jobVertices[0], ALL_TO_ALL, BLOCKING); + jobVertices[0].setInputDependencyConstraint(ALL); + jobVertices[1].setInputDependencyConstraint(ANY); + executionGraph = createSimpleTestGraph( + new JobID(), + taskManagerGateway, + triggeredRestartStrategy, + jobVertices); + adapter = new ExecutionGraphToSchedulingTopologyAdapter(executionGraph); + } + + @Test + public void testConstructor() { + // implicitly tests order constraint of getVertices() + assertGraphEquals(executionGraph, adapter); + } + + @Test + public void testGetResultPartition() { + for (ExecutionVertex vertex : executionGraph.getAllExecutionVertices()) { + for (Map.Entry entry : vertex.getProducedPartitions().entrySet()) { + IntermediateResultPartition partition = entry.getValue(); + SchedulingResultPartition schedulingResultPartition = adapter.getResultPartition(entry.getKey()) + .orElseThrow(() -> new IllegalArgumentException("can not find partition " + entry.getKey())); + + assertPartitionEquals(partition, schedulingResultPartition); + } + } + } + + private static void assertGraphEquals( + ExecutionGraph originalGraph, + SchedulingTopology adaptedTopology) { + + Iterator originalVertices = originalGraph.getAllExecutionVertices().iterator(); + Iterator adaptedVertices = adaptedTopology.getVertices().iterator(); + + while (originalVertices.hasNext()) { + ExecutionVertex originalVertex = originalVertices.next(); + SchedulingExecutionVertex adaptedVertex = adaptedVertices.next(); + + assertVertexEquals(originalVertex, adaptedVertex); + + List originalConsumedPartitions = IntStream.range(0, originalVertex.getNumberOfInputs()) + .mapToObj(originalVertex::getInputEdges) + .flatMap(Arrays::stream) + .map(ExecutionEdge::getSource) + .collect(Collectors.toList()); + Collection adaptedConsumedPartitions = adaptedVertex.getConsumedResultPartitions(); + + assertPartitionsEquals(originalConsumedPartitions, adaptedConsumedPartitions); + + Collection originalProducedPartitions = originalVertex.getProducedPartitions().values(); + Collection adaptedProducedPartitions = adaptedVertex.getProducedResultPartitions(); + + assertPartitionsEquals(originalProducedPartitions, adaptedProducedPartitions); + } + + assertFalse("Number of adapted vertices exceeds number of original vertices.", adaptedVertices.hasNext()); + } + + private static void assertPartitionsEquals( + Collection originalPartitions, + Collection adaptedPartitions) { + + assertEquals(originalPartitions.size(), adaptedPartitions.size()); + + for (IntermediateResultPartition originalPartition : originalPartitions) { + SchedulingResultPartition adaptedPartition = adaptedPartitions.stream() + .filter(adapted -> adapted.getId().equals(originalPartition.getPartitionId())) + .findAny() + .orElseThrow(() -> new AssertionError("Could not find matching adapted partition for " + originalPartition)); + + assertPartitionEquals(originalPartition, adaptedPartition); + + List originalConsumers = originalPartition.getConsumers().stream() + .flatMap(Collection::stream) + .map(ExecutionEdge::getTarget) + .collect(Collectors.toList()); + Collection adaptedConsumers = adaptedPartition.getConsumers(); + + for (ExecutionVertex originalConsumer : originalConsumers) { + // it is sufficient to verify that some vertex exists with the correct ID here, + // since deep equality is verified later in the main loop + // this DOES rely on an implicit assumption that the vertices objects returned by the topology are + // identical to those stored in the partition + ExecutionVertexID originalId = new ExecutionVertexID(originalConsumer.getJobvertexId(), originalConsumer.getParallelSubtaskIndex()); + assertTrue(adaptedConsumers.stream().anyMatch(adaptedConsumer -> adaptedConsumer.getId().equals(originalId))); + } + } + } + + private static void assertPartitionEquals( + IntermediateResultPartition originalPartition, + SchedulingResultPartition adaptedPartition) { + + assertEquals(originalPartition.getPartitionId(), adaptedPartition.getId()); + assertEquals(originalPartition.getIntermediateResult().getId(), adaptedPartition.getResultId()); + assertEquals(originalPartition.getResultType(), adaptedPartition.getPartitionType()); + assertVertexEquals( + originalPartition.getProducer(), + adaptedPartition.getProducer()); + } + + private static void assertVertexEquals( + ExecutionVertex originalVertex, + SchedulingExecutionVertex adaptedVertex) { + assertEquals( + new ExecutionVertexID(originalVertex.getJobvertexId(), originalVertex.getParallelSubtaskIndex()), + adaptedVertex.getId()); + } +} From be7f6db3c25bec9b1f611cd5bc50602782d9a000 Mon Sep 17 00:00:00 2001 From: leesf <490081539@qq.com> Date: Mon, 27 May 2019 20:16:14 +0800 Subject: [PATCH 06/92] [FLINK-12267][runtime] Port SimpleSlotTest to new code base --- .../runtime/instance/SimpleSlotTest.java | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SimpleSlotTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SimpleSlotTest.java index 89fa90a581b9db..affe5cb6b8e606 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SimpleSlotTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/instance/SimpleSlotTest.java @@ -18,16 +18,15 @@ package org.apache.flink.runtime.instance; -import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway; +import org.apache.flink.runtime.jobmanager.slots.TestingSlotOwner; +import org.apache.flink.runtime.jobmaster.LogicalSlot; import org.apache.flink.runtime.jobmaster.TestingPayload; -import org.apache.flink.runtime.taskmanager.TaskManagerLocation; +import org.apache.flink.runtime.taskmanager.LocalTaskManagerLocation; import org.apache.flink.util.TestLogger; import org.junit.Test; -import java.net.InetAddress; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -44,7 +43,7 @@ public void testStateTransitions() { SimpleSlot slot = getSlot(); assertTrue(slot.isAlive()); - slot.releaseSlot(); + slot.releaseSlot(null); assertFalse(slot.isAlive()); assertTrue(slot.isCanceled()); assertTrue(slot.isReleased()); @@ -112,7 +111,7 @@ public void testSetExecutionVertex() { // assign to released { SimpleSlot slot = getSlot(); - slot.releaseSlot(); + slot.releaseSlot(null); assertFalse(slot.tryAssignPayload(payload1)); assertNull(slot.getPayload()); @@ -124,19 +123,13 @@ public void testSetExecutionVertex() { } } - public static SimpleSlot getSlot() throws Exception { - ResourceID resourceID = ResourceID.generate(); - HardwareDescription hardwareDescription = new HardwareDescription(4, 2L*1024*1024*1024, 1024*1024*1024, 512*1024*1024); - InetAddress address = InetAddress.getByName("127.0.0.1"); - TaskManagerLocation connection = new TaskManagerLocation(resourceID, address, 10001); - - Instance instance = new Instance( - new SimpleAckingTaskManagerGateway(), - connection, - new InstanceID(), - hardwareDescription, - 1); - - return instance.allocateSimpleSlot(); + public static SimpleSlot getSlot() { + final TestingSlotOwner slotOwner = new TestingSlotOwner(); + slotOwner.setReturnAllocatedSlotConsumer((LogicalSlot logicalSlot) -> ((SimpleSlot) logicalSlot).markReleased()); + return new SimpleSlot( + slotOwner, + new LocalTaskManagerLocation(), + 0, + new SimpleAckingTaskManagerGateway()); } } From 5f27cb2600027de18f3c4891f561f38dd60f1d88 Mon Sep 17 00:00:00 2001 From: Chesnay Schepler Date: Mon, 27 May 2019 11:48:15 +0200 Subject: [PATCH 07/92] [FLINK-12618][build] Rework jdk.tools exclusion Replicate jdk.tools exclusion in every module that requires them. Remove exclusion from root pom to prevent side-effects. --- .../flink-connector-filesystem/pom.xml | 26 +++++++++++ flink-connectors/flink-hbase/pom.xml | 23 ++++++++++ flink-filesystems/flink-hadoop-fs/pom.xml | 27 ++++++++++++ flink-fs-tests/pom.xml | 25 +++++++++++ flink-shaded-hadoop/pom.xml | 24 ++++++++++ flink-yarn/pom.xml | 24 ++++++++++ pom.xml | 44 ------------------- 7 files changed, 149 insertions(+), 44 deletions(-) diff --git a/flink-connectors/flink-connector-filesystem/pom.xml b/flink-connectors/flink-connector-filesystem/pom.xml index a0e88007d60961..950664b8274605 100644 --- a/flink-connectors/flink-connector-filesystem/pom.xml +++ b/flink-connectors/flink-connector-filesystem/pom.xml @@ -161,4 +161,30 @@ under the License. + + + java9 + + 9 + + + + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + test-jar + + + + jdk.tools + jdk.tools + + + + + + + diff --git a/flink-connectors/flink-hbase/pom.xml b/flink-connectors/flink-hbase/pom.xml index dd3589ab5296a3..d41bb45c30bf96 100644 --- a/flink-connectors/flink-hbase/pom.xml +++ b/flink-connectors/flink-hbase/pom.xml @@ -315,6 +315,29 @@ under the License. + + java9 + + 9 + + + + + + org.apache.hadoop + hadoop-minicluster + test + + + + jdk.tools + jdk.tools + + + + + + diff --git a/flink-filesystems/flink-hadoop-fs/pom.xml b/flink-filesystems/flink-hadoop-fs/pom.xml index 23c6e3be0dd6aa..fb42546ecec690 100644 --- a/flink-filesystems/flink-hadoop-fs/pom.xml +++ b/flink-filesystems/flink-hadoop-fs/pom.xml @@ -91,4 +91,31 @@ under the License. + + + + java9 + + 9 + + + + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + test-jar + + + + jdk.tools + jdk.tools + + + + + + + diff --git a/flink-fs-tests/pom.xml b/flink-fs-tests/pom.xml index fa3aeb81795433..4c5e3e518e3960 100644 --- a/flink-fs-tests/pom.xml +++ b/flink-fs-tests/pom.xml @@ -142,5 +142,30 @@ under the License. + + + java9 + + 9 + + + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + test-jar + + + + jdk.tools + jdk.tools + + + + + + + diff --git a/flink-shaded-hadoop/pom.xml b/flink-shaded-hadoop/pom.xml index aa4ebcf3dc9c25..e480039680d64e 100644 --- a/flink-shaded-hadoop/pom.xml +++ b/flink-shaded-hadoop/pom.xml @@ -191,5 +191,29 @@ under the License. + + + java9 + + 9 + + + + + + org.apache.hadoop + hadoop-annotations + + + + jdk.tools + jdk.tools + + + + + + + diff --git a/flink-yarn/pom.xml b/flink-yarn/pom.xml index a6a5a2304f0ca4..25285dbb85a298 100644 --- a/flink-yarn/pom.xml +++ b/flink-yarn/pom.xml @@ -155,6 +155,30 @@ under the License. + + java9 + + 9 + + + + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + test-jar + + + + jdk.tools + jdk.tools + + + + + + diff --git a/pom.xml b/pom.xml index 38b3601f18a5bb..ba3fe72f4efe9a 100644 --- a/pom.xml +++ b/pom.xml @@ -741,50 +741,6 @@ under the License. 9 - - - - org.apache.hadoop - hadoop-common - ${hadoop.version} - - - - jdk.tools - jdk.tools - - - - - - org.apache.hadoop - hadoop-common - ${hadoop.version} - test-jar - - - - jdk.tools - jdk.tools - - - - - - org.apache.hadoop - hadoop-annotations - ${hadoop.version} - - - - jdk.tools - jdk.tools - - - - - - From 9d1c1d55841479cd3beffd22870714a31ab9ac46 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Wed, 22 May 2019 10:25:02 +0200 Subject: [PATCH 08/92] [FLINK-12254][table] Update cast() and TypeLiteralExpression to new type system --- .../table/expressions/ApiExpressionUtils.java | 30 +++------------ .../expressions/TypeLiteralExpression.java | 21 +++++----- .../table/expressions/ExpressionBuilder.java | 3 +- .../table/expressions/RexNodeConverter.java | 7 +++- .../flink/table/sources/TableSourceUtil.scala | 19 +++++----- .../sources/tsextractors/ExistingField.scala | 15 +++++--- .../rules/ResolveCallByArgumentsRule.java | 5 ++- .../flink/table/api/scala/expressionDsl.scala | 25 ++++++++---- .../PlannerExpressionConverter.scala | 3 +- .../PlannerExpressionParserImpl.scala | 38 +++++++++++++++---- 10 files changed, 95 insertions(+), 71 deletions(-) diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java index 1ef8db301bc881..47cfe77ed9a343 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java @@ -23,14 +23,12 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.types.DataType; import org.apache.flink.table.typeutils.TimeIntervalTypeInfo; import java.util.Arrays; import java.util.Optional; -import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.CAST; -import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.TIMES; - /** * Utilities for API-specific {@link Expression}s. */ @@ -61,8 +59,8 @@ public static ValueLiteralExpression valueLiteral(Object value, TypeInformation< return new ValueLiteralExpression(value, type); } - public static TypeLiteralExpression typeLiteral(TypeInformation type) { - return new TypeLiteralExpression(type); + public static TypeLiteralExpression typeLiteral(DataType dataType) { + return new TypeLiteralExpression(dataType); } public static SymbolExpression symbol(TableSymbol symbol) { @@ -85,17 +83,7 @@ public static Expression toMonthInterval(Expression e, int multiplier) { // check for constant return ExpressionUtils.extractValue(e, BasicTypeInfo.INT_TYPE_INFO) .map((v) -> (Expression) valueLiteral(v * multiplier, TimeIntervalTypeInfo.INTERVAL_MONTHS)) - .orElse( - call( - CAST, - call( - TIMES, - e, - valueLiteral(multiplier) - ), - typeLiteral(TimeIntervalTypeInfo.INTERVAL_MONTHS) - ) - ); + .orElseThrow(() -> new ValidationException("Only constant intervals are supported: " + e)); } public static Expression toMilliInterval(Expression e, long multiplier) { @@ -110,15 +98,7 @@ public static Expression toMilliInterval(Expression e, long multiplier) { } else if (longInterval.isPresent()) { return longInterval.get(); } - return call( - CAST, - call( - TIMES, - e, - valueLiteral(multiplier) - ), - typeLiteral(TimeIntervalTypeInfo.INTERVAL_MONTHS) - ); + throw new ValidationException("Only constant intervals are supported:" + e); } public static Expression toRowInterval(Expression e) { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java index e7ff10cc8bbb47..f50ab006307a76 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java @@ -19,8 +19,7 @@ package org.apache.flink.table.expressions; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.table.utils.TypeStringUtils; +import org.apache.flink.table.types.DataType; import org.apache.flink.util.Preconditions; import java.util.Collections; @@ -28,19 +27,19 @@ import java.util.Objects; /** - * Expression that wraps {@link TypeInformation} as a literal. + * Expression that wraps {@link DataType} as a literal. */ @PublicEvolving public final class TypeLiteralExpression implements Expression { - private final TypeInformation type; + private final DataType dataType; - public TypeLiteralExpression(TypeInformation type) { - this.type = Preconditions.checkNotNull(type); + public TypeLiteralExpression(DataType dataType) { + this.dataType = Preconditions.checkNotNull(dataType, "Data type must not be null."); } - public TypeInformation getType() { - return type; + public DataType getDataType() { + return dataType; } @Override @@ -62,16 +61,16 @@ public boolean equals(Object o) { return false; } TypeLiteralExpression that = (TypeLiteralExpression) o; - return Objects.equals(type, that.type); + return dataType.equals(that.dataType); } @Override public int hashCode() { - return Objects.hash(type); + return Objects.hash(dataType); } @Override public String toString() { - return TypeStringUtils.writeTypeInfo(type); + return dataType.toString(); } } diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/ExpressionBuilder.java index cc5333e7dcec90..3f011cbe40247b 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/ExpressionBuilder.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/ExpressionBuilder.java @@ -41,6 +41,7 @@ import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.REINTERPRET_CAST; import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.TIMES; import static org.apache.flink.table.expressions.InternalFunctionDefinitions.THROW_EXCEPTION; +import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType; /** * Builder for {@link Expression}s. @@ -130,7 +131,7 @@ public static Expression reinterpretCast(Expression child, Expression type, } public static TypeLiteralExpression typeLiteral(TypeInformation type) { - return new TypeLiteralExpression(type); + return new TypeLiteralExpression(fromLegacyInfoToDataType(type)); } public static Expression concat(Expression input1, Expression input2) { diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java index bbd3d129fc011a..034ae67ad92348 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java @@ -48,6 +48,7 @@ import static org.apache.calcite.sql.type.SqlTypeName.VARCHAR; import static org.apache.flink.table.calcite.FlinkTypeFactory.toInternalType; import static org.apache.flink.table.type.TypeConverters.createInternalTypeFromTypeInfo; +import static org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo; import static org.apache.flink.table.typeutils.TypeCheckUtils.isString; import static org.apache.flink.table.typeutils.TypeCheckUtils.isTemporal; import static org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval; @@ -88,7 +89,8 @@ private RexNode visitScalarFunc(CallExpression call) { TypeLiteralExpression type = (TypeLiteralExpression) call.getChildren().get(1); return relBuilder.getRexBuilder().makeAbstractCast( typeFactory.createTypeFromInternalType( - createInternalTypeFromTypeInfo(type.getType()), + createInternalTypeFromTypeInfo( + fromDataTypeToLegacyInfo(type.getDataType())), child.getType().isNullable()), child); } else if (call.getFunctionDefinition().equals(BuiltInFunctionDefinitions.REINTERPRET_CAST)) { @@ -97,7 +99,8 @@ private RexNode visitScalarFunc(CallExpression call) { RexNode checkOverflow = call.getChildren().get(2).accept(this); return relBuilder.getRexBuilder().makeReinterpretCast( typeFactory.createTypeFromInternalType( - createInternalTypeFromTypeInfo(type.getType()), + createInternalTypeFromTypeInfo( + fromDataTypeToLegacyInfo(type.getDataType())), child.getType().isNullable()), child, checkOverflow); diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala index f6b22f6ed1a605..a4dc4c74417f82 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala @@ -18,15 +18,6 @@ package org.apache.flink.table.sources -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.table.`type`.TypeConverters -import org.apache.flink.table.api.{Types, ValidationException} -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.`type`.InternalType -import org.apache.flink.table.`type`.InternalTypes.{BYTE, PROCTIME_BATCH_MARKER, PROCTIME_INDICATOR, PROCTIME_STREAM_MARKER, ROWTIME_BATCH_MARKER, ROWTIME_INDICATOR, ROWTIME_STREAM_MARKER} -import org.apache.flink.table.expressions.{BuiltInFunctionDefinitions, CallExpression, PlannerResolvedFieldReference, ResolvedFieldReference, RexNodeConverter, TypeLiteralExpression} - import com.google.common.collect.ImmutableList import org.apache.calcite.plan.RelOptCluster import org.apache.calcite.rel.RelNode @@ -34,6 +25,14 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.logical.LogicalValues import org.apache.calcite.rex.{RexLiteral, RexNode} import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.table.`type`.InternalTypes._ +import org.apache.flink.table.`type`.{InternalType, TypeConverters} +import org.apache.flink.table.api.{Types, ValidationException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.expressions._ +import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType import scala.collection.JavaConversions._ @@ -277,7 +276,7 @@ object TableSourceUtil { // add cast to requested type and convert expression to RexNode val castExpression = new CallExpression( BuiltInFunctionDefinitions.CAST, - List(expression, new TypeLiteralExpression(resultType))) + List(expression, new TypeLiteralExpression(fromLegacyInfoToDataType(resultType)))) val rexExpression = castExpression.accept(new RexNodeConverter(relBuilder)) relBuilder.clear() rexExpression diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/tsextractors/ExistingField.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/tsextractors/ExistingField.scala index 1e1388e6ec5b2c..d76334c85812b2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/tsextractors/ExistingField.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/tsextractors/ExistingField.scala @@ -18,15 +18,16 @@ package org.apache.flink.table.sources.tsextractors +import java.util + import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.`type`.DecimalType import org.apache.flink.table.api.{Types, ValidationException} import org.apache.flink.table.descriptors.Rowtime import org.apache.flink.table.expressions._ -import org.apache.flink.table.`type`.DecimalType +import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType import org.apache.flink.table.typeutils.DecimalTypeInfo -import java.util - import scala.collection.JavaConversions._ /** @@ -80,13 +81,17 @@ final class ExistingField(val field: String) extends TimestampExtractor { ) new CallExpression( BuiltInFunctionDefinitions.CAST, - List(innerDiv, new TypeLiteralExpression(Types.SQL_TIMESTAMP))) + List( + innerDiv, + new TypeLiteralExpression(fromLegacyInfoToDataType(Types.SQL_TIMESTAMP)))) case Types.SQL_TIMESTAMP => fieldReferenceExpr case Types.STRING => new CallExpression( BuiltInFunctionDefinitions.CAST, - List(fieldReferenceExpr, new TypeLiteralExpression(Types.SQL_TIMESTAMP))) + List( + fieldReferenceExpr, + new TypeLiteralExpression(fromLegacyInfoToDataType(Types.SQL_TIMESTAMP)))) } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java index bebb68336fe1a7..7749b75b44d9ad 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/expressions/rules/ResolveCallByArgumentsRule.java @@ -36,6 +36,7 @@ import static java.util.Arrays.asList; import static org.apache.flink.table.expressions.ApiExpressionUtils.typeLiteral; +import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType; import static org.apache.flink.table.util.JavaScalaConversionUtil.toJava; /** @@ -108,7 +109,9 @@ private Expression castIfNeeded(PlannerExpression childExpression, TypeInformati } else if (TypeCoercion.canSafelyCast(actualType, expectedType)) { return new CallExpression( BuiltInFunctionDefinitions.CAST, - asList(childExpression, typeLiteral(expectedType)) + asList( + childExpression, + typeLiteral(fromLegacyInfoToDataType(expectedType))) ); } else { throw new ValidationException(String.format("Incompatible type of argument: %s Expected: %s", diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index a38d9569cfa1b8..34d228f1797ab2 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -24,10 +24,12 @@ import java.sql.{Date, Time, Timestamp} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.table.api.{Over, Table, ValidationException} import org.apache.flink.table.expressions.ApiExpressionUtils._ -import org.apache.flink.table.expressions.BuiltInFunctionDefinitions.{WITH_COLUMNS, RANGE_TO, E => FDE, UUID => FDUUID, _} +import org.apache.flink.table.expressions.BuiltInFunctionDefinitions.{RANGE_TO, WITH_COLUMNS, E => FDE, UUID => FDUUID, _} import org.apache.flink.table.expressions._ import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getResultTypeOfAggregateFunction} import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedAggregateFunction} +import org.apache.flink.table.types.DataType +import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType import _root_.scala.language.implicitConversions @@ -262,15 +264,22 @@ trait ImplicitExpressionOperations { def collect: Expression = call(COLLECT, expr) /** - * Converts a value to a given type. + * Converts a value to a given data type. * - * e.g. "42".cast(Types.INT) leads to 42. + * e.g. "42".cast(DataTypes.INT()) leads to 42. * * @return casted expression */ - def cast(toType: TypeInformation[_]): Expression = + def cast(toType: DataType): Expression = call(CAST, expr, typeLiteral(toType)) + /** + * @deprecated Use [[cast(DataType)]] instead. + */ + @deprecated + def cast(toType: TypeInformation[_]): Expression = + call(CAST, expr, typeLiteral(fromLegacyInfoToDataType(toType))) + /** * Specifies a name for an expression i.e. a field. * @@ -683,18 +692,20 @@ trait ImplicitExpressionOperations { /** * Parses a date string in the form "yyyy-MM-dd" to a SQL Date. */ - def toDate: Expression = call(CAST, expr, typeLiteral(SqlTimeTypeInfo.DATE)) + def toDate: Expression = + call(CAST, expr, typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.DATE))) /** * Parses a time string in the form "HH:mm:ss" to a SQL Time. */ - def toTime: Expression = call(CAST, expr, typeLiteral(SqlTimeTypeInfo.TIME)) + def toTime: Expression = + call(CAST, expr, typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.TIME))) /** * Parses a timestamp string in the form "yyyy-MM-dd HH:mm:ss[.SSS]" to a SQL Timestamp. */ def toTimestamp: Expression = - call(CAST, expr, typeLiteral(SqlTimeTypeInfo.TIMESTAMP)) + call(CAST, expr, typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.TIMESTAMP))) /** * Extracts parts of a time point or time interval. Returns the part as a long value. diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala index 2eb870643d099c..01340481e57704 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala @@ -21,6 +21,7 @@ package org.apache.flink.table.expressions import org.apache.flink.table.api.{TableException, ValidationException} import org.apache.flink.table.expressions.BuiltInFunctionDefinitions._ import org.apache.flink.table.expressions.{E => PlannerE, UUID => PlannerUUID} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import _root_.scala.collection.JavaConverters._ @@ -39,7 +40,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp assert(children.size == 2) return Cast( children.head.accept(this), - children(1).asInstanceOf[TypeLiteralExpression].getType) + fromDataTypeToLegacyInfo(children(1).asInstanceOf[TypeLiteralExpression].getDataType)) case WINDOW_START => assert(children.size == 1) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionParserImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionParserImpl.scala index 3850c3fca21a6a..74c0089c7387d4 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionParserImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionParserImpl.scala @@ -23,6 +23,7 @@ import _root_.java.util.{List => JList} import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} import org.apache.flink.table.api._ import org.apache.flink.table.expressions.ApiExpressionUtils._ +import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType import _root_.scala.collection.JavaConversions._ import _root_.scala.language.implicitConversions @@ -260,7 +261,7 @@ object PlannerExpressionParserImpl extends JavaTokenParsers lazy val suffixCast: PackratParser[Expression] = composite ~ "." ~ CAST ~ "(" ~ dataType ~ ")" ^^ { case e ~ _ ~ _ ~ _ ~ dt ~ _ => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(dt)) + call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(fromLegacyInfoToDataType(dt))) } lazy val suffixTrim: PackratParser[Expression] = @@ -331,17 +332,26 @@ object PlannerExpressionParserImpl extends JavaTokenParsers lazy val suffixToDate: PackratParser[Expression] = composite <~ "." ~ TO_DATE ~ opt("()") ^^ { e => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(SqlTimeTypeInfo.DATE)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.DATE))) } lazy val suffixToTimestamp: PackratParser[Expression] = composite <~ "." ~ TO_TIMESTAMP ~ opt("()") ^^ { e => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(SqlTimeTypeInfo.TIMESTAMP)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.TIMESTAMP))) } lazy val suffixToTime: PackratParser[Expression] = composite <~ "." ~ TO_TIME ~ opt("()") ^^ { e => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(SqlTimeTypeInfo.TIME)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.TIME))) } lazy val suffixTimeInterval : PackratParser[Expression] = @@ -420,7 +430,10 @@ object PlannerExpressionParserImpl extends JavaTokenParsers lazy val prefixCast: PackratParser[Expression] = CAST ~ "(" ~ expression ~ "," ~ dataType ~ ")" ^^ { case _ ~ _ ~ e ~ _ ~ dt ~ _ => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(dt)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(dt))) } lazy val prefixIf: PackratParser[Expression] = @@ -500,17 +513,26 @@ object PlannerExpressionParserImpl extends JavaTokenParsers lazy val prefixToDate: PackratParser[Expression] = TO_DATE ~ "(" ~> expression <~ ")" ^^ { e => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(SqlTimeTypeInfo.DATE)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.DATE))) } lazy val prefixToTimestamp: PackratParser[Expression] = TO_TIMESTAMP ~ "(" ~> expression <~ ")" ^^ { e => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(SqlTimeTypeInfo.TIMESTAMP)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.TIMESTAMP))) } lazy val prefixToTime: PackratParser[Expression] = TO_TIME ~ "(" ~> expression <~ ")" ^^ { e => - call(BuiltInFunctionDefinitions.CAST, e, typeLiteral(SqlTimeTypeInfo.TIME)) + call( + BuiltInFunctionDefinitions.CAST, + e, + typeLiteral(fromLegacyInfoToDataType(SqlTimeTypeInfo.TIME))) } lazy val prefixDistinct: PackratParser[Expression] = From b336b6d70d861b94995ebc9f863835fb4bc7fabb Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 27 May 2019 11:11:01 +0200 Subject: [PATCH 09/92] [FLINK-12254][table] Improve documentation about the deprecated type system This closes #8510. --- .../apache/flink/table/api/TableSchema.java | 24 +++++++++++++++---- .../flink/table/expressions/Expression.java | 2 +- .../expressions/TypeLiteralExpression.java | 4 ++++ .../flink/table/api/scala/expressionDsl.scala | 8 +++++-- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/TableSchema.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/TableSchema.java index c6804f9a386cd9..8d6122f4d2ca05 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/TableSchema.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/TableSchema.java @@ -118,7 +118,11 @@ public DataType[] getFieldDataTypes() { } /** - * @deprecated Use {@link #getFieldDataTypes()} instead. + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use {@link #getFieldDataTypes()} instead which uses the new type + * system based on {@link DataTypes}. Please make sure to use either the old or the new + * type system consistently to avoid unintended behavior. See the website documentation + * for more information. */ @Deprecated public TypeInformation[] getFieldTypes() { @@ -138,7 +142,11 @@ public Optional getFieldDataType(int fieldIndex) { } /** - * @deprecated Use {@link #getFieldDataType(int)}} instead. + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use {@link #getFieldDataType(int)} instead which uses the new type + * system based on {@link DataTypes}. Please make sure to use either the old or the new + * type system consistently to avoid unintended behavior. See the website documentation + * for more information. */ @Deprecated public Optional> getFieldType(int fieldIndex) { @@ -159,7 +167,11 @@ public Optional getFieldDataType(String fieldName) { } /** - * @deprecated Use {@link #getFieldDataType(String)} instead. + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use {@link #getFieldDataType(String)} instead which uses the new type + * system based on {@link DataTypes}. Please make sure to use either the old or the new + * type system consistently to avoid unintended behavior. See the website documentation + * for more information. */ @Deprecated public Optional> getFieldType(String fieldName) { @@ -306,7 +318,11 @@ public Builder field(String name, DataType dataType) { } /** - * @deprecated Use {@link #field(String, DataType)} instead. + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use {@link #field(String, DataType)} instead which uses the new type + * system based on {@link DataTypes}. Please make sure to use either the old or the new + * type system consistently to avoid unintended behavior. See the website documentation + * for more information. */ @Deprecated public Builder field(String name, TypeInformation typeInfo) { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/Expression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/Expression.java index 57c73d48d6706d..b15d6927ef52ab 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/Expression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/Expression.java @@ -30,7 +30,7 @@ * consists of zero, one, or more sub-expressions. Expressions might be literal values, function calls, * or field references. * - *

Expressions are part of the API. Thus, values and return types are expressed as instances of + *

Expressions are part of the API. Thus, value types and return types are expressed as instances of * {@link DataType}. */ @PublicEvolving diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java index f50ab006307a76..c60c4821af16e4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/TypeLiteralExpression.java @@ -28,6 +28,10 @@ /** * Expression that wraps {@link DataType} as a literal. + * + *

Expressing a type is primarily needed for casting operations. This expression simplifies the + * {@link Expression} design as it makes {@link CallExpression} the only expression that takes + * subexpressions. */ @PublicEvolving public final class TypeLiteralExpression implements Expression { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index 34d228f1797ab2..8d77d883c204d8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Time, Timestamp} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation} -import org.apache.flink.table.api.{Over, Table, ValidationException} +import org.apache.flink.table.api.{DataTypes, Over, Table, ValidationException} import org.apache.flink.table.expressions.ApiExpressionUtils._ import org.apache.flink.table.expressions.BuiltInFunctionDefinitions.{RANGE_TO, WITH_COLUMNS, E => FDE, UUID => FDUUID, _} import org.apache.flink.table.expressions._ @@ -274,7 +274,11 @@ trait ImplicitExpressionOperations { call(CAST, expr, typeLiteral(toType)) /** - * @deprecated Use [[cast(DataType)]] instead. + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use [[cast(DataType)]] instead which uses the new type system + * based on [[DataTypes]]. Please make sure to use either the old or the new type + * system consistently to avoid unintended behavior. See the website documentation + * for more information. */ @deprecated def cast(toType: TypeInformation[_]): Expression = From 3a5bf89384ed07431d15285ef40e751daf9d0c83 Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 11 Jan 2019 17:53:40 +0800 Subject: [PATCH 10/92] [FLINK-11283] Accessing the key when processing connected keyed stream --- docs/dev/stream/operators/process_function.md | 2 +- docs/dev/stream/side_output.md | 1 + .../api/datastream/ConnectedStreams.java | 72 ++- .../functions/co/KeyedCoProcessFunction.java | 152 +++++ .../operators/co/KeyedCoProcessOperator.java | 34 +- .../co/LegacyKeyedCoProcessOperator.java | 193 ++++++ .../co/KeyedCoProcessOperatorTest.java | 77 ++- .../co/LegacyKeyedCoProcessOperatorTest.java | 576 ++++++++++++++++++ .../api/scala/ConnectedStreams.scala | 31 +- .../stream/StreamExecWindowJoin.scala | 6 +- .../DataStreamJoinToCoProcessTranslator.scala | 4 +- ...dCoProcessOperatorWithWatermarkDelay.scala | 6 +- .../runtime/harness/JoinHarnessTest.scala | 52 +- ...edCoProcessOperatorWithWatermarkDelay.java | 3 +- .../join/ProcTimeBoundedStreamJoinTest.java | 4 +- .../join/RowTimeBoundedStreamJoinTest.java | 4 +- .../streaming/runtime/SideOutputITCase.java | 101 ++- 17 files changed, 1246 insertions(+), 72 deletions(-) create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedCoProcessFunction.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperator.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperatorTest.java diff --git a/docs/dev/stream/operators/process_function.md b/docs/dev/stream/operators/process_function.md index 0216d845d45c2e..cea5533c85863b 100644 --- a/docs/dev/stream/operators/process_function.md +++ b/docs/dev/stream/operators/process_function.md @@ -58,7 +58,7 @@ stream.keyBy(...).process(new MyProcessFunction()) ## Low-level Joins -To realize low-level operations on two inputs, applications can use `CoProcessFunction`. This +To realize low-level operations on two inputs, applications can use `CoProcessFunction` or `KeyedCoProcessFunction`. This function is bound to two different inputs and gets individual calls to `processElement1(...)` and `processElement2(...)` for records from the two different inputs. diff --git a/docs/dev/stream/side_output.md b/docs/dev/stream/side_output.md index 7ca195d7022472..faa9d312fc363d 100644 --- a/docs/dev/stream/side_output.md +++ b/docs/dev/stream/side_output.md @@ -60,6 +60,7 @@ Emitting data to a side output is possible from the following functions: - [ProcessFunction]({{ site.baseurl }}/dev/stream/operators/process_function.html) - [KeyedProcessFunction]({{ site.baseurl }}/dev/stream/operators/process_function.html#the-keyedprocessfunction) - CoProcessFunction +- KeyedCoProcessFunction - [ProcessWindowFunction]({{ site.baseurl }}/dev/stream/operators/windows.html#processwindowfunction) - ProcessAllWindowFunction diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java index 0ada54a8f4cc9b..d4a34c96fed768 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/ConnectedStreams.java @@ -28,11 +28,13 @@ import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.api.operators.co.CoProcessOperator; import org.apache.flink.streaming.api.operators.co.CoStreamFlatMap; import org.apache.flink.streaming.api.operators.co.CoStreamMap; import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator; +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator; import org.apache.flink.streaming.api.transformations.TwoInputTransformation; import static java.util.Objects.requireNonNull; @@ -331,7 +333,7 @@ public SingleOutputStreamOperator process( TwoInputStreamOperator operator; if ((inputStream1 instanceof KeyedStream) && (inputStream2 instanceof KeyedStream)) { - operator = new KeyedCoProcessOperator<>(inputStream1.clean(coProcessFunction)); + operator = new LegacyKeyedCoProcessOperator<>(inputStream1.clean(coProcessFunction)); } else { operator = new CoProcessOperator<>(inputStream1.clean(coProcessFunction)); } @@ -339,6 +341,74 @@ public SingleOutputStreamOperator process( return transform("Co-Process", outputType, operator); } + /** + * Applies the given {@link KeyedCoProcessFunction} on the connected input keyed streams, + * thereby creating a transformed output stream. + * + *

The function will be called for every element in the input keyed streams and can produce zero or + * more output elements. Contrary to the {@link #flatMap(CoFlatMapFunction)} function, this + * function can also query the time and set timers. When reacting to the firing of set timers + * the function can directly emit elements and/or register yet more timers. + * + * @param keyedCoProcessFunction The {@link KeyedCoProcessFunction} that is called for each element + * in the stream. + * + * @param The type of elements emitted by the {@code CoProcessFunction}. + * + * @return The transformed {@link DataStream}. + */ + @PublicEvolving + public SingleOutputStreamOperator process( + KeyedCoProcessFunction keyedCoProcessFunction) { + + TypeInformation outTypeInfo = TypeExtractor.getBinaryOperatorReturnType( + keyedCoProcessFunction, + KeyedCoProcessFunction.class, + 1, + 2, + 3, + TypeExtractor.NO_INDEX, + getType1(), + getType2(), + Utils.getCallLocationName(), + true); + + return process(keyedCoProcessFunction, outTypeInfo); + } + + /** + * Applies the given {@link KeyedCoProcessFunction} on the connected input streams, + * thereby creating a transformed output stream. + * + *

The function will be called for every element in the input streams and can produce zero + * or more output elements. Contrary to the {@link #flatMap(CoFlatMapFunction)} function, + * this function can also query the time and set timers. When reacting to the firing of set + * timers the function can directly emit elements and/or register yet more timers. + * + * @param keyedCoProcessFunction The {@link KeyedCoProcessFunction} that is called for each element + * in the stream. + * + * @param The type of elements emitted by the {@code CoProcessFunction}. + * + * @return The transformed {@link DataStream}. + */ + @Internal + public SingleOutputStreamOperator process( + KeyedCoProcessFunction keyedCoProcessFunction, + TypeInformation outputType) { + + TwoInputStreamOperator operator; + + if ((inputStream1 instanceof KeyedStream) && (inputStream2 instanceof KeyedStream)) { + operator = new KeyedCoProcessOperator<>(inputStream1.clean(keyedCoProcessFunction)); + } else { + throw new UnsupportedOperationException("KeyedCoProcessFunction can only be used " + + "when both input streams are of type KeyedStream."); + } + + return transform("Co-Keyed-Process", outputType, operator); + } + @PublicEvolving public SingleOutputStreamOperator transform(String functionName, TypeInformation outTypeInfo, diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedCoProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedCoProcessFunction.java new file mode 100644 index 00000000000000..d7efb67b2db356 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedCoProcessFunction.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.co; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +/** + * A function that processes elements of two keyed streams and produces a single output one. + * + *

The function will be called for every element in the input streams and can produce + * zero or more output elements. Contrary to the {@link CoFlatMapFunction}, this function can also + * query the time (both event and processing) and set timers, through the provided {@link Context}. + * When reacting to the firing of set timers the function can emit yet more elements. + * + *

An example use-case for connected streams would be the application of a set of rules that change + * over time ({@code stream A}) to the elements contained in another stream (stream {@code B}). The rules + * contained in {@code stream A} can be stored in the state and wait for new elements to arrive on + * {@code stream B}. Upon reception of a new element on {@code stream B}, the function can now apply the + * previously stored rules to the element and directly emit a result, and/or register a timer that + * will trigger an action in the future. + * + * @param Type of the key. + * @param Type of the first input. + * @param Type of the second input. + * @param Output type. + */ +@PublicEvolving +public abstract class KeyedCoProcessFunction extends AbstractRichFunction { + + private static final long serialVersionUID = 1L; + + /** + * This method is called for each element in the first of the connected streams. + * + *

This function can output zero or more elements using the {@link Collector} parameter + * and also update internal state or set timers using the {@link Context} parameter. + * + * @param value The stream element + * @param ctx A {@link Context} that allows querying the timestamp of the element, + * querying the {@link TimeDomain} of the firing timer and getting a + * {@link TimerService} for registering timers and querying the time. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector to emit resulting elements to + * @throws Exception The function may throw exceptions which cause the streaming program + * to fail and go into recovery. + */ + public abstract void processElement1(IN1 value, Context ctx, Collector out) throws Exception; + + /** + * This method is called for each element in the second of the connected streams. + * + *

This function can output zero or more elements using the {@link Collector} parameter + * and also update internal state or set timers using the {@link Context} parameter. + * + * @param value The stream element + * @param ctx A {@link Context} that allows querying the timestamp of the element, + * querying the {@link TimeDomain} of the firing timer and getting a + * {@link TimerService} for registering timers and querying the time. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector to emit resulting elements to + * @throws Exception The function may throw exceptions which cause the streaming program + * to fail and go into recovery. + */ + public abstract void processElement2(IN2 value, Context ctx, Collector out) throws Exception; + + /** + * Called when a timer set using {@link TimerService} fires. + * + * @param timestamp The timestamp of the firing timer. + * @param ctx An {@link OnTimerContext} that allows querying the timestamp of the firing timer, + * querying the {@link TimeDomain} of the firing timer and getting a + * {@link TimerService} for registering timers and querying the time. + * The context is only valid during the invocation of this method, do not store it. + * @param out The collector for returning result values. + * + * @throws Exception This method may throw exceptions. Throwing an exception will cause the operation + * to fail and may trigger recovery. + */ + public void onTimer(long timestamp, OnTimerContext ctx, Collector out) throws Exception {} + + /** + * Information available in an invocation of {@link #processElement1(Object, Context, Collector)}/ + * {@link #processElement2(Object, Context, Collector)} + * or {@link #onTimer(long, OnTimerContext, Collector)}. + */ + public abstract class Context { + + /** + * Timestamp of the element currently being processed or timestamp of a firing timer. + * + *

This might be {@code null}, for example if the time characteristic of your program + * is set to {@link org.apache.flink.streaming.api.TimeCharacteristic#ProcessingTime}. + */ + public abstract Long timestamp(); + + /** + * A {@link TimerService} for querying time and registering timers. + */ + public abstract TimerService timerService(); + + /** + * Emits a record to the side output identified by the {@link OutputTag}. + * + * @param outputTag the {@code OutputTag} that identifies the side output to emit to. + * @param value The record to emit. + */ + public abstract void output(OutputTag outputTag, X value); + + /** + * Get key of the element being processed. + */ + public abstract K getCurrentKey(); + } + + /** + * Information available in an invocation of {@link #onTimer(long, OnTimerContext, Collector)}. + */ + public abstract class OnTimerContext extends Context { + /** + * The {@link TimeDomain} of the firing timer. + */ + public abstract TimeDomain timeDomain(); + + /** + * Get key of the firing timer. + */ + @Override + public abstract K getCurrentKey(); + + } + +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperator.java index c7d981cf595fc0..eee3ea0c054cf1 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperator.java @@ -23,7 +23,7 @@ import org.apache.flink.streaming.api.SimpleTimerService; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; -import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimer; import org.apache.flink.streaming.api.operators.InternalTimerService; @@ -38,23 +38,23 @@ /** * A {@link org.apache.flink.streaming.api.operators.StreamOperator} for executing keyed - * {@link CoProcessFunction CoProcessFunctions}. + * {@link KeyedCoProcessFunction KeyedCoProcessFunction}. */ @Internal public class KeyedCoProcessOperator - extends AbstractUdfStreamOperator> + extends AbstractUdfStreamOperator> implements TwoInputStreamOperator, Triggerable { private static final long serialVersionUID = 1L; private transient TimestampedCollector collector; - private transient ContextImpl context; + private transient ContextImpl context; - private transient OnTimerContextImpl onTimerContext; + private transient OnTimerContextImpl onTimerContext; - public KeyedCoProcessOperator(CoProcessFunction coProcessFunction) { - super(coProcessFunction); + public KeyedCoProcessOperator(KeyedCoProcessFunction keyedCoProcessFunction) { + super(keyedCoProcessFunction); } @Override @@ -111,13 +111,13 @@ protected TimestampedCollector getCollector() { return collector; } - private class ContextImpl extends CoProcessFunction.Context { + private class ContextImpl extends KeyedCoProcessFunction.Context { private final TimerService timerService; private StreamRecord element; - ContextImpl(CoProcessFunction function, TimerService timerService) { + ContextImpl(KeyedCoProcessFunction function, TimerService timerService) { function.super(); this.timerService = checkNotNull(timerService); } @@ -146,17 +146,22 @@ public void output(OutputTag outputTag, X value) { output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); } + + @Override + public K getCurrentKey() { + return (K) KeyedCoProcessOperator.this.getCurrentKey(); + } } - private class OnTimerContextImpl extends CoProcessFunction.OnTimerContext { + private class OnTimerContextImpl extends KeyedCoProcessFunction.OnTimerContext { private final TimerService timerService; private TimeDomain timeDomain; - private InternalTimer timer; + private InternalTimer timer; - OnTimerContextImpl(CoProcessFunction function, TimerService timerService) { + OnTimerContextImpl(KeyedCoProcessFunction function, TimerService timerService) { function.super(); this.timerService = checkNotNull(timerService); } @@ -186,5 +191,10 @@ public TimeDomain timeDomain() { checkState(timeDomain != null); return timeDomain; } + + @Override + public K getCurrentKey() { + return timer.getKey(); + } } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperator.java new file mode 100644 index 00000000000000..a08d1a0abfd12e --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperator.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.streaming.api.SimpleTimerService; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.InternalTimer; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A {@link org.apache.flink.streaming.api.operators.StreamOperator} for executing keyed + * {@link CoProcessFunction CoProcessFunctions}. + * + * @deprecated Replaced by {@link KeyedCoProcessOperator} which takes {@code KeyedCoProcessFunction} + */ +@Deprecated +@Internal +public class LegacyKeyedCoProcessOperator + extends AbstractUdfStreamOperator> + implements TwoInputStreamOperator, Triggerable { + + private static final long serialVersionUID = 1L; + + private transient TimestampedCollector collector; + + private transient ContextImpl context; + + private transient OnTimerContextImpl onTimerContext; + + public LegacyKeyedCoProcessOperator(CoProcessFunction flatMapper) { + super(flatMapper); + } + + @Override + public void open() throws Exception { + super.open(); + collector = new TimestampedCollector<>(output); + + InternalTimerService internalTimerService = + getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this); + + TimerService timerService = new SimpleTimerService(internalTimerService); + + context = new ContextImpl<>(userFunction, timerService); + onTimerContext = new OnTimerContextImpl<>(userFunction, timerService); + } + + @Override + public void processElement1(StreamRecord element) throws Exception { + collector.setTimestamp(element); + context.element = element; + userFunction.processElement1(element.getValue(), context, collector); + context.element = null; + } + + @Override + public void processElement2(StreamRecord element) throws Exception { + collector.setTimestamp(element); + context.element = element; + userFunction.processElement2(element.getValue(), context, collector); + context.element = null; + } + + @Override + public void onEventTime(InternalTimer timer) throws Exception { + collector.setAbsoluteTimestamp(timer.getTimestamp()); + onTimerContext.timeDomain = TimeDomain.EVENT_TIME; + onTimerContext.timer = timer; + userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + onTimerContext.timer = null; + } + + @Override + public void onProcessingTime(InternalTimer timer) throws Exception { + collector.eraseTimestamp(); + onTimerContext.timeDomain = TimeDomain.PROCESSING_TIME; + onTimerContext.timer = timer; + userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + onTimerContext.timer = null; + } + + protected TimestampedCollector getCollector() { + return collector; + } + + private class ContextImpl extends CoProcessFunction.Context { + + private final TimerService timerService; + + private StreamRecord element; + + ContextImpl(CoProcessFunction function, TimerService timerService) { + function.super(); + this.timerService = checkNotNull(timerService); + } + + @Override + public Long timestamp() { + checkState(element != null); + + if (element.hasTimestamp()) { + return element.getTimestamp(); + } else { + return null; + } + } + + @Override + public TimerService timerService() { + return timerService; + } + + @Override + public void output(OutputTag outputTag, X value) { + if (outputTag == null) { + throw new IllegalArgumentException("OutputTag must not be null."); + } + + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + } + + private class OnTimerContextImpl extends CoProcessFunction.OnTimerContext { + + private final TimerService timerService; + + private TimeDomain timeDomain; + + private InternalTimer timer; + + OnTimerContextImpl(CoProcessFunction function, TimerService timerService) { + function.super(); + this.timerService = checkNotNull(timerService); + } + + @Override + public Long timestamp() { + checkState(timer != null); + return timer.getTimestamp(); + } + + @Override + public TimerService timerService() { + return timerService; + } + + @Override + public void output(OutputTag outputTag, X value) { + if (outputTag == null) { + throw new IllegalArgumentException("OutputTag must not be null."); + } + + output.collect(outputTag, new StreamRecord<>(value, timer.getTimestamp())); + } + + @Override + public TimeDomain timeDomain() { + checkState(timeDomain != null); + return timeDomain; + } + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperatorTest.java index 1034bfa9e9d61e..caac556cba77d5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/KeyedCoProcessOperatorTest.java @@ -26,7 +26,7 @@ import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; -import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; @@ -144,9 +144,9 @@ public void testEventTimeTimers() throws Exception { expectedOutput.add(new StreamRecord<>("INPUT1:17", 42L)); expectedOutput.add(new StreamRecord<>("INPUT2:18", 42L)); - expectedOutput.add(new StreamRecord<>("1777", 5L)); + expectedOutput.add(new StreamRecord<>("17:1777", 5L)); expectedOutput.add(new Watermark(5L)); - expectedOutput.add(new StreamRecord<>("1777", 6L)); + expectedOutput.add(new StreamRecord<>("18:1777", 6L)); expectedOutput.add(new Watermark(6L)); TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); @@ -334,6 +334,41 @@ public void testSnapshotAndRestore() throws Exception { testHarness.close(); } + @Test + public void testGetCurrentKeyFromContext() throws Exception { + KeyedCoProcessOperator operator = + new KeyedCoProcessOperator<>(new AppendCurrentKeyProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement1(new StreamRecord<>(5)); + testHarness.processElement1(new StreamRecord<>(6)); + testHarness.processElement2(new StreamRecord<>("hello")); + testHarness.processElement2(new StreamRecord<>("world")); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>("5,5")); + expectedOutput.add(new StreamRecord<>("6,6")); + expectedOutput.add(new StreamRecord<>("hello,hello")); + expectedOutput.add(new StreamRecord<>("world,world")); + + TestHarnessUtil.assertOutputEquals( + "Output was not correct.", + expectedOutput, + testHarness.getOutput()); + + testHarness.close(); + } + private static class IntToStringKeySelector implements KeySelector { private static final long serialVersionUID = 1L; @@ -352,7 +387,7 @@ public T getKey(T value) throws Exception { } } - private static class WatermarkQueryingProcessFunction extends CoProcessFunction { + private static class WatermarkQueryingProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -374,7 +409,7 @@ public void onTimer( } } - private static class EventTimeTriggeringProcessFunction extends CoProcessFunction { + private static class EventTimeTriggeringProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -397,11 +432,11 @@ public void onTimer( Collector out) throws Exception { assertEquals(TimeDomain.EVENT_TIME, ctx.timeDomain()); - out.collect("" + 1777); + out.collect(ctx.getCurrentKey() + ":" + 1777); } } - private static class EventTimeTriggeringStatefulProcessFunction extends CoProcessFunction { + private static class EventTimeTriggeringStatefulProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -444,7 +479,7 @@ public void onTimer( } } - private static class ProcessingTimeTriggeringProcessFunction extends CoProcessFunction { + private static class ProcessingTimeTriggeringProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -471,7 +506,7 @@ public void onTimer( } } - private static class ProcessingTimeQueryingProcessFunction extends CoProcessFunction { + private static class ProcessingTimeQueryingProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -493,7 +528,7 @@ public void onTimer( } } - private static class ProcessingTimeTriggeringStatefulProcessFunction extends CoProcessFunction { + private static class ProcessingTimeTriggeringStatefulProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -536,7 +571,7 @@ public void onTimer( } } - private static class BothTriggeringProcessFunction extends CoProcessFunction { + private static class BothTriggeringProcessFunction extends KeyedCoProcessFunction { private static final long serialVersionUID = 1L; @@ -566,4 +601,24 @@ public void onTimer( } } } + + private static class AppendCurrentKeyProcessFunction extends KeyedCoProcessFunction { + + @Override + public void processElement1( + Integer value, + Context ctx, + Collector out) throws Exception { + out.collect(value + "," + ctx.getCurrentKey()); + } + + @Override + public void processElement2( + String value, + Context ctx, + Collector out) throws Exception { + out.collect(value + "," + ctx.getCurrentKey()); + } + } + } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperatorTest.java new file mode 100644 index 00000000000000..3697c571d68bd4 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/LegacyKeyedCoProcessOperatorTest.java @@ -0,0 +1,576 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness; +import org.apache.flink.util.Collector; +import org.apache.flink.util.TestLogger; + +import org.junit.Test; + +import java.io.IOException; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static org.junit.Assert.assertEquals; + +/** + * Tests {@link LegacyKeyedCoProcessOperator}. + */ +public class LegacyKeyedCoProcessOperatorTest extends TestLogger { + + @Test + public void testTimestampAndWatermarkQuerying() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new WatermarkQueryingProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.processWatermark1(new Watermark(17)); + testHarness.processWatermark2(new Watermark(17)); + testHarness.processElement1(new StreamRecord<>(5, 12L)); + + testHarness.processWatermark1(new Watermark(42)); + testHarness.processWatermark2(new Watermark(42)); + testHarness.processElement2(new StreamRecord<>("6", 13L)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new Watermark(17L)); + expectedOutput.add(new StreamRecord<>("5WM:17 TS:12", 12L)); + expectedOutput.add(new Watermark(42L)); + expectedOutput.add(new StreamRecord<>("6WM:42 TS:13", 13L)); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + @Test + public void testTimestampAndProcessingTimeQuerying() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new ProcessingTimeQueryingProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.setProcessingTime(17); + testHarness.processElement1(new StreamRecord<>(5)); + + testHarness.setProcessingTime(42); + testHarness.processElement2(new StreamRecord<>("6")); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>("5PT:17 TS:null")); + expectedOutput.add(new StreamRecord<>("6PT:42 TS:null")); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + @Test + public void testEventTimeTimers() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new EventTimeTriggeringProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement1(new StreamRecord<>(17, 42L)); + testHarness.processElement2(new StreamRecord<>("18", 42L)); + + testHarness.processWatermark1(new Watermark(5)); + testHarness.processWatermark2(new Watermark(5)); + + testHarness.processWatermark1(new Watermark(6)); + testHarness.processWatermark2(new Watermark(6)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>("INPUT1:17", 42L)); + expectedOutput.add(new StreamRecord<>("INPUT2:18", 42L)); + expectedOutput.add(new StreamRecord<>("1777", 5L)); + expectedOutput.add(new Watermark(5L)); + expectedOutput.add(new StreamRecord<>("1777", 6L)); + expectedOutput.add(new Watermark(6L)); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + @Test + public void testProcessingTimeTimers() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new ProcessingTimeTriggeringProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement1(new StreamRecord<>(17)); + testHarness.processElement2(new StreamRecord<>("18")); + + testHarness.setProcessingTime(5); + testHarness.setProcessingTime(6); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>("INPUT1:17")); + expectedOutput.add(new StreamRecord<>("INPUT2:18")); + expectedOutput.add(new StreamRecord<>("1777")); + expectedOutput.add(new StreamRecord<>("1777")); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + /** + * Verifies that we don't have leakage between different keys. + */ + @Test + public void testEventTimeTimerWithState() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new EventTimeTriggeringStatefulProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.processWatermark1(new Watermark(1)); + testHarness.processWatermark2(new Watermark(1)); + testHarness.processElement1(new StreamRecord<>(17, 0L)); // should set timer for 6 + testHarness.processElement1(new StreamRecord<>(13, 0L)); // should set timer for 6 + + testHarness.processWatermark1(new Watermark(2)); + testHarness.processWatermark2(new Watermark(2)); + testHarness.processElement1(new StreamRecord<>(13, 1L)); // should delete timer + testHarness.processElement2(new StreamRecord<>("42", 1L)); // should set timer for 7 + + testHarness.processWatermark1(new Watermark(6)); + testHarness.processWatermark2(new Watermark(6)); + + testHarness.processWatermark1(new Watermark(7)); + testHarness.processWatermark2(new Watermark(7)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new Watermark(1L)); + expectedOutput.add(new StreamRecord<>("INPUT1:17", 0L)); + expectedOutput.add(new StreamRecord<>("INPUT1:13", 0L)); + expectedOutput.add(new Watermark(2L)); + expectedOutput.add(new StreamRecord<>("INPUT2:42", 1L)); + expectedOutput.add(new StreamRecord<>("STATE:17", 6L)); + expectedOutput.add(new Watermark(6L)); + expectedOutput.add(new StreamRecord<>("STATE:42", 7L)); + expectedOutput.add(new Watermark(7L)); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + /** + * Verifies that we don't have leakage between different keys. + */ + @Test + public void testProcessingTimeTimerWithState() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new ProcessingTimeTriggeringStatefulProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.setProcessingTime(1); + testHarness.processElement1(new StreamRecord<>(17)); // should set timer for 6 + testHarness.processElement1(new StreamRecord<>(13)); // should set timer for 6 + + testHarness.setProcessingTime(2); + testHarness.processElement1(new StreamRecord<>(13)); // should delete timer again + testHarness.processElement2(new StreamRecord<>("42")); // should set timer for 7 + + testHarness.setProcessingTime(6); + testHarness.setProcessingTime(7); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>("INPUT1:17")); + expectedOutput.add(new StreamRecord<>("INPUT1:13")); + expectedOutput.add(new StreamRecord<>("INPUT2:42")); + expectedOutput.add(new StreamRecord<>("STATE:17")); + expectedOutput.add(new StreamRecord<>("STATE:42")); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + @Test + public void testSnapshotAndRestore() throws Exception { + + LegacyKeyedCoProcessOperator operator = + new LegacyKeyedCoProcessOperator<>(new BothTriggeringProcessFunction()); + + TwoInputStreamOperatorTestHarness testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement1(new StreamRecord<>(5, 12L)); + testHarness.processElement2(new StreamRecord<>("5", 12L)); + + // snapshot and restore from scratch + OperatorSubtaskState snapshot = testHarness.snapshot(0, 0); + + testHarness.close(); + + operator = new LegacyKeyedCoProcessOperator<>(new BothTriggeringProcessFunction()); + + testHarness = new KeyedTwoInputStreamOperatorTestHarness<>( + operator, + new IntToStringKeySelector<>(), + new IdentityKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + + testHarness.setup(); + testHarness.initializeState(snapshot); + testHarness.open(); + + testHarness.setProcessingTime(5); + testHarness.processWatermark1(new Watermark(6)); + testHarness.processWatermark2(new Watermark(6)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + expectedOutput.add(new StreamRecord<>("PROC:1777")); + expectedOutput.add(new StreamRecord<>("EVENT:1777", 6L)); + expectedOutput.add(new Watermark(6)); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + /** + * A key selector which convert a integer key to string. + * + */ + private static class IntToStringKeySelector implements KeySelector { + private static final long serialVersionUID = 1L; + + @Override + public String getKey(Integer value) throws Exception { + return "" + value; + } + } + + /** + * A identity key selector. + */ + private static class IdentityKeySelector implements KeySelector { + private static final long serialVersionUID = 1L; + + @Override + public T getKey(T value) throws Exception { + return value; + } + } + + private static class WatermarkQueryingProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + out.collect(value + "WM:" + ctx.timerService().currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + out.collect(value + "WM:" + ctx.timerService().currentWatermark() + " TS:" + ctx.timestamp()); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + } + } + + private static class EventTimeTriggeringProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + out.collect("INPUT1:" + value); + ctx.timerService().registerEventTimeTimer(5); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + out.collect("INPUT2:" + value); + ctx.timerService().registerEventTimeTimer(6); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + + assertEquals(TimeDomain.EVENT_TIME, ctx.timeDomain()); + out.collect("" + 1777); + } + } + + private static class EventTimeTriggeringStatefulProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + private final ValueStateDescriptor state = + new ValueStateDescriptor<>("seen-element", StringSerializer.INSTANCE); + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + handleValue(value, out, ctx.timerService(), 1); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + handleValue(value, out, ctx.timerService(), 2); + } + + private void handleValue( + Object value, + Collector out, + TimerService timerService, + int channel) throws IOException { + final ValueState state = getRuntimeContext().getState(this.state); + if (state.value() == null) { + out.collect("INPUT" + channel + ":" + value); + state.update(String.valueOf(value)); + timerService.registerEventTimeTimer(timerService.currentWatermark() + 5); + } else { + state.clear(); + timerService.deleteEventTimeTimer(timerService.currentWatermark() + 4); + } + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + assertEquals(TimeDomain.EVENT_TIME, ctx.timeDomain()); + out.collect("STATE:" + getRuntimeContext().getState(state).value()); + } + } + + private static class ProcessingTimeTriggeringProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + out.collect("INPUT1:" + value); + ctx.timerService().registerProcessingTimeTimer(5); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + out.collect("INPUT2:" + value); + ctx.timerService().registerProcessingTimeTimer(6); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + + assertEquals(TimeDomain.PROCESSING_TIME, ctx.timeDomain()); + out.collect("" + 1777); + } + } + + private static class ProcessingTimeQueryingProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + out.collect(value + "PT:" + ctx.timerService().currentProcessingTime() + " TS:" + ctx.timestamp()); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + out.collect(value + "PT:" + ctx.timerService().currentProcessingTime() + " TS:" + ctx.timestamp()); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + } + } + + private static class ProcessingTimeTriggeringStatefulProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + private final ValueStateDescriptor state = + new ValueStateDescriptor<>("seen-element", StringSerializer.INSTANCE); + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + handleValue(value, out, ctx.timerService(), 1); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + handleValue(value, out, ctx.timerService(), 2); + } + + private void handleValue( + Object value, + Collector out, + TimerService timerService, + int channel) throws IOException { + final ValueState state = getRuntimeContext().getState(this.state); + if (state.value() == null) { + out.collect("INPUT" + channel + ":" + value); + state.update(String.valueOf(value)); + timerService.registerProcessingTimeTimer(timerService.currentProcessingTime() + 5); + } else { + state.clear(); + timerService.deleteProcessingTimeTimer(timerService.currentProcessingTime() + 4); + } + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + assertEquals(TimeDomain.PROCESSING_TIME, ctx.timeDomain()); + out.collect("STATE:" + getRuntimeContext().getState(state).value()); + } + } + + private static class BothTriggeringProcessFunction extends CoProcessFunction { + + private static final long serialVersionUID = 1L; + + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + ctx.timerService().registerProcessingTimeTimer(3); + ctx.timerService().registerEventTimeTimer(6); + ctx.timerService().deleteProcessingTimeTimer(3); + } + + @Override + public void processElement2(String value, Context ctx, Collector out) throws Exception { + ctx.timerService().registerEventTimeTimer(4); + ctx.timerService().registerProcessingTimeTimer(5); + ctx.timerService().deleteEventTimeTimer(4); + } + + @Override + public void onTimer( + long timestamp, + OnTimerContext ctx, + Collector out) throws Exception { + if (TimeDomain.EVENT_TIME.equals(ctx.timeDomain())) { + out.collect("EVENT:1777"); + } else { + out.collect("PROC:1777"); + } + } + } +} diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala index cebda5ddd8a3af..68514b77b46cbd 100644 --- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala +++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/ConnectedStreams.scala @@ -23,7 +23,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.functions.KeySelector import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.streaming.api.datastream.{ConnectedStreams => JavaCStream, DataStream => JavaStream} -import org.apache.flink.streaming.api.functions.co.{CoFlatMapFunction, CoMapFunction, CoProcessFunction} +import org.apache.flink.streaming.api.functions.co._ import org.apache.flink.streaming.api.operators.TwoInputStreamOperator import org.apache.flink.util.Collector @@ -109,10 +109,6 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { * this function can also query the time and set timers. When reacting to the firing of set * timers the function can directly emit elements and/or register yet more timers. * - * A [[RichCoProcessFunction]] - * can be used to gain access to features provided by the - * [[org.apache.flink.api.common.functions.RichFunction]] interface. - * * @param coProcessFunction The [[CoProcessFunction]] that is called for each element * in the stream. * @return The transformed [[DataStream]]. @@ -130,6 +126,31 @@ class ConnectedStreams[IN1, IN2](javaStream: JavaCStream[IN1, IN2]) { asScalaStream(javaStream.process(coProcessFunction, outType)) } + /** + * Applies the given [[KeyedCoProcessFunction]] on the connected input keyed streams, + * thereby creating a transformed output stream. + * + * The function will be called for every element in the input keyed streams and can produce + * zero or more output elements. Contrary to the [[flatMap(CoFlatMapFunction)]] function, this + * function can also query the time and set timers. When reacting to the firing of set timers + * the function can directly emit elements and/or register yet more timers. + * + * @param keyedCoProcessFunction The [[KeyedCoProcessFunction]] that is called for each element + * in the stream. + * @return The transformed [[DataStream]]. + */ + @PublicEvolving + def process[K, R: TypeInformation]( + keyedCoProcessFunction: KeyedCoProcessFunction[K, IN1, IN2, R]) : DataStream[R] = { + if (keyedCoProcessFunction == null) { + throw new NullPointerException("KeyedCoProcessFunction function must not be null.") + } + + val outType : TypeInformation[R] = implicitly[TypeInformation[R]] + + asScalaStream(javaStream.process(keyedCoProcessFunction, outType)) + } + /** * Applies a CoFlatMap transformation on these connected streams. diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala index 4da273f3a0d0fa..25f37a7cc63f47 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecWindowJoin.scala @@ -21,7 +21,7 @@ package org.apache.flink.table.plan.nodes.physical.stream import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, MapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable -import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator import org.apache.flink.streaming.api.operators.{StreamFlatMap, StreamMap, TwoInputStreamOperator} import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation, TwoInputTransformation, UnionTransformation} import org.apache.flink.table.api.{StreamTableEnvironment, TableException} @@ -33,13 +33,11 @@ import org.apache.flink.table.plan.util.{JoinTypeUtil, KeySelectorUtil, RelExpla import org.apache.flink.table.runtime.join.{FlinkJoinType, KeyedCoProcessOperatorWithWatermarkDelay, OuterJoinPaddingUtil, ProcTimeBoundedStreamJoin, RowTimeBoundedStreamJoin} import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.util.Collector - import org.apache.calcite.plan._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{JoinInfo, JoinRelType} import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.calcite.rex.RexNode - import java.util import scala.collection.JavaConversions._ @@ -293,7 +291,7 @@ class StreamExecWindowJoin( leftPlan, rightPlan, "Co-Process", - new KeyedCoProcessOperator(procJoinFunc). + new LegacyKeyedCoProcessOperator(procJoinFunc). asInstanceOf[TwoInputStreamOperator[BaseRow,BaseRow,BaseRow]], returnTypeInfo, leftPlan.getParallelism diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala index e8f9ff4efdb700..846e452e95b218 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala @@ -23,7 +23,7 @@ import org.apache.calcite.rex.{RexBuilder, RexNode} import org.apache.flink.api.common.functions.FlatJoinFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.streaming.api.operators.TwoInputStreamOperator -import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator import org.apache.flink.table.api.{StreamQueryConfig, TableConfig} import org.apache.flink.table.codegen.{FunctionCodeGenerator, GeneratedFunction} import org.apache.flink.table.plan.schema.RowSchema @@ -146,6 +146,6 @@ class DataStreamJoinToCoProcessTranslator( genFunction.code, queryConfig) } - new KeyedCoProcessOperator(joinFunction) + new LegacyKeyedCoProcessOperator(joinFunction) } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala index f25de256e96ac7..72aec4824961bf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/operators/KeyedCoProcessOperatorWithWatermarkDelay.scala @@ -18,16 +18,16 @@ package org.apache.flink.table.runtime.operators import org.apache.flink.streaming.api.functions.co.CoProcessFunction -import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator import org.apache.flink.streaming.api.watermark.Watermark /** - * A [[KeyedCoProcessOperator]] that supports holding back watermarks with a static delay. + * A [[LegacyKeyedCoProcessOperator]] that supports holding back watermarks with a static delay. */ class KeyedCoProcessOperatorWithWatermarkDelay[KEY, IN1, IN2, OUT]( private val flatMapper: CoProcessFunction[IN1, IN2, OUT], private val watermarkDelay: Long = 0L) - extends KeyedCoProcessOperator[KEY, IN1, IN2, OUT](flatMapper) { + extends LegacyKeyedCoProcessOperator[KEY, IN1, IN2, OUT](flatMapper) { /** emits watermark without delay */ def emitWithoutDelay(mark: Watermark): Unit = output.emitWatermark(mark) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala index 4619c759c31afc..533eb05721d020 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/JoinHarnessTest.scala @@ -23,7 +23,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import org.apache.flink.api.common.time.Time import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.java.operators.join.JoinType -import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness @@ -150,8 +150,8 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new ProcTimeBoundedStreamJoin( JoinType.INNER, -10, 20, rowType, rowType, "TestJoinFunction", funcCode) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -240,8 +240,8 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new ProcTimeBoundedStreamJoin( JoinType.INNER, -10, -5, rowType, rowType, "TestJoinFunction", funcCode) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -322,7 +322,7 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new RowTimeBoundedStreamJoin( JoinType.INNER, -10, 20, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) - val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + val operator: LegacyKeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow]( joinProcessFunc, joinProcessFunc.getMaxOutputDelay) @@ -417,7 +417,7 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new RowTimeBoundedStreamJoin( JoinType.INNER, -10, -7, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) - val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + val operator: LegacyKeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow]( joinProcessFunc, joinProcessFunc.getMaxOutputDelay) @@ -495,7 +495,7 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new RowTimeBoundedStreamJoin( JoinType.LEFT_OUTER, -5, 9, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) - val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + val operator: LegacyKeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow]( joinProcessFunc, joinProcessFunc.getMaxOutputDelay) @@ -605,7 +605,7 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new RowTimeBoundedStreamJoin( JoinType.RIGHT_OUTER, -5, 9, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) - val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + val operator: LegacyKeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow]( joinProcessFunc, joinProcessFunc.getMaxOutputDelay) @@ -714,7 +714,7 @@ class JoinHarnessTest extends HarnessTestBase { val joinProcessFunc = new RowTimeBoundedStreamJoin( JoinType.FULL_OUTER, -5, 9, 0, rowType, rowType, "TestJoinFunction", funcCode, 0, 0) - val operator: KeyedCoProcessOperator[String, CRow, CRow, CRow] = + val operator: LegacyKeyedCoProcessOperator[String, CRow, CRow, CRow] = new KeyedCoProcessOperatorWithWatermarkDelay[String, CRow, CRow, CRow]( joinProcessFunc, joinProcessFunc.getMaxOutputDelay) @@ -834,8 +834,8 @@ class JoinHarnessTest extends HarnessTestBase { funcCode, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -925,8 +925,8 @@ class JoinHarnessTest extends HarnessTestBase { funcCode, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -1022,8 +1022,8 @@ class JoinHarnessTest extends HarnessTestBase { true, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -1134,8 +1134,8 @@ class JoinHarnessTest extends HarnessTestBase { true, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -1280,8 +1280,8 @@ class JoinHarnessTest extends HarnessTestBase { false, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -1392,8 +1392,8 @@ class JoinHarnessTest extends HarnessTestBase { false, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -1536,8 +1536,8 @@ class JoinHarnessTest extends HarnessTestBase { funcCode, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, @@ -1700,8 +1700,8 @@ class JoinHarnessTest extends HarnessTestBase { funcCodeWithNonEqualPred2, queryConfig) - val operator: KeyedCoProcessOperator[Integer, CRow, CRow, CRow] = - new KeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) + val operator: LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow] = + new LegacyKeyedCoProcessOperator[Integer, CRow, CRow, CRow](joinProcessFunc) val testHarness: KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow] = new KeyedTwoInputStreamOperatorTestHarness[Integer, CRow, CRow, CRow]( operator, diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/KeyedCoProcessOperatorWithWatermarkDelay.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/KeyedCoProcessOperatorWithWatermarkDelay.java index f76c27a987d98c..a4f4f4f6d814f3 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/KeyedCoProcessOperatorWithWatermarkDelay.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/KeyedCoProcessOperatorWithWatermarkDelay.java @@ -20,6 +20,7 @@ import org.apache.flink.streaming.api.functions.co.CoProcessFunction; import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator; +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.util.Preconditions; @@ -30,7 +31,7 @@ * A {@link KeyedCoProcessOperator} that supports holding back watermarks with a static delay. */ public class KeyedCoProcessOperatorWithWatermarkDelay - extends KeyedCoProcessOperator { + extends LegacyKeyedCoProcessOperator { private static final long serialVersionUID = -7435774708099223442L; diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamJoinTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamJoinTest.java index e854be225a9fa7..189bdb4e2505ec 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamJoinTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/ProcTimeBoundedStreamJoinTest.java @@ -19,7 +19,7 @@ package org.apache.flink.table.runtime.join; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator; +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator; import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.runtime.util.BinaryRowKeySelector; @@ -169,7 +169,7 @@ public void testProcTimeInnerJoinWithNegativeBounds() throws Exception { private KeyedTwoInputStreamOperatorTestHarness createTestHarness( ProcTimeBoundedStreamJoin windowJoinFunc) throws Exception { - KeyedCoProcessOperator operator = new KeyedCoProcessOperator<>( + LegacyKeyedCoProcessOperator operator = new LegacyKeyedCoProcessOperator<>( windowJoinFunc); KeyedTwoInputStreamOperatorTestHarness testHarness = new KeyedTwoInputStreamOperatorTestHarness<>(operator, keySelector, keySelector, keyType); diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/RowTimeBoundedStreamJoinTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/RowTimeBoundedStreamJoinTest.java index ef248e9439501f..dadce5e21dcc6c 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/RowTimeBoundedStreamJoinTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/RowTimeBoundedStreamJoinTest.java @@ -19,7 +19,7 @@ package org.apache.flink.table.runtime.join; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.streaming.api.operators.co.KeyedCoProcessOperator; +import org.apache.flink.streaming.api.operators.co.LegacyKeyedCoProcessOperator; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; import org.apache.flink.table.dataformat.BaseRow; @@ -377,7 +377,7 @@ public void testRowTimeFullOuterJoin() throws Exception { private KeyedTwoInputStreamOperatorTestHarness createTestHarness( RowTimeBoundedStreamJoin windowJoinFunc) throws Exception { - KeyedCoProcessOperator operator = new KeyedCoProcessOperatorWithWatermarkDelay<>( + LegacyKeyedCoProcessOperator operator = new KeyedCoProcessOperatorWithWatermarkDelay<>( windowJoinFunc, windowJoinFunc.getMaxOutputDelay()); KeyedTwoInputStreamOperatorTestHarness testHarness = new KeyedTwoInputStreamOperatorTestHarness<>(operator, keySelector, keySelector, keyType); diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/SideOutputITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/SideOutputITCase.java index 29f2c8c16b8227..13876ed4f39e22 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/SideOutputITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/SideOutputITCase.java @@ -28,6 +28,7 @@ import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; import org.apache.flink.streaming.api.functions.ProcessFunction; import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction; @@ -501,7 +502,7 @@ public void processElement( * Test keyed CoProcessFunction side output. */ @Test - public void testKeyedCoProcessFunctionSideOutput() throws Exception { + public void testLegacyKeyedCoProcessFunctionSideOutput() throws Exception { final OutputTag sideOutputTag = new OutputTag("side"){}; TestListResultSink sideOutputResultSink = new TestListResultSink<>(); @@ -542,11 +543,56 @@ public void processElement2(Integer value, Context ctx, Collector out) assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult()); } + /** + * Test keyed KeyedCoProcessFunction side output. + */ + @Test + public void testKeyedCoProcessFunctionSideOutput() throws Exception { + final OutputTag sideOutputTag = new OutputTag("side"){}; + + TestListResultSink sideOutputResultSink = new TestListResultSink<>(); + TestListResultSink resultSink = new TestListResultSink<>(); + + StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment(); + see.setParallelism(3); + + DataStream ds1 = see.fromCollection(elements); + DataStream ds2 = see.fromCollection(elements); + + SingleOutputStreamOperator passThroughtStream = ds1 + .keyBy(i -> i) + .connect(ds2.keyBy(i -> i)) + .process(new KeyedCoProcessFunction() { + @Override + public void processElement1(Integer value, Context ctx, Collector out) throws Exception { + if (value < 3) { + out.collect(value); + ctx.output(sideOutputTag, "sideout1-" + ctx.getCurrentKey() + "-" + String.valueOf(value)); + } + } + + @Override + public void processElement2(Integer value, Context ctx, Collector out) throws Exception { + if (value >= 3) { + out.collect(value); + ctx.output(sideOutputTag, "sideout2-" + ctx.getCurrentKey() + "-" + String.valueOf(value)); + } + } + }); + + passThroughtStream.getSideOutput(sideOutputTag).addSink(sideOutputResultSink); + passThroughtStream.addSink(resultSink); + see.execute(); + + assertEquals(Arrays.asList("sideout1-1-1", "sideout1-2-2", "sideout2-3-3", "sideout2-4-4", "sideout2-5-5"), sideOutputResultSink.getSortedResult()); + assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult()); + } + /** * Test keyed CoProcessFunction side output with multiple consumers. */ @Test - public void testKeyedCoProcessFunctionSideOutputWithMultipleConsumers() throws Exception { + public void testLegacyKeyedCoProcessFunctionSideOutputWithMultipleConsumers() throws Exception { final OutputTag sideOutputTag1 = new OutputTag("side1"){}; final OutputTag sideOutputTag2 = new OutputTag("side2"){}; @@ -591,6 +637,57 @@ public void processElement2(Integer value, Context ctx, Collector out) assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult()); } + /** + * Test keyed KeyedCoProcessFunction side output with multiple consumers. + */ + @Test + public void testKeyedCoProcessFunctionSideOutputWithMultipleConsumers() throws Exception { + final OutputTag sideOutputTag1 = new OutputTag("side1"){}; + final OutputTag sideOutputTag2 = new OutputTag("side2"){}; + + TestListResultSink sideOutputResultSink1 = new TestListResultSink<>(); + TestListResultSink sideOutputResultSink2 = new TestListResultSink<>(); + TestListResultSink resultSink = new TestListResultSink<>(); + + StreamExecutionEnvironment see = StreamExecutionEnvironment.getExecutionEnvironment(); + see.setParallelism(3); + + DataStream ds1 = see.fromCollection(elements); + DataStream ds2 = see.fromCollection(elements); + + SingleOutputStreamOperator passThroughtStream = ds1 + .keyBy(i -> i) + .connect(ds2.keyBy(i -> i)) + .process(new KeyedCoProcessFunction() { + @Override + public void processElement1(Integer value, Context ctx, Collector out) + throws Exception { + if (value < 4) { + out.collect(value); + ctx.output(sideOutputTag1, "sideout1-" + ctx.getCurrentKey() + "-" + String.valueOf(value)); + } + } + + @Override + public void processElement2(Integer value, Context ctx, Collector out) + throws Exception { + if (value >= 4) { + out.collect(value); + ctx.output(sideOutputTag2, "sideout2-" + ctx.getCurrentKey() + "-" + String.valueOf(value)); + } + } + }); + + passThroughtStream.getSideOutput(sideOutputTag1).addSink(sideOutputResultSink1); + passThroughtStream.getSideOutput(sideOutputTag2).addSink(sideOutputResultSink2); + passThroughtStream.addSink(resultSink); + see.execute(); + + assertEquals(Arrays.asList("sideout1-1-1", "sideout1-2-2", "sideout1-3-3"), sideOutputResultSink1.getSortedResult()); + assertEquals(Arrays.asList("sideout2-4-4", "sideout2-5-5"), sideOutputResultSink2.getSortedResult()); + assertEquals(Arrays.asList(1, 2, 3, 4, 5), resultSink.getSortedResult()); + } + /** * Test ProcessFunction side outputs with wrong {@code OutputTag}. */ From ec61cf21276fb6eea599b40ea999613ffeef48ef Mon Sep 17 00:00:00 2001 From: Huang Xingbo Date: Fri, 17 May 2019 11:17:23 +0800 Subject: [PATCH 11/92] [FLINK-12327][python] Adds support to submit Python Table API job in CliFrontend This closes #8472 --- docs/ops/cli.md | 66 ++++++ docs/ops/cli.zh.md | 63 ++++- .../apache/flink/client/cli/CliFrontend.java | 56 +++-- .../flink/client/cli/CliFrontendParser.java | 29 +++ .../flink/client/cli/ProgramOptions.java | 74 +++++- .../flink/client/program/PackagedProgram.java | 32 ++- .../flink/client/python/PythonDriver.java | 168 +++++++++++++ .../client/python/PythonGatewayServer.java | 21 +- .../flink/client/python/PythonUtil.java | 223 ++++++++++++++++++ .../flink/client/python/PythonDriverTest.java | 104 ++++++++ .../flink/client/python/PythonUtilTest.java | 118 +++++++++ flink-dist/pom.xml | 3 + flink-dist/src/main/assemblies/bin.xml | 7 + flink-python/pyflink/find_flink_home.py | 3 + flink-python/pyflink/java_gateway.py | 14 +- .../pyflink/table/examples/batch/__init__.py | 17 ++ .../table/examples/batch/word_count.py | 79 +++++++ 17 files changed, 1032 insertions(+), 45 deletions(-) create mode 100644 flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java create mode 100644 flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java create mode 100644 flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java create mode 100644 flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java create mode 100644 flink-python/pyflink/table/examples/batch/__init__.py create mode 100644 flink-python/pyflink/table/examples/batch/word_count.py diff --git a/docs/ops/cli.md b/docs/ops/cli.md index b414b30dd91914..505207dd8b2187 100644 --- a/docs/ops/cli.md +++ b/docs/ops/cli.md @@ -47,6 +47,12 @@ available. {:toc} ## Examples +### Job Submission Examples +----------------------------- + +These examples about how to submit a job in CLI. +
+
- Run example program with no arguments: @@ -88,6 +94,53 @@ available. ./examples/batch/WordCount.jar \ --input hdfs:///user/hamlet.txt --output hdfs:///user/wordcount_out +
+ +
+ +- Run Python Table program: + + ./bin/flink run -py examples/python/table/batch/word_count.py -j + +- Run Python Table program with pyFiles: + + ./bin/flink run -py examples/python/table/batch/word_count.py -j \ + -pyfs file:///user.txt,hdfs:///$namenode_address/username.txt + +- Run Python Table program with pyFiles and pyModule: + + ./bin/flink run -pym batch.word_count -pyfs examples/python/table/batch -j + +- Run Python Table program with parallelism 16: + + ./bin/flink run -p 16 -py examples/python/table/batch/word_count.py -j + +- Run Python Table program with flink log output disabled: + + ./bin/flink run -q -py examples/python/table/batch/word_count.py -j + +- Run Python Table program in detached mode: + + ./bin/flink run -d examples/python/table/batch/word_count.py -j + +- Run Python Table program on a specific JobManager: + + ./bin/flink run -m myJMHost:8081 \ + -py examples/python/table/batch/word_count.py \ + -j + +- Run Python Table program using a [per-job YARN cluster]({{site.baseurl}}/ops/deployment/yarn_setup.html#run-a-single-flink-job-on-hadoop-yarn) with 2 TaskManagers: + + ./bin/flink run -m yarn-cluster -yn 2 \ + -py examples/python/table/batch/word_count.py \ + -j +
+ +### Job Management Examples +----------------------------- + +These examples about how to manage a job in CLI. + - Display the optimized execution plan for the WordCount example program as JSON: ./bin/flink info ./examples/batch/WordCount.jar \ @@ -251,6 +304,19 @@ Action "run" compiles and runs a program. program. Optional flag to override the default value specified in the configuration. + -py,--python Python script with the program entry + point.The dependent resources can be + configured with the `--pyFiles` option. + -pyfs,--pyFiles Attach custom python files for job. + Comma can be used as the separator to + specify multiple files. The standard + python resource file suffixes such as + .py/.egg/.zip are all supported. + (eg:--pyFiles file:///tmp/myresource.zip + ,hdfs:///$namenode_address/myresource2.zip) + -pym,--pyModule Python module with the program entry + point. This option must be used in + conjunction with ` --pyFiles`. -q,--sysoutLogging If present, suppress logging output to standard out. -s,--fromSavepoint Path to a savepoint to restore the job diff --git a/docs/ops/cli.zh.md b/docs/ops/cli.zh.md index 7c020479756a81..93f16fb62fd7d2 100644 --- a/docs/ops/cli.zh.md +++ b/docs/ops/cli.zh.md @@ -47,6 +47,12 @@ available. {:toc} ## Examples +### 作业提交示例 +----------------------------- + +这些示例是关于如何通过脚本提交一个作业 +
+
- Run example program with no arguments: @@ -82,11 +88,57 @@ available. ./examples/batch/WordCount.jar \ --input file:///home/user/hamlet.txt --output file:///home/user/wordcount_out -- Run example program using a [per-job YARN cluster]({{site.baseurl}}/ops/deployment/yarn_setup.html#run-a-single-flink-job-on-hadoop-yarn) with 2 TaskManagers: +- Run example program using a [per-job YARN cluster]({{site.baseurl}}/zh/ops/deployment/yarn_setup.html#run-a-single-flink-job-on-hadoop-yarn) with 2 TaskManagers: ./bin/flink run -m yarn-cluster -yn 2 \ ./examples/batch/WordCount.jar \ --input hdfs:///user/hamlet.txt --output hdfs:///user/wordcount_out + +
+ +
+ +- 提交一个Python Table的作业: + + ./bin/flink run -py WordCount.py -j + +- 提交一个有多个依赖的Python Table的作业: + + ./bin/flink run -py examples/python/table/batch/word_count.py -j \ + -pyfs file:///user.txt,hdfs:///$namenode_address/username.txt + +- 提交一个有多个依赖的Python Table的作业,Python作业的主入口通过pym选项指定: + + ./bin/flink run -pym batch.word_count -pyfs examples/python/table/batch -j + +- 提交一个指定并发度为16的Python Table的作业: + + ./bin/flink run -p 16 -py examples/python/table/batch/word_count.py -j + +- 提交一个关闭flink日志输出的Python Table的作业: + + ./bin/flink run -q -py examples/python/table/batch/word_count.py -j + +- 提交一个运行在detached模式下的Python Table的作业: + + ./bin/flink run -d examples/python/table/batch/word_count.py -j + +- 提交一个运行在指定JobManager上的Python Table的作业: + + ./bin/flink run -m myJMHost:8081 \ + -py examples/python/table/batch/word_count.py \ + -j + +- 提交一个运行在有两个TaskManager的[per-job YARN cluster]({{site.baseurl}}/ops/deployment/yarn_setup.html#run-a-single-flink-job-on-hadoop-yarn)的Python Table的作业: + + ./bin/flink run -m yarn-cluster -yn 2 \ + -py examples/python/table/batch/word_count.py \ + -j + +
+ +### 作业管理示例 +----------------------------- - Display the optimized execution plan for the WordCount example program as JSON: @@ -251,6 +303,15 @@ Action "run" compiles and runs a program. program. Optional flag to override the default value specified in the configuration. + -py,--python 指定Python作业的入口,依赖的资源文件可以通过 + `--pyFiles`进行指定。 + -pyfs,--pyFiles 指定Python作业依赖的一些自定义的python文件, + 如果有多个文件,可以通过逗号(,)进行分隔。支持 + 常用的python资源文件,例如(.py/.egg/.zip)。 + (例如:--pyFiles file:///tmp/myresource.zip + ,hdfs:///$namenode_address/myresource2.zip) + -pym,--pyModule 指定python程序的运行的模块入口,这个选项必须配合 + `--pyFiles`一起使用。 -q,--sysoutLogging If present, suppress logging output to standard out. -s,--fromSavepoint Path to a savepoint to restore the job diff --git a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java index c6b5c9abe4a5a1..c591e6e5d5f202 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java +++ b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java @@ -32,6 +32,7 @@ import org.apache.flink.client.program.ProgramInvocationException; import org.apache.flink.client.program.ProgramMissingJobException; import org.apache.flink.client.program.ProgramParametrizationException; +import org.apache.flink.client.python.PythonDriver; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.CoreOptions; @@ -185,8 +186,11 @@ protected void run(String[] args) throws Exception { return; } - if (runOptions.getJarFilePath() == null) { - throw new CliArgsException("The program JAR file was not specified."); + if (!runOptions.isPython()) { + // Java program should be specified a JAR file + if (runOptions.getJarFilePath() == null) { + throw new CliArgsException("Java program should be specified a JAR file."); + } } final PackagedProgram program; @@ -771,12 +775,42 @@ PackagedProgram buildProgram(ProgramOptions options) throws FileNotFoundExceptio String jarFilePath = options.getJarFilePath(); List classpaths = options.getClasspaths(); - if (jarFilePath == null) { - throw new IllegalArgumentException("The program JAR file was not specified."); + String entryPointClass; + File jarFile = null; + if (options.isPython()) { + // If the job is specified a jar file + if (jarFilePath != null) { + jarFile = getJarFile(jarFilePath); + } + // The entry point class of python job is PythonDriver + entryPointClass = PythonDriver.class.getCanonicalName(); + } else { + if (jarFilePath == null) { + throw new IllegalArgumentException("The program JAR file was not specified."); + } + jarFile = getJarFile(jarFilePath); + // Get assembler class + entryPointClass = options.getEntryPointClassName(); } - File jarFile = new File(jarFilePath); + PackagedProgram program = entryPointClass == null ? + new PackagedProgram(jarFile, classpaths, programArgs) : + new PackagedProgram(jarFile, classpaths, entryPointClass, programArgs); + + program.setSavepointRestoreSettings(options.getSavepointRestoreSettings()); + + return program; + } + /** + * Gets the JAR file from the path. + * + * @param jarFilePath The path of JAR file + * @return The JAR file + * @throws FileNotFoundException The JAR file does not exist. + */ + private File getJarFile(String jarFilePath) throws FileNotFoundException { + File jarFile = new File(jarFilePath); // Check if JAR file exists if (!jarFile.exists()) { throw new FileNotFoundException("JAR file does not exist: " + jarFile); @@ -784,17 +818,7 @@ PackagedProgram buildProgram(ProgramOptions options) throws FileNotFoundExceptio else if (!jarFile.isFile()) { throw new FileNotFoundException("JAR file is not a file: " + jarFile); } - - // Get assembler class - String entryPointClass = options.getEntryPointClassName(); - - PackagedProgram program = entryPointClass == null ? - new PackagedProgram(jarFile, classpaths, programArgs) : - new PackagedProgram(jarFile, classpaths, entryPointClass, programArgs); - - program.setSavepointRestoreSettings(options.getSavepointRestoreSettings()); - - return program; + return jarFile; } // -------------------------------------------------------------------------------------------- diff --git a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontendParser.java b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontendParser.java index cea399808f583b..5872a54bd0199a 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontendParser.java +++ b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontendParser.java @@ -118,6 +118,20 @@ public class CliFrontendParser { public static final Option STOP_AND_DRAIN = new Option("d", "drain", false, "Send MAX_WATERMARK before taking the savepoint and stopping the pipelne."); + static final Option PY_OPTION = new Option("py", "python", true, + "Python script with the program entry point. " + + "The dependent resources can be configured with the `--pyFiles` option."); + + static final Option PYFILES_OPTION = new Option("pyfs", "pyFiles", true, + "Attach custom python files for job. " + + "Comma can be used as the separator to specify multiple files. " + + "The standard python resource file suffixes such as .py/.egg/.zip are all supported." + + "(eg: --pyFiles file:///tmp/myresource.zip,hdfs:///$namenode_address/myresource2.zip)"); + + static final Option PYMODULE_OPTION = new Option("pym", "pyModule", true, + "Python module with the program entry point. " + + "This option must be used in conjunction with `--pyFiles`."); + static { HELP_OPTION.setRequired(false); @@ -165,6 +179,15 @@ public class CliFrontendParser { STOP_WITH_SAVEPOINT.setOptionalArg(true); STOP_AND_DRAIN.setRequired(false); + + PY_OPTION.setRequired(false); + PY_OPTION.setArgName("python"); + + PYFILES_OPTION.setRequired(false); + PYFILES_OPTION.setArgName("pyFiles"); + + PYMODULE_OPTION.setRequired(false); + PYMODULE_OPTION.setArgName("pyModule"); } private static final Options RUN_OPTIONS = getRunCommandOptions(); @@ -186,6 +209,9 @@ private static Options getProgramSpecificOptions(Options options) { options.addOption(DETACHED_OPTION); options.addOption(SHUTDOWN_IF_ATTACHED_OPTION); options.addOption(YARN_DETACHED_OPTION); + options.addOption(PY_OPTION); + options.addOption(PYFILES_OPTION); + options.addOption(PYMODULE_OPTION); return options; } @@ -196,6 +222,9 @@ private static Options getProgramSpecificOptionsWithoutDeprecatedOptions(Options options.addOption(LOGGING_OPTION); options.addOption(DETACHED_OPTION); options.addOption(SHUTDOWN_IF_ATTACHED_OPTION); + options.addOption(PY_OPTION); + options.addOption(PYFILES_OPTION); + options.addOption(PYMODULE_OPTION); return options; } diff --git a/flink-clients/src/main/java/org/apache/flink/client/cli/ProgramOptions.java b/flink-clients/src/main/java/org/apache/flink/client/cli/ProgramOptions.java index da03d64048cbc1..30b38675e39bd2 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/cli/ProgramOptions.java +++ b/flink-clients/src/main/java/org/apache/flink/client/cli/ProgramOptions.java @@ -36,6 +36,9 @@ import static org.apache.flink.client.cli.CliFrontendParser.JAR_OPTION; import static org.apache.flink.client.cli.CliFrontendParser.LOGGING_OPTION; import static org.apache.flink.client.cli.CliFrontendParser.PARALLELISM_OPTION; +import static org.apache.flink.client.cli.CliFrontendParser.PYFILES_OPTION; +import static org.apache.flink.client.cli.CliFrontendParser.PYMODULE_OPTION; +import static org.apache.flink.client.cli.CliFrontendParser.PY_OPTION; import static org.apache.flink.client.cli.CliFrontendParser.SHUTDOWN_IF_ATTACHED_OPTION; import static org.apache.flink.client.cli.CliFrontendParser.YARN_DETACHED_OPTION; @@ -62,17 +65,71 @@ public abstract class ProgramOptions extends CommandLineOptions { private final SavepointRestoreSettings savepointSettings; + /** + * Flag indicating whether the job is a Python job. + */ + private final boolean isPython; + protected ProgramOptions(CommandLine line) throws CliArgsException { super(line); String[] args = line.hasOption(ARGS_OPTION.getOpt()) ? - line.getOptionValues(ARGS_OPTION.getOpt()) : - line.getArgs(); + line.getOptionValues(ARGS_OPTION.getOpt()) : + line.getArgs(); + + isPython = line.hasOption(PY_OPTION.getOpt()) | line.hasOption(PYMODULE_OPTION.getOpt()); + // If specified the option -py(--python) + if (line.hasOption(PY_OPTION.getOpt())) { + // Cannot use option -py and -pym simultaneously. + if (line.hasOption(PYMODULE_OPTION.getOpt())) { + throw new CliArgsException("Cannot use option -py and -pym simultaneously."); + } + // The cli cmd args which will be transferred to PythonDriver will be transformed as follows: + // CLI cmd : -py ${python.py} pyfs [optional] ${py-files} [optional] ${other args}. + // PythonDriver args: py ${python.py} [optional] pyfs [optional] ${py-files} [optional] ${other args}. + // -------------------------------transformed------------------------------------------------------- + // e.g. -py wordcount.py(CLI cmd) -----------> py wordcount.py(PythonDriver args) + // e.g. -py wordcount.py -pyfs file:///AAA.py,hdfs:///BBB.py --input in.txt --output out.txt(CLI cmd) + // -----> py wordcount.py pyfs file:///AAA.py,hdfs:///BBB.py --input in.txt --output out.txt(PythonDriver args) + String[] newArgs; + int argIndex; + if (line.hasOption(PYFILES_OPTION.getOpt())) { + newArgs = new String[args.length + 4]; + newArgs[2] = PYFILES_OPTION.getOpt(); + newArgs[3] = line.getOptionValue(PYFILES_OPTION.getOpt()); + argIndex = 4; + } else { + newArgs = new String[args.length + 2]; + argIndex = 2; + } + newArgs[0] = PY_OPTION.getOpt(); + newArgs[1] = line.getOptionValue(PY_OPTION.getOpt()); + System.arraycopy(args, 0, newArgs, argIndex, args.length); + args = newArgs; + } + + // If specified the option -pym(--pyModule) + if (line.hasOption(PYMODULE_OPTION.getOpt())) { + // If you specify the option -pym, you should specify the option --pyFiles simultaneously. + if (!line.hasOption(PYFILES_OPTION.getOpt())) { + throw new CliArgsException("-pym must be used in conjunction with `--pyFiles`"); + } + // The cli cmd args which will be transferred to PythonDriver will be transformed as follows: + // CLI cmd : -pym ${py-module} -pyfs ${py-files} [optional] ${other args}. + // PythonDriver args: pym ${py-module} pyfs ${py-files} [optional] ${other args}. + // e.g. -pym AAA.fun -pyfs AAA.zip(CLI cmd) ----> pym AAA.fun -pyfs AAA.zip(PythonDriver args) + String[] newArgs = new String[args.length + 4]; + newArgs[0] = PYMODULE_OPTION.getOpt(); + newArgs[1] = line.getOptionValue(PYMODULE_OPTION.getOpt()); + newArgs[2] = PYFILES_OPTION.getOpt(); + newArgs[3] = line.getOptionValue(PYFILES_OPTION.getOpt()); + System.arraycopy(args, 0, newArgs, 4, args.length); + args = newArgs; + } if (line.hasOption(JAR_OPTION.getOpt())) { this.jarFilePath = line.getOptionValue(JAR_OPTION.getOpt()); - } - else if (args.length > 0) { + } else if (!isPython && args.length > 0) { jarFilePath = args[0]; args = Arrays.copyOfRange(args, 1, args.length); } @@ -95,7 +152,7 @@ else if (args.length > 0) { this.classpaths = classpaths; this.entryPointClass = line.hasOption(CLASS_OPTION.getOpt()) ? - line.getOptionValue(CLASS_OPTION.getOpt()) : null; + line.getOptionValue(CLASS_OPTION.getOpt()) : null; if (line.hasOption(PARALLELISM_OPTION.getOpt())) { String parString = line.getOptionValue(PARALLELISM_OPTION.getOpt()); @@ -156,4 +213,11 @@ public boolean isShutdownOnAttachedExit() { public SavepointRestoreSettings getSavepointRestoreSettings() { return savepointSettings; } + + /** + * Indicates whether the job is a Python job. + */ + public boolean isPython() { + return isPython; + } } diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java b/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java index 8f5ccba993c144..77b5d295159d64 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java +++ b/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.Program; import org.apache.flink.api.common.ProgramDescription; +import org.apache.flink.client.python.PythonDriver; import org.apache.flink.optimizer.Optimizer; import org.apache.flink.optimizer.dag.DataSinkNode; import org.apache.flink.optimizer.plandump.PlanJSONDumpGenerator; @@ -90,6 +91,11 @@ public class PackagedProgram { private SavepointRestoreSettings savepointSettings = SavepointRestoreSettings.none(); + /** + * Flag indicating whether the job is a Python job. + */ + private final boolean isPython; + /** * Creates an instance that wraps the plan defined in the jar file using the given * argument. @@ -169,18 +175,21 @@ public PackagedProgram(File jarFile, @Nullable String entryPointClassName, Strin * may be a missing / wrong class or manifest files. */ public PackagedProgram(File jarFile, List classpaths, @Nullable String entryPointClassName, String... args) throws ProgramInvocationException { - if (jarFile == null) { - throw new IllegalArgumentException("The jar file must not be null."); - } + // Whether the job is a Python job. + isPython = entryPointClassName != null && entryPointClassName.equals(PythonDriver.class.getCanonicalName()); - URL jarFileUrl; - try { - jarFileUrl = jarFile.getAbsoluteFile().toURI().toURL(); - } catch (MalformedURLException e1) { - throw new IllegalArgumentException("The jar file path is invalid."); - } + URL jarFileUrl = null; + if (jarFile != null) { + try { + jarFileUrl = jarFile.getAbsoluteFile().toURI().toURL(); + } catch (MalformedURLException e1) { + throw new IllegalArgumentException("The jar file path is invalid."); + } - checkJarFile(jarFileUrl); + checkJarFile(jarFileUrl); + } else if (!isPython) { + throw new IllegalArgumentException("The jar file must not be null."); + } this.jarFile = jarFileUrl; this.args = args == null ? new String[0] : args; @@ -191,7 +200,7 @@ public PackagedProgram(File jarFile, List classpaths, @Nullable String entr } // now that we have an entry point, we can extract the nested jar files (if any) - this.extractedTempLibraries = extractContainedLibraries(jarFileUrl); + this.extractedTempLibraries = jarFileUrl == null ? Collections.emptyList() : extractContainedLibraries(jarFileUrl); this.classpaths = classpaths; this.userCodeClassLoader = JobWithJars.buildUserCodeClassLoader(getAllLibraries(), classpaths, getClass().getClassLoader()); @@ -233,6 +242,7 @@ public PackagedProgram(Class entryPointClass, String... args) throws ProgramI // load the entry point class this.mainClass = entryPointClass; + isPython = entryPointClass == PythonDriver.class; // if the entry point is a program, instantiate the class and get the plan if (Program.class.isAssignableFrom(this.mainClass)) { diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java b/flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java new file mode 100644 index 00000000000000..e43a24eec98ea0 --- /dev/null +++ b/flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.client.python; + +import org.apache.flink.core.fs.Path; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import py4j.GatewayServer; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A main class used to launch Python applications. It executes python as a + * subprocess and then has it connect back to the JVM to access system properties, etc. + */ +public class PythonDriver { + private static final Logger LOG = LoggerFactory.getLogger(PythonDriver.class); + + public static void main(String[] args) { + // the python job needs at least 2 args. + // e.g. py a.py ... + // e.g. pym a.b -pyfs a.zip ... + if (args.length < 2) { + LOG.error("Required at least two arguments, only python file or python module is available."); + System.exit(1); + } + // parse args + Map> parsedArgs = parseOptions(args); + // start gateway server + GatewayServer gatewayServer = startGatewayServer(); + // prepare python env + + // map filename to its Path + Map filePathMap = new HashMap<>(); + // commands which will be exec in python progress. + List commands = constructPythonCommands(filePathMap, parsedArgs); + try { + // prepare the exec environment of python progress. + PythonUtil.PythonEnvironment pythonEnv = PythonUtil.preparePythonEnvironment(filePathMap); + // set env variable PYFLINK_GATEWAY_PORT for connecting of python gateway in python progress. + pythonEnv.systemEnv.put("PYFLINK_GATEWAY_PORT", String.valueOf(gatewayServer.getListeningPort())); + // start the python process. + Process pythonProcess = PythonUtil.startPythonProcess(pythonEnv, commands); + int exitCode = pythonProcess.waitFor(); + if (exitCode != 0) { + throw new RuntimeException("Python process exits with code: " + exitCode); + } + } catch (Throwable e) { + LOG.error("Run python process failed", e); + } finally { + gatewayServer.shutdown(); + } + } + + /** + * Creates a GatewayServer run in a daemon thread. + * + * @return The created GatewayServer + */ + public static GatewayServer startGatewayServer() { + InetAddress localhost = InetAddress.getLoopbackAddress(); + GatewayServer gatewayServer = new GatewayServer.GatewayServerBuilder() + .javaPort(0) + .javaAddress(localhost) + .build(); + Thread thread = new Thread(gatewayServer::start); + thread.setName("py4j-gateway"); + thread.setDaemon(true); + thread.start(); + try { + thread.join(); + } catch (InterruptedException e) { + LOG.error("The gateway server thread join failed.", e); + System.exit(1); + } + return gatewayServer; + } + + /** + * Constructs the Python commands which will be executed in python process. + * + * @param filePathMap stores python file name to its path + * @param parsedArgs parsed args + */ + public static List constructPythonCommands(Map filePathMap, Map> parsedArgs) { + List commands = new ArrayList<>(); + if (parsedArgs.containsKey("py")) { + String pythonFile = parsedArgs.get("py").get(0); + Path pythonFilePath = new Path(pythonFile); + filePathMap.put(pythonFilePath.getName(), pythonFilePath); + commands.add(pythonFilePath.getName()); + } + if (parsedArgs.containsKey("pym")) { + String pyModule = parsedArgs.get("pym").get(0); + commands.add("-m"); + commands.add(pyModule); + } + if (parsedArgs.containsKey("pyfs")) { + List pyFiles = parsedArgs.get("pyfs"); + for (String pyFile : pyFiles) { + Path pyFilePath = new Path(pyFile); + filePathMap.put(pyFilePath.getName(), pyFilePath); + } + } + if (parsedArgs.containsKey("args")) { + commands.addAll(parsedArgs.get("args")); + } + return commands; + } + + /** + * Parses the args to the map format. + * + * @param args ["py", "xxx.py", + * "pyfs", "a.py,b.py,c.py", + * "--input", "in.txt"] + * @return {"py"->List("xxx.py"),"pyfs"->List("a.py","b.py","c.py"),"args"->List("--input","in.txt")} + */ + public static Map> parseOptions(String[] args) { + Map> parsedArgs = new HashMap<>(); + int argIndex = 0; + boolean isEntrypointSpecified = false; + // valid args should include python or pyModule field and their value. + if (args[0].equals("py") || args[0].equals("pym")) { + parsedArgs.put(args[0], Collections.singletonList(args[1])); + argIndex = 2; + isEntrypointSpecified = true; + } + if (isEntrypointSpecified && args.length > 2 && args[2].equals("pyfs")) { + List pyFilesList = new ArrayList<>(Arrays.asList(args[3].split(","))); + parsedArgs.put(args[2], pyFilesList); + argIndex = 4; + } + if (!isEntrypointSpecified) { + throw new RuntimeException("The Python entrypoint has not been specified. It can be specified with option -py or -pym"); + } + // if arg include other args, the key "args" will map to other args. + if (args.length > argIndex) { + List otherArgList = new ArrayList<>(args.length - argIndex); + otherArgList.addAll(Arrays.asList(args).subList(argIndex, args.length)); + parsedArgs.put("args", otherArgList); + } + return parsedArgs; + } +} diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java b/flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java index 6432a67ccdb3fe..64f2ef1d382819 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java +++ b/flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java @@ -25,6 +25,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.net.InetAddress; +import java.nio.file.Files; /** * The Py4j Gateway Server provides RPC service for user's python process. @@ -56,15 +57,17 @@ public static void main(String[] args) throws IOException { // Tells python side the port of our java rpc server String handshakeFilePath = System.getenv("_PYFLINK_CONN_INFO_PATH"); File handshakeFile = new File(handshakeFilePath); - if (handshakeFile.createNewFile()) { - FileOutputStream fileOutputStream = new FileOutputStream(handshakeFile); - DataOutputStream stream = new DataOutputStream(fileOutputStream); - stream.writeInt(boundPort); - stream.close(); - fileOutputStream.close(); - } else { - System.out.println("Can't create handshake file: " + handshakeFilePath + ", now exit..."); - return; + File tmpPath = Files.createTempFile(handshakeFile.getParentFile().toPath(), + "connection", ".info").toFile(); + FileOutputStream fileOutputStream = new FileOutputStream(tmpPath); + DataOutputStream stream = new DataOutputStream(fileOutputStream); + stream.writeInt(boundPort); + stream.close(); + fileOutputStream.close(); + + if (!tmpPath.renameTo(handshakeFile)) { + System.out.println("Unable to write connection information to handshake file: " + handshakeFilePath + ", now exit..."); + System.exit(1); } // Exit on EOF or broken pipe. This ensures that the server dies diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java b/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java new file mode 100644 index 00000000000000..b9012a38fa4bf3 --- /dev/null +++ b/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.client.python; + +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; +import org.apache.flink.util.FileUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.file.FileSystems; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * The util class help to prepare Python env and run the python process. + */ +public final class PythonUtil { + private static final Logger LOG = LoggerFactory.getLogger(PythonUtil.class); + + private static final String FLINK_OPT_DIR = System.getenv("FLINK_OPT_DIR"); + + private static final String FLINK_OPT_DIR_PYTHON = FLINK_OPT_DIR + File.separator + "python"; + + /** + * Wraps Python exec environment. + */ + public static class PythonEnvironment { + public String workingDirectory; + + public String pythonExec = "python"; + + public String pythonPath; + + Map systemEnv = new HashMap<>(); + } + + /** + * The hook thread that delete the tmp working dir of python process after the python process shutdown. + */ + private static class ShutDownPythonHook extends Thread { + private Process p; + private String pyFileDir; + + public ShutDownPythonHook(Process p, String pyFileDir) { + this.p = p; + this.pyFileDir = pyFileDir; + } + + public void run() { + + p.destroyForcibly(); + + if (pyFileDir != null) { + File pyDir = new File(pyFileDir); + FileUtils.deleteDirectoryQuietly(pyDir); + } + } + } + + + /** + * Prepares PythonEnvironment to start python process. + * + * @param filePathMap map file name to its file path. + * @return PythonEnvironment the Python environment which will be executed in Python process. + */ + public static PythonEnvironment preparePythonEnvironment(Map filePathMap) { + PythonEnvironment env = new PythonEnvironment(); + + // 1. setup temporary local directory for the user files + String tmpDir = System.getProperty("java.io.tmpdir") + + File.separator + "pyflink" + UUID.randomUUID(); + + Path tmpDirPath = new Path(tmpDir); + try { + FileSystem fs = tmpDirPath.getFileSystem(); + if (fs.exists(tmpDirPath)) { + fs.delete(tmpDirPath, true); + } + fs.mkdirs(tmpDirPath); + } catch (IOException e) { + LOG.error("Prepare tmp directory failed.", e); + } + + env.workingDirectory = tmpDirPath.toString(); + + StringBuilder pythonPathEnv = new StringBuilder(); + + pythonPathEnv.append(env.workingDirectory); + + // 2. create symbolLink in the working directory for the pyflink dependency libs. + List pythonLibs = getLibFiles(FLINK_OPT_DIR_PYTHON); + for (java.nio.file.Path libPath : pythonLibs) { + java.nio.file.Path symbolicLinkFilePath = FileSystems.getDefault().getPath(env.workingDirectory, + libPath.getFileName().toString()); + createSymbolicLinkForPyflinkLib(libPath, symbolicLinkFilePath); + pythonPathEnv.append(File.pathSeparator); + pythonPathEnv.append(symbolicLinkFilePath.toString()); + } + + // 3. copy relevant python files to tmp dir and set them in PYTHONPATH. + filePathMap.forEach((sourceFileName, sourcePath) -> { + Path targetPath = new Path(tmpDirPath, sourceFileName); + try { + FileUtils.copy(sourcePath, targetPath, true); + } catch (IOException e) { + LOG.error("Copy files to tmp dir failed", e); + } + String targetFileName = targetPath.toString(); + pythonPathEnv.append(File.pathSeparator); + pythonPathEnv.append(targetFileName); + + }); + + env.pythonPath = pythonPathEnv.toString(); + return env; + } + + /** + * Gets pyflink dependent libs in specified directory. + * + * @param libDir The lib directory + */ + public static List getLibFiles(String libDir) { + final List libFiles = new ArrayList<>(); + SimpleFileVisitor finder = new SimpleFileVisitor() { + @Override + public FileVisitResult visitFile(java.nio.file.Path file, BasicFileAttributes attrs) throws IOException { + // exclude .txt file + if (!file.toString().endsWith(".txt")) { + libFiles.add(file); + } + return FileVisitResult.CONTINUE; + } + }; + try { + Files.walkFileTree(FileSystems.getDefault().getPath(libDir), finder); + } catch (IOException e) { + LOG.error("Gets pyflink dependent libs failed.", e); + } + return libFiles; + } + + /** + * Creates symbolLink in working directory for pyflink lib. + * + * @param libPath the pyflink lib file path. + * @param symbolicLinkPath the symbolic link to pyflink lib. + */ + public static void createSymbolicLinkForPyflinkLib(java.nio.file.Path libPath, java.nio.file.Path symbolicLinkPath) { + try { + Files.createSymbolicLink(symbolicLinkPath, libPath); + } catch (IOException e) { + LOG.error("Create symbol link for pyflink lib failed.", e); + LOG.info("Try to copy pyflink lib to working directory"); + try { + Files.copy(libPath, symbolicLinkPath); + } catch (IOException ex) { + LOG.error("Copy pylink lib to working directory failed", ex); + } + } + } + + /** + * Starts python process. + * + * @param pythonEnv the python Environment which will be in a process. + * @param commands the commands that python process will execute. + * @return the process represent the python process. + * @throws IOException Thrown if an error occurred when python process start. + */ + public static Process startPythonProcess(PythonEnvironment pythonEnv, List commands) throws IOException { + ProcessBuilder pythonProcessBuilder = new ProcessBuilder(); + Map env = pythonProcessBuilder.environment(); + env.put("PYTHONPATH", pythonEnv.pythonPath); + pythonEnv.systemEnv.forEach(env::put); + commands.add(0, pythonEnv.pythonExec); + pythonProcessBuilder.command(commands); + // set the working directory. + pythonProcessBuilder.directory(new File(pythonEnv.workingDirectory)); + // redirect the stderr to stdout + pythonProcessBuilder.redirectErrorStream(true); + // set the child process the output same as the parent process. + pythonProcessBuilder.redirectOutput(ProcessBuilder.Redirect.INHERIT); + Process process = pythonProcessBuilder.start(); + if (!process.isAlive()) { + throw new RuntimeException("Failed to start Python process. "); + } + + // Make sure that the python sub process will be killed when JVM exit + ShutDownPythonHook hook = new ShutDownPythonHook(process, pythonEnv.workingDirectory); + Runtime.getRuntime().addShutdownHook(hook); + + return process; + } +} diff --git a/flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java b/flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java new file mode 100644 index 00000000000000..0b6f570e75a447 --- /dev/null +++ b/flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.client.python; + +import org.apache.flink.core.fs.Path; + +import org.junit.Assert; +import org.junit.Test; +import py4j.GatewayServer; + +import java.io.IOException; +import java.net.Socket; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Tests for the {@link PythonDriver}. + */ +public class PythonDriverTest { + @Test + public void testStartGatewayServer() { + GatewayServer gatewayServer = PythonDriver.startGatewayServer(); + try { + Socket socket = new Socket("localhost", gatewayServer.getListeningPort()); + assert socket.isConnected(); + } catch (IOException e) { + throw new RuntimeException("Connect Gateway Server failed"); + } finally { + gatewayServer.shutdown(); + } + } + + @Test + public void testConstructCommands() { + Map filePathMap = new HashMap<>(); + Map> parseArgs = new HashMap<>(); + parseArgs.put("py", Collections.singletonList("xxx.py")); + List pyFilesList = new ArrayList<>(); + pyFilesList.add("a.py"); + pyFilesList.add("b.py"); + pyFilesList.add("c.py"); + parseArgs.put("pyfs", pyFilesList); + List otherArgs = new ArrayList<>(); + otherArgs.add("--input"); + otherArgs.add("in.txt"); + parseArgs.put("args", otherArgs); + List commands = PythonDriver.constructPythonCommands(filePathMap, parseArgs); + Path pythonPath = filePathMap.get("xxx.py"); + Assert.assertNotNull(pythonPath); + Assert.assertEquals(pythonPath.getName(), "xxx.py"); + Path aPyFilePath = filePathMap.get("a.py"); + Assert.assertNotNull(aPyFilePath); + Assert.assertEquals(aPyFilePath.getName(), "a.py"); + Path bPyFilePath = filePathMap.get("b.py"); + Assert.assertNotNull(bPyFilePath); + Assert.assertEquals(bPyFilePath.getName(), "b.py"); + Path cPyFilePath = filePathMap.get("c.py"); + Assert.assertNotNull(cPyFilePath); + Assert.assertEquals(cPyFilePath.getName(), "c.py"); + Assert.assertEquals(3, commands.size()); + Assert.assertEquals(commands.get(0), "xxx.py"); + Assert.assertEquals(commands.get(1), "--input"); + Assert.assertEquals(commands.get(2), "in.txt"); + } + + @Test + public void testParseOptions() { + String[] args = {"py", "xxx.py", "pyfs", "a.py,b.py,c.py", "--input", "in.txt"}; + Map> parsedArgs = PythonDriver.parseOptions(args); + List pythonMainFile = parsedArgs.get("py"); + Assert.assertNotNull(pythonMainFile); + Assert.assertEquals(1, pythonMainFile.size()); + Assert.assertEquals(pythonMainFile.get(0), args[1]); + List pyFilesList = parsedArgs.get("pyfs"); + Assert.assertEquals(3, pyFilesList.size()); + String[] pyFiles = args[3].split(","); + for (int i = 0; i < pyFiles.length; i++) { + assert pyFilesList.get(i).equals(pyFiles[i]); + } + List otherArgs = parsedArgs.get("args"); + for (int i = 4; i < args.length; i++) { + Assert.assertEquals(otherArgs.get(i - 4), args[i]); + } + } +} diff --git a/flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java b/flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java new file mode 100644 index 00000000000000..4b14cede4e310c --- /dev/null +++ b/flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.client.python; + +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +/** + * Tests for the {@link PythonUtil}. + */ +public class PythonUtilTest { + private Path sourceTmpDirPath; + private Path targetTmpDirPath; + private FileSystem sourceFs; + private FileSystem targetFs; + + @Before + public void prepareTestEnvironment() { + String sourceTmpDir = System.getProperty("java.io.tmpdir") + + File.separator + "source_" + UUID.randomUUID(); + String targetTmpDir = System.getProperty("java.io.tmpdir") + + File.separator + "target_" + UUID.randomUUID(); + + sourceTmpDirPath = new Path(sourceTmpDir); + targetTmpDirPath = new Path(targetTmpDir); + try { + sourceFs = sourceTmpDirPath.getFileSystem(); + if (sourceFs.exists(sourceTmpDirPath)) { + sourceFs.delete(sourceTmpDirPath, true); + } + sourceFs.mkdirs(sourceTmpDirPath); + targetFs = targetTmpDirPath.getFileSystem(); + if (targetFs.exists(targetTmpDirPath)) { + targetFs.delete(targetTmpDirPath, true); + } + targetFs.mkdirs(targetTmpDirPath); + } catch (IOException e) { + throw new RuntimeException("initial PythonUtil test environment failed"); + } + } + + @Test + public void testStartPythonProcess() { + PythonUtil.PythonEnvironment pythonEnv = new PythonUtil.PythonEnvironment(); + pythonEnv.workingDirectory = targetTmpDirPath.toString(); + pythonEnv.pythonPath = targetTmpDirPath.toString(); + List commands = new ArrayList<>(); + Path pyPath = new Path(targetTmpDirPath, "word_count.py"); + try { + targetFs.create(pyPath, FileSystem.WriteMode.OVERWRITE); + File pyFile = new File(pyPath.toString()); + String pyProgram = "#!/usr/bin/python\n" + + "# -*- coding: UTF-8 -*-\n" + + "import sys\n" + + "\n" + + "if __name__=='__main__':\n" + + "\tfilename = sys.argv[1]\n" + + "\tfo = open(filename, \"w\")\n" + + "\tfo.write( \"hello world\")\n" + + "\tfo.close()"; + Files.write(pyFile.toPath(), pyProgram.getBytes(), StandardOpenOption.WRITE); + Path result = new Path(targetTmpDirPath, "word_count_result.txt"); + commands.add(pyFile.getName()); + commands.add(result.getName()); + Process pythonProcess = PythonUtil.startPythonProcess(pythonEnv, commands); + int exitCode = pythonProcess.waitFor(); + if (exitCode != 0) { + throw new RuntimeException("Python process exits with code: " + exitCode); + } + String cmdResult = new String(Files.readAllBytes(new File(result.toString()).toPath())); + Assert.assertEquals(cmdResult, "hello world"); + pythonProcess.destroyForcibly(); + targetFs.delete(pyPath, true); + targetFs.delete(result, true); + } catch (IOException | InterruptedException e) { + throw new RuntimeException("test start Python process failed " + e.getMessage()); + } + } + + @After + public void cleanEnvironment() { + try { + sourceFs.delete(sourceTmpDirPath, true); + targetFs.delete(targetTmpDirPath, true); + } catch (IOException e) { + throw new RuntimeException("delete tmp dir failed " + e.getMessage()); + } + } +} diff --git a/flink-dist/pom.xml b/flink-dist/pom.xml index 52a6466e513929..1350f10958894f 100644 --- a/flink-dist/pom.xml +++ b/flink-dist/pom.xml @@ -572,6 +572,9 @@ under the License. py4j org.apache.flink.api.python.py4j + + py4j.* + diff --git a/flink-dist/src/main/assemblies/bin.xml b/flink-dist/src/main/assemblies/bin.xml index 3eb5698c19f5fd..788ec1bbdfc00f 100644 --- a/flink-dist/src/main/assemblies/bin.xml +++ b/flink-dist/src/main/assemblies/bin.xml @@ -242,6 +242,13 @@ under the License. 0755 + + + ../flink-python/pyflink/table/examples + examples/python/table + 0755 + + diff --git a/flink-python/pyflink/find_flink_home.py b/flink-python/pyflink/find_flink_home.py index 049136864a3fe7..98064908cbe752 100644 --- a/flink-python/pyflink/find_flink_home.py +++ b/flink-python/pyflink/find_flink_home.py @@ -28,6 +28,9 @@ def _find_flink_home(): # If the environment has set FLINK_HOME, trust it. if 'FLINK_HOME' in os.environ: return os.environ['FLINK_HOME'] + elif 'FLINK_ROOT_DIR' in os.environ: + os.environ['FLINK_HOME'] = os.environ['FLINK_ROOT_DIR'] + return os.environ['FLINK_ROOT_DIR'] else: try: flink_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)) + "/../../") diff --git a/flink-python/pyflink/java_gateway.py b/flink-python/pyflink/java_gateway.py index e5c8330e551db9..b218d231662704 100644 --- a/flink-python/pyflink/java_gateway.py +++ b/flink-python/pyflink/java_gateway.py @@ -28,7 +28,6 @@ from py4j.java_gateway import java_import, JavaGateway, GatewayParameters from pyflink.find_flink_home import _find_flink_home - _gateway = None _lock = RLock() @@ -46,6 +45,9 @@ def get_gateway(): _gateway = JavaGateway(gateway_parameters=gateway_param) else: _gateway = launch_gateway() + + # import the flink view + import_flink_view(_gateway) return _gateway @@ -97,6 +99,14 @@ def preexec_func(): gateway = JavaGateway( gateway_parameters=GatewayParameters(port=gateway_port, auto_convert=True)) + return gateway + + +def import_flink_view(gateway): + """ + import the classes used by PyFlink. + :param gateway:gateway connected to JavaGateWayServer + """ # Import the classes used by PyFlink java_import(gateway.jvm, "org.apache.flink.table.api.*") java_import(gateway.jvm, "org.apache.flink.table.api.java.*") @@ -109,5 +119,3 @@ def preexec_func(): java_import(gateway.jvm, "org.apache.flink.api.java.ExecutionEnvironment") java_import(gateway.jvm, "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment") - - return gateway diff --git a/flink-python/pyflink/table/examples/batch/__init__.py b/flink-python/pyflink/table/examples/batch/__init__.py new file mode 100644 index 00000000000000..65b48d4d79b4e3 --- /dev/null +++ b/flink-python/pyflink/table/examples/batch/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ diff --git a/flink-python/pyflink/table/examples/batch/word_count.py b/flink-python/pyflink/table/examples/batch/word_count.py new file mode 100644 index 00000000000000..a324af4747c1d5 --- /dev/null +++ b/flink-python/pyflink/table/examples/batch/word_count.py @@ -0,0 +1,79 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import os +import tempfile + +from pyflink.table import TableEnvironment, TableConfig +from pyflink.table.table_sink import CsvTableSink +from pyflink.table.table_source import CsvTableSource +from pyflink.table.types import DataTypes + + +# TODO: the word_count.py is just a test example for CLI. +# After pyflink have aligned Java Table API Connectors, this example will be improved. +def word_count(): + tmp_dir = tempfile.gettempdir() + source_path = tmp_dir + '/streaming.csv' + if os.path.isfile(source_path): + os.remove(source_path) + content = "line Licensed to the Apache Software Foundation ASF under one " \ + "line or more contributor license agreements See the NOTICE file " \ + "line distributed with this work for additional information " \ + "line regarding copyright ownership The ASF licenses this file " \ + "to you under the Apache License Version the " \ + "License you may not use this file except in compliance " \ + "with the License" + + with open(source_path, 'w') as f: + for word in content.split(" "): + f.write(",".join([word, "1"])) + f.write("\n") + f.flush() + f.close() + + t_config = TableConfig.Builder().as_batch_execution().set_parallelism(1).build() + t_env = TableEnvironment.create(t_config) + + field_names = ["word", "cout"] + field_types = [DataTypes.STRING, DataTypes.LONG] + + # register Orders table in table environment + t_env.register_table_source( + "Word", + CsvTableSource(source_path, field_names, field_types)) + + # register Results table in table environment + tmp_dir = tempfile.gettempdir() + tmp_csv = tmp_dir + '/streaming2.csv' + if os.path.isfile(tmp_csv): + os.remove(tmp_csv) + + t_env.register_table_sink( + "Results", + field_names, field_types, CsvTableSink(tmp_csv)) + + t_env.scan("Word") \ + .group_by("word") \ + .select("word, count(1) as count") \ + .insert_into("Results") + + t_env.execute() + + +if __name__ == '__main__': + word_count() From 8510e294802c096418cdc121b21d64de630c201f Mon Sep 17 00:00:00 2001 From: sunjincheng121 Date: Tue, 28 May 2019 15:26:14 +0800 Subject: [PATCH 12/92] [hotfix][python]fix command error for python API doc, and function call bug in table environment. --- docs/ops/cli.md | 2 +- docs/ops/cli.zh.md | 2 +- .../src/main/java/org/apache/flink/client/cli/CliFrontend.java | 2 +- .../main/java/org/apache/flink/client/python/PythonUtil.java | 2 +- flink-python/pyflink/table/table_environment.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ops/cli.md b/docs/ops/cli.md index 505207dd8b2187..ab1fd1621ba610 100644 --- a/docs/ops/cli.md +++ b/docs/ops/cli.md @@ -121,7 +121,7 @@ These examples about how to submit a job in CLI. - Run Python Table program in detached mode: - ./bin/flink run -d examples/python/table/batch/word_count.py -j + ./bin/flink run -d -py examples/python/table/batch/word_count.py -j - Run Python Table program on a specific JobManager: diff --git a/docs/ops/cli.zh.md b/docs/ops/cli.zh.md index 93f16fb62fd7d2..ec6001b3b84d42 100644 --- a/docs/ops/cli.zh.md +++ b/docs/ops/cli.zh.md @@ -121,7 +121,7 @@ available. - 提交一个运行在detached模式下的Python Table的作业: - ./bin/flink run -d examples/python/table/batch/word_count.py -j + ./bin/flink run -d -py examples/python/table/batch/word_count.py -j - 提交一个运行在指定JobManager上的Python Table的作业: diff --git a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java index c591e6e5d5f202..fe641fe78134ad 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java +++ b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java @@ -786,7 +786,7 @@ PackagedProgram buildProgram(ProgramOptions options) throws FileNotFoundExceptio entryPointClass = PythonDriver.class.getCanonicalName(); } else { if (jarFilePath == null) { - throw new IllegalArgumentException("The program JAR file was not specified."); + throw new IllegalArgumentException("Java program should be specified a JAR file."); } jarFile = getJarFile(jarFilePath); // Get assembler class diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java b/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java index b9012a38fa4bf3..9fecd499966731 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java +++ b/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java @@ -96,7 +96,7 @@ public static PythonEnvironment preparePythonEnvironment(Map fileP // 1. setup temporary local directory for the user files String tmpDir = System.getProperty("java.io.tmpdir") + - File.separator + "pyflink" + UUID.randomUUID(); + File.separator + "pyflink" + File.separator + UUID.randomUUID(); Path tmpDirPath = new Path(tmpDir); try { diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py index 337a0e4f7241eb..45e268847f7471 100644 --- a/flink-python/pyflink/table/table_environment.py +++ b/flink-python/pyflink/table/table_environment.py @@ -374,7 +374,7 @@ def create(cls, table_config): j_execution_env, table_config._j_table_config) t_env = BatchTableEnvironment(j_tenv) - if table_config.parallelism is not None: + if table_config.parallelism() is not None: t_env._j_tenv.execEnv().setParallelism(table_config.parallelism()) return t_env From 58987dd16c7e8af36e935e811f716d2f843de5ca Mon Sep 17 00:00:00 2001 From: "zhuzhu.zz" Date: Wed, 15 May 2019 18:38:37 +0800 Subject: [PATCH 13/92] [FLINK-12068][runtime] Implement region backtracking for region failover strategy --- .../failover/flip1/FailoverRegion.java | 20 +- .../flip1/RestartPipelinedRegionStrategy.java | 164 ++++++-- .../ResultPartitionAvailabilityChecker.java | 35 ++ .../RestartPipelinedRegionStrategyTest.java | 388 +++++++++++++++++- .../failover/flip1/TestFailoverTopology.java | 12 +- 5 files changed, 569 insertions(+), 50 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ResultPartitionAvailabilityChecker.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverRegion.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverRegion.java index 88a76582c05d2d..9a013df03b1c43 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverRegion.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverRegion.java @@ -20,10 +20,7 @@ import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; -import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; -import java.util.Map; import java.util.Set; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -33,18 +30,21 @@ */ public class FailoverRegion { + /** All vertex IDs in this region. */ + private final Set executionVertexIDs; + /** All vertices in this region. */ - private final Map executionVertices; + private final Set executionVertices; /** * Creates a new failover region containing a set of vertices. * * @param executionVertices to be contained in this region */ - public FailoverRegion(Collection executionVertices) { - checkNotNull(executionVertices); - this.executionVertices = new HashMap<>(); - executionVertices.forEach(v -> this.executionVertices.put(v.getExecutionVertexID(), v)); + public FailoverRegion(Set executionVertices) { + this.executionVertices = checkNotNull(executionVertices); + this.executionVertexIDs = new HashSet<>(); + executionVertices.forEach(v -> this.executionVertexIDs.add(v.getExecutionVertexID())); } /** @@ -53,7 +53,7 @@ public FailoverRegion(Collection executionVertices) { * @return IDs of all vertices in this region */ public Set getAllExecutionVertexIDs() { - return executionVertices.keySet(); + return executionVertexIDs; } /** @@ -62,6 +62,6 @@ public Set getAllExecutionVertexIDs() { * @return all vertices in this region */ public Set getAllExecutionVertices() { - return new HashSet<>(executionVertices.values()); + return executionVertices; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategy.java index e7040c3312f83b..8e063277bdd8d8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategy.java @@ -18,20 +18,22 @@ package org.apache.flink.runtime.executiongraph.failover.flip1; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.io.network.partition.PartitionException; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; -import java.util.Collections; +import java.util.ArrayDeque; import java.util.HashMap; import java.util.HashSet; import java.util.IdentityHashMap; -import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Queue; import java.util.Set; -import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -47,23 +49,46 @@ public class RestartPipelinedRegionStrategy implements FailoverStrategy { /** The topology containing info about all the vertices and edges. */ private final FailoverTopology topology; + /** All failover regions. */ + private final IdentityHashMap regions; + /** Maps execution vertex id to failover region. */ - private final Map regions; + private final Map vertexToRegionMap; + + /** The checker helps to query result partition availability. */ + private final RegionFailoverResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker; /** * Creates a new failover strategy to restart pipelined regions that works on the given topology. + * The result partitions are always considered to be available if no data consumption error happens. * * @param topology containing info about all the vertices and edges */ + @VisibleForTesting public RestartPipelinedRegionStrategy(FailoverTopology topology) { + this(topology, resultPartitionID -> true); + } + + /** + * Creates a new failover strategy to restart pipelined regions that works on the given topology. + * + * @param topology containing info about all the vertices and edges + * @param resultPartitionAvailabilityChecker helps to query result partition availability + */ + public RestartPipelinedRegionStrategy( + FailoverTopology topology, + ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker) { + this.topology = checkNotNull(topology); - this.regions = new HashMap<>(); + this.regions = new IdentityHashMap<>(); + this.vertexToRegionMap = new HashMap<>(); + this.resultPartitionAvailabilityChecker = new RegionFailoverResultPartitionAvailabilityChecker( + resultPartitionAvailabilityChecker); // build regions based on the given topology LOG.info("Start building failover regions."); buildFailoverRegions(); } - // ------------------------------------------------------------------------ // region building // ------------------------------------------------------------------------ @@ -131,25 +156,28 @@ private void buildFailoverRegions() { for (HashSet regionVertices : distinctRegions.keySet()) { LOG.debug("Creating a failover region with {} vertices.", regionVertices.size()); final FailoverRegion failoverRegion = new FailoverRegion(regionVertices); + regions.put(failoverRegion, null); for (FailoverVertex vertex : regionVertices) { - this.regions.put(vertex.getExecutionVertexID(), failoverRegion); + vertexToRegionMap.put(vertex.getExecutionVertexID(), failoverRegion); } } - LOG.info("Created {} failover regions.", distinctRegions.size()); + + LOG.info("Created {} failover regions.", regions.size()); } private void buildOneRegionForAllVertices() { LOG.warn("Cannot decompose the topology into individual failover regions due to use of " + "Co-Location constraints (iterations). Job will fail over as one holistic unit."); - final List allVertices = new ArrayList<>(); + final Set allVertices = new HashSet<>(); for (FailoverVertex vertex : topology.getFailoverVertices()) { allVertices.add(vertex); } final FailoverRegion region = new FailoverRegion(allVertices); + regions.put(region, null); for (FailoverVertex vertex : topology.getFailoverVertices()) { - regions.put(vertex.getExecutionVertexID(), region); + vertexToRegionMap.put(vertex.getExecutionVertexID(), region); } } @@ -162,9 +190,9 @@ private void buildOneRegionForAllVertices() { * In this strategy, all task vertices in 'involved' regions are proposed to be restarted. * The 'involved' regions are calculated with rules below: * 1. The region containing the failed task is always involved - * 2. TODO: If an input result partition of an involved region is not available, i.e. Missing or Corrupted, + * 2. If an input result partition of an involved region is not available, i.e. Missing or Corrupted, * the region containing the partition producer task is involved - * 3. TODO: If a region is involved, all of its consumer regions are involved + * 3. If a region is involved, all of its consumer regions are involved * * @param executionVertexId ID of the failed task * @param cause cause of the failure @@ -172,30 +200,87 @@ private void buildOneRegionForAllVertices() { */ @Override public Set getTasksNeedingRestart(ExecutionVertexID executionVertexId, Throwable cause) { - final FailoverRegion failedRegion = regions.get(executionVertexId); + LOG.info("Calculating tasks to restart to recover the failed task {}.", executionVertexId); + + final FailoverRegion failedRegion = vertexToRegionMap.get(executionVertexId); if (failedRegion == null) { // TODO: show the task name in the log throw new IllegalStateException("Can not find the failover region for task " + executionVertexId, cause); } - // TODO: if the failure cause is data consumption error, mark the corresponding data partition to be unavailable + // if the failure cause is data consumption error, mark the corresponding data partition to be failed, + // so that the failover process will try to recover it + Optional dataConsumptionException = ExceptionUtils.findThrowable( + cause, PartitionException.class); + if (dataConsumptionException.isPresent()) { + resultPartitionAvailabilityChecker.markResultPartitionFailed( + dataConsumptionException.get().getPartitionId().getPartitionId()); + } + + // calculate the tasks to restart based on the result of regions to restart + Set tasksToRestart = new HashSet<>(); + for (FailoverRegion region : getRegionsToRestart(failedRegion)) { + tasksToRestart.addAll(region.getAllExecutionVertexIDs()); + } - return getRegionsToRestart(failedRegion).stream().flatMap( - r -> r.getAllExecutionVertexIDs().stream()).collect(Collectors.toSet()); + // the previous failed partition will be recovered. remove its failed state from the checker + if (dataConsumptionException.isPresent()) { + resultPartitionAvailabilityChecker.removeResultPartitionFromFailedState( + dataConsumptionException.get().getPartitionId().getPartitionId()); + } + + LOG.info("{} tasks should be restarted to recover the failed task {}. ", tasksToRestart.size(), executionVertexId); + return tasksToRestart; } /** * All 'involved' regions are proposed to be restarted. * The 'involved' regions are calculated with rules below: * 1. The region containing the failed task is always involved - * 2. TODO: If an input result partition of an involved region is not available, i.e. Missing or Corrupted, + * 2. If an input result partition of an involved region is not available, i.e. Missing or Corrupted, * the region containing the partition producer task is involved - * 3. TODO: If a region is involved, all of its consumer regions are involved + * 3. If a region is involved, all of its consumer regions are involved */ - private Set getRegionsToRestart(FailoverRegion regionToRestart) { - return Collections.singleton(regionToRestart); + private Set getRegionsToRestart(FailoverRegion failedRegion) { + IdentityHashMap regionsToRestart = new IdentityHashMap<>(); + IdentityHashMap visitedRegions = new IdentityHashMap<>(); + + // start from the failed region to visit all involved regions + Queue regionsToVisit = new ArrayDeque<>(); + visitedRegions.put(failedRegion, null); + regionsToVisit.add(failedRegion); + while (!regionsToVisit.isEmpty()) { + FailoverRegion regionToRestart = regionsToVisit.poll(); + + // an involved region should be restarted + regionsToRestart.put(regionToRestart, null); - // TODO: implement backtracking logic + // if a needed input result partition is not available, its producer region is involved + for (FailoverVertex vertex : regionToRestart.getAllExecutionVertices()) { + for (FailoverEdge inEdge : vertex.getInputEdges()) { + if (!resultPartitionAvailabilityChecker.isAvailable(inEdge.getResultPartitionID())) { + FailoverRegion producerRegion = vertexToRegionMap.get(inEdge.getSourceVertex().getExecutionVertexID()); + if (!visitedRegions.containsKey(producerRegion)) { + visitedRegions.put(producerRegion, null); + regionsToVisit.add(producerRegion); + } + } + } + } + + // all consumer regions of an involved region should be involved + for (FailoverVertex vertex : regionToRestart.getAllExecutionVertices()) { + for (FailoverEdge outEdge : vertex.getOutputEdges()) { + FailoverRegion consumerRegion = vertexToRegionMap.get(outEdge.getTargetVertex().getExecutionVertexID()); + if (!visitedRegions.containsKey(consumerRegion)) { + visitedRegions.put(consumerRegion, null); + regionsToVisit.add(consumerRegion); + } + } + } + } + + return regionsToRestart.keySet(); } // ------------------------------------------------------------------------ @@ -208,7 +293,38 @@ private Set getRegionsToRestart(FailoverRegion regionToRestart) * @return the failover region that contains the given execution vertex */ @VisibleForTesting - FailoverRegion getFailoverRegion(ExecutionVertexID vertexID) { - return regions.get(vertexID); + public FailoverRegion getFailoverRegion(ExecutionVertexID vertexID) { + return vertexToRegionMap.get(vertexID); + } + + /** + * A stateful {@link ResultPartitionAvailabilityChecker} which maintains the failed partitions which are not available. + */ + private static class RegionFailoverResultPartitionAvailabilityChecker implements ResultPartitionAvailabilityChecker { + + /** Result partition state checker from the shuffle master. */ + private final ResultPartitionAvailabilityChecker resultPartitionAvailabilityChecker; + + /** Records partitions which has caused {@link PartitionException}. */ + private final HashSet failedPartitions; + + RegionFailoverResultPartitionAvailabilityChecker(ResultPartitionAvailabilityChecker checker) { + this.resultPartitionAvailabilityChecker = checkNotNull(checker); + this.failedPartitions = new HashSet<>(); + } + + @Override + public boolean isAvailable(IntermediateResultPartitionID resultPartitionID) { + return !failedPartitions.contains(resultPartitionID) && + resultPartitionAvailabilityChecker.isAvailable(resultPartitionID); + } + + public void markResultPartitionFailed(IntermediateResultPartitionID resultPartitionID) { + failedPartitions.add(resultPartitionID); + } + + public void removeResultPartitionFromFailedState(IntermediateResultPartitionID resultPartitionID) { + failedPartitions.remove(resultPartitionID); + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ResultPartitionAvailabilityChecker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ResultPartitionAvailabilityChecker.java new file mode 100644 index 00000000000000..286ad3e8d1ec41 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ResultPartitionAvailabilityChecker.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph.failover.flip1; + +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; + +/** + * This checker helps to query result partition availability. + */ +interface ResultPartitionAvailabilityChecker { + + /** + * Returns whether the given partition is available. + * + * @param resultPartitionID ID of the result partition to query + * @return whether the given partition is available + */ + boolean isAvailable(IntermediateResultPartitionID resultPartitionID); +} + diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategyTest.java index 9a3f5954a2c520..5a9c844d6ed4db 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionStrategyTest.java @@ -18,10 +18,22 @@ package org.apache.flink.runtime.executiongraph.failover.flip1; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.consumer.PartitionConnectionException; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.util.TestLogger; import org.junit.Test; +import java.util.HashSet; +import java.util.Iterator; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertEquals; /** @@ -30,37 +42,387 @@ public class RestartPipelinedRegionStrategyTest extends TestLogger { /** - * Tests for scenes that a task fails for its own error, in which case only the - * region containing the failed task should be restarted. + * Tests for scenes that a task fails for its own error, in which case the + * region containing the failed task and its consumer regions should be restarted. *
-	 *     (v1)
+	 *     (v1) -+-> (v4)
+	 *           x
+	 *     (v2) -+-> (v5)
 	 *
-	 *     (v2)
+	 *     (v3) -+-> (v6)
 	 *
-	 *     (v3)
+	 *           ^
+	 *           |
+	 *       (blocking)
 	 * 
+ * Each vertex is in an individual region. */ @Test - public void testRegionFailoverForTaskInternalErrors() throws Exception { + public void testRegionFailoverForRegionInternalErrors() throws Exception { TestFailoverTopology.Builder topologyBuilder = new TestFailoverTopology.Builder(); TestFailoverTopology.TestFailoverVertex v1 = topologyBuilder.newVertex(); TestFailoverTopology.TestFailoverVertex v2 = topologyBuilder.newVertex(); TestFailoverTopology.TestFailoverVertex v3 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v4 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v5 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v6 = topologyBuilder.newVertex(); + + topologyBuilder.connect(v1, v4, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v1, v5, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v2, v4, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v2, v5, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v3, v6, ResultPartitionType.BLOCKING); FailoverTopology topology = topologyBuilder.build(); RestartPipelinedRegionStrategy strategy = new RestartPipelinedRegionStrategy(topology); - FailoverRegion r1 = strategy.getFailoverRegion(v1.getExecutionVertexID()); - FailoverRegion r2 = strategy.getFailoverRegion(v2.getExecutionVertexID()); - FailoverRegion r3 = strategy.getFailoverRegion(v3.getExecutionVertexID()); - - assertEquals(r1.getAllExecutionVertexIDs(), + // when v1 fails, {v1,v4,v5} should be restarted + HashSet expectedResult = new HashSet<>(); + expectedResult.add(v1.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, strategy.getTasksNeedingRestart(v1.getExecutionVertexID(), new Exception("Test failure"))); - assertEquals(r2.getAllExecutionVertexIDs(), + + // when v2 fails, {v2,v4,v5} should be restarted + expectedResult.clear(); + expectedResult.add(v2.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, strategy.getTasksNeedingRestart(v2.getExecutionVertexID(), new Exception("Test failure"))); - assertEquals(r3.getAllExecutionVertexIDs(), + + // when v3 fails, {v3,v6} should be restarted + expectedResult.clear(); + expectedResult.add(v3.getExecutionVertexID()); + expectedResult.add(v6.getExecutionVertexID()); + assertEquals(expectedResult, strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), new Exception("Test failure"))); + + // when v4 fails, {v4} should be restarted + expectedResult.clear(); + expectedResult.add(v4.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v4.getExecutionVertexID(), new Exception("Test failure"))); + + // when v5 fails, {v5} should be restarted + expectedResult.clear(); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v5.getExecutionVertexID(), new Exception("Test failure"))); + + // when v6 fails, {v6} should be restarted + expectedResult.clear(); + expectedResult.add(v6.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v6.getExecutionVertexID(), new Exception("Test failure"))); + } + + /** + * Tests for scenes that a task fails for data consumption error, in which case the + * region containing the failed task, the region containing the unavailable result partition + * and all their consumer regions should be restarted. + *
+	 *     (v1) -+-> (v4)
+	 *           x
+	 *     (v2) -+-> (v5)
+	 *
+	 *     (v3) -+-> (v6)
+	 *
+	 *           ^
+	 *           |
+	 *       (blocking)
+	 * 
+ * Each vertex is in an individual region. + */ + @Test + public void testRegionFailoverForDataConsumptionErrors() throws Exception { + TestFailoverTopology.Builder topologyBuilder = new TestFailoverTopology.Builder(); + + TestFailoverTopology.TestFailoverVertex v1 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v2 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v3 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v4 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v5 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v6 = topologyBuilder.newVertex(); + + topologyBuilder.connect(v1, v4, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v1, v5, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v2, v4, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v2, v5, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v3, v6, ResultPartitionType.BLOCKING); + + FailoverTopology topology = topologyBuilder.build(); + + RestartPipelinedRegionStrategy strategy = new RestartPipelinedRegionStrategy(topology); + + // when v4 fails to consume data from v1, {v1,v4,v5} should be restarted + HashSet expectedResult = new HashSet<>(); + Iterator v4InputEdgeIterator = v4.getInputEdges().iterator(); + expectedResult.add(v1.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v4.getExecutionVertexID(), + new PartitionConnectionException( + new ResultPartitionID( + v4InputEdgeIterator.next().getResultPartitionID(), + new ExecutionAttemptID()), + new Exception("Test failure")))); + + // when v4 fails to consume data from v2, {v2,v4,v5} should be restarted + expectedResult.clear(); + expectedResult.add(v2.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v4.getExecutionVertexID(), + new PartitionNotFoundException( + new ResultPartitionID( + v4InputEdgeIterator.next().getResultPartitionID(), + new ExecutionAttemptID())))); + + // when v5 fails to consume data from v1, {v1,v4,v5} should be restarted + expectedResult.clear(); + Iterator v5InputEdgeIterator = v5.getInputEdges().iterator(); + expectedResult.add(v1.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v5.getExecutionVertexID(), + new PartitionConnectionException( + new ResultPartitionID( + v5InputEdgeIterator.next().getResultPartitionID(), + new ExecutionAttemptID()), + new Exception("Test failure")))); + + // when v5 fails to consume data from v2, {v2,v4,v5} should be restarted + expectedResult.clear(); + expectedResult.add(v2.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v5.getExecutionVertexID(), + new PartitionNotFoundException( + new ResultPartitionID( + v5InputEdgeIterator.next().getResultPartitionID(), + new ExecutionAttemptID())))); + + // when v6 fails to consume data from v3, {v3,v6} should be restarted + expectedResult.clear(); + Iterator v6InputEdgeIterator = v6.getInputEdges().iterator(); + expectedResult.add(v3.getExecutionVertexID()); + expectedResult.add(v6.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v6.getExecutionVertexID(), + new PartitionConnectionException( + new ResultPartitionID( + v6InputEdgeIterator.next().getResultPartitionID(), + new ExecutionAttemptID()), + new Exception("Test failure")))); + } + + /** + * Tests to verify region failover results regarding different input result partition availability combinations. + *
+	 *     (v1) --rp1--\
+	 *                 (v3)
+	 *     (v2) --rp2--/
+	 *
+	 *             ^
+	 *             |
+	 *         (blocking)
+	 * 
+ * Each vertex is in an individual region. + * rp1, rp2 are result partitions. + */ + @Test + public void testRegionFailoverForVariousResultPartitionAvailabilityCombinations() throws Exception { + TestFailoverTopology.Builder topologyBuilder = new TestFailoverTopology.Builder(); + + TestFailoverTopology.TestFailoverVertex v1 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v2 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v3 = topologyBuilder.newVertex(); + + topologyBuilder.connect(v1, v3, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v2, v3, ResultPartitionType.BLOCKING); + + FailoverTopology topology = topologyBuilder.build(); + + TestResultPartitionAvailabilityChecker availabilityChecker = new TestResultPartitionAvailabilityChecker(); + RestartPipelinedRegionStrategy strategy = new RestartPipelinedRegionStrategy(topology, availabilityChecker); + + IntermediateResultPartitionID rp1ID = v1.getOutputEdges().iterator().next().getResultPartitionID(); + IntermediateResultPartitionID rp2ID = v2.getOutputEdges().iterator().next().getResultPartitionID(); + + // ------------------------------------------------- + // Combination1: (rp1 == available, rp == available) + // ------------------------------------------------- + availabilityChecker.failedPartitions.clear(); + + // when v1 fails, {v1,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v1.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v2 fails, {v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v2.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v3 fails, {v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v3.getExecutionVertexID())); + + // ------------------------------------------------- + // Combination2: (rp1 == unavailable, rp == available) + // ------------------------------------------------- + availabilityChecker.failedPartitions.clear(); + availabilityChecker.markResultPartitionFailed(rp1ID); + + // when v1 fails, {v1,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v1.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v2 fails, {v1,v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v2.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v3 fails, {v1,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v3.getExecutionVertexID())); + + // ------------------------------------------------- + // Combination3: (rp1 == available, rp == unavailable) + // ------------------------------------------------- + availabilityChecker.failedPartitions.clear(); + availabilityChecker.markResultPartitionFailed(rp2ID); + + // when v1 fails, {v1,v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v1.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v2 fails, {v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v2.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v3 fails, {v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // ------------------------------------------------- + // Combination4: (rp1 == unavailable, rp == unavailable) + // ------------------------------------------------- + availabilityChecker.failedPartitions.clear(); + availabilityChecker.markResultPartitionFailed(rp1ID); + availabilityChecker.markResultPartitionFailed(rp2ID); + + // when v1 fails, {v1,v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v1.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v2 fails, {v1,v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v2.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v2.getExecutionVertexID(), v3.getExecutionVertexID())); + + // when v3 fails, {v1,v2,v3} should be restarted + assertThat( + strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), new Exception("Test failure")), + containsInAnyOrder(v1.getExecutionVertexID(), v2.getExecutionVertexID(), v3.getExecutionVertexID())); + } + + /** + * Tests region failover scenes for topology with multiple vertices. + *
+	 *     (v1) ---> (v2) --|--> (v3) ---> (v4) --|--> (v5) ---> (v6)
+	 *
+	 *           ^          ^          ^          ^          ^
+	 *           |          |          |          |          |
+	 *     (pipelined) (blocking) (pipelined) (blocking) (pipelined)
+	 * 
+ * Component 1: 1,2; component 2: 3,4; component 3: 5,6 + */ + @Test + public void testRegionFailoverForMultipleVerticesRegions() throws Exception { + TestFailoverTopology.Builder topologyBuilder = new TestFailoverTopology.Builder(); + + TestFailoverTopology.TestFailoverVertex v1 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v2 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v3 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v4 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v5 = topologyBuilder.newVertex(); + TestFailoverTopology.TestFailoverVertex v6 = topologyBuilder.newVertex(); + + topologyBuilder.connect(v1, v2, ResultPartitionType.PIPELINED); + topologyBuilder.connect(v2, v3, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v3, v4, ResultPartitionType.PIPELINED); + topologyBuilder.connect(v4, v5, ResultPartitionType.BLOCKING); + topologyBuilder.connect(v5, v6, ResultPartitionType.PIPELINED); + + FailoverTopology topology = topologyBuilder.build(); + + RestartPipelinedRegionStrategy strategy = new RestartPipelinedRegionStrategy(topology); + + // when v3 fails due to internal error, {v3,v4,v5,v6} should be restarted + HashSet expectedResult = new HashSet<>(); + expectedResult.add(v3.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + expectedResult.add(v6.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), new Exception("Test failure"))); + + // when v3 fails to consume from v2, {v1,v2,v3,v4,v5,v6} should be restarted + expectedResult.clear(); + expectedResult.add(v1.getExecutionVertexID()); + expectedResult.add(v2.getExecutionVertexID()); + expectedResult.add(v3.getExecutionVertexID()); + expectedResult.add(v4.getExecutionVertexID()); + expectedResult.add(v5.getExecutionVertexID()); + expectedResult.add(v6.getExecutionVertexID()); + assertEquals(expectedResult, + strategy.getTasksNeedingRestart(v3.getExecutionVertexID(), + new PartitionConnectionException( + new ResultPartitionID( + v3.getInputEdges().iterator().next().getResultPartitionID(), + new ExecutionAttemptID()), + new Exception("Test failure")))); + } + + // ------------------------------------------------------------------------ + // utilities + // ------------------------------------------------------------------------ + + private static class TestResultPartitionAvailabilityChecker implements ResultPartitionAvailabilityChecker { + + private final HashSet failedPartitions; + + public TestResultPartitionAvailabilityChecker() { + this.failedPartitions = new HashSet<>(); + } + + @Override + public boolean isAvailable(IntermediateResultPartitionID resultPartitionID) { + return !failedPartitions.contains(resultPartitionID); + } + + public void markResultPartitionFailed(IntermediateResultPartitionID resultPartitionID) { + failedPartitions.add(resultPartitionID); + } + + public void removeResultPartitionFromFailedState(IntermediateResultPartitionID resultPartitionID) { + failedPartitions.remove(resultPartitionID); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/TestFailoverTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/TestFailoverTopology.java index 9daec8eb8f97f8..6b8b5fde18760e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/TestFailoverTopology.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/TestFailoverTopology.java @@ -130,9 +130,7 @@ public static class Builder { private Collection vertices = new ArrayList<>(); public TestFailoverVertex newVertex() { - TestFailoverVertex testFailoverVertex = newVertex(UUID.randomUUID().toString()); - vertices.add(testFailoverVertex); - return testFailoverVertex; + return newVertex(UUID.randomUUID().toString()); } public TestFailoverVertex newVertex(String name) { @@ -149,6 +147,14 @@ public Builder connect(TestFailoverVertex source, TestFailoverVertex target, Res return this; } + public Builder connect(TestFailoverVertex source, TestFailoverVertex target, ResultPartitionType partitionType, IntermediateResultPartitionID partitionID) { + FailoverEdge edge = new TestFailoverEdge(partitionID, partitionType, source, target); + source.addOuputEdge(edge); + target.addInputEdge(edge); + + return this; + } + public Builder setContainsCoLocationConstraints(boolean containsCoLocationConstraints) { this.containsCoLocationConstraints = containsCoLocationConstraints; return this; From eb65515caaca21d852503c91bc4af5d30df90b36 Mon Sep 17 00:00:00 2001 From: Chesnay Schepler Date: Tue, 28 May 2019 10:15:45 +0200 Subject: [PATCH 14/92] [FLINK-12644][travis] Setup jdk9 sticky e2e tests --- .travis.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.travis.yml b/.travis.yml index aa1e0b8be87d4d..c35acd5851a8ea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -299,3 +299,10 @@ jobs: env: PROFILE="-Dinclude-kinesis" script: ./tools/travis/nightly.sh split_heavy.sh name: heavy + - # E2E profiles - Java 9 + - if: type = cron + stage: test + jdk: "openjdk9" + env: PROFILE="-De2e-metrics -Dinclude-kinesis -Djdk9" + script: ./tools/travis/nightly.sh split_sticky.sh + name: misc - jdk 9 From b28865eac0ca20b028875ff4cf851812a864195b Mon Sep 17 00:00:00 2001 From: Chesnay Schepler Date: Tue, 28 May 2019 10:25:51 +0200 Subject: [PATCH 15/92] [FLINK-12644][travis] Setup jdk9 cp/ha e2e tests --- .travis.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index c35acd5851a8ea..1fbed25d308dd9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -304,5 +304,13 @@ jobs: stage: test jdk: "openjdk9" env: PROFILE="-De2e-metrics -Dinclude-kinesis -Djdk9" + script: ./tools/travis/nightly.sh split_ha.sh + name: ha - jdk9 + - if: type = cron + env: PROFILE="-Dinclude-kinesis -Djdk9" script: ./tools/travis/nightly.sh split_sticky.sh - name: misc - jdk 9 + name: sticky - jdk 9 + - if: type = cron + env: PROFILE="-Dinclude-kinesis -Djdk9" + script: ./tools/travis/nightly.sh split_checkpoints.sh + name: checkpoints - jdk 9 From b333ddc16967ae7428229d659130b61de906ef5f Mon Sep 17 00:00:00 2001 From: godfrey he Date: Tue, 28 May 2019 17:21:34 +0800 Subject: [PATCH 16/92] [FLINK-12600][table-planner-blink] Introduce various deterministic rewriting rule, which includes: 1. FlinkLimit0RemoveRule, that rewrites `limit 0` to empty Values 2. FlinkRewriteSubQueryRule, that rewrites a Filter with condition: `(select count from T) > 0` to a Filter with condition: `exists(select * from T)` 3. ReplaceIntersectWithSemiJoinRule, that rewrites distinct Intersect to a distinct Aggregate on a SEMI Join 4. ReplaceMinusWithAntiJoinRule, that rewrites distinct Minus to a distinct Aggregate on an ANTI Join This closes #8520 --- .../table/plan/rules/FlinkBatchRuleSets.scala | 13 +- .../plan/rules/FlinkStreamRuleSets.scala | 13 +- .../rules/logical/FlinkCalcMergeRule.scala | 2 +- .../rules/logical/FlinkLimit0RemoveRule.scala | 50 ++ .../rules/logical/FlinkPruneEmptyRules.scala | 70 ++ .../logical/FlinkRewriteSubQueryRule.scala | 168 +++++ .../ReplaceIntersectWithSemiJoinRule.scala | 61 ++ .../ReplaceMinusWithAntiJoinRule.scala | 61 ++ .../ReplaceSetOpWithJoinRuleBase.scala | 58 ++ .../flink/table/plan/batch/sql/LimitTest.xml | 23 +- .../table/plan/batch/sql/SetOperatorsTest.xml | 224 +++++++ .../table/plan/batch/sql/SortLimitTest.xml | 20 +- .../table/plan/batch/sql/SubplanReuseTest.xml | 39 ++ .../logical/FlinkLimit0RemoveRuleTest.xml | 214 ++++++ .../logical/FlinkPruneEmptyRulesTest.xml | 63 ++ .../ReplaceIntersectWithSemiJoinRuleTest.xml | 123 ++++ .../ReplaceMinusWithAntiJoinRuleTest.xml | 123 ++++ .../subquery/FlinkRewriteSubQueryRuleTest.xml | 612 ++++++++++++++++++ .../flink/table/plan/stream/sql/LimitTest.xml | 60 +- .../plan/stream/sql/SetOperatorsTest.xml | 226 +++++++ .../table/plan/stream/sql/SortLimitTest.xml | 80 +-- .../plan/stream/sql/SubplanReuseTest.xml | 37 ++ .../table/plan/batch/sql/LimitTest.scala | 1 - .../plan/batch/sql/SetOperatorsTest.scala | 129 ++++ .../table/plan/batch/sql/SortLimitTest.scala | 1 - .../plan/batch/sql/SubplanReuseTest.scala | 3 +- .../logical/FlinkLimit0RemoveRuleTest.scala | 101 +++ .../logical/FlinkPruneEmptyRulesTest.scala | 73 +++ ...ReplaceIntersectWithSemiJoinRuleTest.scala | 84 +++ .../ReplaceMinusWithAntiJoinRuleTest.scala | 82 +++ .../FlinkRewriteSubQueryRuleTest.scala | 211 ++++++ .../table/plan/stream/sql/LimitTest.scala | 1 - .../plan/stream/sql/SetOperatorsTest.scala | 127 ++++ .../table/plan/stream/sql/SortLimitTest.scala | 1 - .../plan/stream/sql/SubplanReuseTest.scala | 3 +- .../apache/flink/table/plan/util/pojos.scala | 10 + .../batch/sql/Limit0RemoveITCase.scala | 98 +++ .../stream/sql/Limit0RemoveITCase.scala | 187 ++++++ .../table/runtime/utils/StreamTestSink.scala | 2 +- .../table/runtime/utils/TestSinkUtil.scala | 16 +- 40 files changed, 3303 insertions(+), 167 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRules.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkRewriteSubQueryRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/Limit0RemoveITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/Limit0RemoveITCase.scala diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala index d7e45b30794b5b..1401605d0d5703 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala @@ -34,6 +34,7 @@ object FlinkBatchRuleSets { val SEMI_JOIN_RULES: RuleSet = RuleSets.ofList( SimplifyFilterConditionRule.EXTENDED, + FlinkRewriteSubQueryRule.FILTER, FlinkSubQueryRemoveRule.FILTER, JoinConditionTypeCoerceRule.INSTANCE, FlinkJoinPushExpressionsRule.INSTANCE @@ -116,7 +117,9 @@ object FlinkBatchRuleSets { new CoerceInputsRule(classOf[LogicalIntersect], false), //ensure except set operator have the same row type new CoerceInputsRule(classOf[LogicalMinus], false), - ConvertToNotInOrInRule.INSTANCE + ConvertToNotInOrInRule.INSTANCE, + // optimize limit 0 + FlinkLimit0RemoveRule.INSTANCE )).asJava) /** @@ -159,7 +162,7 @@ object FlinkBatchRuleSets { PruneEmptyRules.AGGREGATE_INSTANCE, PruneEmptyRules.FILTER_INSTANCE, PruneEmptyRules.JOIN_LEFT_INSTANCE, - PruneEmptyRules.JOIN_RIGHT_INSTANCE, + FlinkPruneEmptyRules.JOIN_RIGHT_INSTANCE, PruneEmptyRules.PROJECT_INSTANCE, PruneEmptyRules.SORT_INSTANCE, PruneEmptyRules.UNION_INSTANCE @@ -260,7 +263,11 @@ object FlinkBatchRuleSets { // semi/anti join transpose rule FlinkSemiAntiJoinJoinTransposeRule.INSTANCE, FlinkSemiAntiJoinProjectTransposeRule.INSTANCE, - FlinkSemiAntiJoinFilterTransposeRule.INSTANCE + FlinkSemiAntiJoinFilterTransposeRule.INSTANCE, + + // set operators + ReplaceIntersectWithSemiJoinRule.INSTANCE, + ReplaceMinusWithAntiJoinRule.INSTANCE ) /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala index ce99a988d42088..419306ec9534b7 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala @@ -34,6 +34,7 @@ object FlinkStreamRuleSets { val SEMI_JOIN_RULES: RuleSet = RuleSets.ofList( SimplifyFilterConditionRule.EXTENDED, + FlinkRewriteSubQueryRule.FILTER, FlinkSubQueryRemoveRule.FILTER, JoinConditionTypeCoerceRule.INSTANCE, FlinkJoinPushExpressionsRule.INSTANCE @@ -118,7 +119,9 @@ object FlinkStreamRuleSets { new CoerceInputsRule(classOf[LogicalIntersect], false), //ensure except set operator have the same row type new CoerceInputsRule(classOf[LogicalMinus], false), - ConvertToNotInOrInRule.INSTANCE + ConvertToNotInOrInRule.INSTANCE, + // optimize limit 0 + FlinkLimit0RemoveRule.INSTANCE ) ).asJava) @@ -157,7 +160,7 @@ object FlinkStreamRuleSets { PruneEmptyRules.AGGREGATE_INSTANCE, PruneEmptyRules.FILTER_INSTANCE, PruneEmptyRules.JOIN_LEFT_INSTANCE, - PruneEmptyRules.JOIN_RIGHT_INSTANCE, + FlinkPruneEmptyRules.JOIN_RIGHT_INSTANCE, PruneEmptyRules.PROJECT_INSTANCE, PruneEmptyRules.SORT_INSTANCE, PruneEmptyRules.UNION_INSTANCE @@ -232,7 +235,11 @@ object FlinkStreamRuleSets { // semi/anti join transpose rule FlinkSemiAntiJoinJoinTransposeRule.INSTANCE, FlinkSemiAntiJoinProjectTransposeRule.INSTANCE, - FlinkSemiAntiJoinFilterTransposeRule.INSTANCE + FlinkSemiAntiJoinFilterTransposeRule.INSTANCE, + + // set operators + ReplaceIntersectWithSemiJoinRule.INSTANCE, + ReplaceMinusWithAntiJoinRule.INSTANCE ) /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkCalcMergeRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkCalcMergeRule.scala index 8deddfb095beb5..550c318fb73d7d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkCalcMergeRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkCalcMergeRule.scala @@ -29,7 +29,7 @@ import org.apache.calcite.tools.RelBuilderFactory import scala.collection.JavaConversions._ /** - * This rules is copied from Calcite's [[org.apache.calcite.rel.rules.CalcMergeRule]]. + * This rule is copied from Calcite's [[org.apache.calcite.rel.rules.CalcMergeRule]]. * * Modification: * - Condition in the merged program will be simplified if it exists. diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRule.scala new file mode 100644 index 00000000000000..73190fa05b8748 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRule.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.core.Sort +import org.apache.calcite.rex.RexLiteral + +/** + * Planner rule that rewrites `limit 0` to empty [[org.apache.calcite.rel.core.Values]]. + */ +class FlinkLimit0RemoveRule extends RelOptRule( + operand(classOf[Sort], any()), + "FlinkLimit0RemoveRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val sort: Sort = call.rel(0) + sort.fetch != null && RexLiteral.intValue(sort.fetch) == 0 + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val sort: Sort = call.rel(0) + val emptyValues = call.builder().values(sort.getRowType).build() + call.transformTo(emptyValues) + + // New plan is absolutely better than old plan. + call.getPlanner.setImportance(sort, 0.0) + } +} + +object FlinkLimit0RemoveRule { + val INSTANCE = new FlinkLimit0RemoveRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRules.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRules.scala new file mode 100644 index 00000000000000..f85087baf82671 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRules.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, none, operand, some} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.{Join, JoinRelType, Values} + +object FlinkPruneEmptyRules { + + /** + * This rule is copied from Calcite's + * [[org.apache.calcite.rel.rules.PruneEmptyRules#JOIN_RIGHT_INSTANCE]]. + * + * Modification: + * - Handles ANTI join specially. + * + * Rule that converts a [[Join]] to empty if its right child is empty. + * + *

Examples: + * + *

    + *
  • Join(Scan(Emp), Empty, INNER) becomes Empty + *
+ */ + val JOIN_RIGHT_INSTANCE: RelOptRule = new RelOptRule( + operand(classOf[Join], + some(operand(classOf[RelNode], any), + operand(classOf[Values], none))), + "FlinkPruneEmptyRules(right)") { + + override def matches(call: RelOptRuleCall): Boolean = { + val right: Values = call.rel(2) + Values.IS_EMPTY.apply(right) + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val join: Join = call.rel(0) + join.getJoinType match { + case JoinRelType.ANTI => + // "select * from emp where deptno not in (select deptno from dept where 1=0)" + // return emp + call.transformTo(call.builder().push(join.getLeft).build) + case _ => + if (join.getJoinType.generatesNullsOnRight) { + // "select * from emp left join dept" is not necessarily empty if dept is empty + } else { + call.transformTo(call.builder.push(join).empty.build) + } + } + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkRewriteSubQueryRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkRewriteSubQueryRule.scala new file mode 100644 index 00000000000000..8d97e34c794d35 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkRewriteSubQueryRule.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operandJ} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.{Aggregate, Filter, RelFactories} +import org.apache.calcite.rex.{RexShuttle, _} +import org.apache.calcite.sql.SqlKind +import org.apache.calcite.sql.`type`.SqlTypeFamily +import org.apache.calcite.sql.fun.SqlCountAggFunction +import org.apache.calcite.tools.RelBuilderFactory + +import scala.collection.JavaConversions._ + +/** + * Planner rule that rewrites scalar query in filter like: + * `select * from T1 where (select count(*) from T2) > 0` + * to + * `select * from T1 where exists (select * from T2)`, + * which could be converted to SEMI join by [[FlinkSubQueryRemoveRule]]. + * + * Without this rule, the original query will be rewritten to a filter on a join on an aggregate + * by [[org.apache.calcite.rel.rules.SubQueryRemoveRule]]. the full logical plan is + * {{{ + * LogicalProject(a=[$0], b=[$1], c=[$2]) + * +- LogicalJoin(condition=[$3], joinType=[semi]) + * :- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + * +- LogicalProject($f0=[IS NOT NULL($0)]) + * +- LogicalAggregate(group=[{}], m=[MIN($0)]) + * +- LogicalProject(i=[true]) + * +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) + * }}} + */ +class FlinkRewriteSubQueryRule( + operand: RelOptRuleOperand, + relBuilderFactory: RelBuilderFactory, + description: String) + extends RelOptRule(operand, relBuilderFactory, description) { + + override def onMatch(call: RelOptRuleCall): Unit = { + val filter: Filter = call.rel(0) + val condition = filter.getCondition + val newCondition = rewriteScalarQuery(condition) + if (RexUtil.eq(condition, newCondition)) { + return + } + + val newFilter = filter.copy(filter.getTraitSet, filter.getInput, newCondition) + call.transformTo(newFilter) + } + + // scalar query like: `(select count(*) from T) > 0` can be converted to `exists(select * from T)` + def rewriteScalarQuery(condition: RexNode): RexNode = { + condition.accept(new RexShuttle() { + override def visitCall(call: RexCall): RexNode = { + val subQuery = getSupportedScalarQuery(call) + subQuery match { + case Some(sq) => + val aggInput = sq.rel.getInput(0) + RexSubQuery.exists(aggInput) + case _ => super.visitCall(call) + } + } + }) + } + + private def isScalarQuery(n: RexNode): Boolean = n.isA(SqlKind.SCALAR_QUERY) + + private def getSupportedScalarQuery(call: RexCall): Option[RexSubQuery] = { + // check the RexNode is a RexLiteral which's value is between 0 and 1 + def isBetween0And1(n: RexNode, include0: Boolean, include1: Boolean): Boolean = { + n match { + case l: RexLiteral => + l.getTypeName.getFamily match { + case SqlTypeFamily.NUMERIC if l.getValue != null => + val v = l.getValue.toString.toDouble + (0.0 < v && v < 1.0) || (include0 && v == 0.0) || (include1 && v == 1.0) + case _ => false + } + case _ => false + } + } + + // check the RelNode is a Aggregate which has only count aggregate call with empty args + def isCountStarAggWithoutGroupBy(n: RelNode): Boolean = { + n match { + case agg: Aggregate => + if (agg.getGroupCount == 0 && agg.getAggCallList.size() == 1) { + val aggCall = agg.getAggCallList.head + !aggCall.isDistinct && + aggCall.filterArg < 0 && + aggCall.getArgList.isEmpty && + aggCall.getAggregation.isInstanceOf[SqlCountAggFunction] + } else { + false + } + case _ => false + } + } + + call.getKind match { + // (select count(*) from T) > X (X is between 0 (inclusive) and 1 (exclusive)) + case SqlKind.GREATER_THAN if isScalarQuery(call.operands.head) => + val subQuery = call.operands.head.asInstanceOf[RexSubQuery] + if (isCountStarAggWithoutGroupBy(subQuery.rel) && + isBetween0And1(call.operands.last, include0 = true, include1 = false)) { + Some(subQuery) + } else { + None + } + // (select count(*) from T) >= X (X is between 0 (exclusive) and 1 (inclusive)) + case SqlKind.GREATER_THAN_OR_EQUAL if isScalarQuery(call.operands.head) => + val subQuery = call.operands.head.asInstanceOf[RexSubQuery] + if (isCountStarAggWithoutGroupBy(subQuery.rel) && + isBetween0And1(call.operands.last, include0 = false, include1 = true)) { + Some(subQuery) + } else { + None + } + // X < (select count(*) from T) (X is between 0 (inclusive) and 1 (exclusive)) + case SqlKind.LESS_THAN if isScalarQuery(call.operands.last) => + val subQuery = call.operands.last.asInstanceOf[RexSubQuery] + if (isCountStarAggWithoutGroupBy(subQuery.rel) && + isBetween0And1(call.operands.head, include0 = true, include1 = false)) { + Some(subQuery) + } else { + None + } + // X <= (select count(*) from T) (X is between 0 (exclusive) and 1 (inclusive)) + case SqlKind.LESS_THAN_OR_EQUAL if isScalarQuery(call.operands.last) => + val subQuery = call.operands.last.asInstanceOf[RexSubQuery] + if (isCountStarAggWithoutGroupBy(subQuery.rel) && + isBetween0And1(call.operands.head, include0 = false, include1 = true)) { + Some(subQuery) + } else { + None + } + case _ => None + } + } +} + +object FlinkRewriteSubQueryRule { + + val FILTER = new FlinkRewriteSubQueryRule( + operandJ(classOf[Filter], null, RexUtil.SubQueryFinder.FILTER_PREDICATE, any), + RelFactories.LOGICAL_BUILDER, + "FlinkRewriteSubQueryRule:Filter") + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala new file mode 100644 index 00000000000000..6ac7deae091508 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRule.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.core.{Aggregate, Intersect, Join, JoinRelType} + +import scala.collection.JavaConversions._ + +/** + * Planner rule that replaces distinct [[Intersect]] with + * a distinct [[Aggregate]] on a SEMI [[Join]]. + * + *

Note: Not support Intersect All. + */ +class ReplaceIntersectWithSemiJoinRule extends ReplaceSetOpWithJoinRuleBase( + classOf[Intersect], + "ReplaceIntersectWithSemiJoinRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val intersect: Intersect = call.rel(0) + // not support intersect all now. + intersect.isDistinct + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val intersect: Intersect = call.rel(0) + val left = intersect.getInput(0) + val right = intersect.getInput(1) + + val relBuilder = call.builder + val keys = 0 until left.getRowType.getFieldCount + val conditions = generateCondition(relBuilder, left, right, keys) + + relBuilder.push(left) + relBuilder.push(right) + relBuilder.join(JoinRelType.SEMI, conditions).aggregate(relBuilder.groupKey(keys: _*)) + val rel = relBuilder.build() + call.transformTo(rel) + } +} + +object ReplaceIntersectWithSemiJoinRule { + val INSTANCE: RelOptRule = new ReplaceIntersectWithSemiJoinRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala new file mode 100644 index 00000000000000..c322cf98e603b5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRule.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.core._ + +import scala.collection.JavaConversions._ + +/** + * Planner rule that replaces distinct [[Minus]] (SQL keyword: EXCEPT) with + * a distinct [[Aggregate]] on an ANTI [[Join]]. + * + *

Note: Not support Minus All. + */ +class ReplaceMinusWithAntiJoinRule extends ReplaceSetOpWithJoinRuleBase( + classOf[Minus], + "ReplaceMinusWithAntiJoinRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val minus: Minus = call.rel(0) + // not support minus all now. + minus.isDistinct + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val minus: Minus = call.rel(0) + val left = minus.getInput(0) + val right = minus.getInput(1) + + val relBuilder = call.builder + val keys = 0 until left.getRowType.getFieldCount + val conditions = generateCondition(relBuilder, left, right, keys) + + relBuilder.push(left) + relBuilder.push(right) + relBuilder.join(JoinRelType.ANTI, conditions).aggregate(relBuilder.groupKey(keys: _*)) + val rel = relBuilder.build() + call.transformTo(rel) + } +} + +object ReplaceMinusWithAntiJoinRule { + val INSTANCE: RelOptRule = new ReplaceMinusWithAntiJoinRule +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala new file mode 100644 index 00000000000000..1f400a0cd69310 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/ReplaceSetOpWithJoinRuleBase.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptUtil} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.{RelFactories, SetOp} +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder + +/** + * Base class that replace [[SetOp]] to [[org.apache.calcite.rel.core.Join]]. + */ +abstract class ReplaceSetOpWithJoinRuleBase[T <: SetOp]( + clazz: Class[T], + description: String) + extends RelOptRule( + operand(clazz, any), + RelFactories.LOGICAL_BUILDER, + description) { + + protected def generateCondition( + relBuilder: RelBuilder, + left: RelNode, + right: RelNode, + keys: Seq[Int]): Seq[RexNode] = { + val rexBuilder = relBuilder.getRexBuilder + val leftTypes = RelOptUtil.getFieldTypeList(left.getRowType) + val rightTypes = RelOptUtil.getFieldTypeList(right.getRowType) + val conditions = keys.map { key => + val leftRex = rexBuilder.makeInputRef(leftTypes.get(key), key) + val rightRex = rexBuilder.makeInputRef(rightTypes.get(key), leftTypes.size + key) + val equalCond = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftRex, rightRex) + relBuilder.or( + equalCond, + relBuilder.and(relBuilder.isNull(leftRex), relBuilder.isNull(rightRex))) + } + conditions + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/LimitTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/LimitTest.xml index 731f1b58c08505..54f02926dd098a 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/LimitTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/LimitTest.xml @@ -29,11 +29,7 @@ LogicalSort(fetch=[0]) @@ -92,11 +88,7 @@ LogicalSort(offset=[10], fetch=[0]) @@ -134,11 +126,7 @@ LogicalSort(offset=[0], fetch=[0]) @@ -155,10 +143,7 @@ LogicalSort(fetch=[0]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml new file mode 100644 index 00000000000000..a3fada8bb2e8fe --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.xml @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 THEN b ELSE NULL END FROM A]]> + + + ($2, 0), $1, null:RecordType:peek_no_expand(INTEGER _1, VARCHAR(65536) CHARACTER SET "UTF-16LE" _2))]) + +- LogicalTableScan(table=[[A, source: [TestTableSource(a, b, c)]]]) +]]> + + + (c, 0), b, null:RecordType:peek_no_expand(INTEGER _1, VARCHAR(65536) CHARACTER SET "UTF-16LE" _2)) AS EXPR$0]) + +- Reused(reference_id=[1]) +]]> + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SortLimitTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SortLimitTest.xml index e4f1be9d1dee0d..94c59b310a3851 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SortLimitTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SortLimitTest.xml @@ -49,10 +49,7 @@ LogicalSort(sort0=[$0], sort1=[$1], dir0=[DESC-nulls-last], dir1=[ASC-nulls-firs @@ -88,10 +85,7 @@ LogicalSort(sort0=[$0], dir0=[DESC-nulls-last], fetch=[0]) @@ -147,10 +141,7 @@ LogicalSort(sort0=[$0], dir0=[DESC-nulls-last], fetch=[0]) @@ -207,10 +198,7 @@ LogicalSort(sort0=[$0], sort1=[$1], dir0=[DESC-nulls-last], dir1=[ASC-nulls-firs diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml index a1cc277f0b20d6..a1cbb1ab907788 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.xml @@ -1066,6 +1066,45 @@ SortMergeJoin(joinType=[InnerJoin], where=[=(a, d0)], select=[a, b, c, d, e, f, : +- Exchange(distribution=[hash[d]]) : +- TableSourceScan(table=[[y, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) +- Reused(reference_id=[1]) +]]> + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.xml new file mode 100644 index 00000000000000..2a2949ec8fa686 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.xml @@ -0,0 +1,214 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.xml new file mode 100644 index 00000000000000..ddc7d13f1ead3a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.xml @@ -0,0 +1,63 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.xml new file mode 100644 index 00000000000000..0d5a4ec958b179 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.xml @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1]]> + + + ($0, 1)]) + +- LogicalIntersect(all=[false]) + :- LogicalProject(a=[$0], b=[$1], c=[$2]) + : +- LogicalTableScan(table=[[T1, source: [TestTableSource(a, b, c)]]]) + +- LogicalProject(d=[$0], e=[$1], f=[$2]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(d, e, f)]]]) +]]> + + + ($0, 1)]) + +- LogicalAggregate(group=[{0, 1, 2}]) + +- LogicalJoin(condition=[AND(OR(=($0, $3), AND(IS NULL($0), IS NULL($3))), OR(=($1, $4), AND(IS NULL($1), IS NULL($4))), OR(=($2, $5), AND(IS NULL($2), IS NULL($5))))], joinType=[semi]) + :- LogicalProject(a=[$0], b=[$1], c=[$2]) + : +- LogicalTableScan(table=[[T1, source: [TestTableSource(a, b, c)]]]) + +- LogicalProject(d=[$0], e=[$1], f=[$2]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(d, e, f)]]]) +]]> + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.xml new file mode 100644 index 00000000000000..02494da64f0344 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.xml @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.xml new file mode 100644 index 00000000000000..a4ef3200a0f5fd --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.xml @@ -0,0 +1,612 @@ + + + + + + 10) > 0]]> + + + ($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)]) + LogicalProject(e=[$1]) + LogicalFilter(condition=[>($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0)]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0)]) + +- LogicalJoin(condition=[true], joinType=[left]) + :- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + +- LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)]) + +- LogicalProject(e=[$1]) + +- LogicalFilter(condition=[>($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10 GROUP BY f) > 0]]> + + + ($SCALAR_QUERY({ +LogicalProject(EXPR$0=[$1]) + LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) + LogicalProject(f=[$2]) + LogicalFilter(condition=[>($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0)]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($3, 0)]) + +- LogicalJoin(condition=[true], joinType=[left]) + :- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) + +- LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)]) + +- LogicalProject(EXPR$0=[$1]) + +- LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) + +- LogicalProject(f=[$2]) + +- LogicalFilter(condition=[>($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 0 +ORDER BY i_product_name +LIMIT 100 + ]]> + + + =($0, 738), <=($0, +(738, 40)), >($SCALAR_QUERY({ +LogicalAggregate(group=[{}], item_cnt=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[OR(AND(=($1, $cor0.i_manufact), OR(AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'powder'), =($4, _UTF-16LE'khaki')), OR(=($5, _UTF-16LE'Ounce'), =($5, _UTF-16LE'Oz')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))), AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'brown'), =($4, _UTF-16LE'honeydew')), OR(=($5, _UTF-16LE'Bunch'), =($5, _UTF-16LE'Ton')), OR(=($6, _UTF-16LE'N/A'), =($6, _UTF-16LE'small'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'floral'), =($4, _UTF-16LE'deep')), OR(=($5, _UTF-16LE'N/A'), =($5, _UTF-16LE'Dozen')), OR(=($6, _UTF-16LE'petite'), =($6, _UTF-16LE'large'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'light'), =($4, _UTF-16LE'cornflower')), OR(=($5, _UTF-16LE'Box'), =($5, _UTF-16LE'Pound')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))))), AND(=($1, $cor0.i_manufact), OR(AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'midnight'), =($4, _UTF-16LE'snow')), OR(=($5, _UTF-16LE'Pallet'), =($5, _UTF-16LE'Gross')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))), AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'cyan'), =($4, _UTF-16LE'papaya')), OR(=($5, _UTF-16LE'Cup'), =($5, _UTF-16LE'Dram')), OR(=($6, _UTF-16LE'N/A'), =($6, _UTF-16LE'small'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'orange'), =($4, _UTF-16LE'frosted')), OR(=($5, _UTF-16LE'Each'), =($5, _UTF-16LE'Tbl')), OR(=($6, _UTF-16LE'petite'), =($6, _UTF-16LE'large'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'forest'), =($4, _UTF-16LE'ghost')), OR(=($5, _UTF-16LE'Lb'), =($5, _UTF-16LE'Bundle')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))))))]) + LogicalTableScan(table=[[item, source: [TestTableSource(i_manufact_id, i_manufact, i_product_name, i_category, i_color, i_units, i_size)]]]) +}), 0))], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[item, source: [TestTableSource(i_manufact_id, i_manufact, i_product_name, i_category, i_color, i_units, i_size)]]]) +]]> + + + =($0, 738), <=($0, +(738, 40)))]) + +- LogicalJoin(condition=[=($7, $1)], joinType=[semi]) + :- LogicalTableScan(table=[[item, source: [TestTableSource(i_manufact_id, i_manufact, i_product_name, i_category, i_color, i_units, i_size)]]]) + +- LogicalProject(i_manufact=[$1]) + +- LogicalFilter(condition=[OR(AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'powder'), =($4, _UTF-16LE'khaki')), OR(=($5, _UTF-16LE'Ounce'), =($5, _UTF-16LE'Oz')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))), AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'brown'), =($4, _UTF-16LE'honeydew')), OR(=($5, _UTF-16LE'Bunch'), =($5, _UTF-16LE'Ton')), OR(=($6, _UTF-16LE'N/A'), =($6, _UTF-16LE'small'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'floral'), =($4, _UTF-16LE'deep')), OR(=($5, _UTF-16LE'N/A'), =($5, _UTF-16LE'Dozen')), OR(=($6, _UTF-16LE'petite'), =($6, _UTF-16LE'large'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'light'), =($4, _UTF-16LE'cornflower')), OR(=($5, _UTF-16LE'Box'), =($5, _UTF-16LE'Pound')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))), AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'midnight'), =($4, _UTF-16LE'snow')), OR(=($5, _UTF-16LE'Pallet'), =($5, _UTF-16LE'Gross')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))), AND(=($3, _UTF-16LE'Women'), OR(=($4, _UTF-16LE'cyan'), =($4, _UTF-16LE'papaya')), OR(=($5, _UTF-16LE'Cup'), =($5, _UTF-16LE'Dram')), OR(=($6, _UTF-16LE'N/A'), =($6, _UTF-16LE'small'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'orange'), =($4, _UTF-16LE'frosted')), OR(=($5, _UTF-16LE'Each'), =($5, _UTF-16LE'Tbl')), OR(=($6, _UTF-16LE'petite'), =($6, _UTF-16LE'large'))), AND(=($3, _UTF-16LE'Men'), OR(=($4, _UTF-16LE'forest'), =($4, _UTF-16LE'ghost')), OR(=($5, _UTF-16LE'Lb'), =($5, _UTF-16LE'Bundle')), OR(=($6, _UTF-16LE'medium'), =($6, _UTF-16LE'extra large'))))]) + +- LogicalTableScan(table=[[item, source: [TestTableSource(i_manufact_id, i_manufact, i_product_name, i_category, i_color, i_units, i_size)]]]) +]]> + + + + + 0]]> + + + ($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0)], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + 0.9]]> + + + ($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0.9:DECIMAL(2, 1))], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + = 1]]> + + + =($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 1)], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + = 0.1]]> + + + =($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0.1:DECIMAL(2, 1))], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 10) > 0]]> + + + ($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[>($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0)]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10) > 0.9]]> + + + ($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[>($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0.9:DECIMAL(2, 1))]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10) >= 1]]> + + + =($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[>($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 1)]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10) >= 0.1]]> + + + =($SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + LogicalProject($f0=[0]) + LogicalFilter(condition=[>($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}), 0.1:DECIMAL(2, 1))]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10)]]> + + + ($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}))]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10)]]> + + + ($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}))]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10)]]> + + + ($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}))]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10)]]> + + + ($0, 10)]) + LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +}))]) + +- LogicalTableScan(table=[[x, source: [TestTableSource(a, b, c)]]]) +]]> + + + ($0, 10)]) + +- LogicalTableScan(table=[[y, source: [TestTableSource(d, e, f)]]]) +]]> + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/LimitTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/LimitTest.xml index 285ee01f918750..3b6ffc6fef3f62 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/LimitTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/LimitTest.xml @@ -29,10 +29,7 @@ LogicalSort(fetch=[0]) @@ -89,10 +86,7 @@ LogicalSort(offset=[10], fetch=[0]) @@ -109,10 +103,7 @@ LogicalSort(offset=[0], fetch=[0]) @@ -129,10 +120,7 @@ LogicalSort(fetch=[0]) @@ -149,10 +137,7 @@ LogicalSort(fetch=[0]) @@ -169,10 +154,7 @@ LogicalSort(fetch=[0]) @@ -189,10 +171,7 @@ LogicalSort(fetch=[0]) @@ -209,10 +188,7 @@ LogicalSort(fetch=[0]) @@ -229,10 +205,7 @@ LogicalSort(fetch=[0]) @@ -249,10 +222,7 @@ LogicalSort(fetch=[0]) @@ -269,10 +239,7 @@ LogicalSort(fetch=[0]) @@ -289,10 +256,7 @@ LogicalSort(fetch=[0]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml new file mode 100644 index 00000000000000..d30ecddfbc4178 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.xml @@ -0,0 +1,226 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 THEN b ELSE NULL END FROM A]]> + + + ($2, 0), $1, null:RecordType:peek_no_expand(INTEGER _1, VARCHAR(65536) CHARACTER SET "UTF-16LE" _2))]) + +- LogicalTableScan(table=[[A, source: [TestTableSource(a, b, c)]]]) +]]> + + + (c, 0), b, null:RecordType:peek_no_expand(INTEGER _1, VARCHAR(65536) CHARACTER SET "UTF-16LE" _2)) AS EXPR$0]) + +- Reused(reference_id=[1]) +]]> + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SortLimitTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SortLimitTest.xml index 0faf861113b780..256ca264b453fa 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SortLimitTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SortLimitTest.xml @@ -217,10 +217,7 @@ LogicalProject(a=[$0]) @@ -259,10 +256,7 @@ LogicalProject(a=[$0]) @@ -301,10 +295,7 @@ LogicalProject(a=[$0]) @@ -427,10 +418,7 @@ LogicalProject(a=[$0]) @@ -574,10 +562,7 @@ LogicalProject(a=[$0]) @@ -616,10 +601,7 @@ LogicalProject(a=[$0]) @@ -658,10 +640,7 @@ LogicalProject(a=[$0]) @@ -784,10 +763,7 @@ LogicalProject(a=[$0]) @@ -826,10 +802,7 @@ LogicalProject(a=[$0]) @@ -868,10 +841,7 @@ LogicalProject(a=[$0]) @@ -910,10 +880,7 @@ LogicalProject(a=[$0]) @@ -952,10 +919,7 @@ LogicalProject(a=[$0]) @@ -994,10 +958,7 @@ LogicalProject(a=[$0]) @@ -1036,10 +997,7 @@ LogicalProject(a=[$0]) @@ -1078,10 +1036,7 @@ LogicalProject(a=[$0]) @@ -1225,10 +1180,7 @@ LogicalProject(a=[$0]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml index 07639da0ebab88..9f1f8945f00fe5 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.xml @@ -644,6 +644,43 @@ Union(all=[true], union=[a, b]) : +- TableSourceScan(table=[[x, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Calc(select=[a, *(b, 2) AS b], where=[<(b, 10)]) +- Reused(reference_id=[1]) +]]> + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/LimitTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/LimitTest.scala index 0cfcfe7bebf10f..9f5a4e5cd5932a 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/LimitTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/LimitTest.scala @@ -28,7 +28,6 @@ class LimitTest extends TableTestBase { private val util = batchTestUtil() util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c) - // TODO optimize `limit 0` @Test def testLimitWithoutOffset(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala new file mode 100644 index 00000000000000..635b9b74fbd6e2 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SetOperatorsTest.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.batch.sql + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.GenericTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{TableConfigOptions, TableException, ValidationException} +import org.apache.flink.table.plan.util.NonPojo +import org.apache.flink.table.util.TableTestBase + +import org.junit.{Before, Test} + +class SetOperatorsTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def before(): Unit = { + util.tableEnv.getConfig.getConf.setString( + TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + util.addTableSource[(Int, Long, String)]("T1", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("T2", 'd, 'e, 'f) + util.addTableSource[(Int, Long, Int, String, Long)]("T3", 'a, 'b, 'd, 'c, 'e) + } + + @Test(expected = classOf[ValidationException]) + def testUnionDifferentColumnSize(): Unit = { + // must fail. Union inputs have different column size. + util.verifyPlan("SELECT * FROM T1 UNION ALL SELECT * FROM T3") + } + + @Test(expected = classOf[ValidationException]) + def testUnionDifferentFieldTypes(): Unit = { + // must fail. Union inputs have different field types. + util.verifyPlan("SELECT a, b, c FROM T1 UNION ALL SELECT d, c, e FROM T3") + } + + @Test(expected = classOf[TableException]) + def testIntersectAll(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2") + } + + @Test(expected = classOf[ValidationException]) + def testIntersectDifferentFieldTypes(): Unit = { + // must fail. Intersect inputs have different field types. + util.verifyPlan("SELECT a, b, c FROM T1 INTERSECT SELECT d, c, e FROM T3") + } + + @Test(expected = classOf[TableException]) + def testMinusAll(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2") + } + + @Test(expected = classOf[ValidationException]) + def testMinusDifferentFieldTypes(): Unit = { + // must fail. Minus inputs have different field types. + util.verifyPlan("SELECT a, b, c FROM T1 EXCEPT SELECT d, c, e FROM T3") + } + + @Test + def testIntersect(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2") + } + + @Test + def testIntersectLeftIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 WHERE 1=0 INTERSECT SELECT f FROM T2") + } + + @Test + def testIntersectRightIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2 WHERE 1=0") + } + + @Test + def testMinus(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2") + } + + @Test + def testMinusLeftIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2") + } + + @Test + def testMinusRightIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2 WHERE 1=0") + } + + @Test + def testMinusWithNestedTypes(): Unit = { + util.addTableSource[(Long, (Int, String), Array[Boolean])]("MyTable", 'a, 'b, 'c) + util.verifyPlan("SELECT * FROM MyTable EXCEPT SELECT * FROM MyTable") + } + + @Test + def testUnionNullableTypes(): Unit = { + util.addTableSource[((Int, String), (Int, String), Int)]("A", 'a, 'b, 'c) + util.verifyPlan("SELECT a FROM A UNION ALL SELECT CASE WHEN c > 0 THEN b ELSE NULL END FROM A") + } + + @Test + def testUnionAnyType(): Unit = { + val util = batchTestUtil() + util.addTableSource("A", + Array[TypeInformation[_]]( + new GenericTypeInfo(classOf[NonPojo]), + new GenericTypeInfo(classOf[NonPojo])), + Array("a", "b")) + util.verifyPlan("SELECT a FROM A UNION ALL SELECT b FROM A") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SortLimitTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SortLimitTest.scala index 4ebdf44b020d1d..4aff8d09a83616 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SortLimitTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SortLimitTest.scala @@ -29,7 +29,6 @@ class SortLimitTest extends TableTestBase { private val util = batchTestUtil() util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.tableEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_EXEC_SORT_DEFAULT_LIMIT, 200) - // TODO optimize `limit 0` @Test def testNonRangeSortWithoutOffset(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala index f85d8de49dd3f5..5de63bf27b1579 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/SubplanReuseTest.scala @@ -368,8 +368,7 @@ class SubplanReuseTest extends TableTestBase { util.verifyPlan(sqlQuery) } - @Test(expected = classOf[TableException]) - // INTERSECT is not supported now + @Test def testSubplanReuseWithDynamicFunction(): Unit = { val sqlQuery = util.tableEnv.sqlQuery( """ diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.scala new file mode 100644 index 00000000000000..60af8607e9110f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkLimit0RemoveRuleTest.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkChainedProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.util.TableTestBase + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[FlinkLimit0RemoveRule]]. + */ +class FlinkLimit0RemoveRuleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + val programs = new FlinkChainedProgram[BatchOptimizeContext]() + programs.addLast( + "rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + FlinkSubQueryRemoveRule.FILTER, + FlinkLimit0RemoveRule.INSTANCE)) + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + } + + @Test + def testSimpleLimitZero(): Unit = { + util.verifyPlan("SELECT * FROM MyTable LIMIT 0") + } + + @Test + def testLimitZeroWithOrderBy(): Unit = { + util.verifyPlan("SELECT * FROM MyTable ORDER BY a LIMIT 0") + } + + @Test + def testLimitZeroWithOffset(): Unit = { + util.verifyPlan("SELECT * FROM MyTable ORDER BY a LIMIT 0 OFFSET 10") + } + + @Test + def testLimitZeroWithSelect(): Unit = { + util.verifyPlan("SELECT * FROM (SELECT a FROM MyTable LIMIT 0)") + } + + @Test + def testLimitZeroWithIn(): Unit = { + util.verifyPlan("SELECT * FROM MyTable WHERE a IN (SELECT a FROM MyTable LIMIT 0)") + } + + @Test + def testLimitZeroWithNotIn(): Unit = { + util.verifyPlan("SELECT * FROM MyTable WHERE a NOT IN (SELECT a FROM MyTable LIMIT 0)") + } + + @Test + def testLimitZeroWithExists(): Unit = { + util.verifyPlan("SELECT * FROM MyTable WHERE EXISTS (SELECT a FROM MyTable LIMIT 0)") + } + + @Test + def testLimitZeroWithNotExists(): Unit = { + util.verifyPlan("SELECT * FROM MyTable WHERE NOT EXISTS (SELECT a FROM MyTable LIMIT 0)") + } + + @Test + def testLimitZeroWithJoin(): Unit = { + util.verifyPlan("SELECT * FROM MyTable INNER JOIN (SELECT * FROM MyTable Limit 0) ON TRUE") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.scala new file mode 100644 index 00000000000000..3f15f189686dd2 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkPruneEmptyRulesTest.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkChainedProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.util.TableTestBase + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules.{PruneEmptyRules, ReduceExpressionsRule} +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[FlinkPruneEmptyRules]]. + */ +class FlinkPruneEmptyRulesTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + val programs = new FlinkChainedProgram[BatchOptimizeContext]() + programs.addLast( + "rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + FlinkSubQueryRemoveRule.FILTER, + ReduceExpressionsRule.FILTER_INSTANCE, + ReduceExpressionsRule.PROJECT_INSTANCE, + PruneEmptyRules.FILTER_INSTANCE, + PruneEmptyRules.PROJECT_INSTANCE, + FlinkPruneEmptyRules.JOIN_RIGHT_INSTANCE)) + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + util.addTableSource[(Int, Long, String)]("T1", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("T2", 'd, 'e, 'f) + } + + @Test + def testSemiJoinRightIsEmpty(): Unit = { + util.verifyPlan("SELECT * FROM T1 WHERE a IN (SELECT d FROM T2 WHERE 1=0)") + } + + @Test + def testAntiJoinRightIsEmpty(): Unit = { + util.verifyPlan("SELECT * FROM T1 WHERE a NOT IN (SELECT d FROM T2 WHERE 1=0)") + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala new file mode 100644 index 00000000000000..c867de85baf746 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceIntersectWithSemiJoinRuleTest.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.TableException +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkChainedProgram, + FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.util.TableTestBase + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[ReplaceIntersectWithSemiJoinRule]]. + */ +class ReplaceIntersectWithSemiJoinRuleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + val programs = new FlinkChainedProgram[BatchOptimizeContext]() + programs.addLast( + "rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList(ReplaceIntersectWithSemiJoinRule.INSTANCE)) + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + + util.addTableSource[(Int, Long, String)]("T1", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("T2", 'd, 'e, 'f) + } + + @Test + def testIntersectAll(): Unit = { + util.verifyPlanNotExpected("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2", "joinType=[semi]") + } + + @Test + def testIntersect(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2") + } + + @Test + def testIntersectWithFilter(): Unit = { + util.verifyPlan("SELECT c FROM ((SELECT * FROM T1) INTERSECT (SELECT * FROM T2)) WHERE a > 1") + } + + @Test + def testIntersectLeftIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 WHERE 1=0 INTERSECT SELECT f FROM T2") + } + + @Test + def testIntersectRightIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2 WHERE 1=0") + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala new file mode 100644 index 00000000000000..f5b6fbe951172b --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ReplaceMinusWithAntiJoinRuleTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkChainedProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.util.TableTestBase + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[ReplaceMinusWithAntiJoinRule]]. + */ +class ReplaceMinusWithAntiJoinRuleTest extends TableTestBase { + + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + val programs = new FlinkChainedProgram[BatchOptimizeContext]() + programs.addLast( + "rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList(ReplaceMinusWithAntiJoinRule.INSTANCE)) + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + + util.addTableSource[(Int, Long, String)]("T1", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("T2", 'd, 'e, 'f) + } + + @Test + def testExceptAll(): Unit = { + util.verifyPlanNotExpected("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2", "joinType=[anti]") + } + + @Test + def testExcept(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2") + } + + @Test + def testExceptWithFilter(): Unit = { + util.verifyPlan("SELECT c FROM (SELECT * FROM T1 EXCEPT (SELECT * FROM T2)) WHERE b < 2") + } + + @Test + def testExceptLeftIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2") + } + + @Test + def testExceptRightIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2 WHERE 1=0") + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.scala new file mode 100644 index 00000000000000..255e2ba29970a0 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/subquery/FlinkRewriteSubQueryRuleTest.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical.subquery + +import org.apache.flink.api.scala._ + +import org.junit.{Before, Test} + +/** + * Test for [[org.apache.flink.table.plan.rules.logical.FlinkRewriteSubQueryRule]]. + */ +class FlinkRewriteSubQueryRuleTest extends SubQueryTestBase { + + @Before + def setup(): Unit = { + util.addTableSource[(Int, Long, String)]("x", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("y", 'd, 'e, 'f) + } + + @Test + def testNotCountStarInScalarQuery(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(e) FROM y WHERE d > 10) > 0") + } + + @Test + def testNotEmptyGroupByInScalarQuery(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10 GROUP BY f) > 0") + } + + @Test + def testUnsupportedConversionWithUnexpectedComparisonNumber(): Unit = { + // without correlation + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) > 1", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) >= 0", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) > -1", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE 0 <= (SELECT COUNT(*) FROM y WHERE d > 10)", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE -1 < (SELECT COUNT(*) FROM y WHERE d > 10)", "joinType=[semi]") + + // with correlation + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE a = d) > 1", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE a = d) >= 0", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE 1 < (SELECT COUNT(*) FROM y WHERE a = d)", "joinType=[semi]") + util.verifyPlanNotExpected( + "SELECT * FROM x WHERE 0 <= (SELECT COUNT(*) FROM y WHERE a = d)", "joinType=[semi]") + } + + @Test + def testSupportedConversionWithoutCorrelation1(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) > 0") + } + + @Test + def testSupportedConversionWithoutCorrelation2(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) > 0.9") + } + + @Test + def testSupportedConversionWithoutCorrelation3(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) >= 1") + } + + @Test + def testSupportedConversionWithoutCorrelation4(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE d > 10) >= 0.1") + } + + @Test + def testSupportedConversionWithoutCorrelation5(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 0 < (SELECT COUNT(*) FROM y WHERE d > 10)") + } + + @Test + def testSupportedConversionWithoutCorrelation6(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 0.99 < (SELECT COUNT(*) FROM y WHERE d > 10)") + } + + @Test + def testSupportedConversionWithoutCorrelation7(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 1 <= (SELECT COUNT(*) FROM y WHERE d > 10)") + } + + @Test + def testSupportedConversionWithoutCorrelation8(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 0.01 <= (SELECT COUNT(*) FROM y WHERE d > 10)") + } + + @Test + def testSupportedConversionWithCorrelation1(): Unit = { + // with correlation + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE a = d) > 0") + } + + @Test + def testSupportedConversionWithCorrelation2(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE a = d) > 0.9") + } + + @Test + def testSupportedConversionWithCorrelation3(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE a = d) >= 1") + } + + @Test + def testSupportedConversionWithCorrelation4(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE (SELECT COUNT(*) FROM y WHERE a = d) >= 0.1") + } + + @Test + def testSupportedConversionWithCorrelation5(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 0 < (SELECT COUNT(*) FROM y WHERE a = d)") + } + + @Test + def testSupportedConversionWithCorrelation6(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 0.99 < (SELECT COUNT(*) FROM y WHERE a = d)") + } + + @Test + def testSupportedConversionWithCorrelation7(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 1 <= (SELECT COUNT(*) FROM y WHERE a = d)") + } + + @Test + def testSupportedConversionWithCorrelation8(): Unit = { + util.verifyPlan("SELECT * FROM x WHERE 0.01 <= (SELECT COUNT(*) FROM y WHERE a = d)") + } + + @Test + def testSqlFromTpcDsQ41(): Unit = { + util.addTableSource[(Int, String, String, String, String, String, String)]("item", + 'i_manufact_id, 'i_manufact, 'i_product_name, 'i_category, 'i_color, 'i_units, 'i_size) + val sqlQuery = + """ + |SELECT DISTINCT (i_product_name) + |FROM item i1 + |WHERE i_manufact_id BETWEEN 738 AND 738 + 40 + | AND (SELECT count(*) AS item_cnt + |FROM item + |WHERE (i_manufact = i1.i_manufact AND + | ((i_category = 'Women' AND + | (i_color = 'powder' OR i_color = 'khaki') AND + | (i_units = 'Ounce' OR i_units = 'Oz') AND + | (i_size = 'medium' OR i_size = 'extra large') + | ) OR + | (i_category = 'Women' AND + | (i_color = 'brown' OR i_color = 'honeydew') AND + | (i_units = 'Bunch' OR i_units = 'Ton') AND + | (i_size = 'N/A' OR i_size = 'small') + | ) OR + | (i_category = 'Men' AND + | (i_color = 'floral' OR i_color = 'deep') AND + | (i_units = 'N/A' OR i_units = 'Dozen') AND + | (i_size = 'petite' OR i_size = 'large') + | ) OR + | (i_category = 'Men' AND + | (i_color = 'light' OR i_color = 'cornflower') AND + | (i_units = 'Box' OR i_units = 'Pound') AND + | (i_size = 'medium' OR i_size = 'extra large') + | ))) OR + | (i_manufact = i1.i_manufact AND + | ((i_category = 'Women' AND + | (i_color = 'midnight' OR i_color = 'snow') AND + | (i_units = 'Pallet' OR i_units = 'Gross') AND + | (i_size = 'medium' OR i_size = 'extra large') + | ) OR + | (i_category = 'Women' AND + | (i_color = 'cyan' OR i_color = 'papaya') AND + | (i_units = 'Cup' OR i_units = 'Dram') AND + | (i_size = 'N/A' OR i_size = 'small') + | ) OR + | (i_category = 'Men' AND + | (i_color = 'orange' OR i_color = 'frosted') AND + | (i_units = 'Each' OR i_units = 'Tbl') AND + | (i_size = 'petite' OR i_size = 'large') + | ) OR + | (i_category = 'Men' AND + | (i_color = 'forest' OR i_color = 'ghost') AND + | (i_units = 'Lb' OR i_units = 'Bundle') AND + | (i_size = 'medium' OR i_size = 'extra large') + | )))) > 0 + |ORDER BY i_product_name + |LIMIT 100 + """.stripMargin + util.verifyPlan(sqlQuery) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/LimitTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/LimitTest.scala index fe440f468586c3..b98d7b20846e9a 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/LimitTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/LimitTest.scala @@ -28,7 +28,6 @@ class LimitTest extends TableTestBase { private val util = streamTestUtil() util.addDataStream[(Int, String, Long)]("MyTable", 'a, 'b, 'c, 'proctime, 'rowtime) - // TODO optimize `limit 0` @Test def testLimitWithoutOffset(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala new file mode 100644 index 00000000000000..85f4d4019b33a3 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SetOperatorsTest.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.stream.sql + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.GenericTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{TableException, ValidationException} +import org.apache.flink.table.plan.util.NonPojo +import org.apache.flink.table.util.TableTestBase + +import org.junit.{Before, Test} + +class SetOperatorsTest extends TableTestBase { + + private val util = streamTestUtil() + + @Before + def before(): Unit = { + util.addTableSource[(Int, Long, String)]("T1", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String)]("T2", 'd, 'e, 'f) + util.addTableSource[(Int, Long, Int, String, Long)]("T3", 'a, 'b, 'd, 'c, 'e) + } + + @Test(expected = classOf[ValidationException]) + def testUnionDifferentColumnSize(): Unit = { + // must fail. Union inputs have different column size. + util.verifyPlan("SELECT * FROM T1 UNION ALL SELECT * FROM T3") + } + + @Test(expected = classOf[ValidationException]) + def testUnionDifferentFieldTypes(): Unit = { + // must fail. Union inputs have different field types. + util.verifyPlan("SELECT a, b, c FROM T1 UNION ALL SELECT d, c, e FROM T3") + } + + @Test(expected = classOf[TableException]) + def testIntersectAll(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT ALL SELECT f FROM T2") + } + + @Test(expected = classOf[ValidationException]) + def testIntersectDifferentFieldTypes(): Unit = { + // must fail. Intersect inputs have different field types. + util.verifyPlan("SELECT a, b, c FROM T1 INTERSECT SELECT d, c, e FROM T3") + } + + @Test(expected = classOf[TableException]) + def testMinusAll(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT ALL SELECT f FROM T2") + } + + @Test(expected = classOf[ValidationException]) + def testMinusDifferentFieldTypes(): Unit = { + // must fail. Minus inputs have different field types. + util.verifyPlan("SELECT a, b, c FROM T1 EXCEPT SELECT d, c, e FROM T3") + } + + @Test + def testIntersect(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2") + } + + @Test + def testIntersectLeftIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 WHERE 1=0 INTERSECT SELECT f FROM T2") + } + + @Test + def testIntersectRightIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 INTERSECT SELECT f FROM T2 WHERE 1=0") + } + + @Test + def testMinus(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2") + } + + @Test + def testMinusLeftIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 WHERE 1=0 EXCEPT SELECT f FROM T2") + } + + @Test + def testMinusRightIsEmpty(): Unit = { + util.verifyPlan("SELECT c FROM T1 EXCEPT SELECT f FROM T2 WHERE 1=0") + } + + @Test + def testMinusWithNestedTypes(): Unit = { + util.addTableSource[(Long, (Int, String), Array[Boolean])]("MyTable", 'a, 'b, 'c) + util.verifyPlan("SELECT * FROM MyTable EXCEPT SELECT * FROM MyTable") + } + + @Test + def testUnionNullableTypes(): Unit = { + util.addTableSource[((Int, String), (Int, String), Int)]("A", 'a, 'b, 'c) + util.verifyPlan("SELECT a FROM A UNION ALL SELECT CASE WHEN c > 0 THEN b ELSE NULL END FROM A") + } + + @Test + def testUnionAnyType(): Unit = { + val util = batchTestUtil() + util.addTableSource("A", + Array[TypeInformation[_]]( + new GenericTypeInfo(classOf[NonPojo]), + new GenericTypeInfo(classOf[NonPojo])), + Array("a", "b")) + util.verifyPlan("SELECT a FROM A UNION ALL SELECT b FROM A") + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SortLimitTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SortLimitTest.scala index 16335073df6b9f..285fc461d35bfa 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SortLimitTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SortLimitTest.scala @@ -27,7 +27,6 @@ class SortLimitTest extends TableTestBase { private val util = streamTestUtil() util.addDataStream[(Int, String, Long)]("MyTable", 'a, 'b, 'c, 'proctime, 'rowtime) - // TODO optimize `limit 0` @Test def testSortProcessingTimeAscWithOffSet0AndLimit1(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.scala index a4db284fa52679..7613792e5fae4e 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/SubplanReuseTest.scala @@ -284,8 +284,7 @@ class SubplanReuseTest extends TableTestBase { util.verifyPlan(sqlQuery) } - @Test(expected = classOf[TableException]) - // INTERSECT is not supported now + @Test def testSubplanReuseWithDynamicFunction(): Unit = { val sqlQuery = util.tableEnv.sqlQuery( """ diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/util/pojos.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/util/pojos.scala index 9704b457f01dba..5d2e62e2417567 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/util/pojos.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/util/pojos.scala @@ -39,3 +39,13 @@ class MyPojo() { override def toString = s"MyPojo($f1, $f2)" } + +class NonPojo { + val x = new java.util.HashMap[String, String]() + + override def toString: String = x.toString + + override def hashCode(): Int = super.hashCode() + + override def equals(obj: scala.Any): Boolean = super.equals(obj) +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/Limit0RemoveITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/Limit0RemoveITCase.scala new file mode 100644 index 00000000000000..def6c2914e5525 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/Limit0RemoveITCase.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.batch.sql + +import org.apache.flink.table.api.TableException +import org.apache.flink.table.runtime.utils.BatchTestBase +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.TestData.numericType + +import org.junit.{Before, Test} + +import java.math.{BigDecimal => JBigDecimal} + +import scala.collection.Seq + + +class Limit0RemoveITCase extends BatchTestBase { + + @Before + def before(): Unit = { + lazy val numericData = Seq( + row(null, 1L, 1.0f, 1.0d, JBigDecimal.valueOf(1)), + row(2, null, 2.0f, 2.0d, JBigDecimal.valueOf(2)), + row(3, 3L, null, 3.0d, JBigDecimal.valueOf(3)), + row(3, 3L, 4.0f, null, JBigDecimal.valueOf(3)) + ) + + registerCollection("t1", numericData, numericType, "a, b, c, d, e") + registerCollection("t2", numericData, numericType, "a, b, c, d, e") + } + + @Test + def testSimpleLimitRemove(): Unit = { + val sqlQuery = "SELECT * FROM t1 LIMIT 0" + checkResult(sqlQuery, Seq()) + } + + @Test + def testLimitRemoveWithOrderBy(): Unit = { + val sqlQuery = "SELECT * FROM t1 ORDER BY a LIMIT 0" + checkResult(sqlQuery, Seq()) + } + + @Test + def testLimitRemoveWithJoin(): Unit = { + val sqlQuery = "SELECT * FROM t1 JOIN (SELECT * FROM t2 LIMIT 0) ON true" + checkResult(sqlQuery, Seq()) + } + + @Test + def testLimitRemoveWithIn(): Unit = { + val sqlQuery = "SELECT * FROM t1 WHERE a IN (SELECT a FROM t2 LIMIT 0)" + checkResult(sqlQuery, Seq()) + } + + @Test + def testLimitRemoveWithNotIn(): Unit = { + val sqlQuery = "SELECT a FROM t1 WHERE a NOT IN (SELECT a FROM t2 LIMIT 0)" + checkResult(sqlQuery, Seq(row(2), row(3), row(3), row(null))) + } + + @Test(expected = classOf[TableException]) + // TODO remove exception after translateToPlanInternal is implemented in BatchExecNestedLoopJoin + def testLimitRemoveWithExists(): Unit = { + val sqlQuery = "SELECT * FROM t1 WHERE EXISTS (SELECT a FROM t2 LIMIT 0)" + checkResult(sqlQuery, Seq()) + } + + @Test(expected = classOf[TableException]) + // TODO remove exception after translateToPlanInternal is implemented in BatchExecNestedLoopJoin + def testLimitRemoveWithNotExists(): Unit = { + val sqlQuery = "SELECT * FROM t1 WHERE NOT EXISTS (SELECT a FROM t2 LIMIT 0)" + checkResult(sqlQuery, Seq(row(2), row(3), row(3), row(null))) + } + + @Test + def testLimitRemoveWithSelect(): Unit = { + val sqlQuery = "SELECT * FROM (SELECT a FROM t2 LIMIT 0)" + checkResult(sqlQuery, Seq()) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/Limit0RemoveITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/Limit0RemoveITCase.scala new file mode 100644 index 00000000000000..6a402d5014978a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/Limit0RemoveITCase.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{TableConfigOptions, TableException} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.runtime.utils.{StreamingTestBase, TestSinkUtil, TestingAppendTableSink, TestingUpsertTableSink} + +import org.junit.Assert.assertEquals +import org.junit.{Before, Test} + +class Limit0RemoveITCase extends StreamingTestBase() { + + @Before + def setup(): Unit = { + tEnv.getConfig.getConf.setBoolean(TableConfigOptions.SQL_EXEC_SOURCE_VALUES_INPUT_ENABLED, true) + } + + @Test + def testSimpleLimitRemove(): Unit = { + val ds = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table = ds.toTable(tEnv, 'a) + tEnv.registerTable("MyTable", table) + + val sql = "SELECT * FROM MyTable LIMIT 0" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink()) + tEnv.writeToSink(result, sink) + tEnv.execute() + + assertEquals(0, sink.getAppendResults.size) + } + + @Test + def testLimitRemoveWithOrderBy(): Unit = { + val ds = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table = ds.toTable(tEnv, 'a) + tEnv.registerTable("MyTable", table) + + val sql = "SELECT * FROM MyTable ORDER BY a LIMIT 0" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink()) + tEnv.writeToSink(result, sink) + tEnv.execute() + + assertEquals(0, sink.getAppendResults.size) + } + + @Test + def testLimitRemoveWithSelect(): Unit = { + val ds = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table = ds.toTable(tEnv, 'a) + tEnv.registerTable("MyTable", table) + + val sql = "select a2 from (select cast(a as int) a2 from MyTable limit 0)" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink()) + tEnv.writeToSink(result, sink) + tEnv.execute() + + assertEquals(0, sink.getAppendResults.size) + } + + @Test + def testLimitRemoveWithIn(): Unit = { + val ds1 = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table1 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable1", table1) + + val ds2 = env.fromCollection(Seq(1, 2, 3)) + val table2 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable2", table2) + + val sql = "SELECT * FROM MyTable1 WHERE a IN (SELECT a FROM MyTable2 LIMIT 0)" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink()) + tEnv.writeToSink(result, sink) + tEnv.execute() + + assertEquals(0, sink.getAppendResults.size) + } + + @Test + def testLimitRemoveWithNotIn(): Unit = { + val ds1 = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table1 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable1", table1) + + val ds2 = env.fromCollection(Seq(1, 2, 3)) + val table2 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable2", table2) + + val sql = "SELECT * FROM MyTable1 WHERE a NOT IN (SELECT a FROM MyTable2 LIMIT 0)" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink()) + tEnv.writeToSink(result, sink) + tEnv.execute() + + val expected = Seq("1", "2", "3", "4", "5", "6") + assertEquals(expected, sink.getAppendResults.sorted) + } + + @Test(expected = classOf[TableException]) + // TODO remove exception after translateToPlanInternal is implemented in StreamExecJoin + def testLimitRemoveWithExists(): Unit = { + val ds1 = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table1 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable1", table1) + + val ds2 = env.fromCollection(Seq(1, 2, 3)) + val table2 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable2", table2) + + val sql = "SELECT * FROM MyTable1 WHERE EXISTS (SELECT a FROM MyTable2 LIMIT 0)" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingUpsertTableSink(Array(0))) + tEnv.writeToSink(result, sink) + tEnv.execute() + + assertEquals(0, sink.getRawResults.size) + } + + @Test(expected = classOf[TableException]) + // TODO remove exception after translateToPlanInternal is implemented in StreamExecJoin + def testLimitRemoveWithNotExists(): Unit = { + val ds1 = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table1 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable1", table1) + + val ds2 = env.fromCollection(Seq(1, 2, 3)) + val table2 = ds1.toTable(tEnv, 'a) + tEnv.registerTable("MyTable2", table2) + + val sql = "SELECT * FROM MyTable1 WHERE NOT EXISTS (SELECT a FROM MyTable2 LIMIT 0)" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingUpsertTableSink(Array(0))) + tEnv.writeToSink(result, sink) + tEnv.execute() + + val expected = Seq("1", "2", "3", "4", "5", "6") + assertEquals(expected, sink.getUpsertResults.sorted) + } + + @Test + def testLimitRemoveWithJoin(): Unit = { + val ds1 = env.fromCollection(Seq(1, 2, 3, 4, 5, 6)) + val table1 = ds1.toTable(tEnv, 'a1) + tEnv.registerTable("MyTable1", table1) + + val ds2 = env.fromCollection(Seq(1, 2, 3)) + val table2 = ds1.toTable(tEnv, 'a2) + tEnv.registerTable("MyTable2", table2) + + val sql = "SELECT a1 FROM MyTable1 INNER JOIN (SELECT a2 FROM MyTable2 LIMIT 0) ON true" + + val result = tEnv.sqlQuery(sql) + val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink()) + tEnv.writeToSink(result, sink) + tEnv.execute() + + assertEquals(0, sink.getAppendResults.size) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala index 1839ba65434164..4f71696df71858 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestSink.scala @@ -249,7 +249,7 @@ final class TestingUpsertSink(keys: Array[Int], tz: TimeZone) } } -final class TestingUpsertTableSink(keys: Array[Int], tz: TimeZone) +final class TestingUpsertTableSink(val keys: Array[Int], val tz: TimeZone) extends UpsertStreamTableSink[BaseRow] { var fNames: Array[String] = _ var fTypes: Array[TypeInformation[_]] = _ diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestSinkUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestSinkUtil.scala index 958c2646fa06e3..b1affecd90c083 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestSinkUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestSinkUtil.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.table.`type`.TypeConverters.createExternalTypeInfoFromInternalType -import org.apache.flink.table.api.{Table, TableImpl} +import org.apache.flink.table.api.{Table, TableException, TableImpl} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataformat.GenericRow import org.apache.flink.table.runtime.utils.JavaPojos.Pojo1 @@ -40,9 +40,17 @@ object TestSinkUtil { val rowType = table.asInstanceOf[TableImpl].getRelNode.getRowType val fieldNames = rowType.getFieldNames.asScala.toArray val fieldTypes = rowType.getFieldList.asScala - .map(field => FlinkTypeFactory.toInternalType(field.getType)) - .map(createExternalTypeInfoFromInternalType).toArray - new TestingAppendTableSink().configure(fieldNames, fieldTypes).asInstanceOf[T] + .map(field => FlinkTypeFactory.toInternalType(field.getType)) + .map(createExternalTypeInfoFromInternalType).toArray + sink match { + case _: TestingAppendTableSink => + new TestingAppendTableSink().configure(fieldNames, fieldTypes).asInstanceOf[T] + case s: TestingUpsertTableSink => + new TestingUpsertTableSink(s.keys, s.tz).configure(fieldNames, fieldTypes).asInstanceOf[T] + case _: TestingRetractTableSink => + new TestingRetractTableSink().configure(fieldNames, fieldTypes).asInstanceOf[T] + case _ => throw new TableException(s"Unsupported sink: $sink") + } } def fieldToString(field: Any, tz: TimeZone): String = { From 6489237bd64af60c207fcbdd9da51594ea5c00a8 Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Fri, 24 May 2019 14:39:48 +0200 Subject: [PATCH 17/92] [hotfix] Correct redirect from /ops/deployment/oss.html to /ops/filesystems/oss.html --- docs/redirects/oss.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/redirects/oss.md b/docs/redirects/oss.md index 3b34502a591dcb..b9df3e8737be4a 100644 --- a/docs/redirects/oss.md +++ b/docs/redirects/oss.md @@ -1,8 +1,8 @@ --- title: "Aliyun Object Storage Service (OSS)" layout: redirect -redirect: /ops/deployment/oss.html -permalink: /ops/filesystems/oss.html +redirect: /ops/filesystems/oss.html +permalink: /ops/deployment/oss.html --- \ No newline at end of file +--> From e00ec88601583d370e14d7d969b20ab1cbc6ce3e Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 4 Apr 2019 16:00:46 -0400 Subject: [PATCH 18/92] [FLINK-12115][fs] Add support for AzureFS Check for http enabled storage accounts in AzureFS IT tests Add AzureFS standalone E2E test --- docs/ops/filesystems/azure.md | 77 ++++++ docs/ops/filesystems/azure.zh.md | 77 ++++++ docs/ops/filesystems/index.md | 9 +- docs/ops/filesystems/index.zh.md | 7 +- flink-dist/src/main/assemblies/opt.xml | 7 + .../test-scripts/test_azure_fs.sh | 83 +++++++ .../flink-azure-fs-hadoop/pom.xml | 155 ++++++++++++ .../fs/azurefs/AbstractAzureFSFactory.java | 85 +++++++ .../flink/fs/azurefs/AzureFSFactory.java | 30 +++ .../fs/azurefs/SecureAzureFSFactory.java | 30 +++ ...org.apache.flink.core.fs.FileSystemFactory | 17 ++ .../flink/fs/azurefs/AzureFSFactoryTest.java | 94 ++++++++ .../AzureFileSystemBehaviorITCase.java | 220 ++++++++++++++++++ .../src/test/resources/log4j-test.properties | 27 +++ .../runtime/fs/hdfs/HadoopFileSystem.java | 4 +- .../runtime/util}/HadoopConfigLoader.java | 2 +- flink-filesystems/flink-s3-fs-base/pom.xml | 2 +- .../common/AbstractS3FileSystemFactory.java | 1 + .../fs/s3/common/S3EntropyFsFactoryTest.java | 1 + flink-filesystems/flink-s3-fs-hadoop/pom.xml | 5 + .../fs/s3hadoop/S3FileSystemFactory.java | 2 +- .../fs/s3hadoop/HadoopS3FileSystemTest.java | 2 +- flink-filesystems/flink-s3-fs-presto/pom.xml | 6 + .../fs/s3presto/S3FileSystemFactory.java | 2 +- .../fs/s3presto/PrestoS3FileSystemTest.java | 2 +- flink-filesystems/pom.xml | 1 + 26 files changed, 937 insertions(+), 11 deletions(-) create mode 100644 docs/ops/filesystems/azure.md create mode 100644 docs/ops/filesystems/azure.zh.md create mode 100755 flink-end-to-end-tests/test-scripts/test_azure_fs.sh create mode 100644 flink-filesystems/flink-azure-fs-hadoop/pom.xml create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AbstractAzureFSFactory.java create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AzureFSFactory.java create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/SecureAzureFSFactory.java create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFSFactoryTest.java create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFileSystemBehaviorITCase.java create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/test/resources/log4j-test.properties rename flink-filesystems/{flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common => flink-hadoop-fs/src/main/java/org/apache/flink/runtime/util}/HadoopConfigLoader.java (99%) diff --git a/docs/ops/filesystems/azure.md b/docs/ops/filesystems/azure.md new file mode 100644 index 00000000000000..36720c80a11fa2 --- /dev/null +++ b/docs/ops/filesystems/azure.md @@ -0,0 +1,77 @@ +--- +title: "Azure Blob Storage" +nav-title: Azure Blob Storage +nav-parent_id: filesystems +nav-pos: 3 +--- + + +[Azure Blob Storage](https://docs.microsoft.com/en-us/azure/storage/) is a Microsoft-managed service providing cloud storage for a variety of use cases. +You can use Azure Blob Storage with Flink for **reading** and **writing data** as well in conjunction with the [streaming **state backends**]({{ site.baseurl }}/ops/state/state_backends.html) + +* This will be replaced by the TOC +{:toc} + +You can use Azure Blob Storage objects like regular files by specifying paths in the following format: + +{% highlight plain %} +wasb://@$.blob.core.windows.net/ + +// SSL encrypted access +wasbs://@$.blob.core.windows.net/ +{% endhighlight %} + +Below shows how to use Azure Blob Storage with Flink: + +{% highlight java %} +// Read from Azure Blob storage +env.readTextFile("wasb://@$.blob.core.windows.net/"); + +// Write to Azure Blob storage +stream.writeAsText("wasb://@$.blob.core.windows.net/") + +// Use Azure Blob Storage as FsStatebackend +env.setStateBackend(new FsStateBackend("wasb://@$.blob.core.windows.net/")); +{% endhighlight %} + +### Shaded Hadoop Azure Blob Storage file system + +To use `flink-azure-fs-hadoop,` copy the respective JAR file from the opt directory to the lib directory of your Flink distribution before starting Flink, e.g. + +{% highlight bash %} +cp ./opt/flink-azure-fs-hadoop-{{ site.version }}.jar ./lib/ +{% endhighlight %} + +`flink-azure-fs-hadoop` registers default FileSystem wrappers for URIs with the *wasb://* and *wasbs://* (SSL encrypted access) scheme. + +#### Configurations setup +After setting up the Azure Blob Storage FileSystem wrapper, you need to configure credentials to make sure that Flink is allowed to access Azure Blob Storage. + +To allow for easy adoption, you can use the same configuration keys in `flink-conf.yaml` as in Hadoop's `core-site.xml` + +You can see the configuration keys in the [Hadoop Azure Blob Storage documentation](https://hadoop.apache.org/docs/current/hadoop-azure/index.html#Configuring_Credentials). + +There are some required configurations that must be added to `flink-conf.yaml`: + +{% highlight yaml %} +fs.azure.account.key.youraccount.blob.core.windows.net: Azure Blob Storage access key +{% endhighlight %} + +{% top %} diff --git a/docs/ops/filesystems/azure.zh.md b/docs/ops/filesystems/azure.zh.md new file mode 100644 index 00000000000000..36720c80a11fa2 --- /dev/null +++ b/docs/ops/filesystems/azure.zh.md @@ -0,0 +1,77 @@ +--- +title: "Azure Blob Storage" +nav-title: Azure Blob Storage +nav-parent_id: filesystems +nav-pos: 3 +--- + + +[Azure Blob Storage](https://docs.microsoft.com/en-us/azure/storage/) is a Microsoft-managed service providing cloud storage for a variety of use cases. +You can use Azure Blob Storage with Flink for **reading** and **writing data** as well in conjunction with the [streaming **state backends**]({{ site.baseurl }}/ops/state/state_backends.html) + +* This will be replaced by the TOC +{:toc} + +You can use Azure Blob Storage objects like regular files by specifying paths in the following format: + +{% highlight plain %} +wasb://@$.blob.core.windows.net/ + +// SSL encrypted access +wasbs://@$.blob.core.windows.net/ +{% endhighlight %} + +Below shows how to use Azure Blob Storage with Flink: + +{% highlight java %} +// Read from Azure Blob storage +env.readTextFile("wasb://@$.blob.core.windows.net/"); + +// Write to Azure Blob storage +stream.writeAsText("wasb://@$.blob.core.windows.net/") + +// Use Azure Blob Storage as FsStatebackend +env.setStateBackend(new FsStateBackend("wasb://@$.blob.core.windows.net/")); +{% endhighlight %} + +### Shaded Hadoop Azure Blob Storage file system + +To use `flink-azure-fs-hadoop,` copy the respective JAR file from the opt directory to the lib directory of your Flink distribution before starting Flink, e.g. + +{% highlight bash %} +cp ./opt/flink-azure-fs-hadoop-{{ site.version }}.jar ./lib/ +{% endhighlight %} + +`flink-azure-fs-hadoop` registers default FileSystem wrappers for URIs with the *wasb://* and *wasbs://* (SSL encrypted access) scheme. + +#### Configurations setup +After setting up the Azure Blob Storage FileSystem wrapper, you need to configure credentials to make sure that Flink is allowed to access Azure Blob Storage. + +To allow for easy adoption, you can use the same configuration keys in `flink-conf.yaml` as in Hadoop's `core-site.xml` + +You can see the configuration keys in the [Hadoop Azure Blob Storage documentation](https://hadoop.apache.org/docs/current/hadoop-azure/index.html#Configuring_Credentials). + +There are some required configurations that must be added to `flink-conf.yaml`: + +{% highlight yaml %} +fs.azure.account.key.youraccount.blob.core.windows.net: Azure Blob Storage access key +{% endhighlight %} + +{% top %} diff --git a/docs/ops/filesystems/index.md b/docs/ops/filesystems/index.md index 0d4a1bebf4dd67..eb4087dd8abee4 100644 --- a/docs/ops/filesystems/index.md +++ b/docs/ops/filesystems/index.md @@ -25,7 +25,7 @@ under the License. --> Apache Flink uses file systems to consume and persistently store data, both for the results of applications and for fault tolerance and recovery. -These are some of most of the popular file systems, including *local*, *hadoop-compatible*, *S3*, *MapR FS*, *OpenStack Swift FS* and *Aliyun OSS*. +These are some of most of the popular file systems, including *local*, *hadoop-compatible*, *S3*, *MapR FS*, *OpenStack Swift FS*, *Aliyun OSS* and *Azure Blob Storage*. The file system used for a particular file is determined by its URI scheme. For example, `file:///home/user/text.txt` refers to a file in the local file system, while `hdfs://namenode:50010/data/user/text.txt` is a file in a specific HDFS cluster. @@ -43,12 +43,17 @@ Flink ships with implementations for the following file systems: - **S3**: Flink directly provides file systems to talk to Amazon S3 with two alternative implementations, `flink-s3-fs-presto` and `flink-s3-fs-hadoop`. Both implementations are self-contained with no dependency footprint. - - **MapR FS**: The MapR file system *"maprfs://"* is automatically available when the MapR libraries are in the classpath. + - **MapR FS**: The MapR file system *"maprfs://"* is automatically available when the MapR libraries are in the classpath. - **OpenStack Swift FS**: Flink directly provides a file system to talk to the OpenStack Swift file system, registered under the scheme *"swift://"*. The implementation of `flink-swift-fs-hadoop` is based on the [Hadoop Project](https://hadoop.apache.org/) but is self-contained with no dependency footprint. To use it when using Flink as a library, add the respective maven dependency (`org.apache.flink:flink-swift-fs-hadoop:{{ site.version }}` When starting a Flink application from the Flink binaries, copy or move the respective jar file from the `opt` folder to the `lib` folder. + + - **Azure Blob Storage**: + Flink directly provides a file system to work with Azure Blob Storage. + This filesystem is registered under the scheme *"wasb(s)://"*. + The implementation is self-contained with no dependency footprint. ## HDFS and Hadoop File System support diff --git a/docs/ops/filesystems/index.zh.md b/docs/ops/filesystems/index.zh.md index 0d4a1bebf4dd67..414c82f773e44e 100644 --- a/docs/ops/filesystems/index.zh.md +++ b/docs/ops/filesystems/index.zh.md @@ -25,7 +25,7 @@ under the License. --> Apache Flink uses file systems to consume and persistently store data, both for the results of applications and for fault tolerance and recovery. -These are some of most of the popular file systems, including *local*, *hadoop-compatible*, *S3*, *MapR FS*, *OpenStack Swift FS* and *Aliyun OSS*. +These are some of most of the popular file systems, including *local*, *hadoop-compatible*, *S3*, *MapR FS*, *OpenStack Swift FS*, *Aliyun OSS* and *Azure Blob Storage*. The file system used for a particular file is determined by its URI scheme. For example, `file:///home/user/text.txt` refers to a file in the local file system, while `hdfs://namenode:50010/data/user/text.txt` is a file in a specific HDFS cluster. @@ -49,6 +49,11 @@ Flink ships with implementations for the following file systems: The implementation of `flink-swift-fs-hadoop` is based on the [Hadoop Project](https://hadoop.apache.org/) but is self-contained with no dependency footprint. To use it when using Flink as a library, add the respective maven dependency (`org.apache.flink:flink-swift-fs-hadoop:{{ site.version }}` When starting a Flink application from the Flink binaries, copy or move the respective jar file from the `opt` folder to the `lib` folder. + + - **Azure Blob Storage**: + Flink directly provides a file system to work with Azure Blob Storage. + This filesystem is registered under the scheme *"wasb(s)://"*. + The implementation is self-contained with no dependency footprint. ## HDFS and Hadoop File System support diff --git a/flink-dist/src/main/assemblies/opt.xml b/flink-dist/src/main/assemblies/opt.xml index 1ce6e8899e819f..e28acd8c962329 100644 --- a/flink-dist/src/main/assemblies/opt.xml +++ b/flink-dist/src/main/assemblies/opt.xml @@ -154,6 +154,13 @@ 0644 + + ../flink-filesystems/flink-azure-fs-hadoop/target/flink-azure-fs-hadoop-${project.version}.jar + opt/ + flink-azure-fs-hadoop-${project.version}.jar + 0644 + + ../flink-queryable-state/flink-queryable-state-runtime/target/flink-queryable-state-runtime_${scala.binary.version}-${project.version}.jar diff --git a/flink-end-to-end-tests/test-scripts/test_azure_fs.sh b/flink-end-to-end-tests/test-scripts/test_azure_fs.sh new file mode 100755 index 00000000000000..40c6962d4a4c08 --- /dev/null +++ b/flink-end-to-end-tests/test-scripts/test_azure_fs.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Tests for Azure file system. + +# To run single test, export IT_CASE_AZURE_ACCOUNT, IT_CASE_AZURE_ACCESS_KEY, IT_CASE_AZURE_CONTAINER to +# the appropriate values and run: +# flink-end-to-end-tests/run-single-test.sh skip flink-end-to-end-tests/test-scripts/test_azure_fs.sh + +source "$(dirname "$0")"/common.sh + +if [[ -z "$IT_CASE_AZURE_ACCOUNT" ]]; then + echo "Did not find Azure storage account environment variable, NOT running the e2e test." + exit 0 +else + echo "Found Azure storage account $IT_CASE_AZURE_ACCOUNT, running the e2e test." +fi + +if [[ -z "$IT_CASE_AZURE_ACCESS_KEY" ]]; then + echo "Did not find Azure storage access key environment variable, NOT running the e2e test." + exit 0 +else + echo "Found Azure storage access key $IT_CASE_AZURE_ACCESS_KEY, running the e2e test." +fi + +if [[ -z "$IT_CASE_AZURE_CONTAINER" ]]; then + echo "Did not find Azure storage container environment variable, NOT running the e2e test." + exit 0 +else + echo "Found Azure storage container $IT_CASE_AZURE_CONTAINER, running the e2e test." +fi + +AZURE_TEST_DATA_WORDS_URI="wasbs://$IT_CASE_AZURE_CONTAINER@$IT_CASE_AZURE_ACCOUNT.blob.core.windows.net/words" + +################################### +# Setup Flink Azure access. +# +# Globals: +# FLINK_DIR +# IT_CASE_AZURE_ACCOUNT +# IT_CASE_AZURE_ACCESS_KEY +# Returns: +# None +################################### +function azure_setup { + # make sure we delete the file at the end + function azure_cleanup { + rm $FLINK_DIR/lib/flink-azure-fs*.jar + + # remove any leftover settings + sed -i -e 's/fs.azure.account.key.*//' "$FLINK_DIR/conf/flink-conf.yaml" + } + trap azure_cleanup EXIT + + echo "Copying flink azure jars and writing out configs" + cp $FLINK_DIR/opt/flink-azure-fs-hadoop-*.jar $FLINK_DIR/lib/ + echo "fs.azure.account.key.$IT_CASE_AZURE_ACCOUNT.blob.core.windows.net: $IT_CASE_AZURE_ACCESS_KEY" >> "$FLINK_DIR/conf/flink-conf.yaml" +} + +azure_setup + +echo "Starting Flink cluster.." +start_cluster + +$FLINK_DIR/bin/flink run -p 1 $FLINK_DIR/examples/batch/WordCount.jar --input $AZURE_TEST_DATA_WORDS_URI --output $TEST_DATA_DIR/out/wc_out + +check_result_hash "WordCountWithAzureFS" $TEST_DATA_DIR/out/wc_out "72a690412be8928ba239c2da967328a5" diff --git a/flink-filesystems/flink-azure-fs-hadoop/pom.xml b/flink-filesystems/flink-azure-fs-hadoop/pom.xml new file mode 100644 index 00000000000000..37567ce633ecd3 --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/pom.xml @@ -0,0 +1,155 @@ + + + + + 4.0.0 + + + org.apache.flink + flink-filesystems + 1.9-SNAPSHOT + .. + + + flink-azure-fs-hadoop + flink-azure-fs-hadoop + + jar + + + + 2.7.0 + 1.16.0 + 2.9.4 + + + + + + org.apache.flink + flink-core + ${project.version} + provided + + + + org.apache.flink + flink-hadoop-fs + ${project.version} + + + + org.apache.hadoop + hadoop-azure + ${fs.azure.version} + + + + + org.apache.hadoop + hadoop-hdfs + ${hadoop.version} + + + + + com.microsoft.azure + azure + ${fs.azure.sdk.version} + test + + + com.fasterxml.jackson.core + jackson-core + ${fs.jackson.core.version} + test + + + com.google.guava + guava + ${guava.version} + test + + + + + org.apache.flink + flink-core + ${project.version} + test + test-jar + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + package + + shade + + + false + + + *:* + + + + + org.apache.hadoop + org.apache.flink.fs.shaded.hadoop.org.apache.hadoop + + + + com.microsoft.azure.storage + org.apache.flink.fs.shaded.com.microsoft.azure.storage + + + + + * + + properties.dtd + PropertyList-1.0.dtd + mozilla/** + META-INF/maven/** + META-INF/LICENSE.txt + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + core-default.xml + hdfs-default.xml + + + + + + + + + + diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AbstractAzureFSFactory.java b/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AbstractAzureFSFactory.java new file mode 100644 index 00000000000000..7ae9df8cc4507d --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AbstractAzureFSFactory.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.azurefs; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.FileSystemFactory; +import org.apache.flink.runtime.fs.hdfs.HadoopFileSystem; +import org.apache.flink.runtime.util.HadoopConfigLoader; + +import org.apache.hadoop.fs.azure.NativeAzureFileSystem; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.util.Collections; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Abstract factory for AzureFS. Subclasses override to specify + * the correct scheme (wasb / wasbs). Based on Azure HDFS support in the + * hadoop-azure module. + */ +public abstract class AbstractAzureFSFactory implements FileSystemFactory { + private static final Logger LOG = LoggerFactory.getLogger(AzureFSFactory.class); + + private static final String[] FLINK_CONFIG_PREFIXES = { "fs.azure.", "azure." }; + private static final String HADOOP_CONFIG_PREFIX = "fs.azure."; + + private static final String[][] MIRRORED_CONFIG_KEYS = {}; + private static final Set PACKAGE_PREFIXES_TO_SHADE = Collections.emptySet(); + private static final Set CONFIG_KEYS_TO_SHADE = Collections.emptySet(); + private static final String FLINK_SHADING_PREFIX = ""; + + private final HadoopConfigLoader configLoader; + + private Configuration flinkConfig; + + public AbstractAzureFSFactory() { + this.configLoader = new HadoopConfigLoader(FLINK_CONFIG_PREFIXES, MIRRORED_CONFIG_KEYS, + HADOOP_CONFIG_PREFIX, PACKAGE_PREFIXES_TO_SHADE, CONFIG_KEYS_TO_SHADE, FLINK_SHADING_PREFIX); + } + + @Override + public void configure(Configuration config) { + flinkConfig = config; + configLoader.setFlinkConfig(config); + } + + @Override + public FileSystem create(URI fsUri) throws IOException { + checkNotNull(fsUri, "passed file system URI object should not be null"); + LOG.info("Trying to load and instantiate Azure File System"); + return new HadoopFileSystem(createInitializedAzureFS(fsUri, flinkConfig)); + } + + // uri is of the form: wasb(s)://yourcontainer@youraccount.blob.core.windows.net/testDir + private org.apache.hadoop.fs.FileSystem createInitializedAzureFS(URI fsUri, Configuration flinkConfig) throws IOException { + org.apache.hadoop.conf.Configuration hadoopConfig = configLoader.getOrLoadHadoopConfig(); + + org.apache.hadoop.fs.FileSystem azureFS = new NativeAzureFileSystem(); + azureFS.initialize(fsUri, hadoopConfig); + + return azureFS; + } +} diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AzureFSFactory.java b/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AzureFSFactory.java new file mode 100644 index 00000000000000..5f6246d7f93501 --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/AzureFSFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.azurefs; + +/** + * A factory for the Azure file system over HTTP. + */ +public class AzureFSFactory extends AbstractAzureFSFactory { + + @Override + public String getScheme() { + return "wasb"; + } +} diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/SecureAzureFSFactory.java b/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/SecureAzureFSFactory.java new file mode 100644 index 00000000000000..7130a879c44f96 --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/main/java/org/apache/flink/fs/azurefs/SecureAzureFSFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.azurefs; + +/** + * A factory for the Azure file system over HTTPs. + */ +public class SecureAzureFSFactory extends AbstractAzureFSFactory { + + @Override + public String getScheme() { + return "wasbs"; + } +} diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory b/flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory new file mode 100644 index 00000000000000..4d6a19aa54e6cc --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.fs.azurefs.AzureFSFactory +org.apache.flink.fs.azurefs.SecureAzureFSFactory diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFSFactoryTest.java b/flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFSFactoryTest.java new file mode 100644 index 00000000000000..01b79b5884f4fe --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFSFactoryTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.azurefs; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.TestLogger; + +import org.apache.hadoop.fs.azure.AzureException; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +/** + * Tests for the AzureFSFactory. + */ +@RunWith(Parameterized.class) +public class AzureFSFactoryTest extends TestLogger { + + @Parameterized.Parameter + public String scheme; + + @Parameterized.Parameters(name = "Scheme = {0}") + public static List parameters() { + return Arrays.asList("wasb", "wasbs"); + } + + @Rule + public final ExpectedException exception = ExpectedException.none(); + + private AbstractAzureFSFactory getFactory(String scheme) { + return scheme.equals("wasb") ? new AzureFSFactory() : new SecureAzureFSFactory(); + } + + @Test + public void testNullFsURI() throws Exception { + URI uri = null; + AbstractAzureFSFactory factory = getFactory(scheme); + + exception.expect(NullPointerException.class); + exception.expectMessage("passed file system URI object should not be null"); + + factory.create(uri); + } + + // missing credentials + @Test + public void testCreateFsWithAuthorityMissingCreds() throws Exception { + String uriString = String.format("%s://yourcontainer@youraccount.blob.core.windows.net/testDir", scheme); + final URI uri = URI.create(uriString); + + exception.expect(AzureException.class); + + AbstractAzureFSFactory factory = getFactory(scheme); + Configuration config = new Configuration(); + config.setInteger("fs.azure.io.retry.max.retries", 0); + factory.configure(config); + factory.create(uri); + } + + @Test + public void testCreateFsWithMissingAuthority() throws Exception { + String uriString = String.format("%s:///my/path", scheme); + final URI uri = URI.create(uriString); + + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Cannot initialize WASB file system, URI authority not recognized."); + + AbstractAzureFSFactory factory = getFactory(scheme); + factory.configure(new Configuration()); + factory.create(uri); + } +} diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFileSystemBehaviorITCase.java b/flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFileSystemBehaviorITCase.java new file mode 100644 index 00000000000000..6c65be90cbd637 --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/test/java/org/apache/flink/fs/azurefs/AzureFileSystemBehaviorITCase.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.azurefs; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.FSDataOutputStream; +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.FileSystemBehaviorTestSuite; +import org.apache.flink.core.fs.FileSystemKind; +import org.apache.flink.core.fs.Path; +import org.apache.flink.util.StringUtils; + +import com.microsoft.azure.credentials.ApplicationTokenCredentials; +import com.microsoft.azure.credentials.AzureTokenCredentials; +import com.microsoft.azure.management.Azure; +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; + +import static org.apache.flink.core.fs.FileSystemTestUtils.checkPathEventualExistence; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * An implementation of the {@link FileSystemBehaviorTestSuite} for Azure based + * file system. + */ +@RunWith(Parameterized.class) +public class AzureFileSystemBehaviorITCase extends FileSystemBehaviorTestSuite { + + @Parameterized.Parameter + public String scheme; + + private static final String CONTAINER = System.getenv("ARTIFACTS_AZURE_CONTAINER"); + private static final String ACCOUNT = System.getenv("ARTIFACTS_AZURE_STORAGE_ACCOUNT"); + private static final String ACCESS_KEY = System.getenv("ARTIFACTS_AZURE_ACCESS_KEY"); + private static final String RESOURCE_GROUP = System.getenv("ARTIFACTS_AZURE_RESOURCE_GROUP"); + private static final String SUBSCRIPTION_ID = System.getenv("ARTIFACTS_AZURE_SUBSCRIPTION_ID"); + private static final String TOKEN_CREDENTIALS_FILE = System.getenv("ARTIFACTS_AZURE_TOKEN_CREDENTIALS_FILE"); + + private static final String TEST_DATA_DIR = "tests-" + UUID.randomUUID(); + + // Azure Blob Storage defaults to https only storage accounts. We check if http support has been + // enabled on a best effort basis and test http if so. + @Parameterized.Parameters(name = "Scheme = {0}") + public static List parameters() throws IOException { + boolean httpsOnly = isHttpsTrafficOnly(); + return httpsOnly ? Arrays.asList("wasbs") : Arrays.asList("wasb", "wasbs"); + } + + private static boolean isHttpsTrafficOnly() throws IOException { + if (StringUtils.isNullOrWhitespaceOnly(RESOURCE_GROUP) || StringUtils.isNullOrWhitespaceOnly(TOKEN_CREDENTIALS_FILE)) { + // default to https only, as some fields are missing + return true; + } + + Assume.assumeTrue("Azure storage account not configured, skipping test...", !StringUtils.isNullOrWhitespaceOnly(ACCOUNT)); + + AzureTokenCredentials credentials = ApplicationTokenCredentials.fromFile(new File(TOKEN_CREDENTIALS_FILE)); + Azure azure = + StringUtils.isNullOrWhitespaceOnly(SUBSCRIPTION_ID) ? + Azure.authenticate(credentials).withDefaultSubscription() : + Azure.authenticate(credentials).withSubscription(SUBSCRIPTION_ID); + + return azure.storageAccounts().getByResourceGroup(RESOURCE_GROUP, ACCOUNT).inner().enableHttpsTrafficOnly(); + } + + @BeforeClass + public static void checkCredentialsAndSetup() throws IOException { + // check whether credentials and container details exist + Assume.assumeTrue("Azure container not configured, skipping test...", !StringUtils.isNullOrWhitespaceOnly(CONTAINER)); + Assume.assumeTrue("Azure access key not configured, skipping test...", !StringUtils.isNullOrWhitespaceOnly(ACCESS_KEY)); + + // initialize configuration with valid credentials + final Configuration conf = new Configuration(); + // fs.azure.account.key.youraccount.blob.core.windows.net = ACCESS_KEY + conf.setString("fs.azure.account.key." + ACCOUNT + ".blob.core.windows.net", ACCESS_KEY); + FileSystem.initialize(conf); + } + + @AfterClass + public static void clearFsConfig() throws IOException { + FileSystem.initialize(new Configuration()); + } + + @Override + public FileSystem getFileSystem() throws Exception { + return getBasePath().getFileSystem(); + } + + @Override + public Path getBasePath() { + // wasb(s)://yourcontainer@youraccount.blob.core.windows.net/testDataDir + String uriString = scheme + "://" + CONTAINER + '@' + ACCOUNT + ".blob.core.windows.net/" + TEST_DATA_DIR; + return new Path(uriString); + } + + @Test + public void testSimpleFileWriteAndRead() throws Exception { + final long deadline = System.nanoTime() + 30_000_000_000L; // 30 secs + + final String testLine = "Hello Upload!"; + + final Path path = new Path(getBasePath() + "/test.txt"); + final FileSystem fs = path.getFileSystem(); + + try { + try (FSDataOutputStream out = fs.create(path, FileSystem.WriteMode.OVERWRITE); + OutputStreamWriter writer = new OutputStreamWriter(out, StandardCharsets.UTF_8)) { + writer.write(testLine); + } + + // just in case, wait for the path to exist + checkPathEventualExistence(fs, path, true, deadline); + + try (FSDataInputStream in = fs.open(path); + InputStreamReader ir = new InputStreamReader(in, StandardCharsets.UTF_8); + BufferedReader reader = new BufferedReader(ir)) { + String line = reader.readLine(); + assertEquals(testLine, line); + } + } + finally { + fs.delete(path, false); + } + + // now file must be gone + checkPathEventualExistence(fs, path, false, deadline); + } + + @Test + public void testDirectoryListing() throws Exception { + final long deadline = System.nanoTime() + 30_000_000_000L; // 30 secs + + final Path directory = new Path(getBasePath() + "/testdir/"); + final FileSystem fs = directory.getFileSystem(); + + // directory must not yet exist + assertFalse(fs.exists(directory)); + + try { + // create directory + assertTrue(fs.mkdirs(directory)); + + checkPathEventualExistence(fs, directory, true, deadline); + + // directory empty + assertEquals(0, fs.listStatus(directory).length); + + // create some files + final int numFiles = 3; + for (int i = 0; i < numFiles; i++) { + Path file = new Path(directory, "/file-" + i); + try (FSDataOutputStream out = fs.create(file, FileSystem.WriteMode.OVERWRITE); + OutputStreamWriter writer = new OutputStreamWriter(out, StandardCharsets.UTF_8)) { + writer.write("hello-" + i + "\n"); + } + // just in case, wait for the file to exist (should then also be reflected in the + // directory's file list below) + checkPathEventualExistence(fs, file, true, deadline); + } + + FileStatus[] files = fs.listStatus(directory); + assertNotNull(files); + assertEquals(3, files.length); + + for (FileStatus status : files) { + assertFalse(status.isDir()); + } + + // now that there are files, the directory must exist + assertTrue(fs.exists(directory)); + } + finally { + // clean up + fs.delete(directory, true); + } + + // now directory must be gone + checkPathEventualExistence(fs, directory, false, deadline); + } + + @Override + public FileSystemKind getFileSystemKind() { + return FileSystemKind.OBJECT_STORE; + } +} diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/test/resources/log4j-test.properties b/flink-filesystems/flink-azure-fs-hadoop/src/test/resources/log4j-test.properties new file mode 100644 index 00000000000000..2be35890d31500 --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/test/resources/log4j-test.properties @@ -0,0 +1,27 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +log4j.rootLogger=OFF, testlogger + +# testlogger is set to be a ConsoleAppender. +log4j.appender.testlogger=org.apache.log4j.ConsoleAppender +log4j.appender.testlogger.target = System.err +log4j.appender.testlogger.layout=org.apache.log4j.PatternLayout +log4j.appender.testlogger.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopFileSystem.java b/flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopFileSystem.java index 6a5976aa9ad427..1135e011bf4fd7 100644 --- a/flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopFileSystem.java +++ b/flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopFileSystem.java @@ -224,8 +224,8 @@ public static org.apache.hadoop.fs.Path toHadoopPath(Path path) { static FileSystemKind getKindForScheme(String scheme) { scheme = scheme.toLowerCase(Locale.US); - if (scheme.startsWith("s3") || scheme.startsWith("emr") || scheme.startsWith("oss")) { - // the Amazon S3 storage or Aliyun OSS storage + if (scheme.startsWith("s3") || scheme.startsWith("emr") || scheme.startsWith("oss") || scheme.startsWith("wasb")) { + // the Amazon S3 storage or Aliyun OSS storage or Azure Blob Storage return FileSystemKind.OBJECT_STORE; } else if (scheme.startsWith("http") || scheme.startsWith("ftp")) { diff --git a/flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/HadoopConfigLoader.java b/flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/util/HadoopConfigLoader.java similarity index 99% rename from flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/HadoopConfigLoader.java rename to flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/util/HadoopConfigLoader.java index 1bbb7574277f44..aa8fdfe64457a6 100644 --- a/flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/HadoopConfigLoader.java +++ b/flink-filesystems/flink-hadoop-fs/src/main/java/org/apache/flink/runtime/util/HadoopConfigLoader.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.fs.s3.common; +package org.apache.flink.runtime.util; import org.apache.flink.configuration.Configuration; diff --git a/flink-filesystems/flink-s3-fs-base/pom.xml b/flink-filesystems/flink-s3-fs-base/pom.xml index 0b640a4197e3a5..00d408672932a4 100644 --- a/flink-filesystems/flink-s3-fs-base/pom.xml +++ b/flink-filesystems/flink-s3-fs-base/pom.xml @@ -166,7 +166,7 @@ under the License. org.apache.flink:flink-hadoop-fs - org/apache/flink/runtime/util/** + org/apache/flink/runtime/util/HadoopUtils org/apache/flink/runtime/fs/hdfs/HadoopRecoverable* diff --git a/flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/AbstractS3FileSystemFactory.java b/flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/AbstractS3FileSystemFactory.java index ff575be6f55c81..a576a96ae9c88f 100644 --- a/flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/AbstractS3FileSystemFactory.java +++ b/flink-filesystems/flink-s3-fs-base/src/main/java/org/apache/flink/fs/s3/common/AbstractS3FileSystemFactory.java @@ -26,6 +26,7 @@ import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.FileSystemFactory; import org.apache.flink.fs.s3.common.writer.S3AccessHelper; +import org.apache.flink.runtime.util.HadoopConfigLoader; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; diff --git a/flink-filesystems/flink-s3-fs-base/src/test/java/org/apache/flink/fs/s3/common/S3EntropyFsFactoryTest.java b/flink-filesystems/flink-s3-fs-base/src/test/java/org/apache/flink/fs/s3/common/S3EntropyFsFactoryTest.java index 943de1d8897274..ebf3b672b5194e 100644 --- a/flink-filesystems/flink-s3-fs-base/src/test/java/org/apache/flink/fs/s3/common/S3EntropyFsFactoryTest.java +++ b/flink-filesystems/flink-s3-fs-base/src/test/java/org/apache/flink/fs/s3/common/S3EntropyFsFactoryTest.java @@ -20,6 +20,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.fs.s3.common.writer.S3AccessHelper; +import org.apache.flink.runtime.util.HadoopConfigLoader; import org.apache.flink.util.TestLogger; import org.apache.hadoop.fs.FileSystem; diff --git a/flink-filesystems/flink-s3-fs-hadoop/pom.xml b/flink-filesystems/flink-s3-fs-hadoop/pom.xml index 9a5a80c053d957..e7cf95e232b862 100644 --- a/flink-filesystems/flink-s3-fs-hadoop/pom.xml +++ b/flink-filesystems/flink-s3-fs-hadoop/pom.xml @@ -106,6 +106,11 @@ under the License. com.amazon org.apache.flink.fs.s3base.shaded.com.amazon + + + org.apache.flink.runtime.util + org.apache.flink.fs.s3hadoop.common + diff --git a/flink-filesystems/flink-s3-fs-hadoop/src/main/java/org/apache/flink/fs/s3hadoop/S3FileSystemFactory.java b/flink-filesystems/flink-s3-fs-hadoop/src/main/java/org/apache/flink/fs/s3hadoop/S3FileSystemFactory.java index 2637e7b2e23cb8..6cad0511794d35 100644 --- a/flink-filesystems/flink-s3-fs-hadoop/src/main/java/org/apache/flink/fs/s3hadoop/S3FileSystemFactory.java +++ b/flink-filesystems/flink-s3-fs-hadoop/src/main/java/org/apache/flink/fs/s3hadoop/S3FileSystemFactory.java @@ -20,8 +20,8 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.fs.s3.common.AbstractS3FileSystemFactory; -import org.apache.flink.fs.s3.common.HadoopConfigLoader; import org.apache.flink.fs.s3.common.writer.S3AccessHelper; +import org.apache.flink.runtime.util.HadoopConfigLoader; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.s3a.S3AFileSystem; diff --git a/flink-filesystems/flink-s3-fs-hadoop/src/test/java/org/apache/flink/fs/s3hadoop/HadoopS3FileSystemTest.java b/flink-filesystems/flink-s3-fs-hadoop/src/test/java/org/apache/flink/fs/s3hadoop/HadoopS3FileSystemTest.java index 4471b3868dfdd7..57500f37ca70d2 100644 --- a/flink-filesystems/flink-s3-fs-hadoop/src/test/java/org/apache/flink/fs/s3hadoop/HadoopS3FileSystemTest.java +++ b/flink-filesystems/flink-s3-fs-hadoop/src/test/java/org/apache/flink/fs/s3hadoop/HadoopS3FileSystemTest.java @@ -19,7 +19,7 @@ package org.apache.flink.fs.s3hadoop; import org.apache.flink.configuration.Configuration; -import org.apache.flink.fs.s3.common.HadoopConfigLoader; +import org.apache.flink.runtime.util.HadoopConfigLoader; import org.junit.Test; diff --git a/flink-filesystems/flink-s3-fs-presto/pom.xml b/flink-filesystems/flink-s3-fs-presto/pom.xml index 8f88bbfdac03d7..3fc8e03d124be2 100644 --- a/flink-filesystems/flink-s3-fs-presto/pom.xml +++ b/flink-filesystems/flink-s3-fs-presto/pom.xml @@ -290,6 +290,12 @@ under the License. com.google org.apache.flink.fs.s3presto.shaded.com.google + + + + org.apache.flink.runtime.util + org.apache.flink.fs.s3presto.common + diff --git a/flink-filesystems/flink-s3-fs-presto/src/main/java/org/apache/flink/fs/s3presto/S3FileSystemFactory.java b/flink-filesystems/flink-s3-fs-presto/src/main/java/org/apache/flink/fs/s3presto/S3FileSystemFactory.java index c0c1beb6afe19a..5a1ffeef6127d6 100644 --- a/flink-filesystems/flink-s3-fs-presto/src/main/java/org/apache/flink/fs/s3presto/S3FileSystemFactory.java +++ b/flink-filesystems/flink-s3-fs-presto/src/main/java/org/apache/flink/fs/s3presto/S3FileSystemFactory.java @@ -20,8 +20,8 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.fs.s3.common.AbstractS3FileSystemFactory; -import org.apache.flink.fs.s3.common.HadoopConfigLoader; import org.apache.flink.fs.s3.common.writer.S3AccessHelper; +import org.apache.flink.runtime.util.HadoopConfigLoader; import org.apache.flink.util.FlinkRuntimeException; import com.facebook.presto.hive.s3.PrestoS3FileSystem; diff --git a/flink-filesystems/flink-s3-fs-presto/src/test/java/org/apache/flink/fs/s3presto/PrestoS3FileSystemTest.java b/flink-filesystems/flink-s3-fs-presto/src/test/java/org/apache/flink/fs/s3presto/PrestoS3FileSystemTest.java index 093efc8efabcdd..f3117a277d6bad 100644 --- a/flink-filesystems/flink-s3-fs-presto/src/test/java/org/apache/flink/fs/s3presto/PrestoS3FileSystemTest.java +++ b/flink-filesystems/flink-s3-fs-presto/src/test/java/org/apache/flink/fs/s3presto/PrestoS3FileSystemTest.java @@ -21,7 +21,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.fs.s3.common.FlinkS3FileSystem; -import org.apache.flink.fs.s3.common.HadoopConfigLoader; +import org.apache.flink.runtime.util.HadoopConfigLoader; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.AWSStaticCredentialsProvider; diff --git a/flink-filesystems/pom.xml b/flink-filesystems/pom.xml index 8da2ceac34bf30..c84e8535b839b2 100644 --- a/flink-filesystems/pom.xml +++ b/flink-filesystems/pom.xml @@ -47,6 +47,7 @@ under the License. flink-s3-fs-presto flink-swift-fs-hadoop flink-oss-fs-hadoop + flink-azure-fs-hadoop From 09dae1ff4380b8584c9fc52b8aac32edc96eaa2e Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Fri, 24 May 2019 15:32:12 +0200 Subject: [PATCH 19/92] [FLINK-12115][fs] Decrease dependency footprint of flink-azure-fs-hadoop --- .../flink-azure-fs-hadoop/pom.xml | 92 ++++++++++++++----- 1 file changed, 67 insertions(+), 25 deletions(-) diff --git a/flink-filesystems/flink-azure-fs-hadoop/pom.xml b/flink-filesystems/flink-azure-fs-hadoop/pom.xml index 37567ce633ecd3..781e85106947a8 100644 --- a/flink-filesystems/flink-azure-fs-hadoop/pom.xml +++ b/flink-filesystems/flink-azure-fs-hadoop/pom.xml @@ -34,7 +34,6 @@ under the License. - 2.7.0 1.16.0 2.9.4 @@ -55,16 +54,21 @@ under the License. - org.apache.hadoop - hadoop-azure - ${fs.azure.version} + org.apache.flink + flink-fs-hadoop-shaded + ${project.version} - org.apache.hadoop - hadoop-hdfs - ${hadoop.version} + hadoop-azure + ${fs.hadoopshaded.version} + + + org.apache.hadoop + hadoop-common + + @@ -74,18 +78,6 @@ under the License. ${fs.azure.sdk.version} test - - com.fasterxml.jackson.core - jackson-core - ${fs.jackson.core.version} - test - - - com.google.guava - guava - ${guava.version} - test - @@ -119,17 +111,69 @@ under the License. + org.apache.hadoop - org.apache.flink.fs.shaded.hadoop.org.apache.hadoop + org.apache.flink.fs.shaded.hadoop3.org.apache.hadoop - + + + + org.apache.commons + org.apache.flink.fs.shaded.hadoop3.org.apache.commons + + + - com.microsoft.azure.storage - org.apache.flink.fs.shaded.com.microsoft.azure.storage + com.microsoft.azure + org.apache.flink.fs.azure.shaded.com.microsoft.azure + + + + + org.apache.httpcomponents + org.apache.flink.fs.azure.shaded.org.apache.httpcomponents + + + commons-logging + org.apache.flink.fs.azure.shaded.commons-logging + + + commons-codec + org.apache.flink.fs.azure.shaded.commons-codec + + + com.fasterxml + org.apache.flink.fs.azure.shaded.com.fasterxml + + + com.google + org.apache.flink.fs.azure.shaded.com.google + + + org.eclipse + org.apache.flink.fs.azure.shaded.org.eclipse + + + + + org.apache.flink.runtime.fs.hdfs + org.apache.flink.fs.azure.common.hadoop + + + + org.apache.flink.runtime.util + org.apache.flink.fs.azure.common + + org.apache.flink:flink-hadoop-fs + + org/apache/flink/runtime/util/HadoopUtils + org/apache/flink/runtime/fs/hdfs/HadoopRecoverable* + + * @@ -141,8 +185,6 @@ under the License. META-INF/*.SF META-INF/*.DSA META-INF/*.RSA - core-default.xml - hdfs-default.xml From ec2c1a2944274c47bd000c569a903221d9b951f0 Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Fri, 24 May 2019 15:32:30 +0200 Subject: [PATCH 20/92] [FLINK-12115][fs] Add NOTICE file for flink-azure-fs-hadoop This closes #8537. This closes #8117. --- NOTICE-binary | 22 +++++++++++++++++++ .../src/main/resources/META-INF/NOTICE | 21 ++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/NOTICE diff --git a/NOTICE-binary b/NOTICE-binary index 2c20a5e674a7cb..a49077725d3044 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -305,6 +305,28 @@ See bundled license files for details. - font-awesome:4.5.0 (Font) +flink-azure-fs-hadoop +Copyright 2014-2019 The Apache Software Foundation + +This project includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +This project bundles the following dependencies under the Apache Software License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt) + +- com.fasterxml.jackson.core:jackson-annotations:2.7.0 +- com.fasterxml.jackson.core:jackson-core:2.7.8 +- com.fasterxml.jackson.core:jackson-databind:2.7.8 +- com.google.guava:guava:11.0.2 +- com.microsoft.azure:azure-keyvault-core:0.8.0 +- com.microsoft.azure:azure-storage:5.4.0 +- commons-codec:commons-codec:1.10 +- commons-logging:commons-logging:1.1.3 +- org.apache.hadoop:hadoop-azure:3.1.0 +- org.apache.httpcomponents:httpclient:4.5.3 +- org.apache.httpcomponents:httpcore:4.4.6 +- org.eclipse.jetty:jetty-util:9.3.19.v20170502 +- org.eclipse.jetty:jetty-util-ajax:9.3.19.v20170502 + flink-swift-fs-hadoop Copyright 2014-2019 The Apache Software Foundation diff --git a/flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/NOTICE b/flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/NOTICE new file mode 100644 index 00000000000000..92356263091899 --- /dev/null +++ b/flink-filesystems/flink-azure-fs-hadoop/src/main/resources/META-INF/NOTICE @@ -0,0 +1,21 @@ +flink-azure-fs-hadoop +Copyright 2014-2019 The Apache Software Foundation + +This project includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +This project bundles the following dependencies under the Apache Software License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt) + +- com.fasterxml.jackson.core:jackson-annotations:2.7.0 +- com.fasterxml.jackson.core:jackson-core:2.7.8 +- com.fasterxml.jackson.core:jackson-databind:2.7.8 +- com.google.guava:guava:11.0.2 +- com.microsoft.azure:azure-keyvault-core:0.8.0 +- com.microsoft.azure:azure-storage:5.4.0 +- commons-codec:commons-codec:1.10 +- commons-logging:commons-logging:1.1.3 +- org.apache.hadoop:hadoop-azure:3.1.0 +- org.apache.httpcomponents:httpclient:4.5.3 +- org.apache.httpcomponents:httpcore:4.4.6 +- org.eclipse.jetty:jetty-util:9.3.19.v20170502 +- org.eclipse.jetty:jetty-util-ajax:9.3.19.v20170502 From 4fe936dc8f65d57645eea908337653c76d83f400 Mon Sep 17 00:00:00 2001 From: Chesnay Schepler Date: Mon, 27 May 2019 14:59:18 +0200 Subject: [PATCH 21/92] [FLINK-12636][rest] Fail stability test on compatible modifications --- .../compatibility/RestAPIStabilityTest.java | 12 +++++++ .../src/test/resources/rest_api_v1.snapshot | 34 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java index a388b493fc4612..be47449e80252c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java @@ -131,6 +131,18 @@ private static void assertCompatible(final RestAPISnapshot old, final RestAPISna "To update the snapshot, re-run this test with -D" + REGENERATE_SNAPSHOT_PROPERTY + " being set."); } } + + // check for entirely new calls, for which the snapshot should be updated + for (final JsonNode curCall : cur.calls) { + final List> compatibilityCheckResults = old.calls.stream() + .map(oldCall -> Tuple2.of(curCall, checkCompatibility(oldCall, curCall))) + .collect(Collectors.toList()); + + if (compatibilityCheckResults.stream().noneMatch(result -> result.f1.getBackwardCompatibility() == Compatibility.IDENTICAL)) { + Assert.fail("The API was modified in a compatible way, but the snapshot was not updated. " + + "To update the snapshot, re-run this test with -D" + REGENERATE_SNAPSHOT_PROPERTY + " being set."); + } + } } private static void fail(final JsonNode oldCall, final List> compatibilityCheckResults) { diff --git a/flink-runtime/src/test/resources/rest_api_v1.snapshot b/flink-runtime/src/test/resources/rest_api_v1.snapshot index 7a110044267dce..a4a7b47f000fce 100644 --- a/flink-runtime/src/test/resources/rest_api_v1.snapshot +++ b/flink-runtime/src/test/resources/rest_api_v1.snapshot @@ -1268,6 +1268,40 @@ } } } + }, { + "url" : "/jobs/:jobid/stop-with-savepoint", + "method" : "POST", + "status-code" : "202 Accepted", + "file-upload" : false, + "path-parameters" : { + "pathParameters" : [ { + "key" : "jobid" + } ] + }, + "query-parameters" : { + "queryParameters" : [ ] + }, + "request" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:rest:messages:job:savepoints:stop:StopWithSavepointRequestBody", + "properties" : { + "targetDirectory" : { + "type" : "string" + }, + "endOfEventTime" : { + "type" : "boolean" + } + } + }, + "response" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:rest:handler:async:TriggerResponse", + "properties" : { + "request-id" : { + "type" : "any" + } + } + } }, { "url" : "/jobs/:jobid/vertices/:vertexid", "method" : "GET", From 3558bac8b2cd9609642414a0bf96d622653d144f Mon Sep 17 00:00:00 2001 From: Chesnay Schepler Date: Mon, 27 May 2019 14:49:15 +0200 Subject: [PATCH 22/92] [FLINK-12635][rest] Move stability test to runtime-web --- flink-runtime-web/pom.xml | 6 + .../rest/compatibility/Compatibility.java | 0 .../CompatibilityCheckResult.java | 0 .../compatibility/CompatibilityRoutine.java | 0 .../compatibility/CompatibilityRoutines.java | 0 .../compatibility/RestAPIStabilityTest.java | 0 .../src/test/resources/rest_api_v1.snapshot | 288 ++++++++++++++++++ flink-runtime/pom.xml | 6 - 8 files changed, 294 insertions(+), 6 deletions(-) rename {flink-runtime => flink-runtime-web}/src/test/java/org/apache/flink/runtime/rest/compatibility/Compatibility.java (100%) rename {flink-runtime => flink-runtime-web}/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityCheckResult.java (100%) rename {flink-runtime => flink-runtime-web}/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutine.java (100%) rename {flink-runtime => flink-runtime-web}/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutines.java (100%) rename {flink-runtime => flink-runtime-web}/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java (100%) rename {flink-runtime => flink-runtime-web}/src/test/resources/rest_api_v1.snapshot (90%) diff --git a/flink-runtime-web/pom.xml b/flink-runtime-web/pom.xml index d6db2ff8a35026..5207ab8bdb9c23 100644 --- a/flink-runtime-web/pom.xml +++ b/flink-runtime-web/pom.xml @@ -111,6 +111,12 @@ under the License. test + + org.apache.flink + flink-shaded-jackson-module-jsonSchema + test + + diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/Compatibility.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/Compatibility.java similarity index 100% rename from flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/Compatibility.java rename to flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/Compatibility.java diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityCheckResult.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityCheckResult.java similarity index 100% rename from flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityCheckResult.java rename to flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityCheckResult.java diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutine.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutine.java similarity index 100% rename from flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutine.java rename to flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutine.java diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutines.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutines.java similarity index 100% rename from flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutines.java rename to flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/CompatibilityRoutines.java diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java b/flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java similarity index 100% rename from flink-runtime/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java rename to flink-runtime-web/src/test/java/org/apache/flink/runtime/rest/compatibility/RestAPIStabilityTest.java diff --git a/flink-runtime/src/test/resources/rest_api_v1.snapshot b/flink-runtime-web/src/test/resources/rest_api_v1.snapshot similarity index 90% rename from flink-runtime/src/test/resources/rest_api_v1.snapshot rename to flink-runtime-web/src/test/resources/rest_api_v1.snapshot index a4a7b47f000fce..20ef479e2aca65 100644 --- a/flink-runtime/src/test/resources/rest_api_v1.snapshot +++ b/flink-runtime-web/src/test/resources/rest_api_v1.snapshot @@ -51,6 +51,294 @@ } } } + }, { + "url" : "/jars", + "method" : "GET", + "status-code" : "200 OK", + "file-upload" : false, + "path-parameters" : { + "pathParameters" : [ ] + }, + "query-parameters" : { + "queryParameters" : [ ] + }, + "request" : { + "type" : "any" + }, + "response" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarListInfo", + "properties" : { + "address" : { + "type" : "string" + }, + "files" : { + "type" : "array", + "items" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarListInfo:JarFileInfo", + "properties" : { + "id" : { + "type" : "string" + }, + "name" : { + "type" : "string" + }, + "uploaded" : { + "type" : "integer" + }, + "entry" : { + "type" : "array", + "items" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarListInfo:JarEntryInfo", + "properties" : { + "name" : { + "type" : "string" + }, + "description" : { + "type" : "string" + } + } + } + } + } + } + } + } + } + }, { + "url" : "/jars/upload", + "method" : "POST", + "status-code" : "200 OK", + "file-upload" : true, + "path-parameters" : { + "pathParameters" : [ ] + }, + "query-parameters" : { + "queryParameters" : [ ] + }, + "request" : { + "type" : "any" + }, + "response" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarUploadResponseBody", + "properties" : { + "filename" : { + "type" : "string" + }, + "status" : { + "type" : "string", + "enum" : [ "success" ] + } + } + } + }, { + "url" : "/jars/:jarid", + "method" : "DELETE", + "status-code" : "200 OK", + "file-upload" : false, + "path-parameters" : { + "pathParameters" : [ { + "key" : "jarid" + } ] + }, + "query-parameters" : { + "queryParameters" : [ ] + }, + "request" : { + "type" : "any" + }, + "response" : { + "type" : "any" + } + }, { + "url" : "/jars/:jarid/plan", + "method" : "GET", + "status-code" : "200 OK", + "file-upload" : false, + "path-parameters" : { + "pathParameters" : [ { + "key" : "jarid" + } ] + }, + "query-parameters" : { + "queryParameters" : [ { + "key" : "program-args", + "mandatory" : false + }, { + "key" : "programArg", + "mandatory" : false + }, { + "key" : "entry-class", + "mandatory" : false + }, { + "key" : "parallelism", + "mandatory" : false + } ] + }, + "request" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarPlanRequestBody", + "properties" : { + "entryClass" : { + "type" : "string" + }, + "programArgs" : { + "type" : "string" + }, + "programArgsList" : { + "type" : "array", + "items" : { + "type" : "string" + } + }, + "parallelism" : { + "type" : "integer" + }, + "jobId" : { + "type" : "any" + } + } + }, + "response" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:rest:messages:JobPlanInfo", + "properties" : { + "plan" : { + "type" : "any" + } + } + } + }, { + "url" : "/jars/:jarid/plan", + "method" : "GET", + "status-code" : "200 OK", + "file-upload" : false, + "path-parameters" : { + "pathParameters" : [ { + "key" : "jarid" + } ] + }, + "query-parameters" : { + "queryParameters" : [ { + "key" : "program-args", + "mandatory" : false + }, { + "key" : "programArg", + "mandatory" : false + }, { + "key" : "entry-class", + "mandatory" : false + }, { + "key" : "parallelism", + "mandatory" : false + } ] + }, + "request" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarPlanRequestBody", + "properties" : { + "entryClass" : { + "type" : "string" + }, + "programArgs" : { + "type" : "string" + }, + "programArgsList" : { + "type" : "array", + "items" : { + "type" : "string" + } + }, + "parallelism" : { + "type" : "integer" + }, + "jobId" : { + "type" : "any" + } + } + }, + "response" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:rest:messages:JobPlanInfo", + "properties" : { + "plan" : { + "type" : "any" + } + } + } + }, { + "url" : "/jars/:jarid/run", + "method" : "POST", + "status-code" : "200 OK", + "file-upload" : false, + "path-parameters" : { + "pathParameters" : [ { + "key" : "jarid" + } ] + }, + "query-parameters" : { + "queryParameters" : [ { + "key" : "allowNonRestoredState", + "mandatory" : false + }, { + "key" : "savepointPath", + "mandatory" : false + }, { + "key" : "program-args", + "mandatory" : false + }, { + "key" : "programArg", + "mandatory" : false + }, { + "key" : "entry-class", + "mandatory" : false + }, { + "key" : "parallelism", + "mandatory" : false + } ] + }, + "request" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarRunRequestBody", + "properties" : { + "entryClass" : { + "type" : "string" + }, + "programArgs" : { + "type" : "string" + }, + "programArgsList" : { + "type" : "array", + "items" : { + "type" : "string" + } + }, + "parallelism" : { + "type" : "integer" + }, + "jobId" : { + "type" : "any" + }, + "allowNonRestoredState" : { + "type" : "boolean" + }, + "savepointPath" : { + "type" : "string" + } + } + }, + "response" : { + "type" : "object", + "id" : "urn:jsonschema:org:apache:flink:runtime:webmonitor:handlers:JarRunResponseBody", + "properties" : { + "jobid" : { + "type" : "any" + } + } + } }, { "url" : "/jobmanager/config", "method" : "GET", diff --git a/flink-runtime/pom.xml b/flink-runtime/pom.xml index 6e7b5a3af63cbb..a75c20f59da0df 100644 --- a/flink-runtime/pom.xml +++ b/flink-runtime/pom.xml @@ -297,12 +297,6 @@ under the License. test - - org.apache.flink - flink-shaded-jackson-module-jsonSchema - test - - com.typesafe.akka akka-testkit_${scala.binary.version} From 97510aaa8444895a0cc4df7461889fb66d5ffc01 Mon Sep 17 00:00:00 2001 From: Rui Li Date: Mon, 27 May 2019 19:13:21 +0800 Subject: [PATCH 23/92] [FLINK-12418][hive] Add input/output format and SerDeLib information when creating Hive table in HiveCatalog and add 'hive-exec' as provided dependency To set input/output formats and SerDe lib when creating Hive tables in HiveCatalog, so that we can access these tables later. Also added 'hive-exec' as provided dependency. This closes #8553. --- flink-connectors/flink-connector-hive/pom.xml | 6 ++-- .../flink/table/catalog/hive/HiveCatalog.java | 34 +++++++++++++------ .../hive/HiveCatalogHiveMetadataTest.java | 19 +++++++++++ 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/flink-connectors/flink-connector-hive/pom.xml b/flink-connectors/flink-connector-hive/pom.xml index 1b672d4a691511..25e475efeee3f4 100644 --- a/flink-connectors/flink-connector-hive/pom.xml +++ b/flink-connectors/flink-connector-hive/pom.xml @@ -239,13 +239,11 @@ under the License. - - org.apache.hive hive-exec ${hive.version} - test + provided org.apache.hive @@ -334,6 +332,8 @@ under the License. + + org.apache.derby derby diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java index d2387f08fae732..159499ca136b96 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java @@ -18,6 +18,7 @@ package org.apache.flink.table.catalog.hive; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.catalog.AbstractCatalogTable; import org.apache.flink.table.catalog.AbstractCatalogView; @@ -67,10 +68,12 @@ import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; import org.apache.hadoop.hive.metastore.api.Partition; import org.apache.hadoop.hive.metastore.api.PrincipalType; -import org.apache.hadoop.hive.metastore.api.SerDeInfo; import org.apache.hadoop.hive.metastore.api.StorageDescriptor; import org.apache.hadoop.hive.metastore.api.Table; import org.apache.hadoop.hive.metastore.api.UnknownDBException; +import org.apache.hadoop.hive.ql.io.StorageFormatDescriptor; +import org.apache.hadoop.hive.ql.io.StorageFormatFactory; +import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.thrift.TException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,6 +93,8 @@ public class HiveCatalog implements Catalog { private static final Logger LOG = LoggerFactory.getLogger(HiveCatalog.class); private static final String DEFAULT_DB = "default"; + private static final StorageFormatFactory storageFormatFactory = new StorageFormatFactory(); + private static final String DEFAULT_HIVE_TABLE_STORAGE_FORMAT = "TextFile"; // Prefix used to distinguish properties created by Hive and Flink, // as Hive metastore has its own properties created upon table creation and migration between different versions of metastore. @@ -474,7 +479,8 @@ public boolean tableExists(ObjectPath tablePath) throws CatalogException { } } - private Table getHiveTable(ObjectPath tablePath) throws TableNotExistException { + @VisibleForTesting + Table getHiveTable(ObjectPath tablePath) throws TableNotExistException { try { return client.getTable(tablePath.getDatabaseName(), tablePath.getObjectName()); } catch (NoSuchObjectException e) { @@ -534,9 +540,9 @@ private static CatalogBaseTable instantiateHiveCatalogTable(Table hiveTable) { } private static Table instantiateHiveTable(ObjectPath tablePath, CatalogBaseTable table) { - Table hiveTable = new Table(); - hiveTable.setDbName(tablePath.getDatabaseName()); - hiveTable.setTableName(tablePath.getObjectName()); + // let Hive set default parameters for us, e.g. serialization.format + Table hiveTable = org.apache.hadoop.hive.ql.metadata.Table.getEmptyTable(tablePath.getDatabaseName(), + tablePath.getObjectName()); hiveTable.setCreateTime((int) (System.currentTimeMillis() / 1000)); Map properties = new HashMap<>(table.getProperties()); @@ -549,11 +555,8 @@ private static Table instantiateHiveTable(ObjectPath tablePath, CatalogBaseTabl hiveTable.setParameters(properties); // Hive table's StorageDescriptor - // TODO: This is very basic Hive table. - // [FLINK-11479] Add input/output format and SerDeLib information for Hive tables. - StorageDescriptor sd = new StorageDescriptor(); - hiveTable.setSd(sd); - sd.setSerdeInfo(new SerDeInfo(null, null, new HashMap<>())); + StorageDescriptor sd = hiveTable.getSd(); + setStorageFormat(sd, properties); List allColumns = HiveTableUtil.createHiveColumns(table.getSchema()); @@ -590,6 +593,17 @@ private static Table instantiateHiveTable(ObjectPath tablePath, CatalogBaseTabl return hiveTable; } + private static void setStorageFormat(StorageDescriptor sd, Map properties) { + // TODO: allow user to specify storage format. Simply use text format for now + String storageFormatName = DEFAULT_HIVE_TABLE_STORAGE_FORMAT; + StorageFormatDescriptor storageFormatDescriptor = storageFormatFactory.get(storageFormatName); + checkArgument(storageFormatDescriptor != null, "Unknown storage format " + storageFormatName); + sd.setInputFormat(storageFormatDescriptor.getInputFormat()); + sd.setOutputFormat(storageFormatDescriptor.getOutputFormat()); + String serdeLib = storageFormatDescriptor.getSerde(); + sd.getSerdeInfo().setSerializationLib(serdeLib != null ? serdeLib : LazySimpleSerDe.class.getName()); + } + /** * Filter out Hive-created properties, and return Flink-created properties. */ diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveCatalogHiveMetadataTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveCatalogHiveMetadataTest.java index e2d95a3c1f1f71..5a80d8504d00fd 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveCatalogHiveMetadataTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveCatalogHiveMetadataTest.java @@ -24,13 +24,17 @@ import org.apache.flink.table.catalog.CatalogTable; import org.apache.flink.table.catalog.CatalogTestBase; import org.apache.flink.table.catalog.CatalogView; +import org.apache.flink.util.StringUtils; +import org.apache.hadoop.hive.metastore.api.Table; import org.junit.BeforeClass; +import org.junit.Test; import java.io.IOException; import java.util.HashMap; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; /** @@ -51,6 +55,21 @@ public static void init() throws IOException { public void testCreateTable_Streaming() throws Exception { } + @Test + // verifies that input/output formats and SerDe are set for Hive tables + public void testCreateTable_StorageFormatSet() throws Exception { + catalog.createDatabase(db1, createDb(), false); + catalog.createTable(path1, createTable(), false); + + Table hiveTable = ((HiveCatalog) catalog).getHiveTable(path1); + String inputFormat = hiveTable.getSd().getInputFormat(); + String outputFormat = hiveTable.getSd().getOutputFormat(); + String serde = hiveTable.getSd().getSerdeInfo().getSerializationLib(); + assertFalse(StringUtils.isNullOrWhitespaceOnly(inputFormat)); + assertFalse(StringUtils.isNullOrWhitespaceOnly(outputFormat)); + assertFalse(StringUtils.isNullOrWhitespaceOnly(serde)); + } + // ------ utils ------ @Override From d3e5bf69c3d3f105c1ebe7c46bd4a74e5407137f Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Fri, 24 May 2019 12:04:13 -0700 Subject: [PATCH 24/92] [FLINK-9172][sql-client][tabe] Support catalogs in SQL-Client yaml config file This PR adds support for basic catalog entries in SQL-Client yaml config file, adds CatalogFactory and CatalogDescriptor, and hooks them up with SQL Client thru table factory discovery service. This closes #8541. --- .../table/client/config/Environment.java | 37 +++++- .../client/config/entries/CatalogEntry.java | 70 ++++++++++ .../client/config/entries/ConfigEntry.java | 2 +- .../client/config/entries/ExecutionEntry.java | 14 ++ .../gateway/local/ExecutionContext.java | 33 +++++ .../test/assembly/test-table-factories.xml | 13 -- .../client/gateway/local/DependencyTest.java | 48 +++++++ .../client/gateway/local/EnvironmentTest.java | 29 +++++ .../gateway/local/ExecutionContextTest.java | 20 +++ ....apache.flink.table.factories.TableFactory | 1 + .../resources/test-sql-client-catalogs.yaml | 123 ++++++++++++++++++ .../resources/test-sql-client-defaults.yaml | 4 +- .../resources/test-sql-client-factory.yaml | 4 +- .../table/descriptors/CatalogDescriptor.java | 62 +++++++++ .../CatalogDescriptorValidator.java | 45 +++++++ .../flink/table/factories/CatalogFactory.java | 40 ++++++ .../table/factories/TableFactoryService.java | 1 + .../descriptors/CatalogDescriptorTest.java | 108 +++++++++++++++ .../factories/CatalogFactoryServiceTest.java | 80 ++++++++++++ .../factories/utils/TestCatalogFactory.java | 58 +++++++++ ....apache.flink.table.factories.TableFactory | 1 + 21 files changed, 776 insertions(+), 17 deletions(-) create mode 100644 flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/CatalogEntry.java create mode 100644 flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/CatalogFactory.java create mode 100644 flink-table/flink-table-common/src/test/java/org/apache/flink/table/descriptors/CatalogDescriptorTest.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/CatalogFactoryServiceTest.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/utils/TestCatalogFactory.java diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/Environment.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/Environment.java index 12a4f22ed8341e..64c9453856236f 100644 --- a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/Environment.java +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/Environment.java @@ -19,6 +19,7 @@ package org.apache.flink.table.client.config; import org.apache.flink.table.client.SqlClientException; +import org.apache.flink.table.client.config.entries.CatalogEntry; import org.apache.flink.table.client.config.entries.DeploymentEntry; import org.apache.flink.table.client.config.entries.ExecutionEntry; import org.apache.flink.table.client.config.entries.FunctionEntry; @@ -37,7 +38,7 @@ /** * Environment configuration that represents the content of an environment file. Environment files - * define tables, execution, and deployment behavior. An environment might be defined by default or + * define catalogs, tables, execution, and deployment behavior. An environment might be defined by default or * as part of a session. Environments can be merged or enriched with properties (e.g. from CLI command). * *

In future versions, we might restrict the merging or enrichment of deployment properties to not @@ -49,6 +50,8 @@ public class Environment { public static final String DEPLOYMENT_ENTRY = "deployment"; + private Map catalogs; + private Map tables; private Map functions; @@ -58,12 +61,31 @@ public class Environment { private DeploymentEntry deployment; public Environment() { + this.catalogs = Collections.emptyMap(); this.tables = Collections.emptyMap(); this.functions = Collections.emptyMap(); this.execution = ExecutionEntry.DEFAULT_INSTANCE; this.deployment = DeploymentEntry.DEFAULT_INSTANCE; } + public Map getCatalogs() { + return catalogs; + } + + public void setCatalogs(List> catalogs) { + this.catalogs = new HashMap<>(catalogs.size()); + + catalogs.forEach(config -> { + final CatalogEntry catalog = CatalogEntry.create(config); + if (this.catalogs.containsKey(catalog.getName())) { + throw new SqlClientException( + String.format("Cannot create catalog '%s' because a catalog with this name is already registered.", + catalog.getName())); + } + this.catalogs.put(catalog.getName(), catalog); + }); + } + public Map getTables() { return tables; } @@ -117,6 +139,11 @@ public DeploymentEntry getDeployment() { @Override public String toString() { final StringBuilder sb = new StringBuilder(); + sb.append("===================== Catalogs =====================\n"); + catalogs.forEach((name, catalog) -> { + sb.append("- ").append(CatalogEntry.CATALOG_NAME).append(": ").append(name).append("\n"); + catalog.asMap().forEach((k, v) -> sb.append(" ").append(k).append(": ").append(v).append('\n')); + }); sb.append("===================== Tables =====================\n"); tables.forEach((name, table) -> { sb.append("- ").append(TableEntry.TABLES_NAME).append(": ").append(name).append("\n"); @@ -164,6 +191,11 @@ public static Environment parse(String content) throws IOException { public static Environment merge(Environment env1, Environment env2) { final Environment mergedEnv = new Environment(); + // merge catalogs + final Map catalogs = new HashMap<>(env1.getCatalogs()); + catalogs.putAll(env2.getCatalogs()); + mergedEnv.catalogs = catalogs; + // merge tables final Map tables = new LinkedHashMap<>(env1.getTables()); tables.putAll(env2.getTables()); @@ -192,6 +224,9 @@ public static Environment enrich( Map views) { final Environment enrichedEnv = new Environment(); + // merge catalogs + enrichedEnv.catalogs = new LinkedHashMap<>(env.getCatalogs()); + // merge tables enrichedEnv.tables = new LinkedHashMap<>(env.getTables()); enrichedEnv.tables.putAll(views); diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/CatalogEntry.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/CatalogEntry.java new file mode 100644 index 00000000000000..b385c6c03f21ca --- /dev/null +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/CatalogEntry.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.client.config.entries; + +import org.apache.flink.table.client.config.ConfigUtil; +import org.apache.flink.table.descriptors.DescriptorProperties; + +import java.util.Collections; +import java.util.Map; + +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; + +/** + * Describes a catalog configuration entry. + */ +public class CatalogEntry extends ConfigEntry { + + public static final String CATALOG_NAME = "name"; + + private final String name; + + protected CatalogEntry(String name, DescriptorProperties properties) { + super(properties); + this.name = name; + } + + public String getName() { + return name; + } + + @Override + protected void validate(DescriptorProperties properties) { + properties.validateString(CATALOG_TYPE, false, 1); + properties.validateInt(CATALOG_PROPERTY_VERSION, true, 0); + + // further validation is performed by the discovered factory + } + + public static CatalogEntry create(Map config) { + return create(ConfigUtil.normalizeYaml(config)); + } + + private static CatalogEntry create(DescriptorProperties properties) { + properties.validateString(CATALOG_NAME, false, 1); + + final String name = properties.getString(CATALOG_NAME); + + final DescriptorProperties cleanedProperties = + properties.withoutKeys(Collections.singletonList(CATALOG_NAME)); + + return new CatalogEntry(name, cleanedProperties); + } +} diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ConfigEntry.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ConfigEntry.java index 13d227da2b4757..614aab0d19cd75 100644 --- a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ConfigEntry.java +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ConfigEntry.java @@ -26,7 +26,7 @@ import java.util.Objects; /** - * Describes an environment configuration entry (such as table, functions, views). Config entries + * Describes an environment configuration entry (such as catalogs, table, functions, views). Config entries * are similar to {@link org.apache.flink.table.descriptors.Descriptor} but apply to SQL Client's * environment files only. */ diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ExecutionEntry.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ExecutionEntry.java index d9d113ddd75e0b..a1d47a0d8b6376 100644 --- a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ExecutionEntry.java +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/config/entries/ExecutionEntry.java @@ -97,6 +97,10 @@ public class ExecutionEntry extends ConfigEntry { private static final String EXECUTION_RESTART_STRATEGY_MAX_FAILURES_PER_INTERVAL = "restart-strategy.max-failures-per-interval"; + private static final String EXECUTION_CURRNET_CATALOG = "current-catalog"; + + private static final String EXECUTION_CURRNET_DATABASE = "current-database"; + private ExecutionEntry(DescriptorProperties properties) { super(properties); } @@ -133,6 +137,8 @@ protected void validate(DescriptorProperties properties) { properties.validateLong(EXECUTION_RESTART_STRATEGY_DELAY, true, 0); properties.validateLong(EXECUTION_RESTART_STRATEGY_FAILURE_RATE_INTERVAL, true, 1); properties.validateInt(EXECUTION_RESTART_STRATEGY_MAX_FAILURES_PER_INTERVAL, true, 1); + properties.validateString(EXECUTION_CURRNET_CATALOG, true, 1); + properties.validateString(EXECUTION_CURRNET_DATABASE, true, 1); } public boolean isStreamingExecution() { @@ -230,6 +236,14 @@ public RestartStrategies.RestartStrategyConfiguration getRestartStrategy() { EXECUTION_RESTART_STRATEGY_TYPE_VALUE_FALLBACK)); } + public Optional getCurrentCatalog() { + return properties.getOptionalString(EXECUTION_CURRNET_CATALOG); + } + + public Optional getCurrentDatabase() { + return properties.getOptionalString(EXECUTION_CURRNET_DATABASE); + } + public boolean isChangelogMode() { return properties.getOptionalString(EXECUTION_RESULT_MODE) .map((v) -> v.equals(EXECUTION_RESULT_MODE_VALUE_CHANGELOG)) diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/ExecutionContext.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/ExecutionContext.java index 8c9a34c4347076..ead3b180658f71 100644 --- a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/ExecutionContext.java +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/ExecutionContext.java @@ -46,6 +46,7 @@ import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.java.BatchTableEnvironment; import org.apache.flink.table.api.java.StreamTableEnvironment; +import org.apache.flink.table.catalog.Catalog; import org.apache.flink.table.client.config.Environment; import org.apache.flink.table.client.config.entries.DeploymentEntry; import org.apache.flink.table.client.config.entries.ExecutionEntry; @@ -58,6 +59,7 @@ import org.apache.flink.table.client.gateway.SqlExecutionException; import org.apache.flink.table.factories.BatchTableSinkFactory; import org.apache.flink.table.factories.BatchTableSourceFactory; +import org.apache.flink.table.factories.CatalogFactory; import org.apache.flink.table.factories.StreamTableSinkFactory; import org.apache.flink.table.factories.StreamTableSourceFactory; import org.apache.flink.table.factories.TableFactoryService; @@ -77,6 +79,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; /** @@ -92,6 +95,7 @@ public class ExecutionContext { private final Environment mergedEnv; private final List dependencies; private final ClassLoader classLoader; + private final Map catalogs; private final Map> tableSources; private final Map> tableSinks; private final Map functions; @@ -114,6 +118,12 @@ public ExecutionContext(Environment defaultEnvironment, SessionContext sessionCo dependencies.toArray(new URL[dependencies.size()]), this.getClass().getClassLoader()); + // create catalogs + catalogs = new LinkedHashMap<>(); + mergedEnv.getCatalogs().forEach((name, entry) -> + catalogs.put(name, createCatalog(name, entry.asMap(), classLoader)) + ); + // create table sources & sinks. tableSources = new LinkedHashMap<>(); tableSinks = new LinkedHashMap<>(); @@ -174,6 +184,10 @@ public EnvironmentInstance createEnvironmentInstance() { } } + public Map getCatalogs() { + return catalogs; + } + public Map> getTableSources() { return tableSources; } @@ -227,6 +241,12 @@ private static ClusterSpecification createClusterSpecification(CustomCommandLine } } + private Catalog createCatalog(String name, Map catalogProperties, ClassLoader classLoader) { + final CatalogFactory factory = + TableFactoryService.find(CatalogFactory.class, catalogProperties, classLoader); + return factory.createCatalog(name, catalogProperties); + } + private static TableSource createTableSource(ExecutionEntry execution, Map sourceProperties, ClassLoader classLoader) { if (execution.isStreamingExecution()) { final StreamTableSourceFactory factory = (StreamTableSourceFactory) @@ -281,6 +301,19 @@ private EnvironmentInstance() { throw new SqlExecutionException("Unsupported execution type specified."); } + // register catalogs + catalogs.forEach(tableEnv::registerCatalog); + + Optional potentialCurrentCatalog = mergedEnv.getExecution().getCurrentCatalog(); + if (potentialCurrentCatalog.isPresent()) { + tableEnv.useCatalog(potentialCurrentCatalog.get()); + } + + Optional potentialCurrentDatabase = mergedEnv.getExecution().getCurrentDatabase(); + if (potentialCurrentDatabase.isPresent()) { + tableEnv.useDatabase(potentialCurrentDatabase.get()); + } + // create query config queryConfig = createQueryConfig(); diff --git a/flink-table/flink-sql-client/src/test/assembly/test-table-factories.xml b/flink-table/flink-sql-client/src/test/assembly/test-table-factories.xml index fafffae4a2d2b7..a45e74426b0e5e 100644 --- a/flink-table/flink-sql-client/src/test/assembly/test-table-factories.xml +++ b/flink-table/flink-sql-client/src/test/assembly/test-table-factories.xml @@ -24,19 +24,6 @@ under the License. jar false - - - ${project.build.testOutputDirectory} - / - - - org/apache/flink/table/client/gateway/utils/TestTableSourceFactory.class - org/apache/flink/table/client/gateway/utils/TestTableSourceFactory$*.class - org/apache/flink/table/client/gateway/utils/TestTableSinkFactory.class - org/apache/flink/table/client/gateway/utils/TestTableSinkFactory$*.class - - - diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java index 2fe4975b5f3f56..109246c8909608 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java @@ -22,20 +22,27 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.api.Types; +import org.apache.flink.table.catalog.Catalog; +import org.apache.flink.table.catalog.GenericInMemoryCatalog; import org.apache.flink.table.client.config.Environment; import org.apache.flink.table.client.gateway.SessionContext; import org.apache.flink.table.client.gateway.utils.EnvironmentFileUtil; import org.apache.flink.table.client.gateway.utils.TestTableSinkFactoryBase; import org.apache.flink.table.client.gateway.utils.TestTableSourceFactoryBase; +import org.apache.flink.table.descriptors.DescriptorProperties; +import org.apache.flink.table.factories.CatalogFactory; import org.junit.Test; import java.net.URL; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; import static org.junit.Assert.assertEquals; /** @@ -46,6 +53,8 @@ public class DependencyTest { public static final String CONNECTOR_TYPE_VALUE = "test-connector"; public static final String TEST_PROPERTY = "test-property"; + public static final String CATALOG_TYPE_TEST = "DependencyTest"; + private static final String FACTORY_ENVIRONMENT_FILE = "test-sql-client-factory.yaml"; private static final String TABLE_FACTORY_JAR_FILE = "table-factories-test-jar.jar"; @@ -99,4 +108,43 @@ public TestTableSinkFactory() { super(CONNECTOR_TYPE_VALUE, TEST_PROPERTY); } } + + /** + * External catalog that can be discovered if classloading is correct. + */ + public static class TestCatalogFactory implements CatalogFactory { + + @Override + public Map requiredContext() { + final Map context = new HashMap<>(); + context.put(CATALOG_TYPE, CATALOG_TYPE_TEST); + return context; + } + + @Override + public List supportedProperties() { + final List properties = new ArrayList<>(); + properties.add(TEST_PROPERTY); + return properties; + } + + @Override + public Catalog createCatalog(String name, Map properties) { + final DescriptorProperties params = new DescriptorProperties(true); + params.putProperties(properties); + return new TestCatalog(name); + } + } + + /** + * Test catalog. + */ + public static class TestCatalog extends GenericInMemoryCatalog { + + private static final String TEST_DATABASE_NAME = "mydatabase"; + + public TestCatalog(String name) { + super(name, TEST_DATABASE_NAME); + } + } } diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/EnvironmentTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/EnvironmentTest.java index e6df0c355a06a6..4cddb535363954 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/EnvironmentTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/EnvironmentTest.java @@ -18,16 +18,22 @@ package org.apache.flink.table.client.gateway.local; +import org.apache.flink.table.client.SqlClientException; import org.apache.flink.table.client.config.Environment; import org.apache.flink.table.client.gateway.utils.EnvironmentFileUtil; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import static org.apache.flink.table.client.config.entries.CatalogEntry.CATALOG_NAME; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -39,6 +45,9 @@ public class EnvironmentTest { private static final String DEFAULTS_ENVIRONMENT_FILE = "test-sql-client-defaults.yaml"; private static final String FACTORY_ENVIRONMENT_FILE = "test-sql-client-factory.yaml"; + @Rule + public ExpectedException exception = ExpectedException.none(); + @Test public void testMerging() throws Exception { final Map replaceVars1 = new HashMap<>(); @@ -70,4 +79,24 @@ public void testMerging() throws Exception { assertTrue(merged.getExecution().isStreamingExecution()); assertEquals(16, merged.getExecution().getMaxParallelism()); } + + @Test + public void testDuplicateCatalog() { + exception.expect(SqlClientException.class); + exception.expectMessage("Cannot create catalog 'catalog2' because a catalog with this name is already registered."); + Environment env = new Environment(); + env.setCatalogs(Arrays.asList( + createCatalog("catalog1", "test"), + createCatalog("catalog2", "test"), + createCatalog("catalog2", "test"))); + } + + private static Map createCatalog(String name, String type) { + Map prop = new HashMap<>(); + + prop.put(CATALOG_NAME, name); + prop.put(CATALOG_TYPE, type); + + return prop; + } } diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java index 22aadc644edb57..21cf7d04fd56d6 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java @@ -52,6 +52,7 @@ public class ExecutionContextTest { private static final String DEFAULTS_ENVIRONMENT_FILE = "test-sql-client-defaults.yaml"; + private static final String CATALOGS_ENVIRONMENT_FILE = "test-sql-client-catalogs.yaml"; private static final String STREAMING_ENVIRONMENT_FILE = "test-sql-client-streaming.yaml"; @Test @@ -70,6 +71,16 @@ public void testExecutionConfig() throws Exception { assertEquals(1_000, failureRateStrategy.getDelayBetweenAttemptsInterval().toMilliseconds()); } + @Test + public void testCatalogs() throws Exception { + final String catalogName = "catalog1"; + final ExecutionContext context = createCatalogExecutionContext(); + final TableEnvironment tableEnv = context.createEnvironmentInstance().getTableEnvironment(); + + assertEquals(tableEnv.getCurrentCatalog(), catalogName); + assertEquals(tableEnv.getCurrentDatabase(), "mydatabase"); + } + @Test public void testFunctions() throws Exception { final ExecutionContext context = createDefaultExecutionContext(); @@ -173,6 +184,15 @@ private ExecutionContext createDefaultExecutionContext() throws Exception return createExecutionContext(DEFAULTS_ENVIRONMENT_FILE, replaceVars); } + private ExecutionContext createCatalogExecutionContext() throws Exception { + final Map replaceVars = new HashMap<>(); + replaceVars.put("$VAR_EXECUTION_TYPE", "streaming"); + replaceVars.put("$VAR_RESULT_MODE", "changelog"); + replaceVars.put("$VAR_UPDATE_MODE", "update-mode: append"); + replaceVars.put("$VAR_MAX_ROWS", "100"); + return createExecutionContext(CATALOGS_ENVIRONMENT_FILE, replaceVars); + } + private ExecutionContext createStreamingExecutionContext() throws Exception { final Map replaceVars = new HashMap<>(); replaceVars.put("$VAR_CONNECTOR_TYPE", DummyTableSourceFactory.CONNECTOR_TYPE_VALUE); diff --git a/flink-table/flink-sql-client/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory b/flink-table/flink-sql-client/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory index 54d6fc7840daa6..7d81c838500794 100644 --- a/flink-table/flink-sql-client/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory +++ b/flink-table/flink-sql-client/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory @@ -15,3 +15,4 @@ org.apache.flink.table.client.gateway.utils.DummyTableSinkFactory org.apache.flink.table.client.gateway.utils.DummyTableSourceFactory +org.apache.flink.table.client.gateway.local.DependencyTest$TestCatalogFactory diff --git a/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml b/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml new file mode 100644 index 00000000000000..e915930812baa0 --- /dev/null +++ b/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml @@ -0,0 +1,123 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +#============================================================================== +# TEST ENVIRONMENT FILE +# General purpose default environment file. +#============================================================================== + +# this file has variables that can be filled with content by replacing $VAR_XXX + +tables: + - name: TableNumber1 + type: source-table + $VAR_UPDATE_MODE + schema: + - name: IntegerField1 + type: INT + - name: StringField1 + type: VARCHAR + connector: + type: filesystem + path: "$VAR_SOURCE_PATH1" + format: + type: csv + fields: + - name: IntegerField1 + type: INT + - name: StringField1 + type: VARCHAR + line-delimiter: "\n" + comment-prefix: "#" + - name: TestView1 + type: view + query: SELECT scalarUDF(IntegerField1) FROM default_catalog.default_database.TableNumber1 + - name: TableNumber2 + # Test backwards compatibility ("source" -> "source-table") + type: source + $VAR_UPDATE_MODE + schema: + - name: IntegerField2 + type: INT + - name: StringField2 + type: VARCHAR + connector: + type: filesystem + path: "$VAR_SOURCE_PATH2" + format: + type: csv + fields: + - name: IntegerField2 + type: INT + - name: StringField2 + type: VARCHAR + line-delimiter: "\n" + comment-prefix: "#" + - name: TestView2 + type: view + query: SELECT * FROM default_catalog.default_database.TestView1 + +functions: + - name: scalarUDF + from: class + class: org.apache.flink.table.client.gateway.utils.UserDefinedFunctions$ScalarUDF + constructor: + - 5 + - name: aggregateUDF + from: class + class: org.apache.flink.table.client.gateway.utils.UserDefinedFunctions$AggregateUDF + constructor: + - StarryName + - false + - class: java.lang.Integer + constructor: + - class: java.lang.String + constructor: + - type: VARCHAR + value: 3 + - name: tableUDF + from: class + class: org.apache.flink.table.client.gateway.utils.UserDefinedFunctions$TableUDF + constructor: + - type: LONG + value: 5 + +execution: + type: "$VAR_EXECUTION_TYPE" + time-characteristic: event-time + periodic-watermarks-interval: 99 + parallelism: 1 + max-parallelism: 16 + min-idle-state-retention: 0 + max-idle-state-retention: 0 + result-mode: "$VAR_RESULT_MODE" + max-table-result-rows: "$VAR_MAX_ROWS" + restart-strategy: + type: failure-rate + max-failures-per-interval: 10 + failure-rate-interval: 99000 + delay: 1000 + current-catalog: catalog1 + current-database: mydatabase + +deployment: + response-timeout: 5000 + +catalogs: + - name: catalog1 + type: DependencyTest diff --git a/flink-table/flink-sql-client/src/test/resources/test-sql-client-defaults.yaml b/flink-table/flink-sql-client/src/test/resources/test-sql-client-defaults.yaml index 7c7a7caaf0c800..9e0582be44d254 100644 --- a/flink-table/flink-sql-client/src/test/resources/test-sql-client-defaults.yaml +++ b/flink-table/flink-sql-client/src/test/resources/test-sql-client-defaults.yaml @@ -134,4 +134,6 @@ execution: deployment: response-timeout: 5000 - +catalogs: + - name: catalog1 + type: DependencyTest diff --git a/flink-table/flink-sql-client/src/test/resources/test-sql-client-factory.yaml b/flink-table/flink-sql-client/src/test/resources/test-sql-client-factory.yaml index 33c2e6b9649ef2..de60538f1eedcf 100644 --- a/flink-table/flink-sql-client/src/test/resources/test-sql-client-factory.yaml +++ b/flink-table/flink-sql-client/src/test/resources/test-sql-client-factory.yaml @@ -51,4 +51,6 @@ execution: deployment: response-timeout: 5000 - +catalogs: + - name: catalog2 + type: DependencyTest diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java new file mode 100644 index 00000000000000..18b433ec4da43c --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.descriptors; + +import org.apache.flink.annotation.PublicEvolving; + +import java.util.Map; + +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; + +/** + * Describes a catalog of tables, views, and functions. + */ +@PublicEvolving +public abstract class CatalogDescriptor extends DescriptorBase { + + private final String type; + + private final int propertyVersion; + + /** + * Constructs a {@link CatalogDescriptor}. + * + * @param type string that identifies this catalog + * @param propertyVersion property version for backwards compatibility + */ + public CatalogDescriptor(String type, int propertyVersion) { + this.type = type; + this.propertyVersion = propertyVersion; + } + + @Override + public final Map toProperties() { + final DescriptorProperties properties = new DescriptorProperties(); + properties.putString(CATALOG_TYPE, type); + properties.putLong(CATALOG_PROPERTY_VERSION, propertyVersion); + properties.putProperties(toCatalogProperties()); + return properties.asMap(); + } + + /** + * Converts this descriptor into a set of catalog properties. + */ + protected abstract Map toCatalogProperties(); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java new file mode 100644 index 00000000000000..723dcb013a0758 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.descriptors; + +import org.apache.flink.annotation.Internal; + +/** + * Validator for {@link CatalogDescriptor}. + */ +@Internal +public abstract class CatalogDescriptorValidator implements DescriptorValidator { + + /** + * Key for describing the type of the catalog. Usually used for factory discovery.ca + */ + public static final String CATALOG_TYPE = "type"; + + /** + * Key for describing the property version. This property can be used for backwards + * compatibility in case the property format changes. + */ + public static final String CATALOG_PROPERTY_VERSION = "property-version"; + + @Override + public void validate(DescriptorProperties properties) { + properties.validateString(CATALOG_TYPE, false, 1); + properties.validateInt(CATALOG_PROPERTY_VERSION, true, 0); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/CatalogFactory.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/CatalogFactory.java new file mode 100644 index 00000000000000..6ff65fb72761bc --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/CatalogFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.factories; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.catalog.Catalog; + +import java.util.Map; + +/** + * A factory to create configured catalog instances based on string-based properties. See + * also {@link TableFactory} for more information. + */ +@PublicEvolving +public interface CatalogFactory extends TableFactory { + + /** + * Creates and configures a {@link Catalog} using the given properties. + * + * @param properties normalized properties describing an external catalog. + * @return the configured catalog. + */ + Catalog createCatalog(String name, Map properties); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/TableFactoryService.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/TableFactoryService.java index 30d34b7f1db9f4..ba6eec4fbb17d4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/TableFactoryService.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/TableFactoryService.java @@ -211,6 +211,7 @@ private static List filterByContext( plainContext.remove(METADATA_PROPERTY_VERSION); plainContext.remove(STATISTICS_PROPERTY_VERSION); plainContext.remove(CATALOG_PROPERTY_VERSION); + plainContext.remove(org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION); // check if required context is met return plainContext.keySet().stream().allMatch(e -> properties.containsKey(e) && properties.get(e).equals(plainContext.get(e))); diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/descriptors/CatalogDescriptorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/descriptors/CatalogDescriptorTest.java new file mode 100644 index 00000000000000..f294b4d0269689 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/descriptors/CatalogDescriptorTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.descriptors; + +import org.apache.flink.table.api.ValidationException; + +import org.junit.Test; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; + + +/** + * Tests for the {@link CatalogDescriptor} descriptor and + * {@link CatalogDescriptorValidator} validator. + */ +public class CatalogDescriptorTest extends DescriptorTestBase { + + private static final String CATALOG_TYPE_VALUE = "CatalogDescriptorTest"; + private static final int CATALOG_PROPERTY_VERSION_VALUE = 1; + private static final String CATALOG_FOO = "foo"; + private static final String CATALOG_FOO_VALUE = "foo-1"; + + @Test(expected = ValidationException.class) + public void testMissingCatalogType() { + removePropertyAndVerify(descriptors().get(0), CATALOG_TYPE); + } + + @Test(expected = ValidationException.class) + public void testMissingFoo() { + removePropertyAndVerify(descriptors().get(0), CATALOG_FOO); + } + + @Override + protected List descriptors() { + final Descriptor minimumDesc = new TestCatalogDescriptor(CATALOG_FOO_VALUE); + return Collections.singletonList(minimumDesc); + } + + @Override + protected List> properties() { + final Map minimumProps = new HashMap<>(); + minimumProps.put(CATALOG_TYPE, CATALOG_TYPE_VALUE); + minimumProps.put(CATALOG_PROPERTY_VERSION, "" + CATALOG_PROPERTY_VERSION_VALUE); + minimumProps.put(CATALOG_FOO, CATALOG_FOO_VALUE); + return Collections.singletonList(minimumProps); + } + + @Override + protected DescriptorValidator validator() { + return new TestCatalogDescriptorValidator(); + } + + /** + * CatalogDescriptor for test. + */ + private class TestCatalogDescriptor extends CatalogDescriptor { + private String foo; + + public TestCatalogDescriptor(@Nullable String foo) { + super(CATALOG_TYPE_VALUE, CATALOG_PROPERTY_VERSION_VALUE); + this.foo = foo; + } + + @Override + protected Map toCatalogProperties() { + DescriptorProperties properties = new DescriptorProperties(); + if (foo != null) { + properties.putString(CATALOG_FOO, foo); + } + return properties.asMap(); + } + } + + /** + * CatalogDescriptorValidator for test. + */ + private class TestCatalogDescriptorValidator extends CatalogDescriptorValidator { + @Override + public void validate(DescriptorProperties properties) { + super.validate(properties); + properties.validateString(CATALOG_FOO, false, 1); + } + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/CatalogFactoryServiceTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/CatalogFactoryServiceTest.java new file mode 100644 index 00000000000000..44d924a5efe2c0 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/CatalogFactoryServiceTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.factories; + +import org.apache.flink.table.api.NoMatchingTableFactoryException; +import org.apache.flink.table.factories.utils.TestCatalogFactory; + +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; +import static org.apache.flink.table.factories.utils.TestCatalogFactory.CATALOG_TYPE_TEST; +import static org.junit.Assert.assertEquals; + +/** + * Tests for testing external catalog discovery using {@link TableFactoryService}. + * The tests assume the catalog factory {@link CatalogFactory} is registered. + */ +public class CatalogFactoryServiceTest { + @Test + public void testValidProperties() { + Map props = properties(); + + assertEquals( + TableFactoryService.find(CatalogFactory.class, props).getClass(), + TestCatalogFactory.class); + } + + @Test(expected = NoMatchingTableFactoryException.class) + public void testInvalidContext() { + Map props = properties(); + props.put(CATALOG_TYPE, "unknown-catalog-type"); + TableFactoryService.find(CatalogFactory.class, props); + } + + @Test + public void testDifferentContextVersion() { + Map props = properties(); + props.put(CATALOG_PROPERTY_VERSION, "2"); + + // the catalog should still be found + assertEquals( + TableFactoryService.find(CatalogFactory.class, props).getClass(), + TestCatalogFactory.class); + } + + @Test(expected = NoMatchingTableFactoryException.class) + public void testUnsupportedProperty() { + Map props = properties(); + props.put("unknown-property", "/new/path"); + TableFactoryService.find(CatalogFactory.class, props); + } + + private Map properties() { + Map properties = new HashMap<>(); + + properties.put(CATALOG_TYPE, CATALOG_TYPE_TEST); + properties.put(CATALOG_PROPERTY_VERSION, "1"); + return properties; + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/utils/TestCatalogFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/utils/TestCatalogFactory.java new file mode 100644 index 00000000000000..af0f4761716cf8 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/factories/utils/TestCatalogFactory.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.factories.utils; + +import org.apache.flink.table.catalog.Catalog; +import org.apache.flink.table.catalog.GenericInMemoryCatalog; +import org.apache.flink.table.factories.CatalogFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; + +/** + * Catalog factory for testing. + */ +public class TestCatalogFactory implements CatalogFactory { + + public static final String CATALOG_TYPE_TEST = "test"; + + @Override + public Catalog createCatalog(String name, Map properties) { + return new GenericInMemoryCatalog(name); + } + + @Override + public Map requiredContext() { + Map context = new HashMap<>(); + context.put(CATALOG_TYPE, CATALOG_TYPE_TEST); + context.put(CATALOG_PROPERTY_VERSION, "1"); + + return context; + } + + @Override + public List supportedProperties() { + return Collections.emptyList(); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory b/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory index dab44d237b36e1..c5fe13f414cd67 100644 --- a/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory +++ b/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.TableFactory @@ -21,3 +21,4 @@ org.apache.flink.table.factories.utils.TestTableFormatFactory org.apache.flink.table.factories.utils.TestAmbiguousTableFormatFactory org.apache.flink.table.factories.utils.TestExternalCatalogFactory org.apache.flink.table.catalog.TestExternalTableSourceFactory +org.apache.flink.table.factories.utils.TestCatalogFactory From 5b91db9fe01066980be78df9e8e8f177c8bc6543 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 27 May 2019 11:47:13 +0200 Subject: [PATCH 25/92] [hotfix][table-common] Update CHAR and BINARY in accordance with the SQL standard --- .../org/apache/flink/table/api/DataTypes.java | 2 +- .../flink/table/types/logical/BinaryType.java | 26 ++++++++++++++- .../flink/table/types/logical/CharType.java | 33 ++++++++++++++++--- .../flink/table/types/LogicalTypesTest.java | 2 +- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/DataTypes.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/DataTypes.java index dc7c3192f3fe42..288a2f62ff722a 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/DataTypes.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/DataTypes.java @@ -81,7 +81,7 @@ public final class DataTypes { /** * Data type of a fixed-length character string {@code CHAR(n)} where {@code n} is the number - * of code points. {@code n} must have a value between 1 and 255 (both inclusive). + * of code points. {@code n} must have a value between 1 and {@link Integer#MAX_VALUE} (both inclusive). * * @see CharType */ diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/BinaryType.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/BinaryType.java index e342d9e3262433..25dbcd47529aa4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/BinaryType.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/BinaryType.java @@ -32,10 +32,15 @@ *

The serialized string representation is {@code BINARY(n)} where {@code n} is the number of * bytes. {@code n} must have a value between 1 and {@link Integer#MAX_VALUE} (both inclusive). If * no length is specified, {@code n} is equal to 1. + * + *

For expressing a zero-length binary string literal, this type does also support {@code n} to + * be 0. However, this is not exposed through the API. */ @PublicEvolving public final class BinaryType extends LogicalType { + public static final int EMPTY_LITERAL_LENGTH = 0; + public static final int MIN_LENGTH = 1; public static final int MAX_LENGTH = Integer.MAX_VALUE; @@ -72,13 +77,32 @@ public BinaryType() { this(DEFAULT_LENGTH); } + /** + * Helper constructor for {@link #ofEmptyLiteral()} and {@link #copy(boolean)}. + */ + private BinaryType(int length, boolean isNullable) { + super(isNullable, LogicalTypeRoot.BINARY); + this.length = length; + } + + /** + * The SQL standard defines that character string literals are allowed to be zero-length strings + * (i.e., to contain no characters) even though it is not permitted to declare a type that is zero. + * For consistent behavior, the same logic applies to binary strings. + * + *

This method enables this special kind of binary string. + */ + public static BinaryType ofEmptyLiteral() { + return new BinaryType(EMPTY_LITERAL_LENGTH, false); + } + public int getLength() { return length; } @Override public LogicalType copy(boolean isNullable) { - return new BinaryType(isNullable, length); + return new BinaryType(length, isNullable); } @Override diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/CharType.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/CharType.java index fe870ce18d7a71..8385d05dfddb3b 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/CharType.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/CharType.java @@ -30,17 +30,22 @@ * Logical type of a fixed-length character string. * *

The serialized string representation is {@code CHAR(n)} where {@code n} is the number of - * code points. {@code n} must have a value between 1 and 255 (both inclusive). If no length is - * specified, {@code n} is equal to 1. + * code points. {@code n} must have a value between 1 and {@link Integer#MAX_VALUE} (both inclusive). + * If no length is specified, {@code n} is equal to 1. + * + *

For expressing a zero-length character string literal, this type does also support {@code n} + * to be 0. However, this is not exposed through the API. * *

A conversion from and to {@code byte[]} assumes UTF-8 encoding. */ @PublicEvolving public final class CharType extends LogicalType { + public static final int EMPTY_LITERAL_LENGTH = 0; + public static final int MIN_LENGTH = 1; - public static final int MAX_LENGTH = 255; + public static final int MAX_LENGTH = Integer.MAX_VALUE; public static final int DEFAULT_LENGTH = 1; @@ -57,7 +62,7 @@ public final class CharType extends LogicalType { public CharType(boolean isNullable, int length) { super(isNullable, LogicalTypeRoot.CHAR); - if (length < MIN_LENGTH || length > MAX_LENGTH) { + if (length < MIN_LENGTH) { throw new ValidationException( String.format( "Character string length must be between %d and %d (both inclusive).", @@ -75,13 +80,31 @@ public CharType() { this(DEFAULT_LENGTH); } + /** + * Helper constructor for {@link #ofEmptyLiteral()} and {@link #copy(boolean)}. + */ + private CharType(int length, boolean isNullable) { + super(isNullable, LogicalTypeRoot.CHAR); + this.length = length; + } + + /** + * The SQL standard defines that character string literals are allowed to be zero-length strings + * (i.e., to contain no characters) even though it is not permitted to declare a type that is zero. + * + *

This method enables this special kind of character string. + */ + public static CharType ofEmptyLiteral() { + return new CharType(EMPTY_LITERAL_LENGTH, false); + } + public int getLength() { return length; } @Override public LogicalType copy(boolean isNullable) { - return new CharType(isNullable, length); + return new CharType(length, isNullable); } @Override diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java index 853d37f5a5671e..600b7189b42b79 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java @@ -86,7 +86,7 @@ public void testCharType() { new Class[]{String.class, byte[].class}, new Class[]{String.class, byte[].class}, new LogicalType[]{}, - new CharType(12) + new CharType(Integer.MAX_VALUE) ); } From 1b525fdd83d9fe6090d07b170370f7900d2367d5 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 27 May 2019 11:47:55 +0200 Subject: [PATCH 26/92] [hotfix][table-common] Add missing getter to TimeType --- .../java/org/apache/flink/table/types/logical/TimeType.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/TimeType.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/TimeType.java index f55ddf3e07efb3..15763d3abca1f4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/TimeType.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/TimeType.java @@ -89,6 +89,10 @@ public TimeType() { this(DEFAULT_PRECISION); } + public int getPrecision() { + return precision; + } + @Override public LogicalType copy(boolean isNullable) { return new TimeType(isNullable, precision); From c5eb8a7bdf379514fb50fc6703ae984d3cacb64b Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 27 May 2019 12:10:55 +0200 Subject: [PATCH 27/92] [hotfix][table-common] Add more logical type check utilities --- .../logical/utils/LogicalTypeChecks.java | 183 ++++++++++++++++-- 1 file changed, 171 insertions(+), 12 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java index 39589224529fef..24f15787891331 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java @@ -19,21 +19,41 @@ package org.apache.flink.table.types.logical.utils; import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.logical.BinaryType; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.DayTimeIntervalType; +import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.LocalZonedTimestampType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.TimeType; import org.apache.flink.table.types.logical.TimestampKind; import org.apache.flink.table.types.logical.TimestampType; +import org.apache.flink.table.types.logical.VarBinaryType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.table.types.logical.YearMonthIntervalType; import org.apache.flink.table.types.logical.ZonedTimestampType; /** - * Utilities for checking {@link LogicalType}. + * Utilities for checking {@link LogicalType} and avoiding a lot of type casting and repetitive work. */ @Internal public final class LogicalTypeChecks { - private static final TimeAttributeChecker TIME_ATTRIBUTE_CHECKER = new TimeAttributeChecker(); + private static final TimestampKindExtractor TIMESTAMP_KIND_EXTRACTOR = new TimestampKindExtractor(); + + private static final LengthExtractor LENGTH_EXTRACTOR = new LengthExtractor(); + + private static final PrecisionExtractor PRECISION_EXTRACTOR = new PrecisionExtractor(); + + private static final ScaleExtractor SCALE_EXTRACTOR = new ScaleExtractor(); + + private static final YearPrecisionExtractor YEAR_PRECISION_EXTRACTOR = new YearPrecisionExtractor(); + + private static final DayPrecisionExtractor DAY_PRECISION_EXTRACTOR = new DayPrecisionExtractor(); + + private static final FractionalPrecisionExtractor FRACTIONAL_PRECISION_EXTRACTOR = new FractionalPrecisionExtractor(); public static boolean hasRoot(LogicalType logicalType, LogicalTypeRoot typeRoot) { return logicalType.getTypeRoot() == typeRoot; @@ -44,15 +64,63 @@ public static boolean hasFamily(LogicalType logicalType, LogicalTypeFamily famil } public static boolean isTimeAttribute(LogicalType logicalType) { - return logicalType.accept(TIME_ATTRIBUTE_CHECKER) != TimestampKind.REGULAR; + return logicalType.accept(TIMESTAMP_KIND_EXTRACTOR) != TimestampKind.REGULAR; } public static boolean isRowtimeAttribute(LogicalType logicalType) { - return logicalType.accept(TIME_ATTRIBUTE_CHECKER) == TimestampKind.ROWTIME; + return logicalType.accept(TIMESTAMP_KIND_EXTRACTOR) == TimestampKind.ROWTIME; } public static boolean isProctimeAttribute(LogicalType logicalType) { - return logicalType.accept(TIME_ATTRIBUTE_CHECKER) == TimestampKind.PROCTIME; + return logicalType.accept(TIMESTAMP_KIND_EXTRACTOR) == TimestampKind.PROCTIME; + } + + public static int getLength(LogicalType logicalType) { + return logicalType.accept(LENGTH_EXTRACTOR); + } + + public static boolean hasLength(LogicalType logicalType, int length) { + return getLength(logicalType) == length; + } + + public static int getPrecision(LogicalType logicalType) { + return logicalType.accept(PRECISION_EXTRACTOR); + } + + public static boolean hasPrecision(LogicalType logicalType, int precision) { + return getPrecision(logicalType) == precision; + } + + public static int getScale(LogicalType logicalType) { + return logicalType.accept(SCALE_EXTRACTOR); + } + + public static boolean hasScale(LogicalType logicalType, int scale) { + return getScale(logicalType) == scale; + } + + public static int getYearPrecision(LogicalType logicalType) { + return logicalType.accept(YEAR_PRECISION_EXTRACTOR); + } + + public static boolean hasYearPrecision(LogicalType logicalType, int yearPrecision) { + return getYearPrecision(logicalType) == yearPrecision; + } + + public static int getDayPrecision(LogicalType logicalType) { + return logicalType.accept(DAY_PRECISION_EXTRACTOR); + } + + public static boolean hasDayPrecision(LogicalType logicalType, int yearPrecision) { + return getDayPrecision(logicalType) == yearPrecision; + } + + public static int getFractionalPrecision(LogicalType logicalType) { + return logicalType.accept(FRACTIONAL_PRECISION_EXTRACTOR); + } + + public static boolean hasFractionalPrecision(LogicalType logicalType, int fractionalPrecision) { + return getFractionalPrecision(logicalType) == fractionalPrecision; } private LogicalTypeChecks() { @@ -61,7 +129,104 @@ private LogicalTypeChecks() { // -------------------------------------------------------------------------------------------- - private static class TimeAttributeChecker extends LogicalTypeDefaultVisitor { + /** + * Extracts an attribute of logical types that define that attribute. + */ + private static class Extractor extends LogicalTypeDefaultVisitor { + @Override + protected T defaultMethod(LogicalType logicalType) { + throw new IllegalArgumentException( + String.format( + "Invalid use of extractor %s. Called on logical type: %s", + this.getClass().getName(), + logicalType)); + } + } + + private static class LengthExtractor extends Extractor { + + @Override + public Integer visit(CharType charType) { + return charType.getLength(); + } + + @Override + public Integer visit(VarCharType varCharType) { + return varCharType.getLength(); + } + + @Override + public Integer visit(BinaryType binaryType) { + return binaryType.getLength(); + } + + @Override + public Integer visit(VarBinaryType varBinaryType) { + return varBinaryType.getLength(); + } + } + + private static class PrecisionExtractor extends Extractor { + + @Override + public Integer visit(DecimalType decimalType) { + return decimalType.getPrecision(); + } + + @Override + public Integer visit(TimeType timeType) { + return timeType.getPrecision(); + } + + @Override + public Integer visit(TimestampType timestampType) { + return timestampType.getPrecision(); + } + + @Override + public Integer visit(ZonedTimestampType zonedTimestampType) { + return zonedTimestampType.getPrecision(); + } + + @Override + public Integer visit(LocalZonedTimestampType localZonedTimestampType) { + return localZonedTimestampType.getPrecision(); + } + } + + private static class ScaleExtractor extends Extractor { + + @Override + public Integer visit(DecimalType decimalType) { + return decimalType.getScale(); + } + } + + private static class YearPrecisionExtractor extends Extractor { + + @Override + public Integer visit(YearMonthIntervalType yearMonthIntervalType) { + return yearMonthIntervalType.getYearPrecision(); + } + } + + private static class DayPrecisionExtractor extends Extractor { + + @Override + public Integer visit(DayTimeIntervalType dayTimeIntervalType) { + return dayTimeIntervalType.getDayPrecision(); + } + } + + private static class FractionalPrecisionExtractor extends Extractor { + + @Override + public Integer visit(DayTimeIntervalType dayTimeIntervalType) { + return dayTimeIntervalType.getFractionalPrecision(); + } + } + + private static class TimestampKindExtractor extends Extractor { @Override public TimestampKind visit(TimestampType timestampType) { @@ -77,11 +242,5 @@ public TimestampKind visit(ZonedTimestampType zonedTimestampType) { public TimestampKind visit(LocalZonedTimestampType localZonedTimestampType) { return localZonedTimestampType.getKind(); } - - @Override - protected TimestampKind defaultMethod(LogicalType logicalType) { - // we don't verify that type is actually a timestamp - return TimestampKind.REGULAR; - } } } From 551b8ef507109218a428baac2d5682fec2381a4f Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 27 May 2019 12:45:13 +0200 Subject: [PATCH 28/92] [hotfix][table-common] Add a value to data type converter --- .../types/utils/ValueDataTypeConverter.java | 231 ++++++++++++++++++ .../types/ValueDataTypeConverterTest.java | 160 ++++++++++++ 2 files changed, 391 insertions(+) create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ValueDataTypeConverter.java create mode 100644 flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ValueDataTypeConverterTest.java diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ValueDataTypeConverter.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ValueDataTypeConverter.java new file mode 100644 index 00000000000000..3b311cef688711 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ValueDataTypeConverter.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.utils; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.AtomicDataType; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.BinaryType; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.LogicalTypeFamily; + +import java.math.BigDecimal; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Stream; + +/** + * Value-based data type extractor that supports extraction of clearly identifiable data types for + * input conversion. + * + *

This converter is more precise than {@link ClassDataTypeConverter} because it also considers + * nullability, length, precision, and scale of values. + */ +@Internal +public final class ValueDataTypeConverter { + + /** + * Returns the clearly identifiable data type if possible. For example, {@code 12L} can be + * expressed as {@code DataTypes.BIGINT().notNull()}. However, for example, {@code null} could + * be any type and is not supported. + * + *

All types of the {@link LogicalTypeFamily#PREDEFINED} family and arrays are supported. + */ + public static Optional extractDataType(Object value) { + if (value == null) { + return Optional.empty(); + } + + DataType convertedDataType = null; + + if (value instanceof String) { + convertedDataType = convertToCharType((String) value); + } + + // byte arrays have higher priority than regular arrays + else if (value instanceof byte[]) { + convertedDataType = convertToBinaryType((byte[]) value); + } + + else if (value instanceof BigDecimal) { + convertedDataType = convertToDecimalType((BigDecimal) value); + } + + else if (value instanceof java.time.LocalTime) { + convertedDataType = convertToTimeType((java.time.LocalTime) value); + } + + else if (value instanceof java.time.LocalDateTime) { + convertedDataType = convertToTimestampType(((java.time.LocalDateTime) value).getNano()); + } + + else if (value instanceof java.sql.Timestamp) { + convertedDataType = convertToTimestampType(((java.sql.Timestamp) value).getNanos()); + } + + else if (value instanceof java.time.ZonedDateTime) { + convertedDataType = convertToZonedTimestampType(((java.time.ZonedDateTime) value).getNano()); + } + + else if (value instanceof java.time.OffsetDateTime) { + convertedDataType = convertToZonedTimestampType(((java.time.OffsetDateTime) value).getNano()); + } + + else if (value instanceof java.time.Instant) { + convertedDataType = convertToLocalZonedTimestampType(((java.time.Instant) value).getNano()); + } + + else if (value instanceof java.time.Period) { + convertedDataType = convertToYearMonthIntervalType(((java.time.Period) value).getYears()); + } + + else if (value instanceof java.time.Duration) { + final java.time.Duration duration = (java.time.Duration) value; + convertedDataType = convertToDayTimeIntervalType(duration.toDays(), duration.getNano()); + } + + else if (value instanceof Object[]) { + // don't let the class-based extraction kick in if array elements differ + return convertToArrayType((Object[]) value) + .map(dt -> dt.notNull().bridgedTo(value.getClass())); + } + + final Optional resultType; + if (convertedDataType != null) { + resultType = Optional.of(convertedDataType); + } else { + // class-based extraction is possible for BOOLEAN, TINYINT, SMALLINT, INT, FLOAT, DOUBLE, + // DATE, TIME with java.sql.Time, and arrays of primitive types + resultType = ClassDataTypeConverter.extractDataType(value.getClass()); + } + return resultType.map(dt -> dt.notNull().bridgedTo(value.getClass())); + } + + private static DataType convertToCharType(String string) { + if (string.isEmpty()) { + return new AtomicDataType(CharType.ofEmptyLiteral()); + } + return DataTypes.CHAR(string.length()); + } + + private static DataType convertToBinaryType(byte[] bytes) { + if (bytes.length == 0) { + return new AtomicDataType(BinaryType.ofEmptyLiteral()); + } + return DataTypes.BINARY(bytes.length); + } + + private static DataType convertToDecimalType(BigDecimal decimal) { + // let underlying layers check if precision and scale are supported + return DataTypes.DECIMAL(decimal.precision(), decimal.scale()); + } + + private static DataType convertToTimeType(java.time.LocalTime time) { + return DataTypes.TIME(fractionalSecondPrecision(time.getNano())); + } + + private static DataType convertToTimestampType(int nanos) { + return DataTypes.TIMESTAMP(fractionalSecondPrecision(nanos)); + } + + private static DataType convertToZonedTimestampType(int nanos) { + return DataTypes.TIMESTAMP_WITH_TIME_ZONE(fractionalSecondPrecision(nanos)); + } + + private static DataType convertToLocalZonedTimestampType(int nanos) { + return DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(fractionalSecondPrecision(nanos)); + } + + private static DataType convertToYearMonthIntervalType(int years) { + return DataTypes.INTERVAL(DataTypes.YEAR(yearPrecision(years)), DataTypes.MONTH()); + } + + private static DataType convertToDayTimeIntervalType(long days, int nanos) { + return DataTypes.INTERVAL( + DataTypes.DAY(dayPrecision(days)), + DataTypes.SECOND(fractionalSecondPrecision(nanos))); + } + + private static Optional convertToArrayType(Object[] array) { + // fallback to class based-extraction if no values exist + if (array.length == 0 || Stream.of(array).allMatch(Objects::isNull)) { + return extractElementTypeFromClass(array); + } + + return extractElementTypeFromValues(array); + } + + private static Optional extractElementTypeFromValues(Object[] array) { + DataType elementType = null; + for (Object element : array) { + // null values are wildcard array elements + if (element == null) { + continue; + } + + final Optional possibleElementType = extractDataType(element); + if (!possibleElementType.isPresent()) { + return Optional.empty(); + } + + // for simplification, we assume that array elements can always be nullable + // otherwise mismatches could occur when dealing with nested arrays + final DataType extractedElementType = possibleElementType.get().nullable(); + + // ensure that all elements have the same type; + // in theory the logic could be improved by converting an array with elements + // [CHAR(1), CHAR(2)] into an array of CHAR(2) but this can lead to value + // modification (i.e. adding spaces) which is not intended. + if (elementType != null && !extractedElementType.equals(elementType)) { + return Optional.empty(); + } + elementType = extractedElementType; + } + + return Optional.ofNullable(elementType) + .map(DataTypes::ARRAY); + } + + private static Optional extractElementTypeFromClass(Object[] array) { + final Optional possibleElementType = + ClassDataTypeConverter.extractDataType(array.getClass().getComponentType()); + + // for simplification, we assume that array elements can always be nullable + return possibleElementType + .map(DataType::nullable) + .map(DataTypes::ARRAY); + } + + private static int fractionalSecondPrecision(int nanos) { + return String.format("%09d", nanos).replaceAll("0+$", "").length(); + } + + private static int yearPrecision(int years) { + return String.valueOf(years).length(); + } + + private static int dayPrecision(long days) { + return String.valueOf(days).length(); + } + + private ValueDataTypeConverter() { + // no instantiation + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ValueDataTypeConverterTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ValueDataTypeConverterTest.java new file mode 100644 index 00000000000000..9c001811ea0924 --- /dev/null +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ValueDataTypeConverterTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.types.logical.BinaryType; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.utils.ValueDataTypeConverter; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.annotation.Nullable; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.Period; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link ValueDataTypeConverter}. + */ +@RunWith(Parameterized.class) +public class ValueDataTypeConverterTest { + + @Parameterized.Parameters(name = "[{index}] value: {0} type: {1}") + public static List testData() { + return Arrays.asList( + new Object[][]{ + + {"Hello World", DataTypes.CHAR(11)}, + + {"", new AtomicDataType(CharType.ofEmptyLiteral())}, + + {new byte[]{1, 2, 3}, DataTypes.BINARY(3)}, + + {new byte[0], new AtomicDataType(BinaryType.ofEmptyLiteral())}, + + {BigDecimal.ZERO, DataTypes.DECIMAL(1, 0)}, + + {new BigDecimal("12.123"), DataTypes.DECIMAL(5, 3)}, + + {12, DataTypes.INT()}, + + {LocalTime.of(13, 24, 25, 1000), DataTypes.TIME(6)}, + + {LocalTime.of(13, 24, 25, 0), DataTypes.TIME(0)}, + + {LocalTime.of(13, 24, 25, 1), DataTypes.TIME(9)}, + + {LocalTime.of(13, 24, 25, 999_999_999), DataTypes.TIME(9)}, + + {LocalDateTime.of(2019, 11, 11, 13, 24, 25, 1001), DataTypes.TIMESTAMP(9)}, + + { + ZonedDateTime.of(2019, 11, 11, 13, 24, 25, 1001, ZoneId.systemDefault()), + DataTypes.TIMESTAMP_WITH_TIME_ZONE(9).bridgedTo(ZonedDateTime.class) + }, + + { + OffsetDateTime.of(2019, 11, 11, 13, 24, 25, 1001, ZoneOffset.UTC), + DataTypes.TIMESTAMP_WITH_TIME_ZONE(9).bridgedTo(OffsetDateTime.class) + }, + + { + Instant.ofEpochMilli(12345602021L), + DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3).bridgedTo(Instant.class) + }, + + { + Period.ofYears(1000), + DataTypes.INTERVAL(DataTypes.YEAR(4), DataTypes.MONTH()).bridgedTo(Period.class) + }, + + { + Duration.ofMillis(1100), + DataTypes.INTERVAL(DataTypes.DAY(1), DataTypes.SECOND(1)).bridgedTo(Duration.class) + }, + + { + Duration.ofDays(42), + DataTypes.INTERVAL(DataTypes.DAY(2), DataTypes.SECOND(0)).bridgedTo(Duration.class) + }, + + { + Timestamp.valueOf("2018-01-01 12:13:14.123"), + DataTypes.TIMESTAMP(3).bridgedTo(java.sql.Timestamp.class) + }, + + {new Integer[]{1, 2, 3}, DataTypes.ARRAY(DataTypes.INT())}, + + {new Integer[]{1, null, 3}, DataTypes.ARRAY(DataTypes.INT())}, + + { + new BigDecimal[]{new BigDecimal("12.1234"), new BigDecimal("42.4321"), new BigDecimal("20.0000")}, + DataTypes.ARRAY(DataTypes.DECIMAL(6, 4)) + }, + + { + new BigDecimal[]{null, new BigDecimal("42.4321")}, + DataTypes.ARRAY(DataTypes.DECIMAL(6, 4)) + }, + + {new Integer[0], DataTypes.ARRAY(DataTypes.INT())}, + + { + new Integer[][]{new Integer[]{1, null, 3}, new Integer[0], new Integer[]{1}}, + DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.INT())) + }, + + { + new BigDecimal[0], + null + }, + } + ); + } + + @Parameterized.Parameter + public Object value; + + @Parameterized.Parameter(1) + public @Nullable DataType dataType; + + @Test + public void testClassToDataTypeConversion() { + assertEquals( + Optional.ofNullable(dataType).map(DataType::notNull), + ValueDataTypeConverter.extractDataType(value)); + } +} From 1b9fb2b6d07d54848e57d42333382a4553852491 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Tue, 28 May 2019 16:21:30 +0200 Subject: [PATCH 29/92] [hotfix][table-common] Fix invalid class to data type conversion --- .../apache/flink/table/types/utils/ClassDataTypeConverter.java | 2 +- .../apache/flink/table/types/ClassDataTypeConverterTest.java | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ClassDataTypeConverter.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ClassDataTypeConverter.java index a71c682f6a205c..7001168ae681b0 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ClassDataTypeConverter.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/ClassDataTypeConverter.java @@ -56,7 +56,7 @@ public final class ClassDataTypeConverter { addDefaultDataType(double.class, DataTypes.DOUBLE()); addDefaultDataType(java.sql.Date.class, DataTypes.DATE()); addDefaultDataType(java.time.LocalDate.class, DataTypes.DATE()); - addDefaultDataType(java.sql.Time.class, DataTypes.TIME(3)); + addDefaultDataType(java.sql.Time.class, DataTypes.TIME(0)); addDefaultDataType(java.time.LocalTime.class, DataTypes.TIME(9)); addDefaultDataType(java.sql.Timestamp.class, DataTypes.TIMESTAMP(9)); addDefaultDataType(java.time.LocalDateTime.class, DataTypes.TIMESTAMP(9)); diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ClassDataTypeConverterTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ClassDataTypeConverterTest.java index d889dec9eeb537..46781cb1cf268b 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ClassDataTypeConverterTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/ClassDataTypeConverterTest.java @@ -52,6 +52,8 @@ public static List testData() { {Long.class, DataTypes.BIGINT().nullable().bridgedTo(Long.class)}, + {java.sql.Time.class, DataTypes.TIME(0).nullable().bridgedTo(java.sql.Time.class)}, + {BigDecimal.class, null}, { From 07e75de29969503b7feb639771d86ada6f5fbdf0 Mon Sep 17 00:00:00 2001 From: Zhu Zhu Date: Wed, 29 May 2019 14:39:21 +0800 Subject: [PATCH 30/92] [FLINK-12413][runtime] Implement ExecutionFailureHandler * Implement ExecutionFailureHandler * Throws exception when getting tasks or restart delay from the failure handling result when the restarting is suppressed; Renames verticesToBeRestarted to verticesToRestart * Address the comments * make verticesToRestart in FailureHandlingResult unmodifiable * Support checking for nested unrecoverable throwable; address comments --- .../flip1/ExecutionFailureHandler.java | 86 +++++++ .../failover/flip1/FailoverStrategy.java | 18 ++ .../failover/flip1/FailureHandlingResult.java | 137 +++++++++++ .../flip1/RestartBackoffTimeStrategy.java | 62 +++++ .../throwable/ThrowableClassifier.java | 27 +++ .../ThrowableClassifierTest.java | 33 +++ .../flip1/ExecutionFailureHandlerTest.java | 219 ++++++++++++++++++ .../flip1/FailureHandlingResultTest.java | 85 +++++++ 8 files changed, 667 insertions(+) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandler.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResult.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartBackoffTimeStrategy.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandlerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResultTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandler.java new file mode 100644 index 00000000000000..322599238b335d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandler.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph.failover.flip1; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.throwable.ThrowableClassifier; +import org.apache.flink.runtime.throwable.ThrowableType; + +import java.util.Optional; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * This handler deals with task failures to return a {@link FailureHandlingResult} which contains tasks + * to restart to recover from failures. + */ +public class ExecutionFailureHandler { + + /** Strategy to judge which tasks should be restarted. */ + private final FailoverStrategy failoverStrategy; + + /** Strategy to judge whether and when a restarting should be done. */ + private final RestartBackoffTimeStrategy restartBackoffTimeStrategy; + + /** + * Creates the handler to deal with task failures. + * + * @param failoverStrategy helps to decide tasks to restart on task failures + * @param restartBackoffTimeStrategy helps to decide whether to restart failed tasks and the restarting delay + */ + public ExecutionFailureHandler( + FailoverStrategy failoverStrategy, + RestartBackoffTimeStrategy restartBackoffTimeStrategy) { + + this.failoverStrategy = checkNotNull(failoverStrategy); + this.restartBackoffTimeStrategy = checkNotNull(restartBackoffTimeStrategy); + } + + /** + * Return result of failure handling. Can be a set of task vertices to restart + * and a delay of the restarting. Or that the failure is not recoverable and the reason for it. + * + * @param failedTask is the ID of the failed task vertex + * @param cause of the task failure + * @return result of the failure handling + */ + public FailureHandlingResult getFailureHandlingResult(ExecutionVertexID failedTask, Throwable cause) { + if (isUnrecoverableError(cause)) { + return FailureHandlingResult.unrecoverable(new JobException("The failure is not recoverable", cause)); + } + + restartBackoffTimeStrategy.notifyFailure(cause); + if (restartBackoffTimeStrategy.canRestart()) { + return FailureHandlingResult.restartable( + failoverStrategy.getTasksNeedingRestart(failedTask, cause), + restartBackoffTimeStrategy.getBackoffTime()); + } else { + return FailureHandlingResult.unrecoverable( + new JobException("Failed task restarting is suppressed by " + restartBackoffTimeStrategy, cause)); + } + } + + @VisibleForTesting + static boolean isUnrecoverableError(Throwable cause) { + Optional unrecoverableError = ThrowableClassifier.findThrowableOfThrowableType( + cause, ThrowableType.NonRecoverableError); + return unrecoverableError.isPresent(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverStrategy.java index 2fd4ce72c6afe4..0bd0c01b0abe9f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailoverStrategy.java @@ -34,4 +34,22 @@ public interface FailoverStrategy { * @return set of IDs of vertices to restart */ Set getTasksNeedingRestart(ExecutionVertexID executionVertexId, Throwable cause); + + // ------------------------------------------------------------------------ + // factory + // ------------------------------------------------------------------------ + + /** + * The factory to instantiate {@link FailoverStrategy}. + */ + interface Factory { + + /** + * Instantiates the {@link FailoverStrategy}. + * + * @param topology of the graph to failover + * @return The instantiated failover strategy. + */ + FailoverStrategy create(FailoverTopology topology); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResult.java new file mode 100644 index 00000000000000..51e487dfd182f1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResult.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph.failover.flip1; + +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; + +import java.util.Collections; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Result containing the tasks to restart upon a task failure. + * Also contains the reason if the failure is not recoverable(non-recoverable + * failure type or restarting suppressed by restart strategy). + */ +public class FailureHandlingResult { + + /** Task vertices to restart to recover from the failure. */ + private final Set verticesToRestart; + + /** Delay before the restarting can be conducted. */ + private final long restartDelayMS; + + /** Reason why the failure is not recoverable. */ + private final Throwable error; + + /** + * Creates a result of a set of tasks to restart to recover from the failure. + * + * @param verticesToRestart containing task vertices to restart to recover from the failure + * @param restartDelayMS indicate a delay before conducting the restart + */ + private FailureHandlingResult(Set verticesToRestart, long restartDelayMS) { + checkState(restartDelayMS >= 0); + + this.verticesToRestart = Collections.unmodifiableSet(checkNotNull(verticesToRestart)); + this.restartDelayMS = restartDelayMS; + this.error = null; + } + + /** + * Creates a result that the failure is not recoverable and no restarting should be conducted. + * + * @param error reason why the failure is not recoverable + */ + private FailureHandlingResult(Throwable error) { + this.verticesToRestart = null; + this.restartDelayMS = -1; + this.error = checkNotNull(error); + } + + /** + * Returns the tasks to restart. + * + * @return the tasks to restart + */ + public Set getVerticesToRestart() { + if (canRestart()) { + return verticesToRestart; + } else { + throw new IllegalStateException("Cannot get vertices to restart when the restarting is suppressed."); + } + } + + /** + * Returns the delay before the restarting. + * + * @return the delay before the restarting + */ + public long getRestartDelayMS() { + if (canRestart()) { + return restartDelayMS; + } else { + throw new IllegalStateException("Cannot get restart delay when the restarting is suppressed."); + } + } + + /** + * Returns reason why the restarting cannot be conducted. + * + * @return reason why the restarting cannot be conducted + */ + public Throwable getError() { + if (canRestart()) { + throw new IllegalStateException("Cannot get error when the restarting is accepted."); + } else { + return error; + } + } + + /** + * Returns whether the restarting can be conducted. + * + * @return whether the restarting can be conducted + */ + public boolean canRestart() { + return error == null; + } + + /** + * Creates a result of a set of tasks to restart to recover from the failure. + * + * @param verticesToRestart containing task vertices to restart to recover from the failure + * @param restartDelayMS indicate a delay before conducting the restart + * @return result of a set of tasks to restart to recover from the failure + */ + public static FailureHandlingResult restartable(Set verticesToRestart, long restartDelayMS) { + return new FailureHandlingResult(verticesToRestart, restartDelayMS); + } + + /** + * Creates a result that the failure is not recoverable and no restarting should be conducted. + * + * @param error reason why the failure is not recoverable + * @return result indicating the failure is not recoverable + */ + public static FailureHandlingResult unrecoverable(Throwable error) { + return new FailureHandlingResult(error); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartBackoffTimeStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartBackoffTimeStrategy.java new file mode 100644 index 00000000000000..cd7a39b60bf55e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartBackoffTimeStrategy.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph.failover.flip1; + +/** + * Strategy to decide whether to restart failed tasks and the delay to do the restarting. + */ +public interface RestartBackoffTimeStrategy { + + /** + * Returns whether a restart should be conducted. + * + * @return whether a restart should be conducted + */ + boolean canRestart(); + + /** + * Returns the delay to do the restarting. + * + * @return the delay to do the restarting + */ + long getBackoffTime(); + + /** + * Notify the strategy about the task failure cause. + * + * @param cause of the task failure + */ + void notifyFailure(Throwable cause); + + // ------------------------------------------------------------------------ + // factory + // ------------------------------------------------------------------------ + + /** + * The factory to instantiate {@link RestartBackoffTimeStrategy}. + */ + interface Factory { + + /** + * Instantiates the {@link RestartBackoffTimeStrategy}. + * + * @return The instantiated restart strategy. + */ + RestartBackoffTimeStrategy create(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/throwable/ThrowableClassifier.java b/flink-runtime/src/main/java/org/apache/flink/runtime/throwable/ThrowableClassifier.java index 4a17b5f619989d..a08c28f171dd25 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/throwable/ThrowableClassifier.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/throwable/ThrowableClassifier.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.throwable; +import java.util.Optional; + /** * Helper class, given a exception do the classification. */ @@ -33,4 +35,29 @@ public static ThrowableType getThrowableType(Throwable cause) { final ThrowableAnnotation annotation = cause.getClass().getAnnotation(ThrowableAnnotation.class); return annotation == null ? ThrowableType.RecoverableError : annotation.value(); } + + /** + * Checks whether a throwable chain contains a specific throwable type and returns the corresponding throwable. + * + * @param throwable the throwable chain to check. + * @param throwableType the throwable type to search for in the chain. + * @return Optional throwable of the throwable type if available, otherwise empty + */ + public static Optional findThrowableOfThrowableType(Throwable throwable, ThrowableType throwableType) { + if (throwable == null || throwableType == null) { + return Optional.empty(); + } + + Throwable t = throwable; + while (t != null) { + final ThrowableAnnotation annotation = t.getClass().getAnnotation(ThrowableAnnotation.class); + if (annotation != null && annotation.value() == throwableType) { + return Optional.of(t); + } else { + t = t.getCause(); + } + } + + return Optional.empty(); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ThrowableClassifierTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ThrowableClassifierTest.java index 57b330e3997909..517321a5c3b3e3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ThrowableClassifierTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ThrowableClassifierTest.java @@ -27,6 +27,8 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; /** * Test throwable classifier @@ -68,6 +70,34 @@ public void testThrowableType_InheritError() { ThrowableClassifier.getThrowableType(new Sub_ThrowableType_PartitionDataMissingError_Exception())); } + @Test + public void testFindThrowableOfThrowableType() { + // no throwable type + assertFalse(ThrowableClassifier.findThrowableOfThrowableType( + new Exception(), + ThrowableType.RecoverableError).isPresent()); + + // no recoverable throwable type + assertFalse(ThrowableClassifier.findThrowableOfThrowableType( + new ThrowableType_PartitionDataMissingError_Exception(), + ThrowableType.RecoverableError).isPresent()); + + // direct recoverable throwable + assertTrue(ThrowableClassifier.findThrowableOfThrowableType( + new ThrowableType_RecoverableFailure_Exception(), + ThrowableType.RecoverableError).isPresent()); + + // nested recoverable throwable + assertTrue(ThrowableClassifier.findThrowableOfThrowableType( + new Exception(new ThrowableType_RecoverableFailure_Exception()), + ThrowableType.RecoverableError).isPresent()); + + // inherit recoverable throwable + assertTrue(ThrowableClassifier.findThrowableOfThrowableType( + new Sub_ThrowableType_RecoverableFailure_Exception(), + ThrowableType.RecoverableError).isPresent()); + } + @ThrowableAnnotation(ThrowableType.PartitionDataMissingError) private class ThrowableType_PartitionDataMissingError_Exception extends Exception { } @@ -82,4 +112,7 @@ private class ThrowableType_RecoverableFailure_Exception extends Exception { private class Sub_ThrowableType_PartitionDataMissingError_Exception extends ThrowableType_PartitionDataMissingError_Exception { } + + private class Sub_ThrowableType_RecoverableFailure_Exception extends ThrowableType_RecoverableFailure_Exception { + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandlerTest.java new file mode 100644 index 00000000000000..ed37496f75a3d8 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/ExecutionFailureHandlerTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph.failover.flip1; + +import org.apache.flink.runtime.execution.SuppressRestartsException; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.util.TestLogger; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Tests for {@link ExecutionFailureHandler}. + */ +public class ExecutionFailureHandlerTest extends TestLogger { + + /** + * Tests the case that task restarting is accepted. + */ + @Test + public void testNormalFailureHandling() { + // failover strategy which always suggests restarting the given tasks + Set tasksToRestart = new HashSet<>(); + tasksToRestart.add(new ExecutionVertexID(new JobVertexID(), 0)); + FailoverStrategy failoverStrategy = new TestFailoverStrategy(tasksToRestart); + + // restart strategy which accepts restarting + boolean canRestart = true; + long restartDelayMs = 1234; + RestartBackoffTimeStrategy restartStrategy = new TestRestartBackoffTimeStrategy(canRestart, restartDelayMs); + ExecutionFailureHandler executionFailureHandler = new ExecutionFailureHandler(failoverStrategy, restartStrategy); + + // trigger a task failure + FailureHandlingResult result = executionFailureHandler.getFailureHandlingResult( + new ExecutionVertexID(new JobVertexID(), 0), + new Exception("test failure")); + + // verify results + assertTrue(result.canRestart()); + assertEquals(restartDelayMs, result.getRestartDelayMS()); + assertEquals(tasksToRestart, result.getVerticesToRestart()); + try { + result.getError(); + fail("Cannot get error when the restarting is accepted"); + } catch (IllegalStateException ex) { + // expected + } + } + + /** + * Tests the case that task restarting is suppressed. + */ + @Test + public void testRestartingSuppressedFailureHandlingResult() { + // failover strategy which always suggests restarting the given tasks + Set tasksToRestart = new HashSet<>(); + tasksToRestart.add(new ExecutionVertexID(new JobVertexID(), 0)); + FailoverStrategy failoverStrategy = new TestFailoverStrategy(tasksToRestart); + + // restart strategy which suppresses restarting + boolean canRestart = false; + long restartDelayMs = 1234; + RestartBackoffTimeStrategy restartStrategy = new TestRestartBackoffTimeStrategy(canRestart, restartDelayMs); + ExecutionFailureHandler executionFailureHandler = new ExecutionFailureHandler(failoverStrategy, restartStrategy); + + // trigger a task failure + FailureHandlingResult result = executionFailureHandler.getFailureHandlingResult( + new ExecutionVertexID(new JobVertexID(), 0), + new Exception("test failure")); + + // verify results + assertFalse(result.canRestart()); + assertNotNull(result.getError()); + assertFalse(ExecutionFailureHandler.isUnrecoverableError(result.getError())); + try { + result.getVerticesToRestart(); + fail("get tasks to restart is not allowed when restarting is suppressed"); + } catch (IllegalStateException ex) { + // expected + } + try { + result.getRestartDelayMS(); + fail("get restart delay is not allowed when restarting is suppressed"); + } catch (IllegalStateException ex) { + // expected + } + } + + /** + * Tests the case that the failure is non-recoverable type. + */ + @Test + public void testNonRecoverableFailureHandlingResult() { + // failover strategy which always suggests restarting the given tasks + Set tasksToRestart = new HashSet<>(); + tasksToRestart.add(new ExecutionVertexID(new JobVertexID(), 0)); + FailoverStrategy failoverStrategy = new TestFailoverStrategy(tasksToRestart); + + // restart strategy which accepts restarting + boolean canRestart = true; + long restartDelayMs = 1234; + RestartBackoffTimeStrategy restartStrategy = new TestRestartBackoffTimeStrategy(canRestart, restartDelayMs); + ExecutionFailureHandler executionFailureHandler = new ExecutionFailureHandler(failoverStrategy, restartStrategy); + + // trigger an unrecoverable task failure + FailureHandlingResult result = executionFailureHandler.getFailureHandlingResult( + new ExecutionVertexID(new JobVertexID(), 0), + new Exception(new SuppressRestartsException(new Exception("test failure")))); + + // verify results + assertFalse(result.canRestart()); + assertNotNull(result.getError()); + assertTrue(ExecutionFailureHandler.isUnrecoverableError(result.getError())); + try { + result.getVerticesToRestart(); + fail("get tasks to restart is not allowed when restarting is suppressed"); + } catch (IllegalStateException ex) { + // expected + } + try { + result.getRestartDelayMS(); + fail("get restart delay is not allowed when restarting is suppressed"); + } catch (IllegalStateException ex) { + // expected + } + } + + /** + * Tests the check for unrecoverable error. + */ + @Test + public void testUnrecoverableErrorCheck() { + // normal error + assertFalse(ExecutionFailureHandler.isUnrecoverableError(new Exception())); + + // direct unrecoverable error + assertTrue(ExecutionFailureHandler.isUnrecoverableError(new SuppressRestartsException(new Exception()))); + + // nested unrecoverable error + assertTrue(ExecutionFailureHandler.isUnrecoverableError( + new Exception(new SuppressRestartsException(new Exception())))); + } + + // ------------------------------------------------------------------------ + // utilities + // ------------------------------------------------------------------------ + + /** + * A FailoverStrategy implementation for tests. It always suggest restarting the given task set on construction. + */ + private class TestFailoverStrategy implements FailoverStrategy { + + private final Set tasksToRestart; + + public TestFailoverStrategy(Set tasksToRestart) { + this.tasksToRestart = checkNotNull(tasksToRestart); + } + + @Override + public Set getTasksNeedingRestart(ExecutionVertexID executionVertexId, Throwable cause) { + return tasksToRestart; + } + } + + /** + * A RestartBackoffTimeStrategy implementation for tests. + */ + private class TestRestartBackoffTimeStrategy implements RestartBackoffTimeStrategy { + + private final boolean canRestart; + + private final long backoffTime; + + public TestRestartBackoffTimeStrategy(boolean canRestart, long backoffTime) { + this.canRestart = canRestart; + this.backoffTime = backoffTime; + } + + @Override + public boolean canRestart() { + return canRestart; + } + + @Override + public long getBackoffTime() { + return backoffTime; + } + + @Override + public void notifyFailure(Throwable cause) { + // ignore + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResultTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResultTest.java new file mode 100644 index 00000000000000..89436556324ffb --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/failover/flip1/FailureHandlingResultTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.executiongraph.failover.flip1; + +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.util.TestLogger; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Tests for {@link FailureHandlingResult}. + */ +public class FailureHandlingResultTest extends TestLogger { + + /** + * Tests normal FailureHandlingResult. + */ + @Test + public void testNormalFailureHandlingResult() { + // create a normal FailureHandlingResult + Set tasks = new HashSet<>(); + tasks.add(new ExecutionVertexID(new JobVertexID(), 0)); + long delay = 1234; + FailureHandlingResult result = FailureHandlingResult.restartable(tasks, delay); + + assertTrue(result.canRestart()); + assertEquals(delay, result.getRestartDelayMS()); + assertEquals(tasks, result.getVerticesToRestart()); + try { + result.getError(); + fail("Cannot get error when the restarting is accepted"); + } catch (IllegalStateException ex) { + // expected + } + } + + /** + * Tests FailureHandlingResult which suppresses restarts. + */ + @Test + public void testRestartingSuppressedFailureHandlingResult() { + // create a FailureHandlingResult with error + Throwable error = new Exception("test error"); + FailureHandlingResult result = FailureHandlingResult.unrecoverable(error); + + assertFalse(result.canRestart()); + assertEquals(error, result.getError()); + try { + result.getVerticesToRestart(); + fail("get tasks to restart is not allowed when restarting is suppressed"); + } catch (IllegalStateException ex) { + // expected + } + try { + result.getRestartDelayMS(); + fail("get restart delay is not allowed when restarting is suppressed"); + } catch (IllegalStateException ex) { + // expected + } + } +} From 793a78407aa22530448efbf18b714952eac40aba Mon Sep 17 00:00:00 2001 From: Thomas Weise Date: Wed, 22 May 2019 21:42:15 -0700 Subject: [PATCH 31/92] [FLINK-10921] [kinesis] Shard watermark synchronization in Kinesis consumer --- .../kinesis/FlinkKinesisConsumer.java | 18 +- .../config/ConsumerConfigConstants.java | 11 + .../internals/DynamoDBStreamsDataFetcher.java | 1 + .../kinesis/internals/KinesisDataFetcher.java | 292 ++++++++++++++++-- .../util/JobManagerWatermarkTracker.java | 179 +++++++++++ .../kinesis/util/RecordEmitter.java | 269 ++++++++++++++++ .../kinesis/util/WatermarkTracker.java | 114 +++++++ .../FlinkKinesisConsumerMigrationTest.java | 2 +- .../kinesis/FlinkKinesisConsumerTest.java | 185 +++++++++++ .../kinesis/internals/ShardConsumerTest.java | 9 +- .../testutils/TestableKinesisDataFetcher.java | 1 + .../util/JobManagerWatermarkTrackerTest.java | 101 ++++++ .../kinesis/util/RecordEmitterTest.java | 81 +++++ .../kinesis/util/WatermarkTrackerTest.java | 108 +++++++ 14 files changed, 1342 insertions(+), 29 deletions(-) create mode 100644 flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTracker.java create mode 100644 flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitter.java create mode 100644 flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTracker.java create mode 100644 flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTrackerTest.java create mode 100644 flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitterTest.java create mode 100644 flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTrackerTest.java diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java index 3c5e3c7303f7e7..5b24ded6d6a2c5 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java @@ -45,6 +45,7 @@ import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; +import org.apache.flink.streaming.connectors.kinesis.util.WatermarkTracker; import org.apache.flink.util.InstantiationUtil; import org.slf4j.Logger; @@ -126,6 +127,7 @@ public class FlinkKinesisConsumer extends RichParallelSourceFunction imple private KinesisShardAssigner shardAssigner = KinesisDataFetcher.DEFAULT_SHARD_ASSIGNER; private AssignerWithPeriodicWatermarks periodicWatermarkAssigner; + private WatermarkTracker watermarkTracker; // ------------------------------------------------------------------------ // Runtime state @@ -254,6 +256,20 @@ public void setPeriodicWatermarkAssigner( ClosureCleaner.clean(this.periodicWatermarkAssigner, true); } + public WatermarkTracker getWatermarkTracker() { + return this.watermarkTracker; + } + + /** + * Set the global watermark tracker. When set, it will be used by the fetcher + * to align the shard consumers by event time. + * @param watermarkTracker + */ + public void setWatermarkTracker(WatermarkTracker watermarkTracker) { + this.watermarkTracker = watermarkTracker; + ClosureCleaner.clean(this.watermarkTracker, true); + } + // ------------------------------------------------------------------------ // Source life cycle // ------------------------------------------------------------------------ @@ -448,7 +464,7 @@ protected KinesisDataFetcher createFetcher( Properties configProps, KinesisDeserializationSchema deserializationSchema) { - return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema, shardAssigner, periodicWatermarkAssigner); + return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema, shardAssigner, periodicWatermarkAssigner, watermarkTracker); } @VisibleForTesting diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ConsumerConfigConstants.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ConsumerConfigConstants.java index 41ac6b877a9549..2f5be979c3b05e 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ConsumerConfigConstants.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/config/ConsumerConfigConstants.java @@ -125,6 +125,15 @@ public SentinelSequenceNumber toSentinelSequenceNumber() { /** The interval after which to consider a shard idle for purposes of watermark generation. */ public static final String SHARD_IDLE_INTERVAL_MILLIS = "flink.shard.idle.interval"; + /** The interval for periodically synchronizing the shared watermark state. */ + public static final String WATERMARK_SYNC_MILLIS = "flink.watermark.sync.interval"; + + /** The maximum delta allowed for the reader to advance ahead of the shared global watermark. */ + public static final String WATERMARK_LOOKAHEAD_MILLIS = "flink.watermark.lookahead.millis"; + + /** The maximum number of records that will be buffered before suspending consumption of a shard. */ + public static final String WATERMARK_SYNC_QUEUE_CAPACITY = "flink.watermark.sync.queue.capacity"; + // ------------------------------------------------------------------------ // Default values for consumer configuration // ------------------------------------------------------------------------ @@ -173,6 +182,8 @@ public SentinelSequenceNumber toSentinelSequenceNumber() { public static final long DEFAULT_SHARD_IDLE_INTERVAL_MILLIS = -1; + public static final long DEFAULT_WATERMARK_SYNC_MILLIS = 30_000; + /** * To avoid shard iterator expires in {@link ShardConsumer}s, the value for the configured * getRecords interval can not exceed 5 minutes, which is the expire time for retrieved iterators. diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/DynamoDBStreamsDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/DynamoDBStreamsDataFetcher.java index c2b7be352b1cd5..5620142e9405a4 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/DynamoDBStreamsDataFetcher.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/DynamoDBStreamsDataFetcher.java @@ -64,6 +64,7 @@ public DynamoDBStreamsDataFetcher(List streams, deserializationSchema, shardAssigner, null, + null, new AtomicReference<>(), new ArrayList<>(), createInitialSubscribedStreamsToLastDiscoveredShardsState(streams), diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java index 8c8d94ac3f1548..eae315358650f3 100644 --- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java @@ -38,6 +38,9 @@ import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy; import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface; import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema; +import org.apache.flink.streaming.connectors.kinesis.util.RecordEmitter; +import org.apache.flink.streaming.connectors.kinesis.util.WatermarkTracker; +import org.apache.flink.streaming.runtime.operators.windowing.TimestampedValue; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.util.InstantiationUtil; @@ -191,6 +194,9 @@ public class KinesisDataFetcher { private volatile boolean running = true; private final AssignerWithPeriodicWatermarks periodicWatermarkAssigner; + private final WatermarkTracker watermarkTracker; + private final transient RecordEmitter recordEmitter; + private transient boolean isIdle; /** * The watermark related state for each shard consumer. Entries in this map will be created when shards @@ -206,6 +212,14 @@ public class KinesisDataFetcher { */ private long lastWatermark = Long.MIN_VALUE; + /** + * The next watermark used for synchronization. + * For purposes of global watermark calculation, we need to consider the next watermark based + * on the buffered records vs. the last emitted watermark to allow for progress. + */ + private long nextWatermark = Long.MIN_VALUE; + + /** * The time span since last consumed record, after which a shard will be considered idle for purpose of watermark * calculation. A positive value will allow the watermark to progress even when some shards don't receive new records. @@ -219,6 +233,82 @@ public interface FlinkKinesisProxyFactory { KinesisProxyInterface create(Properties configProps); } + /** + * The wrapper that holds the watermark handling related parameters + * of a record produced by the shard consumer thread. + * + * @param + */ + private static class RecordWrapper extends TimestampedValue { + int shardStateIndex; + SequenceNumber lastSequenceNumber; + long timestamp; + Watermark watermark; + + private RecordWrapper(T record, long timestamp) { + super(record, timestamp); + this.timestamp = timestamp; + } + + @Override + public long getTimestamp() { + return timestamp; + } + } + + /** Kinesis data fetcher specific, asynchronous record emitter. */ + private class AsyncKinesisRecordEmitter extends RecordEmitter> { + + private AsyncKinesisRecordEmitter() { + this(DEFAULT_QUEUE_CAPACITY); + } + + private AsyncKinesisRecordEmitter(int queueCapacity) { + super(queueCapacity); + } + + @Override + public void emit(RecordWrapper record, RecordQueue> queue) { + emitRecordAndUpdateState(record); + ShardWatermarkState sws = shardWatermarks.get(queue.getQueueId()); + sws.lastEmittedRecordWatermark = record.watermark; + } + } + + /** Synchronous emitter for use w/o watermark synchronization. */ + private class SyncKinesisRecordEmitter extends AsyncKinesisRecordEmitter { + private final ConcurrentHashMap>> queues = + new ConcurrentHashMap<>(); + + @Override + public RecordQueue> getQueue(int producerIndex) { + return queues.computeIfAbsent(producerIndex, (key) -> { + return new RecordQueue>() { + @Override + public void put(RecordWrapper record) { + emit(record, this); + } + + @Override + public int getQueueId() { + return producerIndex; + } + + @Override + public int getSize() { + return 0; + } + + @Override + public RecordWrapper peek() { + return null; + } + + }; + }); + } + } + /** * Creates a Kinesis Data Fetcher. * @@ -234,7 +324,8 @@ public KinesisDataFetcher(List streams, Properties configProps, KinesisDeserializationSchema deserializationSchema, KinesisShardAssigner shardAssigner, - AssignerWithPeriodicWatermarks periodicWatermarkAssigner) { + AssignerWithPeriodicWatermarks periodicWatermarkAssigner, + WatermarkTracker watermarkTracker) { this(streams, sourceContext, sourceContext.getCheckpointLock(), @@ -243,6 +334,7 @@ public KinesisDataFetcher(List streams, deserializationSchema, shardAssigner, periodicWatermarkAssigner, + watermarkTracker, new AtomicReference<>(), new ArrayList<>(), createInitialSubscribedStreamsToLastDiscoveredShardsState(streams), @@ -258,6 +350,7 @@ protected KinesisDataFetcher(List streams, KinesisDeserializationSchema deserializationSchema, KinesisShardAssigner shardAssigner, AssignerWithPeriodicWatermarks periodicWatermarkAssigner, + WatermarkTracker watermarkTracker, AtomicReference error, List subscribedShardsState, HashMap subscribedStreamsToLastDiscoveredShardIds, @@ -272,6 +365,7 @@ protected KinesisDataFetcher(List streams, this.deserializationSchema = checkNotNull(deserializationSchema); this.shardAssigner = checkNotNull(shardAssigner); this.periodicWatermarkAssigner = periodicWatermarkAssigner; + this.watermarkTracker = watermarkTracker; this.kinesisProxyFactory = checkNotNull(kinesisProxyFactory); this.kinesis = kinesisProxyFactory.create(configProps); @@ -284,6 +378,17 @@ protected KinesisDataFetcher(List streams, this.shardConsumersExecutor = createShardConsumersThreadPool(runtimeContext.getTaskNameWithSubtasks()); + this.recordEmitter = createRecordEmitter(configProps); + } + + private RecordEmitter createRecordEmitter(Properties configProps) { + if (periodicWatermarkAssigner != null && watermarkTracker != null) { + int queueCapacity = Integer.parseInt(configProps.getProperty( + ConsumerConfigConstants.WATERMARK_SYNC_QUEUE_CAPACITY, + Integer.toString(AsyncKinesisRecordEmitter.DEFAULT_QUEUE_CAPACITY))); + return new AsyncKinesisRecordEmitter(queueCapacity); + } + return new SyncKinesisRecordEmitter(); } /** @@ -380,16 +485,37 @@ public void runFetcher() throws Exception { ProcessingTimeService timerService = ((StreamingRuntimeContext) runtimeContext).getProcessingTimeService(); LOG.info("Starting periodic watermark emitter with interval {}", periodicWatermarkIntervalMillis); new PeriodicWatermarkEmitter(timerService, periodicWatermarkIntervalMillis).start(); + if (watermarkTracker != null) { + // setup global watermark tracking + long watermarkSyncMillis = Long.parseLong( + getConsumerConfiguration().getProperty(ConsumerConfigConstants.WATERMARK_SYNC_MILLIS, + Long.toString(ConsumerConfigConstants.DEFAULT_WATERMARK_SYNC_MILLIS))); + watermarkTracker.setUpdateTimeoutMillis(watermarkSyncMillis * 3); // synchronization latency + watermarkTracker.open(runtimeContext); + new WatermarkSyncCallback(timerService, watermarkSyncMillis).start(); + // emit records ahead of watermark to offset synchronization latency + long lookaheadMillis = Long.parseLong( + getConsumerConfiguration().getProperty(ConsumerConfigConstants.WATERMARK_LOOKAHEAD_MILLIS, + Long.toString(0))); + recordEmitter.setMaxLookaheadMillis(Math.max(lookaheadMillis, watermarkSyncMillis * 3)); + } } this.shardIdleIntervalMillis = Long.parseLong( getConsumerConfiguration().getProperty(ConsumerConfigConstants.SHARD_IDLE_INTERVAL_MILLIS, Long.toString(ConsumerConfigConstants.DEFAULT_SHARD_IDLE_INTERVAL_MILLIS))); + + // run record emitter in separate thread since main thread is used for discovery + Thread thread = new Thread(this.recordEmitter); + thread.setName("recordEmitter-" + runtimeContext.getTaskNameWithSubtasks()); + thread.setDaemon(true); + thread.start(); } // ------------------------------------------------------------------------ // finally, start the infinite shard discovery and consumer launching loop; // we will escape from this loop only when shutdownFetcher() or stopWithError() is called + // TODO: have this thread emit the records for tracking backpressure final long discoveryIntervalMillis = Long.valueOf( configProps.getProperty( @@ -490,6 +616,11 @@ public void shutdownFetcher() { mainThread.interrupt(); // the main thread may be sleeping for the discovery interval } + if (watermarkTracker != null) { + watermarkTracker.close(); + } + this.recordEmitter.stop(); + if (LOG.isInfoEnabled()) { LOG.info("Shutting down the shard consumer threads of subtask {} ...", indexOfThisConsumerSubtask); } @@ -603,28 +734,48 @@ protected KinesisDeserializationSchema getClonedDeserializationSchema() { * @param lastSequenceNumber the last sequence number value to update */ protected void emitRecordAndUpdateState(T record, long recordTimestamp, int shardStateIndex, SequenceNumber lastSequenceNumber) { - // track per shard watermarks and emit timestamps extracted from the record, - // when a watermark assigner was configured. - if (periodicWatermarkAssigner != null) { - ShardWatermarkState sws = shardWatermarks.get(shardStateIndex); - Preconditions.checkNotNull( - sws, "shard watermark state initialized in registerNewSubscribedShardState"); + ShardWatermarkState sws = shardWatermarks.get(shardStateIndex); + Preconditions.checkNotNull( + sws, "shard watermark state initialized in registerNewSubscribedShardState"); + Watermark watermark = null; + if (sws.periodicWatermarkAssigner != null) { recordTimestamp = sws.periodicWatermarkAssigner.extractTimestamp(record, sws.lastRecordTimestamp); - sws.lastRecordTimestamp = recordTimestamp; - sws.lastUpdated = getCurrentTimeMillis(); + // track watermark per record since extractTimestamp has side effect + watermark = sws.periodicWatermarkAssigner.getCurrentWatermark(); } + sws.lastRecordTimestamp = recordTimestamp; + sws.lastUpdated = getCurrentTimeMillis(); + RecordWrapper recordWrapper = new RecordWrapper<>(record, recordTimestamp); + recordWrapper.shardStateIndex = shardStateIndex; + recordWrapper.lastSequenceNumber = lastSequenceNumber; + recordWrapper.watermark = watermark; + try { + sws.emitQueue.put(recordWrapper); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + /** + * Actual record emission called from the record emitter. + * + *

Responsible for tracking per shard watermarks and emit timestamps extracted from + * the record, when a watermark assigner was configured. + * + * @param rw + */ + private void emitRecordAndUpdateState(RecordWrapper rw) { synchronized (checkpointLock) { - if (record != null) { - sourceContext.collectWithTimestamp(record, recordTimestamp); + if (rw.getValue() != null) { + sourceContext.collectWithTimestamp(rw.getValue(), rw.timestamp); } else { LOG.warn("Skipping non-deserializable record at sequence number {} of shard {}.", - lastSequenceNumber, - subscribedShardsState.get(shardStateIndex).getStreamShardHandle()); + rw.lastSequenceNumber, + subscribedShardsState.get(rw.shardStateIndex).getStreamShardHandle()); } - - updateState(shardStateIndex, lastSequenceNumber); + updateState(rw.shardStateIndex, rw.lastSequenceNumber); } } @@ -689,6 +840,7 @@ public int registerNewSubscribedShardState(KinesisStreamShardState newSubscribed } catch (Exception e) { throw new RuntimeException("Failed to instantiate new WatermarkAssigner", e); } + sws.emitQueue = recordEmitter.getQueue(shardStateIndex); sws.lastUpdated = getCurrentTimeMillis(); sws.lastRecordTimestamp = Long.MIN_VALUE; shardWatermarks.put(shardStateIndex, sws); @@ -721,41 +873,57 @@ protected long getCurrentTimeMillis() { protected void emitWatermark() { LOG.debug("Evaluating watermark for subtask {} time {}", indexOfThisConsumerSubtask, getCurrentTimeMillis()); long potentialWatermark = Long.MAX_VALUE; + long potentialNextWatermark = Long.MAX_VALUE; long idleTime = (shardIdleIntervalMillis > 0) ? getCurrentTimeMillis() - shardIdleIntervalMillis : Long.MAX_VALUE; for (Map.Entry e : shardWatermarks.entrySet()) { + Watermark w = e.getValue().lastEmittedRecordWatermark; // consider only active shards, or those that would advance the watermark - Watermark w = e.getValue().periodicWatermarkAssigner.getCurrentWatermark(); - if (w != null && (e.getValue().lastUpdated >= idleTime || w.getTimestamp() > lastWatermark)) { + if (w != null && (e.getValue().lastUpdated >= idleTime + || e.getValue().emitQueue.getSize() > 0 + || w.getTimestamp() > lastWatermark)) { potentialWatermark = Math.min(potentialWatermark, w.getTimestamp()); + // for sync, use the watermark of the next record, when available + // otherwise watermark may stall when record is blocked by synchronization + RecordEmitter.RecordQueue> q = e.getValue().emitQueue; + RecordWrapper nextRecord = q.peek(); + Watermark nextWatermark = (nextRecord != null) ? nextRecord.watermark : w; + potentialNextWatermark = Math.min(potentialNextWatermark, nextWatermark.getTimestamp()); } } // advance watermark if possible (watermarks can only be ascending) if (potentialWatermark == Long.MAX_VALUE) { if (shardWatermarks.isEmpty() || shardIdleIntervalMillis > 0) { - LOG.debug("No active shard for subtask {}, marking the source idle.", + LOG.info("No active shard for subtask {}, marking the source idle.", indexOfThisConsumerSubtask); // no active shard, signal downstream operators to not wait for a watermark sourceContext.markAsTemporarilyIdle(); + isIdle = true; } - } else if (potentialWatermark > lastWatermark) { - LOG.debug("Emitting watermark {} from subtask {}", - potentialWatermark, - indexOfThisConsumerSubtask); - sourceContext.emitWatermark(new Watermark(potentialWatermark)); - lastWatermark = potentialWatermark; + } else { + if (potentialWatermark > lastWatermark) { + LOG.debug("Emitting watermark {} from subtask {}", + potentialWatermark, + indexOfThisConsumerSubtask); + sourceContext.emitWatermark(new Watermark(potentialWatermark)); + lastWatermark = potentialWatermark; + isIdle = false; + } + nextWatermark = potentialNextWatermark; } } /** Per shard tracking of watermark and last activity. */ private static class ShardWatermarkState { private AssignerWithPeriodicWatermarks periodicWatermarkAssigner; + private RecordEmitter.RecordQueue> emitQueue; private volatile long lastRecordTimestamp; private volatile long lastUpdated; + private volatile Watermark lastEmittedRecordWatermark; } /** @@ -785,6 +953,82 @@ public void onProcessingTime(long timestamp) { } } + /** Timer task to update shared watermark state. */ + private class WatermarkSyncCallback implements ProcessingTimeCallback { + + private final ProcessingTimeService timerService; + private final long interval; + private final MetricGroup shardMetricsGroup; + private long lastGlobalWatermark = Long.MIN_VALUE; + private long propagatedLocalWatermark = Long.MIN_VALUE; + private long logIntervalMillis = 60_000; + private int stalledWatermarkIntervalCount = 0; + private long lastLogged; + + WatermarkSyncCallback(ProcessingTimeService timerService, long interval) { + this.timerService = checkNotNull(timerService); + this.interval = interval; + this.shardMetricsGroup = consumerMetricGroup.addGroup("subtaskId", + String.valueOf(indexOfThisConsumerSubtask)); + this.shardMetricsGroup.gauge("localWatermark", () -> nextWatermark); + this.shardMetricsGroup.gauge("globalWatermark", () -> lastGlobalWatermark); + } + + public void start() { + LOG.info("Registering watermark tracker with interval {}", interval); + timerService.registerTimer(timerService.getCurrentProcessingTime() + interval, this); + } + + @Override + public void onProcessingTime(long timestamp) { + if (nextWatermark != Long.MIN_VALUE) { + long globalWatermark = lastGlobalWatermark; + // TODO: refresh watermark while idle + if (!(isIdle && nextWatermark == propagatedLocalWatermark)) { + globalWatermark = watermarkTracker.updateWatermark(nextWatermark); + propagatedLocalWatermark = nextWatermark; + } else { + LOG.info("WatermarkSyncCallback subtask: {} is idle", indexOfThisConsumerSubtask); + } + + if (timestamp - lastLogged > logIntervalMillis) { + lastLogged = System.currentTimeMillis(); + LOG.info("WatermarkSyncCallback subtask: {} local watermark: {}" + + ", global watermark: {}, delta: {} timeouts: {}, emitter: {}", + indexOfThisConsumerSubtask, + nextWatermark, + globalWatermark, + nextWatermark - globalWatermark, + watermarkTracker.getUpdateTimeoutCount(), + recordEmitter.printInfo()); + + // Following is for debugging non-reproducible issue with stalled watermark + if (globalWatermark == nextWatermark && globalWatermark == lastGlobalWatermark + && stalledWatermarkIntervalCount++ > 5) { + // subtask blocks watermark, log to aid troubleshooting + stalledWatermarkIntervalCount = 0; + for (Map.Entry e : shardWatermarks.entrySet()) { + RecordEmitter.RecordQueue> q = e.getValue().emitQueue; + RecordWrapper nextRecord = q.peek(); + if (nextRecord != null) { + LOG.info("stalled watermark {} key {} next watermark {} next timestamp {}", + nextWatermark, + e.getKey(), + nextRecord.watermark, + nextRecord.timestamp); + } + } + } + } + + lastGlobalWatermark = globalWatermark; + recordEmitter.setCurrentWatermark(globalWatermark); + } + // schedule next callback + timerService.registerTimer(timerService.getCurrentProcessingTime() + interval, this); + } + } + /** * Registers a metric group associated with the shard id of the provided {@link KinesisStreamShardState shardState}. * diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTracker.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTracker.java new file mode 100644 index 00000000000000..f150bb0d23b159 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTracker.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.runtime.taskexecutor.GlobalAggregateManager; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** + * A {@link WatermarkTracker} that shares state through {@link GlobalAggregateManager}. + */ +@PublicEvolving +public class JobManagerWatermarkTracker extends WatermarkTracker { + + private GlobalAggregateManager aggregateManager; + private final String aggregateName; + private final WatermarkAggregateFunction aggregateFunction = new WatermarkAggregateFunction(); + private final long logAccumulatorIntervalMillis; + private long updateTimeoutCount; + + public JobManagerWatermarkTracker(String aggregateName) { + this(aggregateName, -1); + } + + public JobManagerWatermarkTracker(String aggregateName, long logAccumulatorIntervalMillis) { + super(); + this.aggregateName = aggregateName; + this.logAccumulatorIntervalMillis = logAccumulatorIntervalMillis; + } + + @Override + public long updateWatermark(long localWatermark) { + WatermarkUpdate update = new WatermarkUpdate(); + update.id = getSubtaskId(); + update.watermark = localWatermark; + try { + byte[] resultBytes = aggregateManager.updateGlobalAggregate(aggregateName, + InstantiationUtil.serializeObject(update), aggregateFunction); + WatermarkResult result = InstantiationUtil.deserializeObject(resultBytes, + this.getClass().getClassLoader()); + this.updateTimeoutCount += result.updateTimeoutCount; + return result.watermark; + } catch (ClassNotFoundException | IOException ex) { + throw new RuntimeException(ex); + } + } + + @Override + public void open(RuntimeContext context) { + super.open(context); + this.aggregateFunction.updateTimeoutMillis = super.getUpdateTimeoutMillis(); + this.aggregateFunction.logAccumulatorIntervalMillis = logAccumulatorIntervalMillis; + Preconditions.checkArgument(context instanceof StreamingRuntimeContext); + StreamingRuntimeContext runtimeContext = (StreamingRuntimeContext) context; + this.aggregateManager = runtimeContext.getGlobalAggregateManager(); + } + + public long getUpdateTimeoutCount() { + return updateTimeoutCount; + } + + /** Watermark aggregation input. */ + protected static class WatermarkUpdate implements Serializable { + protected long watermark = Long.MIN_VALUE; + protected String id; + } + + /** Watermark aggregation result. */ + protected static class WatermarkResult implements Serializable { + protected long watermark = Long.MIN_VALUE; + protected long updateTimeoutCount = 0; + } + + /** + * Aggregate function for computing a combined watermark of parallel subtasks. + */ + private static class WatermarkAggregateFunction implements + AggregateFunction, byte[]> { + + private long updateTimeoutMillis = DEFAULT_UPDATE_TIMEOUT_MILLIS; + private long logAccumulatorIntervalMillis = -1; + + // TODO: wrap accumulator + static long addCount; + static long lastLogged; + private static final Logger LOG = LoggerFactory.getLogger(WatermarkAggregateFunction.class); + + @Override + public Map createAccumulator() { + return new HashMap<>(); + } + + @Override + public Map add(byte[] valueBytes, Map accumulator) { + addCount++; + final WatermarkUpdate value; + try { + value = InstantiationUtil.deserializeObject(valueBytes, this.getClass().getClassLoader()); + } catch (Exception e) { + throw new RuntimeException(e); + } + WatermarkState ws = accumulator.get(value.id); + if (ws == null) { + accumulator.put(value.id, ws = new WatermarkState()); + } + ws.watermark = value.watermark; + ws.lastUpdated = System.currentTimeMillis(); + return accumulator; + } + + @Override + public byte[] getResult(Map accumulator) { + long updateTimeoutCount = 0; + long currentTime = System.currentTimeMillis(); + long globalWatermark = Long.MAX_VALUE; + for (Map.Entry e : accumulator.entrySet()) { + WatermarkState ws = e.getValue(); + if (ws.lastUpdated + updateTimeoutMillis < currentTime) { + // ignore outdated entry + updateTimeoutCount++; + continue; + } + globalWatermark = Math.min(ws.watermark, globalWatermark); + } + + WatermarkResult result = new WatermarkResult(); + result.watermark = globalWatermark == Long.MAX_VALUE ? Long.MIN_VALUE : globalWatermark; + result.updateTimeoutCount = updateTimeoutCount; + + if (logAccumulatorIntervalMillis > 0) { + if (currentTime - lastLogged > logAccumulatorIntervalMillis) { + lastLogged = System.currentTimeMillis(); + LOG.info("WatermarkAggregateFunction added: {}, timeout: {}, map: {}", + addCount, updateTimeoutCount, accumulator); + } + } + + try { + return InstantiationUtil.serializeObject(result); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Map merge(Map accumulatorA, Map accumulatorB) { + // not required + throw new UnsupportedOperationException(); + } + } + +} diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitter.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitter.java new file mode 100644 index 00000000000000..17344b1e2ed0bb --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitter.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.runtime.operators.windowing.TimestampedValue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.PriorityQueue; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Emitter that handles event time synchronization between producer threads. + * + *

Records are organized into per producer queues that will block when capacity is exhausted. + * + *

Records are emitted by selecting the oldest available element of all producer queues, + * as long as the timestamp does not exceed the current shared watermark plus allowed lookahead interval. + * + * @param + */ +@Internal +public abstract class RecordEmitter implements Runnable { + private static final Logger LOG = LoggerFactory.getLogger(RecordEmitter.class); + + /** + * The default capacity of a single queue. + * + *

Larger queue size can lead to higher throughput, but also to + * very high heap space consumption, depending on the size of elements. + * + *

Note that this is difficult to tune, because it does not take into account + * the size of individual objects. + */ + public static final int DEFAULT_QUEUE_CAPACITY = 100; + + private final int queueCapacity; + private final ConcurrentHashMap> queues = new ConcurrentHashMap<>(); + private final ConcurrentHashMap, Boolean> emptyQueues = new ConcurrentHashMap<>(); + private final PriorityQueue> heads = new PriorityQueue<>(this::compareHeadElement); + private volatile boolean running = true; + private volatile long maxEmitTimestamp = Long.MAX_VALUE; + private long maxLookaheadMillis = 60 * 1000; // one minute + private long idleSleepMillis = 100; + private final Object condition = new Object(); + + public RecordEmitter(int queueCapacity) { + this.queueCapacity = queueCapacity; + } + + private int compareHeadElement(AsyncRecordQueue left, AsyncRecordQueue right) { + return Long.compare(left.headTimestamp, right.headTimestamp); + } + + /** + * Accepts records from readers. + * + * @param + */ + public interface RecordQueue { + void put(T record) throws InterruptedException; + + int getQueueId(); + + int getSize(); + + T peek(); + } + + private class AsyncRecordQueue implements RecordQueue { + private final ArrayBlockingQueue queue; + private final int queueId; + long headTimestamp; + + private AsyncRecordQueue(int queueId) { + super(); + this.queue = new ArrayBlockingQueue<>(queueCapacity); + this.queueId = queueId; + this.headTimestamp = Long.MAX_VALUE; + } + + public void put(T record) throws InterruptedException { + queue.put(record); + synchronized (condition) { + condition.notify(); + } + } + + public int getQueueId() { + return queueId; + } + + public int getSize() { + return queue.size(); + } + + public T peek() { + return queue.peek(); + } + + } + + /** + * The queue for the given producer (i.e. Kinesis shard consumer thread). + * + *

The producer may hold on to the queue for subsequent records. + * + * @param producerIndex + * @return + */ + public RecordQueue getQueue(int producerIndex) { + return queues.computeIfAbsent(producerIndex, (key) -> { + AsyncRecordQueue q = new AsyncRecordQueue<>(producerIndex); + emptyQueues.put(q, false); + return q; + }); + } + + /** + * How far ahead of the watermark records should be emitted. + * + *

Needs to account for the latency of obtaining the global watermark. + * + * @param maxLookaheadMillis + */ + public void setMaxLookaheadMillis(long maxLookaheadMillis) { + this.maxLookaheadMillis = maxLookaheadMillis; + LOG.info("[setMaxLookaheadMillis] Max lookahead millis set to {}", maxLookaheadMillis); + } + + /** + * Set the current watermark. + * + *

This watermark will be used to control which records to emit from the queues of pending + * elements. When an element timestamp is ahead of the watermark by at least {@link + * #maxLookaheadMillis}, elements in that queue will wait until the watermark advances. + * + * @param watermark + */ + public void setCurrentWatermark(long watermark) { + maxEmitTimestamp = watermark + maxLookaheadMillis; + synchronized (condition) { + condition.notify(); + } + LOG.info( + "[setCurrentWatermark] Current watermark set to {}, maxEmitTimestamp = {}", + watermark, + maxEmitTimestamp); + } + + /** + * Run loop, does not return unless {@link #stop()} was called. + */ + @Override + public void run() { + LOG.info("Starting emitter with maxLookaheadMillis: {}", this.maxLookaheadMillis); + + // emit available records, ordered by timestamp + AsyncRecordQueue min = heads.poll(); + runLoop: + while (running) { + // find a queue to emit from + while (min == null) { + // check new or previously empty queues + if (!emptyQueues.isEmpty()) { + for (AsyncRecordQueue queueHead : emptyQueues.keySet()) { + if (!queueHead.queue.isEmpty()) { + emptyQueues.remove(queueHead); + queueHead.headTimestamp = queueHead.queue.peek().getTimestamp(); + heads.offer(queueHead); + } + } + } + min = heads.poll(); + if (min == null) { + synchronized (condition) { + // wait for work + try { + condition.wait(idleSleepMillis); + } catch (InterruptedException e) { + continue runLoop; + } + } + } + } + + // wait until ready to emit min or another queue receives elements + while (min.headTimestamp > maxEmitTimestamp) { + synchronized (condition) { + // wait until ready to emit + try { + condition.wait(idleSleepMillis); + } catch (InterruptedException e) { + continue runLoop; + } + if (min.headTimestamp > maxEmitTimestamp && !emptyQueues.isEmpty()) { + // see if another queue can make progress + heads.offer(min); + min = null; + continue runLoop; + } + } + } + + // emit up to queue capacity records + // cap on empty queues since older records may arrive + AsyncRecordQueue nextQueue = heads.poll(); + T record; + int emitCount = 0; + while ((record = min.queue.poll()) != null) { + emit(record, min); + // track last record emitted, even when queue becomes empty + min.headTimestamp = record.getTimestamp(); + // potentially switch to next queue + if ((nextQueue != null && min.headTimestamp > nextQueue.headTimestamp) + || (min.headTimestamp > maxEmitTimestamp) + || (emitCount++ > queueCapacity && !emptyQueues.isEmpty())) { + break; + } + } + if (record == null) { + this.emptyQueues.put(min, true); + } else { + heads.offer(min); + } + min = nextQueue; + } + } + + public void stop() { + running = false; + } + + /** Emit the record. This is specific to a connector, like the Kinesis data fetcher. */ + public abstract void emit(T record, RecordQueue source); + + /** Summary of emit queues that can be used for logging. */ + public String printInfo() { + StringBuffer sb = new StringBuffer(); + sb.append(String.format("queues: %d, empty: %d", + queues.size(), emptyQueues.size())); + for (Map.Entry> e : queues.entrySet()) { + AsyncRecordQueue q = e.getValue(); + sb.append(String.format("\n%d timestamp: %s size: %d", + e.getValue().queueId, q.headTimestamp, q.queue.size())); + } + return sb.toString(); + } + +} diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTracker.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTracker.java new file mode 100644 index 00000000000000..f4207c7dac305d --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTracker.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; + +import java.io.Closeable; +import java.io.Serializable; + +/** + * The watermark tracker is responsible for aggregating watermarks across distributed operators. + *

It can be used for sub tasks of a single Flink source as well as multiple heterogeneous + * sources or other operators. + *

The class essentially functions like a distributed hash table that enclosing operators can + * use to adopt their processing / IO rates. + */ +@PublicEvolving +public abstract class WatermarkTracker implements Closeable, Serializable { + + public static final long DEFAULT_UPDATE_TIMEOUT_MILLIS = 60_000; + + /** + * Subtasks that have not provided a watermark update within the configured interval will be + * considered idle and excluded from target watermark calculation. + */ + private long updateTimeoutMillis = DEFAULT_UPDATE_TIMEOUT_MILLIS; + + /** + * Unique id for the subtask. + * Using string (instead of subtask index) so synchronization can spawn across multiple sources. + */ + private String subtaskId; + + /** Watermark state. */ + protected static class WatermarkState { + protected long watermark = Long.MIN_VALUE; + protected long lastUpdated; + + public long getWatermark() { + return watermark; + } + + @Override + public String toString() { + return "WatermarkState{watermark=" + watermark + ", lastUpdated=" + lastUpdated + '}'; + } + } + + protected String getSubtaskId() { + return this.subtaskId; + } + + protected long getUpdateTimeoutMillis() { + return this.updateTimeoutMillis; + } + + public abstract long getUpdateTimeoutCount(); + + /** + * Subtasks that have not provided a watermark update within the configured interval will be + * considered idle and excluded from target watermark calculation. + * + * @param updateTimeoutMillis + */ + public void setUpdateTimeoutMillis(long updateTimeoutMillis) { + this.updateTimeoutMillis = updateTimeoutMillis; + } + + /** + * Set the current watermark of the owning subtask and return the global low watermark based on + * the current state snapshot. Periodically called by the enclosing consumer instance, which is + * responsible for any timer management etc. + * + * @param localWatermark + * @return + */ + public abstract long updateWatermark(final long localWatermark); + + protected long getCurrentTime() { + return System.currentTimeMillis(); + } + + public void open(RuntimeContext context) { + if (context instanceof StreamingRuntimeContext) { + this.subtaskId = ((StreamingRuntimeContext) context).getOperatorUniqueID() + + "-" + context.getIndexOfThisSubtask(); + } else { + this.subtaskId = context.getTaskNameWithSubtasks(); + } + } + + @Override + public void close() { + // no work to do here + } + +} diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java index b38eef13fbad43..1ce05d189fd5f9 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java @@ -418,7 +418,7 @@ public TestFetcher( HashMap testStateSnapshot, List testInitialDiscoveryShards) { - super(streams, sourceContext, runtimeContext, configProps, deserializationSchema, DEFAULT_SHARD_ASSIGNER, null); + super(streams, sourceContext, runtimeContext, configProps, deserializationSchema, DEFAULT_SHARD_ASSIGNER, null, null); this.testStateSnapshot = testStateSnapshot; this.testInitialDiscoveryShards = testInitialDiscoveryShards; diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java index d36d68a57b679e..cbcd8b4e302faa 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java @@ -52,6 +52,7 @@ import org.apache.flink.streaming.connectors.kinesis.testutils.TestUtils; import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer; import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil; +import org.apache.flink.streaming.connectors.kinesis.util.WatermarkTracker; import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; import org.apache.flink.streaming.util.CollectingSourceContext; @@ -60,6 +61,7 @@ import com.amazonaws.services.kinesis.model.HashKeyRange; import com.amazonaws.services.kinesis.model.SequenceNumberRange; import com.amazonaws.services.kinesis.model.Shard; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -79,6 +81,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -737,6 +740,7 @@ protected KinesisDataFetcher createFetcher( deserializationSchema, getShardAssigner(), getPeriodicWatermarkAssigner(), + null, new AtomicReference<>(), new ArrayList<>(), subscribedStreamsToLastDiscoveredShardIds, @@ -775,6 +779,10 @@ protected KinesisDataFetcher createFetcher( public void emitWatermark(Watermark mark) { watermarks.add(mark); } + + @Override + public void markAsTemporarilyIdle() { + } }; new Thread( @@ -817,6 +825,164 @@ public void emitWatermark(Watermark mark) { assertThat(watermarks, org.hamcrest.Matchers.contains(new Watermark(-3), new Watermark(5))); } + @Test + public void testSourceSynchronization() throws Exception { + + final String streamName = "fakeStreamName"; + final Time maxOutOfOrderness = Time.milliseconds(5); + final long autoWatermarkInterval = 1_000; + final long watermarkSyncInterval = autoWatermarkInterval + 1; + + HashMap subscribedStreamsToLastDiscoveredShardIds = new HashMap<>(); + subscribedStreamsToLastDiscoveredShardIds.put(streamName, null); + + final KinesisDeserializationSchema deserializationSchema = + new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()); + Properties props = new Properties(); + props.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1"); + props.setProperty(ConsumerConfigConstants.SHARD_GETRECORDS_INTERVAL_MILLIS, Long.toString(10L)); + props.setProperty(ConsumerConfigConstants.WATERMARK_SYNC_MILLIS, + Long.toString(watermarkSyncInterval)); + props.setProperty(ConsumerConfigConstants.WATERMARK_LOOKAHEAD_MILLIS, Long.toString(5)); + + BlockingQueue shard1 = new LinkedBlockingQueue(); + BlockingQueue shard2 = new LinkedBlockingQueue(); + + Map>> streamToQueueMap = new HashMap<>(); + streamToQueueMap.put(streamName, Lists.newArrayList(shard1, shard2)); + + // override createFetcher to mock Kinesis + FlinkKinesisConsumer sourceFunc = + new FlinkKinesisConsumer(streamName, deserializationSchema, props) { + @Override + protected KinesisDataFetcher createFetcher( + List streams, + SourceFunction.SourceContext sourceContext, + RuntimeContext runtimeContext, + Properties configProps, + KinesisDeserializationSchema deserializationSchema) { + + KinesisDataFetcher fetcher = + new KinesisDataFetcher( + streams, + sourceContext, + sourceContext.getCheckpointLock(), + runtimeContext, + configProps, + deserializationSchema, + getShardAssigner(), + getPeriodicWatermarkAssigner(), + getWatermarkTracker(), + new AtomicReference<>(), + new ArrayList<>(), + subscribedStreamsToLastDiscoveredShardIds, + (props) -> FakeKinesisBehavioursFactory.blockingQueueGetRecords( + streamToQueueMap) + ) {}; + return fetcher; + } + }; + + sourceFunc.setShardAssigner( + (streamShardHandle, i) -> { + // shardId-000000000000 + return Integer.parseInt( + streamShardHandle.getShard().getShardId().substring("shardId-".length())); + }); + + sourceFunc.setPeriodicWatermarkAssigner(new TestTimestampExtractor(maxOutOfOrderness)); + + sourceFunc.setWatermarkTracker(new TestWatermarkTracker()); + + // there is currently no test harness specifically for sources, + // so we overlay the source thread here + AbstractStreamOperatorTestHarness testHarness = + new AbstractStreamOperatorTestHarness( + new StreamSource(sourceFunc), 1, 1, 0); + testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime); + testHarness.getExecutionConfig().setAutoWatermarkInterval(autoWatermarkInterval); + + testHarness.initializeEmptyState(); + testHarness.open(); + + final ConcurrentLinkedQueue results = testHarness.getOutput(); + + @SuppressWarnings("unchecked") + SourceFunction.SourceContext sourceContext = new CollectingSourceContext( + testHarness.getCheckpointLock(), results) { + @Override + public void markAsTemporarilyIdle() { + } + + @Override + public void emitWatermark(Watermark mark) { + results.add(mark); + } + }; + + new Thread( + () -> { + try { + sourceFunc.run(sourceContext); + } catch (InterruptedException e) { + // expected on cancel + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .start(); + + ArrayList expectedResults = new ArrayList<>(); + + final long record1 = 1; + shard1.put(Long.toString(record1)); + expectedResults.add(Long.toString(record1)); + awaitRecordCount(results, expectedResults.size()); + + // at this point we know the fetcher was initialized + final KinesisDataFetcher fetcher = org.powermock.reflect.Whitebox.getInternalState(sourceFunc, "fetcher"); + + // trigger watermark emit + testHarness.setProcessingTime(testHarness.getProcessingTime() + autoWatermarkInterval); + expectedResults.add(new Watermark(-4)); + // verify watermark + awaitRecordCount(results, expectedResults.size()); + assertThat(results, org.hamcrest.Matchers.contains(expectedResults.toArray())); + assertEquals(0, TestWatermarkTracker.WATERMARK.get()); + + // trigger sync + testHarness.setProcessingTime(testHarness.getProcessingTime() + 1); + TestWatermarkTracker.assertSingleWatermark(-4); + + final long record2 = record1 + (watermarkSyncInterval * 3) + 1; + shard1.put(Long.toString(record2)); + + // TODO: check for record received instead + Thread.sleep(100); + + // Advance the watermark. Since the new record is past global watermark + threshold, + // it won't be emitted and the watermark does not advance + testHarness.setProcessingTime(testHarness.getProcessingTime() + autoWatermarkInterval); + assertThat(results, org.hamcrest.Matchers.contains(expectedResults.toArray())); + assertEquals(3000L, (long) org.powermock.reflect.Whitebox.getInternalState(fetcher, "nextWatermark")); + TestWatermarkTracker.assertSingleWatermark(-4); + + // Trigger global watermark sync + testHarness.setProcessingTime(testHarness.getProcessingTime() + 1); + expectedResults.add(Long.toString(record2)); + awaitRecordCount(results, expectedResults.size()); + assertThat(results, org.hamcrest.Matchers.contains(expectedResults.toArray())); + TestWatermarkTracker.assertSingleWatermark(3000); + + // Trigger watermark update and emit + testHarness.setProcessingTime(testHarness.getProcessingTime() + autoWatermarkInterval); + expectedResults.add(new Watermark(3000)); + assertThat(results, org.hamcrest.Matchers.contains(expectedResults.toArray())); + + sourceFunc.cancel(); + testHarness.close(); + } + private void awaitRecordCount(ConcurrentLinkedQueue queue, int count) throws Exception { long timeoutMillis = System.currentTimeMillis() + 10_000; while (System.currentTimeMillis() < timeoutMillis && queue.size() < count) { @@ -837,4 +1003,23 @@ public long extractTimestamp(String element) { } } + private static class TestWatermarkTracker extends WatermarkTracker { + + private static final AtomicLong WATERMARK = new AtomicLong(); + + @Override + public long getUpdateTimeoutCount() { + return 0; + } + + @Override + public long updateWatermark(long localWatermark) { + WATERMARK.set(localWatermark); + return localWatermark; + } + + static void assertSingleWatermark(long expected) { + Assert.assertEquals(expected, WATERMARK.get()); + } + } } diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java index dbc71182b04430..93886f935246b3 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java @@ -113,9 +113,10 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedState() { KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")), Mockito.mock(KinesisProxyInterface.class)); + int shardIndex = fetcher.registerNewSubscribedShardState(subscribedShardsStateUnderTest.get(0)); new ShardConsumer<>( fetcher, - 0, + shardIndex, subscribedShardsStateUnderTest.get(0).getStreamShardHandle(), subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(), FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(1000, 9, 500L), @@ -151,9 +152,10 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithUnexpectedExpired KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")), Mockito.mock(KinesisProxyInterface.class)); + int shardIndex = fetcher.registerNewSubscribedShardState(subscribedShardsStateUnderTest.get(0)); new ShardConsumer<>( fetcher, - 0, + shardIndex, subscribedShardsStateUnderTest.get(0).getStreamShardHandle(), subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(), // Get a total of 1000 records with 9 getRecords() calls, @@ -195,9 +197,10 @@ public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithAdaptiveReads() { KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")), Mockito.mock(KinesisProxyInterface.class)); + int shardIndex = fetcher.registerNewSubscribedShardState(subscribedShardsStateUnderTest.get(0)); new ShardConsumer<>( fetcher, - 0, + shardIndex, subscribedShardsStateUnderTest.get(0).getStreamShardHandle(), subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(), // Initial number of records to fetch --> 10 diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java index f1fd06903d4b7c..3bb11bd2c92c9a 100644 --- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java @@ -74,6 +74,7 @@ public TestableKinesisDataFetcher( deserializationSchema, DEFAULT_SHARD_ASSIGNER, null, + null, thrownErrorUnderTest, subscribedShardsStateUnderTest, subscribedStreamsToLastDiscoveredShardIdsStateUnderTest, diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTrackerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTrackerTest.java new file mode 100644 index 00000000000000..b793b541c125fe --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/JobManagerWatermarkTrackerTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +/** Test for {@link JobManagerWatermarkTracker}. */ +public class JobManagerWatermarkTrackerTest { + + private static MiniCluster flink; + + @BeforeClass + public static void setUp() throws Exception { + final Configuration config = new Configuration(); + config.setInteger(RestOptions.PORT, 0); + + final MiniClusterConfiguration miniClusterConfiguration = new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(1) + .setNumSlotsPerTaskManager(1) + .build(); + + flink = new MiniCluster(miniClusterConfiguration); + + flink.start(); + } + + @AfterClass + public static void tearDown() throws Exception { + if (flink != null) { + flink.close(); + } + } + + @Test + public void testUpateWatermark() throws Exception { + final Configuration clientConfiguration = new Configuration(); + clientConfiguration.setInteger(RestOptions.RETRY_MAX_ATTEMPTS, 0); + + final StreamExecutionEnvironment env = StreamExecutionEnvironment.createRemoteEnvironment( + flink.getRestAddress().get().getHost(), + flink.getRestAddress().get().getPort(), + clientConfiguration); + + env.addSource(new TestSourceFunction(new JobManagerWatermarkTracker("fakeId"))) + .addSink(new SinkFunction() {}); + env.execute(); + } + + private static class TestSourceFunction extends RichSourceFunction { + + private final JobManagerWatermarkTracker tracker; + + public TestSourceFunction(JobManagerWatermarkTracker tracker) { + this.tracker = tracker; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + tracker.open(getRuntimeContext()); + } + + @Override + public void run(SourceContext ctx) { + Assert.assertEquals(998, tracker.updateWatermark(998)); + Assert.assertEquals(999, tracker.updateWatermark(999)); + } + + @Override + public void cancel() { + } + } + +} diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitterTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitterTest.java new file mode 100644 index 00000000000000..1948237566e347 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/RecordEmitterTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.streaming.runtime.operators.windowing.TimestampedValue; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** Test for {@link RecordEmitter}. */ +public class RecordEmitterTest { + + static List results = Collections.synchronizedList(new ArrayList<>()); + + private class TestRecordEmitter extends RecordEmitter { + + private TestRecordEmitter() { + super(DEFAULT_QUEUE_CAPACITY); + } + + @Override + public void emit(TimestampedValue record, RecordQueue queue) { + results.add(record); + } + } + + @Test + public void test() throws Exception { + + TestRecordEmitter emitter = new TestRecordEmitter(); + + final TimestampedValue one = new TimestampedValue<>("one", 1); + final TimestampedValue two = new TimestampedValue<>("two", 2); + final TimestampedValue five = new TimestampedValue<>("five", 5); + final TimestampedValue ten = new TimestampedValue<>("ten", 10); + + final RecordEmitter.RecordQueue queue0 = emitter.getQueue(0); + final RecordEmitter.RecordQueue queue1 = emitter.getQueue(1); + + queue0.put(one); + queue0.put(five); + queue0.put(ten); + + queue1.put(two); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.submit(emitter); + + long timeout = System.currentTimeMillis() + 10_000; + while (results.size() != 4 && System.currentTimeMillis() < timeout) { + Thread.sleep(100); + } + emitter.stop(); + executor.shutdownNow(); + + Assert.assertThat(results, Matchers.contains(one, five, two, ten)); + } + +} diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTrackerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTrackerTest.java new file mode 100644 index 00000000000000..3d59a45e9a6093 --- /dev/null +++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/util/WatermarkTrackerTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kinesis.util; + +import org.apache.flink.streaming.util.MockStreamingRuntimeContext; + +import org.apache.commons.lang3.mutable.MutableLong; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +/** Test for {@link WatermarkTracker}. */ +public class WatermarkTrackerTest { + + WatermarkTracker.WatermarkState wm1 = new WatermarkTracker.WatermarkState(); + MutableLong clock = new MutableLong(0); + + private class TestWatermarkTracker extends WatermarkTracker { + /** + * The watermarks of all sub tasks that participate in the synchronization. + */ + private final Map watermarks = new HashMap<>(); + + private long updateTimeoutCount = 0; + + @Override + protected long getCurrentTime() { + return clock.longValue(); + } + + @Override + public long updateWatermark(final long localWatermark) { + refreshWatermarkSnapshot(this.watermarks); + + long currentTime = getCurrentTime(); + String subtaskId = this.getSubtaskId(); + + WatermarkState ws = watermarks.get(subtaskId); + if (ws == null) { + watermarks.put(subtaskId, ws = new WatermarkState()); + } + ws.lastUpdated = currentTime; + ws.watermark = Math.max(ws.watermark, localWatermark); + saveWatermark(subtaskId, ws); + + long globalWatermark = ws.watermark; + for (Map.Entry e : watermarks.entrySet()) { + ws = e.getValue(); + if (ws.lastUpdated + getUpdateTimeoutMillis() < currentTime) { + // ignore outdated subtask + updateTimeoutCount++; + continue; + } + globalWatermark = Math.min(ws.watermark, globalWatermark); + } + return globalWatermark; + } + + protected void refreshWatermarkSnapshot(Map watermarks) { + watermarks.put("wm1", wm1); + } + + protected void saveWatermark(String id, WatermarkState ws) { + // do nothing + } + + public long getUpdateTimeoutCount() { + return updateTimeoutCount; + } + } + + @Test + public void test() { + long watermark = 0; + TestWatermarkTracker ws = new TestWatermarkTracker(); + ws.open(new MockStreamingRuntimeContext(false, 1, 0)); + Assert.assertEquals(Long.MIN_VALUE, ws.updateWatermark(Long.MIN_VALUE)); + Assert.assertEquals(Long.MIN_VALUE, ws.updateWatermark(watermark)); + // timeout wm1 + clock.add(WatermarkTracker.DEFAULT_UPDATE_TIMEOUT_MILLIS + 1); + Assert.assertEquals(watermark, ws.updateWatermark(watermark)); + Assert.assertEquals(watermark, ws.updateWatermark(watermark - 1)); + + // min watermark + wm1.watermark = watermark + 1; + wm1.lastUpdated = clock.longValue(); + Assert.assertEquals(watermark, ws.updateWatermark(watermark)); + Assert.assertEquals(watermark + 1, ws.updateWatermark(watermark + 1)); + } + +} From 370e0cb427c2908677d9f32a1fcfd84cd77bb445 Mon Sep 17 00:00:00 2001 From: godfrey he Date: Thu, 30 May 2019 15:36:51 +0800 Subject: [PATCH 32/92] [FLINK-12610][table-planner-blink] Introduce aggregate related planner rules, which includes: 1. AggregateCalcMergeRule, that recognizes Aggregate on top of a Calc and if possible aggregate through the calc or removes the calc 2. AggregateReduceGroupingRule, that reduces unless grouping columns 3. PruneAggregateCallRule, that removes unreferenced AggregateCall from Aggregate 4. FlinkAggregateRemoveRule, that is copied from Calcite's AggregateRemoveRule, and supports SUM, MIN, MAX, AUXILIARY_GROUP functions in non-empty group aggregate 5. FlinkAggregateJoinTransposeRule, that is copied from Calcite's AggregateJoinTransposeRule, and supports Left/Right outer join and aggregate with AUXILIARY_GROUP --- .../calcite/sql/SqlSplittableAggFunction.java | 374 +++++++ .../FlinkAggregateJoinTransposeRule.java | 593 +++++++++++ .../logical/FlinkAggregateRemoveRule.java | 131 +++ .../batch/BatchExecHashAggregate.scala | 8 +- .../BatchExecHashWindowAggregateBase.scala | 2 +- .../batch/BatchExecSortMergeJoin.scala | 2 +- .../table/plan/rules/FlinkBatchRuleSets.scala | 20 +- .../plan/rules/FlinkStreamRuleSets.scala | 20 +- .../logical/AggregateReduceGroupingRule.scala | 124 +++ .../FlinkAggregateJoinTransposeRule.scala | 73 -- .../logical/PruneAggregateCallRule.scala | 201 ++++ .../sql/agg/AggregateReduceGroupingTest.xml | 974 ++++++++++++++++++ .../join/BroadcastHashSemiAntiJoinTest.xml | 4 +- .../plan/batch/sql/join/LookupJoinTest.xml | 28 +- .../sql/join/NestedLoopSemiAntiJoinTest.xml | 30 +- .../plan/batch/sql/join/SemiAntiJoinTest.xml | 30 +- .../sql/join/ShuffledHashSemiAntiJoinTest.xml | 4 +- .../sql/join/SortMergeSemiAntiJoinTest.xml | 4 +- .../AggregateReduceGroupingRuleTest.xml | 909 ++++++++++++++++ .../CalcPruneAggregateCallRuleTest.xml | 377 +++++++ ...inkAggregateInnerJoinTransposeRuleTest.xml | 257 +++++ ...inkAggregateOuterJoinTransposeRuleTest.xml | 267 +++++ .../logical/FlinkAggregateRemoveRuleTest.xml | 528 ++++++++++ .../ProjectPruneAggregateCallRuleTest.xml | 379 +++++++ .../table/plan/stream/sql/join/JoinTest.xml | 99 +- .../plan/stream/sql/join/SemiAntiJoinTest.xml | 24 +- .../plan/batch/sql/RemoveCollationTest.scala | 20 +- .../plan/batch/sql/RemoveShuffleTest.scala | 6 +- .../sql/agg/AggregateReduceGroupingTest.scala | 24 + .../AggregateReduceGroupingTestBase.scala | 318 ++++++ .../AggregateReduceGroupingRuleTest.scala | 48 + .../CalcPruneAggregateCallRuleTest.scala | 55 + ...kAggregateInnerJoinTransposeRuleTest.scala | 150 +++ ...kAggregateOuterJoinTransposeRuleTest.scala | 124 +++ .../FlinkAggregateRemoveRuleTest.scala | 237 +++++ .../JoinDeriveNullFilterRuleTest.scala | 10 +- .../ProjectPruneAggregateCallRuleTest.scala | 51 + .../PruneAggregateCallRuleTestBase.scala | 175 ++++ .../table/runtime/batch/sql/CalcITCase.scala | 13 +- .../table/runtime/batch/sql/LimitITCase.scala | 2 +- .../runtime/batch/sql/OverWindowITCase.scala | 12 +- .../table/runtime/batch/sql/RankITCase.scala | 4 +- .../table/runtime/batch/sql/UnionITCase.scala | 7 +- .../agg/AggregateJoinTransposeITCase.scala | 205 ++++ .../agg/AggregateReduceGroupingITCase.scala | 405 ++++++++ .../batch/sql/agg/AggregateRemoveITCase.scala | 214 ++++ .../sql/agg/PruneAggregateCallITCase.scala | 132 +++ .../batch/sql/join/InnerJoinITCase.scala | 8 +- ...la => JoinConditionTypeCoerceITCase.scala} | 10 +- .../runtime/batch/sql/join/JoinITCase.scala | 18 +- .../batch/sql/join/OuterJoinITCase.scala | 6 +- .../runtime/stream/sql/AggregateITCase.scala | 9 +- .../stream/sql/AggregateRemoveITCase.scala | 254 +++++ .../table/runtime/stream/sql/CalcITCase.scala | 10 +- .../stream/sql/DeduplicateITCase.scala | 6 +- .../runtime/stream/sql/OverWindowITCase.scala | 17 +- .../stream/sql/PruneAggregateCallITCase.scala | 130 +++ .../stream/sql/TemporalSortITCase.scala | 2 +- .../utils/BatchScalaTableEnvUtil.scala | 20 +- .../runtime/utils/BatchTableEnvUtil.scala | 31 +- .../table/runtime/utils/BatchTestBase.scala | 20 +- .../runtime/utils/StreamTableEnvUtil.scala | 52 + .../table/runtime/utils/StreamTestData.scala | 109 -- .../flink/table/runtime/utils/TestData.scala | 146 +-- .../flink/table/util/TableTestBase.scala | 20 +- 65 files changed, 8024 insertions(+), 518 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRule.java create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRule.scala delete mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRule.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.xml create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/common/AggregateReduceGroupingTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRuleTestBase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateJoinTransposeITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateReduceGroupingITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateRemoveITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/PruneAggregateCallITCase.scala rename flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/{JoinConditionTypeCoerceRuleITCase.scala => JoinConditionTypeCoerceITCase.scala} (97%) create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateRemoveITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/PruneAggregateCallITCase.scala create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTableEnvUtil.scala delete mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestData.scala diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java new file mode 100644 index 00000000000000..a69a82d208e406 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.sql; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.calcite.util.mapping.Mappings; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; + +/** + * This file is copied from Calcite and made the following changes: + * 1. makeProperRexNodeForOuterJoin function added for CountSplitter and AbstractSumSplitter. + * 2. If the join type is left or right outer join then make the proper rexNode, or follow the previous logic. + * + * This copy can be removed once [CALCITE-2378] is fixed. + */ + +/** + * Aggregate function that can be split into partial aggregates. + * + *

For example, {@code COUNT(x)} can be split into {@code COUNT(x)} on + * subsets followed by {@code SUM} to combine those counts. + */ +public interface SqlSplittableAggFunction { + AggregateCall split(AggregateCall aggregateCall, + Mappings.TargetMapping mapping); + + /** Called to generate an aggregate for the other side of the join + * than the side aggregate call's arguments come from. Returns null if + * no aggregate is required. */ + AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e); + + /** Generates an aggregate call to merge sub-totals. + * + *

Most implementations will add a single aggregate call to + * {@code aggCalls}, and return a {@link RexInputRef} that points to it. + * + * @param rexBuilder Rex builder + * @param extra Place to define extra input expressions + * @param offset Offset due to grouping columns (and indicator columns if + * applicable) + * @param inputRowType Input row type + * @param aggregateCall Source aggregate call + * @param leftSubTotal Ordinal of the sub-total coming from the left side of + * the join, or -1 if there is no such sub-total + * @param rightSubTotal Ordinal of the sub-total coming from the right side + * of the join, or -1 if there is no such sub-total + * @param joinRelType the join type + * + * @return Aggregate call + */ + AggregateCall topSplit(RexBuilder rexBuilder, Registry extra, + int offset, RelDataType inputRowType, AggregateCall aggregateCall, + int leftSubTotal, int rightSubTotal, JoinRelType joinRelType); + + /** Generates an expression for the value of the aggregate function when + * applied to a single row. + * + *

For example, if there is one row: + *

    + *
  • {@code SUM(x)} is {@code x} + *
  • {@code MIN(x)} is {@code x} + *
  • {@code MAX(x)} is {@code x} + *
  • {@code COUNT(x)} is {@code CASE WHEN x IS NOT NULL THEN 1 ELSE 0 END 1} + * which can be simplified to {@code 1} if {@code x} is never null + *
  • {@code COUNT(*)} is 1 + *
+ * + * @param rexBuilder Rex builder + * @param inputRowType Input row type + * @param aggregateCall Aggregate call + * + * @return Expression for single row + */ + RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, + AggregateCall aggregateCall); + + /** Collection in which one can register an element. Registering may return + * a reference to an existing element. + * + * @param element type */ + interface Registry { + int register(E e); + } + + /** Splitting strategy for {@code COUNT}. + * + *

COUNT splits into itself followed by SUM. (Actually + * SUM0, because the total needs to be 0, not null, if there are 0 rows.) + * This rule works for any number of arguments to COUNT, including COUNT(*). + */ + class CountSplitter implements SqlSplittableAggFunction { + public static final CountSplitter INSTANCE = new CountSplitter(); + + public AggregateCall split(AggregateCall aggregateCall, + Mappings.TargetMapping mapping) { + return aggregateCall.transform(mapping); + } + + public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) { + return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, + ImmutableIntList.of(), -1, + typeFactory.createSqlType(SqlTypeName.BIGINT), null); + } + + /** + * This new function create a proper RexNode for {@coide COUNT} Agg with OuterJoin Condition. + */ + private RexNode makeProperRexNodeForOuterJoin(RexBuilder rexBuilder, + RelDataType inputRowType, + AggregateCall aggregateCall, + int index) { + RexInputRef inputRef = rexBuilder.makeInputRef(inputRowType.getFieldList().get(index).getType(), index); + RexLiteral literal; + boolean isCountStar = aggregateCall.getArgList() == null || aggregateCall.getArgList().isEmpty(); + if (isCountStar) { + literal = rexBuilder.makeExactLiteral(BigDecimal.ONE); + } else { + literal = rexBuilder.makeExactLiteral(BigDecimal.ZERO); + } + RexNode predicate = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, inputRef); + return rexBuilder.makeCall(SqlStdOperatorTable.CASE, + predicate, + literal, + rexBuilder.makeCast(aggregateCall.type, inputRef) + ); + } + + public AggregateCall topSplit(RexBuilder rexBuilder, + Registry extra, int offset, RelDataType inputRowType, + AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal, + JoinRelType joinRelType) { + final List merges = new ArrayList<>(); + if (leftSubTotal >= 0) { + // add support for right outer join + if (joinRelType == JoinRelType.RIGHT) { + merges.add( + makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, leftSubTotal) + ); + } else { + // if it's a inner join, then do the previous logic + merges.add( + rexBuilder.makeInputRef(aggregateCall.type, leftSubTotal)); + } + } + if (rightSubTotal >= 0) { + // add support for left outer join + if (joinRelType == JoinRelType.LEFT) { + merges.add( + makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, rightSubTotal) + ); + } else { + // if it's a inner join, then do the previous logic + merges.add( + rexBuilder.makeInputRef(aggregateCall.type, rightSubTotal)); + } + } + RexNode node; + switch (merges.size()) { + case 1: + node = merges.get(0); + break; + case 2: + node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges); + break; + default: + throw new AssertionError("unexpected count " + merges); + } + int ordinal = extra.register(node); + return AggregateCall.create(SqlStdOperatorTable.SUM0, false, false, + ImmutableList.of(ordinal), -1, aggregateCall.type, + aggregateCall.name); + } + + /** + * {@inheritDoc} + * + *

{@code COUNT(*)}, and {@code COUNT} applied to all NOT NULL arguments, + * become {@code 1}; otherwise + * {@code CASE WHEN arg0 IS NOT NULL THEN 1 ELSE 0 END}. + */ + public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, + AggregateCall aggregateCall) { + final List predicates = new ArrayList<>(); + for (Integer arg : aggregateCall.getArgList()) { + final RelDataType type = inputRowType.getFieldList().get(arg).getType(); + if (type.isNullable()) { + predicates.add( + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, + rexBuilder.makeInputRef(type, arg))); + } + } + final RexNode predicate = + RexUtil.composeConjunction(rexBuilder, predicates, true); + if (predicate == null) { + return rexBuilder.makeExactLiteral(BigDecimal.ONE); + } else { + return rexBuilder.makeCall(SqlStdOperatorTable.CASE, predicate, + rexBuilder.makeExactLiteral(BigDecimal.ONE), + rexBuilder.makeExactLiteral(BigDecimal.ZERO)); + } + } + } + + /** Aggregate function that splits into two applications of itself. + * + *

Examples are MIN and MAX. */ + class SelfSplitter implements SqlSplittableAggFunction { + public static final SelfSplitter INSTANCE = new SelfSplitter(); + + public RexNode singleton(RexBuilder rexBuilder, + RelDataType inputRowType, AggregateCall aggregateCall) { + final int arg = aggregateCall.getArgList().get(0); + final RelDataTypeField field = inputRowType.getFieldList().get(arg); + return rexBuilder.makeInputRef(field.getType(), arg); + } + + public AggregateCall split(AggregateCall aggregateCall, + Mappings.TargetMapping mapping) { + return aggregateCall.transform(mapping); + } + + public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) { + return null; // no aggregate function required on other side + } + + public AggregateCall topSplit(RexBuilder rexBuilder, + Registry extra, int offset, RelDataType inputRowType, + AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal, + JoinRelType joinRelType) { + assert (leftSubTotal >= 0) != (rightSubTotal >= 0); + final int arg = leftSubTotal >= 0 ? leftSubTotal : rightSubTotal; + return aggregateCall.copy(ImmutableIntList.of(arg), -1); + } + } + + /** Common Splitting strategy for {@coide SUM} and {@coide SUM0}. */ + abstract class AbstractSumSplitter implements SqlSplittableAggFunction { + + public RexNode singleton(RexBuilder rexBuilder, + RelDataType inputRowType, AggregateCall aggregateCall) { + final int arg = aggregateCall.getArgList().get(0); + final RelDataTypeField field = inputRowType.getFieldList().get(arg); + return rexBuilder.makeInputRef(field.getType(), arg); + } + + public AggregateCall split(AggregateCall aggregateCall, + Mappings.TargetMapping mapping) { + return aggregateCall.transform(mapping); + } + + public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) { + return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, + ImmutableIntList.of(), -1, + typeFactory.createSqlType(SqlTypeName.BIGINT), null); + } + + /** + * This new function create a proper RexNode for {@coide SUM} Agg with OuterJoin Condition. + */ + private RexNode makeProperRexNodeForOuterJoin(RexBuilder rexBuilder, + RelDataType inputRowType, + AggregateCall aggregateCall, + int index) { + RexInputRef inputRef = rexBuilder.makeInputRef(inputRowType.getFieldList().get(index).getType(), index); + RexLiteral literal = rexBuilder.makeExactLiteral(BigDecimal.ZERO); + RexNode predicate = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, inputRef); + return rexBuilder.makeCall(SqlStdOperatorTable.CASE, + predicate, + literal, + rexBuilder.makeCast(aggregateCall.type, inputRef) + ); + } + + public AggregateCall topSplit(RexBuilder rexBuilder, + Registry extra, int offset, RelDataType inputRowType, + AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal, + JoinRelType joinRelType) { + final List merges = new ArrayList<>(); + final List fieldList = inputRowType.getFieldList(); + if (leftSubTotal >= 0) { + // add support for left outer join + if (joinRelType == JoinRelType.RIGHT && getMergeAggFunctionOfTopSplit() == SqlStdOperatorTable.SUM0) { + merges.add(makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, leftSubTotal)); + } else { + // if it's a inner join, then do the previous logic + final RelDataType type = fieldList.get(leftSubTotal).getType(); + merges.add(rexBuilder.makeInputRef(type, leftSubTotal)); + } + } + if (rightSubTotal >= 0) { + // add support for right outer join + if (joinRelType == JoinRelType.LEFT && getMergeAggFunctionOfTopSplit() == SqlStdOperatorTable.SUM0) { + merges.add(makeProperRexNodeForOuterJoin(rexBuilder, inputRowType, aggregateCall, offset + rightSubTotal)); + } else { + // if it's a inner join, then do the previous logic + final RelDataType type = fieldList.get(rightSubTotal).getType(); + merges.add(rexBuilder.makeInputRef(type, rightSubTotal)); + } + } + RexNode node; + switch (merges.size()) { + case 1: + node = merges.get(0); + break; + case 2: + node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges); + node = rexBuilder.makeAbstractCast(aggregateCall.type, node); + break; + default: + throw new AssertionError("unexpected count " + merges); + } + int ordinal = extra.register(node); + return AggregateCall.create(getMergeAggFunctionOfTopSplit(), false, false, + ImmutableList.of(ordinal), -1, aggregateCall.type, + aggregateCall.name); + } + + protected abstract SqlAggFunction getMergeAggFunctionOfTopSplit(); + + } + + /** Splitting strategy for {@coide SUM}. */ + class SumSplitter extends AbstractSumSplitter { + + public static final SumSplitter INSTANCE = new SumSplitter(); + + @Override public SqlAggFunction getMergeAggFunctionOfTopSplit() { + return SqlStdOperatorTable.SUM; + } + + } + + /** Splitting strategy for {@code SUM0}. */ + class Sum0Splitter extends AbstractSumSplitter { + + public static final Sum0Splitter INSTANCE = new Sum0Splitter(); + + @Override public SqlAggFunction getMergeAggFunctionOfTopSplit() { + return SqlStdOperatorTable.SUM0; + } + } +} + +// End SqlSplittableAggFunction.java diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java new file mode 100644 index 00000000000000..10c0b940eb53e5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical; + +import org.apache.flink.table.plan.util.AggregateUtil; +import org.apache.flink.util.Preconditions; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.plan.volcano.RelSubset; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.SingleRel; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinInfo; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalSnapshot; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlSplittableAggFunction; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.Bug; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; +import org.apache.calcite.util.mapping.Mapping; +import org.apache.calcite.util.mapping.Mappings; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import scala.Tuple2; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +/** + * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.AggregateJoinTransposeRule}. + * Modification: + * - Do not match temporal join since lookup table source doesn't support aggregate. + * - Support Left/Right Outer Join + * - Fix type mismatch error + * - Support aggregate with AUXILIARY_GROUP + */ + +/** + * Planner rule that pushes an + * {@link org.apache.calcite.rel.core.Aggregate} + * past a {@link org.apache.calcite.rel.core.Join}. + */ +public class FlinkAggregateJoinTransposeRule extends RelOptRule { + public static final FlinkAggregateJoinTransposeRule INSTANCE = + new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, + RelFactories.LOGICAL_BUILDER, false, false); + + /** Extended instance of the rule that can push down aggregate functions. */ + public static final FlinkAggregateJoinTransposeRule EXTENDED = + new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, + RelFactories.LOGICAL_BUILDER, true, false); + + public static final FlinkAggregateJoinTransposeRule LEFT_RIGHT_OUTER_JOIN_EXTENDED = + new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, + RelFactories.LOGICAL_BUILDER, true, true); + + private final boolean allowFunctions; + + private final boolean allowLeftOrRightOuterJoin; + + /** Creates an FlinkAggregateJoinTransposeRule. */ + public FlinkAggregateJoinTransposeRule(Class aggregateClass, + Class joinClass, RelBuilderFactory relBuilderFactory, + boolean allowFunctions, boolean allowLeftOrRightOuterJoin) { + super( + operandJ(aggregateClass, null, + aggregate -> aggregate.getGroupType() == Aggregate.Group.SIMPLE, + operand(joinClass, any())), + relBuilderFactory, null); + + this.allowFunctions = allowFunctions; + this.allowLeftOrRightOuterJoin = allowLeftOrRightOuterJoin; + } + + @Deprecated // to be removed before 2.0 + public FlinkAggregateJoinTransposeRule(Class aggregateClass, + RelFactories.AggregateFactory aggregateFactory, + Class joinClass, + RelFactories.JoinFactory joinFactory) { + this(aggregateClass, joinClass, + RelBuilder.proto(aggregateFactory, joinFactory), false, false); + } + + @Deprecated // to be removed before 2.0 + public FlinkAggregateJoinTransposeRule(Class aggregateClass, + RelFactories.AggregateFactory aggregateFactory, + Class joinClass, + RelFactories.JoinFactory joinFactory, + boolean allowFunctions) { + this(aggregateClass, joinClass, + RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions, false); + } + + @Deprecated // to be removed before 2.0 + public FlinkAggregateJoinTransposeRule(Class aggregateClass, + RelFactories.AggregateFactory aggregateFactory, + Class joinClass, + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory) { + this(aggregateClass, joinClass, + RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false, false); + } + + @Deprecated // to be removed before 2.0 + public FlinkAggregateJoinTransposeRule(Class aggregateClass, + RelFactories.AggregateFactory aggregateFactory, + Class joinClass, + RelFactories.JoinFactory joinFactory, + RelFactories.ProjectFactory projectFactory, + boolean allowFunctions) { + this(aggregateClass, joinClass, + RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), + allowFunctions, false); + } + + private boolean containsSnapshot(RelNode relNode) { + RelNode original = null; + if (relNode instanceof RelSubset) { + original = ((RelSubset) relNode).getOriginal(); + } else if (relNode instanceof HepRelVertex) { + original = ((HepRelVertex) relNode).getCurrentRel(); + } else { + original = relNode; + } + if (original instanceof LogicalSnapshot) { + return true; + } else if (original instanceof SingleRel) { + return containsSnapshot(((SingleRel) original).getInput()); + } else { + return false; + } + } + + @Override + public boolean matches(RelOptRuleCall call) { + // avoid push aggregates through dim join + Join join = call.rel(1); + RelNode right = join.getRight(); + // right tree should not contain temporal table + return !containsSnapshot(right); + } + + public void onMatch(RelOptRuleCall call) { + final Aggregate origAgg = call.rel(0); + final Join join = call.rel(1); + final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder(); + final RelBuilder relBuilder = call.builder(); + + boolean isLeftOrRightOuterJoin = + join.getJoinType() == JoinRelType.LEFT || join.getJoinType() == JoinRelType.RIGHT; + + if (join.getJoinType() != JoinRelType.INNER && !(allowLeftOrRightOuterJoin && isLeftOrRightOuterJoin)) { + return; + } + + // converts an aggregate with AUXILIARY_GROUP to a regular aggregate. + // if the converted aggregate can be push down, + // AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this rule + final Pair> newAggAndProject = toRegularAggregate(origAgg); + final Aggregate aggregate = newAggAndProject.left; + final List projectAfterAgg = newAggAndProject.right; + + // If any aggregate functions do not support splitting, bail out + // If any aggregate call has a filter or distinct, bail out + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { + if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) + == null) { + return; + } + if (allowLeftOrRightOuterJoin && isLeftOrRightOuterJoin) { + // todo do not support max/min agg until we've built the proper model + if (aggregateCall.getAggregation().kind == SqlKind.MAX || + aggregateCall.getAggregation().kind == SqlKind.MIN) { + return; + } + } + if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) { + return; + } + } + + if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) { + return; + } + + // Do the columns used by the join appear in the output of the aggregate? + final ImmutableBitSet aggregateColumns = aggregate.getGroupSet(); + final RelMetadataQuery mq = call.getMetadataQuery(); + ImmutableBitSet keyColumns; + if (!isLeftOrRightOuterJoin) { + keyColumns = keyColumns(aggregateColumns, + mq.getPulledUpPredicates(join).pulledUpPredicates); + } else { + // this is an incomplete implementation + if (isAggregateKeyApplicable(aggregateColumns, join)) { + keyColumns = keyColumns(aggregateColumns, + com.google.common.collect.ImmutableList.copyOf(RelOptUtil.conjunctions(join.getCondition()))); + } else { + keyColumns = aggregateColumns; + } + } + final ImmutableBitSet joinColumns = + RelOptUtil.InputFinder.bits(join.getCondition()); + final boolean allColumnsInAggregate = + keyColumns.contains(joinColumns); + final ImmutableBitSet belowAggregateColumns = + aggregateColumns.union(joinColumns); + + // Split join condition + final List leftKeys = com.google.common.collect.Lists.newArrayList(); + final List rightKeys = com.google.common.collect.Lists.newArrayList(); + final List filterNulls = com.google.common.collect.Lists.newArrayList(); + RexNode nonEquiConj = + RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), + join.getCondition(), leftKeys, rightKeys, filterNulls); + // If it contains non-equi join conditions, we bail out + if (!nonEquiConj.isAlwaysTrue()) { + return; + } + + // Push each aggregate function down to each side that contains all of its + // arguments. Note that COUNT(*), because it has no arguments, can go to + // both sides. + final Map map = new HashMap<>(); + final List sides = new ArrayList<>(); + int uniqueCount = 0; + int offset = 0; + int belowOffset = 0; + for (int s = 0; s < 2; s++) { + final Side side = new Side(); + final RelNode joinInput = join.getInput(s); + int fieldCount = joinInput.getRowType().getFieldCount(); + final ImmutableBitSet fieldSet = + ImmutableBitSet.range(offset, offset + fieldCount); + final ImmutableBitSet belowAggregateKeyNotShifted = + belowAggregateColumns.intersect(fieldSet); + for (Ord c : Ord.zip(belowAggregateKeyNotShifted)) { + map.put(c.e, belowOffset + c.i); + } + final Mappings.TargetMapping mapping = + s == 0 + ? Mappings.createIdentity(fieldCount) + : Mappings.createShiftMapping(fieldCount + offset, 0, offset, + fieldCount); + + final ImmutableBitSet belowAggregateKey = + belowAggregateKeyNotShifted.shift(-offset); + final boolean unique; + if (!allowFunctions) { + assert aggregate.getAggCallList().isEmpty(); + // If there are no functions, it doesn't matter as much whether we + // aggregate the inputs before the join, because there will not be + // any functions experiencing a cartesian product effect. + // + // But finding out whether the input is already unique requires a call + // to areColumnsUnique that currently (until [CALCITE-1048] "Make + // metadata more robust" is fixed) places a heavy load on + // the metadata system. + // + // So we choose to imagine the the input is already unique, which is + // untrue but harmless. + // + Util.discard(Bug.CALCITE_1048_FIXED); + unique = true; + } else { + final Boolean unique0 = + mq.areColumnsUnique(joinInput, belowAggregateKey); + unique = unique0 != null && unique0; + } + if (unique) { + ++uniqueCount; + side.aggregate = false; + relBuilder.push(joinInput); + final Map belowAggregateKeyToNewProjectMap = new HashMap<>(); + final List projects = new ArrayList<>(); + for (Integer i : belowAggregateKey) { + belowAggregateKeyToNewProjectMap.put(i, projects.size()); + projects.add(relBuilder.field(i)); + } + for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) { + final SqlAggFunction aggregation = aggCall.e.getAggregation(); + final SqlSplittableAggFunction splitter = + Preconditions.checkNotNull( + aggregation.unwrap(SqlSplittableAggFunction.class)); + if (!aggCall.e.getArgList().isEmpty() + && fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) { + final RexNode singleton = splitter.singleton(rexBuilder, + joinInput.getRowType(), aggCall.e.transform(mapping)); + final RexNode targetSingleton = rexBuilder.ensureType(aggCall.e.type, singleton, false); + + if (targetSingleton instanceof RexInputRef) { + final int index = ((RexInputRef) targetSingleton).getIndex(); + if (!belowAggregateKey.get(index)) { + projects.add(targetSingleton); + side.split.put(aggCall.i, projects.size() - 1); + } else { + side.split.put(aggCall.i, belowAggregateKeyToNewProjectMap.get(index)); + } + } else { + projects.add(targetSingleton); + side.split.put(aggCall.i, projects.size() - 1); + } + } + } + relBuilder.project(projects); + side.newInput = relBuilder.build(); + } else { + side.aggregate = true; + List belowAggCalls = new ArrayList<>(); + final SqlSplittableAggFunction.Registry + belowAggCallRegistry = registry(belowAggCalls); + final int oldGroupKeyCount = aggregate.getGroupCount(); + final int newGroupKeyCount = belowAggregateKey.cardinality(); + for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) { + final SqlAggFunction aggregation = aggCall.e.getAggregation(); + final SqlSplittableAggFunction splitter = + Preconditions.checkNotNull( + aggregation.unwrap(SqlSplittableAggFunction.class)); + final AggregateCall call1; + if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) { + final AggregateCall splitCall = splitter.split(aggCall.e, mapping); + call1 = splitCall.adaptTo( + joinInput, splitCall.getArgList(), splitCall.filterArg, + oldGroupKeyCount, newGroupKeyCount); + } else { + call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e); + } + if (call1 != null) { + side.split.put(aggCall.i, + belowAggregateKey.cardinality() + + belowAggCallRegistry.register(call1)); + } + } + side.newInput = relBuilder.push(joinInput) + .aggregate(relBuilder.groupKey(belowAggregateKey, null), + belowAggCalls) + .build(); + } + offset += fieldCount; + belowOffset += side.newInput.getRowType().getFieldCount(); + sides.add(side); + } + + if (uniqueCount == 2) { + // Both inputs to the join are unique. There is nothing to be gained by + // this rule. In fact, this aggregate+join may be the result of a previous + // invocation of this rule; if we continue we might loop forever. + return; + } + + // Update condition + final Mapping mapping = (Mapping) Mappings.target( + map::get, + join.getRowType().getFieldCount(), + belowOffset); + final RexNode newCondition = + RexUtil.apply(mapping, join.getCondition()); + + // Create new join + relBuilder.push(sides.get(0).newInput) + .push(sides.get(1).newInput) + .join(join.getJoinType(), newCondition); + + // Aggregate above to sum up the sub-totals + final List newAggCalls = new ArrayList<>(); + final int groupIndicatorCount = + aggregate.getGroupCount() + aggregate.getIndicatorCount(); + final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount(); + final List projects = + new ArrayList<>( + rexBuilder.identityProjects(relBuilder.peek().getRowType())); + for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) { + final SqlAggFunction aggregation = aggCall.e.getAggregation(); + final SqlSplittableAggFunction splitter = + Preconditions.checkNotNull( + aggregation.unwrap(SqlSplittableAggFunction.class)); + final Integer leftSubTotal = sides.get(0).split.get(aggCall.i); + final Integer rightSubTotal = sides.get(1).split.get(aggCall.i); + newAggCalls.add( + splitter.topSplit(rexBuilder, registry(projects), + groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, + leftSubTotal == null ? -1 : leftSubTotal, + rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth, join.getJoinType())); + } + + relBuilder.project(projects); + + boolean aggConvertedToProjects = false; + if (allColumnsInAggregate) { + // let's see if we can convert aggregate into projects + List projects2 = new ArrayList<>(); + for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) { + projects2.add(relBuilder.field(key)); + } + int aggCallIdx = projects2.size(); + for (AggregateCall newAggCall : newAggCalls) { + final SqlSplittableAggFunction splitter = + newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class); + if (splitter != null) { + final RelDataType rowType = relBuilder.peek().getRowType(); + final RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall); + final RelDataType originalAggCallType = + aggregate.getRowType().getFieldList().get(aggCallIdx).getType(); + final RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false); + projects2.add(targetSingleton); + } + aggCallIdx += 1; + } + if (projects2.size() + == aggregate.getGroupSet().cardinality() + newAggCalls.size()) { + // We successfully converted agg calls into projects. + relBuilder.project(projects2); + aggConvertedToProjects = true; + } + } + + if (!aggConvertedToProjects) { + relBuilder.aggregate( + relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), + Mappings.apply2(mapping, aggregate.getGroupSets())), + newAggCalls); + } + if (projectAfterAgg != null) { + relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames()); + } + + call.transformTo(relBuilder.build()); + } + + /** + * Convert aggregate with AUXILIARY_GROUP to regular aggregate. + * Return original aggregate and null project if the given aggregate does not contain AUXILIARY_GROUP, + * else new aggregate without AUXILIARY_GROUP and a project to permute output columns if needed. + */ + private Pair> toRegularAggregate(Aggregate aggregate) { + Tuple2> auxGroupAndRegularAggCalls = AggregateUtil.checkAndSplitAggCalls(aggregate); + final int[] auxGroup = auxGroupAndRegularAggCalls._1; + final Seq regularAggCalls = auxGroupAndRegularAggCalls._2; + if (auxGroup.length != 0) { + int[] fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(aggregate); + ImmutableBitSet newGroupSet = ImmutableBitSet.of(fullGroupSet); + List aggCalls = JavaConverters.seqAsJavaListConverter(regularAggCalls).asJava(); + final Aggregate newAgg = aggregate.copy( + aggregate.getTraitSet(), + aggregate.getInput(), + aggregate.indicator, + newGroupSet, + com.google.common.collect.ImmutableList.of(newGroupSet), + aggCalls); + final List aggFields = aggregate.getRowType().getFieldList(); + final List projectAfterAgg = new ArrayList<>(); + for (int i = 0; i < fullGroupSet.length; ++i) { + int group = fullGroupSet[i]; + int index = newGroupSet.indexOf(group); + projectAfterAgg.add(new RexInputRef(index, aggFields.get(i).getType())); + } + int fieldCntOfAgg = aggFields.size(); + for (int i = fullGroupSet.length; i < fieldCntOfAgg; ++i) { + projectAfterAgg.add(new RexInputRef(i, aggFields.get(i).getType())); + } + Preconditions.checkArgument(projectAfterAgg.size() == fieldCntOfAgg); + return new Pair<>(newAgg, projectAfterAgg); + } else { + return new Pair<>(aggregate, null); + } + } + + /** + * Computes the closure of a set of columns according to a given list of + * constraints. Each 'x = y' constraint causes bit y to be set if bit x is + * set, and vice versa. + */ + private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, + com.google.common.collect.ImmutableList predicates) { + SortedMap equivalence = new TreeMap<>(); + for (RexNode predicate : predicates) { + populateEquivalences(equivalence, predicate); + } + ImmutableBitSet keyColumns = aggregateColumns; + for (Integer aggregateColumn : aggregateColumns) { + final BitSet bitSet = equivalence.get(aggregateColumn); + if (bitSet != null) { + keyColumns = keyColumns.union(bitSet); + } + } + return keyColumns; + } + + private static void populateEquivalences(Map equivalence, + RexNode predicate) { + switch (predicate.getKind()) { + case EQUALS: + RexCall call = (RexCall) predicate; + final List operands = call.getOperands(); + if (operands.get(0) instanceof RexInputRef) { + final RexInputRef ref0 = (RexInputRef) operands.get(0); + if (operands.get(1) instanceof RexInputRef) { + final RexInputRef ref1 = (RexInputRef) operands.get(1); + populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex()); + populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex()); + } + } + } + } + + private static boolean isAggregateKeyApplicable(ImmutableBitSet aggregateKeys, Join join) { + JoinInfo joinInfo = join.analyzeCondition(); + return (join.getJoinType() == JoinRelType.LEFT && joinInfo.leftSet().contains(aggregateKeys)) || + (join.getJoinType() == JoinRelType.RIGHT && + joinInfo.rightSet().shift(join.getInput(0).getRowType().getFieldCount()) + .contains(aggregateKeys)); + } + + private static void populateEquivalence(Map equivalence, + int i0, int i1) { + BitSet bitSet = equivalence.get(i0); + if (bitSet == null) { + bitSet = new BitSet(); + equivalence.put(i0, bitSet); + } + bitSet.set(i1); + } + + /** + * Creates a {@link org.apache.calcite.sql.SqlSplittableAggFunction.Registry} + * that is a view of a list. + */ + private static SqlSplittableAggFunction.Registry registry( + final List list) { + return new SqlSplittableAggFunction.Registry() { + public int register(E e) { + int i = list.indexOf(e); + if (i < 0) { + i = list.size(); + list.add(e); + } + return i; + } + }; + } + + /** Work space for an input to a join. */ + private static class Side { + final Map split = new HashMap<>(); + RelNode newInput; + boolean aggregate; + } +} + +// End FlinkAggregateJoinTransposeRule.java diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRule.java new file mode 100644 index 00000000000000..7394a0919ae611 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRule.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical; + +import org.apache.flink.table.functions.sql.internal.SqlAuxiliaryGroupAggFunction; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.runtime.SqlFunctions; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; + +import java.util.ArrayList; +import java.util.List; + +/** + * This rule is copied from Calcite's {@link org.apache.calcite.rel.rules.AggregateRemoveRule}. + * Modification: + * - only matches aggregate with with SIMPLE group and non-empty group + * - supports SUM, MIN, MAX, AUXILIARY_GROUP aggregate functions with no filterArgs + */ + +/** + * Planner rule that removes + * a {@link org.apache.calcite.rel.core.Aggregate} + * if its aggregate functions are SUM, MIN, MAX, AUXILIARY_GROUP with no filterArgs, + * and the underlying relational expression is already distinct. + */ +public class FlinkAggregateRemoveRule extends RelOptRule { + public static final FlinkAggregateRemoveRule INSTANCE = + new FlinkAggregateRemoveRule(LogicalAggregate.class, + RelFactories.LOGICAL_BUILDER); + + //~ Constructors ----------------------------------------------------------- + + @Deprecated // to be removed before 2.0 + public FlinkAggregateRemoveRule(Class aggregateClass) { + this(aggregateClass, RelFactories.LOGICAL_BUILDER); + } + + /** + * Creates an FlinkAggregateRemoveRule. + */ + public FlinkAggregateRemoveRule(Class aggregateClass, + RelBuilderFactory relBuilderFactory) { + // REVIEW jvs 14-Mar-2006: We have to explicitly mention the child here + // to make sure the rule re-fires after the child changes (e.g. via + // ProjectRemoveRule), since that may change our information + // about whether the child is distinct. If we clean up the inference of + // distinct to make it correct up-front, we can get rid of the reference + // to the child here. + super( + operand(aggregateClass, + operand(RelNode.class, any())), + relBuilderFactory, null); + } + + @Override + public boolean matches(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final RelNode input = call.rel(1); + if (aggregate.getGroupCount() == 0 || aggregate.indicator || + aggregate.getGroupType() != Aggregate.Group.SIMPLE) { + return false; + } + for (AggregateCall aggCall : aggregate.getAggCallList()) { + SqlKind aggCallKind = aggCall.getAggregation().getKind(); + // TODO supports more AggregateCalls + boolean isAllowAggCall = aggCallKind == SqlKind.SUM || + aggCallKind == SqlKind.MIN || + aggCallKind == SqlKind.MAX || + aggCall.getAggregation() instanceof SqlAuxiliaryGroupAggFunction; + if (!isAllowAggCall || aggCall.filterArg >= 0 || aggCall.getArgList().size() != 1) { + return false; + } + } + + final RelMetadataQuery mq = call.getMetadataQuery(); + return SqlFunctions.isTrue(mq.areColumnsUnique(input, aggregate.getGroupSet())); + } + + //~ Methods ---------------------------------------------------------------- + + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final RelNode input = call.rel(1); + + // Distinct is "GROUP BY c1, c2" (where c1, c2 are a set of columns on + // which the input is unique, i.e. contain a key) and has no aggregate + // functions or the functions we enumerated. It can be removed. + final RelNode newInput = convert(input, aggregate.getTraitSet().simplify()); + + // If aggregate was projecting a subset of columns, add a project for the + // same effect. + final RelBuilder relBuilder = call.builder(); + relBuilder.push(newInput); + List projectIndices = new ArrayList<>(aggregate.getGroupSet().asList()); + for (AggregateCall aggCall : aggregate.getAggCallList()) { + projectIndices.addAll(aggCall.getArgList()); + } + relBuilder.project(relBuilder.fields(projectIndices)); + // Create a project if some of the columns have become + // NOT NULL due to aggregate functions are removed + relBuilder.convert(aggregate.getRowType(), true); + call.transformTo(relBuilder.build()); + } +} + +// End FlinkAggregateRemoveRule.java diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala index 2827987790cfb9..3793ed35b4eb03 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashAggregate.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.plan.nodes.physical.batch import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.StreamTransformation -import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig} +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfig, TableConfigOptions} import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} @@ -156,6 +156,10 @@ class BatchExecHashAggregate( } override def getParallelism(input: StreamTransformation[BaseRow], conf: TableConfig): Int = { - if (isFinal && grouping.length == 0) 1 else input.getParallelism + if (isFinal && grouping.length == 0) { + 1 + } else { + conf.getConf.getInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM) + } } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala index dcd90264431eb2..e450dff2478745 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecHashWindowAggregateBase.scala @@ -147,6 +147,6 @@ abstract class BatchExecHashWindowAggregateBase( getOperatorName, operator, outputType.toTypeInfo, - input.getParallelism) + tableEnv.getConfig.getConf.getInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM)) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala index 42d248b0e507cf..506e3be27311cc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSortMergeJoin.scala @@ -271,7 +271,7 @@ class BatchExecSortMergeJoin( getOperatorName, operator, FlinkTypeFactory.toInternalRowType(getRowType).toTypeInfo, - leftInput.getParallelism) + tableEnv.getConfig.getConf.getInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM)) } private def estimateOutputSize(relNode: RelNode): Double = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala index 1401605d0d5703..9852dbdfdaa061 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala @@ -213,11 +213,10 @@ object FlinkBatchRuleSets { * This RuleSet is a sub-set of [[LOGICAL_OPT_RULES]]. */ private val LOGICAL_RULES: RuleSet = RuleSets.ofList( - // aggregation and projection rules - AggregateProjectMergeRule.INSTANCE, - AggregateProjectPullUpConstantsRule.INSTANCE, // reorder sort and projection SortProjectTransposeRule.INSTANCE, + // remove unnecessary sort rule + SortRemoveRule.INSTANCE, // join rules FlinkJoinPushExpressionsRule.INSTANCE, @@ -227,8 +226,12 @@ object FlinkBatchRuleSets { // convert non-all union into all-union + distinct UnionToDistinctRule.INSTANCE, + // aggregation and projection rules + AggregateProjectMergeRule.INSTANCE, + AggregateProjectPullUpConstantsRule.INSTANCE, + // remove aggregation if it does not aggregate and input is already distinct - AggregateRemoveRule.INSTANCE, + FlinkAggregateRemoveRule.INSTANCE, // push aggregate through join FlinkAggregateJoinTransposeRule.EXTENDED, // aggregate union rule @@ -240,12 +243,15 @@ object FlinkBatchRuleSets { AggregateReduceFunctionsRule.INSTANCE, WindowAggregateReduceFunctionsRule.INSTANCE, + // reduce group by columns + AggregateReduceGroupingRule.INSTANCE, + // reduce useless aggCall + PruneAggregateCallRule.PROJECT_ON_AGGREGATE, + PruneAggregateCallRule.CALC_ON_AGGREGATE, + // expand grouping sets DecomposeGroupingSetsRule.INSTANCE, - // remove unnecessary sort rule - SortRemoveRule.INSTANCE, - // rank rules FlinkLogicalRankRule.CONSTANT_RANGE_INSTANCE, // transpose calc past rank to reduce rank input fields diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala index 419306ec9534b7..eae74a106509f7 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala @@ -195,11 +195,10 @@ object FlinkStreamRuleSets { * This RuleSet is a sub-set of [[LOGICAL_OPT_RULES]]. */ private val LOGICAL_RULES: RuleSet = RuleSets.ofList( - // aggregation and projection rules - AggregateProjectMergeRule.INSTANCE, - AggregateProjectPullUpConstantsRule.INSTANCE, // reorder sort and projection SortProjectTransposeRule.INSTANCE, + // remove unnecessary sort rule + SortRemoveRule.INSTANCE, // join rules FlinkJoinPushExpressionsRule.INSTANCE, @@ -209,8 +208,14 @@ object FlinkStreamRuleSets { // convert non-all union into all-union + distinct UnionToDistinctRule.INSTANCE, + // aggregation and projection rules + AggregateProjectMergeRule.INSTANCE, + AggregateProjectPullUpConstantsRule.INSTANCE, + // remove aggregation if it does not aggregate and input is already distinct - AggregateRemoveRule.INSTANCE, + FlinkAggregateRemoveRule.INSTANCE, + // push aggregate through join + FlinkAggregateJoinTransposeRule.LEFT_RIGHT_OUTER_JOIN_EXTENDED, // using variants of aggregate union rule AggregateUnionAggregateRule.AGG_ON_FIRST_INPUT, AggregateUnionAggregateRule.AGG_ON_SECOND_INPUT, @@ -219,12 +224,13 @@ object FlinkStreamRuleSets { AggregateReduceFunctionsRule.INSTANCE, WindowAggregateReduceFunctionsRule.INSTANCE, + // reduce useless aggCall + PruneAggregateCallRule.PROJECT_ON_AGGREGATE, + PruneAggregateCallRule.CALC_ON_AGGREGATE, + // expand grouping sets DecomposeGroupingSetsRule.INSTANCE, - // remove unnecessary sort rule - SortRemoveRule.INSTANCE, - // calc rules FilterCalcMergeRule.INSTANCE, ProjectCalcMergeRule.INSTANCE, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRule.scala new file mode 100644 index 00000000000000..e0fc4066d5bdaf --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRule.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable +import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery + +import com.google.common.collect.ImmutableList +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rel.core.Aggregate.Group +import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories} +import org.apache.calcite.tools.RelBuilderFactory + +import scala.collection.JavaConversions._ +import scala.collection.mutable + +/** + * Planner rule that reduces unless grouping columns. + * + * Find (minimum) unique group for the grouping columns, and use it as new grouping columns. + */ +class AggregateReduceGroupingRule(relBuilderFactory: RelBuilderFactory) extends RelOptRule( + operand(classOf[Aggregate], any), + relBuilderFactory, + "AggregateReduceGroupingRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: Aggregate = call.rel(0) + agg.getGroupCount > 1 && agg.getGroupType == Group.SIMPLE && !agg.indicator + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val agg: Aggregate = call.rel(0) + val aggRowType = agg.getRowType + val input = agg.getInput + val inputRowType = input.getRowType + val originalGrouping = agg.getGroupSet + val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) + val newGrouping = fmq.getUniqueGroups(input, originalGrouping) + val uselessGrouping = originalGrouping.except(newGrouping) + if (uselessGrouping.isEmpty) { + return + } + + // new agg: new grouping + aggCalls for dropped grouping + original aggCalls + val indexOldToNewMap = new mutable.HashMap[Int, Int]() + val newGroupingList = newGrouping.toList + var idxOfNewGrouping = 0 + var idxOfAggCallsForDroppedGrouping = newGroupingList.size() + originalGrouping.zipWithIndex.foreach { + case (column, oldIdx) => + val newIdx = if (newGroupingList.contains(column)) { + val p = idxOfNewGrouping + idxOfNewGrouping += 1 + p + } else { + val p = idxOfAggCallsForDroppedGrouping + idxOfAggCallsForDroppedGrouping += 1 + p + } + indexOldToNewMap += (oldIdx -> newIdx) + } + require(indexOldToNewMap.size == originalGrouping.cardinality()) + + // the indices of aggCalls (or NamedProperties for WindowAggregate) do not change + (originalGrouping.cardinality() until aggRowType.getFieldCount).foreach { + index => indexOldToNewMap += (index -> index) + } + + val aggCallsForDroppedGrouping = uselessGrouping.map { column => + val fieldType = inputRowType.getFieldList.get(column).getType + val fieldName = inputRowType.getFieldNames.get(column) + AggregateCall.create( + FlinkSqlOperatorTable.AUXILIARY_GROUP, + false, + false, + ImmutableList.of(column), + -1, + fieldType, + fieldName) + }.toList + + val newAggCalls = aggCallsForDroppedGrouping ++ agg.getAggCallList + val newAgg = agg.copy( + agg.getTraitSet, + input, + agg.indicator, // always false here + newGrouping, + ImmutableList.of(newGrouping), + newAggCalls + ) + val builder = call.builder() + builder.push(newAgg) + val projects = (0 until aggRowType.getFieldCount).map { + index => + val refIndex = indexOldToNewMap.getOrElse(index, + throw new IllegalArgumentException(s"Illegal index: $index")) + builder.field(refIndex) + } + builder.project(projects, aggRowType.getFieldNames) + call.transformTo(builder.build()) + } +} + +object AggregateReduceGroupingRule { + val INSTANCE = new AggregateReduceGroupingRule(RelFactories.LOGICAL_BUILDER) +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala deleted file mode 100644 index 6de7a8107d56c6..00000000000000 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateJoinTransposeRule.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.rules.logical - -import org.apache.calcite.plan.RelOptRuleCall -import org.apache.calcite.plan.hep.HepRelVertex -import org.apache.calcite.plan.volcano.RelSubset -import org.apache.calcite.rel.{RelNode, SingleRel} -import org.apache.calcite.rel.core.{Aggregate, Join, RelFactories} -import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalJoin, LogicalSnapshot} -import org.apache.calcite.rel.rules.AggregateJoinTransposeRule -import org.apache.calcite.tools.RelBuilderFactory - -/** - * Flink's [[AggregateJoinTransposeRule]] which does not match temporal join - * since lookup table source doesn't support aggregate. - */ -class FlinkAggregateJoinTransposeRule( - aggregateClass: Class[_ <: Aggregate], - joinClass: Class[_ <: Join], - factory: RelBuilderFactory, - allowFunctions: Boolean) - extends AggregateJoinTransposeRule(aggregateClass, joinClass, factory, allowFunctions) { - - override def matches(call: RelOptRuleCall): Boolean = { - val join: Join = call.rel(1) - if (containsSnapshot(join.getRight)) { - // avoid push aggregates through temporal join - false - } else { - super.matches(call) - } - } - - private def containsSnapshot(relNode: RelNode): Boolean = { - val original = relNode match { - case r: RelSubset => r.getOriginal - case r: HepRelVertex => r.getCurrentRel - case _ => relNode - } - original match { - case _: LogicalSnapshot => true - case r: SingleRel => containsSnapshot(r.getInput) - case _ => false - } - } -} - -object FlinkAggregateJoinTransposeRule { - - /** Extended instance of the rule that can push down aggregate functions. */ - val EXTENDED = new FlinkAggregateJoinTransposeRule( - classOf[LogicalAggregate], - classOf[LogicalJoin], - RelFactories.LOGICAL_BUILDER, - true) -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRule.scala new file mode 100644 index 00000000000000..a0b67ef0cd5f63 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRule.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import com.google.common.collect.{ImmutableList, Maps} +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.Aggregate.Group +import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Calc, Project, RelFactories} +import org.apache.calcite.rex.{RexInputRef, RexNode, RexProgram, RexUtil} +import org.apache.calcite.runtime.Utilities +import org.apache.calcite.util.ImmutableBitSet +import org.apache.calcite.util.mapping.Mappings + +import java.util + +import scala.collection.JavaConversions._ + +/** + * Planner rule that removes unreferenced AggregateCall from Aggregate + */ +abstract class PruneAggregateCallRule[T <: RelNode](topClass: Class[T]) extends RelOptRule( + operand(topClass, + operand(classOf[Aggregate], any)), + RelFactories.LOGICAL_BUILDER, + s"PruneAggregateCallRule_${topClass.getCanonicalName}") { + + protected def getInputRefs(relOnAgg: T): ImmutableBitSet + + override def matches(call: RelOptRuleCall): Boolean = { + val relOnAgg: T = call.rel(0) + val agg: Aggregate = call.rel(1) + if (agg.indicator || agg.getGroupType != Group.SIMPLE || agg.getAggCallList.isEmpty || + // at least output one column + (agg.getGroupCount == 0 && agg.getAggCallList.size() == 1)) { + return false + } + val inputRefs = getInputRefs(relOnAgg) + val unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg) + unrefAggCallIndices.nonEmpty + } + + private def getUnrefAggCallIndices( + inputRefs: ImmutableBitSet, + agg: Aggregate): Array[Int] = { + val groupCount = agg.getGroupCount + agg.getAggCallList.indices.flatMap { index => + val aggCallOutputIndex = groupCount + index + if (inputRefs.get(aggCallOutputIndex)) { + Array.empty[Int] + } else { + Array(index) + } + }.toArray[Int] + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val relOnAgg: T = call.rel(0) + val agg: Aggregate = call.rel(1) + val inputRefs = getInputRefs(relOnAgg) + var unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg) + require(unrefAggCallIndices.nonEmpty) + + val newAggCalls: util.List[AggregateCall] = new util.ArrayList(agg.getAggCallList) + // remove unreferenced AggCall from original aggCalls + unrefAggCallIndices.sorted.reverse.foreach(i => newAggCalls.remove(i)) + + if (newAggCalls.isEmpty && agg.getGroupCount == 0) { + // at least output one column + newAggCalls.add(agg.getAggCallList.get(0)) + unrefAggCallIndices = unrefAggCallIndices.slice(1, unrefAggCallIndices.length) + } + + val newAgg = agg.copy( + agg.getTraitSet, + agg.getInput, + agg.indicator, + agg.getGroupSet, + ImmutableList.of(agg.getGroupSet), + newAggCalls + ) + + var newFieldIndex = 0 + // map old agg output index to new agg output index + val mapOldToNew = Maps.newHashMap[Integer, Integer]() + val fieldCountOfOldAgg = agg.getRowType.getFieldCount + val unrefAggCallOutputIndices = unrefAggCallIndices.map(_ + agg.getGroupCount) + (0 until fieldCountOfOldAgg).foreach { i => + if (!unrefAggCallOutputIndices.contains(i)) { + mapOldToNew.put(i, newFieldIndex) + newFieldIndex += 1 + } + } + require(mapOldToNew.size() == newAgg.getRowType.getFieldCount) + + val mapping = Mappings.target(mapOldToNew, fieldCountOfOldAgg, newAgg.getRowType.getFieldCount) + val newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg) + call.transformTo(newRelOnAgg) + } + + protected def createNewRel( + mapping: Mappings.TargetMapping, + project: T, + newAgg: RelNode): RelNode +} + +class ProjectPruneAggregateCallRule extends PruneAggregateCallRule(classOf[Project]) { + override protected def getInputRefs(relOnAgg: Project): ImmutableBitSet = { + RelOptUtil.InputFinder.bits(relOnAgg.getProjects, null) + } + + override protected def createNewRel( + mapping: Mappings.TargetMapping, + project: Project, + newAgg: RelNode): RelNode = { + val newProjects = RexUtil.apply(mapping, project.getProjects).toList + if (projectsOnlyIdentity(newProjects, newAgg.getRowType.getFieldCount) && + Utilities.compare(project.getRowType.getFieldNames, newAgg.getRowType.getFieldNames) == 0) { + newAgg + } else { + project.copy(project.getTraitSet, newAgg, newProjects, project.getRowType) + } + } + + private def projectsOnlyIdentity( + projects: util.List[RexNode], + inputFieldCount: Int): Boolean = { + if (projects.size != inputFieldCount) { + return false + } + projects.zipWithIndex.forall { + case (project, index) => + project match { + case r: RexInputRef => r.getIndex == index + case _ => false + } + } + } +} + +class CalcPruneAggregateCallRule extends PruneAggregateCallRule(classOf[Calc]) { + override protected def getInputRefs(relOnAgg: Calc): ImmutableBitSet = { + val program = relOnAgg.getProgram + val condition = if (program.getCondition != null) { + program.expandLocalRef(program.getCondition) + } else { + null + } + val projects = program.getProjectList.map(program.expandLocalRef) + RelOptUtil.InputFinder.bits(projects, condition) + } + + override protected def createNewRel( + mapping: Mappings.TargetMapping, + calc: Calc, + newAgg: RelNode): RelNode = { + val program = calc.getProgram + val newCondition = if (program.getCondition != null) { + RexUtil.apply(mapping, program.expandLocalRef(program.getCondition)) + } else { + null + } + val projects = program.getProjectList.map(program.expandLocalRef) + val newProjects = RexUtil.apply(mapping, projects).toList + val newProgram = RexProgram.create( + newAgg.getRowType, + newProjects, + newCondition, + program.getOutputRowType.getFieldNames, + calc.getCluster.getRexBuilder + ) + if (newProgram.isTrivial && + Utilities.compare(calc.getRowType.getFieldNames, newAgg.getRowType.getFieldNames) == 0) { + newAgg + } else { + calc.copy(calc.getTraitSet, newAgg, newProgram) + } + } +} + +object PruneAggregateCallRule { + val PROJECT_ON_AGGREGATE = new ProjectPruneAggregateCallRule + val CALC_ON_AGGREGATE = new CalcPruneAggregateCallRule +} diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml new file mode 100644 index 00000000000000..d815d86e44cdba --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.xml @@ -0,0 +1,974 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml index 39ec889a551d1b..d50dab091b11d6 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/BroadcastHashSemiAntiJoinTest.xml @@ -72,9 +72,9 @@ HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(c, f))], select=[a, b, c :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Exchange(distribution=[broadcast]) +- Calc(select=[f, e]) - +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_MAX(max$0) AS EXPR$0]) + +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f]) +- Exchange(distribution=[hash[d, e, f]]) - +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_MAX(e) AS max$0]) + +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f]) +- Calc(select=[d, e, f], where=[<(d, 100)]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml index 56faaac15412e8..b398e66df044d9 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/LookupJoinTest.xml @@ -184,30 +184,6 @@ LogicalProject(a=[$0], b=[$1], c=[$2], proctime=[$3], id=[$4]) Calc(select=[a, b, c, PROCTIME() AS proctime, id]) +- LookupJoin(table=[TestTemporalTable(id, name, age)], joinType=[InnerJoin], async=[false], on=[a=id], where=[], select=[a, b, c, id]) +- BoundedStreamScan(table=[[T0]], fields=[a, b, c]) -]]> - - - - - - - - - - - @@ -322,9 +298,9 @@ Calc(select=[EXPR$0, EXPR$1, EXPR$2]) :- Exchange(distribution=[hash[a]], exchange_mode=[BATCH]) : +- LookupJoin(table=[TestTemporalTable(id, name, age)], joinType=[InnerJoin], async=[false], on=[a=id], where=[>(age, 10)], select=[b, a, id], reuse_id=[1]) : +- Calc(select=[b, a]) - : +- HashAggregate(isMerge=[true], groupBy=[a, b], select=[a, b, Final_SUM(sum$0) AS c, Final_SUM(sum$1) AS d]) + : +- HashAggregate(isMerge=[true], groupBy=[a, b], select=[a, b]) : +- Exchange(distribution=[hash[a, b]]) - : +- LocalHashAggregate(groupBy=[a, b], select=[a, b, Partial_SUM(c) AS sum$0, Partial_SUM(d) AS sum$1]) + : +- LocalHashAggregate(groupBy=[a, b], select=[a, b]) : +- BoundedStreamScan(table=[[T1]], fields=[a, b, c, d]) +- Exchange(distribution=[hash[a]]) +- Calc(select=[id AS a, b]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopSemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopSemiAntiJoinTest.xml index 33d6a025f728d5..d2bd6950df703c 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopSemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/NestedLoopSemiAntiJoinTest.xml @@ -72,9 +72,9 @@ NestedLoopJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(c, f))], select=[a :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Exchange(distribution=[broadcast]) +- Calc(select=[f, e]) - +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_MAX(max$0) AS EXPR$0]) + +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f]) +- Exchange(distribution=[hash[d, e, f]]) - +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_MAX(e) AS max$0]) + +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f]) +- Calc(select=[d, e, f], where=[<(d, 100)]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> @@ -521,9 +521,9 @@ NestedLoopJoin(joinType=[LeftAntiJoin], where=[<>(b, e)], select=[a, b, c], buil : : +- Exchange(distribution=[single]) : : +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0]) : : +- Calc(select=[true AS i]) -: : +- HashAggregate(isMerge=[true], groupBy=[l], select=[l, Final_COUNT(count$0) AS EXPR$0]) +: : +- HashAggregate(isMerge=[true], groupBy=[l], select=[l]) : : +- Exchange(distribution=[hash[l]]) -: : +- LocalHashAggregate(groupBy=[l], select=[l, Partial_COUNT(l) AS count$0]) +: : +- LocalHashAggregate(groupBy=[l], select=[l]) : : +- Calc(select=[l], where=[LIKE(n, _UTF-16LE'Test')]) : : +- TableSourceScan(table=[[t2, source: [TestTableSource(l, m, n)]]], fields=[l, m, n]) : +- Exchange(distribution=[broadcast]) @@ -676,12 +676,11 @@ Calc(select=[b]) : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0], build=[right], singleRowJoin=[true]) : : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : : +- Exchange(distribution=[broadcast]) - : : +- Calc(select=[c]) - : : +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[1 AS EXPR$0]) - : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0]) + : : +- Calc(select=[1 AS EXPR$0]) + : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) : +- Exchange(distribution=[broadcast]) : +- Calc(select=[true AS i]) : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) @@ -1202,12 +1201,11 @@ Calc(select=[b]) : : :- Calc(select=[a, b]) : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : : +- Exchange(distribution=[broadcast]) - : : +- Calc(select=[c]) - : : +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[1 AS EXPR$0]) - : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0]) + : : +- Calc(select=[1 AS EXPR$0]) + : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) : +- Exchange(distribution=[broadcast]) : +- Calc(select=[true AS i]) : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml index 8ad7125a2cf8dc..6315d758f69fa3 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SemiAntiJoinTest.xml @@ -74,9 +74,9 @@ HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(c, f))], select=[a, b, c : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Exchange(distribution=[hash[e, f]]) +- Calc(select=[f, e]) - +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_MAX(max$0) AS EXPR$0]) + +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f]) +- Exchange(distribution=[hash[d, e, f]]) - +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_MAX(e) AS max$0]) + +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f]) +- Calc(select=[d, e, f], where=[<(d, 100)]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> @@ -531,9 +531,9 @@ NestedLoopJoin(joinType=[LeftAntiJoin], where=[<>(b, e)], select=[a, b, c], buil : : +- Exchange(distribution=[single]) : : +- LocalHashAggregate(select=[Partial_MIN(i) AS min$0]) : : +- Calc(select=[true AS i]) -: : +- HashAggregate(isMerge=[true], groupBy=[l], select=[l, Final_COUNT(count$0) AS EXPR$0]) +: : +- HashAggregate(isMerge=[true], groupBy=[l], select=[l]) : : +- Exchange(distribution=[hash[l]]) -: : +- LocalHashAggregate(groupBy=[l], select=[l, Partial_COUNT(l) AS count$0]) +: : +- LocalHashAggregate(groupBy=[l], select=[l]) : : +- Calc(select=[l], where=[LIKE(n, _UTF-16LE'Test')]) : : +- TableSourceScan(table=[[t2, source: [TestTableSource(l, m, n)]]], fields=[l, m, n]) : +- Exchange(distribution=[hash[k]]) @@ -691,12 +691,11 @@ Calc(select=[b]) : :- NestedLoopJoin(joinType=[InnerJoin], where=[true], select=[a, b, c, c0], build=[right], singleRowJoin=[true]) : : :- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : : +- Exchange(distribution=[broadcast]) - : : +- Calc(select=[c]) - : : +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[1 AS EXPR$0]) - : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0]) + : : +- Calc(select=[1 AS EXPR$0]) + : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) : +- Exchange(distribution=[broadcast]) : +- Calc(select=[true AS i]) : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) @@ -1236,12 +1235,11 @@ Calc(select=[b]) : : :- Calc(select=[a, b]) : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : : +- Exchange(distribution=[broadcast]) - : : +- Calc(select=[c]) - : : +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c, Final_COUNT(count$1) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0, Partial_COUNT(EXPR$0) AS count$1]) - : : +- Calc(select=[1 AS EXPR$0]) - : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : +- SortAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c]) + : : +- Exchange(distribution=[single]) + : : +- LocalSortAggregate(select=[Partial_COUNT(*) AS count1$0]) + : : +- Calc(select=[1 AS EXPR$0]) + : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) : +- Exchange(distribution=[broadcast]) : +- Calc(select=[true AS i]) : +- HashAggregate(isMerge=[true], groupBy=[EXPR$0], select=[EXPR$0]) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashSemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashSemiAntiJoinTest.xml index 94b5ca751c7f82..0bc53f32958344 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashSemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/ShuffledHashSemiAntiJoinTest.xml @@ -74,9 +74,9 @@ HashJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(c, f))], select=[a, b, c : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Exchange(distribution=[hash[e, f]]) +- Calc(select=[f, e]) - +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_MAX(max$0) AS EXPR$0]) + +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f]) +- Exchange(distribution=[hash[d, e, f]]) - +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_MAX(e) AS max$0]) + +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f]) +- Calc(select=[d, e, f], where=[<(d, 100)]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeSemiAntiJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeSemiAntiJoinTest.xml index e1ec8917422894..8f510415c9b148 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeSemiAntiJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/batch/sql/join/SortMergeSemiAntiJoinTest.xml @@ -74,9 +74,9 @@ SortMergeJoin(joinType=[LeftSemiJoin], where=[AND(=(b, e), =(c, f))], select=[a, : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +- Exchange(distribution=[hash[e, f]]) +- Calc(select=[f, e]) - +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f, Final_MAX(max$0) AS EXPR$0]) + +- HashAggregate(isMerge=[true], groupBy=[d, e, f], select=[d, e, f]) +- Exchange(distribution=[hash[d, e, f]]) - +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f, Partial_MAX(e) AS max$0]) + +- LocalHashAggregate(groupBy=[d, e, f], select=[d, e, f]) +- Calc(select=[d, e, f], where=[<(d, 100)]) +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f]) ]]> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml new file mode 100644 index 00000000000000..97babfdc9f86a2 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.xml @@ -0,0 +1,909 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.xml new file mode 100644 index 00000000000000..3b635b3ac8ab83 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.xml @@ -0,0 +1,377 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 10 + ]]> + + + ($2, 10)]) + +- LogicalAggregate(group=[{0, 1}], c1=[COUNT($2)], d1=[SUM($3)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + ($t1, $t2)], proj#0..1=[{exprs}], $condition=[$t3]) ++- LogicalAggregate(group=[{0}], c1=[COUNT($2)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + + + 10 + ]]> + + + ($2, 10)]) + +- LogicalAggregate(group=[{0, 1}], c1=[COUNT($2)], d1=[SUM($3)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + ($t1, $t2)], c1=[$t1], a1=[$t0], $condition=[$t3]) ++- LogicalAggregate(group=[{0}], c1=[COUNT($2)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + ]]> + + + ($3, 0)]) + +- LogicalAggregate(group=[{0, 1}], c2=[COUNT($2)], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + ($t2, $t3)], proj#0..2=[{exprs}], $condition=[$t4]) ++- LogicalAggregate(group=[{0, 1}], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + + + 0 + ]]> + + + ($3, 0)]) + +- LogicalAggregate(group=[{0, 1}], c2=[COUNT($2)], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + ($t2, $t3)], b2=[$t1], a2=[$t0], d2=[$t2], $condition=[$t4]) ++- LogicalAggregate(group=[{0, 1}], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml new file mode 100644 index 00000000000000..31daeed9781599 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.xml @@ -0,0 +1,257 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml new file mode 100644 index 00000000000000..1f0abf2f41ec0e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.xml @@ -0,0 +1,267 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.xml new file mode 100644 index 00000000000000..a6e8987b887ab9 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.xml @@ -0,0 +1,528 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0), MAX(b) FROM MyTable2 GROUP BY a]]> + + + ($1, 0))], b=[$1]) + +- LogicalTableScan(table=[[MyTable2, source: [TestTableSource(a, b, c)]]]) +]]> + + + (b, 0)) AS $f2, b]) + +- FlinkLogicalTableSourceScan(table=[[MyTable2, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.xml new file mode 100644 index 00000000000000..c77788a88ee5e9 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.xml @@ -0,0 +1,379 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 10 + ]]> + + + ($2, 10)]) + +- LogicalAggregate(group=[{0, 1}], c1=[COUNT($2)], d1=[SUM($3)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + ($1, 10)]) ++- LogicalAggregate(group=[{0}], c1=[COUNT($2)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + + + 10 + ]]> + + + ($2, 10)]) + +- LogicalAggregate(group=[{0, 1}], c1=[COUNT($2)], d1=[SUM($3)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + ($1, 10)]) + +- LogicalAggregate(group=[{0}], c1=[COUNT($2)]) + +- LogicalTableScan(table=[[T1, source: [TestTableSource(a1, b1, c1, d1)]]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + ]]> + + + ($3, 0)]) + +- LogicalAggregate(group=[{0, 1}], c2=[COUNT($2)], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + ($2, 0)]) ++- LogicalAggregate(group=[{0, 1}], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + + + 0 + ]]> + + + ($3, 0)]) + +- LogicalAggregate(group=[{0, 1}], c2=[COUNT($2)], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + ($2, 0)]) + +- LogicalAggregate(group=[{0, 1}], d2=[SUM($3)]) + +- LogicalTableScan(table=[[T2, source: [TestTableSource(a2, b2, c2, d2)]]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/JoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/JoinTest.xml index 807d5577ebc801..9c41ba5246c5a8 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/JoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/JoinTest.xml @@ -87,17 +87,15 @@ LogicalProject(a1=[$1], b1=[$3]) @@ -157,11 +155,10 @@ LogicalProject(a1=[$1], b1=[$2]) @@ -378,17 +373,15 @@ LogicalProject(a1=[$1], b1=[$3]) @@ -636,11 +629,10 @@ LogicalProject(a1=[$1], b1=[$2]) @@ -955,11 +945,10 @@ LogicalProject(a1=[$1], b1=[$2]) (b, e)], select=[a, b, c]) : : : +- TableSourceScan(table=[[r, source: [TestTableSource(d, e, f)]]], fields=[d, e, f], reuse_id=[1]) : : +- Exchange(distribution=[single]) : : +- Calc(select=[IS NOT NULL(m) AS $f0]) -: : +- GroupAggregate(select=[MIN_RETRACT(i) AS m]) +: : +- GroupAggregate(select=[MIN(i) AS m]) : : +- Exchange(distribution=[single]) : : +- Calc(select=[true AS i]) -: : +- GroupAggregate(groupBy=[l], select=[l, COUNT(l) AS EXPR$0]) +: : +- GroupAggregate(groupBy=[l], select=[l]) : : +- Exchange(distribution=[hash[l]]) : : +- Calc(select=[l], where=[LIKE(n, _UTF-16LE'Test')]) : : +- TableSourceScan(table=[[t2, source: [TestTableSource(l, m, n)]]], fields=[l, m, n]) @@ -690,11 +690,10 @@ Calc(select=[b]) : : :- Exchange(distribution=[single]) : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : : +- Exchange(distribution=[single]) - : : +- Calc(select=[c]) - : : +- GroupAggregate(select=[COUNT(*) AS c, COUNT(EXPR$0) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- Calc(select=[1 AS EXPR$0]) - : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : +- GroupAggregate(select=[COUNT(*) AS c]) + : : +- Exchange(distribution=[single]) + : : +- Calc(select=[1 AS EXPR$0]) + : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) : +- Exchange(distribution=[single]) : +- Calc(select=[i]) : +- GroupAggregate(groupBy=[EXPR$0, i], select=[EXPR$0, i]) @@ -1234,11 +1233,10 @@ Calc(select=[b]) : : : +- Calc(select=[a, b]) : : : +- TableSourceScan(table=[[l, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) : : +- Exchange(distribution=[single]) - : : +- Calc(select=[c]) - : : +- GroupAggregate(select=[COUNT(*) AS c, COUNT(EXPR$0) AS ck]) - : : +- Exchange(distribution=[single]) - : : +- Calc(select=[1 AS EXPR$0]) - : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) + : : +- GroupAggregate(select=[COUNT(*) AS c]) + : : +- Exchange(distribution=[single]) + : : +- Calc(select=[1 AS EXPR$0]) + : : +- TableSourceScan(table=[[t, source: [TestTableSource(i, j, k)]]], fields=[i, j, k], reuse_id=[1]) : +- Exchange(distribution=[single]) : +- Calc(select=[i]) : +- GroupAggregate(groupBy=[EXPR$0, i], select=[EXPR$0, i]) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala index 60902273fe6b30..fd9237a10110e0 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveCollationTest.scala @@ -21,7 +21,7 @@ package org.apache.flink.table.plan.batch.sql import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions, Types} -import org.apache.flink.table.plan.stats.TableStats +import org.apache.flink.table.plan.stats.{FlinkStatistic, TableStats} import org.apache.flink.table.runtime.utils.JavaUserDefinedTableFunctions.StringSplit import org.apache.flink.table.util.{TableFunc1, TableTestBase} @@ -37,22 +37,22 @@ class RemoveCollationTest extends TableTestBase { util.addTableSource("x", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), Array("a", "b", "c"), - tableStats = Some(new TableStats(100L)) + FlinkStatistic.builder().tableStats(new TableStats(100L)).build() ) util.addTableSource("y", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), Array("d", "e", "f"), - tableStats = Some(new TableStats(100L)) + FlinkStatistic.builder().tableStats(new TableStats(100L)).build() ) util.addTableSource("t1", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), Array("a1", "b1", "c1"), - tableStats = Some(new TableStats(100L)) + FlinkStatistic.builder().tableStats(new TableStats(100L)).build() ) util.addTableSource("t2", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), Array("d1", "e1", "f1"), - tableStats = Some(new TableStats(100L)) + FlinkStatistic.builder().tableStats(new TableStats(100L)).build() ) util.tableEnv.getConfig.getConf.setBoolean( @@ -269,27 +269,27 @@ class RemoveCollationTest extends TableTestBase { Array[TypeInformation[_]]( Types.STRING, Types.STRING, Types.STRING, Types.STRING, Types.STRING), Array("id", "key", "tb2_ids", "tb3_ids", "name"), - uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("id"))).build() ) util.addTableSource("tb2", Array[TypeInformation[_]](Types.STRING, Types.STRING), Array("id", "name"), - uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("id"))).build() ) util.addTableSource("tb3", Array[TypeInformation[_]](Types.STRING, Types.STRING), Array("id", "name"), - uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("id"))).build() ) util.addTableSource("tb4", Array[TypeInformation[_]](Types.STRING, Types.STRING), Array("id", "name"), - uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("id"))).build() ) util.addTableSource("tb5", Array[TypeInformation[_]](Types.STRING, Types.STRING), Array("id", "name"), - uniqueKeys = Some(ImmutableSet.of(ImmutableSet.of("id"))) + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("id"))).build() ) util.tableEnv.registerFunction("split", new StringSplit()) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala index 29b14b9813c9da..e81190b3e801c1 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/RemoveShuffleTest.scala @@ -21,7 +21,7 @@ package org.apache.flink.table.plan.batch.sql import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions, Types} -import org.apache.flink.table.plan.stats.TableStats +import org.apache.flink.table.plan.stats.{FlinkStatistic, TableStats} import org.apache.flink.table.util.{TableFunc1, TableTestBase} import org.junit.{Before, Test} @@ -35,12 +35,12 @@ class RemoveShuffleTest extends TableTestBase { util.addTableSource("x", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), Array("a", "b", "c"), - tableStats = Some(new TableStats(100L)) + FlinkStatistic.builder().tableStats(new TableStats(100L)).build() ) util.addTableSource("y", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING), Array("d", "e", "f"), - tableStats = Some(new TableStats(100L)) + FlinkStatistic.builder().tableStats(new TableStats(100L)).build() ) util.tableEnv.getConfig.getConf.setBoolean( PlannerConfigOptions.SQL_OPTIMIZER_REUSE_SUB_PLAN_ENABLED, false) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.scala new file mode 100644 index 00000000000000..e55b8b498015b8 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/agg/AggregateReduceGroupingTest.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.batch.sql.agg + +import org.apache.flink.table.plan.common.AggregateReduceGroupingTestBase + +class AggregateReduceGroupingTest extends AggregateReduceGroupingTestBase { + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/common/AggregateReduceGroupingTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/common/AggregateReduceGroupingTestBase.scala new file mode 100644 index 00000000000000..687f73000f0bc5 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/common/AggregateReduceGroupingTestBase.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.common + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.api.Types +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.FlinkBatchProgram +import org.apache.flink.table.plan.rules.logical.FlinkAggregateRemoveRule +import org.apache.flink.table.plan.stats.{FlinkStatistic, TableStats} +import org.apache.flink.table.util.{BatchTableTestUtil, TableTestBase} + +import com.google.common.collect.ImmutableSet +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +abstract class AggregateReduceGroupingTestBase extends TableTestBase { + protected val util: BatchTableTestUtil = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource("T1", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING, Types.STRING), + Array("a1", "b1", "c1", "d1"), + FlinkStatistic.builder() + .tableStats(new TableStats(100000000)) + .uniqueKeys(ImmutableSet.of(ImmutableSet.of("a1"))) + .build() + ) + util.addTableSource("T2", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING), + Array("a2", "b2", "c2"), + FlinkStatistic.builder() + .tableStats(new TableStats(100000000)) + .uniqueKeys(ImmutableSet.of(ImmutableSet.of("b2"), ImmutableSet.of("a2", "b2"))) + .build() + ) + util.addTableSource("T3", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING, Types.LONG), + Array("a3", "b3", "c3", "d3"), + FlinkStatistic.builder() + .tableStats(new TableStats(1000)) + .build() + ) + util.addTableSource("T4", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING, Types.SQL_TIMESTAMP), + Array("a4", "b4", "c4", "d4"), + FlinkStatistic.builder() + .tableStats(new TableStats(100000000)) + .uniqueKeys(ImmutableSet.of(ImmutableSet.of("a4"))) + .build() + ) + } + + @Test + def testAggWithoutAggCall(): Unit = { + val programs = util.tableEnv.getConfig.getCalciteConfig.getBatchProgram + .getOrElse(FlinkBatchProgram.buildProgram(util.getTableEnv.getConfig.getConf)) + programs.getFlinkRuleSetProgram(FlinkBatchProgram.LOGICAL) + .get.remove(RuleSets.ofList(FlinkAggregateRemoveRule.INSTANCE)) // to prevent the agg from + // removing + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + util.verifyPlan("SELECT a1, b1, c1 FROM T1 GROUP BY a1, b1, c1") + } + + @Test + def testAggWithoutReduceGrouping(): Unit = { + util.verifyPlan("SELECT a3, b3, count(c3) FROM T3 GROUP BY a3, b3") + } + + @Test + def testSingleAggOnTableWithUniqueKey(): Unit = { + util.verifyPlan("SELECT a1, b1, count(c1) FROM T1 GROUP BY a1, b1") + } + + @Test + def testSingleAggOnTableWithoutUniqueKey(): Unit = { + util.verifyPlan("SELECT a3, b3, count(c3) FROM T3 GROUP BY a3, b3") + } + + @Test + def testSingleAggOnTableWithUniqueKeys(): Unit = { + util.verifyPlan("SELECT b2, c2, avg(a2) FROM T2 GROUP BY b2, c2") + } + + @Test + def testSingleAggWithConstantGroupKey(): Unit = { + util.verifyPlan("SELECT a1, b1, count(c1) FROM T1 GROUP BY a1, b1, 1, true") + } + + @Test + def testSingleAggOnlyConstantGroupKey(): Unit = { + util.verifyPlan("SELECT count(c1) FROM T1 GROUP BY 1, true") + } + + @Test + def testMultiAggs1(): Unit = { + util.verifyPlan("SELECT a1, b1, c1, d1, m, COUNT(*) FROM " + + "(SELECT a1, b1, c1, COUNT(d1) AS d1, MAX(d1) AS m FROM T1 GROUP BY a1, b1, c1) t " + + "GROUP BY a1, b1, c1, d1, m") + } + + @Test + def testMultiAggs2(): Unit = { + util.verifyPlan("SELECT a3, b3, c, s, a, COUNT(*) FROM " + + "(SELECT a3, b3, COUNT(c3) AS c, SUM(d3) AS s, AVG(d3) AS a FROM T3 GROUP BY a3, b3) t " + + "GROUP BY a3, b3, c, s, a") + } + + @Test + def testAggOnInnerJoin1(): Unit = { + util.verifyPlan("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1, T2 WHERE a1 = b2) t GROUP BY a1, b1, a2, b2") + } + + @Test + def testAggOnInnerJoin2(): Unit = { + util.verifyPlan("SELECT a2, b2, a3, b3, COUNT(c2), AVG(d3) FROM " + + "(SELECT * FROM T2, T3 WHERE b2 = a3) t GROUP BY a2, b2, a3, b3") + } + + @Test + def testAggOnInnerJoin3(): Unit = { + util.verifyPlan("SELECT a1, b1, a2, b2, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1, T2, T3 WHERE a1 = b2 AND a1 = a3) t GROUP BY a1, b1, a2, b2, a3, b3") + } + + @Test + def testAggOnLeftJoin1(): Unit = { + util.verifyPlan("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1 LEFT JOIN T2 ON a1 = b2) t GROUP BY a1, b1, a2, b2") + } + + @Test + def testAggOnLeftJoin2(): Unit = { + util.verifyPlan("SELECT a1, b1, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1 LEFT JOIN T3 ON a1 = a3) t GROUP BY a1, b1, a3, b3") + } + + @Test + def testAggOnLeftJoin3(): Unit = { + util.verifyPlan("SELECT a3, b3, a1, b1, COUNT(c1) FROM " + + "(SELECT * FROM T3 LEFT JOIN T1 ON a1 = a3) t GROUP BY a3, b3, a1, b1") + } + + @Test + def testAggOnRightJoin1(): Unit = { + util.verifyPlan("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1 RIGHT JOIN T2 ON a1 = b2) t GROUP BY a1, b1, a2, b2") + } + + @Test + def testAggOnRightJoin2(): Unit = { + util.verifyPlan("SELECT a1, b1, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1 RIGHT JOIN T3 ON a1 = a3) t GROUP BY a1, b1, a3, b3") + } + + @Test + def testAggOnRightJoin3(): Unit = { + util.verifyPlan("SELECT a3, b3, a1, b1, COUNT(c1) FROM " + + "(SELECT * FROM T3 RIGHT JOIN T1 ON a1 = a3) t GROUP BY a3, b3, a1, b1") + } + + @Test + def testAggOnFullJoin1(): Unit = { + util.verifyPlan("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1 FULL OUTER JOIN T2 ON a1 = b2) t GROUP BY a1, b1, a2, b2") + } + + @Test + def testAggOnFullJoin2(): Unit = { + util.verifyPlan("SELECT a1, b1, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1 FULL OUTER JOIN T3 ON a1 = a3) t GROUP BY a1, b1, a3, b3") + } + + @Test + def testAggOnOver(): Unit = { + util.verifyPlan("SELECT a1, b1, c, COUNT(d1) FROM " + + "(SELECT a1, b1, d1, COUNT(*) OVER (PARTITION BY c1) AS c FROM T1) t GROUP BY a1, b1, c") + } + + @Test + def testAggOnWindow1(): Unit = { + util.verifyPlan("SELECT a4, b4, COUNT(c4) FROM T4 " + + "GROUP BY a4, b4, TUMBLE(d4, INTERVAL '15' MINUTE)") + } + + @Test + def testAggOnWindow2(): Unit = { + util.verifyPlan("SELECT a4, c4, COUNT(b4), AVG(b4) FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)") + } + + @Test + def testAggOnWindow3(): Unit = { + util.verifyPlan("SELECT a4, c4, s, COUNT(b4) FROM " + + "(SELECT a4, c4, VAR_POP(b4) AS b4, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4, s") + } + + @Test + def testAggOnWindow4(): Unit = { + util.verifyPlan("SELECT a4, c4, e, COUNT(b4) FROM " + + "(SELECT a4, c4, VAR_POP(b4) AS b4, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4, e") + } + + @Test + def testAggOnWindow5(): Unit = { + util.verifyPlan("SELECT a4, b4, c4, COUNT(*) FROM " + + "(SELECT a4, c4, VAR_POP(b4) AS b4, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, b4, c4") + } + + @Test + def testAggWithGroupingSets1(): Unit = { + util.verifyPlan("SELECT a1, b1, c1, COUNT(d1) FROM T1 " + + "GROUP BY GROUPING SETS ((a1, b1), (a1, c1))") + } + + @Test + def testAggWithGroupingSets2(): Unit = { + util.verifyPlan("SELECT a1, SUM(b1) AS s FROM T1 GROUP BY GROUPING SETS((a1, c1), (a1), ())") + } + + @Test + def testAggWithGroupingSets3(): Unit = { + util.verifyPlan("SELECT a1, b1, c1, COUNT(d1) FROM T1 " + + "GROUP BY GROUPING SETS ((a1, b1, c1), (a1, b1, d1))") + } + + @Test + def testAggWithRollup(): Unit = { + util.verifyPlan("SELECT a1, b1, c1, COUNT(d1) FROM T1 GROUP BY ROLLUP (a1, b1, c1)") + } + + @Test + def testAggWithCube(): Unit = { + util.verifyPlan("SELECT a1, b1, c1, COUNT(d1) FROM T1 GROUP BY CUBE (a1, b1, c1)") + } + + @Test + def testSingleDistinctAgg1(): Unit = { + util.verifyPlan("SELECT a1, COUNT(DISTINCT c1) FROM T1 GROUP BY a1") + } + + @Test + def testSingleDistinctAgg2(): Unit = { + util.verifyPlan("SELECT a1, b1, COUNT(DISTINCT c1) FROM T1 GROUP BY a1, b1") + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg1(): Unit = { + util.verifyPlan("SELECT a1, COUNT(DISTINCT b1), SUM(b1) FROM T1 GROUP BY a1") + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg2(): Unit = { + util.verifyPlan("SELECT a1, c1, COUNT(DISTINCT b1), SUM(b1) FROM T1 GROUP BY a1, c1") + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg3(): Unit = { + util.verifyPlan("SELECT a1, COUNT(DISTINCT c1), SUM(b1) FROM T1 GROUP BY a1") + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg4(): Unit = { + util.verifyPlan("SELECT a1, d1, COUNT(DISTINCT c1), SUM(b1) FROM T1 GROUP BY a1, d1") + } + + @Test + def testMultiDistinctAggs1(): Unit = { + util.verifyPlan("SELECT a1, COUNT(DISTINCT b1), SUM(DISTINCT b1) FROM T1 GROUP BY a1") + } + + @Test + def testMultiDistinctAggs2(): Unit = { + util.verifyPlan("SELECT a1, d1, COUNT(DISTINCT c1), SUM(DISTINCT b1) FROM T1 GROUP BY a1, d1") + } + + @Test + def testMultiDistinctAggs3(): Unit = { + util.verifyPlan( + "SELECT a1, SUM(DISTINCT b1), MAX(DISTINCT b1), MIN(DISTINCT c1) FROM T1 GROUP BY a1") + } + + @Test + def testMultiDistinctAggs_WithNonDistinctAgg1(): Unit = { + util.verifyPlan( + "SELECT a1, d1, COUNT(DISTINCT c1), MAX(DISTINCT b1), SUM(b1) FROM T1 GROUP BY a1, d1") + } + +} + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.scala new file mode 100644 index 00000000000000..7c826cfc595460 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/AggregateReduceGroupingRuleTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.common.AggregateReduceGroupingTestBase +import org.apache.flink.table.plan.optimize.program.FlinkBatchProgram + +import org.apache.calcite.tools.RuleSets +import org.junit.Before + +/** + * Test for [[AggregateReduceGroupingRule]]. + */ +class AggregateReduceGroupingRuleTest extends AggregateReduceGroupingTestBase { + + @Before + override def setup(): Unit = { + util.buildBatchProgram(FlinkBatchProgram.LOGICAL_REWRITE) + + // remove FlinkAggregateRemoveRule to prevent the agg from removing + val programs = util.getTableEnv.getConfig.getCalciteConfig.getBatchProgram + .getOrElse(FlinkBatchProgram.buildProgram(util.getTableEnv.getConfig.getConf)) + programs.getFlinkRuleSetProgram(FlinkBatchProgram.LOGICAL).get + .remove(RuleSets.ofList(FlinkAggregateRemoveRule.INSTANCE)) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + super.setup() + } +} + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.scala new file mode 100644 index 00000000000000..8e901b72037ac1 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/CalcPruneAggregateCallRuleTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{FlinkBatchProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules.{FilterCalcMergeRule, FilterToCalcRule, ProjectCalcMergeRule, ProjectToCalcRule} +import org.apache.calcite.tools.RuleSets + +/** + * Test for [[PruneAggregateCallRule]]#CALC_ON_AGGREGATE. + */ +class CalcPruneAggregateCallRuleTest extends PruneAggregateCallRuleTestBase { + + override def setup(): Unit = { + super.setup() + util.buildBatchProgram(FlinkBatchProgram.LOGICAL) + + val programs = util.getTableEnv.getConfig.getCalciteConfig.getBatchProgram.get + programs.addLast("rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + AggregateReduceGroupingRule.INSTANCE, + FilterCalcMergeRule.INSTANCE, + ProjectCalcMergeRule.INSTANCE, + FilterToCalcRule.INSTANCE, + ProjectToCalcRule.INSTANCE, + FlinkCalcMergeRule.INSTANCE, + PruneAggregateCallRule.CALC_ON_AGGREGATE) + ).build()) + + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala new file mode 100644 index 00000000000000..ea096e37046d60 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateInnerJoinTransposeRuleTest.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.Types +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkChainedProgram, FlinkGroupProgramBuilder, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.util.TableTestBase + +import com.google.common.collect.ImmutableSet +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules._ +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[FlinkAggregateJoinTransposeRule]]. + * this class only test inner join. + */ +class FlinkAggregateInnerJoinTransposeRuleTest extends TableTestBase { + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + val program = new FlinkChainedProgram[BatchOptimizeContext]() + program.addLast( + "rules", + FlinkGroupProgramBuilder.newBuilder[BatchOptimizeContext] + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList(AggregateReduceGroupingRule.INSTANCE + )).build(), "reduce unless grouping") + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + AggregateReduceGroupingRule.INSTANCE, + FlinkFilterJoinRule.FILTER_ON_JOIN, + FlinkFilterJoinRule.JOIN, + FilterAggregateTransposeRule.INSTANCE, + FilterProjectTransposeRule.INSTANCE, + FilterMergeRule.INSTANCE, + AggregateProjectMergeRule.INSTANCE, + FlinkAggregateJoinTransposeRule.EXTENDED + )).build(), "aggregate join transpose") + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(program).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + util.addTableSource[(Int, Int, String)]("T", 'a, 'b, 'c) + util.addTableSource("T2", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING), + Array("a2", "b2", "c2"), + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("b2"))).build() + ) + } + + @Test + def testPushCountAggThroughJoinOverUniqueColumn(): Unit = { + util.verifyPlan("SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a") + } + + @Test + def testPushSumAggThroughJoinOverUniqueColumn(): Unit = { + util.verifyPlan("SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a") + } + + @Test + def testPushAggThroughJoinWithUniqueJoinKey(): Unit = { + val sqlQuery = + """ + |WITH T1 AS (SELECT a AS a1, COUNT(b) AS b1 FROM T GROUP BY a), + | T2 AS (SELECT COUNT(a) AS a2, b AS b2 FROM T GROUP BY b) + |SELECT MIN(a1), MIN(b1), MIN(a2), MIN(b2), a, b, COUNT(c) FROM + | (SELECT * FROM T1, T2, T WHERE a1 = b2 AND a1 = a) t GROUP BY a, b + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testSomeAggCallColumnsAndJoinConditionColumnsIsSame(): Unit = { + val sqlQuery = + """ + |SELECT MIN(a2), MIN(b2), a, b, COUNT(c2) FROM + | (SELECT * FROM T2, T WHERE b2 = a) t GROUP BY a, b + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsUnique1(): Unit = { + val sqlQuery = + """ + |SELECT a2, b2, c2, SUM(a) FROM (SELECT * FROM T2, T WHERE b2 = b) GROUP BY a2, b2, c2 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsUnique2(): Unit = { + val sqlQuery = + """ + |SELECT a2, b2, c, SUM(a) FROM (SELECT * FROM T2, T WHERE b2 = b) GROUP BY a2, b2, c + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsNotUnique1(): Unit = { + val sqlQuery = + """ + |SELECT a2, b2, c2, SUM(a) FROM (SELECT * FROM T2, T WHERE a2 = a) GROUP BY a2, b2, c2 + """.stripMargin + util.verifyPlan(sqlQuery) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsNotUnique2(): Unit = { + val sqlQuery = + """ + |SELECT a2, b2, c, SUM(a) FROM (SELECT * FROM T2, T WHERE a2 = a) GROUP BY a2, b2, c + """.stripMargin + util.verifyPlan(sqlQuery) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala new file mode 100644 index 00000000000000..486bea173ed25f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateOuterJoinTransposeRuleTest.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.scala._ +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{FlinkChainedProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE, StreamOptimizeContext} +import org.apache.flink.table.util.TableTestBase + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules._ +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[FlinkAggregateJoinTransposeRule]]. + * this class only test left/right outer join. + */ +class FlinkAggregateOuterJoinTransposeRuleTest extends TableTestBase { + + private val util = streamTestUtil() + + @Before + def setup(): Unit = { + val program = new FlinkChainedProgram[StreamOptimizeContext]() + program.addLast( + "rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + FlinkFilterJoinRule.FILTER_ON_JOIN, + FlinkFilterJoinRule.JOIN, + FilterAggregateTransposeRule.INSTANCE, + FilterProjectTransposeRule.INSTANCE, + FilterMergeRule.INSTANCE, + AggregateProjectMergeRule.INSTANCE, + FlinkAggregateJoinTransposeRule.LEFT_RIGHT_OUTER_JOIN_EXTENDED + )) + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceStreamProgram(program).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + util.addTableSource[(Int, Long, String, Int)]("T", 'a, 'b, 'c, 'd) + } + + @Test + def testPushCountAggThroughJoinOverUniqueColumn(): Unit = { + util.verifyPlan("SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a") + } + + @Test + def testPushSumAggThroughJoinOverUniqueColumn(): Unit = { + util.verifyPlan("SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A JOIN T AS B ON A.a=B.a") + } + + @Test + def testPushCountAggThroughLeftJoinOverUniqueColumn(): Unit = { + val sqlQuery = "SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T) AS A " + + "LEFT OUTER JOIN T AS B ON A.a=B.a" + util.verifyPlan(sqlQuery) + } + + @Test + def testPushSumAggThroughLeftJoinOverUniqueColumn(): Unit = { + val sqlQuery = "SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T) AS A " + + "LEFT OUTER JOIN T AS B ON A.a=B.a" + util.verifyPlan(sqlQuery) + } + + @Test + def testPushCountAllAggThroughLeftJoinOverUniqueColumn(): Unit = { + val sqlQuery = "SELECT COUNT(*) FROM (SELECT DISTINCT a FROM T) AS A " + + "LEFT OUTER JOIN T AS B ON A.a=B.a" + util.verifyPlan(sqlQuery) + } + + @Test + def testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByLeft(): Unit = { + val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A " + + "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a" + util.verifyPlan(sqlQuery) + } + + @Test + def testPushCountAggThroughLeftJoinOverUniqueColumnAndGroupByRight(): Unit = { + val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT DISTINCT a FROM T) AS A " + + "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY B.a" + util.verifyPlan(sqlQuery) + } + + @Test + def testPushCountAggThroughLeftJoinAndGroupByLeft(): Unit = { + val sqlQuery = "SELECT COUNT(B.b) FROM (SELECT a FROM T) AS A " + + "LEFT OUTER JOIN T AS B ON A.a=B.a GROUP BY A.a" + util.verifyPlan(sqlQuery) + } + + @Test + def testPushCountAggThroughRightJoin(): Unit = { + val sqlQuery = "SELECT COUNT(B.b) FROM T AS B RIGHT OUTER JOIN (SELECT a FROM T) AS A " + + "ON A.a=B.a GROUP BY A.a" + util.verifyPlan(sqlQuery) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.scala new file mode 100644 index 00000000000000..1c69669def2949 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/FlinkAggregateRemoveRuleTest.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.Types +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalAggregate, FlinkLogicalCalc, FlinkLogicalExpand, FlinkLogicalJoin, FlinkLogicalSink, FlinkLogicalTableSourceScan, FlinkLogicalValues} +import org.apache.flink.table.plan.optimize.program._ +import org.apache.flink.table.plan.rules.FlinkBatchRuleSets +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.util.TableTestBase + +import com.google.common.collect.ImmutableSet +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules.{FilterCalcMergeRule, FilterToCalcRule, ProjectCalcMergeRule, ProjectToCalcRule, ReduceExpressionsRule} +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +/** + * Test for [[FlinkAggregateRemoveRule]]. + */ +class FlinkAggregateRemoveRuleTest extends TableTestBase { + private val util = batchTestUtil() + + @Before + def setup(): Unit = { + val programs = new FlinkChainedProgram[BatchOptimizeContext]() + programs.addLast( + // rewrite sub-queries to joins + "subquery_rewrite", + FlinkGroupProgramBuilder.newBuilder[BatchOptimizeContext] + .addProgram(FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkBatchRuleSets.SEMI_JOIN_RULES) + .build(), "rewrite sub-queries to semi/anti join") + .build()) + + programs.addLast( + "rules", + // use volcano planner because + // rel.getCluster.getPlanner is volcano planner used in FlinkAggregateRemoveRule + FlinkVolcanoProgramBuilder.newBuilder + .add(RuleSets.ofList( + ReduceExpressionsRule.FILTER_INSTANCE, + FlinkAggregateExpandDistinctAggregatesRule.INSTANCE, + FilterCalcMergeRule.INSTANCE, + ProjectCalcMergeRule.INSTANCE, + FilterToCalcRule.INSTANCE, + ProjectToCalcRule.INSTANCE, + FlinkCalcMergeRule.INSTANCE, + FlinkAggregateRemoveRule.INSTANCE, + DecomposeGroupingSetsRule.INSTANCE, + AggregateReduceGroupingRule.INSTANCE, + FlinkLogicalAggregate.BATCH_CONVERTER, + FlinkLogicalCalc.CONVERTER, + FlinkLogicalJoin.CONVERTER, + FlinkLogicalValues.CONVERTER, + FlinkLogicalExpand.CONVERTER, + FlinkLogicalTableSourceScan.CONVERTER, + FlinkLogicalSink.CONVERTER)) + .setRequiredOutputTraits(Array(FlinkConventions.LOGICAL)) + .build()) + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + + util.addTableSource[(Int, Int, String)]("MyTable1", 'a, 'b, 'c) + util.addTableSource("MyTable2", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING), + Array("a", "b", "c"), + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("a"))).build() + ) + util.addTableSource("MyTable3", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING, Types.STRING), + Array("a", "b", "c", "d"), + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("a"))).build() + ) + } + + @Test + def testAggRemove_GroupKeyIsNotUnique(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, MAX(c) from MyTable1 GROUP BY a") + } + + @Test + def testAggRemove_WithoutFilter1(): Unit = { + util.verifyPlan("SELECT a, b + 1, c, s FROM (" + + "SELECT a, MIN(b) AS b, SUM(b) AS s, MAX(c) AS c FROM MyTable2 GROUP BY a)") + } + + @Test + def testAggRemove_WithoutFilter2(): Unit = { + util.verifyPlan("SELECT a, SUM(b) AS s FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_WithoutGroupBy1(): Unit = { + // can not remove agg + util.verifyPlan("SELECT MAX(a), SUM(b), MIN(c) FROM MyTable2") + } + + @Test + def testAggRemove_WithoutGroupBy2(): Unit = { + util.verifyPlan("SELECT MAX(a), SUM(b), MIN(c) FROM (VALUES (1, 2, 3)) T(a, b, c)") + } + + @Test + def testAggRemove_WithoutGroupBy3(): Unit = { + // can not remove agg + util.verifyPlan("SELECT * FROM MyTable2 WHERE EXISTS (SELECT SUM(a) FROM MyTable1 WHERE 1=2)") + } + + @Test + def testAggRemove_WithoutGroupBy4(): Unit = { + // can not remove agg + util.verifyPlan("SELECT SUM(a) FROM (SELECT a FROM MyTable2 WHERE 1=2)") + } + + @Test + def testAggRemove_WithoutAggCall(): Unit = { + util.verifyPlan("SELECT a, b FROM MyTable2 GROUP BY a, b") + } + + @Test + def testAggRemove_WithFilter(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, MIN(c) FILTER (WHERE b > 0), MAX(b) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_Count(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, COUNT(c) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_CountStar(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, COUNT(*) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_GroupSets1(): Unit = { + // a is unique + util.verifyPlan("SELECT a, SUM(b) AS s FROM MyTable3 GROUP BY GROUPING SETS((a, c), (a, d))") + } + + @Test + def testAggRemove_GroupSets2(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, SUM(b) AS s FROM MyTable3 GROUP BY GROUPING SETS((a, c), (a), ())") + } + + @Test + def testAggRemove_Rollup(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, SUM(b) AS s FROM MyTable3 GROUP BY ROLLUP(a, c, d)") + } + + @Test + def testAggRemove_Cube(): Unit = { + // can not remove agg + util.verifyPlan("SELECT a, SUM(b) AS s FROM MyTable3 GROUP BY CUBE(a, c, d)") + } + + @Test + def testAggRemove_SingleDistinctAgg1(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT c) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_SingleDistinctAgg2(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT c) FROM MyTable2 GROUP BY a, b") + } + + @Test + def testAggRemove_SingleDistinctAgg_WithNonDistinctAgg1(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT b), SUM(b) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_SingleDistinctAgg_WithNonDistinctAgg2(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT b), SUM(b) FROM MyTable2 GROUP BY a, c") + } + + @Test + def testAggRemove_SingleDistinctAgg_WithNonDistinctAgg3(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT c), SUM(b) FROM MyTable3 GROUP BY a") + } + + @Test + def testAggRemove_SingleDistinctAgg_WithNonDistinctAgg4(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT c), SUM(b) FROM MyTable3 GROUP BY a, d") + } + + @Test + def testAggRemove_MultiDistinctAggs1(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT b), SUM(DISTINCT b) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_MultiDistinctAggs2(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT c), SUM(DISTINCT b) FROM MyTable3 GROUP BY a, d") + } + + @Test + def testAggRemove_MultiDistinctAggs3(): Unit = { + util.verifyPlan( + "SELECT a, SUM(DISTINCT b), MAX(DISTINCT b), MIN(DISTINCT c) FROM MyTable2 GROUP BY a") + } + + @Test + def testAggRemove_MultiDistinctAggs_WithNonDistinctAgg1(): Unit = { + util.verifyPlan("SELECT a, COUNT(DISTINCT c), SUM(b) FROM MyTable3 GROUP BY a, d") + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/JoinDeriveNullFilterRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/JoinDeriveNullFilterRuleTest.scala index b70cdaa79c3d6f..dfb986f995203c 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/JoinDeriveNullFilterRuleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/JoinDeriveNullFilterRuleTest.scala @@ -21,7 +21,7 @@ package org.apache.flink.table.plan.rules.logical import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.{PlannerConfigOptions, Types} import org.apache.flink.table.plan.optimize.program.{FlinkBatchProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} -import org.apache.flink.table.plan.stats.{ColumnStats, TableStats} +import org.apache.flink.table.plan.stats.{ColumnStats, FlinkStatistic, TableStats} import org.apache.flink.table.util.TableTestBase import org.apache.calcite.plan.hep.HepMatchOrder @@ -55,19 +55,19 @@ class JoinDeriveNullFilterRuleTest extends TableTestBase { util.addTableSource("MyTable1", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING, Types.INT, Types.LONG), Array("a1", "b1", "c1", "d1", "e1"), - Some(new TableStats(1000000000, Map( + FlinkStatistic.builder().tableStats(new TableStats(1000000000, Map( "a1" -> new ColumnStats(null, 10000000L, 4.0, 4, null, null), "c1" -> new ColumnStats(null, 5000000L, 10.2, 16, null, null), "e1" -> new ColumnStats(null, 500000L, 8.0, 8, null, null) - )))) + ))).build()) util.addTableSource("MyTable2", Array[TypeInformation[_]](Types.INT, Types.LONG, Types.STRING, Types.INT, Types.LONG), Array("a2", "b2", "c2", "d2", "e2"), - Some(new TableStats(2000000000, Map( + FlinkStatistic.builder().tableStats(new TableStats(2000000000, Map( "b2" -> new ColumnStats(null, 10000000L, 8.0, 8, null, null), "c2" -> new ColumnStats(null, 3000000L, 18.6, 32, null, null), "e2" -> new ColumnStats(null, 1500000L, 8.0, 8, null, null) - )))) + ))).build()) } @Test diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.scala new file mode 100644 index 00000000000000..01893f361a1352 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/ProjectPruneAggregateCallRuleTest.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{FlinkBatchProgram, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules.ProjectFilterTransposeRule +import org.apache.calcite.tools.RuleSets + +/** + * Test for [[PruneAggregateCallRule]]#PROJECT_ON_AGGREGATE. + */ +class ProjectPruneAggregateCallRuleTest extends PruneAggregateCallRuleTestBase { + + override def setup(): Unit = { + super.setup() + util.buildBatchProgram(FlinkBatchProgram.LOGICAL) + + val programs = util.getTableEnv.getConfig.getCalciteConfig.getBatchProgram.get + programs.addLast("rules", + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + AggregateReduceGroupingRule.INSTANCE, + ProjectFilterTransposeRule.INSTANCE, + PruneAggregateCallRule.PROJECT_ON_AGGREGATE) + ).build()) + + val calciteConfig = CalciteConfig.createBuilder(util.tableEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + util.tableEnv.getConfig.setCalciteConfig(calciteConfig) + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRuleTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRuleTestBase.scala new file mode 100644 index 00000000000000..1a3e9dc94cd3ae --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/rules/logical/PruneAggregateCallRuleTestBase.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.rules.logical + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.Types +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.util.{BatchTableTestUtil, TableTestBase} + +import com.google.common.collect.ImmutableSet +import org.junit.{Before, Test} + +/** + * Base test class for [[PruneAggregateCallRule]]. + */ +abstract class PruneAggregateCallRuleTestBase extends TableTestBase { + protected val util: BatchTableTestUtil = batchTestUtil() + + @Before + def setup(): Unit = { + util.addTableSource("T1", + Array[TypeInformation[_]](Types.INT, Types.INT, Types.STRING, Types.INT), + Array("a1", "b1", "c1", "d1"), + FlinkStatistic.builder().uniqueKeys(ImmutableSet.of(ImmutableSet.of("a1"))).build() + ) + util.addTableSource[(Int, Int, String, Long)]("T2", 'a2, 'b2, 'c2, 'd2) + } + + @Test + def testPruneRegularAggCall_WithoutFilter1(): Unit = { + val sql = + """ + |SELECT a2, b2, d2 FROM + | (SELECT a2, b2, COUNT(c2) as c2, SUM(d2) as d2 FROM T2 GROUP BY a2, b2) t + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneRegularAggCall_WithoutFilter2(): Unit = { + val sql = + """ + |SELECT b2, a2, d2 FROM + | (SELECT a2, b2, COUNT(c2) as c2, SUM(d2) as d2 FROM T2 GROUP BY a2, b2) t + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneRegularAggCall_WithoutFilter3(): Unit = { + val sql = + """ + |SELECT a2 as a, b2, d2 FROM + | (SELECT a2, b2, COUNT(c2) as c2, SUM(d2) as d2 FROM T2 GROUP BY a2, b2) t + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneRegularAggCall_WithFilter1(): Unit = { + val sql = + """ + |SELECT a2, b2, d2 FROM + | (SELECT a2, b2, COUNT(c2) as c2, SUM(d2) as d2 FROM T2 GROUP BY a2, b2) t + |WHERE d2 > 0 + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneRegularAggCall_WithFilter2(): Unit = { + val sql = + """ + |SELECT b2, a2, d2 FROM + | (SELECT a2, b2, COUNT(c2) as c2, SUM(d2) as d2 FROM T2 GROUP BY a2, b2) t + |WHERE d2 > 0 + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneAuxGroupAggCall_WithoutFilter1(): Unit = { + val sql = + """ + |SELECT a1, c1 FROM + | (SELECT a1, b1, COUNT(c1) as c1, SUM(d1) as d1 FROM T1 GROUP BY a1, b1) t + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneAuxGroupAggCall_WithoutFilter2(): Unit = { + val sql = + """ + |SELECT c1, a1 FROM + | (SELECT a1, b1, COUNT(c1) as c1, SUM(d1) as d1 FROM T1 GROUP BY a1, b1) t + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneAuxGroupAggCall_WithFilter1(): Unit = { + val sql = + """ + |SELECT a1, c1 FROM + | (SELECT a1, b1, COUNT(c1) as c1, SUM(d1) as d1 FROM T1 GROUP BY a1, b1) t + |WHERE c1 > 10 + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testPruneAuxGroupAggCall_WithFilter2(): Unit = { + val sql = + """ + |SELECT c1, a1 FROM + | (SELECT a1, b1, COUNT(c1) as c1, SUM(d1) as d1 FROM T1 GROUP BY a1, b1) t + |WHERE c1 > 10 + """.stripMargin + util.verifyPlan(sql) + } + + @Test + def testEmptyGroupKey_WithOneAggCall1(): Unit = { + val sql = "SELECT 1 FROM (SELECT SUM(a1) FROM T1) t" + util.verifyPlan(sql) + } + + @Test + def testEmptyGroupKey_WithOneAggCall2(): Unit = { + val sql = "SELECT * FROM T2 WHERE EXISTS (SELECT COUNT(*) FROM T1)" + util.verifyPlan(sql) + } + + @Test + def testEmptyGroupKey_WithOneAggCall3(): Unit = { + val sql = "SELECT * FROM T2 WHERE EXISTS (SELECT COUNT(*) FROM T1 WHERE 1=2)" + util.verifyPlan(sql) + } + + @Test + def testEmptyGroupKey_WithMoreThanOneAggCalls1(): Unit = { + val sql = "SELECT 1 FROM (SELECT SUM(a1), COUNT(*) FROM T1) t" + util.verifyPlan(sql) + } + + @Test + def testEmptyGroupKey_WithMoreThanOneAggCalls2(): Unit = { + val sql = "SELECT * FROM T2 WHERE EXISTS (SELECT SUM(a1), COUNT(*) FROM T1)" + util.verifyPlan(sql) + } + + @Test + def testEmptyGroupKey_WithMoreThanOneAggCalls3(): Unit = { + val sql = "SELECT * FROM T2 WHERE EXISTS (SELECT SUM(a1), COUNT(*) FROM T1 WHERE 1=2)" + util.verifyPlan(sql) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala index bfeb27a7e3a012..c732b13a2051eb 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/CalcITCase.scala @@ -45,9 +45,9 @@ class CalcITCase extends BatchTestBase { @Before def before(): Unit = { - registerCollection("Table3", data3, type3, nullablesOfData3, "a, b, c") - registerCollection("NullTable3", nullData3, type3, nullablesOfData3, "a, b, c") - registerCollection("SmallTable3", smallData3, type3, nullablesOfData3, "a, b, c") + registerCollection("Table3", data3, type3, "a, b, c", nullablesOfData3) + registerCollection("NullTable3", nullData3, type3, "a, b, c", nullablesOfData3) + registerCollection("SmallTable3", smallData3, type3, "a, b, c", nullablesOfData3) registerCollection("testTable", buildInData, buildInType, "a,b,c,d,e,f,g,h,i,j") } @@ -99,8 +99,9 @@ class CalcITCase extends BatchTestBase { def testManySelect(): Unit = { registerCollection( "ProjectionTestTable", - projectionTestData, projectionTestDataType, nullablesOfProjectionTestData, - "a, b, c, d, e, f, g, h") + projectionTestData, projectionTestDataType, + "a, b, c, d, e, f, g, h", + nullablesOfProjectionTestData) checkResult( """ |SELECT @@ -658,7 +659,7 @@ class CalcITCase extends BatchTestBase { def testValueConstructor(): Unit = { val data = Seq(row("foo", 12, UTCTimestamp("1984-07-12 14:34:24"))) val tpe = new RowTypeInfo(STRING_TYPE_INFO, INT_TYPE_INFO, TIMESTAMP) - registerCollection("MyTable", data, tpe, Array(false, false, false), "a, b, c") + registerCollection("MyTable", data, tpe, "a, b, c" , Array(false, false, false)) val table = parseQuery("SELECT ROW(a, b, c), ARRAY[12, b], MAP[a, c] FROM MyTable " + "WHERE (a, b, c) = ('foo', 12, TIMESTAMP '1984-07-12 14:34:24')") diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/LimitITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/LimitITCase.scala index ef7ce80c28a403..22b2c5e72aa72a 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/LimitITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/LimitITCase.scala @@ -29,7 +29,7 @@ class LimitITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) - registerCollection("Table3", data3, type3, nullablesOfData3, "a, b, c") + registerCollection("Table3", data3, type3, "a, b, c", nullablesOfData3) // TODO support LimitableTableSource // val rowType = new RowTypeInfo(type3.getFieldTypes, Array("a", "b", "c")) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/OverWindowITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/OverWindowITCase.scala index 7ecea2fd60f4d1..5ca852f326776e 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/OverWindowITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/OverWindowITCase.scala @@ -42,11 +42,11 @@ class OverWindowITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) - registerCollection("Table5", data5, type5, nullablesOfData5, "d, e, f, g, h") + registerCollection("Table5", data5, type5, "d, e, f, g, h", nullablesOfData5) registerCollection("ShuflledTable5", - Random.shuffle(data5), type5, nullablesOfData5, "sd, se, sf, sg, sh") - registerCollection("Table6", data6, type6, nullablesOfData6, "a, b, c, d, e, f") - registerCollection("NullTable5", nullData5, type5, nullablesOfNullData5, "d, e, f, g, h") + Random.shuffle(data5), type5, "sd, se, sf, sg, sh", nullablesOfData5) + registerCollection("Table6", data6, type6, "a, b, c, d, e, f", nullablesOfData6) + registerCollection("NullTable5", nullData5, type5, "d, e, f, g, h", nullablesOfNullData5) } @Test @@ -1606,7 +1606,7 @@ class OverWindowITCase extends BatchTestBase { row(null, 3L, 3, "NullTuple", 3L), row(null, 3L, 3, "NullTuple", 3L) ) - registerCollection("NullTable", nullData, type5, nullablesOfNullData5, "d, e, f, g, h") + registerCollection("NullTable", nullData, type5, "d, e, f, g, h", nullablesOfNullData5) checkResult( "SELECT h, d, count(*) over (partition by h order by d range between 0 PRECEDING and " + @@ -1826,7 +1826,7 @@ class OverWindowITCase extends BatchTestBase { row(null, 3L, 3, "NullTuple", 3L), row(null, 3L, 3, "NullTuple", 3L) ) - registerCollection("NullTable", nullData, type5, nullablesOfNullData5, "d, e, f, g, h") + registerCollection("NullTable", nullData, type5, "d, e, f, g, h", nullablesOfNullData5) checkResult( "SELECT h, d, count(*) over (partition by h order by d range between 1 PRECEDING and 2 " + "FOLLOWING) FROM NullTable", diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/RankITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/RankITCase.scala index 2d529f4eaa77bc..bed1f387c30341 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/RankITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/RankITCase.scala @@ -32,8 +32,8 @@ class RankITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) - registerCollection("Table3", data3, type3, nullablesOfData3, "a, b, c") - registerCollection("Table5", data5, type5, nullablesOfData5, "a, b, c, d, e") + registerCollection("Table3", data3, type3, "a, b, c", nullablesOfData3) + registerCollection("Table5", data5, type5, "a, b, c, d, e", nullablesOfData5) registerCollection("Table2", data2_1, INT_DOUBLE, "a, b") } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/UnionITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/UnionITCase.scala index 874a13254e523f..00f9fdecffc4a5 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/UnionITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/UnionITCase.scala @@ -18,7 +18,6 @@ package org.apache.flink.table.runtime.batch.sql -import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{INT_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO} import org.apache.flink.table.`type`.InternalTypes import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions} import org.apache.flink.table.dataformat.BinaryString.fromString @@ -46,9 +45,9 @@ class UnionITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) - registerCollection("Table3", smallData3, type3, nullablesOfSmallData3, "a, b, c") - registerCollection("Table5", data5, type5, nullablesOfData5, "d, e, f, g, h") - registerCollection("Table6", data6, type6, Array(false, false, false), "a, b, c") + registerCollection("Table3", smallData3, type3, "a, b, c", nullablesOfSmallData3) + registerCollection("Table5", data5, type5, "d, e, f, g, h", nullablesOfData5) + registerCollection("Table6", data6, type6, "a, b, c", Array(false, false, false)) tEnv.getConfig.getConf.setString( TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateJoinTransposeITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateJoinTransposeITCase.scala new file mode 100644 index 00000000000000..14abd2c2856b9d --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateJoinTransposeITCase.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.batch.sql.agg + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.api.{TableConfigOptions, TableException, Types} +import org.apache.flink.table.calcite.CalciteConfig +import org.apache.flink.table.plan.optimize.program.{BatchOptimizeContext, FlinkBatchProgram, FlinkGroupProgramBuilder, FlinkHepRuleSetProgramBuilder, HEP_RULES_EXECUTION_TYPE} +import org.apache.flink.table.plan.rules.logical.{AggregateReduceGroupingRule, FlinkAggregateJoinTransposeRule} +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.runtime.utils.BatchTestBase +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.TestData._ + +import org.apache.calcite.plan.hep.HepMatchOrder +import org.apache.calcite.rel.rules._ +import org.apache.calcite.tools.RuleSets +import org.junit.{Before, Test} + +import scala.collection.JavaConverters._ +import scala.collection.Seq + +class AggregateJoinTransposeITCase extends BatchTestBase { + + @Before + def before(): Unit = { + val programs = FlinkBatchProgram.buildProgram(tEnv.getConfig.getConf) + // remove FlinkAggregateJoinTransposeRule from logical program (volcano planner) + programs.getFlinkRuleSetProgram(FlinkBatchProgram.LOGICAL) + .getOrElse(throw new TableException(s"${FlinkBatchProgram.LOGICAL} does not exist")) + .remove(RuleSets.ofList(FlinkAggregateJoinTransposeRule.EXTENDED)) + + // add FlinkAggregateJoinTransposeRule to hep program + // to make sure that the aggregation must be pushed down + programs.addBefore( + FlinkBatchProgram.LOGICAL, + "FlinkAggregateJoinTransposeRule", + FlinkGroupProgramBuilder.newBuilder[BatchOptimizeContext] + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + AggregateReduceGroupingRule.INSTANCE + )).build(), "reduce unless grouping") + .addProgram( + FlinkHepRuleSetProgramBuilder.newBuilder + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(RuleSets.ofList( + AggregateReduceGroupingRule.INSTANCE, + AggregateProjectMergeRule.INSTANCE, + FlinkAggregateJoinTransposeRule.EXTENDED + )).build(), "aggregate join transpose") + .build() + ) + val calciteConfig = CalciteConfig.createBuilder(tEnv.getConfig.getCalciteConfig) + .replaceBatchProgram(programs).build() + tEnv.getConfig.setCalciteConfig(calciteConfig) + + tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) + // HashJoin is disabled due to translateToPlanInternal method is not implemented yet + tEnv.getConfig.getConf.setString(TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin") + registerCollection("T3", data3, type3, "a, b, c", nullablesOfData3) + + registerCollection("MyTable", + Seq(row(1, 1L, "X"), + row(1, 2L, "Y"), + row(2, 3L, null), + row(2, 4L, "Z")), + new RowTypeInfo(Types.INT, Types.LONG, Types.STRING), + "a2, b2, c2", + Array(true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("b2").asJava).asJava).build() + ) + } + + @Test + def testPushCountAggThroughJoinOverUniqueColumn(): Unit = { + checkResult( + "SELECT COUNT(A.a) FROM (SELECT DISTINCT a FROM T3) AS A JOIN T3 AS B ON A.a=B.a", + Seq(row(21)) + ) + } + + @Test + def testPushSumAggThroughJoinOverUniqueColumn(): Unit = { + checkResult( + "SELECT SUM(A.a) FROM (SELECT DISTINCT a FROM T3) AS A JOIN T3 AS B ON A.a=B.a", + Seq(row(231)) + ) + } + + @Test + def testSomeAggCallColumnsAndJoinConditionColumnsIsSame(): Unit = { + checkResult( + "SELECT MIN(a2), MIN(b2), a, b, COUNT(c2) FROM " + + "(SELECT * FROM MyTable, T3 WHERE b2 = b) t GROUP BY b, a", + Seq(row(1, 1, 1, 1, 1), row(1, 2, 2, 2, 1), row(1, 2, 3, 2, 1), + row(2, 3, 4, 3, 0), row(2, 3, 5, 3, 0), row(2, 3, 6, 3, 0), + row(2, 4, 10, 4, 1), row(2, 4, 7, 4, 1), row(2, 4, 8, 4, 1), row(2, 4, 9, 4, 1)) + ) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsUnique1(): Unit = { + checkResult( + """ + |select a2, b2, c2, SUM(a) FROM ( + | SELECT * FROM MyTable, T3 WHERE b2 = b + |) GROUP BY a2, b2, c2 + """.stripMargin, + Seq(row(1, 1, "X", 1), row(1, 2, "Y", 5), row(2, 3, null, 15), row(2, 4, "Z", 34))) + + checkResult( + """ + |select a2, b2, c2, SUM(a), COUNT(c) FROM ( + | SELECT * FROM MyTable, T3 WHERE b2 = b + |) GROUP BY a2, b2, c2 + """.stripMargin, + Seq(row(1, 1, "X", 1, 1), row(1, 2, "Y", 5, 2), + row(2, 3, null, 15, 3), row(2, 4, "Z", 34, 4))) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsUnique2(): Unit = { + checkResult( + """ + |select a2, b2, c, SUM(a) FROM ( + | SELECT * FROM MyTable, T3 WHERE b2 = b + |) GROUP BY a2, b2, c + """.stripMargin, + Seq(row(1, 1, "Hi", 1), row(1, 2, "Hello world", 3), row(1, 2, "Hello", 2), + row(2, 3, "Hello world, how are you?", 4), row(2, 3, "I am fine.", 5), + row(2, 3, "Luke Skywalker", 6), row(2, 4, "Comment#1", 7), row(2, 4, "Comment#2", 8), + row(2, 4, "Comment#3", 9), row(2, 4, "Comment#4", 10))) + + checkResult( + """ + |select a2, b2, c, SUM(a), MAX(b) FROM ( + | SELECT * FROM MyTable, T3 WHERE b2 = b + |) GROUP BY a2, b2, c + """.stripMargin, + Seq(row(1, 1, "Hi", 1, 1), row(1, 2, "Hello world", 3, 2), row(1, 2, "Hello", 2, 2), + row(2, 3, "Hello world, how are you?", 4, 3), row(2, 3, "I am fine.", 5, 3), + row(2, 3, "Luke Skywalker", 6, 3), row(2, 4, "Comment#1", 7, 4), + row(2, 4, "Comment#2", 8, 4), row(2, 4, "Comment#3", 9, 4), row(2, 4, "Comment#4", 10, 4))) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsNotUnique1(): Unit = { + checkResult( + """ + |select a2, b2, c2, SUM(a) FROM ( + | SELECT * FROM MyTable, T3 WHERE a2 = a + |) GROUP BY a2, b2, c2 + """.stripMargin, + Seq(row(1, 1, "X", 1), row(1, 2, "Y", 1), row(2, 3, null, 2), row(2, 4, "Z", 2))) + + checkResult( + """ + |select a2, b2, c2, SUM(a), COUNT(c) FROM ( + | SELECT * FROM MyTable, T3 WHERE a2 = a + |) GROUP BY a2, b2, c2 + """.stripMargin, + Seq(row(1, 1, "X", 1, 1), row(1, 2, "Y", 1, 1), row(2, 3, null, 2, 1), row(2, 4, "Z", 2, 1))) + } + + @Test + def testAggregateWithAuxGroup_JoinKeyIsNotUnique2(): Unit = { + checkResult( + """ + |select a2, b2, c, SUM(a) FROM ( + | SELECT * FROM MyTable, T3 WHERE a2 = a + |) GROUP BY a2, b2, c + """.stripMargin, + Seq(row(1, 1, "Hi", 1), row(1, 2, "Hi", 1), row(2, 3, "Hello", 2), row(2, 4, "Hello", 2))) + + checkResult( + """ + |select a2, b2, c, SUM(a), MIN(b) FROM ( + | SELECT * FROM MyTable, T3 WHERE a2 = a + |) GROUP BY a2, b2, c + """.stripMargin, + Seq(row(1, 1, "Hi", 1, 1), row(1, 2, "Hi", 1, 1), + row(2, 3, "Hello", 2, 2), row(2, 4, "Hello", 2, 2))) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateReduceGroupingITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateReduceGroupingITCase.scala new file mode 100644 index 00000000000000..670a6d1589b416 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateReduceGroupingITCase.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.batch.sql.agg + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.api.{PlannerConfigOptions, TableConfigOptions, Types} +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.runtime.utils.BatchTestBase +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.util.DateTimeTestUtil.UTCTimestamp + +import org.junit.{Before, Test} + +import java.sql.Date + +import scala.collection.JavaConverters._ +import scala.collection.Seq + +class AggregateReduceGroupingITCase extends BatchTestBase { + + @Before + def before(): Unit = { + registerCollection("T1", + Seq(row(2, 1, "A", null), + row(3, 2, "A", "Hi"), + row(5, 2, "B", "Hello"), + row(6, 3, "C", "Hello world")), + new RowTypeInfo(Types.INT, Types.INT, Types.STRING, Types.STRING), + "a1, b1, c1, d1", + Array(true, true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("a1").asJava).asJava).build() + ) + + registerCollection("T2", + Seq(row(1, 1, "X"), + row(1, 2, "Y"), + row(2, 3, null), + row(2, 4, "Z")), + new RowTypeInfo(Types.INT, Types.INT, Types.STRING), + "a2, b2, c2", + Array(true, true, true), + FlinkStatistic.builder() + .uniqueKeys(Set(Set("b2").asJava, Set("a2", "b2").asJava).asJava).build() + ) + + registerCollection("T3", + Seq(row(1, 10, "Hi", 1L), + row(2, 20, "Hello", 1L), + row(2, 20, "Hello world", 2L), + row(3, 10, "Hello world, how are you?", 1L), + row(4, 20, "I am fine.", 2L), + row(4, null, "Luke Skywalker", 2L)), + new RowTypeInfo(Types.INT, Types.INT, Types.STRING, Types.LONG), + "a3, b3, c3, d3", + Array(true, true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("a1").asJava).asJava).build() + ) + + registerCollection("T4", + Seq(row(1, 1, "A", UTCTimestamp("2018-06-01 10:05:30"), "Hi"), + row(2, 1, "B", UTCTimestamp("2018-06-01 10:10:10"), "Hello"), + row(3, 2, "B", UTCTimestamp("2018-06-01 10:15:25"), "Hello world"), + row(4, 3, "C", UTCTimestamp("2018-06-01 10:36:49"), "I am fine.")), + new RowTypeInfo(Types.INT, Types.INT, Types.STRING, Types.SQL_TIMESTAMP, Types.STRING), + "a4, b4, c4, d4, e4", + Array(true, true, true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("a4").asJava).asJava).build() + ) + + registerCollection("T5", + Seq(row(2, 1, "A", null), + row(3, 2, "B", "Hi"), + row(1, null, "C", "Hello"), + row(4, 3, "D", "Hello world"), + row(3, 1, "E", "Hello world, how are you?"), + row(5, null, "F", null), + row(7, 2, "I", "hahaha"), + row(6, 1, "J", "I am fine.")), + new RowTypeInfo(Types.INT, Types.INT, Types.STRING, Types.STRING), + "a5, b5, c5, d5", + Array(true, true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("c5").asJava).asJava).build() + ) + + registerCollection("T6", + (0 until 50000).map( + i => row(i, 1L, if (i % 500 == 0) null else s"Hello$i", "Hello world", 10, + new Date(i + 1531820000000L))), + new RowTypeInfo(Types.INT, Types.LONG, Types.STRING, Types.STRING, Types.INT, Types.SQL_DATE), + "a6, b6, c6, d6, e6, f6", + Array(true, true, true, true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("a6").asJava).asJava).build() + ) + // HashJoin is disabled due to translateToPlanInternal method is not implemented yet + tEnv.getConfig.getConf.setString(TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin") + } + + @Test + def testSingleAggOnTable_SortAgg(): Unit = { + tEnv.getConfig.getConf.setString(TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashAgg") + testSingleAggOnTable() + checkResult("SELECT a6, b6, max(c6), count(d6), sum(e6) FROM T6 GROUP BY a6, b6", + (0 until 50000).map(i => row(i, 1L, if (i % 500 == 0) null else s"Hello$i", 1L, 10)) + ) + } + + @Test + def testSingleAggOnTable_HashAgg_WithLocalAgg(): Unit = { + tEnv.getConfig.getConf.setString(TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + tEnv.getConfig.getConf.setString( + PlannerConfigOptions.SQL_OPTIMIZER_AGG_PHASE_ENFORCER, "TWO_PHASE") + tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_HASH_AGG_TABLE_MEM, 2) // 1M + testSingleAggOnTable() + } + + @Test + def testSingleAggOnTable_HashAgg_WithoutLocalAgg(): Unit = { + tEnv.getConfig.getConf.setString(TableConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "SortAgg") + tEnv.getConfig.getConf.setString( + PlannerConfigOptions.SQL_OPTIMIZER_AGG_PHASE_ENFORCER, "ONE_PHASE") + tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_HASH_AGG_TABLE_MEM, 2) // 1M + testSingleAggOnTable() + } + + private def testSingleAggOnTable(): Unit = { + // group by fix length + checkResult("SELECT a1, b1, count(c1) FROM T1 GROUP BY a1, b1", + Seq(row(2, 1, 1), row(3, 2, 1), row(5, 2, 1), row(6, 3, 1))) + // group by string + checkResult("SELECT a1, c1, count(d1), avg(b1) FROM T1 GROUP BY a1, c1", + Seq(row(2, "A", 0, 1.0), row(3, "A", 1, 2.0), row(5, "B", 1, 2.0), row(6, "C", 1, 3.0))) + checkResult("SELECT c5, d5, avg(b5), avg(a5) FROM T5 WHERE d5 IS NOT NULL GROUP BY c5, d5", + Seq(row("B", "Hi", 2.0, 3.0), row("C", "Hello", null, 1.0), + row("D", "Hello world", 3.0, 4.0), row("E", "Hello world, how are you?", 1.0, 3.0), + row("I", "hahaha", 2.0, 7.0), row("J", "I am fine.", 1.0, 6.0))) + // group by string with null + checkResult("SELECT a1, d1, count(d1) FROM T1 GROUP BY a1, d1", + Seq(row(2, null, 0), row(3, "Hi", 1), row(5, "Hello", 1), row(6, "Hello world", 1))) + checkResult("SELECT c5, d5, avg(b5), avg(a5) FROM T5 GROUP BY c5, d5", + Seq(row("A", null, 1.0, 2.0), row("B", "Hi", 2.0, 3.0), row("C", "Hello", null, 1.0), + row("D", "Hello world", 3.0, 4.0), row("E", "Hello world, how are you?", 1.0, 3.0), + row("F", null, null, 5.0), row("I", "hahaha", 2.0, 7.0), row("J", "I am fine.", 1.0, 6.0))) + + checkResult("SELECT a3, b3, count(c3) FROM T3 GROUP BY a3, b3", + Seq(row(1, 10, 1), row(2, 20, 2), row(3, 10, 1), row(4, 20, 1), row(4, null, 1))) + checkResult("SELECT a2, b2, count(c2) FROM T2 GROUP BY a2, b2", + Seq(row(1, 1, 1), row(1, 2, 1), row(2, 3, 0), row(2, 4, 1))) + + // group by constants + checkResult("SELECT a1, b1, count(c1) FROM T1 GROUP BY a1, b1, 1, true", + Seq(row(2, 1, 1), row(3, 2, 1), row(5, 2, 1), row(6, 3, 1))) + checkResult("SELECT count(c1) FROM T1 GROUP BY 1, true", Seq(row(4))) + + // large data, for hash agg mode it will fallback + checkResult("SELECT a6, c6, avg(b6), count(d6), avg(e6) FROM T6 GROUP BY a6, c6", + (0 until 50000).map(i => row(i, if (i % 500 == 0) null else s"Hello$i", 1D, 1L, 10D)) + ) + checkResult("SELECT a6, d6, avg(b6), count(c6), avg(e6) FROM T6 GROUP BY a6, d6", + (0 until 50000).map(i => row(i, "Hello world", 1D, if (i % 500 == 0) 0L else 1L, 10D)) + ) + checkResult("SELECT a6, f6, avg(b6), count(c6), avg(e6) FROM T6 GROUP BY a6, f6", + (0 until 50000).map(i => row(i, new Date(i + 1531820000000L), 1D, + if (i % 500 == 0) 0L else 1L, 10D)) + ) + } + + @Test + def testMultiAggs(): Unit = { + checkResult("SELECT a1, b1, c1, d1, m, COUNT(*) FROM " + + "(SELECT a1, b1, c1, COUNT(d1) AS d1, MAX(d1) AS m FROM T1 GROUP BY a1, b1, c1) t " + + "GROUP BY a1, b1, c1, d1, m", + Seq(row(2, 1, "A", 0, null, 1), row(3, 2, "A", 1, "Hi", 1), + row(5, 2, "B", 1, "Hello", 1), row(6, 3, "C", 1, "Hello world", 1))) + + checkResult("SELECT a3, b3, c, s, COUNT(*) FROM " + + "(SELECT a3, b3, COUNT(d3) AS c, SUM(d3) AS s, MAX(d3) AS m FROM T3 GROUP BY a3, b3) t " + + "GROUP BY a3, b3, c, s", + Seq(row(1, 10, 1, 1, 1), row(2, 20, 2, 3, 1), row(3, 10, 1, 1, 1), + row(4, 20, 1, 2, 1), row(4, null, 1, 2, 1))) + } + + @Test + def testAggOnInnerJoin(): Unit = { + checkResult("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1, T2 WHERE a1 = b2) t GROUP BY a1, b1, a2, b2", + Seq(row(2, 1, 1, 2, 1), row(3, 2, 2, 3, 1))) + + checkResult("SELECT a2, b2, a3, b3, COUNT(c2) FROM " + + "(SELECT * FROM T2, T3 WHERE b2 = a3) t GROUP BY a2, b2, a3, b3", + Seq(row(1, 1, 1, 10, 1), row(1, 2, 2, 20, 2), row(2, 3, 3, 10, 0), + row(2, 4, 4, 20, 1), row(2, 4, 4, null, 1))) + + checkResult("SELECT a1, b1, a2, b2, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1, T2, T3 WHERE a1 = b2 AND a1 = a3) t GROUP BY a1, b1, a2, b2, a3, b3", + Seq(row(2, 1, 1, 2, 2, 20, 2), row(3, 2, 2, 3, 3, 10, 1))) + } + + @Test + def testAggOnLeftJoin(): Unit = { + checkResult("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1 LEFT JOIN T2 ON a1 = b2) t GROUP BY a1, b1, a2, b2", + Seq(row(2, 1, 1, 2, 1), row(3, 2, 2, 3, 1), + row(5, 2, null, null, 1), row(6, 3, null, null, 1))) + + checkResult("SELECT a1, b1, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1 LEFT JOIN T3 ON a1 = a3) t GROUP BY a1, b1, a3, b3", + Seq(row(2, 1, 2, 20, 2), row(3, 2, 3, 10, 1), + row(5, 2, null, null, 1), row(6, 3, null, null, 1))) + + checkResult("SELECT a3, b3, a1, b1, COUNT(c1) FROM " + + "(SELECT * FROM T3 LEFT JOIN T1 ON a1 = a3) t GROUP BY a3, b3, a1, b1", + Seq(row(1, 10, null, null, 0), row(2, 20, 2, 1, 2), row(3, 10, 3, 2, 1), + row(4, 20, null, null, 0), row(4, null, null, null, 0))) + } + + @Test + def testAggOnRightJoin(): Unit = { + checkResult("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1 RIGHT JOIN T2 ON a1 = b2) t GROUP BY a1, b1, a2, b2", + Seq(row(2, 1, 1, 2, 1), row(3, 2, 2, 3, 1), + row(null, null, 1, 1, 0), row(null, null, 2, 4, 0))) + + checkResult("SELECT a1, b1, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1 RIGHT JOIN T3 ON a1 = a3) t GROUP BY a1, b1, a3, b3", + Seq(row(2, 1, 2, 20, 2), row(3, 2, 3, 10, 1), row(null, null, 1, 10, 0), + row(null, null, 4, 20, 0), row(null, null, 4, null, 0))) + + checkResult("SELECT a3, b3, a1, b1, COUNT(c1) FROM " + + "(SELECT * FROM T3 RIGHT JOIN T1 ON a1 = a3) t GROUP BY a3, b3, a1, b1", + Seq(row(2, 20, 2, 1, 2), row(3, 10, 3, 2, 1), + row(null, null, 5, 2, 1), row(null, null, 6, 3, 1))) + } + + @Test + def testAggOnFullJoin(): Unit = { + checkResult("SELECT a1, b1, a2, b2, COUNT(c1) FROM " + + "(SELECT * FROM T1 FULL OUTER JOIN T2 ON a1 = b2) t GROUP BY a1, b1, a2, b2", + Seq(row(2, 1, 1, 2, 1), row(3, 2, 2, 3, 1), row(5, 2, null, null, 1), + row(6, 3, null, null, 1), row(null, null, 1, 1, 0), row(null, null, 2, 4, 0))) + + checkResult("SELECT a1, b1, a3, b3, COUNT(c1) FROM " + + "(SELECT * FROM T1 FULL OUTER JOIN T3 ON a1 = a3) t GROUP BY a1, b1, a3, b3", + Seq(row(2, 1, 2, 20, 2), row(3, 2, 3, 10, 1), row(5, 2, null, null, 1), + row(6, 3, null, null, 1), row(null, null, 1, 10, 0), row(null, null, 4, 20, 0), + row(null, null, 4, null, 0))) + } + + @Test + def testAggOnOver(): Unit = { + checkResult("SELECT a1, b1, c, COUNT(d1) FROM " + + "(SELECT a1, b1, d1, COUNT(*) OVER (PARTITION BY c1) AS c FROM T1) t GROUP BY a1, b1, c", + Seq(row(2, 1, 2, 0), row(3, 2, 2, 1), row(5, 2, 1, 1), row(6, 3, 1, 1))) + } + + @Test + def testAggOnWindow(): Unit = { + checkResult("SELECT a4, b4, COUNT(c4) FROM T4 " + + "GROUP BY a4, b4, TUMBLE(d4, INTERVAL '15' MINUTE)", + Seq(row(1, 1, 1), row(2, 1, 1), row(3, 2, 1), row(4, 3, 1))) + + checkResult("SELECT a4, c4, COUNT(b4), AVG(b4) FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)", + Seq(row(1, "A", 1, 1.0), row(2, "B", 1, 1.0), row(3, "B", 1, 2.0), row(4, "C", 1, 3.0))) + + checkResult("SELECT a4, e4, s, avg(ab), count(cb) FROM " + + "(SELECT a4, e4, avg(b4) as ab, count(b4) AS cb, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, e4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, e4, s", + Seq(row(1, "Hi", UTCTimestamp("2018-06-01 10:00:00.0"), 1D, 1), + row(2, "Hello", UTCTimestamp("2018-06-01 10:00:00.0"), 1D, 1), + row(3, "Hello world", UTCTimestamp("2018-06-01 10:15:00.0"), 2D, 1), + row(4, "I am fine.", UTCTimestamp("2018-06-01 10:30:00.0"), 3D, 1))) + + checkResult("SELECT a4, c4, s, COUNT(b4) FROM " + + "(SELECT a4, c4, avg(b4) AS b4, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4, s", + Seq(row(1, "A", UTCTimestamp("2018-06-01 10:00:00.0"), 1), + row(2, "B", UTCTimestamp("2018-06-01 10:00:00.0"), 1), + row(3, "B", UTCTimestamp("2018-06-01 10:15:00.0"), 1), + row(4, "C", UTCTimestamp("2018-06-01 10:30:00.0"), 1))) + + checkResult("SELECT a4, c4, e, COUNT(b4) FROM " + + "(SELECT a4, c4, VAR_POP(b4) AS b4, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4, e", + Seq(row(1, "A", UTCTimestamp("2018-06-01 10:15:00.0"), 1), + row(2, "B", UTCTimestamp("2018-06-01 10:15:00.0"), 1), + row(3, "B", UTCTimestamp("2018-06-01 10:30:00.0"), 1), + row(4, "C", UTCTimestamp("2018-06-01 10:45:00.0"), 1))) + + checkResult("SELECT a4, b4, c4, COUNT(*) FROM " + + "(SELECT a4, c4, SUM(b4) AS b4, " + + "TUMBLE_START(d4, INTERVAL '15' MINUTE) AS s, " + + "TUMBLE_END(d4, INTERVAL '15' MINUTE) AS e FROM T4 " + + "GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, b4, c4", + Seq(row(1, 1, "A", 1), row(2, 1, "B", 1), row(3, 2, "B", 1), row(4, 3, "C", 1))) + } + + @Test + def testAggWithGroupingSets(): Unit = { + checkResult("SELECT a1, b1, c1, COUNT(d1) FROM T1 " + + "GROUP BY GROUPING SETS ((a1, b1), (a1, c1))", + Seq(row(2, 1, null, 0), row(2, null, "A", 0), row(3, 2, null, 1), + row(3, null, "A", 1), row(5, 2, null, 1), row(5, null, "B", 1), + row(6, 3, null, 1), row(6, null, "C", 1))) + + checkResult("SELECT a1, c1, COUNT(d1) FROM T1 " + + "GROUP BY GROUPING SETS ((a1, c1), (a1), ())", + Seq(row(2, "A", 0), row(2, null, 0), row(3, "A", 1), row(3, null, 1), row(5, "B", 1), + row(5, null, 1), row(6, "C", 1), row(6, null, 1), row(null, null, 3))) + + checkResult("SELECT a1, b1, c1, COUNT(d1) FROM T1 " + + "GROUP BY GROUPING SETS ((a1, b1, c1), (a1, b1, d1))", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(3, 2, "A", 1), row(3, 2, null, 1), + row(5, 2, "B", 1), row(5, 2, null, 1), row(6, 3, "C", 1), row(6, 3, null, 1))) + } + + @Test + def testAggWithRollup(): Unit = { + checkResult("SELECT a1, b1, c1, COUNT(d1) FROM T1 GROUP BY ROLLUP (a1, b1, c1)", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(2, null, null, 0), row(3, 2, "A", 1), + row(3, 2, null, 1), row(3, null, null, 1), row(5, 2, "B", 1), row(5, 2, null, 1), + row(5, null, null, 1), row(6, 3, "C", 1), row(6, 3, null, 1), row(6, null, null, 1), + row(null, null, null, 3))) + } + + @Test + def testAggWithCube(): Unit = { + checkResult("SELECT a1, b1, c1, COUNT(d1) FROM T1 GROUP BY CUBE (a1, b1, c1)", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(2, null, "A", 0), row(2, null, null, 0), + row(3, 2, "A", 1), row(3, 2, null, 1), row(3, null, "A", 1), row(3, null, null, 1), + row(5, 2, "B", 1), row(5, 2, null, 1), row(5, null, "B", 1), row(5, null, null, 1), + row(6, 3, "C", 1), row(6, 3, null, 1), row(6, null, "C", 1), row(6, null, null, 1), + row(null, 1, "A", 0), row(null, 1, null, 0), row(null, 2, "A", 1), row(null, 2, "B", 1), + row(null, 2, null, 2), row(null, 3, "C", 1), row(null, 3, null, 1), row(null, null, "A", 1), + row(null, null, "B", 1), row(null, null, "C", 1), row(null, null, null, 3))) + } + + @Test + def testSingleDistinctAgg(): Unit = { + checkResult("SELECT a1, COUNT(DISTINCT c1) FROM T1 GROUP BY a1", + Seq(row(2, 1), row(3, 1), row(5, 1), row(6, 1))) + + checkResult("SELECT a1, b1, COUNT(DISTINCT c1) FROM T1 GROUP BY a1, b1", + Seq(row(2, 1, 1), row(3, 2, 1), row(5, 2, 1), row(6, 3, 1))) + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg(): Unit = { + checkResult("SELECT a1, COUNT(DISTINCT c1), SUM(b1) FROM T1 GROUP BY a1", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a1, c1, COUNT(DISTINCT c1), SUM(b1) FROM T1 GROUP BY a1, c1", + Seq(row(2, "A", 1, 1), row(3, "A", 1, 2), row(5, "B", 1, 2), row(6, "C", 1, 3))) + + checkResult("SELECT a1, COUNT(DISTINCT c1), SUM(b1) FROM T1 GROUP BY a1", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a1, d1, COUNT(DISTINCT c1), SUM(b1) FROM T1 GROUP BY a1, d1", + Seq(row(2, null, 1, 1), row(3, "Hi", 1, 2), + row(5, "Hello", 1, 2), row(6, "Hello world", 1, 3))) + } + + @Test + def testMultiDistinctAggs(): Unit = { + checkResult("SELECT a1, COUNT(DISTINCT b1), SUM(DISTINCT b1) FROM T1 GROUP BY a1", Seq(row(2, + 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a1, d1, COUNT(DISTINCT c1), SUM(DISTINCT b1) FROM T1 GROUP BY a1, d1", + Seq(row(2, null, 1, 1), row(3, "Hi", 1, 2), + row(5, "Hello", 1, 2), row(6, "Hello world", 1, 3))) + + checkResult( + "SELECT a1, SUM(DISTINCT b1), MAX(DISTINCT b1), MIN(DISTINCT c1) FROM T1 GROUP BY a1", + Seq(row(2, 1, 1, "A"), row(3, 2, 2, "A"), row(5, 2, 2, "B"), row(6, 3, 3, "C"))) + + checkResult( + "SELECT a1, d1, COUNT(DISTINCT c1), MAX(DISTINCT b1), SUM(b1) FROM T1 GROUP BY a1, d1", + Seq(row(2, null, 1, 1, 1), row(3, "Hi", 1, 2, 2), + row(5, "Hello", 1, 2, 2), row(6, "Hello world", 1, 3, 3))) + + checkResult("SELECT a1, b1, COUNT(DISTINCT c1), COUNT(DISTINCT d1) FROM T1 GROUP BY a1, b1", + Seq(row(2, 1, 1, 0), row(3, 2, 1, 1), row(5, 2, 1, 1), row(6, 3, 1, 1))) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateRemoveITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateRemoveITCase.scala new file mode 100644 index 00000000000000..929f49b1fa89b0 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/AggregateRemoveITCase.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.batch.sql.agg + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.api.Types +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.runtime.utils.BatchTestBase +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.TestData._ + +import org.junit.{Before, Test} + +import scala.collection.JavaConverters._ +import scala.collection.Seq + +class AggregateRemoveITCase extends BatchTestBase { + + @Before + def before(): Unit = { + registerCollection("T1", + Seq(row(2, 1, "A", null), + row(3, 2, "A", "Hi"), + row(5, 2, "B", "Hello"), + row(6, 3, "C", "Hello world")), + new RowTypeInfo(Types.INT, Types.INT, Types.STRING, Types.STRING), + "a, b, c, d", + Array(true, true, true), + FlinkStatistic.builder().uniqueKeys(Set(Set("a").asJava).asJava).build() + ) + + registerCollection("T2", smallData3, type3, "a, b, c", nullablesOfSmallData3, + FlinkStatistic.builder().uniqueKeys(Set(Set("a").asJava).asJava).build()) + registerCollection("T3", smallData5, type5, "a, b, c, d, e", nullablesOfSmallData5, + FlinkStatistic.builder().uniqueKeys(Set(Set("b").asJava).asJava).build()) + } + + @Test + def testSimple(): Unit = { + checkResult("SELECT a, b FROM T3 GROUP BY a, b", + Seq(row(1, 1), row(2, 2), row(2, 3))) + + checkResult("SELECT a, b + 1, c, s FROM (" + + "SELECT a, MIN(b) AS b, SUM(b) AS s, MAX(c) AS c FROM T3 GROUP BY a)", + Seq(row(1, 2, 0, 1), row(2, 3, 2, 5))) + + checkResult("SELECT a, SUM(b) AS s FROM T3 GROUP BY a", + Seq(row(1, 1), row(2, 5))) + + checkResult("SELECT MAX(a), SUM(b), MIN(c) FROM (VALUES (1, 2, 3)) T(a, b, c)", + Seq(row(1, 2, 3))) + + checkResult( + "SELECT a, b + 1, c, s FROM (" + + "SELECT a, MIN(b) AS b, SUM(b) AS s, MAX(c) AS c FROM T2 GROUP BY a)", + Seq( + row(1, 2L, "Hi", 1L), + row(2, 3L, "Hello", 2L), + row(3, 3L, "Hello world", 2L) + )) + + checkResult( + "SELECT MAX(a), SUM(b), MIN(c) FROM (VALUES (1, 2, 3)) T(a, b, c)", + Seq(row(1, 2, 3)) + ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM T2 WHERE EXISTS (SELECT SUM(a) FROM T3 WHERE 1=2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + checkResult( + """ + |SELECT a, SUM(b), MAX(b) FROM + | (SELECT a, MAX(b) AS b FROM + | (VALUES (cast(null as BIGINT), cast(null as BIGINT))) T(a, b) GROUP BY a) t + | GROUP BY a + """.stripMargin, + Seq(row(null, null, null)) + ) + } + + @Test + def testWithGroupingSets(): Unit = { + checkResult("SELECT a, b, c, COUNT(d) FROM T1 " + + "GROUP BY GROUPING SETS ((a, b), (a, c))", + Seq(row(2, 1, null, 0), row(2, null, "A", 0), row(3, 2, null, 1), + row(3, null, "A", 1), row(5, 2, null, 1), row(5, null, "B", 1), + row(6, 3, null, 1), row(6, null, "C", 1))) + + checkResult("SELECT a, c, COUNT(d) FROM T1 " + + "GROUP BY GROUPING SETS ((a, c), (a), ())", + Seq(row(2, "A", 0), row(2, null, 0), row(3, "A", 1), row(3, null, 1), row(5, "B", 1), + row(5, null, 1), row(6, "C", 1), row(6, null, 1), row(null, null, 3))) + + checkResult("SELECT a, b, c, COUNT(d) FROM T1 " + + "GROUP BY GROUPING SETS ((a, b, c), (a, b, d))", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(3, 2, "A", 1), row(3, 2, null, 1), + row(5, 2, "B", 1), row(5, 2, null, 1), row(6, 3, "C", 1), row(6, 3, null, 1))) + } + + + @Test + def testWithRollup(): Unit = { + checkResult("SELECT a, b, c, COUNT(d) FROM T1 GROUP BY ROLLUP (a, b, c)", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(2, null, null, 0), row(3, 2, "A", 1), + row(3, 2, null, 1), row(3, null, null, 1), row(5, 2, "B", 1), row(5, 2, null, 1), + row(5, null, null, 1), row(6, 3, "C", 1), row(6, 3, null, 1), row(6, null, null, 1), + row(null, null, null, 3))) + } + + @Test + def testWithCube(): Unit = { + checkResult("SELECT a, b, c, COUNT(d) FROM T1 GROUP BY CUBE (a, b, c)", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(2, null, "A", 0), row(2, null, null, 0), + row(3, 2, "A", 1), row(3, 2, null, 1), row(3, null, "A", 1), row(3, null, null, 1), + row(5, 2, "B", 1), row(5, 2, null, 1), row(5, null, "B", 1), row(5, null, null, 1), + row(6, 3, "C", 1), row(6, 3, null, 1), row(6, null, "C", 1), row(6, null, null, 1), + row(null, 1, "A", 0), row(null, 1, null, 0), row(null, 2, "A", 1), row(null, 2, "B", 1), + row(null, 2, null, 2), row(null, 3, "C", 1), row(null, 3, null, 1), row(null, null, "A", 1), + row(null, null, "B", 1), row(null, null, "C", 1), row(null, null, null, 3))) + + checkResult( + "SELECT b, c, e, SUM(a), MAX(d) FROM T3 GROUP BY CUBE (b, c, e)", + Seq( + row(null, null, null, 5, "Hallo Welt wie"), + row(null, null, 1, 3, "Hallo Welt wie"), + row(null, null, 2, 2, "Hallo Welt"), + row(null, 0, null, 1, "Hallo"), + row(null, 0, 1, 1, "Hallo"), + row(null, 1, null, 2, "Hallo Welt"), + row(null, 1, 2, 2, "Hallo Welt"), + row(null, 2, null, 2, "Hallo Welt wie"), + row(null, 2, 1, 2, "Hallo Welt wie"), + row(1, null, null, 1, "Hallo"), + row(1, null, 1, 1, "Hallo"), + row(1, 0, null, 1, "Hallo"), + row(1, 0, 1, 1, "Hallo"), + row(2, null, null, 2, "Hallo Welt"), + row(2, null, 2, 2, "Hallo Welt"), + row(2, 1, null, 2, "Hallo Welt"), + row(2, 1, 2, 2, "Hallo Welt"), + row(3, null, null, 2, "Hallo Welt wie"), + row(3, null, 1, 2, "Hallo Welt wie"), + row(3, 2, null, 2, "Hallo Welt wie"), + row(3, 2, 1, 2, "Hallo Welt wie") + )) + } + + @Test + def testSingleDistinctAgg(): Unit = { + checkResult("SELECT a, COUNT(DISTINCT c) FROM T1 GROUP BY a", + Seq(row(2, 1), row(3, 1), row(5, 1), row(6, 1))) + + checkResult("SELECT a, b, COUNT(DISTINCT c) FROM T1 GROUP BY a, b", + Seq(row(2, 1, 1), row(3, 2, 1), row(5, 2, 1), row(6, 3, 1))) + + checkResult("SELECT a, b, COUNT(DISTINCT c), COUNT(DISTINCT d) FROM T1 GROUP BY a, b", + Seq(row(2, 1, 1, 0), row(3, 2, 1, 1), row(5, 2, 1, 1), row(6, 3, 1, 1))) + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg(): Unit = { + checkResult("SELECT a, COUNT(DISTINCT c), SUM(b) FROM T1 GROUP BY a", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a, c, COUNT(DISTINCT c), SUM(b) FROM T1 GROUP BY a, c", + Seq(row(2, "A", 1, 1), row(3, "A", 1, 2), row(5, "B", 1, 2), row(6, "C", 1, 3))) + + checkResult("SELECT a, COUNT(DISTINCT c), SUM(b) FROM T1 GROUP BY a", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a, d, COUNT(DISTINCT c), SUM(b) FROM T1 GROUP BY a, d", + Seq(row(2, null, 1, 1), row(3, "Hi", 1, 2), + row(5, "Hello", 1, 2), row(6, "Hello world", 1, 3))) + } + + @Test + def testMultiDistinctAggs(): Unit = { + checkResult("SELECT a, COUNT(DISTINCT b), SUM(DISTINCT b) FROM T1 GROUP BY a", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a, d, COUNT(DISTINCT c), SUM(DISTINCT b) FROM T1 GROUP BY a, d", + Seq(row(2, null, 1, 1), row(3, "Hi", 1, 2), + row(5, "Hello", 1, 2), row(6, "Hello world", 1, 3))) + + checkResult( + "SELECT a, SUM(DISTINCT b), MAX(DISTINCT b), MIN(DISTINCT c) FROM T1 GROUP BY a", + Seq(row(2, 1, 1, "A"), row(3, 2, 2, "A"), row(5, 2, 2, "B"), row(6, 3, 3, "C"))) + + checkResult( + "SELECT a, d, COUNT(DISTINCT c), MAX(DISTINCT b), SUM(b) FROM T1 GROUP BY a, d", + Seq(row(2, null, 1, 1, 1), row(3, "Hi", 1, 2, 2), + row(5, "Hello", 1, 2, 2), row(6, "Hello world", 1, 3, 3))) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/PruneAggregateCallITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/PruneAggregateCallITCase.scala new file mode 100644 index 00000000000000..9efecbc3b18e7d --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/agg/PruneAggregateCallITCase.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.batch.sql.agg + +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.runtime.utils.BatchTestBase +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.TestData._ + +import org.junit.{Before, Test} + +import scala.collection.JavaConverters._ +import scala.collection.Seq + +class PruneAggregateCallITCase extends BatchTestBase { + + @Before + def before(): Unit = { + registerCollection("MyTable", smallData3, type3, "a, b, c", nullablesOfSmallData3) + registerCollection("MyTable2", smallData5, type5, "a, b, c, d, e", nullablesOfSmallData5, + FlinkStatistic.builder().uniqueKeys(Set(Set("b").asJava).asJava).build()) + } + + @Test + def testNoneEmptyGroupKey(): Unit = { + checkResult( + "SELECT a FROM (SELECT b, MAX(a) AS a, COUNT(*), MAX(c) FROM MyTable GROUP BY b) t", + Seq(row(1), row(3)) + ) + checkResult( + """ + |SELECT c, a FROM + | (SELECT a, c, COUNT(b) as c, SUM(b) as s FROM MyTable GROUP BY a, c) t + |WHERE s > 1 + """.stripMargin, + Seq(row("Hello world", 3), row("Hello", 2)) + ) + checkResult( + "SELECT a, c FROM (SELECT a, b, SUM(c) as c, COUNT(d) as d FROM MyTable2 GROUP BY a, b) t", + Seq(row(1, 0), row(2, 1), row(2, 2))) + + checkResult( + "SELECT a FROM (SELECT a, b, SUM(c) as c, COUNT(d) as d FROM MyTable2 GROUP BY a, b) t", + Seq(row(1), row(2), row(2))) + } + + @Test + def testEmptyGroupKey(): Unit = { + checkResult( + "SELECT 1 FROM (SELECT SUM(a) FROM MyTable) t", + Seq(row(1)) + ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT COUNT(*) FROM MyTable2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT COUNT(*) FROM MyTable2 WHERE 1=2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + checkResult( + "SELECT 1 FROM (SELECT SUM(a), COUNT(*) FROM MyTable) t", + Seq(row(1)) + ) + + checkResult( + "SELECT 1 FROM (SELECT SUM(a), COUNT(*) FROM MyTable WHERE 1=2) t", + Seq(row(1)) + ) + + checkResult( + "SELECT 1 FROM (SELECT COUNT(*), SUM(a) FROM MyTable) t", + Seq(row(1)) + ) + + checkResult( + "SELECT 1 FROM (SELECT COUNT(*), SUM(a) FROM MyTable WHERE 1=2) t", + Seq(row(1)) + ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT SUM(a), COUNT(*) FROM MyTable2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT COUNT(*), SUM(a) FROM MyTable2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT SUM(a), COUNT(*) FROM MyTable2 WHERE 1=2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in BatchExecNestedLoopJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT COUNT(*), SUM(a) FROM MyTable2 WHERE 1=2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/InnerJoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/InnerJoinITCase.scala index d8f22fa3568969..cfc1828624f935 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/InnerJoinITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/InnerJoinITCase.scala @@ -78,10 +78,10 @@ class InnerJoinITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) - registerCollection("myUpperCaseData", myUpperCaseData, INT_STRING, Array(true, false), "N, L") - registerCollection("myLowerCaseData", myLowerCaseData, INT_STRING, Array(true, false), "n, l") - registerCollection("myTestData1", myTestData1, INT_INT, Array(false, false), "a, b") - registerCollection("myTestData2", myTestData2, INT_INT, Array(false, false), "a, b") + registerCollection("myUpperCaseData", myUpperCaseData, INT_STRING, "N, L", Array(true, false)) + registerCollection("myLowerCaseData", myLowerCaseData, INT_STRING, "n, l", Array(true, false)) + registerCollection("myTestData1", myTestData1, INT_INT, "a, b", Array(false, false)) + registerCollection("myTestData2", myTestData2, INT_INT, "a, b", Array(false, false)) disableOtherJoinOpForJoin(tEnv, expectedJoinType) } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinConditionTypeCoerceRuleITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinConditionTypeCoerceITCase.scala similarity index 97% rename from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinConditionTypeCoerceRuleITCase.scala rename to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinConditionTypeCoerceITCase.scala index 15a2ae2a5b221f..ecb0b512b27e5f 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinConditionTypeCoerceRuleITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinConditionTypeCoerceITCase.scala @@ -27,7 +27,7 @@ import org.junit.{Before, Ignore, Test} // @RunWith(classOf[Parameterized]) TODO @Ignore // TODO support JoinConditionTypeCoerce -class JoinConditionTypeCoerceRuleITCase extends BatchTestBase { +class JoinConditionTypeCoerceITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) @@ -35,14 +35,14 @@ class JoinConditionTypeCoerceRuleITCase extends BatchTestBase { "t1", numericData, numericType, - nullablesOfNumericData, - "a, b, c, d, e") + "a, b, c, d, e", + nullablesOfNumericData) registerCollection( "t2", numericData, numericType, - nullablesOfNumericData, - "a, b, c, d, e") + "a, b, c, d, e", + nullablesOfNumericData) // Disable NestedLoopJoin. JoinITCaseHelper.disableOtherJoinOpForJoin(tEnv, JoinType.SortMergeJoin) } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinITCase.scala index 908c0fd68fdfd9..d9ac9d0459a8b0 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/JoinITCase.scala @@ -44,14 +44,14 @@ class JoinITCase() extends BatchTestBase { @Before def before(): Unit = { - registerCollection("SmallTable3", smallData3, type3, nullablesOfSmallData3, "a, b, c") - registerCollection("Table3", data3, type3, nullablesOfData3, "a, b, c") - registerCollection("Table5", data5, type5, nullablesOfData5, "d, e, f, g, h") - registerCollection("NullTable3", nullData3, type3, nullablesOfNullData3, "a, b, c") - registerCollection("NullTable5", nullData5, type5, nullablesOfNullData5, "d, e, f, g, h") + registerCollection("SmallTable3", smallData3, type3, "a, b, c", nullablesOfSmallData3) + registerCollection("Table3", data3, type3, "a, b, c", nullablesOfData3) + registerCollection("Table5", data5, type5, "d, e, f, g, h", nullablesOfData5) + registerCollection("NullTable3", nullData3, type3, "a, b, c", nullablesOfNullData3) + registerCollection("NullTable5", nullData5, type5, "d, e, f, g, h", nullablesOfNullData5) registerCollection("l", data2_1, INT_DOUBLE, "a, b") registerCollection("r", data2_2, INT_DOUBLE, "c, d") - registerCollection("t", data2_3, INT_DOUBLE, nullablesOfData2_3, "c, d") + registerCollection("t", data2_3, INT_DOUBLE, "c, d", nullablesOfData2_3) JoinITCaseHelper.disableOtherJoinOpForJoin(tEnv, expectedJoinType) } @@ -96,11 +96,11 @@ class JoinITCase() extends BatchTestBase { registerCollection("PojoSmallTable3", smallData3, new RowTypeInfo(INT_TYPE_INFO, LONG_TYPE_INFO, new GenericTypeInfoWithoutComparator[String](classOf[String])), - nullablesOfSmallData3, "a, b, c") + "a, b, c", nullablesOfSmallData3) registerCollection("PojoTable5", data5, new RowTypeInfo(INT_TYPE_INFO, LONG_TYPE_INFO, INT_TYPE_INFO, new GenericTypeInfoWithoutComparator[String](classOf[String]), LONG_TYPE_INFO), - nullablesOfData5, "d, e, f, g, h") + "d, e, f, g, h", nullablesOfData5) checkResult( "SELECT c, g FROM (SELECT h, g, f, e, d FROM PojoSmallTable3, PojoTable5 WHERE b = e)," + @@ -667,7 +667,7 @@ class JoinITCase() extends BatchTestBase { )) registerCollection( - "NullT", Seq(row(null, null, "c")), type3, allNullablesOfNullData3, "a, b, c") + "NullT", Seq(row(null, null, "c")), type3, "a, b, c", allNullablesOfNullData3) checkResult( "SELECT T1.a, T1.b, T1.c FROM NullT T1, NullT T2 WHERE " + "(T1.a = T2.a OR (T1.a IS NULL AND T2.a IS NULL)) " + diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/OuterJoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/OuterJoinITCase.scala index a16b88ef846bd0..6c0debf0bdd9b9 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/OuterJoinITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/join/OuterJoinITCase.scala @@ -60,9 +60,9 @@ class OuterJoinITCase extends BatchTestBase { @Before def before(): Unit = { tEnv.getConfig.getConf.setInteger(TableConfigOptions.SQL_RESOURCE_DEFAULT_PARALLELISM, 3) - registerCollection("uppercasedata", upperCaseData, INT_STRING, nullablesOfUpperCaseData, "N, L") - registerCollection("lowercasedata", lowerCaseData, INT_STRING, nullablesOfLowerCaseData, "n, l") - registerCollection("allnulls", allNulls, INT_ONLY, nullablesOfAllNulls, "a") + registerCollection("uppercasedata", upperCaseData, INT_STRING, "N, L", nullablesOfUpperCaseData) + registerCollection("lowercasedata", lowerCaseData, INT_STRING, "n, l", nullablesOfLowerCaseData) + registerCollection("allnulls", allNulls, INT_ONLY, "a", nullablesOfAllNulls) registerCollection("leftT", leftT, INT_DOUBLE, "a, b") registerCollection("rightT", rightT, INT_DOUBLE, "c, d") JoinITCaseHelper.disableOtherJoinOpForJoin(tEnv, expectedJoinType) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateITCase.scala index 160680bc212a92..0e207999d65141 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateITCase.scala @@ -31,10 +31,11 @@ import org.apache.flink.table.runtime.utils.StreamingWithAggTestBase.AggMode import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode import org.apache.flink.table.runtime.utils.UserDefinedFunctionTestUtils._ -import org.apache.flink.table.runtime.utils.{StreamTestData, StreamingWithAggTestBase, TestingRetractSink} +import org.apache.flink.table.runtime.utils.{StreamingWithAggTestBase, TestData, TestingRetractSink} import org.apache.flink.table.typeutils.BigDecimalTypeInfo import org.apache.flink.table.util.DateTimeTestUtil._ import org.apache.flink.types.Row + import org.junit.Assert.assertEquals import org.junit._ import org.junit.runner.RunWith @@ -159,7 +160,7 @@ class AggregateITCase( "FROM MyTable " + "GROUP BY b" - val t = failingDataSource(StreamTestData.get3TupleData).toTable(tEnv, 'a, 'b, 'c) + val t = failingDataSource(TestData.tupleData3).toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTable", t) val result = tEnv.sqlQuery(sqlQuery).toRetractStream[Row] @@ -543,7 +544,7 @@ class AggregateITCase( /** test unbounded groupBy (without window) **/ @Test def testUnboundedGroupBy(): Unit = { - val t = failingDataSource(StreamTestData.get3TupleData).toTable(tEnv, 'a, 'b, 'c) + val t = failingDataSource(TestData.tupleData3).toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTable", t) val sqlQuery = "SELECT b, COUNT(a) FROM MyTable GROUP BY b" @@ -609,7 +610,7 @@ class AggregateITCase( def testUnboundedGroupByCollect(): Unit = { val sqlQuery = "SELECT b, COLLECT(a) FROM MyTable GROUP BY b" - val t = failingDataSource(StreamTestData.get3TupleData).toTable(tEnv, 'a, 'b, 'c) + val t = failingDataSource(TestData.tupleData3).toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTable", t) val sink = new TestingRetractSink diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateRemoveITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateRemoveITCase.scala new file mode 100644 index 00000000000000..5c7b59c45306fb --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/AggregateRemoveITCase.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.StreamingWithAggTestBase.AggMode +import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.{StreamTableEnvUtil, StreamingWithAggTestBase, TestData, TestingRetractSink} +import org.apache.flink.types.Row + +import org.junit.Assert.assertEquals +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ +import scala.collection.{Seq, mutable} + +@RunWith(classOf[Parameterized]) +class AggregateRemoveITCase( + aggMode: AggMode, + minibatch: MiniBatchMode, + backend: StateBackendMode) + extends StreamingWithAggTestBase(aggMode, minibatch, backend) { + + @Test + def testSimple(): Unit = { + checkResult("SELECT a, b FROM T GROUP BY a, b", + Seq(row(2, 1), row(3, 2), row(5, 2), row(6, 3))) + + checkResult("SELECT a, b + 1, c, s FROM (" + + "SELECT a, MIN(b) AS b, SUM(b) AS s, MAX(c) AS c FROM MyTable2 GROUP BY a)", + Seq(row(1, 2, 0, 1), row(2, 3, 2, 5))) + + checkResult("SELECT a, SUM(b) AS s FROM MyTable2 GROUP BY a", + Seq(row(1, 1), row(2, 5))) + + checkResult( + "SELECT a, b + 1, c, s FROM (" + + "SELECT a, MIN(b) AS b, SUM(b) AS s, MAX(c) AS c FROM MyTable GROUP BY a)", + Seq( + row(1, 2L, "Hi", 1L), + row(2, 3L, "Hello", 2L), + row(3, 3L, "Hello world", 2L) + )) + } + + @Test + def testWithGroupingSets(): Unit = { + checkResult("SELECT a, b, c, COUNT(d) FROM T " + + "GROUP BY GROUPING SETS ((a, b), (a, c))", + Seq(row(2, 1, null, 0), row(2, null, "A", 0), row(3, 2, null, 1), + row(3, null, "A", 1), row(5, 2, null, 1), row(5, null, "B", 1), + row(6, 3, null, 1), row(6, null, "C", 1))) + + checkResult("SELECT a, c, COUNT(d) FROM T " + + "GROUP BY GROUPING SETS ((a, c), (a), ())", + Seq(row(2, "A", 0), row(2, null, 0), row(3, "A", 1), row(3, null, 1), row(5, "B", 1), + row(5, null, 1), row(6, "C", 1), row(6, null, 1), row(null, null, 3))) + + checkResult("SELECT a, b, c, COUNT(d) FROM T " + + "GROUP BY GROUPING SETS ((a, b, c), (a, b, d))", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(3, 2, "A", 1), row(3, 2, null, 1), + row(5, 2, "B", 1), row(5, 2, null, 1), row(6, 3, "C", 1), row(6, 3, null, 1))) + } + + @Test + def testWithRollup(): Unit = { + checkResult("SELECT a, b, c, COUNT(d) FROM T GROUP BY ROLLUP (a, b, c)", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(2, null, null, 0), row(3, 2, "A", 1), + row(3, 2, null, 1), row(3, null, null, 1), row(5, 2, "B", 1), row(5, 2, null, 1), + row(5, null, null, 1), row(6, 3, "C", 1), row(6, 3, null, 1), row(6, null, null, 1), + row(null, null, null, 3))) + } + + @Test + def testWithCube(): Unit = { + checkResult("SELECT a, b, c, COUNT(d) FROM T GROUP BY CUBE (a, b, c)", + Seq(row(2, 1, "A", 0), row(2, 1, null, 0), row(2, null, "A", 0), row(2, null, null, 0), + row(3, 2, "A", 1), row(3, 2, null, 1), row(3, null, "A", 1), row(3, null, null, 1), + row(5, 2, "B", 1), row(5, 2, null, 1), row(5, null, "B", 1), row(5, null, null, 1), + row(6, 3, "C", 1), row(6, 3, null, 1), row(6, null, "C", 1), row(6, null, null, 1), + row(null, 1, "A", 0), row(null, 1, null, 0), row(null, 2, "A", 1), row(null, 2, "B", 1), + row(null, 2, null, 2), row(null, 3, "C", 1), row(null, 3, null, 1), row(null, null, "A", 1), + row(null, null, "B", 1), row(null, null, "C", 1), row(null, null, null, 3))) + + checkResult( + "SELECT b, c, e, SUM(a), MAX(d) FROM MyTable2 GROUP BY CUBE (b, c, e)", + Seq( + row(null, null, null, 5, "Hallo Welt wie"), + row(null, null, 1, 3, "Hallo Welt wie"), + row(null, null, 2, 2, "Hallo Welt"), + row(null, 0, null, 1, "Hallo"), + row(null, 0, 1, 1, "Hallo"), + row(null, 1, null, 2, "Hallo Welt"), + row(null, 1, 2, 2, "Hallo Welt"), + row(null, 2, null, 2, "Hallo Welt wie"), + row(null, 2, 1, 2, "Hallo Welt wie"), + row(1, null, null, 1, "Hallo"), + row(1, null, 1, 1, "Hallo"), + row(1, 0, null, 1, "Hallo"), + row(1, 0, 1, 1, "Hallo"), + row(2, null, null, 2, "Hallo Welt"), + row(2, null, 2, 2, "Hallo Welt"), + row(2, 1, null, 2, "Hallo Welt"), + row(2, 1, 2, 2, "Hallo Welt"), + row(3, null, null, 2, "Hallo Welt wie"), + row(3, null, 1, 2, "Hallo Welt wie"), + row(3, 2, null, 2, "Hallo Welt wie"), + row(3, 2, 1, 2, "Hallo Welt wie") + )) + } + + @Test + def testSingleDistinctAgg(): Unit = { + checkResult("SELECT a, COUNT(DISTINCT c) FROM T GROUP BY a", + Seq(row(2, 1), row(3, 1), row(5, 1), row(6, 1))) + + checkResult("SELECT a, b, COUNT(DISTINCT c) FROM T GROUP BY a, b", + Seq(row(2, 1, 1), row(3, 2, 1), row(5, 2, 1), row(6, 3, 1))) + + checkResult("SELECT a, b, COUNT(DISTINCT c), COUNT(DISTINCT d) FROM T GROUP BY a, b", + Seq(row(2, 1, 1, 0), row(3, 2, 1, 1), row(5, 2, 1, 1), row(6, 3, 1, 1))) + } + + @Test + def testSingleDistinctAgg_WithNonDistinctAgg(): Unit = { + checkResult("SELECT a, COUNT(DISTINCT c), SUM(b) FROM T GROUP BY a", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a, c, COUNT(DISTINCT c), SUM(b) FROM T GROUP BY a, c", + Seq(row(2, "A", 1, 1), row(3, "A", 1, 2), row(5, "B", 1, 2), row(6, "C", 1, 3))) + + checkResult("SELECT a, COUNT(DISTINCT c), SUM(b) FROM T GROUP BY a", + Seq(row(2, 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a, d, COUNT(DISTINCT c), SUM(b) FROM T GROUP BY a, d", + Seq(row(2, null, 1, 1), row(3, "Hi", 1, 2), + row(5, "Hello", 1, 2), row(6, "Hello world", 1, 3))) + } + + @Test + def testMultiDistinctAggs(): Unit = { + checkResult("SELECT a, COUNT(DISTINCT b), SUM(DISTINCT b) FROM T GROUP BY a", Seq(row(2, + 1, 1), row(3, 1, 2), row(5, 1, 2), row(6, 1, 3))) + + checkResult("SELECT a, d, COUNT(DISTINCT c), SUM(DISTINCT b) FROM T GROUP BY a, d", + Seq(row(2, null, 1, 1), row(3, "Hi", 1, 2), + row(5, "Hello", 1, 2), row(6, "Hello world", 1, 3))) + + checkResult( + "SELECT a, SUM(DISTINCT b), MAX(DISTINCT b), MIN(DISTINCT c) FROM T GROUP BY a", + Seq(row(2, 1, 1, "A"), row(3, 2, 2, "A"), row(5, 2, 2, "B"), row(6, 3, 3, "C"))) + + checkResult( + "SELECT a, d, COUNT(DISTINCT c), MAX(DISTINCT b), SUM(b) FROM T GROUP BY a, d", + Seq(row(2, null, 1, 1, 1), row(3, "Hi", 1, 2, 2), + row(5, "Hello", 1, 2, 2), row(6, "Hello world", 1, 3, 3))) + } + + @Test + def testAggregateRemove(): Unit = { + val data = new mutable.MutableList[(Int, Int)] + data.+=((1, 1)) + data.+=((2, 2)) + data.+=((3, 3)) + data.+=((4, 2)) + data.+=((4, 4)) + data.+=((6, 2)) + + val t = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.registerTable("T1", t) + + val t1 = tEnv.sqlQuery( + """ + |select sum(b) from + | (select b from + | (select b, sum(a) from + | (select b, sum(a) as a from T1 group by b) t1 + | group by b) t2 + | ) t3 + """.stripMargin) + val sink = new TestingRetractSink + t1.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + val expected = List("10") + assertEquals(expected, sink.getRetractResults) + } + + private def checkResult(str: String, rows: Seq[Row]): Unit = { + super.before() + + val ds1 = env.fromCollection(Seq[(Int, Int, String, String)]( + (2, 1, "A", null), + (3, 2, "A", "Hi"), + (5, 2, "B", "Hello"), + (6, 3, "C", "Hello world"))) + StreamTableEnvUtil.registerDataStreamInternal[(Int, Int, String, String)]( + tEnv, + "T", + ds1.javaStream, + Some(Array("a", "b", "c", "d")), + Some(Array(true, true, true, true)), + Some(FlinkStatistic.builder().uniqueKeys(Set(Set("a").asJava).asJava).build()) + ) + + StreamTableEnvUtil.registerDataStreamInternal[(Int, Long, String)]( + tEnv, + "MyTable", + env.fromCollection(TestData.smallTupleData3).javaStream, + Some(Array("a", "b", "c")), + Some(Array(true, true, true)), + Some(FlinkStatistic.builder().uniqueKeys(Set(Set("a").asJava).asJava).build()) + ) + + StreamTableEnvUtil.registerDataStreamInternal[(Int, Long, Int, String, Long)]( + tEnv, + "MyTable2", + env.fromCollection(TestData.smallTupleData5).javaStream, + Some(Array("a", "b", "c", "d", "e")), + Some(Array(true, true, true, true, true)), + Some(FlinkStatistic.builder().uniqueKeys(Set(Set("b").asJava).asJava).build()) + ) + + val t = tEnv.sqlQuery(str) + val sink = new TestingRetractSink + env.setMaxParallelism(1) + env.setParallelism(1) + t.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + val expected = rows.map(_.toString) + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/CalcITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/CalcITCase.scala index d141c28ef88efc..1506b77fe194a1 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/CalcITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/CalcITCase.scala @@ -25,7 +25,7 @@ import org.apache.flink.api.scala.typeutils.Types import org.apache.flink.table.`type`.InternalTypes import org.apache.flink.table.api.scala._ import org.apache.flink.table.dataformat.{BaseRow, GenericRow} -import org.apache.flink.table.runtime.utils.{StreamTestData, StreamingTestBase, TestSinkUtil, TestingAppendBaseRowSink, TestingAppendSink, TestingAppendTableSink} +import org.apache.flink.table.runtime.utils.{StreamingTestBase, TestData, TestSinkUtil, TestingAppendBaseRowSink, TestingAppendSink, TestingAppendTableSink} import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.types.Row @@ -166,7 +166,7 @@ class CalcITCase extends StreamingTestBase { def testPrimitiveMapType(): Unit = { val sqlQuery = "SELECT MAP[b, 30, 10, a] FROM MyTableRow" - val t = env.fromCollection(StreamTestData.getSmall3TupleData) + val t = env.fromCollection(TestData.smallTupleData3) .toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTableRow", t) @@ -186,7 +186,7 @@ class CalcITCase extends StreamingTestBase { def testNonPrimitiveMapType(): Unit = { val sqlQuery = "SELECT MAP[a, c] FROM MyTableRow" - val t = env.fromCollection(StreamTestData.getSmall3TupleData) + val t = env.fromCollection(TestData.smallTupleData3) .toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTableRow", t) @@ -228,7 +228,7 @@ class CalcITCase extends StreamingTestBase { def testIn(): Unit = { val sqlQuery = "SELECT * FROM MyTable WHERE b in (1,3,4,5,6)" - val t = env.fromCollection(StreamTestData.get3TupleData) + val t = env.fromCollection(TestData.tupleData3) .toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTable", t) @@ -250,7 +250,7 @@ class CalcITCase extends StreamingTestBase { def testNotIn(): Unit = { val sqlQuery = "SELECT * FROM MyTable WHERE b not in (1,3,4,5,6)" - val t = env.fromCollection(StreamTestData.get3TupleData) + val t = env.fromCollection(TestData.tupleData3) .toTable(tEnv, 'a, 'b, 'c) tEnv.registerTable("MyTable", t) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala index 8abe855838d2d9..9f818b862995d2 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/DeduplicateITCase.scala @@ -21,8 +21,8 @@ package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode -import org.apache.flink.table.runtime.utils._ import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils._ import org.apache.flink.types.Row import org.junit.Assert._ @@ -36,7 +36,7 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) @Test def testFirstRowOnProctime(): Unit = { - val t = failingDataSource(StreamTestData.get3TupleData) + val t = failingDataSource(TestData.tupleData3) .toTable(tEnv, 'a, 'b, 'c, 'proctime) tEnv.registerTable("T", t) @@ -62,7 +62,7 @@ class DeduplicateITCase(miniBatch: MiniBatchMode, mode: StateBackendMode) @Test def testLastRowOnProctime(): Unit = { - val t = failingDataSource(StreamTestData.get3TupleData) + val t = failingDataSource(TestData.tupleData3) .toTable(tEnv, 'a, 'b, 'c, 'proctime) tEnv.registerTable("T", t) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala index 72ca720083a0de..4f57887996fd39 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/OverWindowITCase.scala @@ -21,11 +21,12 @@ package org.apache.flink.table.runtime.stream.sql import org.apache.flink.api.common.time.Time import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ -import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeProcessOperator import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeProcessOperator import org.apache.flink.table.runtime.utils.UserDefinedFunctionTestUtils.{CountNullNonNull, CountPairs, LargerThanCount} -import org.apache.flink.table.runtime.utils.{StreamTestData, StreamingWithStateTestBase, TestingAppendSink} +import org.apache.flink.table.runtime.utils.{StreamingWithStateTestBase, TestData, TestingAppendSink} import org.apache.flink.types.Row + import org.junit.Assert._ import org.junit._ import org.junit.runner.RunWith @@ -49,7 +50,7 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas @Test def testProcTimeBoundedPartitionedRowsOver(): Unit = { - val t = failingDataSource(StreamTestData.get5TupleData) + val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime) tEnv.registerTable("MyTable", t) @@ -85,7 +86,7 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas @Test def testProcTimeBoundedPartitionedRowsOverWithBultinProctime(): Unit = { - val t = failingDataSource(StreamTestData.get5TupleData) + val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime) tEnv.registerTable("MyTable", t) @@ -121,7 +122,7 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas @Test def testProcTimeBoundedPartitionedRowsOverWithBuiltinProctime(): Unit = { - val t = failingDataSource(StreamTestData.get5TupleData) + val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime) tEnv.registerTable("MyTable", t) @@ -157,7 +158,7 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas @Test def testProcTimeBoundedNonPartitionedRowsOver(): Unit = { - val t = failingDataSource(StreamTestData.get5TupleData) + val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime) tEnv.registerTable("MyTable", t) @@ -856,7 +857,7 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas @Test def testProcTimeDistinctUnboundedPartitionedRowsOver(): Unit = { - val t = failingDataSource(StreamTestData.get5TupleData) + val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime) tEnv.registerTable("MyTable", t) @@ -941,7 +942,7 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas @Test def testProcTimeDistinctBoundedPartitionedRowsOver(): Unit = { - val t = failingDataSource(StreamTestData.get5TupleData) + val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime) tEnv.registerTable("MyTable", t) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/PruneAggregateCallITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/PruneAggregateCallITCase.scala new file mode 100644 index 00000000000000..35d2365d393ca1 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/PruneAggregateCallITCase.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.stream.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.plan.stats.FlinkStatistic +import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.StreamingWithAggTestBase.AggMode +import org.apache.flink.table.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode +import org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.runtime.utils.{StreamTableEnvUtil, StreamingWithAggTestBase, TestData, TestingRetractSink} +import org.apache.flink.types.Row + +import org.junit.Assert.assertEquals +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ +import scala.collection.Seq + +@RunWith(classOf[Parameterized]) +class PruneAggregateCallITCase( + aggMode: AggMode, + minibatch: MiniBatchMode, + backend: StateBackendMode) + extends StreamingWithAggTestBase(aggMode, minibatch, backend) { + + @Test + def testNoneEmptyGroupKey(): Unit = { + checkResult( + "SELECT a FROM (SELECT b, MAX(a) AS a, COUNT(*), MAX(c) FROM MyTable GROUP BY b) t", + Seq(row(1), row(3)) + ) + checkResult( + """ + |SELECT c, a FROM + | (SELECT a, c, COUNT(b) as b, SUM(b) as s FROM MyTable GROUP BY a, c) t + |WHERE s > 1 + """.stripMargin, + Seq(row("Hello world", 3), row("Hello", 2)) + ) + } + + @Test + def testEmptyGroupKey(): Unit = { + checkResult( + "SELECT 1 FROM (SELECT SUM(a) FROM MyTable) t", + Seq(row(1)) + ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in StreamExecJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT COUNT(*) FROM MyTable2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + checkResult( + "SELECT 1 FROM (SELECT SUM(a), COUNT(*) FROM MyTable) t", + Seq(row(1)) + ) + + checkResult( + "SELECT 1 FROM (SELECT COUNT(*), SUM(a) FROM MyTable) t", + Seq(row(1)) + ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in StreamExecJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT SUM(a), COUNT(*) FROM MyTable2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + + // TODO enable this case after translateToPlanInternal method is implemented + // in StreamExecJoin + // checkResult( + // "SELECT * FROM MyTable WHERE EXISTS (SELECT COUNT(*), SUM(a) FROM MyTable2)", + // Seq(row(1, 1, "Hi"), row(2, 2, "Hello"), row(3, 2, "Hello world")) + // ) + } + + private def checkResult(str: String, rows: Seq[Row]): Unit = { + super.before() + StreamTableEnvUtil.registerDataStreamInternal[(Int, Long, String)]( + tEnv, + "MyTable", + failingDataSource(TestData.smallTupleData3).javaStream, + Some(Array("a", "b", "c")), + Some(Array(true, true, true)), + Some(FlinkStatistic.UNKNOWN) + ) + + StreamTableEnvUtil.registerDataStreamInternal[(Int, Long, Int, String, Long)]( + tEnv, + "MyTable2", + failingDataSource(TestData.smallTupleData5).javaStream, + Some(Array("a", "b", "c", "d", "e")), + Some(Array(true, true, true, true, true)), + Some(FlinkStatistic.builder().uniqueKeys(Set(Set("b").asJava).asJava).build()) + ) + + val t = tEnv.sqlQuery(str) + val sink = new TestingRetractSink + env.setMaxParallelism(1) + env.setParallelism(1) + t.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + val expected = rows.map(_.toString) + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalSortITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalSortITCase.scala index eccc2761456459..acee2ce5a0e9bf 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalSortITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalSortITCase.scala @@ -123,7 +123,7 @@ class TemporalSortITCase(mode: StateBackendMode) extends StreamingWithStateTestB @Test def testProcTimeOrderBy(): Unit = { - val t = failingDataSource(StreamTestData.get3TupleData) + val t = failingDataSource(TestData.tupleData3) .toTable(tEnv, 'a, 'b, 'c, 'proctime) tEnv.registerTable("T", t) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchScalaTableEnvUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchScalaTableEnvUtil.scala index b77056fd3becc9..9b405626b357e5 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchScalaTableEnvUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchScalaTableEnvUtil.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.runtime.utils import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.{BatchTableEnvironment, Table, TableEnvironment} +import org.apache.flink.table.plan.stats.FlinkStatistic import org.apache.flink.table.runtime.utils.BatchTableEnvUtil.parseFieldNames import scala.reflect.ClassTag @@ -40,7 +41,7 @@ object BatchScalaTableEnvUtil { tableName: String, data: Iterable[T], fieldNames: String): Unit = { val typeInfo = implicitly[TypeInformation[T]] BatchTableEnvUtil.registerCollection( - tEnv, tableName, data, typeInfo, Some(parseFieldNames(fieldNames)), None) + tEnv, tableName, data, typeInfo, Some(parseFieldNames(fieldNames)), None, None) } /** @@ -51,15 +52,20 @@ object BatchScalaTableEnvUtil { * @param data The [[Iterable]] to be converted. * @param fieldNames field names, eg: "a, b, c" * @param fieldNullables The field isNullables attributes of data. + * @param statistic statistics of current Table * @tparam T The type of the [[Iterable]]. * @return The converted [[Table]]. */ - def registerCollection[T : ClassTag : TypeInformation]( - tEnv: BatchTableEnvironment, tableName: String, data: Iterable[T], - fieldNames: String, fieldNullables: Array[Boolean]): Unit = { + def registerCollection[T: ClassTag : TypeInformation]( + tEnv: BatchTableEnvironment, + tableName: String, + data: Iterable[T], + fieldNames: String, + fieldNullables: Array[Boolean], + statistic: Option[FlinkStatistic]): Unit = { val typeInfo = implicitly[TypeInformation[T]] - BatchTableEnvUtil.registerCollection( - tEnv, tableName, data, typeInfo, Some(parseFieldNames(fieldNames)), Option(fieldNullables)) + BatchTableEnvUtil.registerCollection(tEnv, tableName, data, typeInfo, + Some(parseFieldNames(fieldNames)), Option(fieldNullables), statistic) } /** @@ -91,7 +97,7 @@ object BatchScalaTableEnvUtil { def fromCollection[T: ClassTag : TypeInformation]( tEnv: BatchTableEnvironment, data: Iterable[T]): Table = { val typeInfo = implicitly[TypeInformation[T]] - BatchTableEnvUtil.fromCollection(tEnv, null, data, typeInfo, null) + BatchTableEnvUtil.fromCollection(tEnv, null, data, typeInfo, null, None) } /** diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala index be32bff9361ccb..fe9315549e4822 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala @@ -26,6 +26,7 @@ import org.apache.flink.api.java.io.CollectionInputFormat import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.{BatchTableEnvironment, Table, TableEnvironment} import org.apache.flink.table.plan.schema.DataStreamTable +import org.apache.flink.table.plan.stats.FlinkStatistic import org.apache.flink.table.sinks.CollectTableSink import org.apache.flink.util.AbstractID @@ -70,7 +71,7 @@ object BatchTableEnvUtil { tableName: String, data: Iterable[T], typeInfo: TypeInformation[T], fieldNames: String): Unit = { registerCollection( - tEnv, tableName, data, typeInfo, Some(parseFieldNames(fieldNames)), None) + tEnv, tableName, data, typeInfo, Some(parseFieldNames(fieldNames)), None, None) } /** @@ -82,6 +83,7 @@ object BatchTableEnvUtil { * @param typeInfo information of [[Iterable]]. * @param fieldNames field names, eg: "a, b, c" * @param fieldNullables The field isNullables attributes of data. + * @param statistic statistics of current Table * @tparam T The type of the [[Iterable]]. * @return The converted [[Table]]. */ @@ -91,9 +93,10 @@ object BatchTableEnvUtil { data: Iterable[T], typeInfo: TypeInformation[T], fieldNames: String, - fieldNullables: Array[Boolean]): Unit = { - registerCollection( - tEnv, tableName, data, typeInfo, Some(parseFieldNames(fieldNames)), Option(fieldNullables)) + fieldNullables: Array[Boolean], + statistic: Option[FlinkStatistic]): Unit = { + registerCollection(tEnv, tableName, data, typeInfo, + Some(parseFieldNames(fieldNames)), Option(fieldNullables), statistic) } /** @@ -105,6 +108,7 @@ object BatchTableEnvUtil { * @param typeInfo information of [[Iterable]]. * @param fieldNames field names. * @param fieldNullables The field isNullables attributes of data. + * @param statistic statistics of current Table * @tparam T The type of the [[Iterable]]. * @return The converted [[Table]]. */ @@ -115,13 +119,15 @@ object BatchTableEnvUtil { data: Iterable[T], typeInfo: TypeInformation[T], fieldNames: Option[Array[String]], - fieldNullables: Option[Array[Boolean]]): Unit = { + fieldNullables: Option[Array[Boolean]], + statistic: Option[FlinkStatistic]): Unit = { val boundedStream = tEnv.streamEnv.createInput(new CollectionInputFormat[T]( data.asJavaCollection, typeInfo.createSerializer(tEnv.streamEnv.getConfig)), typeInfo) boundedStream.forceNonParallel() - registerBoundedStreamInternal(tEnv, tableName, boundedStream, fieldNames, fieldNullables) + registerBoundedStreamInternal( + tEnv, tableName, boundedStream, fieldNames, fieldNullables, statistic) } /** @@ -137,12 +143,14 @@ object BatchTableEnvUtil { name: String, boundedStream: DataStream[T], fieldNames: Option[Array[String]], - fieldNullables: Option[Array[Boolean]]): Unit = { + fieldNullables: Option[Array[Boolean]], + statistic: Option[FlinkStatistic]): Unit = { val (typeFieldNames, fieldIdxs) = tEnv.getFieldInfo(boundedStream.getTransformation.getOutputType) val boundedStreamTable = new DataStreamTable[T]( boundedStream, fieldIdxs, fieldNames.getOrElse(typeFieldNames), fieldNullables) - tEnv.registerTableInternal(name, boundedStreamTable) + val withStatistic = boundedStreamTable.copy(statistic.getOrElse(FlinkStatistic.UNKNOWN)) + tEnv.registerTableInternal(name, withStatistic) } /** @@ -154,7 +162,8 @@ object BatchTableEnvUtil { tableName: String, data: Iterable[T], typeInfo: TypeInformation[T], - fieldNames: Array[String]): Table = { + fieldNames: Array[String], + statistic: Option[FlinkStatistic]): Table = { CollectionInputFormat.checkCollection(data.asJavaCollection, typeInfo.getTypeClass) val boundedStream = tEnv.streamEnv.createInput(new CollectionInputFormat[T]( data.asJavaCollection, @@ -162,7 +171,7 @@ object BatchTableEnvUtil { typeInfo) boundedStream.setParallelism(1) val name = if (tableName == null) tEnv.createUniqueTableName() else tableName - registerBoundedStreamInternal(tEnv, name, boundedStream, Option(fieldNames), None) + registerBoundedStreamInternal(tEnv, name, boundedStream, Option(fieldNames), None, statistic) tEnv.scan(name) } @@ -172,6 +181,6 @@ object BatchTableEnvUtil { */ def fromCollection[T](tEnv: BatchTableEnvironment, data: Iterable[T], typeInfo: TypeInformation[T], fields: String): Table = { - fromCollection(tEnv, null, data, typeInfo, parseFieldNames(fields)) + fromCollection(tEnv, null, data, typeInfo, parseFieldNames(fields), None) } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTestBase.scala index 0202793dbd055c..8d032cd73ecffa 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTestBase.scala @@ -29,6 +29,7 @@ import org.apache.flink.table.api.scala.{BatchTableEnvironment => ScalaBatchTabl import org.apache.flink.table.api.{SqlParserException, Table, TableConfig, TableConfigOptions, TableEnvironment, TableImpl} import org.apache.flink.table.dataformat.{BinaryRow, BinaryRowWriter} import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.plan.stats.FlinkStatistic import org.apache.flink.table.plan.util.FlinkRelOptUtil import org.apache.flink.table.runtime.utils.BatchAbstractTestBase.DEFAULT_PARALLELISM import org.apache.flink.table.util.{BaseRowTestUtil, DiffRepository} @@ -388,18 +389,21 @@ class BatchTestBase extends BatchAbstractTestBase { tableName: String, data: Iterable[T], typeInfo: TypeInformation[T], - fieldNullables: Array[Boolean], - fields: String): Unit = { - BatchTableEnvUtil.registerCollection(tEnv, tableName, data, typeInfo, fields, fieldNullables) + fields: String, + fieldNullables: Array[Boolean]): Unit = { + BatchTableEnvUtil.registerCollection( + tEnv, tableName, data, typeInfo, fields, fieldNullables, None) } - def registerCollection( + def registerCollection[T]( tableName: String, - data: Iterable[Row], - typeInfo: TypeInformation[Row], + data: Iterable[T], + typeInfo: TypeInformation[T], fields: String, - fieldNullables: Array[Boolean]): Unit = { - BatchTableEnvUtil.registerCollection(tEnv, tableName, data, typeInfo, fields, fieldNullables) + fieldNullables: Array[Boolean], + statistic: FlinkStatistic): Unit = { + BatchTableEnvUtil.registerCollection( + tEnv, tableName, data, typeInfo, fields, fieldNullables, Some(statistic)) } def registerFunction[T: TypeInformation, ACC: TypeInformation]( diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTableEnvUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTableEnvUtil.scala new file mode 100644 index 00000000000000..97314097206e3e --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTableEnvUtil.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.utils + +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment} +import org.apache.flink.table.plan.schema.DataStreamTable +import org.apache.flink.table.plan.stats.FlinkStatistic + +object StreamTableEnvUtil { + + // TODO unify BatchTableEnvUtil and StreamTableEnvUtil + /** + * Registers a [[DataStream]] as a table under a given name in the [[TableEnvironment]]'s + * catalog. + * + * @param name The name under which the table is registered in the catalog. + * @param dataStream The [[DataStream]] to register as table in the catalog. + * @tparam T the type of the [[DataStream]]. + */ + def registerDataStreamInternal[T]( + tEnv: StreamTableEnvironment, + name: String, + dataStream: DataStream[T], + fieldNames: Option[Array[String]], + fieldNullables: Option[Array[Boolean]], + statistic: Option[FlinkStatistic]): Unit = { + val (typeFieldNames, fieldIdxs) = + tEnv.getFieldInfo(dataStream.getTransformation.getOutputType) + val boundedStreamTable = new DataStreamTable[T]( + dataStream, fieldIdxs, fieldNames.getOrElse(typeFieldNames), fieldNullables) + val withStatistic = boundedStreamTable.copy(statistic.getOrElse(FlinkStatistic.UNKNOWN)) + tEnv.registerTableInternal(name, withStatistic) + } + +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestData.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestData.scala deleted file mode 100644 index f437ab8d19ce38..00000000000000 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/StreamTestData.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.utils - -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} - -import java.sql.{Date, Time, Timestamp} - -import scala.collection.mutable - -object StreamTestData { - - def getSmall3TupleData: Seq[(Int, Long, String)] = { - val data = new mutable.MutableList[(Int, Long, String)] - data.+=((1, 1L, "Hi")) - data.+=((2, 2L, "Hello")) - data.+=((3, 2L, "Hello world")) - data - } - - def get3TupleData: Seq[(Int, Long, String)] = { - val data = new mutable.MutableList[(Int, Long, String)] - data.+=((1, 1L, "Hi")) - data.+=((2, 2L, "Hello")) - data.+=((3, 2L, "Hello world")) - data.+=((4, 3L, "Hello world, how are you?")) - data.+=((5, 3L, "I am fine.")) - data.+=((6, 3L, "Luke Skywalker")) - data.+=((7, 4L, "Comment#1")) - data.+=((8, 4L, "Comment#2")) - data.+=((9, 4L, "Comment#3")) - data.+=((10, 4L, "Comment#4")) - data.+=((11, 5L, "Comment#5")) - data.+=((12, 5L, "Comment#6")) - data.+=((13, 5L, "Comment#7")) - data.+=((14, 5L, "Comment#8")) - data.+=((15, 5L, "Comment#9")) - data.+=((16, 6L, "Comment#10")) - data.+=((17, 6L, "Comment#11")) - data.+=((18, 6L, "Comment#12")) - data.+=((19, 6L, "Comment#13")) - data.+=((20, 6L, "Comment#14")) - data.+=((21, 6L, "Comment#15")) - data - } - - def get5TupleData: Seq[(Int, Long, Int, String, Long)] = { - val data = new mutable.MutableList[(Int, Long, Int, String, Long)] - data.+=((1, 1L, 0, "Hallo", 1L)) - data.+=((2, 2L, 1, "Hallo Welt", 2L)) - data.+=((2, 3L, 2, "Hallo Welt wie", 1L)) - data.+=((3, 4L, 3, "Hallo Welt wie gehts?", 2L)) - data.+=((3, 5L, 4, "ABC", 2L)) - data.+=((3, 6L, 5, "BCD", 3L)) - data.+=((4, 7L, 6, "CDE", 2L)) - data.+=((4, 8L, 7, "DEF", 1L)) - data.+=((4, 9L, 8, "EFG", 1L)) - data.+=((4, 10L, 9, "FGH", 2L)) - data.+=((5, 11L, 10, "GHI", 1L)) - data.+=((5, 12L, 11, "HIJ", 3L)) - data.+=((5, 13L, 12, "IJK", 3L)) - data.+=((5, 14L, 13, "JKL", 2L)) - data.+=((5, 15L, 14, "KLM", 2L)) - data - } - - def getSmall3TupleDataStream(env: StreamExecutionEnvironment): DataStream[(Int, Long, String)] = { - env.fromCollection(getSmall3TupleData) - } - - def get3TupleDataStream(env: StreamExecutionEnvironment): DataStream[(Int, Long, String)] = { - env.fromCollection(get3TupleData) - } - - def get5TupleDataStream(env: StreamExecutionEnvironment): - DataStream[(Int, Long, Int, String, Long)] = { - env.fromCollection(get5TupleData) - } - - - def getTimeZoneTestData(env: StreamExecutionEnvironment): - DataStream[(Int, Date, Time, Timestamp)] = { - val MiLLIS_PER_DAY = 24 * 3600 * 1000 - val MiLLIS_PER_HOUR = 3600 * 1000 - val data = new mutable.MutableList[(Int, Date, Time, Timestamp)] - data.+=((1, new Date(0), new Time(0), new Timestamp(1))) - data.+=((2, new Date(1*MiLLIS_PER_DAY), new Time(MiLLIS_PER_HOUR), new Timestamp(2))) - data.+=((3, new Date(2*MiLLIS_PER_DAY), new Time(2*MiLLIS_PER_HOUR), new Timestamp(3))) - - env.fromCollection(data) - } -} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestData.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestData.scala index 9409e73de71a23..8fb0a2b52607dd 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestData.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/TestData.scala @@ -29,7 +29,7 @@ import org.apache.flink.types.Row import java.math.{BigDecimal => JBigDecimal} import java.sql.Timestamp -import scala.collection.Seq +import scala.collection.{Seq, mutable} object TestData { @@ -72,23 +72,31 @@ object TestData { val nullablesOfNullData5 = Array(true, false, false, false, false) - lazy val smallData3 = Seq( - row(1, 1L, "Hi"), - row(2, 2L, "Hello"), - row(3, 2L, "Hello world") - ) + lazy val smallTupleData3: Seq[(Int, Long, String)] = { + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Hi")) + data.+=((2, 2L, "Hello")) + data.+=((3, 2L, "Hello world")) + data + } + + lazy val smallData3: Seq[Row] = smallTupleData3.map(d => row(d.productIterator.toList: _*)) val nullablesOfSmallData3 = Array(false, false, false) - lazy val smallData5 = Seq( - row(1, 1L, 0, "Hallo", 1L), - row(2, 2L, 1, "Hallo Welt", 2L), - row(2, 3L, 2, "Hallo Welt wie", 1L) - ) + lazy val smallTupleData5: Seq[(Int, Long, Int, String, Long)] = { + val data = new mutable.MutableList[(Int, Long, Int, String, Long)] + data.+=((1, 1L, 0, "Hallo", 1L)) + data.+=((2, 2L, 1, "Hallo Welt", 2L)) + data.+=((2, 3L, 2, "Hallo Welt wie", 1L)) + data + } + + lazy val smallData5: Seq[Row] = smallTupleData5.map(d => row(d.productIterator.toList: _*)) val nullablesOfSmallData5 = Array(false, false, false, false, false) - lazy val buildInData = Seq( + lazy val buildInData: Seq[Row] = Seq( row(false, 1.toByte, 2, 3L, 2.56, "abcd", "f%g", UTCDate("2017-12-12"), UTCTime("10:08:09"), UTCTimestamp("2017-11-11 20:32:19")), @@ -99,33 +107,37 @@ object TestData { UTCTime("10:08:09"), UTCTimestamp("2015-05-20 10:00:00.887")) ) - lazy val data3 = Seq( - row(1, 1L, "Hi"), - row(2, 2L, "Hello"), - row(3, 2L, "Hello world"), - row(4, 3L, "Hello world, how are you?"), - row(5, 3L, "I am fine."), - row(6, 3L, "Luke Skywalker"), - row(7, 4L, "Comment#1"), - row(8, 4L, "Comment#2"), - row(9, 4L, "Comment#3"), - row(10, 4L, "Comment#4"), - row(11, 5L, "Comment#5"), - row(12, 5L, "Comment#6"), - row(13, 5L, "Comment#7"), - row(14, 5L, "Comment#8"), - row(15, 5L, "Comment#9"), - row(16, 6L, "Comment#10"), - row(17, 6L, "Comment#11"), - row(18, 6L, "Comment#12"), - row(19, 6L, "Comment#13"), - row(20, 6L, "Comment#14"), - row(21, 6L, "Comment#15") - ) + lazy val tupleData3: Seq[(Int, Long, String)] = { + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Hi")) + data.+=((2, 2L, "Hello")) + data.+=((3, 2L, "Hello world")) + data.+=((4, 3L, "Hello world, how are you?")) + data.+=((5, 3L, "I am fine.")) + data.+=((6, 3L, "Luke Skywalker")) + data.+=((7, 4L, "Comment#1")) + data.+=((8, 4L, "Comment#2")) + data.+=((9, 4L, "Comment#3")) + data.+=((10, 4L, "Comment#4")) + data.+=((11, 5L, "Comment#5")) + data.+=((12, 5L, "Comment#6")) + data.+=((13, 5L, "Comment#7")) + data.+=((14, 5L, "Comment#8")) + data.+=((15, 5L, "Comment#9")) + data.+=((16, 6L, "Comment#10")) + data.+=((17, 6L, "Comment#11")) + data.+=((18, 6L, "Comment#12")) + data.+=((19, 6L, "Comment#13")) + data.+=((20, 6L, "Comment#14")) + data.+=((21, 6L, "Comment#15")) + data + } + + lazy val data3: Seq[Row] = tupleData3.map(d => row(d.productIterator.toList: _*)) val nullablesOfData3 = Array(false, false, false) - lazy val genericData3 = Seq( + lazy val genericData3: Seq[Row] = Seq( row(new JTuple2("1", 1), new JTuple2(1, 1), 1), row(new JTuple2("2", 1), new JTuple2(1, 1), 2), row(new JTuple2("1", 1), new JTuple2(1, 1), 1), @@ -134,7 +146,7 @@ object TestData { val nullablesOfData3WithTimestamp = Array(true, false, false, false) - lazy val data3WithTimestamp = Seq( + lazy val data3WithTimestamp: Seq[Row] = Seq( row(2, 2L, "Hello", new Timestamp(2000L)), row(1, 1L, "Hi", new Timestamp(1000L)), row(3, 2L, "Hello world", new Timestamp(3000L)), @@ -158,27 +170,31 @@ object TestData { row(21, 6L, "Comment#15", new Timestamp(21000L)) ) - lazy val data5 = Seq( - row(1, 1L, 0, "Hallo", 1L), - row(2, 2L, 1, "Hallo Welt", 2L), - row(2, 3L, 2, "Hallo Welt wie", 1L), - row(3, 4L, 3, "Hallo Welt wie gehts?", 2L), - row(3, 5L, 4, "ABC", 2L), - row(3, 6L, 5, "BCD", 3L), - row(4, 7L, 6, "CDE", 2L), - row(4, 8L, 7, "DEF", 1L), - row(4, 9L, 8, "EFG", 1L), - row(4, 10L, 9, "FGH", 2L), - row(5, 11L, 10, "GHI", 1L), - row(5, 12L, 11, "HIJ", 3L), - row(5, 13L, 12, "IJK", 3L), - row(5, 14L, 13, "JKL", 2L), - row(5, 15L, 14, "KLM", 2L) - ) + lazy val tupleData5: Seq[(Int, Long, Int, String, Long)] = { + val data = new mutable.MutableList[(Int, Long, Int, String, Long)] + data.+=((1, 1L, 0, "Hallo", 1L)) + data.+=((2, 2L, 1, "Hallo Welt", 2L)) + data.+=((2, 3L, 2, "Hallo Welt wie", 1L)) + data.+=((3, 4L, 3, "Hallo Welt wie gehts?", 2L)) + data.+=((3, 5L, 4, "ABC", 2L)) + data.+=((3, 6L, 5, "BCD", 3L)) + data.+=((4, 7L, 6, "CDE", 2L)) + data.+=((4, 8L, 7, "DEF", 1L)) + data.+=((4, 9L, 8, "EFG", 1L)) + data.+=((4, 10L, 9, "FGH", 2L)) + data.+=((5, 11L, 10, "GHI", 1L)) + data.+=((5, 12L, 11, "HIJ", 3L)) + data.+=((5, 13L, 12, "IJK", 3L)) + data.+=((5, 14L, 13, "JKL", 2L)) + data.+=((5, 15L, 14, "KLM", 2L)) + data + } + + lazy val data5: Seq[Row] = tupleData5.map(d => row(d.productIterator.toList: _*)) val nullablesOfData5 = Array(false, false, false, false, false) - lazy val data6 = Seq( + lazy val data6: Seq[Row] = Seq( row(1, 1.1, "a", UTCDate("2017-04-08"), UTCTime("12:00:59"), UTCTimestamp("2015-05-20 10:00:00")), row(2, 2.5, "abc", UTCDate("2017-04-09"), UTCTime("12:00:59"), @@ -213,7 +229,7 @@ object TestData { val nullablesOfData6 = Array(false, false, false, false, false, false) - lazy val duplicateData5 = Seq( + lazy val duplicateData5: Seq[Row] = Seq( row(1, 1L, 10, "Hallo", 1L), row(2, 2L, 11, "Hallo Welt", 2L), row(2, 3L, 12, "Hallo Welt wie", 1L), @@ -233,7 +249,7 @@ object TestData { val nullablesOfDuplicateData5 = Array(false, false, false, false, false) - lazy val numericData = Seq( + lazy val numericData: Seq[Row] = Seq( row(1, 1L, 1.0f, 1.0d, JBigDecimal.valueOf(1)), row(2, 2L, 2.0f, 2.0d, JBigDecimal.valueOf(2)), row(3, 3L, 3.0f, 3.0d, JBigDecimal.valueOf(3)) @@ -242,7 +258,7 @@ object TestData { val nullablesOfNumericData = Array(false, false, false, false, false) // person test data - lazy val personData = Seq( + lazy val personData: Seq[Row] = Seq( row(1, 23, "tom", 172, "m"), row(2, 21, "mary", 161, "f"), row(3, 18, "jack", 182, "m"), @@ -265,7 +281,7 @@ object TestData { val INT_ONLY = new RowTypeInfo(INT_TYPE_INFO) val INT_INT = new RowTypeInfo(INT_TYPE_INFO, INT_TYPE_INFO) - lazy val data2_1 = Seq( + lazy val data2_1: Seq[Row] = Seq( row(1, 2.0), row(1, 2.0), row(2, 1.0), @@ -276,7 +292,7 @@ object TestData { row(6, null) ) - lazy val data2_2 = Seq( + lazy val data2_2: Seq[Row] = Seq( row(2, 3.0), row(2, 3.0), row(3, 2.0), @@ -286,7 +302,7 @@ object TestData { row(6, null) ) - lazy val data2_3 = Seq( + lazy val data2_3: Seq[Row] = Seq( row(2, 3.0), row(2, 3.0), row(3, 2.0), @@ -323,7 +339,7 @@ object TestData { val nullablesOfIntIntData3 = Array(false, false) - lazy val upperCaseData = Seq( + lazy val upperCaseData: Seq[Row] = Seq( row(1, "A"), row(2, "B"), row(3, "C"), @@ -333,7 +349,7 @@ object TestData { val nullablesOfUpperCaseData = Array(false, false) - lazy val lowerCaseData = Seq( + lazy val lowerCaseData: Seq[Row] = Seq( row(1, "a"), row(2, "b"), row(3, "c"), @@ -341,7 +357,7 @@ object TestData { val nullablesOfLowerCaseData = Array(false, false) - lazy val allNulls = Seq( + lazy val allNulls: Seq[Row] = Seq( row(null), row(null), row(null), @@ -349,7 +365,7 @@ object TestData { val nullablesOfAllNulls = Array(true) - lazy val projectionTestData = Seq( + lazy val projectionTestData: Seq[Row] = Seq( row(1, 10, 100, "1", "10", "100", 1000, "1000"), row(2, 20, 200, "2", "20", "200", 2000, "2000"), row(3, 30, 300, "3", "30", "300", 3000, "3000")) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala index fedd7e8e68d6e7..be85b3c5bd283e 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/util/TableTestBase.scala @@ -144,16 +144,14 @@ abstract class TableTestUtil(test: TableTestBase) { * @param name table name * @param types field types * @param names field names - * @param tableStats table stats - * @param uniqueKeys unique keys + * @param statistic statistic of current table * @return returns the registered [[Table]]. */ def addTableSource( name: String, types: Array[TypeInformation[_]], names: Array[String], - tableStats: Option[TableStats] = None, - uniqueKeys: Option[JSet[_ <: JSet[String]]] = None): Table + statistic: FlinkStatistic = FlinkStatistic.UNKNOWN): Table /** * Create a [[DataStream]] with the given schema, @@ -478,15 +476,10 @@ case class StreamTableTestUtil(test: TableTestBase) extends TableTestUtil(test) name: String, types: Array[TypeInformation[_]], names: Array[String], - tableStats: Option[TableStats] = None, - uniqueKeys: Option[JSet[_ <: JSet[String]]] = None): Table = { + statistic: FlinkStatistic = FlinkStatistic.UNKNOWN): Table = { val tableEnv = getTableEnv val schema = new TableSchema(names, types) val tableSource = new TestTableSource(schema) - val statistic = FlinkStatistic.builder() - .tableStats(tableStats.orNull) - .uniqueKeys(uniqueKeys.orNull) - .build() val table = new StreamTableSourceTable[BaseRow](tableSource, statistic) tableEnv.registerTableInternal(name, table) tableEnv.scan(name) @@ -596,15 +589,10 @@ case class BatchTableTestUtil(test: TableTestBase) extends TableTestUtil(test) { name: String, types: Array[TypeInformation[_]], names: Array[String], - tableStats: Option[TableStats] = None, - uniqueKeys: Option[JSet[_ <: JSet[String]]] = None): Table = { + statistic: FlinkStatistic = FlinkStatistic.UNKNOWN): Table = { val tableEnv = getTableEnv val schema = new TableSchema(names, types) val tableSource = new TestTableSource(schema) - val statistic = FlinkStatistic.builder() - .tableStats(tableStats.orNull) - .uniqueKeys(uniqueKeys.orNull) - .build() val table = new BatchTableSourceTable[BaseRow](tableSource, statistic) tableEnv.registerTableInternal(name, table) tableEnv.scan(name) From 0c99d3dbb01f12de5a96065261d71b9138ba6d25 Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Thu, 30 May 2019 16:25:26 +0800 Subject: [PATCH 33/92] [hotfix] Fix the version number in NOTICE and pom in table-planner-blink --- flink-table/flink-table-planner-blink/pom.xml | 6 +++--- .../src/main/resources/META-INF/NOTICE | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flink-table/flink-table-planner-blink/pom.xml b/flink-table/flink-table-planner-blink/pom.xml index 5fdada9642d8f7..f960b474ddab62 100644 --- a/flink-table/flink-table-planner-blink/pom.xml +++ b/flink-table/flink-table-planner-blink/pom.xml @@ -131,11 +131,11 @@ under the License. Dependencies that are not needed for how we use Calcite right now. - "mvn dependency:tree" as of Calcite 1.18: + "mvn dependency:tree" as of Calcite 1.19: - [INFO] +- org.apache.calcite:calcite-core:jar:1.18.0:compile + [INFO] +- org.apache.calcite:calcite-core:jar:1.19.0:compile [INFO] | +- org.apache.calcite.avatica:avatica-core:jar:1.13.0:compile - [INFO] | +- org.apache.calcite:calcite-linq4j:jar:1.18.0:compile + [INFO] | +- org.apache.calcite:calcite-linq4j:jar:1.19.0:compile [INFO] | +- org.apache.commons:commons-lang3:jar:3.3.2:compile [INFO] | +- com.fasterxml.jackson.core:jackson-core:jar:2.9.6:compile [INFO] | +- com.fasterxml.jackson.core:jackson-annotations:jar:2.9.6:compile diff --git a/flink-table/flink-table-planner-blink/src/main/resources/META-INF/NOTICE b/flink-table/flink-table-planner-blink/src/main/resources/META-INF/NOTICE index f6a41ac04470aa..a87ebf703de14f 100644 --- a/flink-table/flink-table-planner-blink/src/main/resources/META-INF/NOTICE +++ b/flink-table/flink-table-planner-blink/src/main/resources/META-INF/NOTICE @@ -12,8 +12,8 @@ This project bundles the following dependencies under the Apache Software Licens - com.fasterxml.jackson.core:jackson-databind:2.9.6 - com.jayway.jsonpath:json-path:2.4.0 - joda-time:joda-time:2.5 -- org.apache.calcite:calcite-core:1.18.0 -- org.apache.calcite:calcite-linq4j:1.18.0 +- org.apache.calcite:calcite-core:1.19.0 +- org.apache.calcite:calcite-linq4j:1.19.0 - org.apache.calcite.avatica:avatica-core:1.13.0 This project bundles the following dependencies under the BSD license. From 809e40dadad68232c68ac34a3d56bfb69a9396d3 Mon Sep 17 00:00:00 2001 From: Andrey Zagrebin Date: Wed, 15 May 2019 13:12:26 +0200 Subject: [PATCH 34/92] [FLINK-12530][network] Move Task.inputGatesById to NetworkEnvironment Task.inputGatesById indexes SingleInputGates by id. The end user of this indexing is NetworkEnvironment for two cases: - SingleInputGate triggers producer partition readiness check and then the successful result of check is dispatched back to this SingleInputGate by id. We can just return a future from TaskActions.triggerPartitionProducerStateCheck. SingleInputGate could use the future to react with re-triggering of the partition request if the producer is ready. Then inputGatesById is not needed for dispatching. - TaskExecutor.updatePartitions uses inputGatesById to dispatch PartitionInfo update to the right SingleInputGate. If inputGatesById is moved to NetworkEnvironment, which should be a better place for gate management, and NetworkEnvironment.updatePartitionInfo is added then TaskExecutor.updatePartitions could directly call NetworkEnvironment.updatePartitionInfo. Additional refactoring: - TaskActions.triggerPartitionProducerStateCheck is separated into another interface PartitionProducerStateProvider. TaskActions is too broad interface used also for other purposes. Shuffle API needs only PartitionProducerStateProvider. - PartitionProducerStateProvider returns future with the ResponseHandle which contains the producer state and accepts callbacks to cancel or fail consumption as a result of state check. - Task.triggerPartitionProducerStateCheck is also refactored into a RemoteChannelStateChecker which becomes internal detail of NetworkEnvironment. RemoteChannelStateChecker accepts ResponseHandle, checks whether producer is ready for consumption or aborts consumption using ResponseHandle.cancelConsumption or ResponseHandle.failConsumption. --- .../io/network/NetworkEnvironment.java | 52 +++++- .../PartitionProducerStateProvider.java | 63 ++++++++ .../partition/consumer/InputGateID.java | 85 ++++++++++ .../consumer/RemoteChannelStateChecker.java | 125 +++++++++++++++ .../consumer/RemoteInputChannel.java | 2 +- .../partition/consumer/SingleInputGate.java | 58 ++++--- .../consumer/SingleInputGateFactory.java | 16 +- .../taskexecutor/JobManagerConnection.java | 1 - .../PartitionProducerStateChecker.java | 2 +- .../runtime/taskexecutor/TaskExecutor.java | 45 +++--- .../rpc/RpcPartitionStateChecker.java | 2 +- .../runtime/taskmanager/NoOpTaskActions.java | 10 -- .../flink/runtime/taskmanager/Task.java | 151 +++++------------- .../runtime/taskmanager/TaskActions.java | 17 -- .../partition/InputGateFairnessTest.java | 20 +-- .../consumer/SingleInputGateBuilder.java | 15 +- .../consumer/SingleInputGateTest.java | 69 +++++++- .../TaskExecutorSubmissionTest.java | 49 +++--- .../taskexecutor/TaskExecutorTest.java | 1 - .../TaskSubmissionTestEnvironment.java | 42 +++-- .../taskmanager/TaskAsyncCallTest.java | 2 +- .../flink/runtime/taskmanager/TaskTest.java | 68 ++++---- .../runtime/util/JvmExitOnFatalErrorTest.java | 2 +- .../StreamNetworkBenchmarkEnvironment.java | 7 +- .../tasks/InterruptSensitiveRestoreTest.java | 2 +- .../tasks/StreamTaskTerminationTest.java | 2 +- .../runtime/tasks/StreamTaskTest.java | 2 +- .../tasks/SynchronousCheckpointITCase.java | 2 +- .../tasks/TaskCheckpointingBehaviourTest.java | 2 +- 29 files changed, 621 insertions(+), 293 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionProducerStateProvider.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateID.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteChannelStateChecker.java rename flink-runtime/src/main/java/org/apache/flink/runtime/{io/network/netty => taskexecutor}/PartitionProducerStateChecker.java (97%) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 7ee2a205259b63..43969e2c14369c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.PartitionInfo; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.metrics.InputBufferPoolUsageGauge; @@ -37,13 +38,17 @@ import org.apache.flink.runtime.io.network.metrics.ResultPartitionMetrics; import org.apache.flink.runtime.io.network.netty.NettyConfig; import org.apache.flink.runtime.io.network.netty.NettyConnectionManager; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionFactory; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.InputGateID; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.taskexecutor.TaskExecutor; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; import org.apache.flink.runtime.taskmanager.TaskActions; @@ -54,6 +59,9 @@ import java.io.IOException; import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -84,6 +92,8 @@ public class NetworkEnvironment { private final ResultPartitionManager resultPartitionManager; + private final Map inputGatesById; + private final TaskEventPublisher taskEventPublisher; private final ResultPartitionFactory resultPartitionFactory; @@ -104,6 +114,7 @@ private NetworkEnvironment( this.networkBufferPool = networkBufferPool; this.connectionManager = connectionManager; this.resultPartitionManager = resultPartitionManager; + this.inputGatesById = new ConcurrentHashMap<>(); this.taskEventPublisher = taskEventPublisher; this.resultPartitionFactory = resultPartitionFactory; this.singleInputGateFactory = singleInputGateFactory; @@ -190,6 +201,11 @@ public NetworkEnvironmentConfiguration getConfiguration() { return config; } + @VisibleForTesting + public Optional getInputGate(InputGateID id) { + return Optional.ofNullable(inputGatesById.get(id)); + } + /** * Batch release intermediate result partitions. * @@ -231,8 +247,8 @@ public ResultPartition[] createResultPartitionWriters( public SingleInputGate[] createInputGates( String taskName, - JobID jobId, - TaskActions taskActions, + ExecutionAttemptID executionId, + PartitionProducerStateProvider partitionProducerStateProvider, Collection inputGateDeploymentDescriptors, MetricGroup parentGroup, MetricGroup inputGroup, @@ -245,13 +261,16 @@ public SingleInputGate[] createInputGates( SingleInputGate[] inputGates = new SingleInputGate[inputGateDeploymentDescriptors.size()]; int counter = 0; for (InputGateDeploymentDescriptor igdd : inputGateDeploymentDescriptors) { - inputGates[counter++] = singleInputGateFactory.create( + SingleInputGate inputGate = singleInputGateFactory.create( taskName, - jobId, igdd, - taskActions, + partitionProducerStateProvider, inputChannelMetrics, numBytesInCounter); + InputGateID id = new InputGateID(igdd.getConsumedResultId(), executionId); + inputGatesById.put(id, inputGate); + inputGate.getCloseFuture().thenRun(() -> inputGatesById.remove(id)); + inputGates[counter++] = inputGate; } registerInputMetrics(inputGroup, buffersGroup, inputGates); @@ -275,6 +294,29 @@ private void registerInputMetrics(MetricGroup inputGroup, MetricGroup buffersGro buffersGroup.gauge(METRIC_INPUT_POOL_USAGE, new InputBufferPoolUsageGauge(inputGates)); } + /** + * Update consuming gate with newly available partition. + * + * @param consumerID execution id of consumer to identify belonging to it gate. + * @param partitionInfo telling where the partition can be retrieved from + * @return {@code true} if the partition has been updated or {@code false} if the partition is not available anymore. + * @throws IOException IO problem by the update + * @throws InterruptedException potentially blocking operation was interrupted + * @throws IllegalStateException the input gate with the id from the partitionInfo is not found + */ + public boolean updatePartitionInfo( + ExecutionAttemptID consumerID, + PartitionInfo partitionInfo) throws IOException, InterruptedException { + IntermediateDataSetID intermediateResultPartitionID = partitionInfo.getIntermediateDataSetID(); + InputGateID id = new InputGateID(intermediateResultPartitionID, consumerID); + SingleInputGate inputGate = inputGatesById.get(id); + if (inputGate == null) { + return false; + } + inputGate.updateInputChannel(partitionInfo.getInputChannelDeploymentDescriptor()); + return true; + } + public void start() throws IOException { synchronized (lock) { Preconditions.checkState(!isShutdown, "The NetworkEnvironment has already been shut down."); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionProducerStateProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionProducerStateProvider.java new file mode 100644 index 00000000000000..8bbdaa53b16a82 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionProducerStateProvider.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.types.Either; + +import java.util.concurrent.CompletableFuture; + +/** + * Request execution state of partition producer, the response accepts state check callbacks. + */ +public interface PartitionProducerStateProvider { + /** + * Trigger the producer execution state request. + * + * @param intermediateDataSetId ID of the parent intermediate data set. + * @param resultPartitionId ID of the result partition to check. This + * identifies the producing execution and partition. + * @return a future with response handle. + */ + CompletableFuture requestPartitionProducerState( + IntermediateDataSetID intermediateDataSetId, + ResultPartitionID resultPartitionId); + + /** + * Result of state query, accepts state check callbacks. + */ + interface ResponseHandle { + ExecutionState getConsumerExecutionState(); + + Either getProducerExecutionState(); + + /** + * Cancel the partition consumptions as a result of state check. + */ + void cancelConsumption(); + + /** + * Fail the partition consumptions as a result of state check. + * + * @param cause failure cause + */ + void failConsumption(Throwable cause); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateID.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateID.java new file mode 100644 index 00000000000000..ffa2542f2f7693 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateID.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import java.io.Serializable; + +/** + * Runtime identifier of a consumed {@link org.apache.flink.runtime.executiongraph.IntermediateResult}. + * + *

At runtime the {@link org.apache.flink.runtime.jobgraph.IntermediateDataSetID} is not enough to uniquely + * identify an input gate. It needs to be associated with the consuming task as well to ensure + * correct tracking of gates in shuffle implementation. + */ +public class InputGateID implements Serializable { + + private static final long serialVersionUID = 4613970383536333315L; + + /** + * The ID of the consumed intermediate result. Each input gate consumes partitions of the + * intermediate result specified by this ID. This ID also identifies the input gate at the + * consuming task. + */ + private final IntermediateDataSetID consumedResultID; + + /** + * The ID of the consumer. + * + *

The ID of {@link org.apache.flink.runtime.executiongraph.Execution} and + * its local {@link org.apache.flink.runtime.taskmanager.Task}. + */ + private final ExecutionAttemptID consumerID; + + public InputGateID(IntermediateDataSetID consumedResultID, ExecutionAttemptID consumerID) { + this.consumedResultID = consumedResultID; + this.consumerID = consumerID; + } + + public IntermediateDataSetID getConsumedResultID() { + return consumedResultID; + } + + public ExecutionAttemptID getConsumerID() { + return consumerID; + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj.getClass() == InputGateID.class) { + InputGateID o = (InputGateID) obj; + + return o.getConsumedResultID().equals(consumedResultID) && o.getConsumerID().equals(consumerID); + } + + return false; + } + + @Override + public int hashCode() { + return consumedResultID.hashCode() ^ consumerID.hashCode(); + } + + @Override + public String toString() { + return consumedResultID.toString() + "@" + consumerID.toString(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteChannelStateChecker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteChannelStateChecker.java new file mode 100644 index 00000000000000..69ee3fd167ef0a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteChannelStateChecker.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider.ResponseHandle; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException; +import org.apache.flink.types.Either; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.TimeoutException; + +/** + * Handles the response of {@link PartitionProducerStateProvider}. + * + *

The method {@code isProducerReadyOrAbortConsumption} determines + * whether the partition producer is in a producing state, ready for consumption. + * Otherwise it aborts the consumption. + */ +public class RemoteChannelStateChecker { + private static final Logger LOG = LoggerFactory.getLogger(RemoteChannelStateChecker.class); + + private final ResultPartitionID resultPartitionId; + + private final String taskNameWithSubtask; + + public RemoteChannelStateChecker(ResultPartitionID resultPartitionId, String taskNameWithSubtask) { + this.resultPartitionId = resultPartitionId; + this.taskNameWithSubtask = taskNameWithSubtask; + } + + public boolean isProducerReadyOrAbortConsumption(ResponseHandle responseHandle) { + Either result = responseHandle.getProducerExecutionState(); + if (responseHandle.getConsumerExecutionState() != ExecutionState.RUNNING) { + LOG.debug( + "Ignore a partition producer state notification for task {}, because it's not running.", + taskNameWithSubtask); + } + else if (result.isLeft() || result.right() instanceof TimeoutException) { + boolean isProducerConsumerReady = isProducerConsumerReady(responseHandle); + if (isProducerConsumerReady) { + return true; + } else { + abortConsumptionOrIgnoreCheckResult(responseHandle); + } + } else { + handleFailedCheckResult(responseHandle); + } + return false; + } + + private boolean isProducerConsumerReady(ResponseHandle responseHandle) { + ExecutionState producerState = getProducerState(responseHandle); + return producerState == ExecutionState.SCHEDULED || + producerState == ExecutionState.DEPLOYING || + producerState == ExecutionState.RUNNING || + producerState == ExecutionState.FINISHED; + } + + private void abortConsumptionOrIgnoreCheckResult(ResponseHandle responseHandle) { + ExecutionState producerState = getProducerState(responseHandle); + if (producerState == ExecutionState.CANCELING || + producerState == ExecutionState.CANCELED || + producerState == ExecutionState.FAILED) { + + // The producing execution has been canceled or failed. We + // don't need to re-trigger the request since it cannot + // succeed. + if (LOG.isDebugEnabled()) { + LOG.debug("Cancelling task {} after the producer of partition {} with attempt ID {} has entered state {}.", + taskNameWithSubtask, + resultPartitionId.getPartitionId(), + resultPartitionId.getProducerId(), + producerState); + } + + responseHandle.cancelConsumption(); + } else { + // Any other execution state is unexpected. Currently, only + // state CREATED is left out of the checked states. If we + // see a producer in this state, something went wrong with + // scheduling in topological order. + final String msg = String.format("Producer with attempt ID %s of partition %s in unexpected state %s.", + resultPartitionId.getProducerId(), + resultPartitionId.getPartitionId(), + producerState); + + responseHandle.failConsumption(new IllegalStateException(msg)); + } + } + + private static ExecutionState getProducerState(ResponseHandle responseHandle) { + Either result = responseHandle.getProducerExecutionState(); + return result.isLeft() ? result.left() : ExecutionState.RUNNING; + } + + private void handleFailedCheckResult(ResponseHandle responseHandle) { + Throwable throwable = responseHandle.getProducerExecutionState().right(); + if (throwable instanceof PartitionProducerDisposedException) { + String msg = String.format( + "Producer %s of partition %s disposed. Cancelling execution.", + resultPartitionId.getProducerId(), + resultPartitionId.getPartitionId()); + LOG.info(msg, throwable); + responseHandle.cancelConsumption(); + } else { + responseHandle.failConsumption(throwable); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 2d174eafb543ce..50bf1d07945dd3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -175,7 +175,7 @@ public void requestSubpartition(int subpartitionIndex) throws IOException, Inter /** * Retriggers a remote subpartition request. */ - void retriggerSubpartitionRequest(int subpartitionIndex) throws IOException, InterruptedException { + void retriggerSubpartitionRequest(int subpartitionIndex) throws IOException { checkState(partitionRequestClient != null, "Missing initial subpartition request."); if (increaseBackoff()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 63504bb7c57df2..5e5a722ffd51f2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.JobID; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; @@ -30,13 +29,13 @@ import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.function.SupplierWithException; import org.slf4j.Logger; @@ -107,9 +106,6 @@ public class SingleInputGate extends InputGate { /** The name of the owning task, for logging purposes. */ private final String owningTaskName; - /** The job ID of the owning task. */ - private final JobID jobId; - /** * The ID of the consumed intermediate result. Each input gate consumes partitions of the * intermediate result specified by this ID. This ID also identifies the input gate at the @@ -146,8 +142,8 @@ public class SingleInputGate extends InputGate { private final BitSet channelsWithEndOfPartitionEvents; - /** The partition state listener listening to failed partition requests. */ - private final TaskActions taskActions; + /** The partition producer state listener. */ + private final PartitionProducerStateProvider partitionProducerStateProvider; /** * Buffer pool for incoming buffers. Incoming data from remote channels is copied to buffers @@ -162,9 +158,6 @@ public class SingleInputGate extends InputGate { /** Flag indicating whether partitions have been requested. */ private boolean requestedPartitionsFlag; - /** Flag indicating whether all resources have been released. */ - private volatile boolean isReleased; - private final List pendingEvents = new ArrayList<>(); private int numberOfUninitializedChannels; @@ -176,20 +169,20 @@ public class SingleInputGate extends InputGate { private final SupplierWithException bufferPoolFactory; + private final CompletableFuture closeFuture; + public SingleInputGate( String owningTaskName, - JobID jobId, IntermediateDataSetID consumedResultId, final ResultPartitionType consumedPartitionType, int consumedSubpartitionIndex, int numberOfInputChannels, - TaskActions taskActions, + PartitionProducerStateProvider partitionProducerStateProvider, Counter numBytesIn, boolean isCreditBased, SupplierWithException bufferPoolFactory) { this.owningTaskName = checkNotNull(owningTaskName); - this.jobId = checkNotNull(jobId); this.consumedResultId = checkNotNull(consumedResultId); this.consumedPartitionType = checkNotNull(consumedPartitionType); @@ -205,11 +198,13 @@ public SingleInputGate( this.channelsWithEndOfPartitionEvents = new BitSet(numberOfInputChannels); this.enqueuedInputChannelsWithData = new BitSet(numberOfInputChannels); - this.taskActions = checkNotNull(taskActions); + this.partitionProducerStateProvider = checkNotNull(partitionProducerStateProvider); this.numBytesIn = checkNotNull(numBytesIn); this.isCreditBased = isCreditBased; + + this.closeFuture = new CompletableFuture<>(); } @Override @@ -289,6 +284,10 @@ public String getOwningTaskName() { return owningTaskName; } + public CompletableFuture getCloseFuture() { + return closeFuture; + } + // ------------------------------------------------------------------------ // Setup/Life-cycle // ------------------------------------------------------------------------ @@ -327,7 +326,7 @@ public void setInputChannel(IntermediateResultPartitionID partitionId, InputChan public void updateInputChannel(InputChannelDeploymentDescriptor icdd) throws IOException, InterruptedException { synchronized (requestLock) { - if (isReleased) { + if (closeFuture.isDone()) { // There was a race with a task failure/cancel return; } @@ -380,9 +379,9 @@ else if (partitionLocation.isRemote()) { /** * Retriggers a partition request. */ - public void retriggerPartitionRequest(IntermediateResultPartitionID partitionId) throws IOException, InterruptedException { + public void retriggerPartitionRequest(IntermediateResultPartitionID partitionId) throws IOException { synchronized (requestLock) { - if (!isReleased) { + if (!closeFuture.isDone()) { final InputChannel ch = inputChannels.get(partitionId); checkNotNull(ch, "Unknown input channel with ID " + partitionId); @@ -419,7 +418,7 @@ Timer getRetriggerLocalRequestTimer() { public void close() throws IOException { boolean released = false; synchronized (requestLock) { - if (!isReleased) { + if (!closeFuture.isDone()) { try { LOG.debug("{}: Releasing {}.", owningTaskName, this); @@ -444,8 +443,8 @@ public void close() throws IOException { } } finally { - isReleased = true; released = true; + closeFuture.complete(null); } } } @@ -474,7 +473,7 @@ public boolean isFinished() { public void requestPartitions() throws IOException, InterruptedException { synchronized (requestLock) { if (!requestedPartitionsFlag) { - if (isReleased) { + if (closeFuture.isDone()) { throw new IllegalStateException("Already released."); } @@ -513,7 +512,7 @@ private Optional getNextBufferOrEvent(boolean blocking) throws IO return Optional.empty(); } - if (isReleased) { + if (closeFuture.isDone()) { throw new IllegalStateException("Released"); } @@ -623,7 +622,20 @@ void notifyChannelNonEmpty(InputChannel channel) { } void triggerPartitionStateCheck(ResultPartitionID partitionId) { - taskActions.triggerPartitionProducerStateCheck(jobId, consumedResultId, partitionId); + partitionProducerStateProvider.requestPartitionProducerState( + consumedResultId, + partitionId) + .thenAccept(responseHandle -> { + boolean isProducingState = new RemoteChannelStateChecker(partitionId, owningTaskName) + .isProducerReadyOrAbortConsumption(responseHandle); + if (isProducingState) { + try { + retriggerPartitionRequest(partitionId.getPartitionId()); + } catch (IOException t) { + responseHandle.failConsumption(t); + } + } + }); } private void queueChannel(InputChannel channel) { @@ -655,7 +667,7 @@ private void queueChannel(InputChannel channel) { private Optional getChannel(boolean blocking) throws InterruptedException { synchronized (inputChannelsWithData) { while (inputChannelsWithData.size() == 0) { - if (isReleased) { + if (closeFuture.isDone()) { throw new IllegalStateException("Released"); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java index cf2820dd34fcc7..fcc36659edd78f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.JobID; import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -30,12 +29,12 @@ import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.function.SupplierWithException; import org.slf4j.Logger; @@ -98,9 +97,8 @@ public SingleInputGateFactory( */ public SingleInputGate create( @Nonnull String owningTaskName, - @Nonnull JobID jobId, @Nonnull InputGateDeploymentDescriptor igdd, - @Nonnull TaskActions taskActions, + @Nonnull PartitionProducerStateProvider partitionProducerStateProvider, @Nonnull InputChannelMetrics metrics, @Nonnull Counter numBytesInCounter) { final IntermediateDataSetID consumedResultId = checkNotNull(igdd.getConsumedResultId()); @@ -112,8 +110,14 @@ public SingleInputGate create( final InputChannelDeploymentDescriptor[] icdd = checkNotNull(igdd.getInputChannelDeploymentDescriptors()); final SingleInputGate inputGate = new SingleInputGate( - owningTaskName, jobId, consumedResultId, consumedPartitionType, consumedSubpartitionIndex, - icdd.length, taskActions, numBytesInCounter, isCreditBased, + owningTaskName, + consumedResultId, + consumedPartitionType, + consumedSubpartitionIndex, + icdd.length, + partitionProducerStateProvider, + numBytesInCounter, + isCreditBased, createBufferPoolFactory(icdd.length, consumedPartitionType)); // Create the input channels. There is one input channel for each consumed partition. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java index 3c710b6664c44b..6713b4d9a8413e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/JobManagerConnection.java @@ -21,7 +21,6 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.clusterframework.types.ResourceID; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobmaster.JobMasterGateway; import org.apache.flink.runtime.jobmaster.JobMasterId; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionProducerStateChecker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/PartitionProducerStateChecker.java similarity index 97% rename from flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionProducerStateChecker.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/PartitionProducerStateChecker.java index b1ea68b8dd20c4..79922b6d910f45 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionProducerStateChecker.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/PartitionProducerStateChecker.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.io.network.netty; +package org.apache.flink.runtime.taskexecutor; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.execution.ExecutionState; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java index dcfabbffaf4b1d..efc2b427c48a74 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java @@ -49,11 +49,8 @@ import org.apache.flink.runtime.instance.HardwareDescription; import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; @@ -80,7 +77,6 @@ import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.state.TaskStateManagerImpl; import org.apache.flink.runtime.taskexecutor.exceptions.CheckpointException; -import org.apache.flink.runtime.taskexecutor.exceptions.PartitionUpdateException; import org.apache.flink.runtime.taskexecutor.exceptions.RegistrationTimeoutException; import org.apache.flink.runtime.taskexecutor.exceptions.SlotAllocationException; import org.apache.flink.runtime.taskexecutor.exceptions.SlotOccupiedException; @@ -614,34 +610,29 @@ public CompletableFuture updatePartitions( if (task != null) { for (final PartitionInfo partitionInfo: partitionInfos) { - IntermediateDataSetID intermediateResultPartitionID = partitionInfo.getIntermediateDataSetID(); - - final SingleInputGate singleInputGate = task.getInputGateById(intermediateResultPartitionID); - - if (singleInputGate != null) { - // Run asynchronously because it might be blocking - getRpcService().execute( + // Run asynchronously because it might be blocking + FutureUtils.assertNoException( + CompletableFuture.runAsync( () -> { try { - singleInputGate.updateInputChannel(partitionInfo.getInputChannelDeploymentDescriptor()); - } catch (IOException | InterruptedException e) { - log.error("Could not update input data location for task {}. Trying to fail task.", task.getTaskInfo().getTaskName(), e); - - try { - task.failExternally(e); - } catch (RuntimeException re) { - // TODO: Check whether we need this or make exception in failExtenally checked - log.error("Failed canceling task with execution ID {} after task update failure.", executionAttemptID, re); + if (!networkEnvironment.updatePartitionInfo(executionAttemptID, partitionInfo)) { + log.debug( + "Discard update for input gate partition {} of result {} in task {}. " + + "The partition is no longer available.", + partitionInfo.getInputChannelDeploymentDescriptor().getConsumedPartitionId(), + partitionInfo.getIntermediateDataSetID(), + executionAttemptID); } + } catch (IOException | InterruptedException e) { + log.error( + "Could not update input data location for task {}. Trying to fail task.", + task.getTaskInfo().getTaskName(), + e); + task.failExternally(e); } - }); - } else { - return FutureUtils.completedExceptionally( - new PartitionUpdateException("No reader with ID " + intermediateResultPartitionID + - " for task " + executionAttemptID + " was found.")); - } + }, + getRpcService().getExecutor())); } - return CompletableFuture.completedFuture(Acknowledge.get()); } else { log.debug("Discard update for input partitions of task {}. Task is no longer running.", executionAttemptID); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java index f3eb717166a921..47fce1fb47bfe4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcPartitionStateChecker.java @@ -20,7 +20,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobmaster.JobMasterGateway; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NoOpTaskActions.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NoOpTaskActions.java index 5f314cc7a220fa..ff77eedb4b7949 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NoOpTaskActions.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NoOpTaskActions.java @@ -18,21 +18,11 @@ package org.apache.flink.runtime.taskmanager; -import org.apache.flink.api.common.JobID; -import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; - /** * A dummy implementation of the {@link TaskActions} which is mainly used for tests. */ public class NoOpTaskActions implements TaskActions { - @Override - public void triggerPartitionProducerStateCheck( - JobID jobId, - IntermediateDataSetID intermediateDataSetId, - ResultPartitionID resultPartitionId) {} - @Override public void failExternally(Throwable cause) {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 6742346b28acb5..57d8c6bdc20a32 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotReadyException; import org.apache.flink.runtime.clusterframework.types.AllocationID; +import org.apache.flink.runtime.concurrent.FutureUtils; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.execution.CancelTaskException; @@ -51,7 +52,7 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; @@ -60,7 +61,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; -import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; import org.apache.flink.runtime.query.TaskKvStateRegistry; @@ -68,7 +68,9 @@ import org.apache.flink.runtime.state.TaskStateManager; import org.apache.flink.runtime.taskexecutor.GlobalAggregateManager; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.util.FatalExitExceptionHandler; +import org.apache.flink.types.Either; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FlinkException; import org.apache.flink.util.Preconditions; @@ -92,7 +94,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -120,7 +121,7 @@ * *

Each Task is run by one dedicated thread. */ -public class Task implements Runnable, TaskActions, CheckpointListener { +public class Task implements Runnable, TaskActions, PartitionProducerStateProvider, CheckpointListener { /** The class logger. */ private static final Logger LOG = LoggerFactory.getLogger(Task.class); @@ -193,8 +194,6 @@ public class Task implements Runnable, TaskActions, CheckpointListener { private final SingleInputGate[] inputGates; - private final Map inputGatesById; - /** Connection to the task manager. */ private final TaskManagerActions taskManagerActions; @@ -383,7 +382,7 @@ public Task( // consumed intermediate result partitions this.inputGates = networkEnvironment.createInputGates( taskNameWithSubtaskAndId, - jobId, + executionId, this, inputGateDeploymentDescriptors, metrics.getIOMetricGroup(), @@ -391,11 +390,6 @@ public Task( buffersGroup, metrics.getIOMetricGroup().getNumBytesInCounter()); - this.inputGatesById = new HashMap<>(); - for (SingleInputGate inputGate : inputGates) { - inputGatesById.put(inputGate.getConsumedResultId(), inputGate); - } - invokableHasBeenCanceled = new AtomicBoolean(false); // finally, create the executing thread, but do not start it @@ -434,10 +428,6 @@ public Configuration getTaskConfiguration() { return this.taskConfiguration; } - public SingleInputGate getInputGateById(IntermediateDataSetID id) { - return inputGatesById.get(id); - } - public AccumulatorRegistry getAccumulatorRegistry() { return accumulatorRegistry; } @@ -1068,44 +1058,18 @@ else if (current == ExecutionState.RUNNING) { // ------------------------------------------------------------------------ @Override - public void triggerPartitionProducerStateCheck( - JobID jobId, - final IntermediateDataSetID intermediateDataSetId, - final ResultPartitionID resultPartitionId) { - - CompletableFuture futurePartitionState = + public CompletableFuture requestPartitionProducerState( + final IntermediateDataSetID intermediateDataSetId, + final ResultPartitionID resultPartitionId) { + final CompletableFuture futurePartitionState = partitionProducerStateChecker.requestPartitionProducerState( jobId, intermediateDataSetId, resultPartitionId); - - futurePartitionState.whenCompleteAsync( - (ExecutionState executionState, Throwable throwable) -> { - try { - if (executionState != null) { - onPartitionStateUpdate( - intermediateDataSetId, - resultPartitionId, - executionState); - } else if (throwable instanceof TimeoutException) { - // our request timed out, assume we're still running and try again - onPartitionStateUpdate( - intermediateDataSetId, - resultPartitionId, - ExecutionState.RUNNING); - } else if (throwable instanceof PartitionProducerDisposedException) { - String msg = String.format("Producer %s of partition %s disposed. Cancelling execution.", - resultPartitionId.getProducerId(), resultPartitionId.getPartitionId()); - LOG.info(msg, throwable); - cancelExecution(); - } else { - failExternally(throwable); - } - } catch (IOException | InterruptedException e) { - failExternally(e); - } - }, - executor); + final CompletableFuture result = + futurePartitionState.handleAsync(PartitionProducerStateResponseHandle::new, executor); + FutureUtils.assertNoException(result); + return result; } // ------------------------------------------------------------------------ @@ -1215,64 +1179,6 @@ public void run() { // ------------------------------------------------------------------------ - /** - * Answer to a partition state check issued after a failed partition request. - */ - @VisibleForTesting - void onPartitionStateUpdate( - IntermediateDataSetID intermediateDataSetId, - ResultPartitionID resultPartitionId, - ExecutionState producerState) throws IOException, InterruptedException { - - if (executionState == ExecutionState.RUNNING) { - final SingleInputGate inputGate = inputGatesById.get(intermediateDataSetId); - - if (inputGate != null) { - if (producerState == ExecutionState.SCHEDULED - || producerState == ExecutionState.DEPLOYING - || producerState == ExecutionState.RUNNING - || producerState == ExecutionState.FINISHED) { - - // Retrigger the partition request - inputGate.retriggerPartitionRequest(resultPartitionId.getPartitionId()); - - } else if (producerState == ExecutionState.CANCELING - || producerState == ExecutionState.CANCELED - || producerState == ExecutionState.FAILED) { - - // The producing execution has been canceled or failed. We - // don't need to re-trigger the request since it cannot - // succeed. - if (LOG.isDebugEnabled()) { - LOG.debug("Cancelling task {} after the producer of partition {} with attempt ID {} has entered state {}.", - taskNameWithSubtask, - resultPartitionId.getPartitionId(), - resultPartitionId.getProducerId(), - producerState); - } - - cancelExecution(); - } else { - // Any other execution state is unexpected. Currently, only - // state CREATED is left out of the checked states. If we - // see a producer in this state, something went wrong with - // scheduling in topological order. - String msg = String.format("Producer with attempt ID %s of partition %s in unexpected state %s.", - resultPartitionId.getProducerId(), - resultPartitionId.getPartitionId(), - producerState); - - failExternally(new IllegalStateException(msg)); - } - } else { - failExternally(new IllegalStateException("Received partition producer state for " + - "unknown input gate " + intermediateDataSetId + ".")); - } - } else { - LOG.debug("Task {} ignored a partition producer state notification, because it's not running.", taskNameWithSubtask); - } - } - /** * Utility method to dispatch an asynchronous call on the invokable. * @@ -1352,6 +1258,35 @@ public String toString() { return String.format("%s (%s) [%s]", taskNameWithSubtask, executionId, executionState); } + @VisibleForTesting + class PartitionProducerStateResponseHandle implements ResponseHandle { + private final Either result; + + PartitionProducerStateResponseHandle(@Nullable ExecutionState producerState, @Nullable Throwable t) { + this.result = producerState != null ? Either.Left(producerState) : Either.Right(t); + } + + @Override + public ExecutionState getConsumerExecutionState() { + return executionState; + } + + @Override + public Either getProducerExecutionState() { + return result; + } + + @Override + public void cancelConsumption() { + cancelExecution(); + } + + @Override + public void failConsumption(Throwable cause) { + failExternally(cause); + } + } + /** * Instantiates the given task invokable class, passing the given environment (and possibly * the initial task state) to the task's constructor. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/TaskActions.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/TaskActions.java index f7650d25750454..525d60fa9efae2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/TaskActions.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/TaskActions.java @@ -18,28 +18,11 @@ package org.apache.flink.runtime.taskmanager; -import org.apache.flink.api.common.JobID; -import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; - /** * Actions which can be performed on a {@link Task}. */ public interface TaskActions { - /** - * Check the execution state of the execution producing a result partition. - * - * @param jobId ID of the job the partition belongs to. - * @param intermediateDataSetId ID of the parent intermediate data set. - * @param resultPartitionId ID of the result partition to check. This - * identifies the producing execution and partition. - */ - void triggerPartitionProducerStateCheck( - JobID jobId, - IntermediateDataSetID intermediateDataSetId, - ResultPartitionID resultPartitionId); - /** * Fail the owning task with the given throwable. * diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java index 14ebabcf154a52..d670d01743e09f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition; -import org.apache.flink.api.common.JobID; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; @@ -30,10 +29,9 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -import org.apache.flink.runtime.taskmanager.NoOpTaskActions; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.function.SupplierWithException; import org.junit.Test; @@ -265,11 +263,9 @@ public void testFairConsumptionRemoteChannels() throws Exception { private SingleInputGate createFairnessVerifyingInputGate(int numberOfChannels) { return new FairnessVerifyingInputGate( "Test Task Name", - new JobID(), new IntermediateDataSetID(), 0, numberOfChannels, - new NoOpTaskActions(), true); } @@ -324,16 +320,20 @@ private static class FairnessVerifyingInputGate extends SingleInputGate { @SuppressWarnings("unchecked") public FairnessVerifyingInputGate( String owningTaskName, - JobID jobId, IntermediateDataSetID consumedResultId, int consumedSubpartitionIndex, int numberOfInputChannels, - TaskActions taskActions, boolean isCreditBased) { - super(owningTaskName, jobId, consumedResultId, ResultPartitionType.PIPELINED, - consumedSubpartitionIndex, numberOfInputChannels, taskActions, new SimpleCounter(), - isCreditBased, STUB_BUFFER_POOL_FACTORY); + super(owningTaskName, + consumedResultId, + ResultPartitionType.PIPELINED, + consumedSubpartitionIndex, + numberOfInputChannels, + SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, + new SimpleCounter(), + isCreditBased, + STUB_BUFFER_POOL_FACTORY); try { Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData"); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index 6f27baf7a0579a..51eba307403441 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -18,26 +18,28 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.api.common.JobID; import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; +import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider.ResponseHandle; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; -import org.apache.flink.runtime.taskmanager.NoOpTaskActions; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.function.SupplierWithException; import java.io.IOException; +import java.util.concurrent.CompletableFuture; /** * Utility class to encapsulate the logic of building a {@link SingleInputGate} instance. */ public class SingleInputGateBuilder { - private final JobID jobId = new JobID(); + private static final CompletableFuture NO_OP_PRODUCER_CHECKER_RESULT = new CompletableFuture<>(); + + public static final PartitionProducerStateProvider NO_OP_PRODUCER_CHECKER = (dsid, id) -> NO_OP_PRODUCER_CHECKER_RESULT; private final IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID(); @@ -47,7 +49,7 @@ public class SingleInputGateBuilder { private int numberOfChannels = 1; - private final TaskActions taskActions = new NoOpTaskActions(); + private final PartitionProducerStateProvider partitionProducerStateProvider = NO_OP_PRODUCER_CHECKER; private final Counter numBytesInCounter = new SimpleCounter(); @@ -92,12 +94,11 @@ public SingleInputGateBuilder setupBufferPoolFactory(NetworkEnvironment environm public SingleInputGate build() { return new SingleInputGate( "Single Input Gate", - jobId, intermediateDataSetID, partitionType, consumedSubpartitionIndex, numberOfChannels, - taskActions, + partitionProducerStateProvider, numBytesInCounter, isCreditBased, bufferPoolFactory); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index c159ce20d70c79..c87f99f4e932ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -18,9 +18,9 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.api.common.JobID; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; @@ -44,11 +44,12 @@ import org.apache.flink.runtime.io.network.util.TestTaskEvent; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; -import org.apache.flink.runtime.taskmanager.NoOpTaskActions; import org.junit.Test; import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; @@ -344,9 +345,8 @@ public void testRequestBackoffConfiguration() throws Exception { netEnv.getNetworkBufferPool()) .create( "TestTask", - new JobID(), gateDesc, - new NoOpTaskActions(), + SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, InputChannelTestUtils.newUnregisteredInputChannelMetrics(), new SimpleCounter()); @@ -546,8 +546,69 @@ public void testPartitionNotFoundExceptionWhileGetNextBuffer() throws Exception } } + @Test + public void testInputGateRemovalFromNetworkEnvironment() throws Exception { + NetworkEnvironment network = createNetworkEnvironment(); + + try { + int numberOfGates = 10; + Map createdInputGatesById = + createInputGateWithLocalChannels(network, numberOfGates, 1); + + assertEquals(numberOfGates, createdInputGatesById.size()); + + for (InputGateID id : createdInputGatesById.keySet()) { + assertThat(network.getInputGate(id).isPresent(), is(true)); + createdInputGatesById.get(id).close(); + assertThat(network.getInputGate(id).isPresent(), is(false)); + } + } finally { + network.shutdown(); + } + } + // --------------------------------------------------------------------------------------------- + private static Map createInputGateWithLocalChannels( + NetworkEnvironment network, + int numberOfGates, + @SuppressWarnings("SameParameterValue") int numberOfLocalChannels) { + InputChannelDeploymentDescriptor[] channelDescs = new InputChannelDeploymentDescriptor[numberOfLocalChannels]; + for (int i = 0; i < numberOfLocalChannels; i++) { + channelDescs[i] = new InputChannelDeploymentDescriptor( + new ResultPartitionID(), + ResultPartitionLocation.createLocal()); + } + + InputGateDeploymentDescriptor[] gateDescs = new InputGateDeploymentDescriptor[numberOfGates]; + IntermediateDataSetID[] ids = new IntermediateDataSetID[numberOfGates]; + for (int i = 0; i < numberOfGates; i++) { + ids[i] = new IntermediateDataSetID(); + gateDescs[i] = new InputGateDeploymentDescriptor( + ids[i], + ResultPartitionType.PIPELINED, + 0, + channelDescs); + } + + ExecutionAttemptID consumerID = new ExecutionAttemptID(); + SingleInputGate[] gates = network.createInputGates( + "", + consumerID, + SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, + Arrays.asList(gateDescs), + new UnregisteredMetricsGroup(), + new UnregisteredMetricsGroup(), + new UnregisteredMetricsGroup(), + new SimpleCounter()); + Map inputGatesById = new HashMap<>(); + for (int i = 0; i < numberOfGates; i++) { + inputGatesById.put(new InputGateID(ids[i], consumerID), gates[i]); + } + + return inputGatesById; + } + private void addUnknownInputChannel( NetworkEnvironment network, SingleInputGate inputGate, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorSubmissionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorSubmissionTest.java index cde72594050c07..8a6f784d878635 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorSubmissionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorSubmissionTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.NetworkEnvironmentOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.runtime.blob.PermanentBlobKey; import org.apache.flink.runtime.clusterframework.types.AllocationID; @@ -38,7 +39,7 @@ import org.apache.flink.runtime.executiongraph.PartitionInfo; import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.io.network.ConnectionID; -import org.apache.flink.configuration.NetworkEnvironmentOptions; +import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; @@ -53,8 +54,8 @@ import org.apache.flink.runtime.jobmaster.utils.TestingJobMasterGatewayBuilder; import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.messages.StackTraceSampleResponse; -import org.apache.flink.runtime.taskexecutor.exceptions.PartitionUpdateException; import org.apache.flink.runtime.taskexecutor.slot.TaskSlotTable; +import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.testtasks.BlockingNoOpInvokable; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.NetUtils; @@ -65,6 +66,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; +import org.mockito.Mockito; import java.io.IOException; import java.net.InetSocketAddress; @@ -75,13 +77,15 @@ import java.util.List; import java.util.concurrent.CompletableFuture; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; /** * Tests for submission logic of the {@link TaskExecutor}. @@ -259,7 +263,7 @@ public void testRunJobWithForwardChannel() throws Exception { createTestTaskDeploymentDescriptor( "Sender", eid1, - TestingAbstractInvokables.Sender.class, + TestingAbstractInvokables.Sender.class, 1, Collections.singletonList(task1ResultPartitionDescriptor), Collections.emptyList()); @@ -294,7 +298,7 @@ public void testRunJobWithForwardChannel() throws Exception { .addTaskManagerActionListener(eid2, ExecutionState.FINISHED, task2FinishedFuture) .setJobMasterId(jobMasterId) .setJobMasterGateway(testingJobMasterGateway) - .setMockNetworkEnvironment(false) + .useRealNonMockNetworkEnvironment() .build()) { TaskExecutorGateway tmGateway = env.getTaskExecutorGateway(); TaskSlotTable taskSlotTable = env.getTaskSlotTable(); @@ -379,7 +383,7 @@ public void testCancellingDependentAndStateUpdateFails() throws Exception { .addTaskManagerActionListener(eid2, ExecutionState.CANCELED, task2CanceledFuture) .setJobMasterId(jobMasterId) .setJobMasterGateway(testingJobMasterGateway) - .setMockNetworkEnvironment(false) + .useRealNonMockNetworkEnvironment() .build()) { TaskExecutorGateway tmGateway = env.getTaskExecutorGateway(); TaskSlotTable taskSlotTable = env.getTaskSlotTable(); @@ -447,7 +451,7 @@ public void testRemotePartitionNotFound() throws Exception { .addTaskManagerActionListener(eid, ExecutionState.FAILED, taskFailedFuture) .setConfiguration(config) .setLocalCommunication(false) - .setMockNetworkEnvironment(false) + .useRealNonMockNetworkEnvironment() .build()) { TaskExecutorGateway tmGateway = env.getTaskExecutorGateway(); TaskSlotTable taskSlotTable = env.getTaskSlotTable(); @@ -462,7 +466,7 @@ public void testRemotePartitionNotFound() throws Exception { } /** - * Tests that the TaskManager sends proper exception back to the sender if the partition update fails. + * Tests that the TaskManager fails the task if the partition update fails. */ @Test public void testUpdateTaskInputPartitionsFailure() throws Exception { @@ -471,11 +475,19 @@ public void testUpdateTaskInputPartitionsFailure() throws Exception { final TaskDeploymentDescriptor tdd = createTestTaskDeploymentDescriptor("test task", eid, BlockingNoOpInvokable.class); final CompletableFuture taskRunningFuture = new CompletableFuture<>(); + final CompletableFuture taskFailedFuture = new CompletableFuture<>(); + final PartitionInfo partitionUpdate = new PartitionInfo( + new IntermediateDataSetID(), + new InputChannelDeploymentDescriptor(new ResultPartitionID(), ResultPartitionLocation.createLocal())); + final NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class, Mockito.RETURNS_MOCKS); + doThrow(new IOException()).when(networkEnvironment).updatePartitionInfo(eid, partitionUpdate); try (TaskSubmissionTestEnvironment env = new TaskSubmissionTestEnvironment.Builder(jobId) + .setNetworkEnvironment(networkEnvironment) .setSlotSize(1) .addTaskManagerActionListener(eid, ExecutionState.RUNNING, taskRunningFuture) + .addTaskManagerActionListener(eid, ExecutionState.FAILED, taskFailedFuture) .build()) { TaskExecutorGateway tmGateway = env.getTaskExecutorGateway(); TaskSlotTable taskSlotTable = env.getTaskSlotTable(); @@ -486,17 +498,14 @@ public void testUpdateTaskInputPartitionsFailure() throws Exception { CompletableFuture updateFuture = tmGateway.updatePartitions( eid, - Collections.singletonList( - new PartitionInfo( - new IntermediateDataSetID(), - new InputChannelDeploymentDescriptor(new ResultPartitionID(), ResultPartitionLocation.createLocal()))), + Collections.singletonList(partitionUpdate), timeout); - try { - updateFuture.get(); - fail(); - } catch (Exception e) { - assertTrue(ExceptionUtils.findThrowable(e, PartitionUpdateException.class).isPresent()); - } + + updateFuture.get(); + taskFailedFuture.get(); + Task task = taskSlotTable.getTask(tdd.getExecutionAttemptId()); + assertThat(task.getExecutionState(), is(ExecutionState.FAILED)); + assertThat(task.getFailureCause(), instanceOf(IOException.class)); } } @@ -539,7 +548,7 @@ public void testLocalPartitionNotFound() throws Exception { .addTaskManagerActionListener(eid, ExecutionState.RUNNING, taskRunningFuture) .addTaskManagerActionListener(eid, ExecutionState.FAILED, taskFailedFuture) .setConfiguration(config) - .setMockNetworkEnvironment(false) + .useRealNonMockNetworkEnvironment() .build()) { TaskExecutorGateway tmGateway = env.getTaskExecutorGateway(); TaskSlotTable taskSlotTable = env.getTaskSlotTable(); @@ -611,7 +620,7 @@ public void testFailingScheduleOrUpdateConsumers() throws Exception { .addTaskManagerActionListener(eid, ExecutionState.RUNNING, taskRunningFuture) .setJobMasterId(jobMasterId) .setJobMasterGateway(testingJobMasterGateway) - .setMockNetworkEnvironment(false) + .useRealNonMockNetworkEnvironment() .build()) { TaskExecutorGateway tmGateway = env.getTaskExecutorGateway(); TaskSlotTable taskSlotTable = env.getTaskSlotTable(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java index 58f234cca1d23c..4d0b4d5a029623 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java @@ -59,7 +59,6 @@ import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskSubmissionTestEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskSubmissionTestEnvironment.java index d0d142f3292dfa..f0364a6700fc9e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskSubmissionTestEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskSubmissionTestEnvironment.java @@ -37,7 +37,6 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.netty.NettyConfig; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.jobmaster.JobMasterGateway; import org.apache.flink.runtime.jobmaster.JobMasterId; import org.apache.flink.runtime.jobmaster.utils.TestingJobMasterGateway; @@ -59,6 +58,7 @@ import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.ConfigurationParserUtils; import org.apache.flink.runtime.util.TestingFatalErrorHandler; +import org.apache.flink.util.FlinkRuntimeException; import org.junit.rules.TemporaryFolder; import org.mockito.Mockito; @@ -70,6 +70,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import static org.mockito.ArgumentMatchers.any; @@ -83,7 +84,7 @@ class TaskSubmissionTestEnvironment implements AutoCloseable { private final HeartbeatServices heartbeatServices = new HeartbeatServices(1000L, 1000L); - private final TestingRpcService testingRpcService = new TestingRpcService(); + private final TestingRpcService testingRpcService; private final BlobCacheService blobCacheService= new BlobCacheService(new Configuration(), new VoidBlobStore(), null); private final Time timeout = Time.milliseconds(10000L); private final TestingFatalErrorHandler testingFatalErrorHandler = new TestingFatalErrorHandler(); @@ -100,11 +101,11 @@ private TaskSubmissionTestEnvironment( JobID jobId, JobMasterId jobMasterId, int slotSize, - boolean mockNetworkEnvironment, TestingJobMasterGateway testingJobMasterGateway, Configuration configuration, - boolean localCommunication, - List>> taskManagerActionListeners) throws Exception { + List>> taskManagerActionListeners, + TestingRpcService testingRpcService, + NetworkEnvironment networkEnvironment) throws Exception { this.haServices = new TestingHighAvailabilityServices(); this.haServices.setResourceManagerLeaderRetriever(new SettableLeaderRetrievalService()); @@ -143,8 +144,7 @@ private TaskSubmissionTestEnvironment( taskManagerActions = testTaskManagerActions; } - final NetworkEnvironment networkEnvironment = createNetworkEnvironment(localCommunication, configuration, testingRpcService, mockNetworkEnvironment); - + this.testingRpcService = testingRpcService; final JobManagerConnection jobManagerConnection = createJobManagerConnection(jobId, jobMasterGateway, testingRpcService, taskManagerActions, timeout); final JobManagerTable jobManagerTable = new JobManagerTable(); jobManagerTable.put(jobId, jobManagerConnection); @@ -250,6 +250,7 @@ private static NetworkEnvironment createNetworkEnvironment( .setPartitionRequestMaxBackoff(configuration.getInteger(NetworkEnvironmentOptions.NETWORK_REQUEST_BACKOFF_MAX)) .setNettyConfig(localCommunication ? null : nettyConfig) .build(); + networkEnvironment.start(); } @@ -278,6 +279,8 @@ public static final class Builder { private TestingJobMasterGateway jobMasterGateway; private boolean localCommunication = true; private Configuration configuration = new Configuration(); + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private Optional optionalNetworkEnvironment = Optional.empty(); private List>> taskManagerActionListeners = new ArrayList<>(); @@ -285,8 +288,15 @@ public Builder(JobID jobId) { this.jobId = jobId; } - public Builder setMockNetworkEnvironment(boolean mockNetworkEnvironment) { - this.mockNetworkEnvironment = mockNetworkEnvironment; + public Builder useRealNonMockNetworkEnvironment() { + this.optionalNetworkEnvironment = Optional.empty(); + this.mockNetworkEnvironment = false; + return this; + } + + public Builder setNetworkEnvironment(NetworkEnvironment optionalNetworkEnvironment) { + this.mockNetworkEnvironment = false; + this.optionalNetworkEnvironment = Optional.of(optionalNetworkEnvironment); return this; } @@ -321,15 +331,23 @@ public Builder addTaskManagerActionListener(ExecutionAttemptID eid, ExecutionSta } public TaskSubmissionTestEnvironment build() throws Exception { + final TestingRpcService testingRpcService = new TestingRpcService(); + final NetworkEnvironment network = optionalNetworkEnvironment.orElseGet(() -> { + try { + return createNetworkEnvironment(localCommunication, configuration, testingRpcService, mockNetworkEnvironment); + } catch (Exception e) { + throw new FlinkRuntimeException("Failed to build TaskSubmissionTestEnvironment", e); + } + }); return new TaskSubmissionTestEnvironment( jobId, jobMasterId, slotSize, - mockNetworkEnvironment, jobMasterGateway, configuration, - localCommunication, - taskManagerActionListeners); + taskManagerActionListeners, + testingRpcService, + network); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index 1f743e4242a659..01cd18477dc332 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -43,7 +43,7 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index 75c7c898b6d5e9..8ed009e549c528 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -51,11 +51,11 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.io.network.partition.consumer.RemoteChannelStateChecker; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; @@ -69,6 +69,7 @@ import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.taskexecutor.TestGlobalAggregateManager; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; @@ -98,18 +99,19 @@ import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -555,18 +557,14 @@ public void testCancelTaskExceptionAfterTaskMarkedFailed() throws Exception { @Test public void testOnPartitionStateUpdate() throws Exception { - final IntermediateDataSetID resultId = new IntermediateDataSetID(); final ResultPartitionID partitionId = new ResultPartitionID(); - final SingleInputGate inputGate = mock(SingleInputGate.class); - when(inputGate.getConsumedResultId()).thenReturn(resultId); - final Task task = createTaskBuilder() .setInvokable(InvokableBlockingInInvoke.class) .build(); - // Set the mock input gate - setInputGate(task, inputGate); + RemoteChannelStateChecker checker = + new RemoteChannelStateChecker(partitionId, "test task"); // Expected task state for each producer state final Map expected = new HashMap<>(ExecutionState.values().length); @@ -585,17 +583,20 @@ public void testOnPartitionStateUpdate() throws Exception { expected.put(ExecutionState.CANCELING, ExecutionState.CANCELING); expected.put(ExecutionState.FAILED, ExecutionState.CANCELING); + int producingStateCounter = 0; for (ExecutionState state : ExecutionState.values()) { setState(task, ExecutionState.RUNNING); - task.onPartitionStateUpdate(resultId, partitionId, state); + if (checker.isProducerReadyOrAbortConsumption(task.new PartitionProducerStateResponseHandle(state, null))) { + producingStateCounter++; + } ExecutionState newTaskState = task.getExecutionState(); assertEquals(expected.get(state), newTaskState); } - verify(inputGate, times(4)).retriggerPartitionRequest(eq(partitionId.getPartitionId())); + assertEquals(4, producingStateCounter); } /** @@ -610,6 +611,11 @@ public void testTriggerPartitionStateUpdate() throws Exception { final ResultPartitionConsumableNotifier consumableNotifier = new NoOpResultPartitionConsumableNotifier(); + AtomicInteger callCount = new AtomicInteger(0); + + RemoteChannelStateChecker remoteChannelStateChecker = + new RemoteChannelStateChecker(partitionId, "test task"); + // Test all branches of trigger partition state check { // Reset latches @@ -622,11 +628,14 @@ public void testTriggerPartitionStateUpdate() throws Exception { .setPartitionProducerStateChecker(partitionChecker) .setExecutor(Executors.directExecutor()) .build(); + setState(task, ExecutionState.RUNNING); final CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); - task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); + task.requestPartitionProducerState(resultId, partitionId).thenAccept(checkResult -> + assertThat(remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult), is(false)) + ); promise.completeExceptionally(new PartitionProducerDisposedException(partitionId)); assertEquals(ExecutionState.CANCELING, task.getExecutionState()); @@ -643,11 +652,14 @@ public void testTriggerPartitionStateUpdate() throws Exception { .setPartitionProducerStateChecker(partitionChecker) .setExecutor(Executors.directExecutor()) .build(); + setState(task, ExecutionState.RUNNING); final CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); - task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); + task.requestPartitionProducerState(resultId, partitionId).thenAccept(checkResult -> + assertThat(remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult), is(false)) + ); promise.completeExceptionally(new RuntimeException("Any other exception")); @@ -655,6 +667,8 @@ public void testTriggerPartitionStateUpdate() throws Exception { } { + callCount.set(0); + // Reset latches setup(); @@ -667,25 +681,24 @@ public void testTriggerPartitionStateUpdate() throws Exception { .setExecutor(Executors.directExecutor()) .build(); - final SingleInputGate inputGate = mock(SingleInputGate.class); - when(inputGate.getConsumedResultId()).thenReturn(resultId); - try { task.startTaskThread(); awaitLatch.await(); - setInputGate(task, inputGate); - CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); - task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); + task.requestPartitionProducerState(resultId, partitionId).thenAccept(checkResult -> { + if (remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult)) { + callCount.incrementAndGet(); + } + }); promise.completeExceptionally(new TimeoutException()); assertEquals(ExecutionState.RUNNING, task.getExecutionState()); - verify(inputGate, times(1)).retriggerPartitionRequest(eq(partitionId.getPartitionId())); + assertEquals(1, callCount.get()); } finally { task.getExecutingThread().interrupt(); task.getExecutingThread().join(); @@ -693,6 +706,8 @@ public void testTriggerPartitionStateUpdate() throws Exception { } { + callCount.set(0); + // Reset latches setup(); @@ -704,25 +719,24 @@ public void testTriggerPartitionStateUpdate() throws Exception { .setExecutor(Executors.directExecutor()) .build(); - final SingleInputGate inputGate = mock(SingleInputGate.class); - when(inputGate.getConsumedResultId()).thenReturn(resultId); - try { task.startTaskThread(); awaitLatch.await(); - setInputGate(task, inputGate); - CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); - task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); + task.requestPartitionProducerState(resultId, partitionId).thenAccept(checkResult -> { + if (remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult)) { + callCount.incrementAndGet(); + } + }); promise.complete(ExecutionState.RUNNING); assertEquals(ExecutionState.RUNNING, task.getExecutionState()); - verify(inputGate, times(1)).retriggerPartitionRequest(eq(partitionId.getPartitionId())); + assertEquals(1, callCount.get()); } finally { task.getExecutingThread().interrupt(); task.getExecutingThread().join(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java index 1b2ec490f5197a..51930a0e756e82 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java @@ -46,7 +46,7 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java index af5922096139f9..fccb5db8ee4d76 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java @@ -44,11 +44,11 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory; import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; -import org.apache.flink.runtime.taskmanager.NoOpTaskActions; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.ConfigurationParserUtils; @@ -166,7 +166,6 @@ public SerializingLongReceiver createReceiver() throws Exception { senderEnv.getConnectionManager().getDataPort()); InputGate receiverGate = createInputGate( - jobId, dataSetID, executionAttemptID, senderLocation, @@ -228,7 +227,6 @@ protected ResultPartitionWriter createResultPartition( } private InputGate createInputGate( - JobID jobId, IntermediateDataSetID dataSetID, ExecutionAttemptID executionAttemptID, final TaskManagerLocation senderLocation, @@ -258,9 +256,8 @@ private InputGate createInputGate( environment.getNetworkBufferPool()) .create( "receiving task[" + channel + "]", - jobId, gateDescriptor, - new NoOpTaskActions(), + SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, InputChannelTestUtils.newUnregisteredInputChannelMetrics(), new SimpleCounter()); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index 24b1047aa0dbad..ce5a7e168094d7 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -46,7 +46,6 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; @@ -67,6 +66,7 @@ import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.taskexecutor.TestGlobalAggregateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java index c079d15bce7a30..4cffdbd700372c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java @@ -46,7 +46,6 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; @@ -69,6 +68,7 @@ import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorage; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.taskexecutor.TestGlobalAggregateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index af779b674c049c..343957457ad941 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -52,7 +52,6 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; @@ -87,6 +86,7 @@ import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorage; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.taskexecutor.TestGlobalAggregateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.NoOpTaskManagerActions; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java index 778a7d0e1c08c8..0416e5c84d918b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java @@ -45,7 +45,6 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -58,6 +57,7 @@ import org.apache.flink.runtime.state.CheckpointStorageLocationReference; import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.taskexecutor.TestGlobalAggregateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java index 9418e142257de9..a6268cea8413ab 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java @@ -48,7 +48,6 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; @@ -73,6 +72,7 @@ import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.state.testutils.BackendForTestStream; import org.apache.flink.runtime.taskexecutor.KvStateService; +import org.apache.flink.runtime.taskexecutor.PartitionProducerStateChecker; import org.apache.flink.runtime.taskexecutor.TestGlobalAggregateManager; import org.apache.flink.runtime.taskmanager.CheckpointResponder; import org.apache.flink.runtime.taskmanager.Task; From 836fdfff0db64ff8241f38e8dd362dd50a9d1895 Mon Sep 17 00:00:00 2001 From: Wei Zhong Date: Fri, 24 May 2019 14:45:30 +0800 Subject: [PATCH 35/92] [FLINK-12440][python] Add all connector support align Java Table API. This closes #8531 --- .../flink-bin/bin/pyflink-gateway-server.sh | 2 +- flink-python/pyflink/table/__init__.py | 4 +- .../pyflink/table/table_descriptor.py | 485 +++++++++++++++++- .../pyflink/table/tests/test_descriptor.py | 407 ++++++++++++++- tools/travis_controller.sh | 2 + 5 files changed, 883 insertions(+), 17 deletions(-) diff --git a/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh b/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh index 026f813a053f80..9e41ad5e1c1d83 100644 --- a/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh +++ b/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh @@ -63,7 +63,7 @@ if [[ -n "$FLINK_TESTING" ]]; then else FLINK_TEST_CLASSPATH="$FLINK_TEST_CLASSPATH":"$testJarFile" fi - done < <(find "$FLINK_SOURCE_ROOT_DIR" ! -type d -name 'flink-*-tests.jar' -print0 | sort -z) + done < <(find "$FLINK_SOURCE_ROOT_DIR" ! -type d \( -name 'flink-*-tests.jar' -o -path "${FLINK_SOURCE_ROOT_DIR}/flink-connectors/flink-connector-elasticsearch-base/target/flink*.jar" -o -path "${FLINK_SOURCE_ROOT_DIR}/flink-connectors/flink-connector-kafka-base/target/flink*.jar" \) -print0 | sort -z) fi exec $JAVA_RUN $JVM_ARGS "${log_setting[@]}" -cp ${FLINK_CLASSPATH}:${TABLE_JAR_PATH}:${FLINK_TEST_CLASSPATH} ${DRIVER} ${ARGS[@]} diff --git a/flink-python/pyflink/table/__init__.py b/flink-python/pyflink/table/__init__.py index 281647f3e92609..904264edc06b62 100644 --- a/flink-python/pyflink/table/__init__.py +++ b/flink-python/pyflink/table/__init__.py @@ -40,7 +40,7 @@ from pyflink.table.table_source import TableSource, CsvTableSource from pyflink.table.types import DataTypes, UserDefinedType, Row from pyflink.table.window import Tumble, Session, Slide, Over -from pyflink.table.table_descriptor import Rowtime, Schema, OldCsv, FileSystem +from pyflink.table.table_descriptor import Rowtime, Schema, OldCsv, FileSystem, Kafka, Elasticsearch __all__ = [ 'TableEnvironment', @@ -63,4 +63,6 @@ 'FileSystem', 'UserDefinedType', 'Row', + 'Kafka', + 'Elasticsearch' ] diff --git a/flink-python/pyflink/table/table_descriptor.py b/flink-python/pyflink/table/table_descriptor.py index 1dfbde315d3459..65161b4eb17b22 100644 --- a/flink-python/pyflink/table/table_descriptor.py +++ b/flink-python/pyflink/table/table_descriptor.py @@ -30,7 +30,9 @@ 'Rowtime', 'Schema', 'OldCsv', - 'FileSystem' + 'FileSystem', + 'Kafka', + 'Elasticsearch' ] @@ -256,7 +258,7 @@ class OldCsv(FormatDescriptor): format in the dedicated `flink-formats/flink-csv` module instead when writing to Kafka. Use the old one for stream/batch filesystem operations for now. - .. note:: + ..note:: Deprecated: use the RFC-compliant `Csv` format instead when writing to Kafka. """ @@ -373,6 +375,485 @@ def path(self, path_str): return self +class Kafka(ConnectorDescriptor): + """ + Connector descriptor for the Apache Kafka message queue. + """ + + def __init__(self): + gateway = get_gateway() + self._j_kafka = gateway.jvm.Kafka() + super(Kafka, self).__init__(self._j_kafka) + + def version(self, version): + """ + Sets the Kafka version to be used. + + :param version: Kafka version. E.g., "0.8", "0.11", etc. + :return: This object. + """ + if not isinstance(version, (str, unicode)): + version = str(version) + self._j_kafka = self._j_kafka.version(version) + return self + + def topic(self, topic): + """ + Sets the topic from which the table is read. + + :param topic: The topic from which the table is read. + :return: This object. + """ + self._j_kafka = self._j_kafka.topic(topic) + return self + + def properties(self, property_dict): + """ + Sets the configuration properties for the Kafka consumer. Resets previously set properties. + + :param property_dict: The dict object contains configuration properties for the Kafka + consumer. Both the keys and values should be strings. + :return: This object. + """ + gateway = get_gateway() + properties = gateway.jvm.java.util.Properties() + for key in property_dict: + properties.setProperty(key, property_dict[key]) + self._j_kafka = self._j_kafka.properties(properties) + return self + + def property(self, key, value): + """ + Adds a configuration properties for the Kafka consumer. + + :param key: Property key string for the Kafka consumer. + :param value: Property value string for the Kafka consumer. + :return: This object. + """ + self._j_kafka = self._j_kafka.property(key, value) + return self + + def start_from_earliest(self): + """ + Specifies the consumer to start reading from the earliest offset for all partitions. + This lets the consumer ignore any committed group offsets in Zookeeper / Kafka brokers. + + This method does not affect where partitions are read from when the consumer is restored + from a checkpoint or savepoint. When the consumer is restored from a checkpoint or + savepoint, only the offsets in the restored state will be used. + + :return: This object. + """ + self._j_kafka = self._j_kafka.startFromEarliest() + return self + + def start_from_latest(self): + """ + Specifies the consumer to start reading from the latest offset for all partitions. + This lets the consumer ignore any committed group offsets in Zookeeper / Kafka brokers. + + This method does not affect where partitions are read from when the consumer is restored + from a checkpoint or savepoint. When the consumer is restored from a checkpoint or + savepoint, only the offsets in the restored state will be used. + + :return: This object. + """ + self._j_kafka = self._j_kafka.startFromLatest() + return self + + def start_from_group_offsets(self): + """ + Specifies the consumer to start reading from any committed group offsets found + in Zookeeper / Kafka brokers. The "group.id" property must be set in the configuration + properties. If no offset can be found for a partition, the behaviour in "auto.offset.reset" + set in the configuration properties will be used for the partition. + + This method does not affect where partitions are read from when the consumer is restored + from a checkpoint or savepoint. When the consumer is restored from a checkpoint or + savepoint, only the offsets in the restored state will be used. + + :return: This object. + """ + self._j_kafka = self._j_kafka.startFromGroupOffsets() + return self + + def start_from_specific_offsets(self, specific_offsets_dict): + """ + Specifies the consumer to start reading partitions from specific offsets, set independently + for each partition. The specified offset should be the offset of the next record that will + be read from partitions. This lets the consumer ignore any committed group offsets in + Zookeeper / Kafka brokers. + + If the provided map of offsets contains entries whose partition is not subscribed by the + consumer, the entry will be ignored. If the consumer subscribes to a partition that does + not exist in the provided map of offsets, the consumer will fallback to the default group + offset behaviour(see :func:`pyflink.table.table_descriptor.Kafka.start_from_group_offsets`) + for that particular partition. + + If the specified offset for a partition is invalid, or the behaviour for that partition is + defaulted to group offsets but still no group offset could be found for it, then the + "auto.offset.reset" behaviour set in the configuration properties will be used for the + partition. + + This method does not affect where partitions are read from when the consumer is restored + from a checkpoint or savepoint. When the consumer is restored from a checkpoint or + savepoint, only the offsets in the restored state will be used. + + :param specific_offsets_dict: Dict of specific_offsets that the key is int-type partition + id and value is int-type offset value. + :return: This object. + """ + for key in specific_offsets_dict: + self.start_from_specific_offset(key, specific_offsets_dict[key]) + return self + + def start_from_specific_offset(self, partition, specific_offset): + """ + Configures to start reading partitions from specific offsets and specifies the given offset + for the given partition. + + see :func:`pyflink.table.table_descriptor.Kafka.start_from_specific_offsets` + + :param partition: + :param specific_offset: + :return: This object. + """ + self._j_kafka = self._j_kafka.startFromSpecificOffset(int(partition), int(specific_offset)) + return self + + def sink_partitioner_fixed(self): + """ + Configures how to partition records from Flink's partitions into Kafka's partitions. + + This strategy ensures that each Flink partition ends up in one Kafka partition. + + ..note:: + One Kafka partition can contain multiple Flink partitions. Examples: + + More Flink partitions than Kafka partitions. Some (or all) Kafka partitions contain + the output of more than one flink partition: + + | Flink Sinks --------- Kafka Partitions + | 1 ----------------> 1 + | 2 --------------/ + | 3 -------------/ + | 4 ------------/ + + Fewer Flink partitions than Kafka partitions: + + | Flink Sinks --------- Kafka Partitions + | 1 ----------------> 1 + | 2 ----------------> 2 + | 3 + | 4 + | 5 + + :return: This object. + """ + self._j_kafka = self._j_kafka.sinkPartitionerFixed() + return self + + def sink_partitioner_round_robin(self): + """ + Configures how to partition records from Flink's partitions into Kafka's partitions. + + This strategy ensures that records will be distributed to Kafka partitions in a + round-robin fashion. + + ..note:: + This strategy is useful to avoid an unbalanced partitioning. However, it will cause a + lot of network connections between all the Flink instances and all the Kafka brokers. + + :return: This object. + """ + self._j_kafka = self._j_kafka.sinkPartitionerRoundRobin() + return self + + def sink_partitioner_custom(self, partitioner_class_name): + """ + Configures how to partition records from Flink's partitions into Kafka's partitions. + + This strategy allows for a custom partitioner by providing an implementation + of ``FlinkKafkaPartitioner``. + + :param partitioner_class_name: The java canonical class name of the FlinkKafkaPartitioner. + The FlinkKafkaPartitioner must have a public no-argument + constructor and can be founded by in current Java + classloader. + :return: This object. + """ + gateway = get_gateway() + self._j_kafka = self._j_kafka.sinkPartitionerCustom( + gateway.jvm.Thread.currentThread().getContextClassLoader() + .loadClass(partitioner_class_name)) + return self + + +class Elasticsearch(ConnectorDescriptor): + """ + Connector descriptor for the Elasticsearch search engine. + """ + + def __init__(self): + gateway = get_gateway() + self._j_elasticsearch = gateway.jvm.Elasticsearch() + super(Elasticsearch, self).__init__(self._j_elasticsearch) + + def version(self, version): + """ + Sets the Elasticsearch version to be used. Required. + + :param version: Elasticsearch version. E.g., "6". + :return: This object. + """ + if not isinstance(version, (str, unicode)): + version = str(version) + self._j_elasticsearch = self._j_elasticsearch.version(version) + return self + + def host(self, hostname, port, protocol): + """ + Adds an Elasticsearch host to connect to. Required. + + Multiple hosts can be declared by calling this method multiple times. + + :param hostname: Connection hostname. + :param port: Connection port. + :param protocol: Connection protocol; e.g. "http". + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.host(hostname, int(port), protocol) + return self + + def index(self, index): + """ + Declares the Elasticsearch index for every record. Required. + + :param index: Elasticsearch index. + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.index(index) + return self + + def document_type(self, document_type): + """ + Declares the Elasticsearch document type for every record. Required. + + :param document_type: Elasticsearch document type. + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.documentType(document_type) + return self + + def key_delimiter(self, key_delimiter): + """ + Sets a custom key delimiter in case the Elasticsearch ID needs to be constructed from + multiple fields. Optional. + + :param key_delimiter: Key delimiter; e.g., "$" would result in IDs "KEY1$KEY2$KEY3". + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.keyDelimiter(key_delimiter) + return self + + def key_null_literal(self, key_null_literal): + """ + Sets a custom representation for null fields in keys. Optional. + + :param key_null_literal: key null literal string; e.g. "N/A" would result in IDs + "KEY1_N/A_KEY3". + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.keyNullLiteral(key_null_literal) + return self + + def failure_handler_fail(self): + """ + Configures a failure handling strategy in case a request to Elasticsearch fails. + + This strategy throws an exception if a request fails and thus causes a job failure. + + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.failureHandlerFail() + return self + + def failure_handler_ignore(self): + """ + Configures a failure handling strategy in case a request to Elasticsearch fails. + + This strategy ignores failures and drops the request. + + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.failureHandlerIgnore() + return self + + def failure_handler_retry_rejected(self): + """ + Configures a failure handling strategy in case a request to Elasticsearch fails. + + This strategy re-adds requests that have failed due to queue capacity saturation. + + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.failureHandlerRetryRejected() + return self + + def failure_handler_custom(self, failure_handler_class_name): + """ + Configures a failure handling strategy in case a request to Elasticsearch fails. + + This strategy allows for custom failure handling using a ``ActionRequestFailureHandler``. + + :param failure_handler_class_name: + :return: This object. + """ + gateway = get_gateway() + self._j_elasticsearch = self._j_elasticsearch.failureHandlerCustom( + gateway.jvm.Thread.currentThread().getContextClassLoader() + .loadClass(failure_handler_class_name)) + return self + + def disable_flush_on_checkpoint(self): + """ + Disables flushing on checkpoint. When disabled, a sink will not wait for all pending action + requests to be acknowledged by Elasticsearch on checkpoints. + + ..note:: + If flushing on checkpoint is disabled, a Elasticsearch sink does NOT + provide any strong guarantees for at-least-once delivery of action requests. + + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.disableFlushOnCheckpoint() + return self + + def bulk_flush_max_actions(self, max_actions_num): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets the maximum number of actions to buffer for each bulk request. + + :param max_actions_num: the maximum number of actions to buffer per bulk request. + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushMaxActions(int(max_actions_num)) + return self + + def bulk_flush_max_size(self, max_size): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets the maximum size of buffered actions per bulk request (using the syntax of + MemorySize). + + :param max_size: The maximum size. E.g. "42 mb". only MB granularity is supported. + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushMaxSize(max_size) + return self + + def bulk_flush_interval(self, interval): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets the bulk flush interval (in milliseconds). + + :param interval: Bulk flush interval (in milliseconds). + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushInterval(int(interval)) + return self + + def bulk_flush_backoff_constant(self): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets a constant backoff type to use when flushing bulk requests. + + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushBackoffConstant() + return self + + def bulk_flush_backoff_exponential(self): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets an exponential backoff type to use when flushing bulk requests. + + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushBackoffExponential() + return self + + def bulk_flush_backoff_max_retries(self, max_retries): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets the maximum number of retries for a backoff attempt when flushing bulk requests. + + Make sure to enable backoff by selecting a strategy ( + :func:`pyflink.table.table_descriptor.Elasticsearch.bulk_flush_backoff_constant` or + :func:`pyflink.table.table_descriptor.Elasticsearch.bulk_flush_backoff_exponential`). + + :param max_retries: The maximum number of retries. + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushBackoffMaxRetries(int(max_retries)) + return self + + def bulk_flush_backoff_delay(self, delay): + """ + Configures how to buffer elements before sending them in bulk to the cluster for + efficiency. + + Sets the amount of delay between each backoff attempt when flushing bulk requests + (in milliseconds). + + Make sure to enable backoff by selecting a strategy ( + :func:`pyflink.table.table_descriptor.Elasticsearch.bulk_flush_backoff_constant` or + :func:`pyflink.table.table_descriptor.Elasticsearch.bulk_flush_backoff_exponential`). + + :param delay: Delay between each backoff attempt (in milliseconds). + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.bulkFlushBackoffDelay(int(delay)) + return self + + def connection_max_retry_timeout(self, max_retry_timeout): + """ + Sets connection properties to be used during REST communication to Elasticsearch. + + Sets the maximum timeout (in milliseconds) in case of multiple retries of the same request. + + :param max_retry_timeout: Maximum timeout (in milliseconds). + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.connectionMaxRetryTimeout( + int(max_retry_timeout)) + return self + + def connection_path_prefix(self, path_prefix): + """ + Sets connection properties to be used during REST communication to Elasticsearch. + + Adds a path prefix to every REST communication. + + :param path_prefix: Prefix string to be added to every REST communication. + :return: This object. + """ + self._j_elasticsearch = self._j_elasticsearch.connectionPathPrefix(path_prefix) + return self + + class ConnectTableDescriptor(Descriptor): """ Common class for table's created with :class:`pyflink.table.TableEnvironment.connect`. diff --git a/flink-python/pyflink/table/tests/test_descriptor.py b/flink-python/pyflink/table/tests/test_descriptor.py index 4fcc35532ad964..c9370fa5e7c1f0 100644 --- a/flink-python/pyflink/table/tests/test_descriptor.py +++ b/flink-python/pyflink/table/tests/test_descriptor.py @@ -17,7 +17,8 @@ ################################################################################ import os -from pyflink.table.table_descriptor import (FileSystem, OldCsv, Rowtime, Schema) +from pyflink.table.table_descriptor import (FileSystem, OldCsv, Rowtime, Schema, Kafka, + Elasticsearch) from pyflink.table.table_sink import CsvTableSink from pyflink.table.types import DataTypes from pyflink.testing.test_case_utils import (PyFlinkTestCase, PyFlinkStreamTableTestCase, @@ -29,7 +30,7 @@ class FileSystemDescriptorTests(PyFlinkTestCase): def test_path(self): file_system = FileSystem() - file_system.path("/test.csv") + file_system = file_system.path("/test.csv") properties = file_system.to_properties() expected = {'connector.property-version': '1', @@ -38,12 +39,392 @@ def test_path(self): assert properties == expected +class KafkaDescriptorTests(PyFlinkTestCase): + + def test_version(self): + kafka = Kafka() + + kafka = kafka.version("0.11") + + properties = kafka.to_properties() + expected = {'connector.version': '0.11', + 'connector.type': 'kafka', + 'connector.property-version': '1'} + assert properties == expected + + def test_topic(self): + kafka = Kafka() + + kafka = kafka.topic("topic1") + + properties = kafka.to_properties() + expected = {'connector.type': 'kafka', + 'connector.topic': 'topic1', + 'connector.property-version': '1'} + assert properties == expected + + def test_properties(self): + kafka = Kafka() + + kafka = kafka.properties({"zookeeper.connect": "localhost:2181", + "bootstrap.servers": "localhost:9092"}) + + properties = kafka.to_properties() + expected = {'connector.type': 'kafka', + 'connector.properties.0.key': 'zookeeper.connect', + 'connector.properties.0.value': 'localhost:2181', + 'connector.properties.1.key': 'bootstrap.servers', + 'connector.properties.1.value': 'localhost:9092', + 'connector.property-version': '1'} + assert properties == expected + + def test_property(self): + kafka = Kafka() + + kafka = kafka.property("group.id", "testGroup") + + properties = kafka.to_properties() + expected = {'connector.type': 'kafka', + 'connector.properties.0.key': 'group.id', + 'connector.properties.0.value': 'testGroup', + 'connector.property-version': '1'} + assert properties == expected + + def test_start_from_earliest(self): + kafka = Kafka() + + kafka = kafka.start_from_earliest() + + properties = kafka.to_properties() + expected = {'connector.type': 'kafka', + 'connector.startup-mode': 'earliest-offset', + 'connector.property-version': '1'} + assert properties == expected + + def test_start_from_latest(self): + kafka = Kafka() + + kafka = kafka.start_from_latest() + + properties = kafka.to_properties() + expected = {'connector.type': 'kafka', + 'connector.startup-mode': 'latest-offset', + 'connector.property-version': '1'} + assert properties == expected + + def test_start_from_group_offsets(self): + kafka = Kafka() + + kafka = kafka.start_from_group_offsets() + + properties = kafka.to_properties() + expected = {'connector.type': 'kafka', + 'connector.startup-mode': 'group-offsets', + 'connector.property-version': '1'} + assert properties == expected + + def test_start_from_specific_offsets(self): + kafka = Kafka() + + kafka = kafka.start_from_specific_offsets({1: 220, 3: 400}) + + properties = kafka.to_properties() + expected = {'connector.startup-mode': 'specific-offsets', + 'connector.specific-offsets.0.partition': '1', + 'connector.specific-offsets.0.offset': '220', + 'connector.specific-offsets.1.partition': '3', + 'connector.specific-offsets.1.offset': '400', + 'connector.type': 'kafka', + 'connector.property-version': '1'} + assert properties == expected + + def test_start_from_specific_offset(self): + kafka = Kafka() + + kafka = kafka.start_from_specific_offset(3, 300) + + properties = kafka.to_properties() + expected = {'connector.startup-mode': 'specific-offsets', + 'connector.specific-offsets.0.partition': '3', + 'connector.specific-offsets.0.offset': '300', + 'connector.type': 'kafka', + 'connector.property-version': '1'} + assert properties == expected + + def test_sink_partitioner_fixed(self): + kafka = Kafka() + + kafka = kafka.sink_partitioner_fixed() + + properties = kafka.to_properties() + expected = {'connector.sink-partitioner': 'fixed', + 'connector.type': 'kafka', + 'connector.property-version': '1'} + assert properties == expected + + def test_sink_partitioner_custom(self): + kafka = Kafka() + + kafka = kafka.sink_partitioner_custom( + "org.apache.flink.streaming.connectors.kafka.partitioner.FlinkFixedPartitioner") + + properties = kafka.to_properties() + expected = {'connector.sink-partitioner': 'custom', + 'connector.sink-partitioner-class': + 'org.apache.flink.streaming.connectors.kafka.partitioner.' + 'FlinkFixedPartitioner', + 'connector.type': 'kafka', + 'connector.property-version': '1'} + assert properties == expected + + def test_sink_partitioner_round_robin(self): + kafka = Kafka() + + kafka = kafka.sink_partitioner_round_robin() + + properties = kafka.to_properties() + expected = {'connector.sink-partitioner': 'round-robin', + 'connector.type': 'kafka', + 'connector.property-version': '1'} + assert properties == expected + + +class ElasticsearchDescriptorTest(PyFlinkTestCase): + + def test_version(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.version("6") + + properties = elasticsearch.to_properties() + expected = {'connector.type': 'elasticsearch', + 'connector.version': '6', + 'connector.property-version': '1'} + assert properties == expected + + def test_host(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.host("localhost", 9200, "http") + + properties = elasticsearch.to_properties() + expected = {'connector.hosts.0.hostname': 'localhost', + 'connector.hosts.0.port': '9200', + 'connector.hosts.0.protocol': 'http', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_index(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.index("MyUsers") + + properties = elasticsearch.to_properties() + expected = {'connector.index': 'MyUsers', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_document_type(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.document_type("user") + + properties = elasticsearch.to_properties() + expected = {'connector.document-type': 'user', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_key_delimiter(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.key_delimiter("$") + + properties = elasticsearch.to_properties() + expected = {'connector.key-delimiter': '$', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_key_null_literal(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.key_null_literal("n/a") + + properties = elasticsearch.to_properties() + expected = {'connector.key-null-literal': 'n/a', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_failure_handler_fail(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.failure_handler_fail() + + properties = elasticsearch.to_properties() + expected = {'connector.failure-handler': 'fail', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_failure_handler_ignore(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.failure_handler_ignore() + + properties = elasticsearch.to_properties() + expected = {'connector.failure-handler': 'ignore', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_failure_handler_retry_rejected(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.failure_handler_retry_rejected() + + properties = elasticsearch.to_properties() + expected = {'connector.failure-handler': 'retry-rejected', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_failure_handler_custom(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.failure_handler_custom( + "org.apache.flink.streaming.connectors.elasticsearch.util.IgnoringFailureHandler") + + properties = elasticsearch.to_properties() + expected = {'connector.failure-handler': 'custom', + 'connector.failure-handler-class': + 'org.apache.flink.streaming.connectors.elasticsearch.util.' + 'IgnoringFailureHandler', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_disable_flush_on_checkpoint(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.disable_flush_on_checkpoint() + + properties = elasticsearch.to_properties() + expected = {'connector.flush-on-checkpoint': 'false', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_bulk_flush_max_actions(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_max_actions(42) + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.max-actions': '42', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_bulk_flush_max_size(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_max_size("42 mb") + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.max-size': '44040192 bytes', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + + assert properties == expected + + def test_bulk_flush_interval(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_interval(2000) + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.interval': '2000', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_bulk_flush_backoff_exponential(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_backoff_exponential() + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.backoff.type': 'exponential', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_bulk_flush_backoff_constant(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_backoff_constant() + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.backoff.type': 'constant', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_bulk_flush_backoff_max_retries(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_backoff_max_retries(3) + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.backoff.max-retries': '3', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_bulk_flush_backoff_delay(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.bulk_flush_backoff_delay(30000) + + properties = elasticsearch.to_properties() + expected = {'connector.bulk-flush.backoff.delay': '30000', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_connection_max_retry_timeout(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.connection_max_retry_timeout(3000) + + properties = elasticsearch.to_properties() + expected = {'connector.connection-max-retry-timeout': '3000', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + def test_connection_path_prefix(self): + elasticsearch = Elasticsearch() + + elasticsearch = elasticsearch.connection_path_prefix("/v1") + + properties = elasticsearch.to_properties() + expected = {'connector.connection-path-prefix': '/v1', + 'connector.type': 'elasticsearch', + 'connector.property-version': '1'} + assert properties == expected + + class OldCsvDescriptorTests(PyFlinkTestCase): def test_field_delimiter(self): csv = OldCsv() - csv.field_delimiter("|") + csv = csv.field_delimiter("|") properties = csv.to_properties() expected = {'format.field-delimiter': '|', @@ -54,7 +435,7 @@ def test_field_delimiter(self): def test_line_delimiter(self): csv = OldCsv() - csv.line_delimiter(";") + csv = csv.line_delimiter(";") expected = {'format.type': 'csv', 'format.property-version': '1', @@ -66,7 +447,7 @@ def test_line_delimiter(self): def test_ignore_parse_errors(self): csv = OldCsv() - csv.ignore_parse_errors() + csv = csv.ignore_parse_errors() properties = csv.to_properties() expected = {'format.ignore-parse-errors': 'true', @@ -77,7 +458,7 @@ def test_ignore_parse_errors(self): def test_quote_character(self): csv = OldCsv() - csv.quote_character("*") + csv = csv.quote_character("*") properties = csv.to_properties() expected = {'format.quote-character': '*', @@ -88,7 +469,7 @@ def test_quote_character(self): def test_comment_prefix(self): csv = OldCsv() - csv.comment_prefix("#") + csv = csv.comment_prefix("#") properties = csv.to_properties() expected = {'format.comment-prefix': '#', @@ -99,7 +480,7 @@ def test_comment_prefix(self): def test_ignore_first_line(self): csv = OldCsv() - csv.ignore_first_line() + csv = csv.ignore_first_line() properties = csv.to_properties() expected = {'format.ignore-first-line': 'true', @@ -363,7 +744,7 @@ class AbstractTableDescriptorTests(object): def test_with_format(self): descriptor = self.t_env.connect(FileSystem()) - descriptor.with_format(OldCsv().field("a", "INT")) + descriptor = descriptor.with_format(OldCsv().field("a", "INT")) properties = descriptor.to_properties() @@ -378,7 +759,7 @@ def test_with_format(self): def test_with_schema(self): descriptor = self.t_env.connect(FileSystem()) - descriptor.with_format(OldCsv()).with_schema(Schema().field("a", "INT")) + descriptor = descriptor.with_format(OldCsv()).with_schema(Schema().field("a", "INT")) properties = descriptor.to_properties() expected = {'schema.0.name': 'a', @@ -505,7 +886,7 @@ class StreamTableDescriptorTests(PyFlinkStreamTableTestCase, AbstractTableDescri def test_in_append_mode(self): descriptor = self.t_env.connect(FileSystem()) - descriptor\ + descriptor = descriptor\ .with_format(OldCsv())\ .in_append_mode() @@ -520,7 +901,7 @@ def test_in_append_mode(self): def test_in_retract_mode(self): descriptor = self.t_env.connect(FileSystem()) - descriptor \ + descriptor = descriptor \ .with_format(OldCsv()) \ .in_retract_mode() @@ -535,7 +916,7 @@ def test_in_retract_mode(self): def test_in_upsert_mode(self): descriptor = self.t_env.connect(FileSystem()) - descriptor \ + descriptor = descriptor \ .with_format(OldCsv()) \ .in_upsert_mode() diff --git a/tools/travis_controller.sh b/tools/travis_controller.sh index 6741d6ab8c117d..f19fd8a2dbf95b 100755 --- a/tools/travis_controller.sh +++ b/tools/travis_controller.sh @@ -161,6 +161,8 @@ if [ $STAGE == "$STAGE_COMPILE" ]; then find "$CACHE_FLINK_DIR" -maxdepth 8 -type f -name '*.jar' \ ! -path "$CACHE_FLINK_DIR/flink-dist/target/flink-*-bin/flink-*/lib/flink-dist*.jar" \ ! -path "$CACHE_FLINK_DIR/flink-dist/target/flink-*-bin/flink-*/opt/flink-table*.jar" \ + ! -path "$CACHE_FLINK_DIR/flink-connectors/flink-connector-elasticsearch-base/target/flink-*.jar" \ + ! -path "$CACHE_FLINK_DIR/flink-connectors/flink-connector-kafka-base/target/flink-*.jar" \ ! -path "$CACHE_FLINK_DIR/flink-table/flink-table-planner/target/flink-table-planner*tests.jar" | xargs rm -rf # .git directory From 3710bcbc195f9a6304bfb452a58bb4a96804cc24 Mon Sep 17 00:00:00 2001 From: Danny Chan Date: Fri, 31 May 2019 17:28:39 +0800 Subject: [PATCH 36/92] [FLINK-6962][table] Add sql parser module and support CREATE / DROP table This closes #8548 --- flink-table/flink-sql-parser/pom.xml | 297 +++++++++++ .../src/main/codegen/config.fmpp | 41 ++ .../src/main/codegen/data/Parser.tdd | 430 +++++++++++++++ .../codegen/includes/compoundIdentifier.ftl | 34 ++ .../src/main/codegen/includes/parserImpls.ftl | 289 ++++++++++ .../apache/flink/sql/parser/SqlProperty.java | 91 ++++ .../flink/sql/parser/ddl/ExtendedSqlType.java | 42 ++ .../flink/sql/parser/ddl/SqlArrayType.java | 49 ++ .../flink/sql/parser/ddl/SqlColumnType.java | 62 +++ .../flink/sql/parser/ddl/SqlCreateTable.java | 320 +++++++++++ .../flink/sql/parser/ddl/SqlDropTable.java | 87 +++ .../flink/sql/parser/ddl/SqlMapType.java | 57 ++ .../flink/sql/parser/ddl/SqlRowType.java | 78 +++ .../flink/sql/parser/ddl/SqlTableColumn.java | 97 ++++ .../sql/parser/error/SqlParseException.java | 60 +++ .../flink/sql/parser/utils/SqlTimeUnit.java | 49 ++ .../sql/parser/FlinkSqlParserImplTest.java | 498 ++++++++++++++++++ .../sql/parser/FlinkSqlUnParserTest.java | 42 ++ flink-table/flink-table-planner-blink/pom.xml | 6 + .../table/sqlexec/SqlExecutableStatement.java | 79 +++ flink-table/pom.xml | 1 + 21 files changed, 2709 insertions(+) create mode 100644 flink-table/flink-sql-parser/pom.xml create mode 100644 flink-table/flink-sql-parser/src/main/codegen/config.fmpp create mode 100644 flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd create mode 100644 flink-table/flink-sql-parser/src/main/codegen/includes/compoundIdentifier.ftl create mode 100644 flink-table/flink-sql-parser/src/main/codegen/includes/parserImpls.ftl create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/SqlProperty.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/ExtendedSqlType.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlArrayType.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlColumnType.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlCreateTable.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlDropTable.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlMapType.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlRowType.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlTableColumn.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/error/SqlParseException.java create mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/utils/SqlTimeUnit.java create mode 100644 flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java create mode 100644 flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlUnParserTest.java create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sqlexec/SqlExecutableStatement.java diff --git a/flink-table/flink-sql-parser/pom.xml b/flink-table/flink-sql-parser/pom.xml new file mode 100644 index 00000000000000..a4003cc34cf869 --- /dev/null +++ b/flink-table/flink-sql-parser/pom.xml @@ -0,0 +1,297 @@ + + + + + 4.0.0 + + + flink-table + org.apache.flink + 1.9-SNAPSHOT + + + flink-sql-parser + flink-sql-parser + + jar + + + + + 1.19.0 + + + + + org.apache.flink + flink-shaded-guava + + + + org.apache.calcite + calcite-core + + ${calcite.version} + + + + org.apache.calcite.avatica + avatica-metrics + + + com.google.protobuf + protobuf-java + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.apache.commons + commons-dbcp2 + + + com.esri.geometry + esri-geometry-api + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + com.yahoo.datasketches + sketches-core + + + net.hydromatic + aggdesigner-algorithm + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + com.jayway.jsonpath + json-path + + + joda-time + joda-time + + + org.apache.calcite + calcite-linq4j + + + org.codehaus.janino + janino + + + org.codehaus.janino + commons-compiler + + + com.google.code.findbugs + jsr305 + + + org.apache.commons + commons-lang3 + + + + + org.apache.calcite + calcite-core + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + test + test-jar + + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.8 + + + unpack-parser-template + initialize + + unpack + + + + + org.apache.calcite + calcite-core + jar + true + ${project.build.directory}/ + **/Parser.jj + + + + + + + + + maven-resources-plugin + + + copy-fmpp-resources + initialize + + copy-resources + + + ${project.build.directory}/codegen + + + src/main/codegen + false + + + + + + + + com.googlecode.fmpp-maven-plugin + fmpp-maven-plugin + 1.0 + + + org.freemarker + freemarker + 2.3.25-incubating + + + + + generate-fmpp-sources + generate-sources + + generate + + + ${project.build.directory}/codegen/config.fmpp + target/generated-sources + ${project.build.directory}/codegen/templates + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 1.5 + + + add-generated-sources + process-sources + + add-source + + + + ${project.build.directory}/generated-sources + + + + + + + org.codehaus.mojo + javacc-maven-plugin + 2.4 + + + generate-sources + javacc + + javacc + + + ${project.build.directory}/generated-sources/ + + **/Parser.jj + + + 2 + false + ${project.build.directory}/generated-sources/ + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + 1 + false + + + + + + diff --git a/flink-table/flink-sql-parser/src/main/codegen/config.fmpp b/flink-table/flink-sql-parser/src/main/codegen/config.fmpp new file mode 100644 index 00000000000000..1d5c8e79a511d7 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/codegen/config.fmpp @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is an FMPP (http://fmpp.sourceforge.net/) configuration file to +# allow clients to extend Calcite's SQL parser to support application specific +# SQL statements, literals or data types. +# +# Calcite's parser grammar file (Parser.jj) is written in javacc +# (https://javacc.org/) with Freemarker (http://freemarker.org/) variables +# to allow clients to: +# 1. have custom parser implementation class and package name. +# 2. insert new parser method implementations written in javacc to parse +# custom: +# a) SQL statements. +# b) literals. +# c) data types. +# 3. add new keywords to support custom SQL constructs added as part of (2). +# 4. add import statements needed by inserted custom parser implementations. +# +# Parser template file (Parser.jj) along with this file are packaged as +# part of the calcite-core-.jar under "codegen" directory. + +data: { + parser: tdd(../data/Parser.tdd) +} + +freemarkerLinks: { + includes: includes/ +} diff --git a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd new file mode 100644 index 00000000000000..c026d8a2f01552 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd @@ -0,0 +1,430 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + # Generated parser implementation package and class name. + package: "org.apache.flink.sql.parser.impl", + class: "FlinkSqlParserImpl", + + # List of additional classes and packages to import. + # Example. "org.apache.calcite.sql.*", "java.util.List". + imports: [ + "org.apache.flink.sql.parser.ddl.SqlCreateTable", + "org.apache.flink.sql.parser.ddl.SqlDropTable" + "org.apache.flink.sql.parser.ddl.SqlCreateTable.TableCreationContext", + "org.apache.flink.sql.parser.ddl.SqlTableColumn", + "org.apache.flink.sql.parser.ddl.SqlArrayType", + "org.apache.flink.sql.parser.ddl.SqlMapType", + "org.apache.flink.sql.parser.ddl.SqlRowType", + "org.apache.flink.sql.parser.utils.SqlTimeUnit", + "org.apache.flink.sql.parser.SqlProperty", + "org.apache.calcite.sql.SqlDrop", + "org.apache.calcite.sql.SqlCreate", + "java.util.List", + "java.util.ArrayList" + ] + + # List of new keywords. Example: "DATABASES", "TABLES". If the keyword is not a reserved + # keyword, please also add it to 'nonReservedKeywords' section. + keywords: [ + "COMMENT", + "PARTITIONED", + "IF", + "WATERMARK", + "ASCENDING", + "FROM_SOURCE", + "BOUNDED", + "DELAY" + ] + + # List of keywords from "keywords" section that are not reserved. + nonReservedKeywords: [ + "A" + "ABSENT" + "ABSOLUTE" + "ACTION" + "ADA" + "ADD" + "ADMIN" + "AFTER" + "ALWAYS" + "APPLY" + "ASC" + "ASSERTION" + "ASSIGNMENT" + "ATTRIBUTE" + "ATTRIBUTES" + "BEFORE" + "BERNOULLI" + "BREADTH" + "C" + "CASCADE" + "CATALOG" + "CATALOG_NAME" + "CENTURY" + "CHAIN" + "CHARACTER_SET_CATALOG" + "CHARACTER_SET_NAME" + "CHARACTER_SET_SCHEMA" + "CHARACTERISTICS" + "CHARACTERS" + "CLASS_ORIGIN" + "COBOL" + "COLLATION" + "COLLATION_CATALOG" + "COLLATION_NAME" + "COLLATION_SCHEMA" + "COLUMN_NAME" + "COMMAND_FUNCTION" + "COMMAND_FUNCTION_CODE" + "COMMITTED" + "CONDITION_NUMBER" + "CONDITIONAL" + "CONNECTION" + "CONNECTION_NAME" + "CONSTRAINT_CATALOG" + "CONSTRAINT_NAME" + "CONSTRAINT_SCHEMA" + "CONSTRAINTS" + "CONSTRUCTOR" + "CONTINUE" + "CURSOR_NAME" + "DATA" + "DATABASE" + "DATETIME_INTERVAL_CODE" + "DATETIME_INTERVAL_PRECISION" + "DECADE" + "DEFAULTS" + "DEFERRABLE" + "DEFERRED" + "DEFINED" + "DEFINER" + "DEGREE" + "DEPTH" + "DERIVED" + "DESC" + "DESCRIPTION" + "DESCRIPTOR" + "DIAGNOSTICS" + "DISPATCH" + "DOMAIN" + "DOW" + "DOY" + "DYNAMIC_FUNCTION" + "DYNAMIC_FUNCTION_CODE" + "ENCODING" + "EPOCH" + "ERROR" + "EXCEPTION" + "EXCLUDE" + "EXCLUDING" + "FINAL" + "FIRST" + "FOLLOWING" + "FORMAT" + "FORTRAN" + "FOUND" + "FRAC_SECOND" + "G" + "GENERAL" + "GENERATED" + "GEOMETRY" + "GO" + "GOTO" + "GRANTED" + "HIERARCHY" + "IMMEDIATE" + "IMMEDIATELY" + "IMPLEMENTATION" + "INCLUDING" + "INCREMENT" + "INITIALLY" + "INPUT" + "INSTANCE" + "INSTANTIABLE" + "INVOKER" + "ISODOW" + "ISOYEAR" + "ISOLATION" + "JAVA" + "JSON" + "JSON_TYPE" + "JSON_DEPTH" + "JSON_PRETTY" + "K" + "KEY" + "KEY_MEMBER" + "KEY_TYPE" + "LABEL" + "LAST" + "LENGTH" + "LEVEL" + "LIBRARY" + "LOCATOR" + "M" + "MAP" + "MATCHED" + "MAXVALUE" + "MICROSECOND" + "MESSAGE_LENGTH" + "MESSAGE_OCTET_LENGTH" + "MESSAGE_TEXT" + "MILLISECOND" + "MILLENNIUM" + "MINVALUE" + "MORE_" + "MUMPS" + "NAME" + "NAMES" + "NANOSECOND" + "NESTING" + "NORMALIZED" + "NULLABLE" + "NULLS" + "NUMBER" + "OBJECT" + "OCTETS" + "OPTION" + "OPTIONS" + "ORDERING" + "ORDINALITY" + "OTHERS" + "OUTPUT" + "OVERRIDING" + "PAD" + "PARAMETER_MODE" + "PARAMETER_NAME" + "PARAMETER_ORDINAL_POSITION" + "PARAMETER_SPECIFIC_CATALOG" + "PARAMETER_SPECIFIC_NAME" + "PARAMETER_SPECIFIC_SCHEMA" + "PARTIAL" + "PASCAL" + "PASSING" + "PASSTHROUGH" + "PAST" + "PATH" + "PLACING" + "PLAN" + "PLI" + "PRECEDING" + "PRESERVE" + "PRIOR" + "PRIVILEGES" + "PUBLIC" + "QUARTER" + "READ" + "RELATIVE" + "REPEATABLE" + "REPLACE" + "RESTART" + "RESTRICT" + "RETURNED_CARDINALITY" + "RETURNED_LENGTH" + "RETURNED_OCTET_LENGTH" + "RETURNED_SQLSTATE" + "RETURNING" + "ROLE" + "ROUTINE" + "ROUTINE_CATALOG" + "ROUTINE_NAME" + "ROUTINE_SCHEMA" + "ROW_COUNT" + "SCALAR" + "SCALE" + "SCHEMA" + "SCHEMA_NAME" + "SCOPE_CATALOGS" + "SCOPE_NAME" + "SCOPE_SCHEMA" + "SECTION" + "SECURITY" + "SELF" + "SEQUENCE" + "SERIALIZABLE" + "SERVER" + "SERVER_NAME" + "SESSION" + "SETS" + "SIMPLE" + "SIZE" + "SOURCE" + "SPACE" + "SPECIFIC_NAME" + "SQL_BIGINT" + "SQL_BINARY" + "SQL_BIT" + "SQL_BLOB" + "SQL_BOOLEAN" + "SQL_CHAR" + "SQL_CLOB" + "SQL_DATE" + "SQL_DECIMAL" + "SQL_DOUBLE" + "SQL_FLOAT" + "SQL_INTEGER" + "SQL_INTERVAL_DAY" + "SQL_INTERVAL_DAY_TO_HOUR" + "SQL_INTERVAL_DAY_TO_MINUTE" + "SQL_INTERVAL_DAY_TO_SECOND" + "SQL_INTERVAL_HOUR" + "SQL_INTERVAL_HOUR_TO_MINUTE" + "SQL_INTERVAL_HOUR_TO_SECOND" + "SQL_INTERVAL_MINUTE" + "SQL_INTERVAL_MINUTE_TO_SECOND" + "SQL_INTERVAL_MONTH" + "SQL_INTERVAL_SECOND" + "SQL_INTERVAL_YEAR" + "SQL_INTERVAL_YEAR_TO_MONTH" + "SQL_LONGVARBINARY" + "SQL_LONGVARNCHAR" + "SQL_LONGVARCHAR" + "SQL_NCHAR" + "SQL_NCLOB" + "SQL_NUMERIC" + "SQL_NVARCHAR" + "SQL_REAL" + "SQL_SMALLINT" + "SQL_TIME" + "SQL_TIMESTAMP" + "SQL_TINYINT" + "SQL_TSI_DAY" + "SQL_TSI_FRAC_SECOND" + "SQL_TSI_HOUR" + "SQL_TSI_MICROSECOND" + "SQL_TSI_MINUTE" + "SQL_TSI_MONTH" + "SQL_TSI_QUARTER" + "SQL_TSI_SECOND" + "SQL_TSI_WEEK" + "SQL_TSI_YEAR" + "SQL_VARBINARY" + "SQL_VARCHAR" + "STATE" + "STATEMENT" + "STRUCTURE" + "STYLE" + "SUBCLASS_ORIGIN" + "SUBSTITUTE" + "TABLE_NAME" + "TEMPORARY" + "TIES" + "TIMESTAMPADD" + "TIMESTAMPDIFF" + "TOP_LEVEL_COUNT" + "TRANSACTION" + "TRANSACTIONS_ACTIVE" + "TRANSACTIONS_COMMITTED" + "TRANSACTIONS_ROLLED_BACK" + "TRANSFORM" + "TRANSFORMS" + "TRIGGER_CATALOG" + "TRIGGER_NAME" + "TRIGGER_SCHEMA" + "TYPE" + "UNBOUNDED" + "UNCOMMITTED" + "UNCONDITIONAL" + "UNDER" + "UNNAMED" + "USAGE" + "USER_DEFINED_TYPE_CATALOG" + "USER_DEFINED_TYPE_CODE" + "USER_DEFINED_TYPE_NAME" + "USER_DEFINED_TYPE_SCHEMA" + "UTF8" + "UTF16" + "UTF32" + "VERSION" + "VIEW" + "WEEK" + "WRAPPER" + "WORK" + "WRITE" + "XML" + "ZONE" + + # not in core, added in Flink + "PARTITIONED", + "IF", + "ASCENDING", + "FROM_SOURCE", + "BOUNDED", + "DELAY" + ] + + # List of methods for parsing custom SQL statements. + # Return type of method implementation should be 'SqlNode'. + # Example: SqlShowDatabases(), SqlShowTables(). + statementParserMethods: [ + ] + + # List of methods for parsing custom literals. + # Return type of method implementation should be "SqlNode". + # Example: ParseJsonLiteral(). + literalParserMethods: [ + ] + + # List of methods for parsing custom data types. + # Return type of method implementation should be "SqlIdentifier". + # Example: SqlParseTimeStampZ(). + dataTypeParserMethods: [ + "SqlArrayType()", + "SqlMapType()", + "SqlRowType()" + ] + + # List of methods for parsing builtin function calls. + # Return type of method implementation should be "SqlNode". + # Example: DateFunctionCall(). + builtinFunctionCallMethods: [ + ] + + # List of methods for parsing extensions to "ALTER " calls. + # Each must accept arguments "(SqlParserPos pos, String scope)". + # Example: "SqlUploadJarNode" + alterStatementParserMethods: [ + ] + + # List of methods for parsing extensions to "CREATE [OR REPLACE]" calls. + # Each must accept arguments "(SqlParserPos pos, boolean replace)". + createStatementParserMethods: [ + "SqlCreateTable" + ] + + # List of methods for parsing extensions to "DROP" calls. + # Each must accept arguments "(Span s)". + dropStatementParserMethods: [ + "SqlDropTable" + ] + + # List of files in @includes directory that have parser method + # implementations for parsing custom SQL statements, literals or types + # given as part of "statementParserMethods", "literalParserMethods" or + # "dataTypeParserMethods". + implementationFiles: [ + "parserImpls.ftl" + ] + + # List of additional join types. Each is a method with no arguments. + # Example: LeftSemiJoin() + joinTypes: [ + ] + + includeCompoundIdentifier: true + includeBraces: true + includeAdditionalDeclarations: false +} diff --git a/flink-table/flink-sql-parser/src/main/codegen/includes/compoundIdentifier.ftl b/flink-table/flink-sql-parser/src/main/codegen/includes/compoundIdentifier.ftl new file mode 100644 index 00000000000000..70db3c2ee3bd45 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/codegen/includes/compoundIdentifier.ftl @@ -0,0 +1,34 @@ +<#-- +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to you under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +--> + +<#-- + Add implementations of additional parser statements, literals or + data types. + + Example of SqlShowTables() implementation: + SqlNode SqlShowTables() + { + ...local variables... + } + { + + ... + { + return SqlShowTables(...) + } + } +--> diff --git a/flink-table/flink-sql-parser/src/main/codegen/includes/parserImpls.ftl b/flink-table/flink-sql-parser/src/main/codegen/includes/parserImpls.ftl new file mode 100644 index 00000000000000..92607d5b30c140 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/codegen/includes/parserImpls.ftl @@ -0,0 +1,289 @@ +<#-- +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to you under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +--> + +void TableColumn(TableCreationContext context) : +{ +} +{ + ( + TableColumn2(context.columnList) + | + context.primaryKeyList = PrimaryKey() + | + UniqueKey(context.uniqueKeysList) + | + ComputedColumn(context) + ) +} + +void ComputedColumn(TableCreationContext context) : +{ + SqlNode identifier; + SqlNode expr; + boolean hidden = false; + SqlParserPos pos; +} +{ + identifier = SimpleIdentifier() {pos = getPos();} + + expr = Expression(ExprContext.ACCEPT_SUB_QUERY) { + expr = SqlStdOperatorTable.AS.createCall(Span.of(identifier, expr).pos(), expr, identifier); + context.columnList.add(expr); + } +} + +void TableColumn2(List list) : +{ + SqlParserPos pos; + SqlIdentifier name; + SqlDataTypeSpec type; + SqlCharStringLiteral comment = null; +} +{ + name = SimpleIdentifier() + type = DataType() + ( + { type = type.withNullable(true); } + | + { type = type.withNullable(false); } + | + { type = type.withNullable(true); } + ) + [ { + String p = SqlParserUtil.parseString(token.image); + comment = SqlLiteral.createCharString(p, getPos()); + }] + { + SqlTableColumn tableColumn = new SqlTableColumn(name, type, comment, getPos()); + list.add(tableColumn); + } +} + +SqlNodeList PrimaryKey() : +{ + List pkList = new ArrayList(); + + SqlParserPos pos; + SqlIdentifier columnName; +} +{ + { pos = getPos(); } + columnName = SimpleIdentifier() { pkList.add(columnName); } + ( columnName = SimpleIdentifier() { pkList.add(columnName); })* + + { + return new SqlNodeList(pkList, pos.plus(getPos())); + } +} + +void UniqueKey(List list) : +{ + List ukList = new ArrayList(); + SqlParserPos pos; + SqlIdentifier columnName; +} +{ + { pos = getPos(); } + columnName = SimpleIdentifier() { ukList.add(columnName); } + ( columnName = SimpleIdentifier() { ukList.add(columnName); })* + + { + SqlNodeList uk = new SqlNodeList(ukList, pos.plus(getPos())); + list.add(uk); + } +} + +SqlNode PropertyValue() : +{ + SqlIdentifier key; + SqlNode value; + SqlParserPos pos; +} +{ + key = CompoundIdentifier() + { pos = getPos(); } + value = StringLiteral() + { + return new SqlProperty(key, value, getPos()); + } +} + +SqlCreate SqlCreateTable(Span s, boolean replace) : +{ + final SqlParserPos startPos = s.pos(); + SqlIdentifier tableName; + SqlNodeList primaryKeyList = null; + List uniqueKeysList = null; + SqlNodeList columnList = SqlNodeList.EMPTY; + SqlCharStringLiteral comment = null; + + SqlNodeList propertyList = null; + SqlNodeList partitionColumns = null; + SqlParserPos pos = startPos; +} +{ + + + tableName = CompoundIdentifier() + [ + { pos = getPos(); TableCreationContext ctx = new TableCreationContext();} + TableColumn(ctx) + ( + TableColumn(ctx) + )* + { + pos = pos.plus(getPos()); + columnList = new SqlNodeList(ctx.columnList, pos); + primaryKeyList = ctx.primaryKeyList; + uniqueKeysList = ctx.uniqueKeysList; + } + + ] + [ { + String p = SqlParserUtil.parseString(token.image); + comment = SqlLiteral.createCharString(p, getPos()); + }] + [ + + { + SqlNode column; + List partitionKey = new ArrayList(); + pos = getPos(); + + } + + [ + column = SimpleIdentifier() + { + partitionKey.add(column); + } + ( + column = SimpleIdentifier() + { + partitionKey.add(column); + } + )* + ] + + { + partitionColumns = new SqlNodeList(partitionKey, pos.plus(getPos())); + } + ] + [ + + { + SqlNode property; + List proList = new ArrayList(); + pos = getPos(); + } + + [ + property = PropertyValue() + { + proList.add(property); + } + ( + property = PropertyValue() + { + proList.add(property); + } + )* + ] + + { propertyList = new SqlNodeList(proList, pos.plus(getPos())); } + ] + + { + return new SqlCreateTable(startPos.plus(getPos()), + tableName, + columnList, + primaryKeyList, + uniqueKeysList, + propertyList, + partitionColumns, + comment); + } +} + +SqlDrop SqlDropTable(Span s, boolean replace) : +{ + SqlIdentifier tableName = null; + boolean ifExists = false; +} +{ +
+ + [ { ifExists = true; } ] + + tableName = CompoundIdentifier() + + { + return new SqlDropTable(s.pos(), tableName, ifExists); + } +} + +SqlIdentifier SqlArrayType() : +{ + SqlParserPos pos; + SqlDataTypeSpec elementType; +} +{ + { pos = getPos(); } + elementType = DataType() + + { + return new SqlArrayType(pos, elementType); + } +} + +SqlIdentifier SqlMapType() : +{ + SqlParserPos pos; + SqlDataTypeSpec keyType; + SqlDataTypeSpec valType; +} +{ + { pos = getPos(); } + keyType = DataType() + valType = DataType() + + { + return new SqlMapType(pos, keyType, valType); + } +} + +SqlIdentifier SqlRowType() : +{ + SqlParserPos pos; + List fieldNames = new ArrayList(); + List fieldTypes = new ArrayList(); +} +{ + { pos = getPos(); SqlIdentifier fName; SqlDataTypeSpec fType;} + + fName = SimpleIdentifier() fType = DataType() + { fieldNames.add(fName); fieldTypes.add(fType); } + ( + + fName = SimpleIdentifier() fType = DataType() + { fieldNames.add(fName); fieldTypes.add(fType); } + )* + + { + return new SqlRowType(pos, fieldNames, fieldTypes); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/SqlProperty.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/SqlProperty.java new file mode 100644 index 00000000000000..dbb58e6db2cccd --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/SqlProperty.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.util.ImmutableNullableList; +import org.apache.calcite.util.NlsString; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * Properties of a DDL, consist of key value pairs. + */ +public class SqlProperty extends SqlCall { + + /** Use this operator only if you don't have a better one. */ + protected static final SqlOperator OPERATOR = + new SqlSpecialOperator("Property", SqlKind.OTHER); + + private final SqlIdentifier key; + private final SqlNode value; + + public SqlProperty(SqlIdentifier key, SqlNode value, SqlParserPos pos) { + super(pos); + this.key = requireNonNull(key, "Property key is missing"); + this.value = requireNonNull(value, "Property value is missing"); + } + + public SqlIdentifier getKey() { + return key; + } + + public SqlNode getValue() { + return value; + } + + public String getKeyString() { + return key.toString(); + } + + public String getValueString() { + return ((NlsString) SqlLiteral.value(value)).getValue(); + } + + @Override + public SqlOperator getOperator() { + return OPERATOR; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(key, value); + } + + @Override + public void unparse( + SqlWriter writer, + int leftPrec, + int rightPrec) { + key.unparse(writer, leftPrec, rightPrec); + writer.keyword("="); + value.unparse(writer, leftPrec, rightPrec); + } +} + +// End SqlProperty.java diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/ExtendedSqlType.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/ExtendedSqlType.java new file mode 100644 index 00000000000000..135dc4946d049f --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/ExtendedSqlType.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlWriter; + +/** An remark interface which should be inherited by supported sql types which are not supported + * by Calcite default parser. + * + *

Caution that the subclass must override the method + * {@link org.apache.calcite.sql.SqlNode#unparse(SqlWriter, int, int)}. + */ +public interface ExtendedSqlType { + + static void unparseType(SqlDataTypeSpec type, + SqlWriter writer, + int leftPrec, + int rightPrec) { + if (type.getTypeName() instanceof ExtendedSqlType) { + type.getTypeName().unparse(writer, leftPrec, rightPrec); + } else { + type.unparse(writer, leftPrec, rightPrec); + } + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlArrayType.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlArrayType.java new file mode 100644 index 00000000000000..7d43d4f5b66b88 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlArrayType.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Parse column of ArrayType. + */ +public class SqlArrayType extends SqlIdentifier implements ExtendedSqlType { + + private final SqlDataTypeSpec elementType; + + public SqlArrayType(SqlParserPos pos, SqlDataTypeSpec elementType) { + super(SqlTypeName.ARRAY.getName(), pos); + this.elementType = elementType; + } + + public SqlDataTypeSpec getElementType() { + return elementType; + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("ARRAY<"); + ExtendedSqlType.unparseType(this.elementType, writer, leftPrec, rightPrec); + writer.keyword(">"); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlColumnType.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlColumnType.java new file mode 100644 index 00000000000000..3e494a72628179 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlColumnType.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +/** + * All supported data types in DDL. Used for Create Table DDL validation. + */ +public enum SqlColumnType { + BOOLEAN, + TINYINT, + SMALLINT, + INT, + INTEGER, + BIGINT, + REAL, + FLOAT, + DOUBLE, + DECIMAL, + DATE, + TIME, + TIMESTAMP, + VARCHAR, + VARBINARY, + ANY, + ARRAY, + MAP, + ROW, + UNSUPPORTED; + + /** Returns the column type with the string representation. **/ + public static SqlColumnType getType(String type) { + if (type == null) { + return UNSUPPORTED; + } + try { + return SqlColumnType.valueOf(type.toUpperCase()); + } catch (IllegalArgumentException var1) { + return UNSUPPORTED; + } + } + + /** Returns true if this type is unsupported. **/ + public boolean isUnsupported() { + return this.equals(UNSUPPORTED); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlCreateTable.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlCreateTable.java new file mode 100644 index 00000000000000..47773cb76fe8c1 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlCreateTable.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.flink.sql.parser.error.SqlParseException; + +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCharStringLiteral; +import org.apache.calcite.sql.SqlCreate; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.dialect.AnsiSqlDialect; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.pretty.SqlPrettyWriter; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +/** + * CREATE TABLE DDL sql call. + */ +public class SqlCreateTable extends SqlCreate { + + public static final SqlSpecialOperator OPERATOR = new SqlSpecialOperator("CREATE TABLE", SqlKind.CREATE_TABLE); + + private final SqlIdentifier tableName; + + private final SqlNodeList columnList; + + private final SqlNodeList propertyList; + + private final SqlNodeList primaryKeyList; + + private final List uniqueKeysList; + + private final SqlNodeList partitionKeyList; + + private final SqlCharStringLiteral comment; + + public SqlCreateTable( + SqlParserPos pos, + SqlIdentifier tableName, + SqlNodeList columnList, + SqlNodeList primaryKeyList, + List uniqueKeysList, + SqlNodeList propertyList, + SqlNodeList partitionKeyList, + SqlCharStringLiteral comment) { + super(OPERATOR, pos, false, false); + this.tableName = requireNonNull(tableName, "Table name is missing"); + this.columnList = requireNonNull(columnList, "Column list should not be null"); + this.primaryKeyList = primaryKeyList; + this.uniqueKeysList = uniqueKeysList; + this.propertyList = propertyList; + this.partitionKeyList = partitionKeyList; + this.comment = comment; + } + + @Override + public SqlOperator getOperator() { + return OPERATOR; + } + + @Override + public List getOperandList() { + return null; + } + + public SqlIdentifier getTableName() { + return tableName; + } + + public SqlNodeList getColumnList() { + return columnList; + } + + public SqlNodeList getPropertyList() { + return propertyList; + } + + public SqlNodeList getPartitionKeyList() { + return partitionKeyList; + } + + public SqlNodeList getPrimaryKeyList() { + return primaryKeyList; + } + + public List getUniqueKeysList() { + return uniqueKeysList; + } + + public SqlCharStringLiteral getComment() { + return comment; + } + + public void validate() throws SqlParseException { + Set columnNames = new HashSet<>(); + if (columnList != null) { + for (SqlNode column : columnList) { + String columnName = null; + if (column instanceof SqlTableColumn) { + SqlTableColumn tableColumn = (SqlTableColumn) column; + columnName = tableColumn.getName().getSimple(); + String typeName = tableColumn.getType().getTypeName().getSimple(); + if (SqlColumnType.getType(typeName).isUnsupported()) { + throw new SqlParseException( + column.getParserPosition(), + "Not support type [" + typeName + "], at " + column.getParserPosition()); + } + } else if (column instanceof SqlBasicCall) { + SqlBasicCall tableColumn = (SqlBasicCall) column; + columnName = tableColumn.getOperands()[1].toString(); + } + + if (!columnNames.add(columnName)) { + throw new SqlParseException( + column.getParserPosition(), + "Duplicate column name [" + columnName + "], at " + + column.getParserPosition()); + } + } + } + + if (this.primaryKeyList != null) { + for (SqlNode primaryKeyNode : this.primaryKeyList) { + String primaryKey = ((SqlIdentifier) primaryKeyNode).getSimple(); + if (!columnNames.contains(primaryKey)) { + throw new SqlParseException( + primaryKeyNode.getParserPosition(), + "Primary key [" + primaryKey + "] not defined in columns, at " + + primaryKeyNode.getParserPosition()); + } + } + } + + if (this.uniqueKeysList != null) { + for (SqlNodeList uniqueKeys: this.uniqueKeysList) { + for (SqlNode uniqueKeyNode : uniqueKeys) { + String uniqueKey = ((SqlIdentifier) uniqueKeyNode).getSimple(); + if (!columnNames.contains(uniqueKey)) { + throw new SqlParseException( + uniqueKeyNode.getParserPosition(), + "Unique key [" + uniqueKey + "] not defined in columns, at " + uniqueKeyNode.getParserPosition()); + } + } + } + } + + if (this.partitionKeyList != null) { + for (SqlNode partitionKeyNode : this.partitionKeyList.getList()) { + String partitionKey = ((SqlIdentifier) partitionKeyNode).getSimple(); + if (!columnNames.contains(partitionKey)) { + throw new SqlParseException( + partitionKeyNode.getParserPosition(), + "Partition column [" + partitionKey + "] not defined in columns, at " + + partitionKeyNode.getParserPosition()); + } + } + } + + } + + public boolean containsComputedColumn() { + for (SqlNode column : columnList) { + if (column instanceof SqlBasicCall) { + return true; + } + } + return false; + } + + /** + * Returns the projection format of the DDL columns(including computed columns). + * e.g. If we got a DDL: + *

+	 *   create table tbl1(
+	 *     col1 int,
+	 *     col2 varchar,
+	 *     col3 as to_timestamp(col2)
+	 *   ) with (
+	 *     connector = 'csv'
+	 *   )
+	 * 
+ * we would return a query like: + * + *

"col1, col2, to_timestamp(col2) as col3", caution that the "computed column" operands + * have been reversed. + */ + public String getColumnSqlString() { + SqlPrettyWriter writer = new SqlPrettyWriter(AnsiSqlDialect.DEFAULT); + writer.setAlwaysUseParentheses(true); + writer.setSelectListItemsOnSeparateLines(false); + writer.setIndentation(0); + writer.startList("", ""); + for (SqlNode column : columnList) { + writer.sep(","); + if (column instanceof SqlTableColumn) { + SqlTableColumn tableColumn = (SqlTableColumn) column; + tableColumn.getName().unparse(writer, 0, 0); + } else { + column.unparse(writer, 0, 0); + } + } + + return writer.toString(); + } + + @Override + public void unparse( + SqlWriter writer, + int leftPrec, + int rightPrec) { + writer.keyword("CREATE TABLE"); + tableName.unparse(writer, leftPrec, rightPrec); + SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.create("sds"), "(", ")"); + for (SqlNode column : columnList) { + printIndent(writer); + if (column instanceof SqlBasicCall) { + SqlCall call = (SqlCall) column; + SqlCall newCall = call.getOperator().createCall( + SqlParserPos.ZERO, + call.operand(1), + call.operand(0)); + newCall.unparse(writer, leftPrec, rightPrec); + } else { + column.unparse(writer, leftPrec, rightPrec); + } + } + if (primaryKeyList != null && primaryKeyList.size() > 0) { + printIndent(writer); + writer.keyword("PRIMARY KEY"); + SqlWriter.Frame keyFrame = writer.startList("(", ")"); + primaryKeyList.unparse(writer, leftPrec, rightPrec); + writer.endList(keyFrame); + } + if (uniqueKeysList != null && uniqueKeysList.size() > 0) { + printIndent(writer); + for (SqlNodeList uniqueKeyList : uniqueKeysList) { + writer.keyword("UNIQUE"); + SqlWriter.Frame keyFrame = writer.startList("(", ")"); + uniqueKeyList.unparse(writer, leftPrec, rightPrec); + writer.endList(keyFrame); + } + } + writer.newlineAndIndent(); + writer.endList(frame); + + if (comment != null) { + writer.newlineAndIndent(); + writer.keyword("COMMENT"); + comment.unparse(writer, leftPrec, rightPrec); + } + + if (this.partitionKeyList != null) { + writer.newlineAndIndent(); + writer.keyword("PARTITIONED BY"); + SqlWriter.Frame withFrame = writer.startList("(", ")"); + this.partitionKeyList.unparse(writer, leftPrec, rightPrec); + writer.endList(withFrame); + writer.newlineAndIndent(); + } + + if (propertyList != null) { + writer.keyword("WITH"); + SqlWriter.Frame withFrame = writer.startList("(", ")"); + for (SqlNode property : propertyList) { + printIndent(writer); + property.unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); + writer.endList(withFrame); + } + } + + private void printIndent(SqlWriter writer) { + writer.sep(",", false); + writer.newlineAndIndent(); + writer.print(" "); + } + + /** + * Table creation context. + */ + public static class TableCreationContext { + public List columnList = new ArrayList<>(); + public SqlNodeList primaryKeyList; + public List uniqueKeysList = new ArrayList<>(); + } + + public String[] fullTableName() { + return tableName.names.toArray(new String[0]); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlDropTable.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlDropTable.java new file mode 100644 index 00000000000000..ee544c4b125904 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlDropTable.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.calcite.sql.SqlDrop; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.util.ImmutableNullableList; + +import java.util.List; + +/** + * DROP TABLE DDL sql call. + */ +public class SqlDropTable extends SqlDrop { + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("DROP TABLE", SqlKind.DROP_TABLE); + + private SqlIdentifier tableName; + private boolean ifExists; + + public SqlDropTable(SqlParserPos pos, SqlIdentifier tableName, boolean ifExists) { + super(OPERATOR, pos, ifExists); + this.tableName = tableName; + this.ifExists = ifExists; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(tableName); + } + + public SqlIdentifier getTableName() { + return tableName; + } + + public void setTableName(SqlIdentifier viewName) { + this.tableName = viewName; + } + + public boolean getIfExists() { + return this.ifExists; + } + + public void setIfExists(boolean ifExists) { + this.ifExists = ifExists; + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("DROP"); + writer.keyword("TABLE"); + if (ifExists) { + writer.keyword("IF EXISTS"); + } + tableName.unparse(writer, leftPrec, rightPrec); + } + + public void validate() { + // no-op + } + + public String[] fullTableName() { + return tableName.names.toArray(new String[0]); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlMapType.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlMapType.java new file mode 100644 index 00000000000000..98f549ae73aa64 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlMapType.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; + +/** + * Extended Flink MapType. + */ +public class SqlMapType extends SqlIdentifier implements ExtendedSqlType { + + private final SqlDataTypeSpec keyType; + private final SqlDataTypeSpec valType; + + public SqlMapType(SqlParserPos pos, SqlDataTypeSpec keyType, SqlDataTypeSpec valType) { + super(SqlTypeName.MAP.getName(), pos); + this.keyType = keyType; + this.valType = valType; + } + + public SqlDataTypeSpec getKeyType() { + return keyType; + } + + public SqlDataTypeSpec getValType() { + return valType; + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("MAP<"); + ExtendedSqlType.unparseType(keyType, writer, leftPrec, rightPrec); + writer.sep(","); + ExtendedSqlType.unparseType(valType, writer, leftPrec, rightPrec); + writer.keyword(">"); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlRowType.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlRowType.java new file mode 100644 index 00000000000000..39580538a75856 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlRowType.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Pair; + +import java.util.List; + +/** + * Parse column of Row type. + */ +public class SqlRowType extends SqlIdentifier implements ExtendedSqlType { + + private final List fieldNames; + private final List fieldTypes; + + public SqlRowType(SqlParserPos pos, + List fieldNames, + List fieldTypes) { + super(SqlTypeName.ROW.getName(), pos); + this.fieldNames = fieldNames; + this.fieldTypes = fieldTypes; + } + + public List getFieldNames() { + return fieldNames; + } + + public List getFieldTypes() { + return fieldTypes; + } + + public int getArity() { + return fieldNames.size(); + } + + public SqlIdentifier getFieldName(int i) { + return fieldNames.get(i); + } + + public SqlDataTypeSpec getFieldType(int i) { + return fieldTypes.get(i); + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("ROW"); + SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "<", ">"); + for (Pair p : Pair.zip(this.fieldNames, this.fieldTypes)) { + writer.sep(",", false); + p.left.unparse(writer, 0, 0); + writer.sep(":"); + ExtendedSqlType.unparseType(p.right, writer, leftPrec, rightPrec); + } + writer.endList(frame); + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlTableColumn.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlTableColumn.java new file mode 100644 index 00000000000000..bcd578f296c722 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/ddl/SqlTableColumn.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.ddl; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCharStringLiteral; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * Table column of a CREATE TABLE DDL. + */ +public class SqlTableColumn extends SqlCall { + + private SqlIdentifier name; + private SqlDataTypeSpec type; + private SqlCharStringLiteral comment; + + public SqlTableColumn(SqlIdentifier name, + SqlDataTypeSpec type, + SqlCharStringLiteral comment, + SqlParserPos pos) { + super(pos); + this.name = requireNonNull(name, "Column name should not be null"); + this.type = requireNonNull(type, "Column type should not be null"); + this.comment = comment; + } + + @Override + public SqlOperator getOperator() { + return null; + } + + @Override + public List getOperandList() { + return null; + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + this.name.unparse(writer, leftPrec, rightPrec); + writer.print(" "); + ExtendedSqlType.unparseType(type, writer, leftPrec, rightPrec); + if (this.comment != null) { + writer.print(" COMMENT "); + this.comment.unparse(writer, leftPrec, rightPrec); + } + } + + public SqlIdentifier getName() { + return name; + } + + public void setName(SqlIdentifier name) { + this.name = name; + } + + public SqlDataTypeSpec getType() { + return type; + } + + public void setType(SqlDataTypeSpec type) { + this.type = type; + } + + public SqlCharStringLiteral getComment() { + return comment; + } + + public void setComment(SqlCharStringLiteral comment) { + this.comment = comment; + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/error/SqlParseException.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/error/SqlParseException.java new file mode 100644 index 00000000000000..365de20e725e29 --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/error/SqlParseException.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.error; + +import org.apache.calcite.sql.parser.SqlParserPos; + +/** + * SQL parse Exception. This is a simpler version + * of Calcite {@link org.apache.calcite.sql.parser.SqlParseException} + * which is used for SqlNode validation. + */ +public class SqlParseException extends Exception { + + private SqlParserPos errorPosition; + + private String message; + + public SqlParseException(SqlParserPos errorPosition, String message) { + this.errorPosition = errorPosition; + this.message = message; + } + + public SqlParseException(SqlParserPos errorPosition, String message, Exception e) { + super(e); + this.errorPosition = errorPosition; + this.message = message; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public SqlParserPos getErrorPosition() { + return errorPosition; + } + + public void setErrorPosition(SqlParserPos errorPosition) { + this.errorPosition = errorPosition; + } +} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/utils/SqlTimeUnit.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/utils/SqlTimeUnit.java new file mode 100644 index 00000000000000..950399a104aabb --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/utils/SqlTimeUnit.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser.utils; + +import org.apache.calcite.sql.SqlWriter; + +/** SqlTimeUnit used for Flink DDL sql. **/ +public enum SqlTimeUnit { + DAY("DAY", 24 * 3600 * 1000), + HOUR("HOUR", 3600 * 1000), + MINUTE("MINUTE", 60 * 1000), + SECOND("SECOND", 1000), + MILLISECOND("MILLISECOND", 1); + + /** Unparsing keyword. */ + private String keyword; + /** Times used to transform this time unit to millisecond. **/ + private long timeToMillisecond; + + SqlTimeUnit(String keyword, long timeToMillisecond) { + this.keyword = keyword; + this.timeToMillisecond = timeToMillisecond; + } + + public long populateAsMillisecond(int timeInterval) { + return timeToMillisecond * timeInterval; + } + + public void unparse(SqlWriter writer) { + writer.keyword(keyword); + } + +} diff --git a/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java new file mode 100644 index 00000000000000..d6d88be152d67a --- /dev/null +++ b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser; + +import org.apache.flink.sql.parser.ddl.SqlCreateTable; +import org.apache.flink.sql.parser.error.SqlParseException; +import org.apache.flink.sql.parser.impl.FlinkSqlParserImpl; + +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParserImplFactory; +import org.apache.calcite.sql.parser.SqlParserTest; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.junit.Ignore; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + + +/** FlinkSqlParserImpl tests. **/ +public class FlinkSqlParserImplTest extends SqlParserTest { + + protected SqlParserImplFactory parserImplFactory() { + return FlinkSqlParserImpl.FACTORY; + } + + @Test + public void testCreateTable() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " h varchar, \n" + + " g as 2 * (a + 1), \n" + + " ts as toTimestamp(b, 'yyyy-MM-dd HH:mm:ss'), \n" + + " b varchar,\n" + + " proc as PROCTIME(), \n" + + " PRIMARY KEY (a, b)\n" + + ")\n" + + "PARTITIONED BY (a, h)\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `H` VARCHAR,\n" + + " `G` AS (2 * (`A` + 1)),\n" + + " `TS` AS `TOTIMESTAMP`(`B`, 'yyyy-MM-dd HH:mm:ss'),\n" + + " `B` VARCHAR,\n" + + " `PROC` AS `PROCTIME`(),\n" + + " PRIMARY KEY (`A`, `B`)\n" + + ")\n" + + "PARTITIONED BY (`A`, `H`)\n" + + "WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Test + public void testCreateTableWithComment() { + check("CREATE TABLE tbl1 (\n" + + " a bigint comment 'test column comment AAA.',\n" + + " h varchar, \n" + + " g as 2 * (a + 1), \n" + + " ts as toTimestamp(b, 'yyyy-MM-dd HH:mm:ss'), \n" + + " b varchar,\n" + + " proc as PROCTIME(), \n" + + " PRIMARY KEY (a, b)\n" + + ")\n" + + "comment 'test table comment ABC.'\n" + + "PARTITIONED BY (a, h)\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT COMMENT 'test column comment AAA.',\n" + + " `H` VARCHAR,\n" + + " `G` AS (2 * (`A` + 1)),\n" + + " `TS` AS `TOTIMESTAMP`(`B`, 'yyyy-MM-dd HH:mm:ss'),\n" + + " `B` VARCHAR,\n" + + " `PROC` AS `PROCTIME`(),\n" + + " PRIMARY KEY (`A`, `B`)\n" + + ")\n" + + "COMMENT 'test table comment ABC.'\n" + + "PARTITIONED BY (`A`, `H`)\n" + + "WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Test + public void testCreateTableWithPrimaryKeyAndUniqueKey() { + check("CREATE TABLE tbl1 (\n" + + " a bigint comment 'test column comment AAA.',\n" + + " h varchar, \n" + + " g as 2 * (a + 1), \n" + + " ts as toTimestamp(b, 'yyyy-MM-dd HH:mm:ss'), \n" + + " b varchar,\n" + + " proc as PROCTIME(), \n" + + " PRIMARY KEY (a, b), \n" + + " UNIQUE (h, g)\n" + + ")\n" + + "comment 'test table comment ABC.'\n" + + "PARTITIONED BY (a, h)\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT COMMENT 'test column comment AAA.',\n" + + " `H` VARCHAR,\n" + + " `G` AS (2 * (`A` + 1)),\n" + + " `TS` AS `TOTIMESTAMP`(`B`, 'yyyy-MM-dd HH:mm:ss'),\n" + + " `B` VARCHAR,\n" + + " `PROC` AS `PROCTIME`(),\n" + + " PRIMARY KEY (`A`, `B`),\n" + + " UNIQUE (`H`, `G`)\n" + + ")\n" + + "COMMENT 'test table comment ABC.'\n" + + "PARTITIONED BY (`A`, `H`)\n" + + "WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithoutWatermarkFieldName() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK FOR a AS BOUNDED WITH DELAY 1000 MILLISECOND\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK FOR `A` AS BOUNDED WITH DELAY 1000 MILLISECOND\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithWatermarkBoundedDelay() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS BOUNDED WITH DELAY 1000 DAY\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK `WK` FOR `A` AS BOUNDED WITH DELAY 1000 DAY\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithWatermarkBoundedDelay1() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS BOUNDED WITH DELAY 1000 HOUR\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK `WK` FOR `A` AS BOUNDED WITH DELAY 1000 HOUR\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithWatermarkBoundedDelay2() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS BOUNDED WITH DELAY 1000 MINUTE\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK `WK` FOR `A` AS BOUNDED WITH DELAY 1000 MINUTE\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithWatermarkBoundedDelay3() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS BOUNDED WITH DELAY 1000 SECOND\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK `WK` FOR `A` AS BOUNDED WITH DELAY 1000 SECOND\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithNegativeWatermarkOffsetDelay() { + checkFails("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS BOUNDED WITH DELAY ^-^1000 SECOND\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "(?s).*Encountered \"-\" at line 5, column 44.\n" + + "Was expecting:\n" + + " ...\n" + + ".*"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithWatermarkStrategyAscending() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS ASCENDING\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK `WK` FOR `A` AS ASCENDING\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Ignore // need to implement + @Test + public void testCreateTableWithWatermarkStrategyFromSource() { + check("CREATE TABLE tbl1 (\n" + + " a bigint,\n" + + " b varchar, \n" + + " c as 2 * (a + 1), \n" + + " WATERMARK wk FOR a AS FROM_SOURCE\n" + + ")\n" + + " with (\n" + + " connector = 'kafka', \n" + + " kafka.topic = 'log.test'\n" + + ")\n", + "CREATE TABLE `TBL1` (\n" + + " `A` BIGINT,\n" + + " `B` VARCHAR,\n" + + " `C` AS (2 * (`A` + 1)),\n" + + " WATERMARK `WK` FOR `A` AS FROM_SOURCE\n" + + ") WITH (\n" + + " `CONNECTOR` = 'kafka',\n" + + " `KAFKA`.`TOPIC` = 'log.test'\n" + + ")"); + } + + @Test + public void testCreateTableWithComplexType() { + check("CREATE TABLE tbl1 (\n" + + " a ARRAY, \n" + + " b MAP,\n" + + " c ROW,\n" + + " PRIMARY KEY (a, b) \n" + + ") with (\n" + + " x = 'y', \n" + + " asd = 'data'\n" + + ")\n", "CREATE TABLE `TBL1` (\n" + + " `A` ARRAY< BIGINT >,\n" + + " `B` MAP< INTEGER, VARCHAR >,\n" + + " `C` ROW< `CC0` : INTEGER, `CC1` : FLOAT, `CC2` : VARCHAR >,\n" + + " PRIMARY KEY (`A`, `B`)\n" + + ") WITH (\n" + + " `X` = 'y',\n" + + " `ASD` = 'data'\n" + + ")"); + } + + @Test + public void testCreateTableWithDecimalType() { + check("CREATE TABLE tbl1 (\n" + + " a decimal, \n" + + " b decimal(10, 0),\n" + + " c decimal(38, 38),\n" + + " PRIMARY KEY (a, b) \n" + + ") with (\n" + + " x = 'y', \n" + + " asd = 'data'\n" + + ")\n", "CREATE TABLE `TBL1` (\n" + + " `A` DECIMAL,\n" + + " `B` DECIMAL(10, 0),\n" + + " `C` DECIMAL(38, 38),\n" + + " PRIMARY KEY (`A`, `B`)\n" + + ") WITH (\n" + + " `X` = 'y',\n" + + " `ASD` = 'data'\n" + + ")"); + } + + @Test + public void testCreateTableWithNestedComplexType() { + check("CREATE TABLE tbl1 (\n" + + " a ARRAY>, \n" + + " b MAP, ARRAY>,\n" + + " c ROW, cc1: float, cc2: varchar>,\n" + + " PRIMARY KEY (a, b) \n" + + ") with (\n" + + " x = 'y', \n" + + " asd = 'data'\n" + + ")\n", "CREATE TABLE `TBL1` (\n" + + " `A` ARRAY< ARRAY< BIGINT > >,\n" + + " `B` MAP< MAP< INTEGER, VARCHAR >, ARRAY< VARCHAR > >,\n" + + " `C` ROW< `CC0` : ARRAY< INTEGER >, `CC1` : FLOAT, `CC2` : VARCHAR >,\n" + + " PRIMARY KEY (`A`, `B`)\n" + + ") WITH (\n" + + " `X` = 'y',\n" + + " `ASD` = 'data'\n" + + ")"); + } + + @Test + public void testInvalidComputedColumn() { + checkFails("CREATE TABLE sls_stream (\n" + + " a bigint, \n" + + " b varchar,\n" + + " ^toTimestamp^(b, 'yyyy-MM-dd HH:mm:ss'), \n" + + " PRIMARY KEY (a, b) \n" + + ") with (\n" + + " x = 'y', \n" + + " asd = 'data'\n" + + ")\n", "(?s).*Encountered \"toTimestamp \\(\" at line 4, column 3.\n" + + "Was expecting one of:\n" + + " \"CHARACTER\" ...\n" + + " \"CHAR\" ...\n" + + ".*"); + } + + @Test + public void testColumnSqlString() { + String sql = "CREATE TABLE sls_stream (\n" + + " a bigint, \n" + + " f as a + 1, \n" + + " b varchar,\n" + + " ts as toTimestamp(b, 'yyyy-MM-dd HH:mm:ss'), \n" + + " proc as PROCTIME(),\n" + + " c int,\n" + + " PRIMARY KEY (a, b) \n" + + ") with (\n" + + " x = 'y', \n" + + " asd = 'data'\n" + + ")\n"; + String expected = "`A`, (`A` + 1) AS `F`, `B`, " + + "`TOTIMESTAMP`(`B`, 'yyyy-MM-dd HH:mm:ss') AS `TS`, " + + "`PROCTIME`() AS `PROC`, `C`"; + sql(sql).node(new ValidationMatcher() + .expectColumnSql(expected)); + } + + @Test + public void testCreateInvalidPartitionedTable() { + String sql = "create table sls_stream1(\n" + + " a bigint,\n" + + " b VARCHAR,\n" + + " PRIMARY KEY(a, b)\n" + + ") PARTITIONED BY (\n" + + " c,\n" + + " d\n" + + ") with ( x = 'y', asd = 'dada')"; + sql(sql).node(new ValidationMatcher() + .fails("Partition column [C] not defined in columns, at line 6, column 3")); + + } + + @Test + public void testDropTable() { + String sql = "DROP table catalog1.db1.tbl1"; + check(sql, "DROP TABLE `CATALOG1`.`DB1`.`TBL1`"); + } + + @Test + public void testDropIfExists() { + String sql = "DROP table IF EXISTS catalog1.db1.tbl1"; + check(sql, "DROP TABLE IF EXISTS `CATALOG1`.`DB1`.`TBL1`"); + } + + /** Matcher that invokes the #validate() of the produced SqlNode. **/ + private static class ValidationMatcher extends BaseMatcher { + private String expectedColumnSql; + private String failMsg; + + public ValidationMatcher expectColumnSql(String s) { + this.expectedColumnSql = s; + return this; + } + + public ValidationMatcher fails(String failMsg) { + this.failMsg = failMsg; + return this; + } + + @Override + public void describeTo(Description description) { + description.appendText("test"); + } + + @Override + public boolean matches(Object item) { + if (item instanceof SqlCreateTable) { + SqlCreateTable createTable = (SqlCreateTable) item; + try { + createTable.validate(); + } catch (SqlParseException e) { + assertEquals(failMsg, e.getMessage()); + } + if (expectedColumnSql != null) { + assertEquals(expectedColumnSql, createTable.getColumnSqlString()); + } + return true; + } else { + return false; + } + } + } +} diff --git a/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlUnParserTest.java b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlUnParserTest.java new file mode 100644 index 00000000000000..ce3ac2d93e2cb1 --- /dev/null +++ b/flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlUnParserTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.sql.parser; + +/** + * Extension to {@link FlinkSqlParserImplTest} that ensures that every expression can + * un-parse successfully. + */ +public class FlinkSqlUnParserTest extends FlinkSqlParserImplTest { + //~ Constructors ----------------------------------------------------------- + + public FlinkSqlUnParserTest() { + } + + //~ Methods ---------------------------------------------------------------- + + @Override + protected boolean isUnparserTest() { + return true; + } + + @Override + protected Tester getTester() { + return new UnparsingTesterImpl(); + } +} diff --git a/flink-table/flink-table-planner-blink/pom.xml b/flink-table/flink-table-planner-blink/pom.xml index f960b474ddab62..7715207376fee8 100644 --- a/flink-table/flink-table-planner-blink/pom.xml +++ b/flink-table/flink-table-planner-blink/pom.xml @@ -100,6 +100,12 @@ under the License. ${project.version} + + org.apache.flink + flink-sql-parser + ${project.version} + + org.apache.flink flink-scala_${scala.binary.version} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sqlexec/SqlExecutableStatement.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sqlexec/SqlExecutableStatement.java new file mode 100644 index 00000000000000..7ac310687993bf --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/sqlexec/SqlExecutableStatement.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sqlexec; + +import org.apache.flink.sql.parser.ddl.SqlCreateTable; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.TableException; + +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.util.ReflectUtil; +import org.apache.calcite.util.ReflectiveVisitor; + +/** + * Mix-in tool class for {@code SqlNode} that allows DDL commands to be + * executed directly. + * + *

For every kind of {@link SqlNode}, there needs to have a corresponding + * method #execute(type) method, the 'type' argument should be the subclass + * type for the supported {@link SqlNode}. + */ +public class SqlExecutableStatement implements ReflectiveVisitor { + private TableEnvironment tableEnv; + + private final ReflectUtil.MethodDispatcher dispatcher = + ReflectUtil.createMethodDispatcher(Void.class, + this, + "execute", + SqlNode.class); + + //~ Constructors ----------------------------------------------------------- + + private SqlExecutableStatement(TableEnvironment tableEnvironment) { + this.tableEnv = tableEnvironment; + } + + /** + * This is the main entrance of executing all kinds of DDL/DML {@code SqlNode}s, different + * SqlNode will have it's implementation in the #execute(type) method whose 'type' argument + * is subclass of {@code SqlNode}. + * + *

Caution that the {@link #execute(SqlNode)} should never expect to be invoked. + * + * @param tableEnvironment TableEnvironment to interact with + * @param sqlNode SqlNode to execute on + */ + public static void executeSqlNode(TableEnvironment tableEnvironment, SqlNode sqlNode) { + SqlExecutableStatement statement = new SqlExecutableStatement(tableEnvironment); + statement.dispatcher.invoke(sqlNode); + } + + /** + * Execute the {@link SqlCreateTable} node. + */ + public void execute(SqlCreateTable sqlCreateTable) { + // need to implement. + } + + /** Fallback method to throw exception. */ + public void execute(SqlNode node) { + throw new TableException("Should not invoke to node type " + + node.getClass().getSimpleName()); + } +} diff --git a/flink-table/pom.xml b/flink-table/pom.xml index 5213ea54defda5..57918da755d73e 100644 --- a/flink-table/pom.xml +++ b/flink-table/pom.xml @@ -43,6 +43,7 @@ under the License. flink-table-runtime-blink flink-table-uber flink-sql-client + flink-sql-parser From 44c2d7632577b7eeb70ceae73d122be00a18de44 Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Tue, 7 May 2019 12:38:58 +0200 Subject: [PATCH 37/92] [FLINK-12431][table-api-java] Port utility methods for extracting fields information from TypeInformation --- .../flink/table/typeutils/FieldInfoUtils.java | 496 ++++++++++++++++++ .../BuiltInFunctionDefinitions.java | 4 + .../operations/CalculatedTableFactory.java | 6 +- .../flink/table/api/BatchTableEnvImpl.scala | 23 +- .../flink/table/api/StreamTableEnvImpl.scala | 19 +- .../apache/flink/table/api/TableEnvImpl.scala | 220 +------- .../table/api/java/StreamTableEnvImpl.scala | 21 +- .../utils/UserDefinedFunctionUtils.scala | 9 +- 8 files changed, 543 insertions(+), 255 deletions(-) create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java new file mode 100644 index 00000000000000..a0a3cc6e03c180 --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.typeutils; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.expressions.ApiExpressionDefaultVisitor; +import org.apache.flink.table.expressions.BuiltInFunctionDefinitions; +import org.apache.flink.table.expressions.CallExpression; +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ExpressionUtils; +import org.apache.flink.table.expressions.UnresolvedReferenceExpression; +import org.apache.flink.types.Row; + +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static java.lang.String.format; +import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.TIME_ATTRIBUTES; + +/** + * Utility classes for extracting names and indices of fields from different {@link TypeInformation}s. + */ +public class FieldInfoUtils { + + private static final String ATOMIC_FIELD_NAME = "f0"; + + /** + * Describes extracted fields and corresponding indices from a {@link TypeInformation}. + */ + public static class FieldsInfo { + private final String[] fieldNames; + private final int[] indices; + + FieldsInfo(String[] fieldNames, int[] indices) { + this.fieldNames = fieldNames; + this.indices = indices; + } + + public String[] getFieldNames() { + return fieldNames; + } + + public int[] getIndices() { + return indices; + } + } + + /** + * Reference input fields by name: + * All fields in the schema definition are referenced by name + * (and possibly renamed using an alias (as). In this mode, fields can be reordered and + * projected out. Moreover, we can define proctime and rowtime attributes at arbitrary + * positions using arbitrary names (except those that exist in the result schema). This mode + * can be used for any input type, including POJOs. + * + *

Reference input fields by position: + * In this mode, fields are simply renamed. Event-time attributes can + * replace the field on their position in the input data (if it is of correct type) or be + * appended at the end. Proctime attributes must be appended at the end. This mode can only be + * used if the input type has a defined field order (tuple, case class, Row) and no of fields + * references a field of the input type. + */ + public static boolean isReferenceByPosition(CompositeType ct, Expression[] fields) { + if (!(ct instanceof TupleTypeInfoBase)) { + return false; + } + + List inputNames = Arrays.asList(ct.getFieldNames()); + + // Use the by-position mode if no of the fields exists in the input. + // This prevents confusing cases like ('f2, 'f0, 'myName) for a Tuple3 where fields are renamed + // by position but the user might assume reordering instead of renaming. + return Arrays.stream(fields).allMatch(f -> { + if (f instanceof UnresolvedReferenceExpression) { + return !inputNames.contains(((UnresolvedReferenceExpression) f).getName()); + } + + return true; + }); + } + + /** + * Returns field names and field positions for a given {@link TypeInformation}. + * + * @param inputType The TypeInformation extract the field names and positions from. + * @param The type of the TypeInformation. + * @return A tuple of two arrays holding the field names and corresponding field positions. + */ + public static FieldsInfo getFieldsInfo(TypeInformation inputType) { + + if (inputType instanceof GenericTypeInfo && inputType.getTypeClass() == Row.class) { + throw new TableException( + "An input of GenericTypeInfo cannot be converted to Table. " + + "Please specify the type of the input with a RowTypeInfo."); + } else { + return new FieldsInfo(getFieldNames(inputType), getFieldIndices(inputType)); + } + } + + /** + * Returns field names and field positions for a given {@link TypeInformation} and array of + * {@link Expression}. It does not handle time attributes but considers them in indices. + * + * @param inputType The {@link TypeInformation} against which the {@link Expression}s are evaluated. + * @param exprs The expressions that define the field names. + * @param The type of the TypeInformation. + * @return A tuple of two arrays holding the field names and corresponding field positions. + */ + public static FieldsInfo getFieldsInfo(TypeInformation inputType, Expression[] exprs) { + validateInputTypeInfo(inputType); + + final Set fieldInfos; + if (inputType instanceof GenericTypeInfo && inputType.getTypeClass() == Row.class) { + throw new TableException( + "An input of GenericTypeInfo cannot be converted to Table. " + + "Please specify the type of the input with a RowTypeInfo."); + } else if (inputType instanceof TupleTypeInfoBase) { + fieldInfos = extractFieldInfosFromTupleType((CompositeType) inputType, exprs); + } else if (inputType instanceof PojoTypeInfo) { + fieldInfos = extractFieldInfosByNameReference((CompositeType) inputType, exprs); + } else { + fieldInfos = extractFieldInfoFromAtomicType(exprs); + } + + if (fieldInfos.stream().anyMatch(info -> info.getFieldName().equals("*"))) { + throw new TableException("Field name can not be '*'."); + } + + String[] fieldNames = fieldInfos.stream().map(FieldInfo::getFieldName).toArray(String[]::new); + int[] fieldIndices = fieldInfos.stream().mapToInt(FieldInfo::getIndex).toArray(); + return new FieldsInfo(fieldNames, fieldIndices); + } + + /** + * Returns field names for a given {@link TypeInformation}. + * + * @param inputType The TypeInformation extract the field names. + * @param The type of the TypeInformation. + * @return An array holding the field names + */ + public static String[] getFieldNames(TypeInformation inputType) { + validateInputTypeInfo(inputType); + + final String[] fieldNames; + if (inputType instanceof CompositeType) { + fieldNames = ((CompositeType) inputType).getFieldNames(); + } else { + fieldNames = new String[]{ATOMIC_FIELD_NAME}; + } + + if (Arrays.asList(fieldNames).contains("*")) { + throw new TableException("Field name can not be '*'."); + } + + return fieldNames; + } + + /** + * Validate if class represented by the typeInfo is static and globally accessible. + * + * @param typeInfo type to check + * @throws TableException if type does not meet these criteria + */ + public static void validateInputTypeInfo(TypeInformation typeInfo) { + Class clazz = typeInfo.getTypeClass(); + if ((clazz.isMemberClass() && !Modifier.isStatic(clazz.getModifiers())) || + !Modifier.isPublic(clazz.getModifiers()) || + clazz.getCanonicalName() == null) { + throw new TableException(format( + "Class '%s' described in type information '%s' must be " + + "static and globally accessible.", clazz, typeInfo)); + } + } + + /** + * Returns field indexes for a given {@link TypeInformation}. + * + * @param inputType The TypeInformation extract the field positions from. + * @return An array holding the field positions + */ + public static int[] getFieldIndices(TypeInformation inputType) { + return IntStream.range(0, getFieldNames(inputType).length).toArray(); + } + + /** + * Returns field types for a given {@link TypeInformation}. + * + * @param inputType The TypeInformation to extract field types from. + * @return An array holding the field types. + */ + public static TypeInformation[] getFieldTypes(TypeInformation inputType) { + validateInputTypeInfo(inputType); + + final TypeInformation[] fieldTypes; + if (inputType instanceof CompositeType) { + int arity = inputType.getArity(); + CompositeType ct = (CompositeType) inputType; + fieldTypes = IntStream.range(0, arity).mapToObj(ct::getTypeAt).toArray(TypeInformation[]::new); + } else { + fieldTypes = new TypeInformation[]{inputType}; + } + + return fieldTypes; + } + + public static TableSchema calculateTableSchema( + TypeInformation typeInfo, + int[] fieldIndexes, + String[] fieldNames) { + + if (fieldIndexes.length != fieldNames.length) { + throw new TableException(String.format( + "Number of field names and field indexes must be equal.\n" + + "Number of names is %s, number of indexes is %s.\n" + + "List of column names: %s.\n" + + "List of column indexes: %s.", + fieldNames.length, + fieldIndexes.length, + String.join(", ", fieldNames), + Arrays.stream(fieldIndexes).mapToObj(Integer::toString).collect(Collectors.joining(", ")))); + } + + // check uniqueness of field names + Set duplicatedNames = findDuplicates(fieldNames); + if (duplicatedNames.size() != 0) { + + throw new TableException(String.format( + "Field names must be unique.\n" + + "List of duplicate fields: [%s].\n" + + "List of all fields: [%s].", + String.join(", ", duplicatedNames), + String.join(", ", fieldNames))); + } + + final TypeInformation[] types; + long fieldIndicesCount = Arrays.stream(fieldIndexes).filter(i -> i >= 0).count(); + if (typeInfo instanceof CompositeType) { + CompositeType ct = (CompositeType) typeInfo; + // it is ok to leave out fields + if (fieldIndicesCount > ct.getArity()) { + throw new TableException(String.format( + "Arity of type (%s) must not be greater than number of field names %s.", + Arrays.toString(ct.getFieldNames()), + Arrays.toString(fieldNames))); + } + + types = Arrays.stream(fieldIndexes) + .mapToObj(idx -> extractTimeMarkerType(idx).orElseGet(() -> ct.getTypeAt(idx))) + .toArray(TypeInformation[]::new); + } else { + if (fieldIndicesCount > 1) { + throw new TableException( + "Non-composite input type may have only a single field and its index must be 0."); + } + + types = Arrays.stream(fieldIndexes) + .mapToObj(idx -> extractTimeMarkerType(idx).orElse(typeInfo)) + .toArray(TypeInformation[]::new); + } + + return new TableSchema(fieldNames, types); + } + + private static Optional> extractTimeMarkerType(int idx) { + switch (idx) { + case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER: + return Optional.of(TimeIndicatorTypeInfo.ROWTIME_INDICATOR); + case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER: + return Optional.of(TimeIndicatorTypeInfo.PROCTIME_INDICATOR); + case TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER: + case TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER: + return Optional.of(Types.SQL_TIMESTAMP); + default: + return Optional.empty(); + } + } + + + + /* Utility methods */ + + private static Set extractFieldInfoFromAtomicType(Expression[] exprs) { + boolean referenced = false; + FieldInfo fieldInfo = null; + for (Expression expr : exprs) { + if (expr instanceof UnresolvedReferenceExpression) { + if (referenced) { + throw new TableException("Only the first field can reference an atomic type."); + } else { + referenced = true; + fieldInfo = new FieldInfo(((UnresolvedReferenceExpression) expr).getName(), 0); + } + } else if (!isTimeAttribute(expr)) { // IGNORE Time attributes + throw new TableException("Field reference expression expected."); + } + } + + if (fieldInfo != null) { + return Collections.singleton(fieldInfo); + } + + return Collections.emptySet(); + } + + private static Set extractFieldInfosByNameReference(CompositeType inputType, Expression[] exprs) { + ExprToFieldInfo exprToFieldInfo = new ExprToFieldInfo(inputType); + return Arrays.stream(exprs) + .map(expr -> expr.accept(exprToFieldInfo)) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + + private static Set extractFieldInfosFromTupleType(CompositeType inputType, Expression[] exprs) { + boolean isRefByPos = isReferenceByPosition((CompositeType) inputType, exprs); + + if (isRefByPos) { + return IntStream.range(0, exprs.length) + .mapToObj(idx -> exprs[idx].accept(new IndexedExprToFieldInfo(idx))) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } else { + return extractFieldInfosByNameReference(inputType, exprs); + } + } + + private static class FieldInfo { + private final String fieldName; + private final int index; + + FieldInfo(String fieldName, int index) { + this.fieldName = fieldName; + this.index = index; + } + + public String getFieldName() { + return fieldName; + } + + public int getIndex() { + return index; + } + } + + private static class IndexedExprToFieldInfo extends ApiExpressionDefaultVisitor> { + + private final int index; + + private IndexedExprToFieldInfo(int index) { + this.index = index; + } + + @Override + public Optional visitUnresolvedReference(UnresolvedReferenceExpression unresolvedReference) { + String fieldName = unresolvedReference.getName(); + return Optional.of(new FieldInfo(fieldName, index)); + } + + @Override + public Optional visitCall(CallExpression call) { + if (call.getFunctionDefinition() == BuiltInFunctionDefinitions.AS) { + List children = call.getChildren(); + Expression origExpr = children.get(0); + String newName = ExpressionUtils.extractValue(children.get(1), Types.STRING) + .orElseThrow(() -> + new TableException("Alias expects string literal as new name. Got: " + children.get(1))); + + if (origExpr instanceof UnresolvedReferenceExpression) { + throw new TableException( + format("Alias '%s' is not allowed if other fields are referenced by position.", newName)); + } else if (isTimeAttribute(origExpr)) { + return Optional.empty(); + } + } else if (isTimeAttribute(call)) { + return Optional.empty(); + } + + return defaultMethod(call); + } + + @Override + protected Optional defaultMethod(Expression expression) { + throw new TableException("Field reference expression or alias on field expression expected."); + } + } + + private static class ExprToFieldInfo extends ApiExpressionDefaultVisitor> { + + private final CompositeType ct; + + private ExprToFieldInfo(CompositeType ct) { + this.ct = ct; + } + + @Override + public Optional visitUnresolvedReference(UnresolvedReferenceExpression unresolvedReference) { + String fieldName = unresolvedReference.getName(); + return referenceByName(fieldName, ct).map(idx -> new FieldInfo(fieldName, idx)); + } + + @Override + public Optional visitCall(CallExpression call) { + if (call.getFunctionDefinition() == BuiltInFunctionDefinitions.AS) { + List children = call.getChildren(); + Expression origExpr = children.get(0); + String newName = ExpressionUtils.extractValue(children.get(1), Types.STRING) + .orElseThrow(() -> + new TableException("Alias expects string literal as new name. Got: " + children.get(1))); + + if (origExpr instanceof UnresolvedReferenceExpression) { + return referenceByName(((UnresolvedReferenceExpression) origExpr).getName(), ct) + .map(idx -> new FieldInfo(newName, idx)); + } else if (isTimeAttribute(origExpr)) { + return Optional.empty(); + } + } else if (isTimeAttribute(call)) { + return Optional.empty(); + } + + return defaultMethod(call); + } + + @Override + protected Optional defaultMethod(Expression expression) { + throw new TableException("Field reference expression or alias on field expression expected."); + } + } + + private static boolean isTimeAttribute(Expression origExpr) { + return origExpr instanceof CallExpression && + TIME_ATTRIBUTES.contains(((CallExpression) origExpr).getFunctionDefinition()); + } + + private static Optional referenceByName(String name, CompositeType ct) { + int inputIdx = ct.getFieldIndex(name); + if (inputIdx < 0) { + throw new TableException(format( + "%s is not a field of type %s. Expected: %s}", + name, + ct, + String.join(", ", ct.getFieldNames()))); + } else { + return Optional.of(inputIdx); + } + } + + private static Set findDuplicates(T[] array) { + Set duplicates = new HashSet<>(); + Set seenElements = new HashSet<>(); + + for (T t : array) { + if (seenElements.contains(t)) { + duplicates.add(t); + } else { + seenElements.add(t); + } + } + + return duplicates; + } + + private FieldInfoUtils() { + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/BuiltInFunctionDefinitions.java index a326be8a673683..e74e1294ab3d98 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/BuiltInFunctionDefinitions.java @@ -343,6 +343,10 @@ public final class BuiltInFunctionDefinitions { WINDOW_START, WINDOW_END, PROCTIME, ROWTIME )); + public static final Set TIME_ATTRIBUTES = new HashSet<>(Arrays.asList( + PROCTIME, ROWTIME + )); + public static final List ORDERING = Arrays.asList(ORDER_ASC, ORDER_DESC); public static List getDefinitions() { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/CalculatedTableFactory.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/CalculatedTableFactory.java index 3f2953b4398700..860b55155d3fc0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/CalculatedTableFactory.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/CalculatedTableFactory.java @@ -21,7 +21,6 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.table.api.TableEnvImpl$; import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.expressions.ApiExpressionDefaultVisitor; @@ -30,6 +29,7 @@ import org.apache.flink.table.expressions.ExpressionUtils; import org.apache.flink.table.expressions.FunctionDefinition; import org.apache.flink.table.expressions.TableFunctionDefinition; +import org.apache.flink.table.typeutils.FieldInfoUtils; import java.util.Collections; import java.util.List; @@ -103,7 +103,7 @@ private CalculatedTableOperation createFunctionCall( String[] fieldNames; if (aliasesSize == 0) { - fieldNames = TableEnvImpl$.MODULE$.getFieldNames(resultType); + fieldNames = FieldInfoUtils.getFieldNames(resultType); } else if (aliasesSize != callArity) { throw new ValidationException(String.format( "List of column aliases must have same degree as table; " + @@ -116,7 +116,7 @@ private CalculatedTableOperation createFunctionCall( fieldNames = aliases.toArray(new String[aliasesSize]); } - TypeInformation[] fieldTypes = TableEnvImpl$.MODULE$.getFieldTypes(resultType); + TypeInformation[] fieldTypes = FieldInfoUtils.getFieldTypes(resultType); return new CalculatedTableOperation( tableFunctionDefinition.getTableFunction(), diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala index 312aed9a8b7b88..d0ee44088d137a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala @@ -33,7 +33,8 @@ import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} import org.apache.flink.table.catalog.CatalogManager import org.apache.flink.table.descriptors.{BatchTableDescriptor, ConnectorDescriptor} import org.apache.flink.table.explain.PlanJsonParser -import org.apache.flink.table.expressions.{Expression, TimeAttribute} +import org.apache.flink.table.expressions.BuiltInFunctionDefinitions.TIME_ATTRIBUTES +import org.apache.flink.table.expressions.{CallExpression, Expression} import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.dataset.DataSetRel import org.apache.flink.table.plan.rules.FlinkRuleSets @@ -41,6 +42,7 @@ import org.apache.flink.table.plan.schema._ import org.apache.flink.table.runtime.MapRunner import org.apache.flink.table.sinks._ import org.apache.flink.table.sources.{BatchTableSource, TableSource} +import org.apache.flink.table.typeutils.FieldInfoUtils.{getFieldsInfo, validateInputTypeInfo} import org.apache.flink.types.Row /** @@ -319,11 +321,11 @@ abstract class BatchTableEnvImpl( */ protected def registerDataSetInternal[T](name: String, dataSet: DataSet[T]): Unit = { - val (fieldNames, fieldIndexes) = getFieldInfo[T](dataSet.getType) + val fieldInfo = getFieldsInfo[T](dataSet.getType) val dataSetTable = new DataSetTable[T]( dataSet, - fieldIndexes, - fieldNames + fieldInfo.getIndices, + fieldInfo.getFieldNames ) registerTableInternal(name, dataSetTable) } @@ -341,18 +343,19 @@ abstract class BatchTableEnvImpl( name: String, dataSet: DataSet[T], fields: Array[Expression]): Unit = { val inputType = dataSet.getType - val bridgedFields = fields.map(expressionBridge.bridge).toArray[Expression] - val (fieldNames, fieldIndexes) = getFieldInfo[T]( + val fieldsInfo = getFieldsInfo[T]( inputType, - bridgedFields) + fields) - if (bridgedFields.exists(_.isInstanceOf[TimeAttribute])) { + if (fields.exists(f => + f.isInstanceOf[CallExpression] && + TIME_ATTRIBUTES.contains(f.asInstanceOf[CallExpression].getFunctionDefinition))) { throw new ValidationException( ".rowtime and .proctime time indicators are not allowed in a batch environment.") } - val dataSetTable = new DataSetTable[T](dataSet, fieldIndexes, fieldNames) + val dataSetTable = new DataSetTable[T](dataSet, fieldsInfo.getIndices, fieldsInfo.getFieldNames) registerTableInternal(name, dataSetTable) } @@ -416,7 +419,7 @@ abstract class BatchTableEnvImpl( logicalPlan: RelNode, logicalType: RelDataType, queryConfig: BatchQueryConfig)(implicit tpe: TypeInformation[A]): DataSet[A] = { - TableEnvImpl.validateType(tpe) + validateInputTypeInfo(tpe) logicalPlan match { case node: DataSetRel => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala index 62d44a28fb2802..40ecadacbfc58e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala @@ -52,6 +52,7 @@ import org.apache.flink.table.runtime.{CRowMapRunner, OutputRowtimeProcessFuncti import org.apache.flink.table.sinks._ import org.apache.flink.table.sources.{StreamTableSource, TableSource, TableSourceUtil} import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TypeCheckUtils} +import org.apache.flink.table.typeutils.FieldInfoUtils.{getFieldsInfo, isReferenceByPosition} import _root_.scala.collection.JavaConverters._ @@ -443,11 +444,11 @@ abstract class StreamTableEnvImpl( name: String, dataStream: DataStream[T]): Unit = { - val (fieldNames, fieldIndexes) = getFieldInfo[T](dataStream.getType) + val fieldInfo = getFieldsInfo[T](dataStream.getType) val dataStreamTable = new DataStreamTable[T]( dataStream, - fieldIndexes, - fieldNames + fieldInfo.getIndices, + fieldInfo.getFieldNames ) registerTableInternal(name, dataStreamTable) } @@ -468,13 +469,12 @@ abstract class StreamTableEnvImpl( : Unit = { val streamType = dataStream.getType - val bridgedFields = fields.map(expressionBridge.bridge).toArray[Expression] // get field names and types for all non-replaced fields - val (fieldNames, fieldIndexes) = getFieldInfo[T](streamType, bridgedFields) + val fieldsInfo = getFieldsInfo[T](streamType, fields) // validate and extract time attributes - val (rowtime, proctime) = validateAndExtractTimeAttributes(streamType, bridgedFields) + val (rowtime, proctime) = validateAndExtractTimeAttributes(streamType, fields) // check if event-time is enabled if (rowtime.isDefined && execEnv.getStreamTimeCharacteristic != TimeCharacteristic.EventTime) { @@ -484,8 +484,8 @@ abstract class StreamTableEnvImpl( } // adjust field indexes and field names - val indexesWithIndicatorFields = adjustFieldIndexes(fieldIndexes, rowtime, proctime) - val namesWithIndicatorFields = adjustFieldNames(fieldNames, rowtime, proctime) + val indexesWithIndicatorFields = adjustFieldIndexes(fieldsInfo.getIndices, rowtime, proctime) + val namesWithIndicatorFields = adjustFieldNames(fieldsInfo.getFieldNames, rowtime, proctime) val dataStreamTable = new DataStreamTable[T]( dataStream, @@ -593,7 +593,8 @@ abstract class StreamTableEnvImpl( } } - exprs.zipWithIndex.foreach { + val bridgedFields = exprs.map(expressionBridge.bridge).toArray[Expression] + bridgedFields.zipWithIndex.foreach { case (RowtimeAttribute(UnresolvedFieldReference(name)), idx) => extractRowtime(idx, name, None) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala index 5ada6526bd78de..5dd46ccbd02588 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala @@ -37,9 +37,7 @@ import org.apache.calcite.tools._ import org.apache.flink.annotation.VisibleForTesting import org.apache.flink.api.common.functions.MapFunction import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.java.typeutils.{RowTypeInfo, _} -import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo +import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, TupleTypeInfoBase} import org.apache.flink.table.calcite._ import org.apache.flink.table.catalog._ import org.apache.flink.table.codegen.{FunctionCodeGenerator, GeneratedFunction} @@ -720,153 +718,6 @@ abstract class TableEnvImpl( planningConfigurationBuilder.createFlinkPlanner(currentCatalogName, currentDatabase) } - /** - * Reference input fields by name: - * All fields in the schema definition are referenced by name - * (and possibly renamed using an alias (as). In this mode, fields can be reordered and - * projected out. Moreover, we can define proctime and rowtime attributes at arbitrary - * positions using arbitrary names (except those that exist in the result schema). This mode - * can be used for any input type, including POJOs. - * - * Reference input fields by position: - * In this mode, fields are simply renamed. Event-time attributes can - * replace the field on their position in the input data (if it is of correct type) or be - * appended at the end. Proctime attributes must be appended at the end. This mode can only be - * used if the input type has a defined field order (tuple, case class, Row) and no of fields - * references a field of the input type. - */ - protected def isReferenceByPosition(ct: CompositeType[_], fields: Array[Expression]): Boolean = { - if (!ct.isInstanceOf[TupleTypeInfoBase[_]]) { - return false - } - - val inputNames = ct.getFieldNames - - // Use the by-position mode if no of the fields exists in the input. - // This prevents confusing cases like ('f2, 'f0, 'myName) for a Tuple3 where fields are renamed - // by position but the user might assume reordering instead of renaming. - fields.forall { - case UnresolvedFieldReference(name) => !inputNames.contains(name) - case _ => true - } - } - - /** - * Returns field names and field positions for a given [[TypeInformation]]. - * - * @param inputType The TypeInformation extract the field names and positions from. - * @tparam A The type of the TypeInformation. - * @return A tuple of two arrays holding the field names and corresponding field positions. - */ - protected[flink] def getFieldInfo[A](inputType: TypeInformation[A]): - (Array[String], Array[Int]) = { - - if (inputType.isInstanceOf[GenericTypeInfo[A]] && inputType.getTypeClass == classOf[Row]) { - throw new TableException( - "An input of GenericTypeInfo cannot be converted to Table. " + - "Please specify the type of the input with a RowTypeInfo.") - } else { - (TableEnvImpl.getFieldNames(inputType), TableEnvImpl.getFieldIndices(inputType)) - } - } - - /** - * Returns field names and field positions for a given [[TypeInformation]] and [[Array]] of - * [[Expression]]. It does not handle time attributes but considers them in indices. - * - * @param inputType The [[TypeInformation]] against which the [[Expression]]s are evaluated. - * @param exprs The expressions that define the field names. - * @tparam A The type of the TypeInformation. - * @return A tuple of two arrays holding the field names and corresponding field positions. - */ - protected def getFieldInfo[A]( - inputType: TypeInformation[A], - exprs: Array[Expression]) - : (Array[String], Array[Int]) = { - - TableEnvImpl.validateType(inputType) - - def referenceByName(name: String, ct: CompositeType[_]): Option[Int] = { - val inputIdx = ct.getFieldIndex(name) - if (inputIdx < 0) { - throw new TableException(s"$name is not a field of type $ct. " + - s"Expected: ${ct.getFieldNames.mkString(", ")}") - } else { - Some(inputIdx) - } - } - - val indexedNames: Array[(Int, String)] = inputType match { - - case g: GenericTypeInfo[A] if g.getTypeClass == classOf[Row] => - throw new TableException( - "An input of GenericTypeInfo cannot be converted to Table. " + - "Please specify the type of the input with a RowTypeInfo.") - - case t: TupleTypeInfoBase[A] if t.isInstanceOf[TupleTypeInfo[A]] || - t.isInstanceOf[CaseClassTypeInfo[A]] || t.isInstanceOf[RowTypeInfo] => - - // determine schema definition mode (by position or by name) - val isRefByPos = isReferenceByPosition(t, exprs) - - exprs.zipWithIndex flatMap { - case (UnresolvedFieldReference(name: String), idx) => - if (isRefByPos) { - Some((idx, name)) - } else { - referenceByName(name, t).map((_, name)) - } - case (Alias(UnresolvedFieldReference(origName), name: String, _), _) => - if (isRefByPos) { - throw new TableException( - s"Alias '$name' is not allowed if other fields are referenced by position.") - } else { - referenceByName(origName, t).map((_, name)) - } - case (_: TimeAttribute, _) | (Alias(_: TimeAttribute, _, _), _) => - None - case _ => throw new TableException( - "Field reference expression or alias on field expression expected.") - } - - case p: PojoTypeInfo[A] => - exprs flatMap { - case UnresolvedFieldReference(name: String) => - referenceByName(name, p).map((_, name)) - case Alias(UnresolvedFieldReference(origName), name: String, _) => - referenceByName(origName, p).map((_, name)) - case _: TimeAttribute | Alias(_: TimeAttribute, _, _) => - None - case _ => throw new TableException( - "Field reference expression or alias on field expression expected.") - } - - case _: TypeInformation[_] => // atomic or other custom type information - var referenced = false - exprs flatMap { - case _: TimeAttribute => - None - case UnresolvedFieldReference(_) if referenced => - // only accept the first field for an atomic type - throw new TableException("Only the first field can reference an atomic type.") - case UnresolvedFieldReference(name: String) => - referenced = true - // first field reference is mapped to atomic type - Some((0, name)) - case _ => throw new TableException( - "Field reference expression expected.") - } - } - - val (fieldIndexes, fieldNames) = indexedNames.unzip - - if (fieldNames.contains("*")) { - throw new TableException("Field name can not be '*'.") - } - - (fieldNames, fieldIndexes) - } - protected def generateRowConverterFunction[OUT]( inputTypeInfo: TypeInformation[Row], schema: RowSchema, @@ -980,72 +831,3 @@ abstract class TableEnvImpl( Some(generated) } } - -/** - * Object to instantiate a [[TableEnvImpl]] depending on the batch or stream execution environment. - */ -object TableEnvImpl { - - /** - * Returns field names for a given [[TypeInformation]]. - * - * @param inputType The TypeInformation extract the field names. - * @tparam A The type of the TypeInformation. - * @return An array holding the field names - */ - def getFieldNames[A](inputType: TypeInformation[A]): Array[String] = { - validateType(inputType) - - val fieldNames: Array[String] = inputType match { - case t: CompositeType[_] => t.getFieldNames - case _: TypeInformation[_] => Array("f0") - } - - if (fieldNames.contains("*")) { - throw new TableException("Field name can not be '*'.") - } - - fieldNames - } - - /** - * Validate if class represented by the typeInfo is static and globally accessible - * @param typeInfo type to check - * @throws TableException if type does not meet these criteria - */ - def validateType(typeInfo: TypeInformation[_]): Unit = { - val clazz = typeInfo.getTypeClass - if ((clazz.isMemberClass && !Modifier.isStatic(clazz.getModifiers)) || - !Modifier.isPublic(clazz.getModifiers) || - clazz.getCanonicalName == null) { - throw new TableException( - s"Class '$clazz' described in type information '$typeInfo' must be " + - s"static and globally accessible.") - } - } - - /** - * Returns field indexes for a given [[TypeInformation]]. - * - * @param inputType The TypeInformation extract the field positions from. - * @return An array holding the field positions - */ - def getFieldIndices(inputType: TypeInformation[_]): Array[Int] = { - getFieldNames(inputType).indices.toArray - } - - /** - * Returns field types for a given [[TypeInformation]]. - * - * @param inputType The TypeInformation to extract field types from. - * @return An array holding the field types. - */ - def getFieldTypes(inputType: TypeInformation[_]): Array[TypeInformation[_]] = { - validateType(inputType) - - inputType match { - case ct: CompositeType[_] => 0.until(ct.getArity).map(i => ct.getTypeAt(i)).toArray - case t: TypeInformation[_] => Array(t.asInstanceOf[TypeInformation[_]]) - } - } -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala index d87619582f5aaa..235370ee1b91e0 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala @@ -17,17 +17,18 @@ */ package org.apache.flink.table.api.java +import _root_.java.lang.{Boolean => JBool} + import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TypeExtractor} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.table.api._ -import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, TableFunction, UserDefinedAggregateFunction} -import org.apache.flink.table.expressions.ExpressionParser +import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TypeExtractor} import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment -import _root_.java.lang.{Boolean => JBool} - +import org.apache.flink.table.api._ import org.apache.flink.table.catalog.CatalogManager +import org.apache.flink.table.expressions.ExpressionParser +import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, TableFunction, UserDefinedAggregateFunction} +import org.apache.flink.table.typeutils.FieldInfoUtils import _root_.scala.collection.JavaConverters._ @@ -95,7 +96,7 @@ class StreamTableEnvImpl( clazz: Class[T], queryConfig: StreamQueryConfig): DataStream[T] = { val typeInfo = TypeExtractor.createTypeInfo(clazz) - TableEnvImpl.validateType(typeInfo) + FieldInfoUtils.validateInputTypeInfo(typeInfo) translate[T](table, queryConfig, updatesAsRetraction = false, withChangeFlag = false)(typeInfo) } @@ -103,7 +104,7 @@ class StreamTableEnvImpl( table: Table, typeInfo: TypeInformation[T], queryConfig: StreamQueryConfig): DataStream[T] = { - TableEnvImpl.validateType(typeInfo) + FieldInfoUtils.validateInputTypeInfo(typeInfo) translate[T](table, queryConfig, updatesAsRetraction = false, withChangeFlag = false)(typeInfo) } @@ -127,7 +128,7 @@ class StreamTableEnvImpl( queryConfig: StreamQueryConfig): DataStream[JTuple2[JBool, T]] = { val typeInfo = TypeExtractor.createTypeInfo(clazz) - TableEnvImpl.validateType(typeInfo) + FieldInfoUtils.validateInputTypeInfo(typeInfo) val resultType = new TupleTypeInfo[JTuple2[JBool, T]](Types.BOOLEAN, typeInfo) translate[JTuple2[JBool, T]]( table, @@ -141,7 +142,7 @@ class StreamTableEnvImpl( typeInfo: TypeInformation[T], queryConfig: StreamQueryConfig): DataStream[JTuple2[JBool, T]] = { - TableEnvImpl.validateType(typeInfo) + FieldInfoUtils.validateInputTypeInfo(typeInfo) val resultTypeInfo = new TupleTypeInfo[JTuple2[JBool, T]]( Types.BOOLEAN, typeInfo diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index f8443b2598bad2..777aa8e3795699 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -34,11 +34,12 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, TypeExtractor} import org.apache.flink.table.api.dataview._ -import org.apache.flink.table.api.{TableEnvImpl, TableException, ValidationException} +import org.apache.flink.table.api.{TableException, ValidationException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.dataview._ import org.apache.flink.table.functions._ import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl +import org.apache.flink.table.typeutils.FieldInfoUtils import org.apache.flink.util.InstantiationUtil import scala.collection.mutable @@ -699,9 +700,9 @@ object UserDefinedFunctionUtils { def getFieldInfo(inputType: TypeInformation[_]) : (Array[String], Array[Int], Array[TypeInformation[_]]) = { - (TableEnvImpl.getFieldNames(inputType), - TableEnvImpl.getFieldIndices(inputType), - TableEnvImpl.getFieldTypes(inputType)) + (FieldInfoUtils.getFieldNames(inputType), + FieldInfoUtils.getFieldIndices(inputType), + FieldInfoUtils.getFieldTypes(inputType)) } /** From 3934a7f3e7abce4f2ef25391bf62a5754fcfdbcf Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Thu, 23 May 2019 11:06:49 +0200 Subject: [PATCH 38/92] [hotfix][table-planner] Removed TableOperationConverterSupplier. Rather than passing TableOperationConverterSupplier, we just create FlinkRelBuilder whenever we need to convert from TableOperation to RelNode. --- .../table/plan/TableOperationConverter.java | 26 +++---------------- .../planner/PlanningConfigurationBuilder.java | 11 +++----- ...icalCorrelateToTemporalTableJoinRule.scala | 16 +++++++----- 3 files changed, 17 insertions(+), 36 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java index 2a32fe333e06cb..565bca7a205589 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java @@ -64,7 +64,6 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.logical.LogicalTableFunctionScan; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilder.AggCall; import org.apache.calcite.tools.RelBuilder.GroupKey; @@ -89,23 +88,7 @@ @Internal public class TableOperationConverter extends TableOperationDefaultVisitor { - /** - * Supplier for {@link TableOperationConverter} that can wrap given {@link RelBuilder}. - */ - @Internal - public static class ToRelConverterSupplier { - private final ExpressionBridge expressionBridge; - - public ToRelConverterSupplier(ExpressionBridge expressionBridge) { - this.expressionBridge = expressionBridge; - } - - public TableOperationConverter get(RelBuilder relBuilder) { - return new TableOperationConverter(relBuilder, expressionBridge); - } - } - - private final RelBuilder relBuilder; + private final FlinkRelBuilder relBuilder; private final SingleRelVisitor singleRelVisitor = new SingleRelVisitor(); private final ExpressionBridge expressionBridge; private final AggregateVisitor aggregateVisitor = new AggregateVisitor(); @@ -113,7 +96,7 @@ public TableOperationConverter get(RelBuilder relBuilder) { private final JoinExpressionVisitor joinExpressionVisitor = new JoinExpressionVisitor(); public TableOperationConverter( - RelBuilder relBuilder, + FlinkRelBuilder relBuilder, ExpressionBridge expressionBridge) { this.relBuilder = relBuilder; this.expressionBridge = expressionBridge; @@ -148,7 +131,6 @@ public RelNode visitAggregate(AggregateTableOperation aggregate) { @Override public RelNode visitWindowAggregate(WindowAggregateTableOperation windowAggregate) { - FlinkRelBuilder flinkRelBuilder = (FlinkRelBuilder) relBuilder; List aggregations = windowAggregate.getAggregateExpressions() .stream() .map(this::getAggCall) @@ -161,7 +143,7 @@ public RelNode visitWindowAggregate(WindowAggregateTableOperation windowAggregat .collect(toList()); GroupKey groupKey = relBuilder.groupKey(groupings); LogicalWindow logicalWindow = toLogicalWindow(windowAggregate.getGroupWindow()); - return flinkRelBuilder.windowAggregate(logicalWindow, groupKey, windowProperties, aggregations).build(); + return relBuilder.windowAggregate(logicalWindow, groupKey, windowProperties, aggregations).build(); } /** @@ -237,7 +219,7 @@ public RelNode visitCalculatedTable(CalculatedTableOperation calculatedTa fieldNames); TableFunction tableFunction = calculatedTable.getTableFunction(); - FlinkTypeFactory typeFactory = (FlinkTypeFactory) relBuilder.getTypeFactory(); + FlinkTypeFactory typeFactory = relBuilder.getTypeFactory(); TableSqlFunction sqlFunction = new TableSqlFunction( tableFunction.functionIdentifier(), tableFunction.toString(), diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/PlanningConfigurationBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/PlanningConfigurationBuilder.java index 7c2d79a0740b26..09a55f5cd54ac5 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/PlanningConfigurationBuilder.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/PlanningConfigurationBuilder.java @@ -31,7 +31,6 @@ import org.apache.flink.table.codegen.ExpressionReducer; import org.apache.flink.table.expressions.ExpressionBridge; import org.apache.flink.table.expressions.PlannerExpression; -import org.apache.flink.table.plan.TableOperationConverter; import org.apache.flink.table.plan.cost.DataSetCostFactory; import org.apache.flink.table.util.JavaScalaConversionUtil; import org.apache.flink.table.validate.FunctionCatalog; @@ -84,10 +83,9 @@ public PlanningConfigurationBuilder( this.tableConfig = tableConfig; this.functionCatalog = functionCatalog; - // create context instances with Flink type factory - this.context = Contexts.of( - new TableOperationConverter.ToRelConverterSupplier(expressionBridge) - ); + // the converter is needed when calling temporal table functions from SQL, because + // they reference a history table represented with a tree of table operations + this.context = Contexts.of(expressionBridge); this.planner = new VolcanoPlanner(costFactory, context); planner.setExecutor(new ExpressionReducer(tableConfig)); @@ -193,9 +191,6 @@ private FrameworkConfig createFrameworkConfig() { getSqlToRelConverterConfig( calciteConfig(tableConfig), expressionBridge)) - // the converter is needed when calling temporal table functions from SQL, because - // they reference a history table represented with a tree of table operations - .context(context) // set the executor to evaluate constant expressions .executor(new ExpressionReducer(tableConfig)) .build(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala index f4a669968d65d5..0b07f47119a073 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala @@ -34,6 +34,7 @@ import org.apache.flink.table.plan.TableOperationConverter import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin import org.apache.flink.table.plan.util.RexDefaultVisitor import org.apache.flink.util.Preconditions.checkState +import org.apache.flink.table.calcite.FlinkRelBuilder class LogicalCorrelateToTemporalTableJoinRule extends RelOptRule( @@ -82,14 +83,17 @@ class LogicalCorrelateToTemporalTableJoinRule // If TemporalTableFunction was found, rewrite LogicalCorrelate to TemporalJoin val underlyingHistoryTable: TableOperation = rightTemporalTableFunction .getUnderlyingHistoryTable - val relBuilder = this.relBuilderFactory.create( - cluster, - leftNode.getTable.getRelOptSchema) val rexBuilder = cluster.getRexBuilder - val converter = call.getPlanner.getContext - .unwrap(classOf[TableOperationConverter.ToRelConverterSupplier]).get(relBuilder) - val rightNode: RelNode = underlyingHistoryTable.accept(converter) + val expressionBridge = call.getPlanner.getContext + .unwrap(classOf[ExpressionBridge[PlannerExpression]]) + + val relBuilder = new FlinkRelBuilder(call.getPlanner.getContext, + cluster, + leftNode.getTable.getRelOptSchema, + expressionBridge) + + val rightNode: RelNode = relBuilder.tableOperation(underlyingHistoryTable).build() val rightTimeIndicatorExpression = createRightExpression( rexBuilder, From bfd53e9d3f5a221bca8ca82e5f2ab5399d3fa0fd Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Tue, 21 May 2019 17:15:02 +0800 Subject: [PATCH 39/92] [FLINK-12571][network] Make NetworkEnvironment#start() return the binded data port NetworkEnvironment#getConnectionManager is currently used for getting binded data port from ConnectionManager. Considering the general shuffle service architecture, the internal ConnectionManager in NetworkEnvironment should not be exposed to outsides. We could make ShuffleService#start return the binded data port directly if exists, then for other cases it could return a default int value which seems no harm. This closes #8496. --- .../runtime/io/network/ConnectionManager.java | 9 ++++++--- .../runtime/io/network/LocalConnectionManager.java | 8 ++------ .../runtime/io/network/NetworkEnvironment.java | 10 ++++++++-- .../io/network/netty/NettyConnectionManager.java | 14 +++----------- .../runtime/io/network/netty/NettyServer.java | 8 +++----- .../runtime/taskexecutor/TaskManagerServices.java | 4 ++-- .../io/network/TestingConnectionManager.java | 9 +++------ .../network/partition/InputChannelTestUtils.java | 8 ++------ .../StreamNetworkBenchmarkEnvironment.java | 6 ++++-- 9 files changed, 33 insertions(+), 43 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/ConnectionManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/ConnectionManager.java index c342750fb9fcd4..32ec02b5f25aa3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/ConnectionManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/ConnectionManager.java @@ -26,7 +26,12 @@ */ public interface ConnectionManager { - void start() throws IOException; + /** + * Starts the internal related components for network connection and communication. + * + * @return a port to connect to the task executor for shuffle data exchange, -1 if only local connection is possible. + */ + int start() throws IOException; /** * Creates a {@link PartitionRequestClient} instance for the given {@link ConnectionID}. @@ -40,8 +45,6 @@ public interface ConnectionManager { int getNumberOfActiveConnections(); - int getDataPort(); - void shutdown() throws IOException; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/LocalConnectionManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/LocalConnectionManager.java index 319a9eaf3d6a36..5613d19c3ee726 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/LocalConnectionManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/LocalConnectionManager.java @@ -25,7 +25,8 @@ public class LocalConnectionManager implements ConnectionManager { @Override - public void start() { + public int start() { + return -1; } @Override @@ -41,11 +42,6 @@ public int getNumberOfActiveConnections() { return 0; } - @Override - public int getDataPort() { - return -1; - } - @Override public void shutdown() {} } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 43969e2c14369c..561909c1a595f9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -187,6 +187,7 @@ public ResultPartitionManager getResultPartitionManager() { return resultPartitionManager; } + @VisibleForTesting public ConnectionManager getConnectionManager() { return connectionManager; } @@ -317,7 +318,12 @@ public boolean updatePartitionInfo( return true; } - public void start() throws IOException { + /* + * Starts the internal related components for network connection and communication. + * + * @return a port to connect to the task executor for shuffle data exchange, -1 if only local connection is possible. + */ + public int start() throws IOException { synchronized (lock) { Preconditions.checkState(!isShutdown, "The NetworkEnvironment has already been shut down."); @@ -325,7 +331,7 @@ public void start() throws IOException { try { LOG.debug("Starting network connection manager"); - connectionManager.start(); + return connectionManager.start(); } catch (IOException t) { throw new IOException("Failed to instantiate network connection manager.", t); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java index ef3db13bee8d74..3e6a932f8d1ee8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyConnectionManager.java @@ -56,9 +56,10 @@ public NettyConnectionManager( } @Override - public void start() throws IOException { + public int start() throws IOException { client.init(nettyProtocol, bufferPool); - server.init(nettyProtocol, bufferPool); + + return server.init(nettyProtocol, bufferPool); } @Override @@ -77,15 +78,6 @@ public int getNumberOfActiveConnections() { return partitionRequestClientFactory.getNumberOfActiveClients(); } - @Override - public int getDataPort() { - if (server != null && server.getLocalAddress() != null) { - return server.getLocalAddress().getPort(); - } else { - return -1; - } - } - @Override public void shutdown() { client.shutdown(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java index f818ff64016ff0..8bbda5bdd81fb3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyServer.java @@ -64,7 +64,7 @@ class NettyServer { localAddress = null; } - void init(final NettyProtocol protocol, NettyBufferPool nettyBufferPool) throws IOException { + int init(final NettyProtocol protocol, NettyBufferPool nettyBufferPool) throws IOException { checkState(bootstrap == null, "Netty server has already been initialized."); final long start = System.nanoTime(); @@ -164,6 +164,8 @@ public void initChannel(SocketChannel channel) throws Exception { final long duration = (System.nanoTime() - start) / 1_000_000; LOG.info("Successful initialization (took {} ms). Listening on SocketAddress {}.", duration, localAddress); + + return localAddress.getPort(); } NettyConfig getConfig() { @@ -174,10 +176,6 @@ ServerBootstrap getBootstrap() { return bootstrap; } - public InetSocketAddress getLocalAddress() { - return localAddress; - } - void shutdown() { final long start = System.nanoTime(); if (bindFuture != null) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java index e19e8fc60b0630..6e9b864affc8b8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java @@ -248,7 +248,7 @@ public static TaskManagerServices fromConfiguration( final NetworkEnvironment network = NetworkEnvironment.create( taskManagerServicesConfiguration.getNetworkConfig(), taskEventDispatcher, taskManagerMetricGroup, ioManager); - network.start(); + int dataPort = network.start(); final KvStateService kvStateService = KvStateService.fromConfiguration(taskManagerServicesConfiguration); kvStateService.start(); @@ -256,7 +256,7 @@ public static TaskManagerServices fromConfiguration( final TaskManagerLocation taskManagerLocation = new TaskManagerLocation( resourceID, taskManagerServicesConfiguration.getTaskManagerAddress(), - network.getConnectionManager().getDataPort()); + dataPort); // this call has to happen strictly after the network stack has been initialized final MemoryManager memoryManager = createMemoryManager(taskManagerServicesConfiguration, freeHeapMemoryWithDefrag, maxJvmHeapMemory); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TestingConnectionManager.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TestingConnectionManager.java index c23b3c2cb72997..19203fb945df69 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TestingConnectionManager.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TestingConnectionManager.java @@ -27,7 +27,9 @@ public class TestingConnectionManager implements ConnectionManager { @Override - public void start() {} + public int start() { + return -1; + } @Override public PartitionRequestClient createPartitionRequestClient(ConnectionID connectionId) throws IOException { @@ -42,11 +44,6 @@ public int getNumberOfActiveConnections() { return 0; } - @Override - public int getDataPort() { - return -1; - } - @Override public void shutdown() {} } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java index 4ff472e9f645a2..16d6cabfee339e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java @@ -140,7 +140,8 @@ public static RemoteInputChannel createRemoteInputChannel( public static ConnectionManager mockConnectionManagerWithPartitionRequestClient(PartitionRequestClient client) { return new ConnectionManager() { @Override - public void start() { + public int start() { + return -1; } @Override @@ -157,11 +158,6 @@ public int getNumberOfActiveConnections() { return 0; } - @Override - public int getDataPort() { - return 0; - } - @Override public void shutdown() { } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java index fccb5db8ee4d76..4b28961231f59f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java @@ -89,6 +89,8 @@ public class StreamNetworkBenchmarkEnvironment { protected ResultPartitionID[] partitionIds; + private int dataPort; + public void setUp( int writers, int channels, @@ -141,7 +143,7 @@ public void setUp( ioManager = new IOManagerAsync(); senderEnv = createNettyNetworkEnvironment(senderBufferPoolSize, config); - senderEnv.start(); + this.dataPort = senderEnv.start(); if (localMode && senderBufferPoolSize == receiverBufferPoolSize) { receiverEnv = senderEnv; } @@ -163,7 +165,7 @@ public SerializingLongReceiver createReceiver() throws Exception { TaskManagerLocation senderLocation = new TaskManagerLocation( ResourceID.generate(), LOCAL_ADDRESS, - senderEnv.getConnectionManager().getDataPort()); + dataPort); InputGate receiverGate = createInputGate( dataSetID, From d1da5e153f3583eb57aca90282d8f21910c7b0e1 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Wed, 29 May 2019 10:15:19 +0800 Subject: [PATCH 40/92] [hotfix][network] Remove legacy private TaskEventDispatcher from NetworkEnvironment --- .../apache/flink/runtime/io/network/NetworkEnvironment.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 561909c1a595f9..b90e9038eae2fa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -94,8 +94,6 @@ public class NetworkEnvironment { private final Map inputGatesById; - private final TaskEventPublisher taskEventPublisher; - private final ResultPartitionFactory resultPartitionFactory; private final SingleInputGateFactory singleInputGateFactory; @@ -107,7 +105,6 @@ private NetworkEnvironment( NetworkBufferPool networkBufferPool, ConnectionManager connectionManager, ResultPartitionManager resultPartitionManager, - TaskEventPublisher taskEventPublisher, ResultPartitionFactory resultPartitionFactory, SingleInputGateFactory singleInputGateFactory) { this.config = config; @@ -115,7 +112,6 @@ private NetworkEnvironment( this.connectionManager = connectionManager; this.resultPartitionManager = resultPartitionManager; this.inputGatesById = new ConcurrentHashMap<>(); - this.taskEventPublisher = taskEventPublisher; this.resultPartitionFactory = resultPartitionFactory; this.singleInputGateFactory = singleInputGateFactory; this.isShutdown = false; @@ -164,7 +160,6 @@ public static NetworkEnvironment create( networkBufferPool, connectionManager, resultPartitionManager, - taskEventPublisher, resultPartitionFactory, singleInputGateFactory); } From 8b7577cfafa067ca719aac64695c2b5acc8d56f2 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Tue, 28 May 2019 16:09:56 -0700 Subject: [PATCH 41/92] [FLINK-12679][sql-client] Support 'default-database' config for catalog entries in SQL CLI yaml file --- .../client/gateway/local/DependencyTest.java | 20 ++++++++++++---- .../gateway/local/ExecutionContextTest.java | 4 ++-- .../resources/test-sql-client-catalogs.yaml | 7 ++++-- .../table/descriptors/CatalogDescriptor.java | 24 +++++++++++++++++++ .../CatalogDescriptorValidator.java | 6 +++++ 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java index 109246c8909608..7730b0dab35164 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java @@ -41,7 +41,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_DEFAULT_DATABASE; import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; import static org.junit.Assert.assertEquals; @@ -125,6 +127,7 @@ public Map requiredContext() { public List supportedProperties() { final List properties = new ArrayList<>(); properties.add(TEST_PROPERTY); + properties.add(CATALOG_DEFAULT_DATABASE); return properties; } @@ -132,7 +135,14 @@ public List supportedProperties() { public Catalog createCatalog(String name, Map properties) { final DescriptorProperties params = new DescriptorProperties(true); params.putProperties(properties); - return new TestCatalog(name); + + final Optional defaultDatabase = params.getOptionalString(CATALOG_DEFAULT_DATABASE); + + if (defaultDatabase.isPresent()) { + return new TestCatalog(name, defaultDatabase.get()); + } else { + return new TestCatalog(name); + } } } @@ -141,10 +151,12 @@ public Catalog createCatalog(String name, Map properties) { */ public static class TestCatalog extends GenericInMemoryCatalog { - private static final String TEST_DATABASE_NAME = "mydatabase"; - public TestCatalog(String name) { - super(name, TEST_DATABASE_NAME); + super(name); + } + + public TestCatalog(String name, String defaultDatabase) { + super(name, defaultDatabase); } } } diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java index 21cf7d04fd56d6..3783f37b284ccd 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java @@ -73,12 +73,12 @@ public void testExecutionConfig() throws Exception { @Test public void testCatalogs() throws Exception { - final String catalogName = "catalog1"; + final String catalogName = "catalog2"; final ExecutionContext context = createCatalogExecutionContext(); final TableEnvironment tableEnv = context.createEnvironmentInstance().getTableEnvironment(); assertEquals(tableEnv.getCurrentCatalog(), catalogName); - assertEquals(tableEnv.getCurrentDatabase(), "mydatabase"); + assertEquals(tableEnv.getCurrentDatabase(), "test-default-database"); } @Test diff --git a/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml b/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml index e915930812baa0..324ae38e1f0949 100644 --- a/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml +++ b/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml @@ -112,8 +112,8 @@ execution: max-failures-per-interval: 10 failure-rate-interval: 99000 delay: 1000 - current-catalog: catalog1 - current-database: mydatabase + current-catalog: catalog2 + current-database: test-default-database deployment: response-timeout: 5000 @@ -121,3 +121,6 @@ deployment: catalogs: - name: catalog1 type: DependencyTest + - name: catalog2 + type: DependencyTest + default-database: test-default-database diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java index 18b433ec4da43c..753e63c0fb8580 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java @@ -22,6 +22,7 @@ import java.util.Map; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_DEFAULT_DATABASE; import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; @@ -35,6 +36,8 @@ public abstract class CatalogDescriptor extends DescriptorBase { private final int propertyVersion; + private final String defaultDatabase; + /** * Constructs a {@link CatalogDescriptor}. * @@ -42,8 +45,20 @@ public abstract class CatalogDescriptor extends DescriptorBase { * @param propertyVersion property version for backwards compatibility */ public CatalogDescriptor(String type, int propertyVersion) { + this(type, propertyVersion, null); + } + + /** + * Constructs a {@link CatalogDescriptor}. + * + * @param type string that identifies this catalog + * @param propertyVersion property version for backwards compatibility + * @param defaultDatabase default database of the catalog + */ + public CatalogDescriptor(String type, int propertyVersion, String defaultDatabase) { this.type = type; this.propertyVersion = propertyVersion; + this.defaultDatabase = defaultDatabase; } @Override @@ -51,10 +66,19 @@ public final Map toProperties() { final DescriptorProperties properties = new DescriptorProperties(); properties.putString(CATALOG_TYPE, type); properties.putLong(CATALOG_PROPERTY_VERSION, propertyVersion); + + if (defaultDatabase != null) { + properties.putString(CATALOG_DEFAULT_DATABASE, defaultDatabase); + } + properties.putProperties(toCatalogProperties()); return properties.asMap(); } + public String getDefaultDatabase() { + return defaultDatabase; + } + /** * Converts this descriptor into a set of catalog properties. */ diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java index 723dcb013a0758..a907ac79bec5eb 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptorValidator.java @@ -37,9 +37,15 @@ public abstract class CatalogDescriptorValidator implements DescriptorValidator */ public static final String CATALOG_PROPERTY_VERSION = "property-version"; + /** + * Key for describing the default database of the catalog. + */ + public static final String CATALOG_DEFAULT_DATABASE = "default-database"; + @Override public void validate(DescriptorProperties properties) { properties.validateString(CATALOG_TYPE, false, 1); properties.validateInt(CATALOG_PROPERTY_VERSION, true, 0); + properties.validateString(CATALOG_DEFAULT_DATABASE, true, 1); } } From ba648a57c3bf65efb657095a5c682e2a852d1bdf Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Wed, 29 May 2019 12:07:54 -0700 Subject: [PATCH 42/92] [FLINK-12678][table] Add AbstractCatalog to manage the common catalog name and default database name for catalogs --- .../flink/table/catalog/hive/HiveCatalog.java | 79 ++++++++--------- flink-table/flink-table-api-java/pom.xml | 5 ++ .../table/catalog/GenericInMemoryCatalog.java | 85 ++++++++----------- .../flink/table/catalog/AbstractCatalog.java | 51 +++++++++++ 4 files changed, 128 insertions(+), 92 deletions(-) create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java index 159499ca136b96..4081296c7e9624 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java @@ -20,9 +20,9 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.catalog.AbstractCatalog; import org.apache.flink.table.catalog.AbstractCatalogTable; import org.apache.flink.table.catalog.AbstractCatalogView; -import org.apache.flink.table.catalog.Catalog; import org.apache.flink.table.catalog.CatalogBaseTable; import org.apache.flink.table.catalog.CatalogDatabase; import org.apache.flink.table.catalog.CatalogFunction; @@ -90,7 +90,7 @@ /** * A catalog implementation for Hive. */ -public class HiveCatalog implements Catalog { +public class HiveCatalog extends AbstractCatalog { private static final Logger LOG = LoggerFactory.getLogger(HiveCatalog.class); private static final String DEFAULT_DB = "default"; private static final StorageFormatFactory storageFormatFactory = new StorageFormatFactory(); @@ -106,10 +106,8 @@ public class HiveCatalog implements Catalog { // because Hive's Function object doesn't have properties or other place to store the flag for Flink functions. private static final String FLINK_FUNCTION_PREFIX = "flink:"; - protected final String catalogName; protected final HiveConf hiveConf; - private final String defaultDatabase; protected IMetaStoreClient client; public HiveCatalog(String catalogName, String hivemetastoreURI) { @@ -121,10 +119,7 @@ public HiveCatalog(String catalogName, HiveConf hiveConf) { } public HiveCatalog(String catalogName, String defaultDatabase, HiveConf hiveConf) { - checkArgument(!StringUtils.isNullOrWhitespaceOnly(catalogName), "catalogName cannot be null or empty"); - checkArgument(!StringUtils.isNullOrWhitespaceOnly(defaultDatabase), "defaultDatabase cannot be null or empty"); - this.catalogName = catalogName; - this.defaultDatabase = defaultDatabase; + super(catalogName, defaultDatabase); this.hiveConf = checkNotNull(hiveConf, "hiveConf cannot be null"); LOG.info("Created HiveCatalog '{}'", catalogName); @@ -158,9 +153,9 @@ public void open() throws CatalogException { LOG.info("Connected to Hive metastore"); } - if (!databaseExists(defaultDatabase)) { + if (!databaseExists(getDefaultDatabase())) { throw new CatalogException(String.format("Configured default database %s doesn't exist in catalog %s.", - defaultDatabase, catalogName)); + getDefaultDatabase(), getCatalogName())); } } @@ -175,10 +170,6 @@ public void close() throws CatalogException { // ------ databases ------ - public String getDefaultDatabase() throws CatalogException { - return defaultDatabase; - } - @Override public CatalogDatabase getDatabase(String databaseName) throws DatabaseNotExistException, CatalogException { Database hiveDatabase = getHiveDatabase(databaseName); @@ -201,7 +192,7 @@ public void createDatabase(String databaseName, CatalogDatabase database, boolea client.createDatabase(hiveDatabase); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new DatabaseAlreadyExistException(catalogName, hiveDatabase.getName()); + throw new DatabaseAlreadyExistException(getCatalogName(), hiveDatabase.getName()); } } catch (TException e) { throw new CatalogException(String.format("Failed to create database %s", hiveDatabase.getName()), e); @@ -268,7 +259,7 @@ public List listDatabases() throws CatalogException { return client.getAllDatabases(); } catch (TException e) { throw new CatalogException( - String.format("Failed to list all databases in %s", catalogName), e); + String.format("Failed to list all databases in %s", getCatalogName()), e); } } @@ -291,10 +282,10 @@ public void dropDatabase(String name, boolean ignoreIfNotExists) throws Database client.dropDatabase(name, true, ignoreIfNotExists); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new DatabaseNotExistException(catalogName, name); + throw new DatabaseNotExistException(getCatalogName(), name); } } catch (InvalidOperationException e) { - throw new DatabaseNotEmptyException(catalogName, name); + throw new DatabaseNotEmptyException(getCatalogName(), name); } catch (TException e) { throw new CatalogException(String.format("Failed to drop database %s", name), e); } @@ -304,10 +295,10 @@ private Database getHiveDatabase(String databaseName) throws DatabaseNotExistExc try { return client.getDatabase(databaseName); } catch (NoSuchObjectException e) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } catch (TException e) { throw new CatalogException( - String.format("Failed to get database %s from %s", databaseName, catalogName), e); + String.format("Failed to get database %s from %s", databaseName, getCatalogName()), e); } } @@ -328,7 +319,7 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig checkNotNull(table, "table cannot be null"); if (!databaseExists(tablePath.getDatabaseName())) { - throw new DatabaseNotExistException(catalogName, tablePath.getDatabaseName()); + throw new DatabaseNotExistException(getCatalogName(), tablePath.getDatabaseName()); } Table hiveTable = instantiateHiveTable(tablePath, table); @@ -337,7 +328,7 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig client.createTable(hiveTable); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new TableAlreadyExistException(catalogName, tablePath); + throw new TableAlreadyExistException(getCatalogName(), tablePath); } } catch (TException e) { throw new CatalogException(String.format("Failed to create table %s", tablePath.getFullName()), e); @@ -358,14 +349,14 @@ public void renameTable(ObjectPath tablePath, String newTableName, boolean ignor // alter_table() doesn't throw a clear exception when new table already exists. // Thus, check the table existence explicitly if (tableExists(newPath)) { - throw new TableAlreadyExistException(catalogName, newPath); + throw new TableAlreadyExistException(getCatalogName(), newPath); } else { Table table = getHiveTable(tablePath); table.setTableName(newTableName); client.alter_table(tablePath.getDatabaseName(), tablePath.getObjectName(), table); } } else if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } catch (TException e) { throw new CatalogException( @@ -426,7 +417,7 @@ public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists) throws Ta ignoreIfNotExists); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } catch (TException e) { throw new CatalogException( @@ -441,7 +432,7 @@ public List listTables(String databaseName) throws DatabaseNotExistExcep try { return client.getAllTables(databaseName); } catch (UnknownDBException e) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } catch (TException e) { throw new CatalogException( String.format("Failed to list tables in database %s", databaseName), e); @@ -458,7 +449,7 @@ public List listViews(String databaseName) throws DatabaseNotExistExcept null, // table pattern TableType.VIRTUAL_VIEW); } catch (UnknownDBException e) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } catch (TException e) { throw new CatalogException( String.format("Failed to list views in database %s", databaseName), e); @@ -484,7 +475,7 @@ Table getHiveTable(ObjectPath tablePath) throws TableNotExistException { try { return client.getTable(tablePath.getDatabaseName(), tablePath.getObjectName()); } catch (NoSuchObjectException e) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } catch (TException e) { throw new CatalogException( String.format("Failed to get table %s from Hive metastore", tablePath.getFullName()), e); @@ -661,7 +652,7 @@ public void createPartition(ObjectPath tablePath, CatalogPartitionSpec partition client.add_partition(instantiateHivePartition(hiveTable, partitionSpec, partition)); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new PartitionAlreadyExistsException(catalogName, tablePath, partitionSpec); + throw new PartitionAlreadyExistsException(getCatalogName(), tablePath, partitionSpec); } } catch (TException e) { throw new CatalogException( @@ -681,10 +672,10 @@ public void dropPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSp getOrderedFullPartitionValues(partitionSpec, getFieldNames(hiveTable.getPartitionKeys()), tablePath), true); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec, e); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); } } catch (MetaException | TableNotExistException | PartitionSpecInvalidException e) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec, e); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); } catch (TException e) { throw new CatalogException( String.format("Failed to drop partition %s of table %s", partitionSpec, tablePath)); @@ -741,7 +732,7 @@ public CatalogPartition getPartition(ObjectPath tablePath, CatalogPartitionSpec Partition hivePartition = getHivePartition(tablePath, partitionSpec); return instantiateCatalogPartition(hivePartition); } catch (NoSuchObjectException | MetaException | TableNotExistException | PartitionSpecInvalidException e) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec, e); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); } catch (TException e) { throw new CatalogException( String.format("Failed to get partition %s of table %s", partitionSpec, tablePath), e); @@ -769,7 +760,7 @@ public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionS if (ignoreIfNotExists) { return; } - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } Partition newHivePartition = instantiateHivePartition(hiveTable, partitionSpec, newPartition); if (newHivePartition.getSd().getLocation() == null) { @@ -782,10 +773,10 @@ public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionS ); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec, e); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); } } catch (InvalidOperationException | MetaException | TableNotExistException | PartitionSpecInvalidException e) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec, e); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); } catch (TException e) { throw new CatalogException( String.format("Failed to alter existing partition with new partition %s of table %s", @@ -812,7 +803,7 @@ private Partition instantiateHivePartition(Table hiveTable, CatalogPartitionSpec // validate partition values for (int i = 0; i < partCols.size(); i++) { if (StringUtils.isNullOrWhitespaceOnly(partValues.get(i))) { - throw new PartitionSpecInvalidException(catalogName, partCols, + throw new PartitionSpecInvalidException(getCatalogName(), partCols, new ObjectPath(hiveTable.getDbName(), hiveTable.getTableName()), partitionSpec); } } @@ -836,7 +827,7 @@ private static CatalogPartition instantiateCatalogPartition(Partition hivePartit private void ensurePartitionedTable(ObjectPath tablePath, Table hiveTable) throws TableNotPartitionedException { if (hiveTable.getPartitionKeysSize() == 0) { - throw new TableNotPartitionedException(catalogName, tablePath); + throw new TableNotPartitionedException(getCatalogName(), tablePath); } } @@ -879,13 +870,13 @@ private List getOrderedFullPartitionValues(CatalogPartitionSpec partitio throws PartitionSpecInvalidException { Map spec = partitionSpec.getPartitionSpec(); if (spec.size() != partitionKeys.size()) { - throw new PartitionSpecInvalidException(catalogName, partitionKeys, tablePath, partitionSpec); + throw new PartitionSpecInvalidException(getCatalogName(), partitionKeys, tablePath, partitionSpec); } List values = new ArrayList<>(spec.size()); for (String key : partitionKeys) { if (!spec.containsKey(key)) { - throw new PartitionSpecInvalidException(catalogName, partitionKeys, tablePath, partitionSpec); + throw new PartitionSpecInvalidException(getCatalogName(), partitionKeys, tablePath, partitionSpec); } else { values.add(spec.get(key)); } @@ -927,10 +918,10 @@ public void createFunction(ObjectPath functionPath, CatalogFunction function, bo try { client.createFunction(hiveFunction); } catch (NoSuchObjectException e) { - throw new DatabaseNotExistException(catalogName, functionPath.getDatabaseName(), e); + throw new DatabaseNotExistException(getCatalogName(), functionPath.getDatabaseName(), e); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new FunctionAlreadyExistException(catalogName, functionPath, e); + throw new FunctionAlreadyExistException(getCatalogName(), functionPath, e); } } catch (TException e) { throw new CatalogException( @@ -986,7 +977,7 @@ public void dropFunction(ObjectPath functionPath, boolean ignoreIfNotExists) client.dropFunction(functionPath.getDatabaseName(), functionPath.getObjectName()); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new FunctionNotExistException(catalogName, functionPath, e); + throw new FunctionNotExistException(getCatalogName(), functionPath, e); } } catch (TException e) { throw new CatalogException( @@ -1001,7 +992,7 @@ public List listFunctions(String databaseName) throws DatabaseNotExistEx // client.getFunctions() returns empty list when the database doesn't exist // thus we need to explicitly check whether the database exists or not if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } try { @@ -1033,7 +1024,7 @@ public CatalogFunction getFunction(ObjectPath functionPath) throws FunctionNotEx return new HiveCatalogFunction(function.getClassName()); } } catch (NoSuchObjectException e) { - throw new FunctionNotExistException(catalogName, functionPath, e); + throw new FunctionNotExistException(getCatalogName(), functionPath, e); } catch (TException e) { throw new CatalogException( String.format("Failed to get function %s", functionPath.getFullName()), e); diff --git a/flink-table/flink-table-api-java/pom.xml b/flink-table/flink-table-api-java/pom.xml index f8a8fd7c31cbcc..d4eefd34d243fb 100644 --- a/flink-table/flink-table-api-java/pom.xml +++ b/flink-table/flink-table-api-java/pom.xml @@ -52,5 +52,10 @@ under the License. test-jar test + + + org.apache.flink + flink-test-utils-junit + diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java index d46a4238e4a8b3..6d028f07bcb542 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java @@ -47,15 +47,12 @@ /** * A generic catalog implementation that holds all meta objects in memory. */ -public class GenericInMemoryCatalog implements Catalog { +public class GenericInMemoryCatalog extends AbstractCatalog { public static final String FLINK_IS_GENERIC_KEY = "is_generic"; public static final String FLINK_IS_GENERIC_VALUE = "true"; private static final String DEFAULT_DB = "default"; - private final String defaultDatabase; - - private final String catalogName; private final Map databases; private final Map tables; private final Map functions; @@ -71,11 +68,8 @@ public GenericInMemoryCatalog(String name) { } public GenericInMemoryCatalog(String name, String defaultDatabase) { - checkArgument(!StringUtils.isNullOrWhitespaceOnly(name), "name cannot be null or empty"); - checkArgument(!StringUtils.isNullOrWhitespaceOnly(defaultDatabase), "defaultDatabase cannot be null or empty"); + super(name, defaultDatabase); - this.catalogName = name; - this.defaultDatabase = defaultDatabase; this.databases = new LinkedHashMap<>(); this.databases.put(defaultDatabase, new GenericCatalogDatabase(new HashMap<>(), "")); this.tables = new LinkedHashMap<>(); @@ -98,11 +92,6 @@ public void close() { // ------ databases ------ - @Override - public String getDefaultDatabase() { - return defaultDatabase; - } - @Override public void createDatabase(String databaseName, CatalogDatabase db, boolean ignoreIfExists) throws DatabaseAlreadyExistException { @@ -111,7 +100,7 @@ public void createDatabase(String databaseName, CatalogDatabase db, boolean igno if (databaseExists(databaseName)) { if (!ignoreIfExists) { - throw new DatabaseAlreadyExistException(catalogName, databaseName); + throw new DatabaseAlreadyExistException(getCatalogName(), databaseName); } } else { databases.put(databaseName, db.copy()); @@ -129,10 +118,10 @@ public void dropDatabase(String databaseName, boolean ignoreIfNotExists) if (isDatabaseEmpty(databaseName)) { databases.remove(databaseName); } else { - throw new DatabaseNotEmptyException(catalogName, databaseName); + throw new DatabaseNotEmptyException(getCatalogName(), databaseName); } } else if (!ignoreIfNotExists) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } } @@ -161,7 +150,7 @@ public void alterDatabase(String databaseName, CatalogDatabase newDatabase, bool databases.put(databaseName, newDatabase.copy()); } else if (!ignoreIfNotExists) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } } @@ -175,7 +164,7 @@ public CatalogDatabase getDatabase(String databaseName) throws DatabaseNotExistE checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName)); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } else { return databases.get(databaseName).copy(); } @@ -197,12 +186,12 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig checkNotNull(table); if (!databaseExists(tablePath.getDatabaseName())) { - throw new DatabaseNotExistException(catalogName, tablePath.getDatabaseName()); + throw new DatabaseNotExistException(getCatalogName(), tablePath.getDatabaseName()); } if (tableExists(tablePath)) { if (!ignoreIfExists) { - throw new TableAlreadyExistException(catalogName, tablePath); + throw new TableAlreadyExistException(getCatalogName(), tablePath); } } else { tables.put(tablePath, table.copy()); @@ -232,7 +221,7 @@ public void alterTable(ObjectPath tablePath, CatalogBaseTable newTable, boolean tables.put(tablePath, newTable.copy()); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } @@ -251,7 +240,7 @@ public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists) throws Ta partitionStats.remove(tablePath); partitionColumnStats.remove(tablePath); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } @@ -265,7 +254,7 @@ public void renameTable(ObjectPath tablePath, String newTableName, boolean ignor ObjectPath newPath = new ObjectPath(tablePath.getDatabaseName(), newTableName); if (tableExists(newPath)) { - throw new TableAlreadyExistException(catalogName, newPath); + throw new TableAlreadyExistException(getCatalogName(), newPath); } else { tables.put(newPath, tables.remove(tablePath)); @@ -295,7 +284,7 @@ public void renameTable(ObjectPath tablePath, String newTableName, boolean ignor } } } else if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } @@ -304,7 +293,7 @@ public List listTables(String databaseName) throws DatabaseNotExistExcep checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } return tables.keySet().stream() @@ -317,7 +306,7 @@ public List listViews(String databaseName) throws DatabaseNotExistExcept checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } return tables.keySet().stream() @@ -331,7 +320,7 @@ public CatalogBaseTable getTable(ObjectPath tablePath) throws TableNotExistExcep checkNotNull(tablePath); if (!tableExists(tablePath)) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } else { return tables.get(tablePath).copy(); } @@ -346,7 +335,7 @@ public boolean tableExists(ObjectPath tablePath) { private void ensureTableExists(ObjectPath tablePath) throws TableNotExistException { if (!tableExists(tablePath)) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } @@ -359,12 +348,12 @@ public void createFunction(ObjectPath functionPath, CatalogFunction function, bo checkNotNull(function); if (!databaseExists(functionPath.getDatabaseName())) { - throw new DatabaseNotExistException(catalogName, functionPath.getDatabaseName()); + throw new DatabaseNotExistException(getCatalogName(), functionPath.getDatabaseName()); } if (functionExists(functionPath)) { if (!ignoreIfExists) { - throw new FunctionAlreadyExistException(catalogName, functionPath); + throw new FunctionAlreadyExistException(getCatalogName(), functionPath); } } else { functions.put(functionPath, function.copy()); @@ -389,7 +378,7 @@ public void alterFunction(ObjectPath functionPath, CatalogFunction newFunction, functions.put(functionPath, newFunction.copy()); } else if (!ignoreIfNotExists) { - throw new FunctionNotExistException(catalogName, functionPath); + throw new FunctionNotExistException(getCatalogName(), functionPath); } } @@ -400,7 +389,7 @@ public void dropFunction(ObjectPath functionPath, boolean ignoreIfNotExists) thr if (functionExists(functionPath)) { functions.remove(functionPath); } else if (!ignoreIfNotExists) { - throw new FunctionNotExistException(catalogName, functionPath); + throw new FunctionNotExistException(getCatalogName(), functionPath); } } @@ -409,7 +398,7 @@ public List listFunctions(String databaseName) throws DatabaseNotExistEx checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(catalogName, databaseName); + throw new DatabaseNotExistException(getCatalogName(), databaseName); } return functions.keySet().stream() @@ -422,7 +411,7 @@ public CatalogFunction getFunction(ObjectPath functionPath) throws FunctionNotEx checkNotNull(functionPath); if (!functionExists(functionPath)) { - throw new FunctionNotExistException(catalogName, functionPath); + throw new FunctionNotExistException(getCatalogName(), functionPath); } else { return functions.get(functionPath).copy(); } @@ -449,7 +438,7 @@ public void createPartition(ObjectPath tablePath, CatalogPartitionSpec partition if (partitionExists(tablePath, partitionSpec)) { if (!ignoreIfExists) { - throw new PartitionAlreadyExistsException(catalogName, tablePath, partitionSpec); + throw new PartitionAlreadyExistsException(getCatalogName(), tablePath, partitionSpec); } } @@ -467,7 +456,7 @@ public void dropPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSp partitionStats.get(tablePath).remove(partitionSpec); partitionColumnStats.get(tablePath).remove(partitionSpec); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } } @@ -490,7 +479,7 @@ public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionS partitions.get(tablePath).put(partitionSpec, newPartition.copy()); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } } @@ -532,7 +521,7 @@ public CatalogPartition getPartition(ObjectPath tablePath, CatalogPartitionSpec checkNotNull(partitionSpec); if (!partitionExists(tablePath, partitionSpec)) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } return partitions.get(tablePath).get(partitionSpec).copy(); @@ -550,7 +539,7 @@ public boolean partitionExists(ObjectPath tablePath, CatalogPartitionSpec partit private void ensureFullPartitionSpec(ObjectPath tablePath, CatalogPartitionSpec partitionSpec) throws TableNotExistException, PartitionSpecInvalidException { if (!isFullPartitionSpec(tablePath, partitionSpec)) { - throw new PartitionSpecInvalidException(catalogName, ((CatalogTable) getTable(tablePath)).getPartitionKeys(), + throw new PartitionSpecInvalidException(getCatalogName(), ((CatalogTable) getTable(tablePath)).getPartitionKeys(), tablePath, partitionSpec); } } @@ -575,7 +564,7 @@ private boolean isFullPartitionSpec(ObjectPath tablePath, CatalogPartitionSpec p private void ensurePartitionedTable(ObjectPath tablePath) throws TableNotPartitionedException { if (!isPartitionedTable(tablePath)) { - throw new TableNotPartitionedException(catalogName, tablePath); + throw new TableNotPartitionedException(getCatalogName(), tablePath); } } @@ -601,7 +590,7 @@ public CatalogTableStatistics getTableStatistics(ObjectPath tablePath) throws Ta checkNotNull(tablePath); if (!tableExists(tablePath)) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } CatalogTableStatistics result = tableStats.get(tablePath); @@ -613,7 +602,7 @@ public CatalogColumnStatistics getTableColumnStatistics(ObjectPath tablePath) th checkNotNull(tablePath); if (!tableExists(tablePath)) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } CatalogColumnStatistics result = tableColumnStats.get(tablePath); @@ -627,7 +616,7 @@ public CatalogTableStatistics getPartitionStatistics(ObjectPath tablePath, Catal checkNotNull(partitionSpec); if (!partitionExists(tablePath, partitionSpec)) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } CatalogTableStatistics result = partitionStats.get(tablePath).get(partitionSpec); @@ -641,7 +630,7 @@ public CatalogColumnStatistics getPartitionColumnStatistics(ObjectPath tablePath checkNotNull(partitionSpec); if (!partitionExists(tablePath, partitionSpec)) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } CatalogColumnStatistics result = partitionColumnStats.get(tablePath).get(partitionSpec); @@ -657,7 +646,7 @@ public void alterTableStatistics(ObjectPath tablePath, CatalogTableStatistics ta if (tableExists(tablePath)) { tableStats.put(tablePath, tableStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } @@ -670,7 +659,7 @@ public void alterTableColumnStatistics(ObjectPath tablePath, CatalogColumnStatis if (tableExists(tablePath)) { tableColumnStats.put(tablePath, columnStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(catalogName, tablePath); + throw new TableNotExistException(getCatalogName(), tablePath); } } @@ -684,7 +673,7 @@ public void alterPartitionStatistics(ObjectPath tablePath, CatalogPartitionSpec if (partitionExists(tablePath, partitionSpec)) { partitionStats.get(tablePath).put(partitionSpec, partitionStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } } @@ -698,7 +687,7 @@ public void alterPartitionColumnStatistics(ObjectPath tablePath, CatalogPartitio if (partitionExists(tablePath, partitionSpec)) { partitionColumnStats.get(tablePath).put(partitionSpec, columnStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(catalogName, tablePath, partitionSpec); + throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java new file mode 100644 index 00000000000000..8d4d95749446e3 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.util.StringUtils; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * This interface is responsible for reading and writing metadata such as database/table/views/UDFs + * from a registered catalog. It connects a registered catalog and Flink's Table API. + */ +@PublicEvolving +public abstract class AbstractCatalog implements Catalog { + private final String catalogName; + private final String defaultDatabase; + + public AbstractCatalog(String catalogName, String defaultDatabase) { + checkArgument(!StringUtils.isNullOrWhitespaceOnly(catalogName), "catalogName cannot be null or empty"); + checkArgument(!StringUtils.isNullOrWhitespaceOnly(defaultDatabase), "defaultDatabase cannot be null or empty"); + + this.catalogName = catalogName; + this.defaultDatabase = defaultDatabase; + } + + public String getCatalogName() { + return catalogName; + } + + @Override + public String getDefaultDatabase() { + return defaultDatabase; + } +} From c691c1381fe486068a3beb6bab38f3a29b1cc255 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Wed, 29 May 2019 12:20:18 -0700 Subject: [PATCH 43/92] [FLINK-12676][table][sql-client] Add descriptor, validator, and factory of GenericInMemoryCatalog for table discovery service This closes #8567. --- .../flink/table/catalog/hive/HiveCatalog.java | 62 +++++++-------- .../client/gateway/local/DependencyTest.java | 11 +-- .../gateway/local/ExecutionContextTest.java | 4 +- .../resources/test-sql-client-catalogs.yaml | 9 ++- .../table/catalog/GenericInMemoryCatalog.java | 72 +++++++++--------- .../GenericInMemoryCatalogFactory.java | 76 +++++++++++++++++++ .../GenericInMemoryCatalogDescriptor.java | 44 +++++++++++ .../GenericInMemoryCatalogValidator.java | 32 ++++++++ ....apache.flink.table.factories.TableFactory | 16 ++++ .../GenericInMemoryCatalogFactoryTest.java | 67 ++++++++++++++++ .../GenericInMemoryCatalogDescriptorTest.java | 66 ++++++++++++++++ .../flink/table/catalog/AbstractCatalog.java | 11 ++- .../table/descriptors/CatalogDescriptor.java | 4 + 13 files changed, 386 insertions(+), 88 deletions(-) create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactory.java create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogDescriptor.java create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogValidator.java create mode 100644 flink-table/flink-table-api-java/src/main/resources/META-INF/services/org.apache.flink.table.factories.TableFactory create mode 100644 flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactoryTest.java create mode 100644 flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/descriptor/GenericInMemoryCatalogDescriptorTest.java diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java index 4081296c7e9624..8022c453084949 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java @@ -155,7 +155,7 @@ public void open() throws CatalogException { if (!databaseExists(getDefaultDatabase())) { throw new CatalogException(String.format("Configured default database %s doesn't exist in catalog %s.", - getDefaultDatabase(), getCatalogName())); + getDefaultDatabase(), getName())); } } @@ -192,7 +192,7 @@ public void createDatabase(String databaseName, CatalogDatabase database, boolea client.createDatabase(hiveDatabase); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new DatabaseAlreadyExistException(getCatalogName(), hiveDatabase.getName()); + throw new DatabaseAlreadyExistException(getName(), hiveDatabase.getName()); } } catch (TException e) { throw new CatalogException(String.format("Failed to create database %s", hiveDatabase.getName()), e); @@ -259,7 +259,7 @@ public List listDatabases() throws CatalogException { return client.getAllDatabases(); } catch (TException e) { throw new CatalogException( - String.format("Failed to list all databases in %s", getCatalogName()), e); + String.format("Failed to list all databases in %s", getName()), e); } } @@ -282,10 +282,10 @@ public void dropDatabase(String name, boolean ignoreIfNotExists) throws Database client.dropDatabase(name, true, ignoreIfNotExists); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new DatabaseNotExistException(getCatalogName(), name); + throw new DatabaseNotExistException(getName(), name); } } catch (InvalidOperationException e) { - throw new DatabaseNotEmptyException(getCatalogName(), name); + throw new DatabaseNotEmptyException(getName(), name); } catch (TException e) { throw new CatalogException(String.format("Failed to drop database %s", name), e); } @@ -295,10 +295,10 @@ private Database getHiveDatabase(String databaseName) throws DatabaseNotExistExc try { return client.getDatabase(databaseName); } catch (NoSuchObjectException e) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } catch (TException e) { throw new CatalogException( - String.format("Failed to get database %s from %s", databaseName, getCatalogName()), e); + String.format("Failed to get database %s from %s", databaseName, getName()), e); } } @@ -319,7 +319,7 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig checkNotNull(table, "table cannot be null"); if (!databaseExists(tablePath.getDatabaseName())) { - throw new DatabaseNotExistException(getCatalogName(), tablePath.getDatabaseName()); + throw new DatabaseNotExistException(getName(), tablePath.getDatabaseName()); } Table hiveTable = instantiateHiveTable(tablePath, table); @@ -328,7 +328,7 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig client.createTable(hiveTable); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new TableAlreadyExistException(getCatalogName(), tablePath); + throw new TableAlreadyExistException(getName(), tablePath); } } catch (TException e) { throw new CatalogException(String.format("Failed to create table %s", tablePath.getFullName()), e); @@ -349,14 +349,14 @@ public void renameTable(ObjectPath tablePath, String newTableName, boolean ignor // alter_table() doesn't throw a clear exception when new table already exists. // Thus, check the table existence explicitly if (tableExists(newPath)) { - throw new TableAlreadyExistException(getCatalogName(), newPath); + throw new TableAlreadyExistException(getName(), newPath); } else { Table table = getHiveTable(tablePath); table.setTableName(newTableName); client.alter_table(tablePath.getDatabaseName(), tablePath.getObjectName(), table); } } else if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } catch (TException e) { throw new CatalogException( @@ -417,7 +417,7 @@ public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists) throws Ta ignoreIfNotExists); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } catch (TException e) { throw new CatalogException( @@ -432,7 +432,7 @@ public List listTables(String databaseName) throws DatabaseNotExistExcep try { return client.getAllTables(databaseName); } catch (UnknownDBException e) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } catch (TException e) { throw new CatalogException( String.format("Failed to list tables in database %s", databaseName), e); @@ -449,7 +449,7 @@ public List listViews(String databaseName) throws DatabaseNotExistExcept null, // table pattern TableType.VIRTUAL_VIEW); } catch (UnknownDBException e) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } catch (TException e) { throw new CatalogException( String.format("Failed to list views in database %s", databaseName), e); @@ -475,7 +475,7 @@ Table getHiveTable(ObjectPath tablePath) throws TableNotExistException { try { return client.getTable(tablePath.getDatabaseName(), tablePath.getObjectName()); } catch (NoSuchObjectException e) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } catch (TException e) { throw new CatalogException( String.format("Failed to get table %s from Hive metastore", tablePath.getFullName()), e); @@ -652,7 +652,7 @@ public void createPartition(ObjectPath tablePath, CatalogPartitionSpec partition client.add_partition(instantiateHivePartition(hiveTable, partitionSpec, partition)); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new PartitionAlreadyExistsException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionAlreadyExistsException(getName(), tablePath, partitionSpec); } } catch (TException e) { throw new CatalogException( @@ -672,10 +672,10 @@ public void dropPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSp getOrderedFullPartitionValues(partitionSpec, getFieldNames(hiveTable.getPartitionKeys()), tablePath), true); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec, e); } } catch (MetaException | TableNotExistException | PartitionSpecInvalidException e) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec, e); } catch (TException e) { throw new CatalogException( String.format("Failed to drop partition %s of table %s", partitionSpec, tablePath)); @@ -732,7 +732,7 @@ public CatalogPartition getPartition(ObjectPath tablePath, CatalogPartitionSpec Partition hivePartition = getHivePartition(tablePath, partitionSpec); return instantiateCatalogPartition(hivePartition); } catch (NoSuchObjectException | MetaException | TableNotExistException | PartitionSpecInvalidException e) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec, e); } catch (TException e) { throw new CatalogException( String.format("Failed to get partition %s of table %s", partitionSpec, tablePath), e); @@ -760,7 +760,7 @@ public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionS if (ignoreIfNotExists) { return; } - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } Partition newHivePartition = instantiateHivePartition(hiveTable, partitionSpec, newPartition); if (newHivePartition.getSd().getLocation() == null) { @@ -773,10 +773,10 @@ public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionS ); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec, e); } } catch (InvalidOperationException | MetaException | TableNotExistException | PartitionSpecInvalidException e) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec, e); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec, e); } catch (TException e) { throw new CatalogException( String.format("Failed to alter existing partition with new partition %s of table %s", @@ -803,7 +803,7 @@ private Partition instantiateHivePartition(Table hiveTable, CatalogPartitionSpec // validate partition values for (int i = 0; i < partCols.size(); i++) { if (StringUtils.isNullOrWhitespaceOnly(partValues.get(i))) { - throw new PartitionSpecInvalidException(getCatalogName(), partCols, + throw new PartitionSpecInvalidException(getName(), partCols, new ObjectPath(hiveTable.getDbName(), hiveTable.getTableName()), partitionSpec); } } @@ -827,7 +827,7 @@ private static CatalogPartition instantiateCatalogPartition(Partition hivePartit private void ensurePartitionedTable(ObjectPath tablePath, Table hiveTable) throws TableNotPartitionedException { if (hiveTable.getPartitionKeysSize() == 0) { - throw new TableNotPartitionedException(getCatalogName(), tablePath); + throw new TableNotPartitionedException(getName(), tablePath); } } @@ -870,13 +870,13 @@ private List getOrderedFullPartitionValues(CatalogPartitionSpec partitio throws PartitionSpecInvalidException { Map spec = partitionSpec.getPartitionSpec(); if (spec.size() != partitionKeys.size()) { - throw new PartitionSpecInvalidException(getCatalogName(), partitionKeys, tablePath, partitionSpec); + throw new PartitionSpecInvalidException(getName(), partitionKeys, tablePath, partitionSpec); } List values = new ArrayList<>(spec.size()); for (String key : partitionKeys) { if (!spec.containsKey(key)) { - throw new PartitionSpecInvalidException(getCatalogName(), partitionKeys, tablePath, partitionSpec); + throw new PartitionSpecInvalidException(getName(), partitionKeys, tablePath, partitionSpec); } else { values.add(spec.get(key)); } @@ -918,10 +918,10 @@ public void createFunction(ObjectPath functionPath, CatalogFunction function, bo try { client.createFunction(hiveFunction); } catch (NoSuchObjectException e) { - throw new DatabaseNotExistException(getCatalogName(), functionPath.getDatabaseName(), e); + throw new DatabaseNotExistException(getName(), functionPath.getDatabaseName(), e); } catch (AlreadyExistsException e) { if (!ignoreIfExists) { - throw new FunctionAlreadyExistException(getCatalogName(), functionPath, e); + throw new FunctionAlreadyExistException(getName(), functionPath, e); } } catch (TException e) { throw new CatalogException( @@ -977,7 +977,7 @@ public void dropFunction(ObjectPath functionPath, boolean ignoreIfNotExists) client.dropFunction(functionPath.getDatabaseName(), functionPath.getObjectName()); } catch (NoSuchObjectException e) { if (!ignoreIfNotExists) { - throw new FunctionNotExistException(getCatalogName(), functionPath, e); + throw new FunctionNotExistException(getName(), functionPath, e); } } catch (TException e) { throw new CatalogException( @@ -992,7 +992,7 @@ public List listFunctions(String databaseName) throws DatabaseNotExistEx // client.getFunctions() returns empty list when the database doesn't exist // thus we need to explicitly check whether the database exists or not if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } try { @@ -1024,7 +1024,7 @@ public CatalogFunction getFunction(ObjectPath functionPath) throws FunctionNotEx return new HiveCatalogFunction(function.getClassName()); } } catch (NoSuchObjectException e) { - throw new FunctionNotExistException(getCatalogName(), functionPath, e); + throw new FunctionNotExistException(getName(), functionPath, e); } catch (TException e) { throw new CatalogException( String.format("Failed to get function %s", functionPath.getFullName()), e); diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java index 7730b0dab35164..c3d0ba0735ec7f 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/DependencyTest.java @@ -138,11 +138,7 @@ public Catalog createCatalog(String name, Map properties) { final Optional defaultDatabase = params.getOptionalString(CATALOG_DEFAULT_DATABASE); - if (defaultDatabase.isPresent()) { - return new TestCatalog(name, defaultDatabase.get()); - } else { - return new TestCatalog(name); - } + return new TestCatalog(name, defaultDatabase.orElse(GenericInMemoryCatalog.DEFAULT_DB)); } } @@ -150,11 +146,6 @@ public Catalog createCatalog(String name, Map properties) { * Test catalog. */ public static class TestCatalog extends GenericInMemoryCatalog { - - public TestCatalog(String name) { - super(name); - } - public TestCatalog(String name, String defaultDatabase) { super(name, defaultDatabase); } diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java index 3783f37b284ccd..5137f6316a77b0 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java @@ -73,12 +73,12 @@ public void testExecutionConfig() throws Exception { @Test public void testCatalogs() throws Exception { - final String catalogName = "catalog2"; + final String catalogName = "inmemorycatalog"; final ExecutionContext context = createCatalogExecutionContext(); final TableEnvironment tableEnv = context.createEnvironmentInstance().getTableEnvironment(); assertEquals(tableEnv.getCurrentCatalog(), catalogName); - assertEquals(tableEnv.getCurrentDatabase(), "test-default-database"); + assertEquals(tableEnv.getCurrentDatabase(), "mydatabase"); } @Test diff --git a/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml b/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml index 324ae38e1f0949..a1b2dcbc6341d2 100644 --- a/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml +++ b/flink-table/flink-sql-client/src/test/resources/test-sql-client-catalogs.yaml @@ -112,8 +112,8 @@ execution: max-failures-per-interval: 10 failure-rate-interval: 99000 delay: 1000 - current-catalog: catalog2 - current-database: test-default-database + current-catalog: inmemorycatalog + current-database: mydatabase deployment: response-timeout: 5000 @@ -123,4 +123,7 @@ catalogs: type: DependencyTest - name: catalog2 type: DependencyTest - default-database: test-default-database + default-database: mydatabase + - name: inmemorycatalog + type: generic_in_memory + default-database: mydatabase diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java index 6d028f07bcb542..e625072f0511df 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalog.java @@ -51,7 +51,7 @@ public class GenericInMemoryCatalog extends AbstractCatalog { public static final String FLINK_IS_GENERIC_KEY = "is_generic"; public static final String FLINK_IS_GENERIC_VALUE = "true"; - private static final String DEFAULT_DB = "default"; + public static final String DEFAULT_DB = "default"; private final Map databases; private final Map tables; @@ -100,7 +100,7 @@ public void createDatabase(String databaseName, CatalogDatabase db, boolean igno if (databaseExists(databaseName)) { if (!ignoreIfExists) { - throw new DatabaseAlreadyExistException(getCatalogName(), databaseName); + throw new DatabaseAlreadyExistException(getName(), databaseName); } } else { databases.put(databaseName, db.copy()); @@ -118,10 +118,10 @@ public void dropDatabase(String databaseName, boolean ignoreIfNotExists) if (isDatabaseEmpty(databaseName)) { databases.remove(databaseName); } else { - throw new DatabaseNotEmptyException(getCatalogName(), databaseName); + throw new DatabaseNotEmptyException(getName(), databaseName); } } else if (!ignoreIfNotExists) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } } @@ -150,7 +150,7 @@ public void alterDatabase(String databaseName, CatalogDatabase newDatabase, bool databases.put(databaseName, newDatabase.copy()); } else if (!ignoreIfNotExists) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } } @@ -164,7 +164,7 @@ public CatalogDatabase getDatabase(String databaseName) throws DatabaseNotExistE checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName)); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } else { return databases.get(databaseName).copy(); } @@ -186,12 +186,12 @@ public void createTable(ObjectPath tablePath, CatalogBaseTable table, boolean ig checkNotNull(table); if (!databaseExists(tablePath.getDatabaseName())) { - throw new DatabaseNotExistException(getCatalogName(), tablePath.getDatabaseName()); + throw new DatabaseNotExistException(getName(), tablePath.getDatabaseName()); } if (tableExists(tablePath)) { if (!ignoreIfExists) { - throw new TableAlreadyExistException(getCatalogName(), tablePath); + throw new TableAlreadyExistException(getName(), tablePath); } } else { tables.put(tablePath, table.copy()); @@ -221,7 +221,7 @@ public void alterTable(ObjectPath tablePath, CatalogBaseTable newTable, boolean tables.put(tablePath, newTable.copy()); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } @@ -240,7 +240,7 @@ public void dropTable(ObjectPath tablePath, boolean ignoreIfNotExists) throws Ta partitionStats.remove(tablePath); partitionColumnStats.remove(tablePath); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } @@ -254,7 +254,7 @@ public void renameTable(ObjectPath tablePath, String newTableName, boolean ignor ObjectPath newPath = new ObjectPath(tablePath.getDatabaseName(), newTableName); if (tableExists(newPath)) { - throw new TableAlreadyExistException(getCatalogName(), newPath); + throw new TableAlreadyExistException(getName(), newPath); } else { tables.put(newPath, tables.remove(tablePath)); @@ -284,7 +284,7 @@ public void renameTable(ObjectPath tablePath, String newTableName, boolean ignor } } } else if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } @@ -293,7 +293,7 @@ public List listTables(String databaseName) throws DatabaseNotExistExcep checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } return tables.keySet().stream() @@ -306,7 +306,7 @@ public List listViews(String databaseName) throws DatabaseNotExistExcept checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } return tables.keySet().stream() @@ -320,7 +320,7 @@ public CatalogBaseTable getTable(ObjectPath tablePath) throws TableNotExistExcep checkNotNull(tablePath); if (!tableExists(tablePath)) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } else { return tables.get(tablePath).copy(); } @@ -335,7 +335,7 @@ public boolean tableExists(ObjectPath tablePath) { private void ensureTableExists(ObjectPath tablePath) throws TableNotExistException { if (!tableExists(tablePath)) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } @@ -348,12 +348,12 @@ public void createFunction(ObjectPath functionPath, CatalogFunction function, bo checkNotNull(function); if (!databaseExists(functionPath.getDatabaseName())) { - throw new DatabaseNotExistException(getCatalogName(), functionPath.getDatabaseName()); + throw new DatabaseNotExistException(getName(), functionPath.getDatabaseName()); } if (functionExists(functionPath)) { if (!ignoreIfExists) { - throw new FunctionAlreadyExistException(getCatalogName(), functionPath); + throw new FunctionAlreadyExistException(getName(), functionPath); } } else { functions.put(functionPath, function.copy()); @@ -378,7 +378,7 @@ public void alterFunction(ObjectPath functionPath, CatalogFunction newFunction, functions.put(functionPath, newFunction.copy()); } else if (!ignoreIfNotExists) { - throw new FunctionNotExistException(getCatalogName(), functionPath); + throw new FunctionNotExistException(getName(), functionPath); } } @@ -389,7 +389,7 @@ public void dropFunction(ObjectPath functionPath, boolean ignoreIfNotExists) thr if (functionExists(functionPath)) { functions.remove(functionPath); } else if (!ignoreIfNotExists) { - throw new FunctionNotExistException(getCatalogName(), functionPath); + throw new FunctionNotExistException(getName(), functionPath); } } @@ -398,7 +398,7 @@ public List listFunctions(String databaseName) throws DatabaseNotExistEx checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); if (!databaseExists(databaseName)) { - throw new DatabaseNotExistException(getCatalogName(), databaseName); + throw new DatabaseNotExistException(getName(), databaseName); } return functions.keySet().stream() @@ -411,7 +411,7 @@ public CatalogFunction getFunction(ObjectPath functionPath) throws FunctionNotEx checkNotNull(functionPath); if (!functionExists(functionPath)) { - throw new FunctionNotExistException(getCatalogName(), functionPath); + throw new FunctionNotExistException(getName(), functionPath); } else { return functions.get(functionPath).copy(); } @@ -438,7 +438,7 @@ public void createPartition(ObjectPath tablePath, CatalogPartitionSpec partition if (partitionExists(tablePath, partitionSpec)) { if (!ignoreIfExists) { - throw new PartitionAlreadyExistsException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionAlreadyExistsException(getName(), tablePath, partitionSpec); } } @@ -456,7 +456,7 @@ public void dropPartition(ObjectPath tablePath, CatalogPartitionSpec partitionSp partitionStats.get(tablePath).remove(partitionSpec); partitionColumnStats.get(tablePath).remove(partitionSpec); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } } @@ -479,7 +479,7 @@ public void alterPartition(ObjectPath tablePath, CatalogPartitionSpec partitionS partitions.get(tablePath).put(partitionSpec, newPartition.copy()); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } } @@ -521,7 +521,7 @@ public CatalogPartition getPartition(ObjectPath tablePath, CatalogPartitionSpec checkNotNull(partitionSpec); if (!partitionExists(tablePath, partitionSpec)) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } return partitions.get(tablePath).get(partitionSpec).copy(); @@ -539,7 +539,7 @@ public boolean partitionExists(ObjectPath tablePath, CatalogPartitionSpec partit private void ensureFullPartitionSpec(ObjectPath tablePath, CatalogPartitionSpec partitionSpec) throws TableNotExistException, PartitionSpecInvalidException { if (!isFullPartitionSpec(tablePath, partitionSpec)) { - throw new PartitionSpecInvalidException(getCatalogName(), ((CatalogTable) getTable(tablePath)).getPartitionKeys(), + throw new PartitionSpecInvalidException(getName(), ((CatalogTable) getTable(tablePath)).getPartitionKeys(), tablePath, partitionSpec); } } @@ -564,7 +564,7 @@ private boolean isFullPartitionSpec(ObjectPath tablePath, CatalogPartitionSpec p private void ensurePartitionedTable(ObjectPath tablePath) throws TableNotPartitionedException { if (!isPartitionedTable(tablePath)) { - throw new TableNotPartitionedException(getCatalogName(), tablePath); + throw new TableNotPartitionedException(getName(), tablePath); } } @@ -590,7 +590,7 @@ public CatalogTableStatistics getTableStatistics(ObjectPath tablePath) throws Ta checkNotNull(tablePath); if (!tableExists(tablePath)) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } CatalogTableStatistics result = tableStats.get(tablePath); @@ -602,7 +602,7 @@ public CatalogColumnStatistics getTableColumnStatistics(ObjectPath tablePath) th checkNotNull(tablePath); if (!tableExists(tablePath)) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } CatalogColumnStatistics result = tableColumnStats.get(tablePath); @@ -616,7 +616,7 @@ public CatalogTableStatistics getPartitionStatistics(ObjectPath tablePath, Catal checkNotNull(partitionSpec); if (!partitionExists(tablePath, partitionSpec)) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } CatalogTableStatistics result = partitionStats.get(tablePath).get(partitionSpec); @@ -630,7 +630,7 @@ public CatalogColumnStatistics getPartitionColumnStatistics(ObjectPath tablePath checkNotNull(partitionSpec); if (!partitionExists(tablePath, partitionSpec)) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } CatalogColumnStatistics result = partitionColumnStats.get(tablePath).get(partitionSpec); @@ -646,7 +646,7 @@ public void alterTableStatistics(ObjectPath tablePath, CatalogTableStatistics ta if (tableExists(tablePath)) { tableStats.put(tablePath, tableStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } @@ -659,7 +659,7 @@ public void alterTableColumnStatistics(ObjectPath tablePath, CatalogColumnStatis if (tableExists(tablePath)) { tableColumnStats.put(tablePath, columnStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new TableNotExistException(getCatalogName(), tablePath); + throw new TableNotExistException(getName(), tablePath); } } @@ -673,7 +673,7 @@ public void alterPartitionStatistics(ObjectPath tablePath, CatalogPartitionSpec if (partitionExists(tablePath, partitionSpec)) { partitionStats.get(tablePath).put(partitionSpec, partitionStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } } @@ -687,7 +687,7 @@ public void alterPartitionColumnStatistics(ObjectPath tablePath, CatalogPartitio if (partitionExists(tablePath, partitionSpec)) { partitionColumnStats.get(tablePath).put(partitionSpec, columnStatistics.copy()); } else if (!ignoreIfNotExists) { - throw new PartitionNotExistException(getCatalogName(), tablePath, partitionSpec); + throw new PartitionNotExistException(getName(), tablePath, partitionSpec); } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactory.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactory.java new file mode 100644 index 00000000000000..ce153938cf4161 --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactory.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.table.descriptors.DescriptorProperties; +import org.apache.flink.table.descriptors.GenericInMemoryCatalogValidator; +import org.apache.flink.table.factories.CatalogFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_DEFAULT_DATABASE; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; +import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; +import static org.apache.flink.table.descriptors.GenericInMemoryCatalogValidator.CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY; + +/** + * Catalog factory for {@link GenericInMemoryCatalog}. + */ +public class GenericInMemoryCatalogFactory implements CatalogFactory { + + @Override + public Map requiredContext() { + Map context = new HashMap<>(); + context.put(CATALOG_TYPE, CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY); // generic_in_memory + context.put(CATALOG_PROPERTY_VERSION, "1"); // backwards compatibility + return context; + } + + @Override + public List supportedProperties() { + List properties = new ArrayList<>(); + + // default database + properties.add(CATALOG_DEFAULT_DATABASE); + + return properties; + } + + @Override + public Catalog createCatalog(String name, Map properties) { + final DescriptorProperties descriptorProperties = getValidatedProperties(properties); + + final Optional defaultDatabase = descriptorProperties.getOptionalString(CATALOG_DEFAULT_DATABASE); + + return new GenericInMemoryCatalog(name, defaultDatabase.orElse(GenericInMemoryCatalog.DEFAULT_DB)); + } + + private static DescriptorProperties getValidatedProperties(Map properties) { + final DescriptorProperties descriptorProperties = new DescriptorProperties(true); + descriptorProperties.putProperties(properties); + + new GenericInMemoryCatalogValidator().validate(descriptorProperties); + + return descriptorProperties; + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogDescriptor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogDescriptor.java new file mode 100644 index 00000000000000..03b06eb16013f3 --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogDescriptor.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.descriptors; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.descriptors.GenericInMemoryCatalogValidator.CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY; + +/** + * Catalog descriptor for the generic in memory catalog. + */ +public class GenericInMemoryCatalogDescriptor extends CatalogDescriptor { + + public GenericInMemoryCatalogDescriptor() { + super(CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY, 1); + } + + public GenericInMemoryCatalogDescriptor(String defaultDatabase) { + super(CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY, 1, defaultDatabase); + } + + @Override + protected Map toCatalogProperties() { + return Collections.unmodifiableMap(new HashMap<>()); + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogValidator.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogValidator.java new file mode 100644 index 00000000000000..22a54ff2bca39f --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/descriptors/GenericInMemoryCatalogValidator.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.descriptors; + +/** + * Validator for {@link GenericInMemoryCatalogDescriptor}. + */ +public class GenericInMemoryCatalogValidator extends CatalogDescriptorValidator { + public static final String CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY = "generic_in_memory"; + + @Override + public void validate(DescriptorProperties properties) { + super.validate(properties); + properties.validateValue(CATALOG_TYPE, CATALOG_TYPE_VALUE_GENERIC_IN_MEMORY, false); + } +} diff --git a/flink-table/flink-table-api-java/src/main/resources/META-INF/services/org.apache.flink.table.factories.TableFactory b/flink-table/flink-table-api-java/src/main/resources/META-INF/services/org.apache.flink.table.factories.TableFactory new file mode 100644 index 00000000000000..81b51d431f1fef --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/resources/META-INF/services/org.apache.flink.table.factories.TableFactory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.table.catalog.GenericInMemoryCatalogFactory diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactoryTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactoryTest.java new file mode 100644 index 00000000000000..eb3d61512c375c --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/GenericInMemoryCatalogFactoryTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.table.descriptors.CatalogDescriptor; +import org.apache.flink.table.descriptors.GenericInMemoryCatalogDescriptor; +import org.apache.flink.table.factories.CatalogFactory; +import org.apache.flink.table.factories.TableFactoryService; +import org.apache.flink.util.TestLogger; + +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** + * Test for {@link GenericInMemoryCatalog} created by {@link GenericInMemoryCatalogFactory}. + */ +public class GenericInMemoryCatalogFactoryTest extends TestLogger { + + @Test + public void test() throws Exception { + final String catalogName = "mycatalog"; + final String databaseName = "mydatabase"; + + final GenericInMemoryCatalog expectedCatalog = new GenericInMemoryCatalog(catalogName, databaseName); + + final CatalogDescriptor catalogDescriptor = new GenericInMemoryCatalogDescriptor(databaseName); + + final Map properties = catalogDescriptor.toProperties(); + + final Catalog actualCatalog = TableFactoryService.find(CatalogFactory.class, properties) + .createCatalog(catalogName, properties); + + checkEquals(expectedCatalog, (GenericInMemoryCatalog) actualCatalog); + } + + private static void checkEquals(GenericInMemoryCatalog c1, GenericInMemoryCatalog c2) throws Exception { + // Only assert a few selected properties for now + assertEquals(c1.getName(), c2.getName()); + assertEquals(c1.getDefaultDatabase(), c2.getDefaultDatabase()); + assertEquals(c1.listDatabases(), c2.listDatabases()); + + final String database = c1.getDefaultDatabase(); + + assertEquals(c1.listTables(database), c2.listTables(database)); + assertEquals(c1.listViews(database), c2.listViews(database)); + assertEquals(c1.listFunctions(database), c2.listFunctions(database)); + } +} diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/descriptor/GenericInMemoryCatalogDescriptorTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/descriptor/GenericInMemoryCatalogDescriptorTest.java new file mode 100644 index 00000000000000..2b29eb947593bf --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/descriptor/GenericInMemoryCatalogDescriptorTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.descriptor; + +import org.apache.flink.table.descriptors.Descriptor; +import org.apache.flink.table.descriptors.DescriptorTestBase; +import org.apache.flink.table.descriptors.DescriptorValidator; +import org.apache.flink.table.descriptors.GenericInMemoryCatalogDescriptor; +import org.apache.flink.table.descriptors.GenericInMemoryCatalogValidator; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Tests for the {@link GenericInMemoryCatalogDescriptor} descriptor. + */ +public class GenericInMemoryCatalogDescriptorTest extends DescriptorTestBase { + + private static final String TEST_DATABASE = "test"; + + @Override + protected List descriptors() { + final Descriptor withoutDefaultDB = new GenericInMemoryCatalogDescriptor(); + + final Descriptor withDefaultDB = new GenericInMemoryCatalogDescriptor(TEST_DATABASE); + + return Arrays.asList(withoutDefaultDB, withDefaultDB); + } + + @Override + protected List> properties() { + final Map props1 = new HashMap<>(); + props1.put("type", "generic_in_memory"); + props1.put("property-version", "1"); + + final Map props2 = new HashMap<>(); + props2.put("type", "generic_in_memory"); + props2.put("property-version", "1"); + props2.put("default-database", TEST_DATABASE); + + return Arrays.asList(props1, props2); + } + + @Override + protected DescriptorValidator validator() { + return new GenericInMemoryCatalogValidator(); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java index 8d4d95749446e3..7f5707eacc3013 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/AbstractCatalog.java @@ -24,23 +24,22 @@ import static org.apache.flink.util.Preconditions.checkArgument; /** - * This interface is responsible for reading and writing metadata such as database/table/views/UDFs - * from a registered catalog. It connects a registered catalog and Flink's Table API. + * Abstract class for catalogs. */ @PublicEvolving public abstract class AbstractCatalog implements Catalog { private final String catalogName; private final String defaultDatabase; - public AbstractCatalog(String catalogName, String defaultDatabase) { - checkArgument(!StringUtils.isNullOrWhitespaceOnly(catalogName), "catalogName cannot be null or empty"); + public AbstractCatalog(String name, String defaultDatabase) { + checkArgument(!StringUtils.isNullOrWhitespaceOnly(name), "name cannot be null or empty"); checkArgument(!StringUtils.isNullOrWhitespaceOnly(defaultDatabase), "defaultDatabase cannot be null or empty"); - this.catalogName = catalogName; + this.catalogName = name; this.defaultDatabase = defaultDatabase; } - public String getCatalogName() { + public String getName() { return catalogName; } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java index 753e63c0fb8580..0827f20080e690 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/CatalogDescriptor.java @@ -19,12 +19,14 @@ package org.apache.flink.table.descriptors; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.util.StringUtils; import java.util.Map; import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_DEFAULT_DATABASE; import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; import static org.apache.flink.table.descriptors.CatalogDescriptorValidator.CATALOG_TYPE; +import static org.apache.flink.util.Preconditions.checkArgument; /** * Describes a catalog of tables, views, and functions. @@ -56,6 +58,8 @@ public CatalogDescriptor(String type, int propertyVersion) { * @param defaultDatabase default database of the catalog */ public CatalogDescriptor(String type, int propertyVersion, String defaultDatabase) { + checkArgument(!StringUtils.isNullOrWhitespaceOnly(type), "type cannot be null or empty"); + this.type = type; this.propertyVersion = propertyVersion; this.defaultDatabase = defaultDatabase; From 038ab385c6f9af129b5eda7fe05d8b39d6122077 Mon Sep 17 00:00:00 2001 From: Rui Li Date: Wed, 29 May 2019 18:25:25 +0800 Subject: [PATCH 44/92] [FLINK-12649][hive] Add a shim layer to support multiple versions of Hive Metastore To add shim layer for HMS client, in order to support different versions of HMS. This closes #8564. --- flink-connectors/flink-connector-hive/pom.xml | 21 +- .../flink/table/catalog/hive/HiveCatalog.java | 28 +-- .../hive/HiveMetastoreClientFactory.java | 34 +++ .../hive/HiveMetastoreClientWrapper.java | 226 ++++++++++++++++++ .../flink/table/catalog/hive/HiveShim.java | 65 +++++ .../table/catalog/hive/HiveShimLoader.java | 57 +++++ .../flink/table/catalog/hive/HiveShimV1.java | 82 +++++++ .../flink/table/catalog/hive/HiveShimV2.java | 73 ++++++ 8 files changed, 562 insertions(+), 24 deletions(-) create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java diff --git a/flink-connectors/flink-connector-hive/pom.xml b/flink-connectors/flink-connector-hive/pom.xml index 25e475efeee3f4..5c205d8e31aab6 100644 --- a/flink-connectors/flink-connector-hive/pom.xml +++ b/flink-connectors/flink-connector-hive/pom.xml @@ -83,6 +83,7 @@ under the License. org.apache.hive hive-metastore ${hive.version} + provided org.apache.hive @@ -385,7 +386,6 @@ under the License. commons-beanutils:commons-beanutils com.fasterxml.jackson.core:* com.jolbox:bonecp - org.apache.hive:* org.apache.thrift:libthrift org.datanucleus:* org.antlr:antlr-runtime @@ -423,4 +423,23 @@ under the License. + + + + + hive-1.2.1 + + 1.2.1 + 2.6.0 + + + + javax.jdo + jdo-api + 3.0.1 + provided + + + + diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java index 8022c453084949..562c1983e84b68 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java @@ -53,10 +53,7 @@ import org.apache.flink.util.StringUtils; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.metastore.HiveMetaStoreClient; -import org.apache.hadoop.hive.metastore.IMetaStoreClient; import org.apache.hadoop.hive.metastore.MetaStoreUtils; -import org.apache.hadoop.hive.metastore.RetryingMetaStoreClient; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.api.AlreadyExistsException; import org.apache.hadoop.hive.metastore.api.Database; @@ -108,7 +105,7 @@ public class HiveCatalog extends AbstractCatalog { protected final HiveConf hiveConf; - protected IMetaStoreClient client; + protected HiveMetastoreClientWrapper client; public HiveCatalog(String catalogName, String hivemetastoreURI) { this(catalogName, DEFAULT_DB, getHiveConf(hivemetastoreURI)); @@ -133,23 +130,10 @@ private static HiveConf getHiveConf(String hiveMetastoreURI) { return hiveConf; } - private static IMetaStoreClient getMetastoreClient(HiveConf hiveConf) { - try { - return RetryingMetaStoreClient.getProxy( - hiveConf, - null, - null, - HiveMetaStoreClient.class.getName(), - true); - } catch (MetaException e) { - throw new CatalogException("Failed to create Hive metastore client", e); - } - } - @Override public void open() throws CatalogException { if (client == null) { - client = getMetastoreClient(hiveConf); + client = HiveMetastoreClientFactory.create(hiveConf); LOG.info("Connected to Hive metastore"); } @@ -444,10 +428,7 @@ public List listViews(String databaseName) throws DatabaseNotExistExcept checkArgument(!StringUtils.isNullOrWhitespaceOnly(databaseName), "databaseName cannot be null or empty"); try { - return client.getTables( - databaseName, - null, // table pattern - TableType.VIRTUAL_VIEW); + return client.getViews(databaseName); } catch (UnknownDBException e) { throw new DatabaseNotExistException(getName(), databaseName); } catch (TException e) { @@ -996,7 +977,8 @@ public List listFunctions(String databaseName) throws DatabaseNotExistEx } try { - return client.getFunctions(databaseName, null); + // hive-1.x requires the pattern not being null, so pass a pattern that matches any name + return client.getFunctions(databaseName, ".*"); } catch (TException e) { throw new CatalogException( String.format("Failed to list functions in database %s", databaseName), e); diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java new file mode 100644 index 00000000000000..be46552fb49d1b --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog.hive; + +import org.apache.hadoop.hive.conf.HiveConf; + +/** + * Factory to create Hive metastore client. + */ +public class HiveMetastoreClientFactory { + + private HiveMetastoreClientFactory() { + } + + public static HiveMetastoreClientWrapper create(HiveConf hiveConf) { + return new HiveMetastoreClientWrapper(hiveConf); + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java new file mode 100644 index 00000000000000..43937aba0c3564 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog.hive; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.util.Preconditions; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.metastore.api.AlreadyExistsException; +import org.apache.hadoop.hive.metastore.api.ColumnStatistics; +import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; +import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.hadoop.hive.metastore.api.Function; +import org.apache.hadoop.hive.metastore.api.InvalidInputException; +import org.apache.hadoop.hive.metastore.api.InvalidObjectException; +import org.apache.hadoop.hive.metastore.api.InvalidOperationException; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; +import org.apache.hadoop.hive.metastore.api.Partition; +import org.apache.hadoop.hive.metastore.api.Table; +import org.apache.hadoop.hive.metastore.api.UnknownDBException; +import org.apache.thrift.TException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; + +/** + * Wrapper class for Hive Metastore Client, which embeds a HiveShim layer to handle different Hive versions. + * Methods provided mostly conforms to IMetaStoreClient interfaces except those that require shims. + */ +@Internal +public class HiveMetastoreClientWrapper implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(HiveMetastoreClientWrapper.class); + + private final IMetaStoreClient client; + private final HiveConf hiveConf; + + public HiveMetastoreClientWrapper(HiveConf hiveConf) { + this.hiveConf = Preconditions.checkNotNull(hiveConf, "HiveConf cannot be null"); + client = createMetastoreClient(); + } + + @Override + public void close() { + client.close(); + } + + public List getDatabases(String pattern) throws MetaException, TException { + return client.getDatabases(pattern); + } + + public List getAllDatabases() throws MetaException, TException { + return client.getAllDatabases(); + } + + public List getAllTables(String databaseName) throws MetaException, TException, UnknownDBException { + return client.getAllTables(databaseName); + } + + public void dropTable(String databaseName, String tableName) + throws MetaException, TException, NoSuchObjectException { + client.dropTable(databaseName, tableName); + } + + public void dropTable(String dbName, String tableName, boolean deleteData, boolean ignoreUnknownTable) + throws MetaException, NoSuchObjectException, TException { + client.dropTable(dbName, tableName, deleteData, ignoreUnknownTable); + } + + public boolean tableExists(String databaseName, String tableName) + throws MetaException, TException, UnknownDBException { + return client.tableExists(databaseName, tableName); + } + + public Database getDatabase(String name) throws NoSuchObjectException, MetaException, TException { + return client.getDatabase(name); + } + + public Table getTable(String databaseName, String tableName) + throws MetaException, NoSuchObjectException, TException { + return client.getTable(databaseName, tableName); + } + + public Partition add_partition(Partition partition) + throws InvalidObjectException, AlreadyExistsException, MetaException, TException { + return client.add_partition(partition); + } + + public int add_partitions(List partitionList) + throws InvalidObjectException, AlreadyExistsException, MetaException, TException { + return client.add_partitions(partitionList); + } + + public Partition getPartition(String databaseName, String tableName, List list) + throws NoSuchObjectException, MetaException, TException { + return client.getPartition(databaseName, tableName, list); + } + + public List listPartitionNames(String databaseName, String tableName, short maxPartitions) + throws MetaException, TException { + return client.listPartitionNames(databaseName, tableName, maxPartitions); + } + + public List listPartitionNames(String databaseName, String tableName, List partitionValues, + short maxPartitions) throws MetaException, TException, NoSuchObjectException { + return client.listPartitionNames(databaseName, tableName, partitionValues, maxPartitions); + } + + public void createTable(Table table) + throws AlreadyExistsException, InvalidObjectException, MetaException, NoSuchObjectException, TException { + client.createTable(table); + } + + public void alter_table(String databaseName, String tableName, Table table) + throws InvalidOperationException, MetaException, TException { + client.alter_table(databaseName, tableName, table); + } + + public void createDatabase(Database database) + throws InvalidObjectException, AlreadyExistsException, MetaException, TException { + client.createDatabase(database); + } + + public void dropDatabase(String name, boolean deleteData, boolean ignoreIfNotExists) + throws NoSuchObjectException, InvalidOperationException, MetaException, TException { + client.dropDatabase(name, deleteData, ignoreIfNotExists); + } + + public void alterDatabase(String name, Database database) throws NoSuchObjectException, MetaException, TException { + client.alterDatabase(name, database); + } + + public boolean dropPartition(String databaseName, String tableName, List partitionValues, boolean deleteData) + throws NoSuchObjectException, MetaException, TException { + return client.dropPartition(databaseName, tableName, partitionValues, deleteData); + } + + public void alter_partition(String databaseName, String tableName, Partition partition) + throws InvalidOperationException, MetaException, TException { + client.alter_partition(databaseName, tableName, partition); + } + + public void renamePartition(String databaseName, String tableName, List partitionValues, Partition partition) + throws InvalidOperationException, MetaException, TException { + client.renamePartition(databaseName, tableName, partitionValues, partition); + } + + public void createFunction(Function function) throws InvalidObjectException, MetaException, TException { + client.createFunction(function); + } + + public void alterFunction(String databaseName, String functionName, Function function) + throws InvalidObjectException, MetaException, TException { + client.alterFunction(databaseName, functionName, function); + } + + public void dropFunction(String databaseName, String functionName) + throws MetaException, NoSuchObjectException, InvalidObjectException, InvalidInputException, TException { + client.dropFunction(databaseName, functionName); + } + + public List getFunctions(String databaseName, String pattern) throws MetaException, TException { + return client.getFunctions(databaseName, pattern); + } + + List getTableColumnStatistics(String databaseName, String tableName, List columnNames) + throws NoSuchObjectException, MetaException, TException { + return client.getTableColumnStatistics(databaseName, tableName, columnNames); + } + + Map> getPartitionColumnStatistics(String dbName, String tableName, + List partNames, List colNames) + throws NoSuchObjectException, MetaException, TException { + return client.getPartitionColumnStatistics(dbName, tableName, partNames, colNames); + } + + public boolean updateTableColumnStatistics(ColumnStatistics columnStatistics) + throws NoSuchObjectException, InvalidObjectException, MetaException, TException, InvalidInputException { + return client.updateTableColumnStatistics(columnStatistics); + } + + public List listPartitions(String dbName, String tblName, List partVals, short max) throws TException { + return client.listPartitions(dbName, tblName, partVals, max); + } + + public List listPartitions(String dbName, String tblName, short max) throws TException { + return client.listPartitions(dbName, tblName, max); + } + + //-------- Start of shimmed methods ---------- + + public List getViews(String databaseName) throws UnknownDBException, TException { + HiveShim hiveShim = HiveShimLoader.loadHiveShim(); + return hiveShim.getViews(client, databaseName); + } + + private IMetaStoreClient createMetastoreClient() { + HiveShim hiveShim = HiveShimLoader.loadHiveShim(); + return hiveShim.getHiveMetastoreClient(hiveConf); + } + + public Function getFunction(String databaseName, String functionName) throws MetaException, TException { + HiveShim hiveShim = HiveShimLoader.loadHiveShim(); + return hiveShim.getFunction(client, databaseName, functionName); + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java new file mode 100644 index 00000000000000..422dfe6307e36a --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog.hive; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.metastore.api.Function; +import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; +import org.apache.hadoop.hive.metastore.api.UnknownDBException; +import org.apache.thrift.TException; + +import java.util.List; + +/** + * A shim layer to support different versions of HMS. + */ +public interface HiveShim { + + /** + * Create a Hive Metastore client based on the given HiveConf object. + * + * @param hiveConf HiveConf instance + * @return an IMetaStoreClient instance + */ + IMetaStoreClient getHiveMetastoreClient(HiveConf hiveConf); + + /** + * Get a list of views in the given database from the given Hive Metastore client. + * + * @param client Hive Metastore client + * @param databaseName the name of the database + * @return A list of names of the views + * @throws UnknownDBException if the database doesn't exist + * @throws TException for any other generic exceptions caused by Thrift + */ + List getViews(IMetaStoreClient client, String databaseName) throws UnknownDBException, TException; + + /** + * Gets a function from a database with the given HMS client. + * + * @param client the Hive Metastore client + * @param dbName name of the database + * @param functionName name of the function + * @return the Function under the specified name + * @throws NoSuchObjectException if the function doesn't exist + * @throws TException for any other generic exceptions caused by Thrift + */ + Function getFunction(IMetaStoreClient client, String dbName, String functionName) throws NoSuchObjectException, TException; +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java new file mode 100644 index 00000000000000..ef90547e09c221 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog.hive; + +import org.apache.flink.table.catalog.exceptions.CatalogException; + +import org.apache.hive.common.util.HiveVersionInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A loader to load HiveShim. + */ +public class HiveShimLoader { + + private static final String HIVE_V1_VERSION_NAME = "1.2.1"; + private static final String HIVE_V2_VERSION_NAME = "2.3.4"; + + private static final Map hiveShims = new ConcurrentHashMap<>(2); + + private static final Logger LOG = LoggerFactory.getLogger(HiveShimLoader.class); + + private HiveShimLoader() { + } + + public static HiveShim loadHiveShim() { + String version = HiveVersionInfo.getVersion(); + return hiveShims.computeIfAbsent(version, (v) -> { + if (v.startsWith(HIVE_V1_VERSION_NAME)) { + return new HiveShimV1(); + } + if (v.startsWith(HIVE_V2_VERSION_NAME)) { + return new HiveShimV2(); + } + throw new CatalogException("Unsupported Hive version " + v); + }); + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java new file mode 100644 index 00000000000000..1048f52316f760 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog.hive; + +import org.apache.flink.table.catalog.exceptions.CatalogException; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.metastore.RetryingMetaStoreClient; +import org.apache.hadoop.hive.metastore.api.Function; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; +import org.apache.hadoop.hive.metastore.api.Table; +import org.apache.hadoop.hive.metastore.api.UnknownDBException; +import org.apache.thrift.TException; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +/** + * Shim for Hive version 1.x. + */ +public class HiveShimV1 implements HiveShim { + + @Override + public IMetaStoreClient getHiveMetastoreClient(HiveConf hiveConf) { + try { + Method method = RetryingMetaStoreClient.class.getMethod("getProxy", HiveConf.class); + // getProxy is a static method + return (IMetaStoreClient) method.invoke(null, (hiveConf)); + } catch (Exception ex) { + throw new CatalogException("Failed to create Hive Metastore client", ex); + } + } + + @Override + // 1.x client doesn't support filtering tables by type, so here we need to get all tables and filter by ourselves + public List getViews(IMetaStoreClient client, String databaseName) throws UnknownDBException, TException { + // We don't have to use reflection here because client.getAllTables(String) is supposed to be there for + // all versions. + List tableNames = client.getAllTables(databaseName); + List views = new ArrayList<>(); + for (String name : tableNames) { + Table table = client.getTable(databaseName, name); + String viewDef = table.getViewOriginalText(); + if (viewDef != null && !viewDef.isEmpty()) { + views.add(table.getTableName()); + } + } + return views; + } + + @Override + public Function getFunction(IMetaStoreClient client, String dbName, String functionName) throws NoSuchObjectException, TException { + try { + // hive-1.x doesn't throw NoSuchObjectException if function doesn't exist, instead it throws a MetaException + return client.getFunction(dbName, functionName); + } catch (MetaException e) { + if (e.getCause() instanceof NoSuchObjectException) { + throw (NoSuchObjectException) e.getCause(); + } + throw e; + } + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java new file mode 100644 index 00000000000000..fefb48f2a7871e --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog.hive; + +import org.apache.flink.table.catalog.exceptions.CatalogException; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.metastore.RetryingMetaStoreClient; +import org.apache.hadoop.hive.metastore.TableType; +import org.apache.hadoop.hive.metastore.api.Function; +import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; +import org.apache.hadoop.hive.metastore.api.UnknownDBException; +import org.apache.thrift.TException; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; + +/** + * Shim for Hive version 2.x. + */ +public class HiveShimV2 implements HiveShim { + + @Override + public IMetaStoreClient getHiveMetastoreClient(HiveConf hiveConf) { + try { + Method method = RetryingMetaStoreClient.class.getMethod("getProxy", HiveConf.class, Boolean.TYPE); + // getProxy is a static method + return (IMetaStoreClient) method.invoke(null, hiveConf, true); + } catch (Exception ex) { + throw new CatalogException("Failed to create Hive Metastore client", ex); + } + } + + @Override + public List getViews(IMetaStoreClient client, String databaseName) throws UnknownDBException, TException { + try { + Method method = client.getClass().getMethod("getTables", String.class, String.class, TableType.class); + return (List) method.invoke(client, databaseName, null, TableType.VIRTUAL_VIEW); + } catch (InvocationTargetException ite) { + Throwable targetEx = ite.getTargetException(); + if (targetEx instanceof TException) { + throw (TException) targetEx; + } else { + throw new CatalogException(String.format("Failed to get views for %s", databaseName), targetEx); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new CatalogException(String.format("Failed to get views for %s", databaseName), e); + } + } + + @Override + public Function getFunction(IMetaStoreClient client, String dbName, String functionName) throws NoSuchObjectException, TException { + return client.getFunction(dbName, functionName); + } +} From c53c446486d58e3db149a9ea6fe1984227e415b2 Mon Sep 17 00:00:00 2001 From: zhijiang Date: Sat, 1 Jun 2019 04:05:25 +0800 Subject: [PATCH 45/92] [FLINK-12564][network] Remove ResultPartitionWriter#getBufferProvider() * [FLINK-12564][network] Refactor the method of getBufferProvider to getBufferBuilder in ResultPartitionWriter ResultPartitionWriter#getBufferProvider seems not very general for all the writer implementations. The key point is to request a BufferBuilder from the BufferProvider, so this method is refactored into getBufferBuilder directly. Then the internal components of ResultPartitionWriter instance would not be exposed to outside. * [fixup] Remove getBufferProvider from ResultPartition --- .../io/network/api/writer/RecordWriter.java | 2 +- .../api/writer/ResultPartitionWriter.java | 9 ++++++--- .../io/network/partition/ResultPartition.java | 14 +++++++------ ...stractCollectingResultPartitionWriter.java | 11 +++++----- .../network/api/writer/RecordWriterTest.java | 20 +++++++++---------- .../PartialConsumePipelinedResultTest.java | 2 +- .../consumer/LocalInputChannelTest.java | 2 +- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java index 1743576688eb6a..cc40df064ee785 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java @@ -253,7 +253,7 @@ private BufferBuilder getBufferBuilder(int targetChannel) throws IOException, In private BufferBuilder requestNewBufferBuilder(int targetChannel) throws IOException, InterruptedException { checkState(!bufferBuilders[targetChannel].isPresent() || bufferBuilders[targetChannel].get().isFinished()); - BufferBuilder bufferBuilder = targetPartition.getBufferProvider().requestBufferBuilderBlocking(); + BufferBuilder bufferBuilder = targetPartition.getBufferBuilder(); bufferBuilders[targetChannel] = Optional.of(bufferBuilder); targetPartition.addBufferConsumer(bufferBuilder.createBufferConsumer(), targetChannel); return bufferBuilder; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java index 153b8800fa5870..cc1e49abb5240f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java @@ -18,8 +18,8 @@ package org.apache.flink.runtime.io.network.api.writer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import javax.annotation.Nullable; @@ -36,14 +36,17 @@ public interface ResultPartitionWriter extends AutoCloseable { */ void setup() throws IOException; - BufferProvider getBufferProvider(); - ResultPartitionID getPartitionId(); int getNumberOfSubpartitions(); int getNumTargetKeyGroups(); + /** + * Requests a {@link BufferBuilder} from this partition for writing data. + */ + BufferBuilder getBufferBuilder() throws IOException, InterruptedException; + /** * Adds the bufferConsumer to the subpartition with the given index. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java index 15f15e9640e60a..fef0278e9b48a6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java @@ -22,10 +22,10 @@ import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferPoolOwner; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.jobgraph.DistributionPattern; @@ -188,11 +188,6 @@ public int getNumberOfSubpartitions() { return subpartitions.length; } - @Override - public BufferProvider getBufferProvider() { - return bufferPool; - } - public BufferPool getBufferPool() { return bufferPool; } @@ -218,6 +213,13 @@ public ResultPartitionType getPartitionType() { // ------------------------------------------------------------------------ + @Override + public BufferBuilder getBufferBuilder() throws IOException, InterruptedException { + checkInProduceState(); + + return bufferPool.requestBufferBuilderBlocking(); + } + @Override public void addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException { checkNotNull(bufferConsumer); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java index 8ae8f5e8a93570..8633fe317f3896 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.api.writer; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; @@ -48,11 +49,6 @@ public AbstractCollectingResultPartitionWriter(BufferProvider bufferProvider) { public void setup() { } - @Override - public BufferProvider getBufferProvider() { - return bufferProvider; - } - @Override public ResultPartitionID getPartitionId() { return new ResultPartitionID(); @@ -68,6 +64,11 @@ public int getNumTargetKeyGroups() { return 1; } + @Override + public BufferBuilder getBufferBuilder() throws IOException, InterruptedException { + return bufferProvider.requestBufferBuilderBlocking(); + } + @Override public synchronized void addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { checkState(targetChannel < getNumberOfSubpartitions()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 35487b86b63694..f8c6fdd1871679 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -479,11 +479,6 @@ private CollectingPartitionWriter(Queue[] queues, BufferProvider public void setup() { } - @Override - public BufferProvider getBufferProvider() { - return bufferProvider; - } - @Override public ResultPartitionID getPartitionId() { return partitionId; @@ -499,6 +494,11 @@ public int getNumTargetKeyGroups() { return 1; } + @Override + public BufferBuilder getBufferBuilder() throws IOException, InterruptedException { + return bufferProvider.requestBufferBuilderBlocking(); + } + @Override public void addBufferConsumer(BufferConsumer buffer, int targetChannel) throws IOException { queues[targetChannel].add(buffer); @@ -554,11 +554,6 @@ private RecyclingPartitionWriter(BufferProvider bufferProvider) { public void setup() { } - @Override - public BufferProvider getBufferProvider() { - return bufferProvider; - } - @Override public ResultPartitionID getPartitionId() { return partitionId; @@ -574,6 +569,11 @@ public int getNumTargetKeyGroups() { return 1; } + @Override + public BufferBuilder getBufferBuilder() throws IOException, InterruptedException { + return bufferProvider.requestBufferBuilderBlocking(); + } + @Override public void addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { bufferConsumer.close(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java index 7cb1abd0458f2d..004ad08d6be512 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java @@ -119,7 +119,7 @@ public void invoke() throws Exception { final ResultPartitionWriter writer = getEnvironment().getWriter(0); for (int i = 0; i < 8; i++) { - final BufferBuilder bufferBuilder = writer.getBufferProvider().requestBufferBuilderBlocking(); + final BufferBuilder bufferBuilder = writer.getBufferBuilder(); writer.addBufferConsumer(bufferBuilder.createBufferConsumer(), 0); Thread.sleep(50); bufferBuilder.finish(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index a3bc696dba0b4e..74c4968ccaac7d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -131,7 +131,7 @@ public void testConcurrentConsumeMultiplePartitions() throws Exception { false, new TestPartitionProducerBufferSource( parallelism, - partition.getBufferProvider(), + partition.getBufferPool(), numberOfBuffersPerChannel) ); } From bc16485cc89fbe5b0dd1534737d0b5cd1ced885b Mon Sep 17 00:00:00 2001 From: azagrebin Date: Fri, 31 May 2019 22:24:54 +0200 Subject: [PATCH 46/92] [FLINK-12642][network][metrics] Fix In/OutputBufferPoolUsageGauge failure with NPE The result partition metrics are initialised before `ResultPartitiion#setup` was called. If a reporter tries to access a In/OutputBufferPoolUsageGauge in between it will fail with an `NullPointerException` since the `BufferPool` of the partition is still `null`. Currently, the quick fix is to return zero metrics until the `BufferPool` is initialised. When we have a single-threaded access from `Task#run`, we can merge partition/gate create and setup then it should not be the case anymore. --- .../io/network/metrics/InputBufferPoolUsageGauge.java | 8 ++++++-- .../io/network/metrics/OutputBufferPoolUsageGauge.java | 9 +++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/InputBufferPoolUsageGauge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/InputBufferPoolUsageGauge.java index 992f5611203846..c7a6d4e76a9c9f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/InputBufferPoolUsageGauge.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/InputBufferPoolUsageGauge.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.metrics; import org.apache.flink.metrics.Gauge; +import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; /** @@ -38,8 +39,11 @@ public Float getValue() { int bufferPoolSize = 0; for (SingleInputGate inputGate : inputGates) { - usedBuffers += inputGate.getBufferPool().bestEffortGetNumOfUsedBuffers(); - bufferPoolSize += inputGate.getBufferPool().getNumBuffers(); + BufferPool bufferPool = inputGate.getBufferPool(); + if (bufferPool != null) { + usedBuffers += bufferPool.bestEffortGetNumOfUsedBuffers(); + bufferPoolSize += bufferPool.getNumBuffers(); + } } if (bufferPoolSize != 0) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/OutputBufferPoolUsageGauge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/OutputBufferPoolUsageGauge.java index 9aad92ce628407..b8f771ba823267 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/OutputBufferPoolUsageGauge.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/metrics/OutputBufferPoolUsageGauge.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.metrics; import org.apache.flink.metrics.Gauge; +import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.partition.ResultPartition; /** @@ -38,8 +39,12 @@ public Float getValue() { int bufferPoolSize = 0; for (ResultPartition resultPartition : resultPartitions) { - usedBuffers += resultPartition.getBufferPool().bestEffortGetNumOfUsedBuffers(); - bufferPoolSize += resultPartition.getBufferPool().getNumBuffers(); + BufferPool bufferPool = resultPartition.getBufferPool(); + + if (bufferPool != null) { + usedBuffers += bufferPool.bestEffortGetNumOfUsedBuffers(); + bufferPoolSize += bufferPool.getNumBuffers(); + } } if (bufferPoolSize != 0) { From 5bab9cd1866c39f16b01d143eb481e013d8f8ef2 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Tue, 28 May 2019 10:09:08 -0700 Subject: [PATCH 47/92] [hotfix][hive] make hive-metastore dependency scope to 'provided' as pre-requisite to support multiple hive versions This closes #8560. --- .../src/main/resources/META-INF/NOTICE | 29 -------------- .../resources/META-INF/licenses/LICENSE.antlr | 38 ------------------- 2 files changed, 67 deletions(-) delete mode 100644 flink-connectors/flink-connector-hive/src/main/resources/META-INF/NOTICE delete mode 100644 flink-connectors/flink-connector-hive/src/main/resources/META-INF/licenses/LICENSE.antlr diff --git a/flink-connectors/flink-connector-hive/src/main/resources/META-INF/NOTICE b/flink-connectors/flink-connector-hive/src/main/resources/META-INF/NOTICE deleted file mode 100644 index b67affe024a75a..00000000000000 --- a/flink-connectors/flink-connector-hive/src/main/resources/META-INF/NOTICE +++ /dev/null @@ -1,29 +0,0 @@ -flink-connector-hive -Copyright 2014-2019 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -This project bundles the following dependencies under the Apache Software License 2.0. (http://www.apache.org/licenses/LICENSE-2.0.txt) - -- commons-dbcp:commons-dbcp:1.4 -- commons-pool:commons-pool:1.5.4 -- com.fasterxml.jackson.core:jackson-annotations:2.6.0 -- com.fasterxml.jackson.core:jackson-core:2.6.5 -- com.fasterxml.jackson.core:jackson-databind:2.6.5 -- com.jolbox:bonecp:0.8.0.RELEASE -- org.apache.hive:hive-common:2.3.4 -- org.apache.hive:hive-metastore:2.3.4 -- org.apache.hive:hive-serde:2.3.4 -- org.apache.hive:hive-service-rpc:2.3.4 -- org.apache.hive:hive-storage-api:2.4.0 -- org.apache.thrift:libthrift:0.9.3 -- org.datanucleus:datanucleus-api-jdo:4.2.4 -- org.datanucleus:datanucleus-core:4.1.17 -- org.datanucleus:datanucleus-rdbms:4.1.19 -- org.datanucleus:javax.jdo:3.2.0-m3 - -This project bundles the following dependencies under the BSD license. -See bundled license files for details. - -- org.antlr:antlr-runtime:3.5.2 diff --git a/flink-connectors/flink-connector-hive/src/main/resources/META-INF/licenses/LICENSE.antlr b/flink-connectors/flink-connector-hive/src/main/resources/META-INF/licenses/LICENSE.antlr deleted file mode 100644 index 0af2cce61d76e8..00000000000000 --- a/flink-connectors/flink-connector-hive/src/main/resources/META-INF/licenses/LICENSE.antlr +++ /dev/null @@ -1,38 +0,0 @@ -(BSD License: http://www.opensource.org/licenses/bsd-license) - -Copyright (c) 2012 Terence Parr and Sam Harwell -All rights reserved. - -Redistribution and use in source and binary forms, with or -without modification, are permitted provided that the -following conditions are met: - -* Redistributions of source code must retain the above - copyright notice, this list of conditions and the - following disclaimer. - -* Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the - following disclaimer in the documentation and/or other - materials provided with the distribution. - -* Neither the name of the Webbit nor the names of - its contributors may be used to endorse or promote products - derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE -GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR -BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - From 9fa4ece22ae04af9061de60b1d569b823039047c Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Fri, 31 May 2019 14:51:20 -0700 Subject: [PATCH 48/92] [hotfix][hive] remove jdo-api dependency from profile hive-1.2.1 in flink-connector-hive --- flink-connectors/flink-connector-hive/pom.xml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/flink-connectors/flink-connector-hive/pom.xml b/flink-connectors/flink-connector-hive/pom.xml index 5c205d8e31aab6..cd245d961abaf9 100644 --- a/flink-connectors/flink-connector-hive/pom.xml +++ b/flink-connectors/flink-connector-hive/pom.xml @@ -432,14 +432,6 @@ under the License. 1.2.1 2.6.0 - - - javax.jdo - jdo-api - 3.0.1 - provided - - From 8dd07bf62a56f84700e4c36f6537e1f8570ec11d Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Fri, 31 May 2019 14:59:54 -0700 Subject: [PATCH 49/92] [hotfix][hive] move hive metastore client and shim related classes to package 'org.apache.flink.table.catalog.hive.client' --- .../java/org/apache/flink/table/catalog/hive/HiveCatalog.java | 2 ++ .../catalog/hive/{ => client}/HiveMetastoreClientFactory.java | 2 +- .../catalog/hive/{ => client}/HiveMetastoreClientWrapper.java | 2 +- .../apache/flink/table/catalog/hive/{ => client}/HiveShim.java | 2 +- .../flink/table/catalog/hive/{ => client}/HiveShimLoader.java | 2 +- .../flink/table/catalog/hive/{ => client}/HiveShimV1.java | 2 +- .../flink/table/catalog/hive/{ => client}/HiveShimV2.java | 2 +- 7 files changed, 8 insertions(+), 6 deletions(-) rename flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/{ => client}/HiveMetastoreClientFactory.java (95%) rename flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/{ => client}/HiveMetastoreClientWrapper.java (99%) rename flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/{ => client}/HiveShim.java (97%) rename flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/{ => client}/HiveShimLoader.java (97%) rename flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/{ => client}/HiveShimV1.java (98%) rename flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/{ => client}/HiveShimV2.java (98%) diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java index 562c1983e84b68..b35ea648eaf2d6 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java @@ -47,6 +47,8 @@ import org.apache.flink.table.catalog.exceptions.TableAlreadyExistException; import org.apache.flink.table.catalog.exceptions.TableNotExistException; import org.apache.flink.table.catalog.exceptions.TableNotPartitionedException; +import org.apache.flink.table.catalog.hive.client.HiveMetastoreClientFactory; +import org.apache.flink.table.catalog.hive.client.HiveMetastoreClientWrapper; import org.apache.flink.table.catalog.hive.util.HiveTableUtil; import org.apache.flink.table.catalog.stats.CatalogColumnStatistics; import org.apache.flink.table.catalog.stats.CatalogTableStatistics; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveMetastoreClientFactory.java similarity index 95% rename from flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java rename to flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveMetastoreClientFactory.java index be46552fb49d1b..0601744443b9ed 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientFactory.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveMetastoreClientFactory.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.catalog.hive; +package org.apache.flink.table.catalog.hive.client; import org.apache.hadoop.hive.conf.HiveConf; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveMetastoreClientWrapper.java similarity index 99% rename from flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java rename to flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveMetastoreClientWrapper.java index 43937aba0c3564..a72cea157ff4fd 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveMetastoreClientWrapper.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveMetastoreClientWrapper.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.catalog.hive; +package org.apache.flink.table.catalog.hive.client; import org.apache.flink.annotation.Internal; import org.apache.flink.util.Preconditions; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShim.java similarity index 97% rename from flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java rename to flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShim.java index 422dfe6307e36a..2e1c19522c61e0 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShim.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShim.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.catalog.hive; +package org.apache.flink.table.catalog.hive.client; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.IMetaStoreClient; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimLoader.java similarity index 97% rename from flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java rename to flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimLoader.java index ef90547e09c221..93be53cf6169e1 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimLoader.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimLoader.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.catalog.hive; +package org.apache.flink.table.catalog.hive.client; import org.apache.flink.table.catalog.exceptions.CatalogException; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimV1.java similarity index 98% rename from flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java rename to flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimV1.java index 1048f52316f760..d22ff39be44618 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV1.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimV1.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.catalog.hive; +package org.apache.flink.table.catalog.hive.client; import org.apache.flink.table.catalog.exceptions.CatalogException; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimV2.java similarity index 98% rename from flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java rename to flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimV2.java index fefb48f2a7871e..5fbdd9acb60c31 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveShimV2.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/client/HiveShimV2.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.catalog.hive; +package org.apache.flink.table.catalog.hive.client; import org.apache.flink.table.catalog.exceptions.CatalogException; From 323531abbfb7ae94af8939e67bc1895eb716b2d3 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Fri, 31 May 2019 20:11:36 -0700 Subject: [PATCH 50/92] [hotfix] remove extra spaces in project's pom.xml --- pom.xml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index ba3fe72f4efe9a..0334b334e8f991 100644 --- a/pom.xml +++ b/pom.xml @@ -315,7 +315,7 @@ under the License. 3.4.0 - + org.apache.avro avro @@ -1615,7 +1615,7 @@ under the License. true - + org.apache.maven.plugins maven-javadoc-plugin @@ -1643,7 +1643,7 @@ under the License. pl.project13.maven git-commit-id-plugin - + 2.1.10 From b48f17b6c10ec98dd759c737cb93f7baf97b0366 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Thu, 30 May 2019 15:38:26 +0800 Subject: [PATCH 51/92] [FLINK-12688] [state] Make serializer lazy initialization thread safe in StateDescriptor This closes #8570. --- .../api/common/state/StateDescriptor.java | 47 ++++++++++++++----- .../api/common/state/StateDescriptorTest.java | 32 +++++++++++++ 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java index 422d77f9f5221b..d92d16672a42e4 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; @@ -28,6 +29,9 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -37,6 +41,7 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; +import java.util.concurrent.atomic.AtomicReference; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -52,6 +57,7 @@ */ @PublicEvolving public abstract class StateDescriptor implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(StateDescriptor.class); /** * An enumeration of the types of supported states. Used to identify the state type @@ -82,8 +88,7 @@ public enum Type { /** The serializer for the type. May be eagerly initialized in the constructor, * or lazily once the {@link #initializeSerializerUnlessSet(ExecutionConfig)} method * is called. */ - @Nullable - protected TypeSerializer serializer; + private final AtomicReference> serializerAtomicReference = new AtomicReference<>(); /** The type information describing the value type. Only used to if the serializer * is created lazily. */ @@ -114,7 +119,7 @@ public enum Type { */ protected StateDescriptor(String name, TypeSerializer serializer, @Nullable T defaultValue) { this.name = checkNotNull(name, "name must not be null"); - this.serializer = checkNotNull(serializer, "serializer must not be null"); + this.serializerAtomicReference.set(checkNotNull(serializer, "serializer must not be null")); this.defaultValue = defaultValue; } @@ -175,6 +180,7 @@ public String getName() { */ public T getDefaultValue() { if (defaultValue != null) { + TypeSerializer serializer = serializerAtomicReference.get(); if (serializer != null) { return serializer.copy(defaultValue); } else { @@ -191,6 +197,7 @@ public T getDefaultValue() { * calling {@link #initializeSerializerUnlessSet(ExecutionConfig)}. */ public TypeSerializer getSerializer() { + TypeSerializer serializer = serializerAtomicReference.get(); if (serializer != null) { return serializer.duplicate(); } else { @@ -198,6 +205,16 @@ public TypeSerializer getSerializer() { } } + @VisibleForTesting + final TypeSerializer getOriginalSerializer() { + TypeSerializer serializer = serializerAtomicReference.get(); + if (serializer != null) { + return serializer; + } else { + throw new IllegalStateException("Serializer not yet initialized."); + } + } + /** * Sets the name for queries of state created from this descriptor. * @@ -272,7 +289,7 @@ public StateTtlConfig getTtlConfig() { * @return True if the serializers have been initialized, false otherwise. */ public boolean isSerializerInitialized() { - return serializer != null; + return serializerAtomicReference.get() != null; } /** @@ -281,14 +298,14 @@ public boolean isSerializerInitialized() { * @param executionConfig The execution config to use when creating the serializer. */ public void initializeSerializerUnlessSet(ExecutionConfig executionConfig) { - if (serializer == null) { + if (serializerAtomicReference.get() == null) { checkState(typeInfo != null, "no serializer and no type info"); - - // instantiate the serializer - serializer = typeInfo.createSerializer(executionConfig); - - // we can drop the type info now, no longer needed - typeInfo = null; + // try to instantiate and set the serializer + TypeSerializer serializer = typeInfo.createSerializer(executionConfig); + // use cas to assure the singleton + if (!serializerAtomicReference.compareAndSet(null, serializer)) { + LOG.debug("Someone else beat us at initializing the serializer."); + } } } @@ -320,7 +337,7 @@ public String toString() { return getClass().getSimpleName() + "{name=" + name + ", defaultValue=" + defaultValue + - ", serializer=" + serializer + + ", serializer=" + serializerAtomicReference.get() + (isQueryable() ? ", queryableStateName=" + queryableStateName + "" : "") + '}'; } @@ -340,6 +357,9 @@ private void writeObject(final ObjectOutputStream out) throws IOException { // we don't have a default value out.writeBoolean(false); } else { + TypeSerializer serializer = serializerAtomicReference.get(); + checkNotNull(serializer, "Serializer not initialized."); + // we have a default value out.writeBoolean(true); @@ -370,6 +390,9 @@ private void readObject(final ObjectInputStream in) throws IOException, ClassNot // read the default value field boolean hasDefaultValue = in.readBoolean(); if (hasDefaultValue) { + TypeSerializer serializer = serializerAtomicReference.get(); + checkNotNull(serializer, "Serializer not initialized."); + int size = in.readInt(); byte[] buffer = new byte[size]; diff --git a/flink-core/src/test/java/org/apache/flink/api/common/state/StateDescriptorTest.java b/flink-core/src/test/java/org/apache/flink/api/common/state/StateDescriptorTest.java index 4346163481beba..63fc5087e2433b 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/state/StateDescriptorTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/state/StateDescriptorTest.java @@ -22,13 +22,17 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.core.fs.Path; +import org.apache.flink.core.testutils.CheckedThread; import org.apache.flink.core.testutils.CommonTestUtils; import org.junit.Test; import java.io.File; +import java.util.ArrayList; +import java.util.concurrent.ConcurrentHashMap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -200,6 +204,34 @@ public void testEqualsSameNameAndTypeDifferentClass() throws Exception { assertNotEquals(descr1, descr2); } + @Test + public void testSerializerLazyInitializeInParallel() throws Exception { + final String name = "testSerializerLazyInitializeInParallel"; + // use PojoTypeInfo which will create a new serializer when createSerializer is invoked. + final TestStateDescriptor desc = + new TestStateDescriptor<>(name, new PojoTypeInfo<>(String.class, new ArrayList<>())); + final int threadNumber = 20; + final ArrayList threads = new ArrayList<>(threadNumber); + final ExecutionConfig executionConfig = new ExecutionConfig(); + final ConcurrentHashMap> serializers = new ConcurrentHashMap<>(); + for (int i = 0; i < threadNumber; i++) { + threads.add(new CheckedThread() { + @Override + public void go() { + desc.initializeSerializerUnlessSet(executionConfig); + TypeSerializer serializer = desc.getOriginalSerializer(); + serializers.put(System.identityHashCode(serializer), serializer); + } + }); + } + threads.forEach(Thread::start); + for (CheckedThread t : threads) { + t.sync(); + } + assertEquals("Should use only one serializer but actually: " + serializers, 1, serializers.size()); + threads.clear(); + } + // ------------------------------------------------------------------------ // Mock implementations and test types // ------------------------------------------------------------------------ From f27c40dce0198a21cb933855ed2c7e60d2641550 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Mon, 20 May 2019 16:47:15 +0800 Subject: [PATCH 52/92] [FLINK-12409][python] Adds from_elements in TableEnvironment Note: Currently only flink planner is supported. The blink planner will be supported after the planner discovery is supported which is part of the work of FLIP32. This closes #8474 --- NOTICE-binary | 39 +++ flink-clients/pom.xml | 7 - .../apache/flink/client/cli/CliFrontend.java | 3 +- .../flink/client/program/PackagedProgram.java | 42 ++- .../flink/configuration/ConfigConstants.java | 3 + flink-dist/pom.xml | 17 +- flink-dist/src/main/assemblies/opt.xml | 8 + .../flink-bin/bin/pyflink-gateway-server.sh | 3 +- flink-dist/src/main/flink-bin/bin/pyflink.sh | 2 +- flink-dist/src/main/resources/META-INF/NOTICE | 1 - flink-python/pom.xml | 76 ++++- flink-python/pyflink/java_gateway.py | 4 +- flink-python/pyflink/serializers.py | 211 ++++++++++++ .../table/examples/batch/word_count.py | 4 +- .../pyflink/table/table_environment.py | 100 +++++- .../pyflink/table/tests/test_aggregate.py | 15 +- flink-python/pyflink/table/tests/test_calc.py | 94 +++--- .../table/tests/test_column_operation.py | 50 +-- .../pyflink/table/tests/test_distinct.py | 15 +- flink-python/pyflink/table/tests/test_join.py | 128 ++------ .../pyflink/table/tests/test_print_schema.py | 15 +- .../pyflink/table/tests/test_set_operation.py | 110 ++----- flink-python/pyflink/table/tests/test_sort.py | 14 +- .../table/tests/test_table_environment_api.py | 51 ++- .../pyflink/table/tests/test_window.py | 56 +--- .../python/bridge/PythonBridgeUtils.java | 147 +++++++++ .../bridge/pickle/ArrayConstructor.java | 62 ++++ .../bridge/pickle/ByteArrayConstructor.java | 35 ++ .../flink/python/client}/PythonDriver.java | 6 +- .../flink/python/client/PythonEnvUtils.java | 9 +- .../python/client}/PythonGatewayServer.java | 2 +- .../src/main/resources/META-INF/NOTICE | 15 + .../resources/META-INF/licenses/LICENSE.py4j | 0 .../META-INF/licenses/LICENSE.pyrolite | 21 ++ .../python/client}/PythonDriverTest.java | 2 +- .../python/client/PythonEnvUtilsTest.java | 10 +- .../flink/table/python/PythonTableUtils.scala | 302 ++++++++++++++++++ .../python/ExamplePointUserDefinedType.java | 87 +++++ licenses-binary/LICENSE.py4j | 26 ++ licenses-binary/LICENSE.pyrolite | 21 ++ licenses/LICENSE.py4j | 26 ++ tools/travis_controller.sh | 1 + 42 files changed, 1423 insertions(+), 417 deletions(-) create mode 100644 flink-python/pyflink/serializers.py create mode 100644 flink-python/src/main/java/org/apache/flink/python/bridge/PythonBridgeUtils.java create mode 100644 flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ArrayConstructor.java create mode 100644 flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ByteArrayConstructor.java rename {flink-clients/src/main/java/org/apache/flink/client/python => flink-python/src/main/java/org/apache/flink/python/client}/PythonDriver.java (96%) rename flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java => flink-python/src/main/java/org/apache/flink/python/client/PythonEnvUtils.java (96%) rename {flink-clients/src/main/java/org/apache/flink/client/python => flink-python/src/main/java/org/apache/flink/python/client}/PythonGatewayServer.java (98%) create mode 100644 flink-python/src/main/resources/META-INF/NOTICE rename {flink-dist => flink-python}/src/main/resources/META-INF/licenses/LICENSE.py4j (100%) create mode 100644 flink-python/src/main/resources/META-INF/licenses/LICENSE.pyrolite rename {flink-clients/src/test/java/org/apache/flink/client/python => flink-python/src/test/java/org/apache/flink/python/client}/PythonDriverTest.java (98%) rename flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java => flink-python/src/test/java/org/apache/flink/python/client/PythonEnvUtilsTest.java (93%) create mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/python/PythonTableUtils.scala create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/types/python/ExamplePointUserDefinedType.java create mode 100644 licenses-binary/LICENSE.py4j create mode 100644 licenses-binary/LICENSE.pyrolite create mode 100644 licenses/LICENSE.py4j diff --git a/NOTICE-binary b/NOTICE-binary index a49077725d3044..97b8e7c598f6fe 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -5866,3 +5866,42 @@ This project bundles the following dependencies under the Apache Software Licens Apache HttpClient Copyright 1999-2017 The Apache Software Foundation + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for Apache Flink +// ------------------------------------------------------------------ + +Apache Flink +Copyright 2006-2019 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +flink-python +Copyright 2014-2019 The Apache Software Foundation + +This project bundles the following dependencies under the BSD license. +See bundled license files for details + +- net.sf.py4j:py4j:0.10.8.1 + +This project bundles the following dependencies under the MIT license. (https://opensource.org/licenses/MIT) +See bundled license files for details. + +- net.razorvine:pyrolite:4.13 + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for Apache Flink +// ------------------------------------------------------------------ + +Apache Flink +Copyright 2006-2019 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +flink-python +Copyright 2014-2019 The Apache Software Foundation diff --git a/flink-clients/pom.xml b/flink-clients/pom.xml index 514b6cd4969299..6476798d46a489 100644 --- a/flink-clients/pom.xml +++ b/flink-clients/pom.xml @@ -68,13 +68,6 @@ under the License. commons-cli - - - net.sf.py4j - py4j - ${py4j.version} - - diff --git a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java index fe641fe78134ad..06dd7616797dc2 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java +++ b/flink-clients/src/main/java/org/apache/flink/client/cli/CliFrontend.java @@ -32,7 +32,6 @@ import org.apache.flink.client.program.ProgramInvocationException; import org.apache.flink.client.program.ProgramMissingJobException; import org.apache.flink.client.program.ProgramParametrizationException; -import org.apache.flink.client.python.PythonDriver; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.CoreOptions; @@ -783,7 +782,7 @@ PackagedProgram buildProgram(ProgramOptions options) throws FileNotFoundExceptio jarFile = getJarFile(jarFilePath); } // The entry point class of python job is PythonDriver - entryPointClass = PythonDriver.class.getCanonicalName(); + entryPointClass = "org.apache.flink.python.client.PythonDriver"; } else { if (jarFilePath == null) { throw new IllegalArgumentException("Java program should be specified a JAR file."); diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java b/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java index 77b5d295159d64..cc5c960c4707a7 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java +++ b/flink-clients/src/main/java/org/apache/flink/client/program/PackagedProgram.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.Plan; import org.apache.flink.api.common.Program; import org.apache.flink.api.common.ProgramDescription; -import org.apache.flink.client.python.PythonDriver; +import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.optimizer.Optimizer; import org.apache.flink.optimizer.dag.DataSinkNode; import org.apache.flink.optimizer.plandump.PlanJSONDumpGenerator; @@ -44,6 +44,12 @@ import java.net.MalformedURLException; import java.net.URISyntaxException; import java.net.URL; +import java.nio.file.FileSystems; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; @@ -176,7 +182,7 @@ public PackagedProgram(File jarFile, @Nullable String entryPointClassName, Strin */ public PackagedProgram(File jarFile, List classpaths, @Nullable String entryPointClassName, String... args) throws ProgramInvocationException { // Whether the job is a Python job. - isPython = entryPointClassName != null && entryPointClassName.equals(PythonDriver.class.getCanonicalName()); + isPython = entryPointClassName != null && entryPointClassName.equals("org.apache.flink.python.client.PythonDriver"); URL jarFileUrl = null; if (jarFile != null) { @@ -242,7 +248,7 @@ public PackagedProgram(Class entryPointClass, String... args) throws ProgramI // load the entry point class this.mainClass = entryPointClass; - isPython = entryPointClass == PythonDriver.class; + isPython = entryPointClass.getCanonicalName().equals("org.apache.flink.python.client.PythonDriver"); // if the entry point is a program, instantiate the class and get the plan if (Program.class.isAssignableFrom(this.mainClass)) { @@ -470,6 +476,36 @@ public List getAllLibraries() { } } + if (isPython) { + String flinkOptPath = System.getenv(ConfigConstants.ENV_FLINK_OPT_DIR); + final List pythonJarPath = new ArrayList<>(); + try { + Files.walkFileTree(FileSystems.getDefault().getPath(flinkOptPath), new SimpleFileVisitor() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { + FileVisitResult result = super.visitFile(file, attrs); + if (file.getFileName().toString().startsWith("flink-python-")) { + pythonJarPath.add(file); + } + return result; + } + }); + } catch (IOException e) { + throw new RuntimeException( + "Exception encountered during finding the flink-python jar. This should not happen.", e); + } + + if (pythonJarPath.size() != 1) { + throw new RuntimeException("Found " + pythonJarPath.size() + " flink-python jar."); + } + + try { + libs.add(pythonJarPath.get(0).toUri().toURL()); + } catch (MalformedURLException e) { + throw new RuntimeException("URL is invalid. This should not happen.", e); + } + } + return libs; } diff --git a/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java b/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java index 56a264c04f9c34..b1b97136af8f08 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/ConfigConstants.java @@ -2008,6 +2008,9 @@ public final class ConfigConstants { /** The environment variable name which contains the location of the lib folder. */ public static final String ENV_FLINK_LIB_DIR = "FLINK_LIB_DIR"; + /** The environment variable name which contains the location of the opt directory. */ + public static final String ENV_FLINK_OPT_DIR = "FLINK_OPT_DIR"; + /** The environment variable name which contains the location of the bin directory. */ public static final String ENV_FLINK_BIN_DIR = "FLINK_BIN_DIR"; diff --git a/flink-dist/pom.xml b/flink-dist/pom.xml index 1350f10958894f..343cc58d7a7ba0 100644 --- a/flink-dist/pom.xml +++ b/flink-dist/pom.xml @@ -361,6 +361,14 @@ under the License. ${project.version} provided + + + org.apache.flink + flink-python + java-binding + ${project.version} + provided + @@ -568,15 +576,6 @@ under the License. Apache Flink - - - py4j - org.apache.flink.api.python.py4j - - py4j.* - - - diff --git a/flink-dist/src/main/assemblies/opt.xml b/flink-dist/src/main/assemblies/opt.xml index e28acd8c962329..7f872d49c73d3a 100644 --- a/flink-dist/src/main/assemblies/opt.xml +++ b/flink-dist/src/main/assemblies/opt.xml @@ -184,6 +184,14 @@ flink-streaming-python_${scala.binary.version}-${project.version}.jar 0644 + + + + ../flink-python/target/flink-python-${project.version}-java-binding.jar + opt + flink-python-${project.version}-java-binding.jar + 0644 + diff --git a/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh b/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh index 9e41ad5e1c1d83..078c7575a55358 100644 --- a/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh +++ b/flink-dist/src/main/flink-bin/bin/pyflink-gateway-server.sh @@ -51,6 +51,7 @@ log=$FLINK_LOG_DIR/flink-$FLINK_IDENT_STRING-python-$HOSTNAME.log log_setting=(-Dlog.file="$log" -Dlog4j.configuration=file:"$FLINK_CONF_DIR"/log4j-cli.properties -Dlogback.configurationFile=file:"$FLINK_CONF_DIR"/logback.xml) TABLE_JAR_PATH=`echo "$FLINK_ROOT_DIR"/opt/flink-table*.jar` +PYTHON_JAR_PATH=`echo "$FLINK_ROOT_DIR"/opt/flink-python*java-binding.jar` FLINK_TEST_CLASSPATH="" if [[ -n "$FLINK_TESTING" ]]; then @@ -66,4 +67,4 @@ if [[ -n "$FLINK_TESTING" ]]; then done < <(find "$FLINK_SOURCE_ROOT_DIR" ! -type d \( -name 'flink-*-tests.jar' -o -path "${FLINK_SOURCE_ROOT_DIR}/flink-connectors/flink-connector-elasticsearch-base/target/flink*.jar" -o -path "${FLINK_SOURCE_ROOT_DIR}/flink-connectors/flink-connector-kafka-base/target/flink*.jar" \) -print0 | sort -z) fi -exec $JAVA_RUN $JVM_ARGS "${log_setting[@]}" -cp ${FLINK_CLASSPATH}:${TABLE_JAR_PATH}:${FLINK_TEST_CLASSPATH} ${DRIVER} ${ARGS[@]} +exec $JAVA_RUN $JVM_ARGS "${log_setting[@]}" -cp ${FLINK_CLASSPATH}:${TABLE_JAR_PATH}:${PYTHON_JAR_PATH}:${FLINK_TEST_CLASSPATH} ${DRIVER} ${ARGS[@]} diff --git a/flink-dist/src/main/flink-bin/bin/pyflink.sh b/flink-dist/src/main/flink-bin/bin/pyflink.sh index 9c998cd445855e..46470e87d8f4cf 100644 --- a/flink-dist/src/main/flink-bin/bin/pyflink.sh +++ b/flink-dist/src/main/flink-bin/bin/pyflink.sh @@ -22,4 +22,4 @@ bin=`cd "$bin"; pwd` . "$bin"/config.sh -"$FLINK_BIN_DIR"/flink run -v "$FLINK_ROOT_DIR"/opt/flink-python*.jar "$@" +"$FLINK_BIN_DIR"/flink run -v "$FLINK_ROOT_DIR"/opt/flink-python_*.jar "$@" diff --git a/flink-dist/src/main/resources/META-INF/NOTICE b/flink-dist/src/main/resources/META-INF/NOTICE index 7013a2c0040b63..59adff8a0fdb52 100644 --- a/flink-dist/src/main/resources/META-INF/NOTICE +++ b/flink-dist/src/main/resources/META-INF/NOTICE @@ -34,7 +34,6 @@ See bundled license files for details. - com.esotericsoftware.kryo:kryo:2.24.0 - com.esotericsoftware.minlog:minlog:1.2 - org.clapper:grizzled-slf4j_2.11:1.3.2 -- net.sf.py4j:py4j:0.10.8.1 The following dependencies all share the same BSD license which you find under licenses/LICENSE.scala. diff --git a/flink-python/pom.xml b/flink-python/pom.xml index 712d7db9a3f5d2..86d3d7b6e26130 100644 --- a/flink-python/pom.xml +++ b/flink-python/pom.xml @@ -32,7 +32,50 @@ under the License. flink-python flink-python - pom + jar + + + + + + + org.apache.flink + flink-core + ${project.version} + provided + + + org.apache.flink + flink-java + ${project.version} + provided + + + org.apache.flink + flink-streaming-java_${scala.binary.version} + ${project.version} + provided + + + + + + net.sf.py4j + py4j + ${py4j.version} + + + net.razorvine + pyrolite + 4.13 + + + net.razorvine + serpent + + + + @@ -79,6 +122,37 @@ under the License. + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + + false + true + java-binding + + + net.razorvine:* + net.sf.py4j:* + + + + + py4j + org.apache.flink.python.shaded.py4j + + + net.razorvine + org.apache.flink.python.shaded.net.razorvine + + + + + + diff --git a/flink-python/pyflink/java_gateway.py b/flink-python/pyflink/java_gateway.py index b218d231662704..c6954ed86948e6 100644 --- a/flink-python/pyflink/java_gateway.py +++ b/flink-python/pyflink/java_gateway.py @@ -64,7 +64,7 @@ def launch_gateway(): raise Exception("Windows system is not supported currently.") script = "./bin/pyflink-gateway-server.sh" command = [os.path.join(FLINK_HOME, script)] - command += ['-c', 'org.apache.flink.client.python.PythonGatewayServer'] + command += ['-c', 'org.apache.flink.python.client.PythonGatewayServer'] # Create a temporary directory where the gateway server should write the connection information. conn_info_dir = tempfile.mkdtemp() @@ -114,6 +114,8 @@ def import_flink_view(gateway): java_import(gateway.jvm, "org.apache.flink.table.descriptors.*") java_import(gateway.jvm, "org.apache.flink.table.sources.*") java_import(gateway.jvm, "org.apache.flink.table.sinks.*") + java_import(gateway.jvm, "org.apache.flink.table.python.*") + java_import(gateway.jvm, "org.apache.flink.python.bridge.*") java_import(gateway.jvm, "org.apache.flink.api.common.typeinfo.TypeInformation") java_import(gateway.jvm, "org.apache.flink.api.common.typeinfo.Types") java_import(gateway.jvm, "org.apache.flink.api.java.ExecutionEnvironment") diff --git a/flink-python/pyflink/serializers.py b/flink-python/pyflink/serializers.py new file mode 100644 index 00000000000000..7d8e04a3b2a541 --- /dev/null +++ b/flink-python/pyflink/serializers.py @@ -0,0 +1,211 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import sys +import struct +from abc import ABCMeta, abstractmethod + +if sys.version < '3': + import cPickle as pickle + protocol = 2 + from itertools import imap as map, chain +else: + import pickle + protocol = 3 + xrange = range + + +class SpecialLengths(object): + END_OF_DATA_SECTION = -1 + NULL = -2 + + +class Serializer(object): + + __metaclass__ = ABCMeta + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + @abstractmethod + def dump_to_stream(self, iterator, stream): + """ + Serializes an iterator of objects to the output stream. + """ + pass + + @abstractmethod + def load_from_stream(self, stream): + """ + Returns an iterator of deserialized objects from the input stream. + """ + pass + + def _load_from_stream_without_unbatching(self, stream): + """ + Returns an iterator of deserialized batches (iterable) of objects from the input stream. + If the serializer does not operate on batches the default implementation returns an + iterator of single element lists. + """ + return map(lambda x: [x], self.load_from_stream(stream)) + + +class VarLengthDataSerializer(Serializer): + """ + Serializer that writes objects as a stream of (length, data) pairs, + where length is a 32-bit integer and data is length bytes. + """ + + def dump_to_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_from_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self.dumps(obj) + if serialized is None: + raise ValueError("Serialized value should not be None") + if len(serialized) > (1 << 31): + raise ValueError("Can not serialize object larger than 2G") + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError + elif length == SpecialLengths.NULL: + return None + obj = stream.read(length) + if len(obj) < length: + raise EOFError + return self.loads(obj) + + @abstractmethod + def dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + pass + + @abstractmethod + def loads(self, obj): + """ + Deserialize an object from a byte array. + """ + pass + + +class PickleSerializer(VarLengthDataSerializer): + """ + Serializes objects using Python's pickle serializer: + + http://docs.python.org/3/library/pickle.html + + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + + def dumps(self, obj): + return pickle.dumps(obj, protocol) + + if sys.version >= '3': + def loads(self, obj, encoding="bytes"): + return pickle.loads(obj, encoding=encoding) + else: + def loads(self, obj, encoding=None): + return pickle.loads(obj) + + +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 + + def __init__(self, serializer, batch_size=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batch_size = batch_size + + def __repr__(self): + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batch_size) + + def _batched(self, iterator): + if self.batch_size == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): + n = len(iterator) + for i in xrange(0, n, self.batch_size): + yield iterator[i: i + self.batch_size] + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batch_size: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_to_stream(self, iterator, stream): + self.serializer.dump_to_stream(self._batched(iterator), stream) + + def load_from_stream(self, stream): + return chain.from_iterable(self._load_from_stream_without_unbatching(stream)) + + def _load_from_stream_without_unbatching(self, stream): + return self.serializer.load_from_stream(stream) + + +def read_int(stream): + length = stream.read(4) + if not length: + raise EOFError + return struct.unpack("!i", length)[0] + + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) diff --git a/flink-python/pyflink/table/examples/batch/word_count.py b/flink-python/pyflink/table/examples/batch/word_count.py index a324af4747c1d5..a2ae23c4a6b1c8 100644 --- a/flink-python/pyflink/table/examples/batch/word_count.py +++ b/flink-python/pyflink/table/examples/batch/word_count.py @@ -46,11 +46,11 @@ def word_count(): f.flush() f.close() - t_config = TableConfig.Builder().as_batch_execution().set_parallelism(1).build() + t_config = TableConfig.Builder().as_batch_execution().build() t_env = TableEnvironment.create(t_config) field_names = ["word", "cout"] - field_types = [DataTypes.STRING, DataTypes.LONG] + field_types = [DataTypes.STRING(), DataTypes.BIGINT()] # register Orders table in table environment t_env.register_table_source( diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py index 45e268847f7471..8bb2f638b81268 100644 --- a/flink-python/pyflink/table/table_environment.py +++ b/flink-python/pyflink/table/table_environment.py @@ -15,9 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ - +import os +import tempfile from abc import ABCMeta, abstractmethod +from pyflink.serializers import BatchedSerializer, PickleSerializer from pyflink.table.query_config import StreamQueryConfig, BatchQueryConfig, QueryConfig from pyflink.table.table_config import TableConfig from pyflink.table.table_descriptor import (StreamTableDescriptor, ConnectorDescriptor, @@ -25,7 +27,8 @@ from pyflink.java_gateway import get_gateway from pyflink.table import Table -from pyflink.table.types import _to_java_type +from pyflink.table.types import _to_java_type, _create_type_verifier, RowType, DataType, \ + _infer_schema_from_data, _create_converter from pyflink.util import utils __all__ = [ @@ -42,8 +45,9 @@ class TableEnvironment(object): __metaclass__ = ABCMeta - def __init__(self, j_tenv): + def __init__(self, j_tenv, serializer=PickleSerializer()): self._j_tenv = j_tenv + self._serializer = serializer def from_table_source(self, table_source): """ @@ -379,6 +383,82 @@ def create(cls, table_config): return t_env + def from_elements(self, elements, schema=None, verify_schema=True): + """ + Creates a table from a collection of elements. + + :param elements: The elements to create a table from. + :param schema: The schema of the table. + :param verify_schema: Whether to verify the elements against the schema. + :return: A Table. + """ + + # verifies the elements against the specified schema + if isinstance(schema, RowType): + verify_func = _create_type_verifier(schema) if verify_schema else lambda _: True + + def verify_obj(obj): + verify_func(obj) + return obj + elif isinstance(schema, DataType): + data_type = schema + schema = RowType().add("value", schema) + + verify_func = _create_type_verifier( + data_type, name="field value") if verify_schema else lambda _: True + + def verify_obj(obj): + verify_func(obj) + return obj + else: + def verify_obj(obj): + return obj + + if "__len__" not in dir(elements): + elements = list(elements) + + # infers the schema if not specified + if schema is None or isinstance(schema, (list, tuple)): + schema = _infer_schema_from_data(elements, names=schema) + converter = _create_converter(schema) + elements = map(converter, elements) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + schema.fields[i].name = name + schema.names[i] = name + + elif not isinstance(schema, RowType): + raise TypeError( + "schema should be RowType, list, tuple or None, but got: %s" % schema) + + # converts python data to sql data + elements = [schema.to_sql_type(element) for element in elements] + return self._from_elements(map(verify_obj, elements), schema) + + def _from_elements(self, elements, schema): + """ + Creates a table from a collection of elements. + + :param elements: The elements to create a table from. + :return: A table. + """ + + # serializes to a file, and we read the file in java + temp_file = tempfile.NamedTemporaryFile(delete=False, dir=tempfile.mkdtemp()) + serializer = BatchedSerializer(self._serializer) + try: + try: + serializer.dump_to_stream(elements, temp_file) + finally: + temp_file.close() + return self._from_file(temp_file.name, schema) + finally: + os.unlink(temp_file.name) + + @abstractmethod + def _from_file(self, filename, schema): + pass + class StreamTableEnvironment(TableEnvironment): @@ -386,6 +466,13 @@ def __init__(self, j_tenv): self._j_tenv = j_tenv super(StreamTableEnvironment, self).__init__(j_tenv) + def _from_file(self, filename, schema): + gateway = get_gateway() + jds = gateway.jvm.PythonBridgeUtils.createDataStreamFromFile( + self._j_tenv.execEnv(), filename, True) + return Table(gateway.jvm.PythonTableUtils.fromDataStream( + self._j_tenv, jds, _to_java_type(schema))) + def get_config(self): """ Returns the table config to define the runtime behavior of the Table API. @@ -444,6 +531,13 @@ def __init__(self, j_tenv): self._j_tenv = j_tenv super(BatchTableEnvironment, self).__init__(j_tenv) + def _from_file(self, filename, schema): + gateway = get_gateway() + jds = gateway.jvm.PythonBridgeUtils.createDataSetFromFile( + self._j_tenv.execEnv(), filename, True) + return Table(gateway.jvm.PythonTableUtils.fromDataSet( + self._j_tenv, jds, _to_java_type(schema))) + def get_config(self): """ Returns the table config to define the runtime behavior of the Table API. diff --git a/flink-python/pyflink/table/tests/test_aggregate.py b/flink-python/pyflink/table/tests/test_aggregate.py index b0bdf82ca8bd09..503319cd64695f 100644 --- a/flink-python/pyflink/table/tests/test_aggregate.py +++ b/flink-python/pyflink/table/tests/test_aggregate.py @@ -16,8 +16,6 @@ # limitations under the License. ################################################################################ -import os - from pyflink.table.types import DataTypes from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase @@ -26,21 +24,16 @@ class StreamTableAggregateTests(PyFlinkStreamTableTestCase): def test_group_by(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello'), (2, 'Hello', 'Hello')], + ['a', 'b', 'c']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source.group_by("c").select("a.sum, c as b") + result = t.group_by("c").select("a.sum, c as b") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() diff --git a/flink-python/pyflink/table/tests/test_calc.py b/flink-python/pyflink/table/tests/test_calc.py index 0822c912046cde..dba7792445c8a6 100644 --- a/flink-python/pyflink/table/tests/test_calc.py +++ b/flink-python/pyflink/table/tests/test_calc.py @@ -16,9 +16,13 @@ # limitations under the License. ################################################################################ -import os +import array +import datetime +from decimal import Decimal -from pyflink.table import CsvTableSource, DataTypes +from pyflink.table import DataTypes, Row +from pyflink.table.tests.test_types import ExamplePoint, PythonOnlyPoint, ExamplePointUDT, \ + PythonOnlyUDT from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase @@ -26,25 +30,16 @@ class StreamTableCalcTests(PyFlinkStreamTableTestCase): def test_select(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - with open(source_path, 'w') as f: - lines = '1,hi,hello\n' + '2,hi,hello\n' - f.write(lines) - f.close() + t_env = self.t_env + t = t_env.from_elements([(1, 'hi', 'hello'), (2, 'hi', 'hello')], ['a', 'b', 'c']) field_names = ["a", "b", "c"] field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] - t_env = self.t_env - # register Orders table in table environment - t_env.register_table_source( - "Orders", - CsvTableSource(source_path, field_names, field_types)) t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - t_env.scan("Orders") \ - .select("a + 1, b, c") \ - .insert_into("Results") + t.select("a + 1, b, c") \ + .insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -52,40 +47,32 @@ def test_select(self): self.assert_equals(actual, expected) def test_alias(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') + t_env = self.t_env + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hi', 'Hello')], ['a', 'b', 'c']) field_names = ["a", "b", "c"] field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.alias("d, e, f").select("d, e, f") + result = t.alias("d, e, f").select("d, e, f") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() - expected = ['1,Hi,Hello', '2,Hello,Hello'] + expected = ['1,Hi,Hello', '2,Hi,Hello'] self.assert_equals(actual, expected) def test_where(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') + t_env = self.t_env + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello')], ['a', 'b', 'c']) field_names = ["a", "b", "c"] field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.where("a > 1 && b = 'Hello'") + result = t.where("a > 1 && b = 'Hello'") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -94,19 +81,15 @@ def test_where(self): self.assert_equals(actual, expected) def test_filter(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') + t_env = self.t_env + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello')], ['a', 'b', 'c']) field_names = ["a", "b", "c"] field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.filter("a > 1 && b = 'Hello'") + result = t.filter("a > 1 && b = 'Hello'") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -114,6 +97,43 @@ def test_filter(self): expected = ['2,Hello,Hello'] self.assert_equals(actual, expected) + def test_from_element(self): + t_env = self.t_env + a = array.array('b') + a.fromstring('ABCD') + t = t_env.from_elements( + [(1, 1.0, "hi", "hello", datetime.date(1970, 1, 2), datetime.time(1, 0, 0), + datetime.datetime(1970, 1, 2, 0, 0), array.array("d", [1]), ["abc"], + [datetime.date(1970, 1, 2)], Decimal(1), Row("a", "b")(1, 2.0), + {"key": 1.0}, a, ExamplePoint(1.0, 2.0), + PythonOnlyPoint(3.0, 4.0))]) + field_names = ["a", "b", "c", "d", "e", "f", "g", "h", + "i", "j", "k", "l", "m", "n", "o", "p"] + field_types = [DataTypes.BIGINT(), DataTypes.DOUBLE(), DataTypes.STRING(), + DataTypes.STRING(), DataTypes.DATE(), + DataTypes.TIME(), + DataTypes.TIMESTAMP(), + DataTypes.ARRAY(DataTypes.DOUBLE()), + DataTypes.ARRAY(DataTypes.STRING()), + DataTypes.ARRAY(DataTypes.DATE()), + DataTypes.DECIMAL(), + DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT()), + DataTypes.FIELD("b", DataTypes.DOUBLE())]), + DataTypes.MAP(DataTypes.VARCHAR(), DataTypes.DOUBLE()), + DataTypes.VARBINARY(), ExamplePointUDT(), + PythonOnlyUDT()] + t_env.register_table_sink( + "Results", + field_names, field_types, source_sink_utils.TestAppendSink()) + + t.insert_into("Results") + t_env.execute() + actual = source_sink_utils.results() + + expected = ['1,1.0,hi,hello,1970-01-02,01:00:00,1970-01-02 00:00:00.0,[1.0],[abc],' + '[1970-01-02],1,1,2.0,{key=1.0},[65, 66, 67, 68],[1.0, 2.0],[3.0, 4.0]'] + self.assert_equals(actual, expected) + if __name__ == '__main__': import unittest diff --git a/flink-python/pyflink/table/tests/test_column_operation.py b/flink-python/pyflink/table/tests/test_column_operation.py index 5faff12c3a95ac..41d0768951c422 100644 --- a/flink-python/pyflink/table/tests/test_column_operation.py +++ b/flink-python/pyflink/table/tests/test_column_operation.py @@ -16,8 +16,6 @@ # limitations under the License. ################################################################################ -import os - from pyflink.table.types import DataTypes from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase @@ -26,20 +24,15 @@ class StreamTableColumnsOperationTests(PyFlinkStreamTableTestCase): def test_add_columns(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") - field_types = [DataTypes.INT(), DataTypes.INT(), DataTypes.INT()] + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello')], ['a', 'b', 'c']) + field_names = ["a", "b", "c"] + field_types = [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.select("a").add_columns("a + 1 as b, a + 2 as c") + result = t.select("a").add_columns("a + 1 as b, a + 2 as c") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -48,21 +41,15 @@ def test_add_columns(self): self.assert_equals(actual, expected) def test_add_or_replace_columns(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") - field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.INT()] + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello')], ['a', 'b', 'c']) + field_names = ["b", "a"] + field_types = [DataTypes.BIGINT(), DataTypes.BIGINT()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.select("a").add_or_replace_columns("a + 1 as b, a + 2 as a") + result = t.select("a").add_or_replace_columns("a + 1 as b, a + 2 as a") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -71,20 +58,15 @@ def test_add_or_replace_columns(self): self.assert_equals(actual, expected) def test_rename_columns(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello')], ['a', 'b', 'c']) field_names = ["d", "e", "f"] + field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.select("a, b, c").rename_columns("a as d, c as f, b as e").select("d, e, f") + result = t.select("a, b, c").rename_columns("a as d, c as f, b as e").select("d, e, f") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -93,21 +75,15 @@ def test_rename_columns(self): self.assert_equals(actual, expected) def test_drop_columns(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello')], ['a', 'b', 'c']) field_names = ["b"] field_types = [DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source.select("a, b, c").drop_columns("a, c").select("b") + result = t.select("a, b, c").drop_columns("a, c").select("b") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() diff --git a/flink-python/pyflink/table/tests/test_distinct.py b/flink-python/pyflink/table/tests/test_distinct.py index fd32a2859f3a0d..4cfb310cc32c2b 100644 --- a/flink-python/pyflink/table/tests/test_distinct.py +++ b/flink-python/pyflink/table/tests/test_distinct.py @@ -16,8 +16,6 @@ # limitations under the License. ################################################################################ -import os - from pyflink.table.types import DataTypes from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase @@ -26,21 +24,16 @@ class StreamTableDistinctTests(PyFlinkStreamTableTestCase): def test_distinct(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello", "Hello")], + ['a', 'b', 'c']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source.distinct().select("a, c as b") + result = t.distinct().select("a, c as b") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() diff --git a/flink-python/pyflink/table/tests/test_join.py b/flink-python/pyflink/table/tests/test_join.py index 4666b97ebb00c4..827ea90d08e376 100644 --- a/flink-python/pyflink/table/tests/test_join.py +++ b/flink-python/pyflink/table/tests/test_join.py @@ -16,8 +16,6 @@ # limitations under the License. ################################################################################ -import os - from pyflink.table.types import DataTypes from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase @@ -26,28 +24,17 @@ class StreamTableJoinTests(PyFlinkStreamTableTestCase): def test_join_without_where(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - field_names2 = ["d", "e"] - field_types2 = [DataTypes.INT(), DataTypes.STRING()] - data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types2, field_names2) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Flink"), (3, "Python"), (3, "Flink")], ['d', 'e']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source1.join(source2, "a = d").select("a, b + e") + result = t1.join(t2, "a = d").select("a, b + e") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -56,28 +43,17 @@ def test_join_without_where(self): self.assert_equals(actual, expected) def test_join_with_where(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - field_names2 = ["d", "e"] - field_types2 = [DataTypes.INT(), DataTypes.STRING()] - data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types2, field_names2) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Flink"), (3, "Python"), (3, "Flink")], ['d', 'e']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source1.join(source2).where("a = d").select("a, b + e") + result = t1.join(t2).where("a = d").select("a, b + e") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -86,28 +62,17 @@ def test_join_with_where(self): self.assert_equals(actual, expected) def test_left_outer_join_without_where(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - field_names2 = ["d", "e"] - field_types2 = [DataTypes.INT(), DataTypes.STRING()] - data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types2, field_names2) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Flink"), (3, "Python"), (3, "Flink")], ['d', 'e']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source1.left_outer_join(source2, "a = d").select("a, b + e") + result = t1.left_outer_join(t2, "a = d").select("a, b + e") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -116,28 +81,17 @@ def test_left_outer_join_without_where(self): self.assert_equals(actual, expected) def test_left_outer_join_with_where(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - field_names2 = ["d", "e"] - field_types2 = [DataTypes.INT(), DataTypes.STRING()] - data2 = [(2, "Flink"), (3, "Python"), (3, "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types2, field_names2) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Flink"), (3, "Python"), (3, "Flink")], ['d', 'e']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source1.left_outer_join(source2).where("a = d").select("a, b + e") + result = t1.left_outer_join(t2).where("a = d").select("a, b + e") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -146,28 +100,17 @@ def test_left_outer_join_with_where(self): self.assert_equals(actual, expected) def test_right_outer_join(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - field_names2 = ["d", "e"] - field_types2 = [DataTypes.INT(), DataTypes.STRING()] - data2 = [(2, "Flink"), (3, "Python"), (4, "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types2, field_names2) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Flink"), (3, "Python"), (4, "Flink")], ['d', 'e']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source1.right_outer_join(source2, "a = d").select("d, b + e") + result = t1.right_outer_join(t2, "a = d").select("d, b + e") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -176,28 +119,17 @@ def test_right_outer_join(self): self.assert_equals(actual, expected) def test_full_outer_join(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - field_names2 = ["d", "e"] - field_types2 = [DataTypes.INT(), DataTypes.STRING()] - data2 = [(2, "Flink"), (3, "Python"), (4, "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types2, field_names2) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Flink"), (3, "Python"), (4, "Flink")], ['d', 'e']) field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source1.full_outer_join(source2, "a = d").select("a, d, b + e") + result = t1.full_outer_join(t2, "a = d").select("a, d, b + e") result.insert_into("Results") t_env.execute() actual = source_sink_utils.results() diff --git a/flink-python/pyflink/table/tests/test_print_schema.py b/flink-python/pyflink/table/tests/test_print_schema.py index 4a6309ff3c8df8..53404ae444becb 100644 --- a/flink-python/pyflink/table/tests/test_print_schema.py +++ b/flink-python/pyflink/table/tests/test_print_schema.py @@ -16,8 +16,6 @@ # limitations under the License. ################################################################################ -import os - from pyflink.table.types import DataTypes from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase @@ -26,21 +24,16 @@ class StreamTableSchemaTests(PyFlinkStreamTableTestCase): def test_print_schema(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = t_env.from_elements([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hello'), (2, 'Hello', 'Hello')], + ['a', 'b', 'c']) field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestRetractSink()) - result = source.group_by("c").select("a.sum, c as b") + result = t.group_by("c").select("a.sum, c as b") result.print_schema() diff --git a/flink-python/pyflink/table/tests/test_set_operation.py b/flink-python/pyflink/table/tests/test_set_operation.py index 62e4dfaff744f6..e46201d5b246c0 100644 --- a/flink-python/pyflink/table/tests/test_set_operation.py +++ b/flink-python/pyflink/table/tests/test_set_operation.py @@ -15,7 +15,6 @@ # # See the License for the specific language governing permissions and # # limitations under the License. ################################################################################ -import os from pyflink.table.types import DataTypes from pyflink.testing import source_sink_utils @@ -26,26 +25,18 @@ class StreamTableSetOperationTests(PyFlinkStreamTableTestCase): def test_union_all(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - data2 = [(2, "Hi", "Hello"), (3, "Hello", "Python"), (4, "Hi", "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")], + ['a', 'b', 'c']) + t2 = t_env.from_elements([(2, "Hi", "Hello"), (3, "Hello", "Python"), (4, "Hi", "Flink")], + ['a', 'b', 'c']) field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = source1.union_all(source2) + result = t1.union_all(t2) result.insert_into("Results") t_env.execute() @@ -61,43 +52,27 @@ def test_union_all(self): class BatchTableSetOperationTests(PyFlinkBatchTableTestCase): + data1 = [(1, "Hi", "Hello"), (1, "Hi", "Hello"), (3, "Hello", "Hello")] + data2 = [(3, "Hello", "Hello"), (3, "Hello", "Python"), (4, "Hi", "Flink")] + schema = ["a", "b", "c"] + def test_minus(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (1, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - data2 = [(3, "Hello", "Hello"), (3, "Hello", "Python"), (4, "Hi", "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements(self.data1, self.schema) + t2 = t_env.from_elements(self.data2, self.schema) - result = source1.minus(source2) + result = t1.minus(t2) actual = self.collect(result) expected = ['1,Hi,Hello'] self.assert_equals(actual, expected) def test_minus_all(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (1, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - data2 = [(3, "Hello", "Hello"), (3, "Hello", "Python"), (4, "Hi", "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements(self.data1, self.schema) + t2 = t_env.from_elements(self.data2, self.schema) - result = source1.minus_all(source2) + result = t1.minus_all(t2) actual = self.collect(result) expected = ['1,Hi,Hello', @@ -105,70 +80,39 @@ def test_minus_all(self): self.assert_equals(actual, expected) def test_union(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (3, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - data2 = [(2, "Hi", "Hello"), (3, "Hello", "Python"), (4, "Hi", "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements(self.data1, self.schema) + t2 = t_env.from_elements(self.data2, self.schema) - result = source1.union(source2) + result = t1.union(t2) actual = self.collect(result) expected = ['1,Hi,Hello', - '2,Hi,Hello', '3,Hello,Hello', '3,Hello,Python', '4,Hi,Flink'] self.assert_equals(actual, expected) def test_intersect(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (2, "Hi", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - data2 = [(2, "Hi", "Hello"), (2, "Hi", "Hello"), (4, "Hi", "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements(self.data1, self.schema) + t2 = t_env.from_elements(self.data2, self.schema) - result = source1.intersect(source2) + result = t1.intersect(t2) actual = self.collect(result) - expected = ['2,Hi,Hello'] + expected = ['3,Hello,Hello'] self.assert_equals(actual, expected) def test_intersect_all(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello"), (2, "Hi", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - source_path2 = os.path.join(self.tempdir + '/streaming2.csv') - data2 = [(2, "Hi", "Hello"), (2, "Hi", "Hello"), (4, "Hi", "Flink")] - csv_source2 = self.prepare_csv_source(source_path2, data2, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source1", csv_source) - t_env.register_table_source("Source2", csv_source2) - source1 = t_env.scan("Source1") - source2 = t_env.scan("Source2") + t1 = t_env.from_elements(self.data1, self.schema) + t2 = t_env.from_elements(self.data2, self.schema) - result = source1.intersect_all(source2) + result = t1.intersect_all(t2) actual = self.collect(result) - expected = ['2,Hi,Hello', '2,Hi,Hello'] + expected = ['3,Hello,Hello'] self.assert_equals(actual, expected) diff --git a/flink-python/pyflink/table/tests/test_sort.py b/flink-python/pyflink/table/tests/test_sort.py index f0a5df539c148b..236eb71f503bd8 100644 --- a/flink-python/pyflink/table/tests/test_sort.py +++ b/flink-python/pyflink/table/tests/test_sort.py @@ -16,25 +16,17 @@ # limitations under the License. ################################################################################ -import os - -from pyflink.table.types import DataTypes from pyflink.testing.test_case_utils import PyFlinkBatchTableTestCase class BatchTableSortTests(PyFlinkBatchTableTestCase): def test_order_by_offset_fetch(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b"] - field_types = [DataTypes.INT(), DataTypes.STRING()] - data = [(1, "Hello"), (2, "Hello"), (3, "Flink"), (4, "Python")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = t_env.from_elements([(1, "Hello"), (2, "Hello"), (3, "Flink"), (4, "Python")], + ["a", "b"]) - result = source.order_by("a.desc").offset(2).fetch(2).select("a, b") + result = t.order_by("a.desc").offset(2).fetch(2).select("a, b") actual = self.collect(result) expected = ['2,Hello', '1,Hello'] diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py b/flink-python/pyflink/table/tests/test_table_environment_api.py index 4b2e7c3d0dc6bb..aa51f368e4c091 100644 --- a/flink-python/pyflink/table/tests/test_table_environment_api.py +++ b/flink-python/pyflink/table/tests/test_table_environment_api.py @@ -23,7 +23,7 @@ from pyflink.table.table_environment import TableEnvironment from pyflink.table.table_config import TableConfig from pyflink.table.table_sink import CsvTableSink -from pyflink.table.types import DataTypes +from pyflink.table.types import DataTypes, RowType from pyflink.testing import source_sink_utils from pyflink.testing.test_case_utils import PyFlinkBatchTableTestCase, PyFlinkStreamTableTestCase @@ -31,21 +31,15 @@ class StreamTableEnvironmentTests(PyFlinkStreamTableTestCase): def test_register_scan(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hello", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] + field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] t_env.register_table_sink( "Results", field_names, field_types, source_sink_utils.TestAppendSink()) - result = t_env.scan("Source") - result.insert_into("Results") + t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])\ + .insert_into("Results") t_env.execute() actual = source_sink_utils.results() @@ -53,18 +47,14 @@ def test_register_scan(self): self.assert_equals(actual, expected) def test_register_table_source_sink(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - - t_env.register_table_source("Orders", csv_source) + field_names = ["a", "b", "c"] + field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] t_env.register_table_sink( "Sinks", field_names, field_types, source_sink_utils.TestAppendSink()) - t_env.scan("Orders").insert_into("Sinks") + + t_env.from_elements([(1, "Hi", "Hello")], ["a", "b", "c"]).insert_into("Sinks") t_env.execute() actual = source_sink_utils.results() @@ -72,18 +62,15 @@ def test_register_table_source_sink(self): self.assert_equals(actual, expected) def test_from_table_source(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [(1, "Hi", "Hello"), (2, "Hi", "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env + field_names = ["a", "b", "c"] + field_types = [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.STRING()] t_env.register_table_sink( "Sinks", field_names, field_types, source_sink_utils.TestAppendSink()) - source = t_env.from_table_source(csv_source) - source.insert_into("Sinks") + t_env.from_elements([(1, "Hi", "Hello"), (2, "Hi", "Hello")], ["a", "b", "c"])\ + .insert_into("Sinks") t_env.execute() actual = source_sink_utils.results() @@ -111,15 +98,13 @@ def test_list_tables(self): self.assert_equals(actual, expected) def test_explain(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.INT(), DataTypes.STRING(), DataTypes.STRING()] - data = [] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) + schema = RowType()\ + .add('a', DataTypes.INT())\ + .add('b', DataTypes.STRING())\ + .add('c', DataTypes.STRING()) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") - result = source.alias("a, b, c").select("1 + a, b, c") + t = t_env.from_elements([], schema) + result = t.select("1 + a, b, c") actual = t_env.explain(result) diff --git a/flink-python/pyflink/table/tests/test_window.py b/flink-python/pyflink/table/tests/test_window.py index b37f895a96eec0..5fc149cb3b2a9e 100644 --- a/flink-python/pyflink/table/tests/test_window.py +++ b/flink-python/pyflink/table/tests/test_window.py @@ -16,30 +16,22 @@ # limitations under the License. ################################################################################ -import os - from py4j.protocol import Py4JJavaError from pyflink.table.window import Session, Slide, Tumble from pyflink.table import Over -from pyflink.table.types import DataTypes from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase, PyFlinkBatchTableTestCase class StreamTableWindowTests(PyFlinkStreamTableTestCase): def test_over_window(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()] - data = [(1, 1, "Hello"), (2, 2, "Hello"), (3, 4, "Hello"), (4, 8, "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + data = [(1, 1, "Hello"), (2, 2, "Hello"), (3, 4, "Hello"), (4, 8, "Hello")] + t = t_env.from_elements(data, ['a', 'b', 'c']) - result = source.over_window(Over.partition_by("c").order_by("a") - .preceding("2.rows").following("current_row").alias("w")) + result = t.over_window(Over.partition_by("c").order_by("a") + .preceding("2.rows").following("current_row").alias("w")) self.assertRaisesRegexp( Py4JJavaError, "Ordering must be defined on a time attribute", @@ -49,16 +41,10 @@ def test_over_window(self): class BatchTableWindowTests(PyFlinkBatchTableTestCase): def test_tumble_window(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()] - data = [(1, 1, "Hello"), (2, 2, "Hello"), (3, 4, "Hello"), (4, 8, "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") - - result = source.window(Tumble.over("2.rows").on("a").alias("w"))\ + t = self.t_env.from_elements( + [(1, 1, "Hello"), (2, 2, "Hello"), (3, 4, "Hello"), (4, 8, "Hello")], + ["a", "b", "c"]) + result = t.window(Tumble.over("2.rows").on("a").alias("w"))\ .group_by("w, c").select("b.sum") actual = self.collect(result) @@ -66,16 +52,11 @@ def test_tumble_window(self): self.assert_equals(actual, expected) def test_slide_window(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()] - data = [(1000, 1, "Hello"), (2000, 2, "Hello"), (3000, 4, "Hello"), (4000, 8, "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = self.t_env.from_elements( + [(1000, 1, "Hello"), (2000, 2, "Hello"), (3000, 4, "Hello"), (4000, 8, "Hello")], + ["a", "b", "c"]) - result = source.window(Slide.over("2.seconds").every("1.seconds").on("a").alias("w"))\ + result = t.window(Slide.over("2.seconds").every("1.seconds").on("a").alias("w"))\ .group_by("w, c").select("b.sum") actual = self.collect(result) @@ -83,16 +64,11 @@ def test_slide_window(self): self.assert_equals(actual, expected) def test_session_window(self): - source_path = os.path.join(self.tempdir + '/streaming.csv') - field_names = ["a", "b", "c"] - field_types = [DataTypes.BIGINT(), DataTypes.INT(), DataTypes.STRING()] - data = [(1000, 1, "Hello"), (2000, 2, "Hello"), (4000, 4, "Hello"), (5000, 8, "Hello")] - csv_source = self.prepare_csv_source(source_path, data, field_types, field_names) - t_env = self.t_env - t_env.register_table_source("Source", csv_source) - source = t_env.scan("Source") + t = self.t_env.from_elements( + [(1000, 1, "Hello"), (2000, 2, "Hello"), (4000, 4, "Hello"), (5000, 8, "Hello")], + ["a", "b", "c"]) - result = source.window(Session.with_gap("1.seconds").on("a").alias("w"))\ + result = t.window(Session.with_gap("1.seconds").on("a").alias("w"))\ .group_by("w, c").select("b.sum") actual = self.collect(result) diff --git a/flink-python/src/main/java/org/apache/flink/python/bridge/PythonBridgeUtils.java b/flink-python/src/main/java/org/apache/flink/python/bridge/PythonBridgeUtils.java new file mode 100644 index 00000000000000..1c5c6e86fe7007 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/python/bridge/PythonBridgeUtils.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.python.bridge; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.python.bridge.pickle.ArrayConstructor; +import org.apache.flink.python.bridge.pickle.ByteArrayConstructor; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.Collector; + +import net.razorvine.pickle.Unpickler; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +/** + * Utility class that contains helper methods to create a DataStream/DataSet from + * a file which contains Python objects. + */ +public final class PythonBridgeUtils { + + /** + * Creates a DataStream from a file which contains serialized python objects. + */ + public static DataStream createDataStreamFromFile( + final StreamExecutionEnvironment streamExecutionEnvironment, + final String fileName, + final boolean batched) throws IOException { + return streamExecutionEnvironment + .fromCollection(readPythonObjects(fileName)) + .flatMap(new PythonFlatMapFunction(batched)) + .returns(Types.GENERIC(Object[].class)); + } + + /** + * Creates a DataSet from a file which contains serialized python objects. + */ + public static DataSet createDataSetFromFile( + final ExecutionEnvironment executionEnvironment, + final String fileName, + final boolean batched) throws IOException { + return executionEnvironment + .fromCollection(readPythonObjects(fileName)) + .flatMap(new PythonFlatMapFunction(batched)) + .returns(Types.GENERIC(Object[].class)); + } + + private static List readPythonObjects(final String fileName) throws IOException { + List objs = new LinkedList<>(); + try (DataInputStream din = new DataInputStream(new FileInputStream(fileName))) { + try { + while (true) { + final int length = din.readInt(); + byte[] obj = new byte[length]; + din.readFully(obj); + objs.add(obj); + } + } catch (EOFException eof) { + // expected + } + } + return objs; + } + + private static final class PythonFlatMapFunction extends RichFlatMapFunction { + + private static final long serialVersionUID = 1L; + + private final boolean batched; + private transient Unpickler unpickle; + + PythonFlatMapFunction(boolean batched) { + this.batched = batched; + initialize(); + } + + @Override + public void open(Configuration parameters) { + this.unpickle = new Unpickler(); + } + + @Override + public void flatMap(byte[] value, Collector out) throws Exception { + Object obj = unpickle.loads(value); + if (batched) { + if (obj instanceof Object[]) { + for (int i = 0; i < ((Object[]) obj).length; i++) { + collect(out, ((Object[]) obj)[i]); + } + } else { + for (Object o : (ArrayList) obj) { + collect(out, o); + } + } + } else { + collect(out, obj); + } + } + + private void collect(Collector out, Object obj) { + if (obj.getClass().isArray()) { + out.collect((Object[]) obj); + } else { + out.collect(((ArrayList) obj).toArray(new Object[0])); + } + } + } + + private static boolean initialized = false; + private static void initialize() { + synchronized (PythonBridgeUtils.class) { + if (!initialized) { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()); + Unpickler.registerConstructor("__builtin__", "bytearray", new ByteArrayConstructor()); + Unpickler.registerConstructor("builtins", "bytearray", new ByteArrayConstructor()); + Unpickler.registerConstructor("__builtin__", "bytes", new ByteArrayConstructor()); + Unpickler.registerConstructor("_codecs", "encode", new ByteArrayConstructor()); + initialized = true; + } + } + } +} diff --git a/flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ArrayConstructor.java b/flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ArrayConstructor.java new file mode 100644 index 00000000000000..c2644b72bdb980 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ArrayConstructor.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.python.bridge.pickle; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; + +/** + * Creates arrays of objects. Returns a primitive type array such as int[] if + * the objects are ints, etc. Returns an ArrayList if it needs to + * contain arbitrary objects (such as lists). + */ +public final class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { + + @Override + public Object construct(Object[] args) { + if (args.length == 2 && args[1] instanceof String) { + char typecode = ((String) args[0]).charAt(0); + // This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly + byte[] data = ((String) args[1]).getBytes(StandardCharsets.ISO_8859_1); + if (typecode == 'c') { + // It seems like the pickle of pypy uses the similar protocol to Python 2.6, which uses + // a string for array data instead of list as Python 2.7, and handles an array of + // typecode 'c' as 1-byte character. + char[] result = new char[data.length]; + int i = 0; + while (i < data.length) { + result[i] = (char) data[i]; + i += 1; + } + return result; + } + } else if (args.length == 2 && args[0] == "l") { + // On Python 2, an array of typecode 'l' should be handled as long rather than int. + ArrayList values = (ArrayList) args[1]; + long[] result = new long[values.size()]; + int i = 0; + while (i < values.size()) { + result[i] = ((Number) values.get(i)).longValue(); + i += 1; + } + return result; + } + + return super.construct(args); + } +} diff --git a/flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ByteArrayConstructor.java b/flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ByteArrayConstructor.java new file mode 100644 index 00000000000000..ce8ebd69f633d0 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/python/bridge/pickle/ByteArrayConstructor.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.python.bridge.pickle; + +import org.apache.commons.lang3.ArrayUtils; + +/** + * Creates byte arrays (byte[]). Deal with an empty byte array pickled by Python 3. + */ +public final class ByteArrayConstructor extends net.razorvine.pickle.objects.ByteArrayConstructor { + + @Override + public Object construct(Object[] args) { + if (args.length == 0) { + return ArrayUtils.EMPTY_BYTE_ARRAY; + } else { + return super.construct(args); + } + } +} diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java b/flink-python/src/main/java/org/apache/flink/python/client/PythonDriver.java similarity index 96% rename from flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java rename to flink-python/src/main/java/org/apache/flink/python/client/PythonDriver.java index e43a24eec98ea0..66040a395d1024 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/python/PythonDriver.java +++ b/flink-python/src/main/java/org/apache/flink/python/client/PythonDriver.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.client.python; +package org.apache.flink.python.client; import org.apache.flink.core.fs.Path; @@ -59,11 +59,11 @@ public static void main(String[] args) { List commands = constructPythonCommands(filePathMap, parsedArgs); try { // prepare the exec environment of python progress. - PythonUtil.PythonEnvironment pythonEnv = PythonUtil.preparePythonEnvironment(filePathMap); + PythonEnvUtils.PythonEnvironment pythonEnv = PythonEnvUtils.preparePythonEnvironment(filePathMap); // set env variable PYFLINK_GATEWAY_PORT for connecting of python gateway in python progress. pythonEnv.systemEnv.put("PYFLINK_GATEWAY_PORT", String.valueOf(gatewayServer.getListeningPort())); // start the python process. - Process pythonProcess = PythonUtil.startPythonProcess(pythonEnv, commands); + Process pythonProcess = PythonEnvUtils.startPythonProcess(pythonEnv, commands); int exitCode = pythonProcess.waitFor(); if (exitCode != 0) { throw new RuntimeException("Python process exits with code: " + exitCode); diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java b/flink-python/src/main/java/org/apache/flink/python/client/PythonEnvUtils.java similarity index 96% rename from flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java rename to flink-python/src/main/java/org/apache/flink/python/client/PythonEnvUtils.java index 9fecd499966731..d733d5085a1ce1 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/python/PythonUtil.java +++ b/flink-python/src/main/java/org/apache/flink/python/client/PythonEnvUtils.java @@ -16,8 +16,9 @@ * limitations under the License. */ -package org.apache.flink.client.python; +package org.apache.flink.python.client; +import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; import org.apache.flink.util.FileUtils; @@ -41,10 +42,10 @@ /** * The util class help to prepare Python env and run the python process. */ -public final class PythonUtil { - private static final Logger LOG = LoggerFactory.getLogger(PythonUtil.class); +public final class PythonEnvUtils { + private static final Logger LOG = LoggerFactory.getLogger(PythonEnvUtils.class); - private static final String FLINK_OPT_DIR = System.getenv("FLINK_OPT_DIR"); + private static final String FLINK_OPT_DIR = System.getenv(ConfigConstants.ENV_FLINK_OPT_DIR); private static final String FLINK_OPT_DIR_PYTHON = FLINK_OPT_DIR + File.separator + "python"; diff --git a/flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java b/flink-python/src/main/java/org/apache/flink/python/client/PythonGatewayServer.java similarity index 98% rename from flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java rename to flink-python/src/main/java/org/apache/flink/python/client/PythonGatewayServer.java index 64f2ef1d382819..0d2c447b0e0333 100644 --- a/flink-clients/src/main/java/org/apache/flink/client/python/PythonGatewayServer.java +++ b/flink-python/src/main/java/org/apache/flink/python/client/PythonGatewayServer.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.client.python; +package org.apache.flink.python.client; import py4j.GatewayServer; diff --git a/flink-python/src/main/resources/META-INF/NOTICE b/flink-python/src/main/resources/META-INF/NOTICE new file mode 100644 index 00000000000000..d4eeef5a4639ad --- /dev/null +++ b/flink-python/src/main/resources/META-INF/NOTICE @@ -0,0 +1,15 @@ +flink-python +Copyright 2014-2019 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +This project bundles the following dependencies under the BSD license. +See bundled license files for details + +- net.sf.py4j:py4j:0.10.8.1 + +This project bundles the following dependencies under the MIT license. (https://opensource.org/licenses/MIT) +See bundled license files for details. + +- net.razorvine:pyrolite:4.13 diff --git a/flink-dist/src/main/resources/META-INF/licenses/LICENSE.py4j b/flink-python/src/main/resources/META-INF/licenses/LICENSE.py4j similarity index 100% rename from flink-dist/src/main/resources/META-INF/licenses/LICENSE.py4j rename to flink-python/src/main/resources/META-INF/licenses/LICENSE.py4j diff --git a/flink-python/src/main/resources/META-INF/licenses/LICENSE.pyrolite b/flink-python/src/main/resources/META-INF/licenses/LICENSE.pyrolite new file mode 100644 index 00000000000000..ad923a6ea4c9ca --- /dev/null +++ b/flink-python/src/main/resources/META-INF/licenses/LICENSE.pyrolite @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) by Irmen de Jong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java b/flink-python/src/test/java/org/apache/flink/python/client/PythonDriverTest.java similarity index 98% rename from flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java rename to flink-python/src/test/java/org/apache/flink/python/client/PythonDriverTest.java index 0b6f570e75a447..77955aca8a6695 100644 --- a/flink-clients/src/test/java/org/apache/flink/client/python/PythonDriverTest.java +++ b/flink-python/src/test/java/org/apache/flink/python/client/PythonDriverTest.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.client.python; +package org.apache.flink.python.client; import org.apache.flink.core.fs.Path; diff --git a/flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java b/flink-python/src/test/java/org/apache/flink/python/client/PythonEnvUtilsTest.java similarity index 93% rename from flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java rename to flink-python/src/test/java/org/apache/flink/python/client/PythonEnvUtilsTest.java index 4b14cede4e310c..22d1b35417544c 100644 --- a/flink-clients/src/test/java/org/apache/flink/client/python/PythonUtilTest.java +++ b/flink-python/src/test/java/org/apache/flink/python/client/PythonEnvUtilsTest.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.client.python; +package org.apache.flink.python.client; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; @@ -35,9 +35,9 @@ import java.util.UUID; /** - * Tests for the {@link PythonUtil}. + * Tests for the {@link PythonEnvUtils}. */ -public class PythonUtilTest { +public class PythonEnvUtilsTest { private Path sourceTmpDirPath; private Path targetTmpDirPath; private FileSystem sourceFs; @@ -70,7 +70,7 @@ public void prepareTestEnvironment() { @Test public void testStartPythonProcess() { - PythonUtil.PythonEnvironment pythonEnv = new PythonUtil.PythonEnvironment(); + PythonEnvUtils.PythonEnvironment pythonEnv = new PythonEnvUtils.PythonEnvironment(); pythonEnv.workingDirectory = targetTmpDirPath.toString(); pythonEnv.pythonPath = targetTmpDirPath.toString(); List commands = new ArrayList<>(); @@ -91,7 +91,7 @@ public void testStartPythonProcess() { Path result = new Path(targetTmpDirPath, "word_count_result.txt"); commands.add(pyFile.getName()); commands.add(result.getName()); - Process pythonProcess = PythonUtil.startPythonProcess(pythonEnv, commands); + Process pythonProcess = PythonEnvUtils.startPythonProcess(pythonEnv, commands); int exitCode = pythonProcess.waitFor(); if (exitCode != 0) { throw new RuntimeException("Python process exits with code: " + exitCode); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/python/PythonTableUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/python/PythonTableUtils.scala new file mode 100644 index 00000000000000..70a5111c07a26d --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/python/PythonTableUtils.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.python + +import java.nio.charset.StandardCharsets +import java.sql.{Date, Time, Timestamp} +import java.util.function.BiConsumer + +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation} +import org.apache.flink.api.java.DataSet +import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.{Table, Types} +import org.apache.flink.table.api.java.{BatchTableEnvironment, StreamTableEnvironment} +import org.apache.flink.types.Row + +object PythonTableUtils { + + /** + * Converts the given [[DataStream]] into a [[Table]]. + * + * The schema of the [[Table]] is derived from the specified schemaString. + * + * @param tableEnv The table environment. + * @param dataStream The [[DataStream]] to be converted. + * @param dataType The type information of the table. + * @return The converted [[Table]]. + */ + def fromDataStream( + tableEnv: StreamTableEnvironment, + dataStream: DataStream[Array[Object]], + dataType: TypeInformation[Row]): Table = { + val convertedDataStream = dataStream.map( + new MapFunction[Array[Object], Row] { + override def map(value: Array[Object]): Row = + convertTo(dataType).apply(value).asInstanceOf[Row] + }).returns(dataType.asInstanceOf[TypeInformation[Row]]) + + tableEnv.fromDataStream(convertedDataStream) + } + + /** + * Converts the given [[DataSet]] into a [[Table]]. + * + * The schema of the [[Table]] is derived from the specified schemaString. + * + * @param tableEnv The table environment. + * @param dataSet The [[DataSet]] to be converted. + * @param dataType The type information of the table. + * @return The converted [[Table]]. + */ + def fromDataSet( + tableEnv: BatchTableEnvironment, + dataSet: DataSet[Array[Object]], + dataType: TypeInformation[Row]): Table = { + val convertedDataSet = dataSet.map( + new MapFunction[Array[Object], Row] { + override def map(value: Array[Object]): Row = + convertTo(dataType).apply(value).asInstanceOf[Row] + }).returns(dataType.asInstanceOf[TypeInformation[Row]]) + + tableEnv.fromDataSet(convertedDataSet) + } + + /** + * Creates a converter that converts `obj` to the type specified by the data type, or returns + * null if the type of obj is unexpected because Python doesn't enforce the type. + */ + private def convertTo(dataType: TypeInformation[_]): Any => Any = dataType match { + case Types.BOOLEAN => (obj: Any) => nullSafeConvert(obj) { + case b: Boolean => b + } + + case Types.BYTE => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c + case c: Short => c.toByte + case c: Int => c.toByte + case c: Long => c.toByte + } + + case Types.SHORT => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toShort + case c: Short => c + case c: Int => c.toShort + case c: Long => c.toShort + } + + case Types.INT => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toInt + case c: Short => c.toInt + case c: Int => c + case c: Long => c.toInt + } + + case Types.LONG => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toLong + case c: Short => c.toLong + case c: Int => c.toLong + case c: Long => c + } + + case Types.FLOAT => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c + case c: Double => c.toFloat + } + + case Types.DOUBLE => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c.toDouble + case c: Double => c + } + + case Types.DECIMAL => (obj: Any) => nullSafeConvert(obj) { + case c: java.math.BigDecimal => c + } + + case Types.SQL_DATE => (obj: Any) => nullSafeConvert(obj) { + case c: Int => new Date(c * 86400000) + } + + case Types.SQL_TIME => (obj: Any) => nullSafeConvert(obj) { + case c: Long => new Time(c / 1000) + case c: Int => new Time(c.toLong / 1000) + } + + case Types.SQL_TIMESTAMP => (obj: Any) => nullSafeConvert(obj) { + case c: Long => new Timestamp(c / 1000) + case c: Int => new Timestamp(c.toLong / 1000) + } + + case Types.STRING => (obj: Any) => nullSafeConvert(obj) { + case _ => obj.toString + } + + case PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO => + (obj: Any) => + nullSafeConvert(obj) { + case c: String => c.getBytes(StandardCharsets.UTF_8) + case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + } + + case _: PrimitiveArrayTypeInfo[_] | + _: BasicArrayTypeInfo[_, _] | + _: ObjectArrayTypeInfo[_, _] => + val elementType = dataType match { + case p: PrimitiveArrayTypeInfo[_] => + p.getComponentType + case b: BasicArrayTypeInfo[_, _] => + b.getComponentInfo + case o: ObjectArrayTypeInfo[_, _] => + o.getComponentInfo + } + val elementFromJava = convertTo(elementType) + + (obj: Any) => nullSafeConvert(obj) { + case c: java.util.List[_] => + createArray(elementType, + c.size(), + i => elementFromJava(c.get(i))) + case c if c.getClass.isArray => + createArray(elementType, + c.asInstanceOf[Array[_]].length, + i => elementFromJava(c.asInstanceOf[Array[_]](i))) + } + + case m: MapTypeInfo[_, _] => + val keyFromJava = convertTo(m.getKeyTypeInfo) + val valueFromJava = convertTo(m.getValueTypeInfo) + + (obj: Any) => nullSafeConvert(obj) { + case javaMap: java.util.Map[_, _] => + val map = new java.util.HashMap[Any, Any] + javaMap.forEach(new BiConsumer[Any, Any] { + override def accept(k: Any, v: Any): Unit = + map.put(keyFromJava(k), valueFromJava(v)) + }) + map + } + + case rowType: RowTypeInfo => + val fieldsFromJava = rowType.getFieldTypes.map(f => convertTo(f)) + + (obj: Any) => nullSafeConvert(obj) { + case c if c.getClass.isArray => + val r = c.asInstanceOf[Array[_]] + if (r.length != rowType.getFieldTypes.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${rowType.getFieldTypes.length} fields are required while ${r.length} " + + s"values are provided." + ) + } + + val row = new Row(r.length) + var i = 0 + while (i < r.length) { + row.setField(i, fieldsFromJava(i)(r(i))) + i += 1 + } + row + } + + // UserDefinedType + case _ => (obj: Any) => obj + } + + private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, { + _: Any => null + }) + } + } + + private def createArray( + elementType: TypeInformation[_], + length: Int, + getElement: Int => Any): Array[_] = { + elementType match { + case BasicTypeInfo.BOOLEAN_TYPE_INFO => + val array = new Array[Boolean](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Boolean] + } + array + + case BasicTypeInfo.BYTE_TYPE_INFO => + val array = new Array[Byte](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Byte] + } + array + + case BasicTypeInfo.SHORT_TYPE_INFO => + val array = new Array[Short](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Short] + } + array + + case BasicTypeInfo.INT_TYPE_INFO => + val array = new Array[Int](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Int] + } + array + + case BasicTypeInfo.LONG_TYPE_INFO => + val array = new Array[Long](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Long] + } + array + + case BasicTypeInfo.FLOAT_TYPE_INFO => + val array = new Array[Float](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Float] + } + array + + case BasicTypeInfo.DOUBLE_TYPE_INFO => + val array = new Array[Double](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Double] + } + array + + case BasicTypeInfo.STRING_TYPE_INFO => + val array = new Array[java.lang.String](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[java.lang.String] + } + array + + case _ => + val array = new Array[Object](length) + for (i <- 0 until length) { + array(i) = getElement(i).asInstanceOf[Object] + } + array + } + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/types/python/ExamplePointUserDefinedType.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/types/python/ExamplePointUserDefinedType.java new file mode 100644 index 00000000000000..cd49dbd238dc5a --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/types/python/ExamplePointUserDefinedType.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.python; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; + +import java.util.List; + +/** + * Type information for Python class ExamplePoint. + */ +public final class ExamplePointUserDefinedType extends TypeInformation> { + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + public int getTotalFields() { + return 1; + } + + @Override + public Class> getTypeClass() { + return (Class>) (Class) List.class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + public TypeSerializer> createSerializer(ExecutionConfig config) { + return new ListSerializer<>(DoubleSerializer.INSTANCE); + } + + @Override + public String toString() { + return getClass().getCanonicalName(); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof ExamplePointUserDefinedType; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof ExamplePointUserDefinedType; + } +} diff --git a/licenses-binary/LICENSE.py4j b/licenses-binary/LICENSE.py4j new file mode 100644 index 00000000000000..0f45e3e464c1ea --- /dev/null +++ b/licenses-binary/LICENSE.py4j @@ -0,0 +1,26 @@ +Copyright (c) 2009-2018, Barthelemy Dagenais and individual contributors. All +rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses-binary/LICENSE.pyrolite b/licenses-binary/LICENSE.pyrolite new file mode 100644 index 00000000000000..ad923a6ea4c9ca --- /dev/null +++ b/licenses-binary/LICENSE.pyrolite @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) by Irmen de Jong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE.py4j b/licenses/LICENSE.py4j new file mode 100644 index 00000000000000..0f45e3e464c1ea --- /dev/null +++ b/licenses/LICENSE.py4j @@ -0,0 +1,26 @@ +Copyright (c) 2009-2018, Barthelemy Dagenais and individual contributors. All +rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tools/travis_controller.sh b/tools/travis_controller.sh index f19fd8a2dbf95b..c13d72680b2ec0 100755 --- a/tools/travis_controller.sh +++ b/tools/travis_controller.sh @@ -161,6 +161,7 @@ if [ $STAGE == "$STAGE_COMPILE" ]; then find "$CACHE_FLINK_DIR" -maxdepth 8 -type f -name '*.jar' \ ! -path "$CACHE_FLINK_DIR/flink-dist/target/flink-*-bin/flink-*/lib/flink-dist*.jar" \ ! -path "$CACHE_FLINK_DIR/flink-dist/target/flink-*-bin/flink-*/opt/flink-table*.jar" \ + ! -path "$CACHE_FLINK_DIR/flink-dist/target/flink-*-bin/flink-*/opt/flink-python*java-binding.jar" \ ! -path "$CACHE_FLINK_DIR/flink-connectors/flink-connector-elasticsearch-base/target/flink-*.jar" \ ! -path "$CACHE_FLINK_DIR/flink-connectors/flink-connector-kafka-base/target/flink-*.jar" \ ! -path "$CACHE_FLINK_DIR/flink-table/flink-table-planner/target/flink-table-planner*tests.jar" | xargs rm -rf From e521191b6bb224e13d10ecd4c31eea1807d636d9 Mon Sep 17 00:00:00 2001 From: Gary Yao Date: Sat, 1 Jun 2019 21:52:23 +0200 Subject: [PATCH 53/92] [FLINK-12689][dist] Add flink-azure-fs-hadoop dependency to flink-dist This closes #8590. --- flink-dist/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/flink-dist/pom.xml b/flink-dist/pom.xml index 343cc58d7a7ba0..59c2b77ebaed9f 100644 --- a/flink-dist/pom.xml +++ b/flink-dist/pom.xml @@ -327,6 +327,13 @@ under the License. provided + + org.apache.flink + flink-azure-fs-hadoop + ${project.version} + provided + + org.apache.flink flink-s3-fs-hadoop From 86758504ead2b3ad0ccfc81f5977d2a3f8cab913 Mon Sep 17 00:00:00 2001 From: hehuiyuan <471627698@qq.com> Date: Mon, 3 Jun 2019 17:39:36 +0800 Subject: [PATCH 54/92] [hotfix][docs] Remove space in variable name --- docs/dev/stream/state/state.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/stream/state/state.md b/docs/dev/stream/state/state.md index a128296425e4b7..393a702c11b27c 100644 --- a/docs/dev/stream/state/state.md +++ b/docs/dev/stream/state/state.md @@ -424,7 +424,7 @@ import org.apache.flink.api.common.state.StateTtlConfig;
{% highlight scala %} import org.apache.flink.api.common.state.StateTtlConfig -val ttlConfig = StateTtlCon fig +val ttlConfig = StateTtlConfig .newBuilder(Time.seconds(1)) .cleanupIncrementally(10, true) .build From c6eb3f5eb4d32fad4fb973a7e2da0bbc6a7dfc4a Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Fri, 31 May 2019 12:36:54 +0800 Subject: [PATCH 55/92] [FLINK-12687][table-runtime-blink] ByteHashSet is always in dense mode This closes #8579 --- .../runtime/util/collections/ByteHashSet.java | 122 ++---------------- 1 file changed, 11 insertions(+), 111 deletions(-) diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java index bc84943ff1cb4b..9364f280e5bf65 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/util/collections/ByteHashSet.java @@ -17,135 +17,35 @@ package org.apache.flink.table.runtime.util.collections; -import org.apache.flink.table.util.MurmurHashUtil; - /** * Byte hash set. */ -public class ByteHashSet extends OptimizableHashSet { - - private byte[] key; +public class ByteHashSet { - private byte min = Byte.MAX_VALUE; - private byte max = Byte.MIN_VALUE; + protected boolean containsNull; - public ByteHashSet(final int expected, final float f) { - super(expected, f); - this.key = new byte[this.n + 1]; - } - - public ByteHashSet(final int expected) { - this(expected, DEFAULT_LOAD_FACTOR); - } + protected boolean[] used; public ByteHashSet() { - this(DEFAULT_INITIAL_SIZE, DEFAULT_LOAD_FACTOR); + used = new boolean[Byte.MAX_VALUE - Byte.MIN_VALUE + 1]; } public boolean add(final byte k) { - if (k == 0) { - if (this.containsZero) { - return false; - } - - this.containsZero = true; - } else { - byte[] key = this.key; - int pos; - byte curr; - if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) != 0) { - if (curr == k) { - return false; - } - - while ((curr = key[pos = pos + 1 & this.mask]) != 0) { - if (curr == k) { - return false; - } - } - } - - key[pos] = k; - } - - if (this.size++ >= this.maxFill) { - this.rehash(OptimizableHashSet.arraySize(this.size + 1, this.f)); - } + return !used[k - Byte.MIN_VALUE] && (used[k - Byte.MIN_VALUE] = true); + } - if (k < min) { - min = k; - } - if (k > max) { - max = k; - } - return true; + public void addNull() { + this.containsNull = true; } public boolean contains(final byte k) { - if (isDense) { - return k >= min && k <= max && used[k - min]; - } else { - if (k == 0) { - return this.containsZero; - } else { - byte[] key = this.key; - byte curr; - int pos; - if ((curr = key[pos = MurmurHashUtil.fmix(k) & this.mask]) == 0) { - return false; - } else if (k == curr) { - return true; - } else { - while ((curr = key[pos = pos + 1 & this.mask]) != 0) { - if (k == curr) { - return true; - } - } - - return false; - } - } - } + return used[k - Byte.MIN_VALUE]; } - private void rehash(final int newN) { - byte[] key = this.key; - int mask = newN - 1; - byte[] newKey = new byte[newN + 1]; - int i = this.n; - - int pos; - for (int j = this.realSize(); j-- != 0; newKey[pos] = key[i]) { - do { - --i; - } while(key[i] == 0); - - if (newKey[pos = MurmurHashUtil.fmix(key[i]) & mask] != 0) { - while (newKey[pos = pos + 1 & mask] != 0) {} - } - } - - this.n = newN; - this.mask = mask; - this.maxFill = OptimizableHashSet.maxFill(this.n, this.f); - this.key = newKey; + public boolean containsNull() { + return containsNull; } - @Override public void optimize() { - int range = max - min; - if (range >= 0 && (range < key.length || range < OptimizableHashSet.DENSE_THRESHOLD)) { - this.used = new boolean[max - min + 1]; - for (byte v : key) { - if (v != 0) { - used[v - min] = true; - } - } - if (containsZero) { - used[-min] = true; - } - isDense = true; - key = null; - } } } From f8b78f1c390c5cdcae1d3b36739c7e693da45fe6 Mon Sep 17 00:00:00 2001 From: Bo WANG Date: Thu, 9 May 2019 11:05:30 +0800 Subject: [PATCH 56/92] [FLINK-12229][runtime] Add LazyFromSourcesSchedulingStrategy The LazyFromSourcesSchedulingStrategy encapsulates Flink's basic batch scheduling strategy. It starts scheduling the source vertices and schedules consumers as soon as their input result partitions become consumable. For pipelined result partitions this is the case once a record has been produced whereas for blocking result partitions all producers need to finish. This closes #8309. --- .../DefaultSchedulingExecutionVertex.java | 12 +- ...utionGraphToSchedulingTopologyAdapter.java | 3 +- .../InputDependencyConstraintChecker.java | 180 +++++++++ .../LazyFromSourcesSchedulingStrategy.java | 180 +++++++++ .../strategy/SchedulingExecutionVertex.java | 8 + .../DefaultSchedulingExecutionVertexTest.java | 7 +- .../DefaultSchedulingResultPartitionTest.java | 4 +- ...nGraphToSchedulingTopologyAdapterTest.java | 1 + .../strategy/EagerSchedulingStrategyTest.java | 9 +- .../InputDependencyConstraintCheckerTest.java | 289 +++++++++++++ ...LazyFromSourcesSchedulingStrategyTest.java | 382 ++++++++++++++++++ .../scheduler/strategy/StrategyTestUtil.java | 38 ++ .../strategy/TestingSchedulerOperations.java | 6 +- .../TestingSchedulingExecutionVertex.java | 45 ++- .../TestingSchedulingResultPartition.java | 120 ++++++ .../strategy/TestingSchedulingTopology.java | 192 ++++++++- 16 files changed, 1458 insertions(+), 18 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintChecker.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategy.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintCheckerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategyTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/StrategyTestUtil.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java index 4b13d70c97d5b9..92d80474654399 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertex.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.scheduler.adapter; +import org.apache.flink.api.common.InputDependencyConstraint; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex; @@ -44,14 +45,18 @@ class DefaultSchedulingExecutionVertex implements SchedulingExecutionVertex { private final Supplier stateSupplier; + private final InputDependencyConstraint inputDependencyConstraint; + DefaultSchedulingExecutionVertex( ExecutionVertexID executionVertexId, List producedPartitions, - Supplier stateSupplier) { + Supplier stateSupplier, + InputDependencyConstraint constraint) { this.executionVertexId = checkNotNull(executionVertexId); this.consumedPartitions = new ArrayList<>(); this.stateSupplier = checkNotNull(stateSupplier); this.producedPartitions = checkNotNull(producedPartitions); + this.inputDependencyConstraint = checkNotNull(constraint); } @Override @@ -74,6 +79,11 @@ public Collection getProducedResultPartitions() { return Collections.unmodifiableCollection(producedPartitions); } + @Override + public InputDependencyConstraint getInputDependencyConstraint() { + return inputDependencyConstraint; + } + void addConsumedPartition(X partition) { consumedPartitions.add(partition); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java index abf94697fbeee8..9b377d82b18eb5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapter.java @@ -109,7 +109,8 @@ private static DefaultSchedulingExecutionVertex generateSchedulingExecutionVerte DefaultSchedulingExecutionVertex schedulingVertex = new DefaultSchedulingExecutionVertex( new ExecutionVertexID(vertex.getJobvertexId(), vertex.getParallelSubtaskIndex()), producedPartitions, - new ExecutionStateSupplier(vertex)); + new ExecutionStateSupplier(vertex), + vertex.getInputDependencyConstraint()); producedPartitions.forEach(partition -> partition.setProducer(schedulingVertex)); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintChecker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintChecker.java new file mode 100644 index 00000000000000..ec0bf6fd172e7c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintChecker.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.strategy; + +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.runtime.jobgraph.IntermediateDataSet; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.flink.api.common.InputDependencyConstraint.ALL; +import static org.apache.flink.api.common.InputDependencyConstraint.ANY; +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.DONE; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.PRODUCING; + +/** + * A wrapper class for {@link InputDependencyConstraint} checker. + */ +public class InputDependencyConstraintChecker { + private final SchedulingIntermediateDataSetManager intermediateDataSetManager = + new SchedulingIntermediateDataSetManager(); + + public boolean check(final SchedulingExecutionVertex schedulingExecutionVertex) { + final InputDependencyConstraint inputConstraint = schedulingExecutionVertex.getInputDependencyConstraint(); + if (schedulingExecutionVertex.getConsumedResultPartitions().isEmpty() || ALL.equals(inputConstraint)) { + return checkAll(schedulingExecutionVertex); + } else if (ANY.equals(inputConstraint)) { + return checkAny(schedulingExecutionVertex); + } else { + throw new IllegalArgumentException(); + } + } + + List markSchedulingResultPartitionFinished(SchedulingResultPartition srp) { + return intermediateDataSetManager.markSchedulingResultPartitionFinished(srp); + } + + void resetSchedulingResultPartition(SchedulingResultPartition srp) { + intermediateDataSetManager.resetSchedulingResultPartition(srp); + } + + void addSchedulingResultPartition(SchedulingResultPartition srp) { + intermediateDataSetManager.addSchedulingResultPartition(srp); + } + + private boolean checkAll(final SchedulingExecutionVertex schedulingExecutionVertex) { + return schedulingExecutionVertex.getConsumedResultPartitions() + .stream() + .allMatch(this::partitionConsumable); + } + + private boolean checkAny(final SchedulingExecutionVertex schedulingExecutionVertex) { + return schedulingExecutionVertex.getConsumedResultPartitions() + .stream() + .anyMatch(this::partitionConsumable); + } + + private boolean partitionConsumable(SchedulingResultPartition partition) { + if (BLOCKING.equals(partition.getPartitionType())) { + return intermediateDataSetManager.allPartitionsFinished(partition); + } else { + SchedulingResultPartition.ResultPartitionState state = partition.getState(); + return PRODUCING.equals(state) || DONE.equals(state); + } + } + + private static class SchedulingIntermediateDataSetManager { + + private final Map intermediateDataSets = new HashMap<>(); + + List markSchedulingResultPartitionFinished(SchedulingResultPartition srp) { + SchedulingIntermediateDataSet intermediateDataSet = getSchedulingIntermediateDataSet(srp.getResultId()); + if (intermediateDataSet.markPartitionFinished(srp.getId())) { + return intermediateDataSet.getSchedulingResultPartitions(); + } + return Collections.emptyList(); + } + + void resetSchedulingResultPartition(SchedulingResultPartition srp) { + SchedulingIntermediateDataSet sid = getSchedulingIntermediateDataSet(srp.getResultId()); + sid.resetPartition(srp.getId()); + } + + void addSchedulingResultPartition(SchedulingResultPartition srp) { + SchedulingIntermediateDataSet sid = getOrCreateSchedulingIntermediateDataSetIfAbsent(srp.getResultId()); + sid.addSchedulingResultPartition(srp); + } + + boolean allPartitionsFinished(SchedulingResultPartition srp) { + SchedulingIntermediateDataSet sid = getSchedulingIntermediateDataSet(srp.getResultId()); + return sid.allPartitionsFinished(); + } + + private SchedulingIntermediateDataSet getSchedulingIntermediateDataSet( + final IntermediateDataSetID intermediateDataSetId) { + return getSchedulingIntermediateDataSetInternal(intermediateDataSetId, false); + } + + private SchedulingIntermediateDataSet getOrCreateSchedulingIntermediateDataSetIfAbsent( + final IntermediateDataSetID intermediateDataSetId) { + return getSchedulingIntermediateDataSetInternal(intermediateDataSetId, true); + } + + private SchedulingIntermediateDataSet getSchedulingIntermediateDataSetInternal( + final IntermediateDataSetID intermediateDataSetId, + boolean createIfAbsent) { + + return intermediateDataSets.computeIfAbsent( + intermediateDataSetId, + (key) -> { + if (createIfAbsent) { + return new SchedulingIntermediateDataSet(); + } else { + throw new IllegalArgumentException("can not find data set for " + intermediateDataSetId); + } + }); + } + } + + /** + * Representation of {@link IntermediateDataSet}. + */ + private static class SchedulingIntermediateDataSet { + + private final List partitions; + + private final Set producingPartitionIds; + + SchedulingIntermediateDataSet() { + partitions = new ArrayList<>(); + producingPartitionIds = new HashSet<>(); + } + + boolean markPartitionFinished(IntermediateResultPartitionID partitionId) { + producingPartitionIds.remove(partitionId); + return producingPartitionIds.isEmpty(); + } + + void resetPartition(IntermediateResultPartitionID partitionId) { + producingPartitionIds.add(partitionId); + } + + boolean allPartitionsFinished() { + return producingPartitionIds.isEmpty(); + } + + void addSchedulingResultPartition(SchedulingResultPartition partition) { + partitions.add(partition); + producingPartitionIds.add(partition.getId()); + } + + List getSchedulingResultPartitions() { + return Collections.unmodifiableList(partitions); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategy.java new file mode 100644 index 00000000000000..b5a09d4529f423 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategy.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.strategy; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.scheduler.DeploymentOption; +import org.apache.flink.runtime.scheduler.ExecutionVertexDeploymentOption; +import org.apache.flink.runtime.scheduler.SchedulerOperations; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static org.apache.flink.runtime.execution.ExecutionState.CREATED; +import static org.apache.flink.runtime.execution.ExecutionState.FINISHED; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * {@link SchedulingStrategy} instance for batch job which schedule vertices when input data are ready. + */ +public class LazyFromSourcesSchedulingStrategy implements SchedulingStrategy { + + private static final Predicate IS_IN_CREATED_EXECUTION_STATE = schedulingExecutionVertex -> CREATED == schedulingExecutionVertex.getState(); + + private final SchedulerOperations schedulerOperations; + + private final SchedulingTopology schedulingTopology; + + private final Map deploymentOptions; + + private final InputDependencyConstraintChecker inputConstraintChecker; + + public LazyFromSourcesSchedulingStrategy( + SchedulerOperations schedulerOperations, + SchedulingTopology schedulingTopology) { + + this.schedulerOperations = checkNotNull(schedulerOperations); + this.schedulingTopology = checkNotNull(schedulingTopology); + this.deploymentOptions = new HashMap<>(); + this.inputConstraintChecker = new InputDependencyConstraintChecker(); + } + + @Override + public void startScheduling() { + final DeploymentOption updateOption = new DeploymentOption(true); + final DeploymentOption nonUpdateOption = new DeploymentOption(false); + + for (SchedulingExecutionVertex schedulingVertex : schedulingTopology.getVertices()) { + DeploymentOption option = nonUpdateOption; + for (SchedulingResultPartition srp : schedulingVertex.getProducedResultPartitions()) { + if (srp.getPartitionType().isPipelined()) { + option = updateOption; + } + inputConstraintChecker.addSchedulingResultPartition(srp); + } + deploymentOptions.put(schedulingVertex.getId(), option); + } + + allocateSlotsAndDeployExecutionVertexIds(getAllVerticesFromTopology()); + } + + @Override + public void restartTasks(Set verticesToRestart) { + // increase counter of the dataset first + verticesToRestart + .stream() + .map(this::getSchedulingVertex) + .flatMap(vertex -> vertex.getProducedResultPartitions().stream()) + .forEach(inputConstraintChecker::resetSchedulingResultPartition); + + allocateSlotsAndDeployExecutionVertexIds(verticesToRestart); + } + + @Override + public void onExecutionStateChange(ExecutionVertexID executionVertexId, ExecutionState executionState) { + if (!FINISHED.equals(executionState)) { + return; + } + + final Set verticesToSchedule = getSchedulingVertex(executionVertexId) + .getProducedResultPartitions() + .stream() + .flatMap(partition -> inputConstraintChecker.markSchedulingResultPartitionFinished(partition).stream()) + .flatMap(partition -> partition.getConsumers().stream()) + .collect(Collectors.toSet()); + + allocateSlotsAndDeployExecutionVertices(verticesToSchedule); + } + + @Override + public void onPartitionConsumable(ExecutionVertexID executionVertexId, ResultPartitionID resultPartitionId) { + final SchedulingResultPartition resultPartition = schedulingTopology + .getResultPartition(resultPartitionId.getPartitionId()) + .orElseThrow(() -> new IllegalStateException("can not find scheduling result partition for " + + resultPartitionId)); + + if (!resultPartition.getPartitionType().isPipelined()) { + return; + } + + final SchedulingExecutionVertex producerVertex = getSchedulingVertex(executionVertexId); + if (!producerVertex.getProducedResultPartitions().contains(resultPartition)) { + throw new IllegalStateException("partition " + resultPartitionId + + " is not the produced partition of " + executionVertexId); + } + + allocateSlotsAndDeployExecutionVertices(resultPartition.getConsumers()); + } + + private SchedulingExecutionVertex getSchedulingVertex(final ExecutionVertexID executionVertexId) { + return schedulingTopology.getVertex(executionVertexId) + .orElseThrow(() -> new IllegalStateException("can not find scheduling vertex for " + executionVertexId)); + } + + private void allocateSlotsAndDeployExecutionVertexIds(Set verticesToSchedule) { + allocateSlotsAndDeployExecutionVertices( + verticesToSchedule + .stream() + .map(this::getSchedulingVertex) + .collect(Collectors.toList())); + } + + private void allocateSlotsAndDeployExecutionVertices(final Collection schedulingExecutionVertices) { + schedulerOperations.allocateSlotsAndDeploy( + schedulingExecutionVertices + .stream() + .filter(isInputConstraintSatisfied().and(IS_IN_CREATED_EXECUTION_STATE)) + .map(SchedulingExecutionVertex::getId) + .map(executionVertexID -> new ExecutionVertexDeploymentOption( + executionVertexID, + deploymentOptions.get(executionVertexID))) + .collect(Collectors.toSet())); + } + + private Predicate isInputConstraintSatisfied() { + return inputConstraintChecker::check; + } + + private Set getAllVerticesFromTopology() { + return StreamSupport + .stream(schedulingTopology.getVertices().spliterator(), false) + .map(SchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + } + + /** + * The factory for creating {@link LazyFromSourcesSchedulingStrategy}. + */ + public static class Factory implements SchedulingStrategyFactory { + @Override + public SchedulingStrategy createInstance( + SchedulerOperations schedulerOperations, + SchedulingTopology schedulingTopology, + JobGraph jobGraph) { + return new LazyFromSourcesSchedulingStrategy(schedulerOperations, schedulingTopology); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingExecutionVertex.java index b9b227118fc924..b4e1c2f06ec69d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingExecutionVertex.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.scheduler.strategy; +import org.apache.flink.api.common.InputDependencyConstraint; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.ExecutionVertex; @@ -55,4 +56,11 @@ public interface SchedulingExecutionVertex { * @return collection of output edges */ Collection getProducedResultPartitions(); + + /** + * Get {@link InputDependencyConstraint}. + * + * @return input dependency constraint + */ + InputDependencyConstraint getInputDependencyConstraint(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java index e9af7c6980c3ab..639548cec4fb37 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingExecutionVertexTest.java @@ -32,6 +32,7 @@ import java.util.Collections; import java.util.function.Supplier; +import static org.apache.flink.api.common.InputDependencyConstraint.ANY; import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; import static org.junit.Assert.assertEquals; @@ -60,12 +61,14 @@ public void setUp() throws Exception { producerVertex = new DefaultSchedulingExecutionVertex( new ExecutionVertexID(new JobVertexID(), 0), Collections.singletonList(schedulingResultPartition), - stateSupplier); + stateSupplier, + ANY); schedulingResultPartition.setProducer(producerVertex); consumerVertex = new DefaultSchedulingExecutionVertex( new ExecutionVertexID(new JobVertexID(), 0), Collections.emptyList(), - stateSupplier); + stateSupplier, + ANY); consumerVertex.addConsumedPartition(schedulingResultPartition); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java index d114b2ea28f07b..93ab29f961cbce 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultSchedulingResultPartitionTest.java @@ -32,6 +32,7 @@ import java.util.Collections; import java.util.function.Supplier; +import static org.apache.flink.api.common.InputDependencyConstraint.ANY; import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.DONE; import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.EMPTY; @@ -60,7 +61,8 @@ public void setUp() { DefaultSchedulingExecutionVertex producerVertex = new DefaultSchedulingExecutionVertex( new ExecutionVertexID(new JobVertexID(), 0), Collections.singletonList(resultPartition), - stateProvider); + stateProvider, + ANY); resultPartition.setProducer(producerVertex); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java index 0861f13d64c23a..b20a2f4e5f855b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/ExecutionGraphToSchedulingTopologyAdapterTest.java @@ -183,5 +183,6 @@ private static void assertVertexEquals( assertEquals( new ExecutionVertexID(originalVertex.getJobvertexId(), originalVertex.getParallelSubtaskIndex()), adaptedVertex.getId()); + assertEquals(originalVertex.getInputDependencyConstraint(), adaptedVertex.getInputDependencyConstraint()); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/EagerSchedulingStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/EagerSchedulingStrategyTest.java index 1808bd8954f110..364267c52efc25 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/EagerSchedulingStrategyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/EagerSchedulingStrategyTest.java @@ -29,8 +29,8 @@ import java.util.Collection; import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; +import static org.apache.flink.runtime.scheduler.strategy.StrategyTestUtil.getExecutionVertexIdsFromDeployOptions; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; @@ -111,11 +111,4 @@ public void testRestartTasks() { Collection scheduledVertices2 = testingSchedulerOperations.getScheduledVertices().get(1); assertThat(getExecutionVertexIdsFromDeployOptions(scheduledVertices2), containsInAnyOrder(verticesToRestart2.toArray())); } - - private static Collection getExecutionVertexIdsFromDeployOptions( - Collection deploymentOptions) { - return deploymentOptions.stream() - .map(ExecutionVertexDeploymentOption::getExecutionVertexId) - .collect(Collectors.toList()); - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintCheckerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintCheckerTest.java new file mode 100644 index 00000000000000..a45961f27ad2f6 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/InputDependencyConstraintCheckerTest.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.strategy; + +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.util.TestLogger; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.api.common.InputDependencyConstraint.ALL; +import static org.apache.flink.api.common.InputDependencyConstraint.ANY; +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING; +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.PIPELINED; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.DONE; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.EMPTY; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.PRODUCING; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@link InputDependencyConstraintChecker}. + */ +public class InputDependencyConstraintCheckerTest extends TestLogger { + + @Test + public void testCheckInputVertex() { + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex().finish(); + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(Collections.emptyList()); + + assertTrue(inputChecker.check(vertex)); + } + + @Test + public void testCheckEmptyPipelinedInput() { + final List partitions = addResultPartition() + .withPartitionType(PIPELINED) + .withPartitionState(EMPTY) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + assertFalse(inputChecker.check(vertex)); + } + + @Test + public void testCheckProducingPipelinedInput() { + final List partitions = addResultPartition() + .withPartitionType(PIPELINED) + .withPartitionState(PRODUCING) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + assertTrue(inputChecker.check(vertex)); + } + + @Test + public void testCheckDoneBlockingInput() { + final List partitions = addResultPartition() + .withPartitionCntPerDataSet(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + for (TestingSchedulingResultPartition srp : partitions) { + inputChecker.markSchedulingResultPartitionFinished(srp); + } + + assertTrue(inputChecker.check(vertex)); + } + + @Test + public void testCheckPartialDoneBlockingInput() { + final List partitions = addResultPartition() + .withPartitionCntPerDataSet(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + inputChecker.markSchedulingResultPartitionFinished(partitions.get(0)); + + assertFalse(inputChecker.check(vertex)); + } + + @Test + public void testCheckResetBlockingInput() { + final List partitions = addResultPartition() + .withPartitionCntPerDataSet(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + for (TestingSchedulingResultPartition srp : partitions) { + inputChecker.markSchedulingResultPartitionFinished(srp); + } + + for (TestingSchedulingResultPartition srp : partitions) { + inputChecker.resetSchedulingResultPartition(srp); + } + + assertFalse(inputChecker.check(vertex)); + } + + @Test + public void testCheckAnyBlockingInput() { + final List partitions = addResultPartition() + .withDataSetCnt(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + inputChecker.markSchedulingResultPartitionFinished(partitions.get(0)); + + assertTrue(inputChecker.check(vertex)); + } + + @Test + public void testCheckAllBlockingInput() { + final List partitions = addResultPartition() + .withDataSetCnt(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withInputDependencyConstraint(ALL) + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + for (TestingSchedulingResultPartition srp : partitions) { + inputChecker.markSchedulingResultPartitionFinished(srp); + } + + assertTrue(inputChecker.check(vertex)); + } + + @Test + public void testCheckAllPartialDatasetBlockingInput() { + final List partitions = addResultPartition() + .withDataSetCnt(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withInputDependencyConstraint(ALL) + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + inputChecker.markSchedulingResultPartitionFinished(partitions.get(0)); + assertFalse(inputChecker.check(vertex)); + } + + @Test + public void testCheckAllPartialPartitionBlockingInput() { + final List partitions = addResultPartition() + .withDataSetCnt(2) + .withPartitionCntPerDataSet(2) + .finish(); + final TestingSchedulingExecutionVertex vertex = addSchedulingExecutionVertex() + .withInputDependencyConstraint(ALL) + .withConsumedPartitions(partitions) + .finish(); + + final InputDependencyConstraintChecker inputChecker = createInputDependencyConstraintChecker(partitions); + + for (int idx = 0; idx < 3; idx++) { + inputChecker.markSchedulingResultPartitionFinished(partitions.get(idx)); + } + + assertFalse(inputChecker.check(vertex)); + } + + private static TestingSchedulingExecutionVertexBuilder addSchedulingExecutionVertex() { + return new TestingSchedulingExecutionVertexBuilder(); + } + + private static class TestingSchedulingExecutionVertexBuilder { + private static final JobVertexID jobVertexId = new JobVertexID(); + private InputDependencyConstraint inputDependencyConstraint = ANY; + private List partitions = Collections.emptyList(); + + TestingSchedulingExecutionVertexBuilder withInputDependencyConstraint(InputDependencyConstraint constraint) { + this.inputDependencyConstraint = constraint; + return this; + } + + TestingSchedulingExecutionVertexBuilder withConsumedPartitions(List partitions) { + this.partitions = partitions; + return this; + } + + TestingSchedulingExecutionVertex finish() { + return new TestingSchedulingExecutionVertex(jobVertexId, 0, inputDependencyConstraint, partitions); + } + } + + private static TestingSchedulingResultPartitionBuilder addResultPartition() { + return new TestingSchedulingResultPartitionBuilder(); + } + + private static InputDependencyConstraintChecker createInputDependencyConstraintChecker( + List partitions) { + + InputDependencyConstraintChecker inputChecker = new InputDependencyConstraintChecker(); + for (SchedulingResultPartition partition : partitions) { + inputChecker.addSchedulingResultPartition(partition); + } + return inputChecker; + } + + private static class TestingSchedulingResultPartitionBuilder { + private int dataSetCnt = 1; + private int partitionCntPerDataSet = 1; + private ResultPartitionType partitionType = BLOCKING; + private SchedulingResultPartition.ResultPartitionState partitionState = DONE; + + TestingSchedulingResultPartitionBuilder withDataSetCnt(int dataSetCnt) { + this.dataSetCnt = dataSetCnt; + return this; + } + + TestingSchedulingResultPartitionBuilder withPartitionCntPerDataSet(int partitionCnt) { + this.partitionCntPerDataSet = partitionCnt; + return this; + } + + TestingSchedulingResultPartitionBuilder withPartitionType(ResultPartitionType type) { + this.partitionType = type; + return this; + } + + TestingSchedulingResultPartitionBuilder withPartitionState(SchedulingResultPartition.ResultPartitionState state) { + this.partitionState = state; + return this; + } + + List finish() { + List partitions = new ArrayList<>(dataSetCnt * partitionCntPerDataSet); + for (int dataSetIdx = 0; dataSetIdx < dataSetCnt; dataSetIdx++) { + IntermediateDataSetID dataSetId = new IntermediateDataSetID(); + for (int partitionIdx = 0; partitionIdx < partitionCntPerDataSet; partitionIdx++) { + partitions.add(new TestingSchedulingResultPartition(dataSetId, partitionType, partitionState)); + } + } + + return partitions; + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategyTest.java new file mode 100644 index 00000000000000..05e3ab09940b1b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/LazyFromSourcesSchedulingStrategyTest.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.strategy; + +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.util.TestLogger; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.api.common.InputDependencyConstraint.ALL; +import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.PIPELINED; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.EMPTY; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.PRODUCING; +import static org.apache.flink.runtime.scheduler.strategy.StrategyTestUtil.getExecutionVertexIdsFromDeployOptions; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; + +/** + * Unit tests for {@link LazyFromSourcesSchedulingStrategy}. + */ +public class LazyFromSourcesSchedulingStrategyTest extends TestLogger { + + private TestingSchedulerOperations testingSchedulerOperation = new TestingSchedulerOperations(); + + /** + * Tests that when start scheduling lazy from sources scheduling strategy will start input vertices in scheduling topology. + */ + @Test + public void testStartScheduling() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices().finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices().finish(); + testingSchedulingTopology.connectAllToAll(producers, consumers).finish(); + + startScheduling(testingSchedulingTopology); + + assertThat(testingSchedulerOperation, hasScheduledVertices(producers)); + } + + /** + * Tests that when restart tasks will only schedule input ready vertices in given ones. + */ + @Test + public void testRestartBlockingTasks() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices().finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices().finish(); + testingSchedulingTopology.connectAllToAll(producers, consumers).finish(); + + LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + Set verticesToRestart = producers.stream().map(TestingSchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + verticesToRestart.addAll(consumers.stream().map( + TestingSchedulingExecutionVertex::getId).collect(Collectors.toSet())); + + schedulingStrategy.restartTasks(verticesToRestart); + assertThat(testingSchedulerOperation, hasScheduledVertices(producers)); + } + + /** + * Tests that when restart tasks will schedule input consumable vertices in given ones. + */ + @Test + public void testRestartConsumableBlockingTasks() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices().finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices().finish(); + testingSchedulingTopology.connectAllToAll(producers, consumers).finish(); + + LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + Set verticesToRestart = consumers.stream().map(TestingSchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + + for (TestingSchedulingExecutionVertex producer : producers) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + + schedulingStrategy.restartTasks(verticesToRestart); + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + /** + * Tests that when all the input partitions are ready will start available downstream {@link ResultPartitionType#BLOCKING} vertices. + * vertex#0 vertex#1 + * \ / + * \ / + * \ / + * (BLOCKING, ALL) + * vertex#2 + */ + @Test + public void testRestartBlockingALLExecutionStateChange() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers1 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List producers2 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).withInputDependencyConstraint(ALL).finish(); + testingSchedulingTopology.connectPointwise(producers1, consumers).finish(); + testingSchedulingTopology.connectPointwise(producers2, consumers).finish(); + + final LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + for (TestingSchedulingExecutionVertex producer : producers1) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + for (TestingSchedulingExecutionVertex producer : producers2) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + + Set verticesToRestart = consumers.stream().map(TestingSchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + + schedulingStrategy.restartTasks(verticesToRestart); + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + /** + * Tests that when any input dataset finishes will start available downstream {@link ResultPartitionType#BLOCKING} vertices. + * vertex#0 vertex#1 + * \ / + * \ / + * \ / + * (BLOCKING, ANY) + * vertex#2 + */ + @Test + public void testRestartBlockingANYExecutionStateChange() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers1 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List producers2 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + testingSchedulingTopology.connectPointwise(producers1, consumers).finish(); + testingSchedulingTopology.connectPointwise(producers2, consumers).finish(); + + final LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + for (TestingSchedulingExecutionVertex producer : producers1) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + + Set verticesToRestart = consumers.stream().map(TestingSchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + + schedulingStrategy.restartTasks(verticesToRestart); + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + /** + * Tests that when restart {@link ResultPartitionType#PIPELINED} tasks with {@link SchedulingResultPartition.ResultPartitionState#PRODUCING} will be scheduled. + */ + @Test + public void testRestartProducingPipelinedTasks() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + testingSchedulingTopology.connectAllToAll(producers, consumers).withResultPartitionState(PRODUCING) + .withResultPartitionType(PIPELINED).finish(); + + LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + Set verticesToRestart = producers.stream().map(TestingSchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + verticesToRestart.addAll(consumers.stream().map( + TestingSchedulingExecutionVertex::getId).collect(Collectors.toSet())); + + schedulingStrategy.restartTasks(verticesToRestart); + List toScheduleVertices = new ArrayList<>(producers.size() + consumers.size()); + toScheduleVertices.addAll(consumers); + toScheduleVertices.addAll(producers); + assertThat(testingSchedulerOperation, hasScheduledVertices(toScheduleVertices)); + } + + /** + * Tests that when restart {@link ResultPartitionType#PIPELINED} tasks with {@link SchedulingResultPartition.ResultPartitionState#EMPTY} will not be scheduled. + */ + @Test + public void testRestartEmptyPipelinedTasks() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + testingSchedulingTopology.connectAllToAll(producers, consumers).withResultPartitionState(EMPTY) + .withResultPartitionType(PIPELINED).finish(); + + LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + Set verticesToRestart = producers.stream().map(TestingSchedulingExecutionVertex::getId) + .collect(Collectors.toSet()); + verticesToRestart.addAll(consumers.stream().map( + TestingSchedulingExecutionVertex::getId).collect(Collectors.toSet())); + + schedulingStrategy.restartTasks(verticesToRestart); + assertThat(testingSchedulerOperation, hasScheduledVertices(producers)); + } + + /** + * Tests that when partition consumable notified will start available {@link ResultPartitionType#PIPELINED} downstream vertices. + */ + @Test + public void testPipelinedPartitionConsumable() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + testingSchedulingTopology.connectAllToAll(producers, consumers).withResultPartitionType(PIPELINED).finish(); + + final LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + final TestingSchedulingExecutionVertex producer1 = producers.get(0); + final SchedulingResultPartition partition1 = producer1.getProducedResultPartitions().iterator().next(); + + schedulingStrategy.onExecutionStateChange(producer1.getId(), ExecutionState.RUNNING); + schedulingStrategy.onPartitionConsumable(producer1.getId(), new ResultPartitionID(partition1.getId(), new ExecutionAttemptID())); + + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + /** + * Tests that when partition consumable notified will start available {@link ResultPartitionType#BLOCKING} downstream vertices. + */ + @Test + public void testBlockingPointwiseExecutionStateChange() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).withInputDependencyConstraint(ALL).finish(); + testingSchedulingTopology.connectPointwise(producers, consumers).finish(); + + final LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + for (TestingSchedulingExecutionVertex producer : producers) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + /** + * Tests that when all the input partitions are ready will start available downstream {@link ResultPartitionType#BLOCKING} vertices. + * vertex#0 vertex#1 + * \ / + * \ / + * \ / + * (BLOCKING, ALL) + * vertex#2 + */ + @Test + public void testBlockingALLExecutionStateChange() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers1 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List producers2 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).withInputDependencyConstraint(ALL).finish(); + testingSchedulingTopology.connectPointwise(producers1, consumers).finish(); + testingSchedulingTopology.connectPointwise(producers2, consumers).finish(); + + final LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + for (TestingSchedulingExecutionVertex producer : producers1) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + for (TestingSchedulingExecutionVertex producer : producers2) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + /** + * Tests that when any input dataset finishes will start available downstream {@link ResultPartitionType#BLOCKING} vertices. + * vertex#0 vertex#1 + * \ / + * \ / + * \ / + * (BLOCKING, ANY) + * vertex#2 + */ + @Test + public void testBlockingANYExecutionStateChange() { + final TestingSchedulingTopology testingSchedulingTopology = new TestingSchedulingTopology(); + + final List producers1 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List producers2 = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + final List consumers = testingSchedulingTopology.addExecutionVertices() + .withParallelism(2).finish(); + testingSchedulingTopology.connectPointwise(producers1, consumers).finish(); + testingSchedulingTopology.connectPointwise(producers2, consumers).finish(); + + final LazyFromSourcesSchedulingStrategy schedulingStrategy = startScheduling(testingSchedulingTopology); + + for (TestingSchedulingExecutionVertex producer : producers1) { + schedulingStrategy.onExecutionStateChange(producer.getId(), ExecutionState.FINISHED); + } + + assertThat(testingSchedulerOperation, hasScheduledVertices(consumers)); + } + + private static Matcher hasScheduledVertices(final List consumers) { + + final Matcher> vertexIdMatcher = containsInAnyOrder(consumers.stream() + .map(SchedulingExecutionVertex::getId) + .toArray(ExecutionVertexID[]::new)); + + return new TypeSafeDiagnosingMatcher() { + + @Override + protected boolean matchesSafely(final TestingSchedulerOperations item, final Description mismatchDescription) { + final boolean matches = vertexIdMatcher.matches(getExecutionVertexIdsFromDeployOptions(item.getLatestScheduledVertices())); + if (!matches) { + vertexIdMatcher.describeMismatch(item.getLatestScheduledVertices(), mismatchDescription); + } + return matches; + } + + @Override + public void describeTo(final Description description) { + description.appendText("to be scheduled vertex id is ").appendDescriptionOf(vertexIdMatcher); + } + }; + } + + private LazyFromSourcesSchedulingStrategy startScheduling(TestingSchedulingTopology testingSchedulingTopology) { + LazyFromSourcesSchedulingStrategy schedulingStrategy = new LazyFromSourcesSchedulingStrategy( + testingSchedulerOperation, + testingSchedulingTopology); + schedulingStrategy.startScheduling(); + return schedulingStrategy; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/StrategyTestUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/StrategyTestUtil.java new file mode 100644 index 00000000000000..119afe2cc3d2d5 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/StrategyTestUtil.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.strategy; + +import org.apache.flink.runtime.scheduler.ExecutionVertexDeploymentOption; + +import java.util.Collection; +import java.util.stream.Collectors; + +/** + * Strategy test utilities. + */ +public class StrategyTestUtil { + + static Collection getExecutionVertexIdsFromDeployOptions( + Collection deploymentOptions) { + + return deploymentOptions.stream() + .map(ExecutionVertexDeploymentOption::getExecutionVertexId) + .collect(Collectors.toList()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulerOperations.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulerOperations.java index 3a67b3980b62f1..9edba3bfc2abed 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulerOperations.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulerOperations.java @@ -38,7 +38,11 @@ public void allocateSlotsAndDeploy(Collection e scheduledVertices.add(executionVertexDeploymentOptions); } - public List> getScheduledVertices() { + List> getScheduledVertices() { return Collections.unmodifiableList(scheduledVertices); } + + Collection getLatestScheduledVertices() { + return Collections.unmodifiableCollection(scheduledVertices.get(scheduledVertices.size() - 1)); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java index 8681fe10172bac..17f15c5a9d3326 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java @@ -18,12 +18,16 @@ package org.apache.flink.runtime.scheduler.strategy; +import org.apache.flink.api.common.InputDependencyConstraint; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.jobgraph.JobVertexID; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import static org.apache.flink.util.Preconditions.checkNotNull; + /** * A simple scheduling execution vertex for testing purposes. */ @@ -31,8 +35,32 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert private final ExecutionVertexID executionVertexId; + private final Collection consumedPartitions; + + private final Collection producedPartitions; + + private InputDependencyConstraint inputDependencyConstraint; + public TestingSchedulingExecutionVertex(JobVertexID jobVertexId, int subtaskIndex) { + this(jobVertexId, subtaskIndex, InputDependencyConstraint.ANY); + } + + public TestingSchedulingExecutionVertex(JobVertexID jobVertexId, int subtaskIndex, + InputDependencyConstraint constraint) { + + this(jobVertexId, subtaskIndex, constraint, new ArrayList<>()); + } + + public TestingSchedulingExecutionVertex( + JobVertexID jobVertexId, + int subtaskIndex, + InputDependencyConstraint constraint, + Collection consumedPartitions) { + this.executionVertexId = new ExecutionVertexID(jobVertexId, subtaskIndex); + this.inputDependencyConstraint = constraint; + this.consumedPartitions = checkNotNull(consumedPartitions); + this.producedPartitions = new ArrayList<>(); } @Override @@ -47,11 +75,24 @@ public ExecutionState getState() { @Override public Collection getConsumedResultPartitions() { - return Collections.emptyList(); + return Collections.unmodifiableCollection(consumedPartitions); } @Override public Collection getProducedResultPartitions() { - return Collections.emptyList(); + return Collections.unmodifiableCollection(producedPartitions); + } + + @Override + public InputDependencyConstraint getInputDependencyConstraint() { + return inputDependencyConstraint; + } + + void addConsumedPartition(TestingSchedulingResultPartition partition) { + consumedPartitions.add(partition); + } + + void addProducedPartition(SchedulingResultPartition partition) { + producedPartitions.add(partition); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java new file mode 100644 index 00000000000000..cb603195dea43b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.strategy; + +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A simple implementation of {@link SchedulingResultPartition} for testing. + */ +public class TestingSchedulingResultPartition implements SchedulingResultPartition { + private final IntermediateDataSetID intermediateDataSetID; + + private final IntermediateResultPartitionID intermediateResultPartitionID; + + private final ResultPartitionType partitionType; + + private SchedulingExecutionVertex producer; + + private Collection consumers; + + private ResultPartitionState state; + + TestingSchedulingResultPartition(IntermediateDataSetID dataSetID, ResultPartitionType type, ResultPartitionState state) { + this.intermediateDataSetID = dataSetID; + this.partitionType = type; + this.state = state; + this.intermediateResultPartitionID = new IntermediateResultPartitionID(); + this.consumers = new ArrayList<>(); + } + + @Override + public IntermediateResultPartitionID getId() { + return intermediateResultPartitionID; + } + + @Override + public IntermediateDataSetID getResultId() { + return intermediateDataSetID; + } + + @Override + public ResultPartitionType getPartitionType() { + return partitionType; + } + + @Override + public ResultPartitionState getState() { + return state; + } + + @Override + public SchedulingExecutionVertex getProducer() { + return producer; + } + + @Override + public Collection getConsumers() { + return Collections.unmodifiableCollection(consumers); + } + + void addConsumer(SchedulingExecutionVertex consumer) { + this.consumers.add(consumer); + } + + void setProducer(TestingSchedulingExecutionVertex producer) { + this.producer = checkNotNull(producer); + } + + /** + * Builder for {@link TestingSchedulingResultPartition}. + */ + public static final class Builder { + private IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID(); + private ResultPartitionType resultPartitionType = ResultPartitionType.BLOCKING; + private ResultPartitionState resultPartitionState = ResultPartitionState.DONE; + + Builder withIntermediateDataSetID(IntermediateDataSetID intermediateDataSetId) { + this.intermediateDataSetId = intermediateDataSetId; + return this; + } + + Builder withResultPartitionState(ResultPartitionState state) { + this.resultPartitionState = state; + return this; + } + + Builder withResultPartitionType(ResultPartitionType type) { + this.resultPartitionType = type; + return this; + } + + TestingSchedulingResultPartition build() { + return new TestingSchedulingResultPartition(intermediateDataSetId, resultPartitionType, resultPartitionState); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java index e827f78654cdc5..2a84ba38aeb99b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java @@ -18,14 +18,24 @@ package org.apache.flink.runtime.scheduler.strategy; +import org.apache.flink.api.common.InputDependencyConstraint; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.util.Preconditions; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; +import static org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition.ResultPartitionState.DONE; + /** * A simple scheduling topology for testing purposes. */ @@ -48,10 +58,10 @@ public Optional getVertex(ExecutionVertexID execution @Override public Optional getResultPartition( IntermediateResultPartitionID intermediateResultPartitionId) { - return Optional.ofNullable(schedulingResultPartitions.get(intermediateResultPartitionId)); + return Optional.of(schedulingResultPartitions.get(intermediateResultPartitionId)); } - public void addSchedulingExecutionVertex(SchedulingExecutionVertex schedulingExecutionVertex) { + void addSchedulingExecutionVertex(SchedulingExecutionVertex schedulingExecutionVertex) { schedulingExecutionVertices.put(schedulingExecutionVertex.getId(), schedulingExecutionVertex); addSchedulingResultPartitions(schedulingExecutionVertex.getConsumedResultPartitions()); addSchedulingResultPartitions(schedulingExecutionVertex.getProducedResultPartitions()); @@ -62,4 +72,182 @@ private void addSchedulingResultPartitions(final Collection vertices) { + for (TestingSchedulingExecutionVertex vertex : vertices) { + addSchedulingExecutionVertex(vertex); + } + } + + SchedulingExecutionVerticesBuilder addExecutionVertices() { + return new SchedulingExecutionVerticesBuilder(); + } + + ProducerConsumerConnectionBuilder connectPointwise( + final List producers, + final List consumers) { + + return new ProducerConsumerPointwiseConnectionBuilder(producers, consumers); + } + + ProducerConsumerConnectionBuilder connectAllToAll( + final List producers, + final List consumers) { + + return new ProducerConsumerAllToAllConnectionBuilder(producers, consumers); + } + + /** + * Builder for {@link TestingSchedulingResultPartition}. + */ + public abstract class ProducerConsumerConnectionBuilder { + + protected final List producers; + + protected final List consumers; + + protected ResultPartitionType resultPartitionType = ResultPartitionType.BLOCKING; + + protected SchedulingResultPartition.ResultPartitionState resultPartitionState = DONE; + + protected ProducerConsumerConnectionBuilder( + final List producers, + final List consumers) { + this.producers = producers; + this.consumers = consumers; + } + + ProducerConsumerConnectionBuilder withResultPartitionType(final ResultPartitionType resultPartitionType) { + this.resultPartitionType = resultPartitionType; + return this; + } + + ProducerConsumerConnectionBuilder withResultPartitionState(final SchedulingResultPartition.ResultPartitionState state) { + this.resultPartitionState = state; + return this; + } + + public List finish() { + final List resultPartitions = connect(); + + TestingSchedulingTopology.this.addSchedulingExecutionVertices(producers); + TestingSchedulingTopology.this.addSchedulingExecutionVertices(consumers); + + return resultPartitions; + } + + TestingSchedulingResultPartition.Builder initTestingSchedulingResultPartitionBuilder() { + return new TestingSchedulingResultPartition.Builder() + .withResultPartitionType(resultPartitionType); + } + + protected abstract List connect(); + + } + + /** + * Builder for {@link TestingSchedulingResultPartition} of {@link DistributionPattern#POINTWISE}. + */ + private class ProducerConsumerPointwiseConnectionBuilder extends ProducerConsumerConnectionBuilder { + + private ProducerConsumerPointwiseConnectionBuilder( + final List producers, + final List consumers) { + super(producers, consumers); + // currently we only support one to one + Preconditions.checkState(producers.size() == consumers.size()); + } + + @Override + protected List connect() { + final List resultPartitions = new ArrayList<>(); + final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID(); + + for (int idx = 0; idx < producers.size(); idx++) { + final TestingSchedulingExecutionVertex producer = producers.get(idx); + final TestingSchedulingExecutionVertex consumer = consumers.get(idx); + + final TestingSchedulingResultPartition resultPartition = initTestingSchedulingResultPartitionBuilder() + .withIntermediateDataSetID(intermediateDataSetId) + .withResultPartitionState(resultPartitionState) + .build(); + resultPartition.setProducer(producer); + producer.addProducedPartition(resultPartition); + consumer.addConsumedPartition(resultPartition); + resultPartition.addConsumer(consumer); + resultPartitions.add(resultPartition); + } + + return resultPartitions; + } + } + + /** + * Builder for {@link TestingSchedulingResultPartition} of {@link DistributionPattern#ALL_TO_ALL}. + */ + private class ProducerConsumerAllToAllConnectionBuilder extends ProducerConsumerConnectionBuilder { + + private ProducerConsumerAllToAllConnectionBuilder( + final List producers, + final List consumers) { + super(producers, consumers); + } + + @Override + protected List connect() { + final List resultPartitions = new ArrayList<>(); + final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID(); + + for (TestingSchedulingExecutionVertex producer : producers) { + + final TestingSchedulingResultPartition resultPartition = initTestingSchedulingResultPartitionBuilder() + .withIntermediateDataSetID(intermediateDataSetId) + .withResultPartitionState(resultPartitionState) + .build(); + resultPartition.setProducer(producer); + producer.addProducedPartition(resultPartition); + + for (TestingSchedulingExecutionVertex consumer : consumers) { + consumer.addConsumedPartition(resultPartition); + resultPartition.addConsumer(consumer); + resultPartitions.add(resultPartition); + } + } + + return resultPartitions; + } + } + + /** + * Builder for {@link TestingSchedulingExecutionVertex}. + */ + public class SchedulingExecutionVerticesBuilder { + + private final JobVertexID jobVertexId = new JobVertexID(); + + private int parallelism = 1; + + private InputDependencyConstraint inputDependencyConstraint = InputDependencyConstraint.ANY; + + SchedulingExecutionVerticesBuilder withParallelism(final int parallelism) { + this.parallelism = parallelism; + return this; + } + + SchedulingExecutionVerticesBuilder withInputDependencyConstraint(final InputDependencyConstraint inputDependencyConstraint) { + this.inputDependencyConstraint = inputDependencyConstraint; + return this; + } + + public List finish() { + final List vertices = new ArrayList<>(); + for (int subtaskIndex = 0; subtaskIndex < parallelism; subtaskIndex++) { + vertices.add(new TestingSchedulingExecutionVertex(jobVertexId, subtaskIndex, inputDependencyConstraint)); + } + + TestingSchedulingTopology.this.addSchedulingExecutionVertices(vertices); + + return vertices; + } + } } From de31f49b4b835b71cdf99f93c89e33a84113b272 Mon Sep 17 00:00:00 2001 From: sunhaibotb Date: Tue, 21 May 2019 11:26:20 +0800 Subject: [PATCH 57/92] [FLINK-12547][blob] Add connection and socket timeouts for the blob client This closes #8484. --- .../generated/blob_server_configuration.html | 10 +++ .../configuration/BlobServerOptions.java | 16 +++++ .../apache/flink/runtime/blob/BlobClient.java | 8 +-- .../flink/runtime/blob/BlobClientSslTest.java | 6 +- .../flink/runtime/blob/BlobClientTest.java | 63 ++++++++++++++++++- 5 files changed, 93 insertions(+), 10 deletions(-) diff --git a/docs/_includes/generated/blob_server_configuration.html b/docs/_includes/generated/blob_server_configuration.html index 4cb1744b82d0ee..36b69c6f223ed4 100644 --- a/docs/_includes/generated/blob_server_configuration.html +++ b/docs/_includes/generated/blob_server_configuration.html @@ -7,6 +7,16 @@

+ + + + + + + + + + diff --git a/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java index 20a068a60ee0e5..42d2cd86f08320 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/BlobServerOptions.java @@ -102,4 +102,20 @@ public class BlobServerOptions { public static final ConfigOption OFFLOAD_MINSIZE = key("blob.offload.minsize") .defaultValue(1_024 * 1_024) // 1MiB by default .withDescription("The minimum size for messages to be offloaded to the BlobServer."); + + /** + * The socket timeout in milliseconds for the blob client. + */ + public static final ConfigOption SO_TIMEOUT = + key("blob.client.socket.timeout") + .defaultValue(300_000) + .withDescription("The socket timeout in milliseconds for the blob client."); + + /** + * The connection timeout in milliseconds for the blob client. + */ + public static final ConfigOption CONNECT_TIMEOUT = + key("blob.client.connect.timeout") + .defaultValue(0) + .withDescription("The connection timeout in milliseconds for the blob client."); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java index 01e307eb5b1c56..b76ba5d7bc67c1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobClient.java @@ -86,14 +86,14 @@ public BlobClient(InetSocketAddress serverAddress, Configuration clientConfig) t if (SSLUtils.isInternalSSLEnabled(clientConfig) && clientConfig.getBoolean(BlobServerOptions.SSL_ENABLED)) { LOG.info("Using ssl connection to the blob server"); - socket = SSLUtils.createSSLClientSocketFactory(clientConfig).createSocket( - serverAddress.getAddress(), - serverAddress.getPort()); + socket = SSLUtils.createSSLClientSocketFactory(clientConfig).createSocket(); } else { socket = new Socket(); - socket.connect(serverAddress); } + + socket.connect(serverAddress, clientConfig.getInteger(BlobServerOptions.CONNECT_TIMEOUT)); + socket.setSoTimeout(clientConfig.getInteger(BlobServerOptions.SO_TIMEOUT)); } catch (Exception e) { BlobUtils.closeSilently(socket, LOG); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java index 531f2148f23737..a46a6df8994a5a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java @@ -36,7 +36,7 @@ public class BlobClientSslTest extends BlobClientTest { /** The instance of the SSL BLOB server used during the tests. */ - private static BlobServer blobSslServer; + private static TestBlobServer blobSslServer; /** Instance of a non-SSL BLOB server with SSL-enabled security options. */ private static BlobServer blobNonSslServer; @@ -58,7 +58,7 @@ public static void startSSLServer() throws IOException { Configuration config = SSLUtilsTest.createInternalSslConfigWithKeyAndTrustStores(); config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporarySslFolder.newFolder().getAbsolutePath()); - blobSslServer = new BlobServer(config, new VoidBlobStore()); + blobSslServer = new TestBlobServer(config, new VoidBlobStore()); blobSslServer.start(); sslClientConfig = config; @@ -93,7 +93,7 @@ protected Configuration getBlobClientConfig() { return sslClientConfig; } - protected BlobServer getBlobServer() { + protected TestBlobServer getBlobServer() { return blobSslServer; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java index c083d08ca195ae..c28b9a5b37215b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientTest.java @@ -22,6 +22,7 @@ import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; +import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.TestLogger; import org.junit.AfterClass; @@ -48,6 +49,8 @@ import static org.apache.flink.runtime.blob.BlobKey.BlobType.PERMANENT_BLOB; import static org.apache.flink.runtime.blob.BlobKey.BlobType.TRANSIENT_BLOB; import static org.apache.flink.runtime.blob.BlobKeyTest.verifyKeyDifferentHashEquals; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -62,7 +65,7 @@ public class BlobClientTest extends TestLogger { private static final int TEST_BUFFER_SIZE = 17 * 1000; /** The instance of the (non-ssl) BLOB server used during the tests. */ - static BlobServer blobServer; + static TestBlobServer blobServer; /** The blob service (non-ssl) client configuration. */ static Configuration clientConfig; @@ -79,7 +82,7 @@ public static void startServer() throws IOException { config.setString(BlobServerOptions.STORAGE_DIRECTORY, temporaryFolder.newFolder().getAbsolutePath()); - blobServer = new BlobServer(config, new VoidBlobStore()); + blobServer = new TestBlobServer(config, new VoidBlobStore()); blobServer.start(); clientConfig = new Configuration(); @@ -318,7 +321,7 @@ protected Configuration getBlobClientConfig() { return clientConfig; } - protected BlobServer getBlobServer() { + protected TestBlobServer getBlobServer() { return blobServer; } @@ -487,4 +490,58 @@ private static void uploadJarFile( validateGetAndClose(blobClient.getInternal(jobId, blobKeys.get(0)), testFile); } } + + + /** + * Tests the socket operation timeout. + */ + @Test + public void testSocketTimeout() { + Configuration clientConfig = getBlobClientConfig(); + int oldSoTimeout = clientConfig.getInteger(BlobServerOptions.SO_TIMEOUT); + + clientConfig.setInteger(BlobServerOptions.SO_TIMEOUT, 50); + getBlobServer().setBlockingMillis(10_000); + + try { + InetSocketAddress serverAddress = new InetSocketAddress("localhost", getBlobServer().getPort()); + + try (BlobClient client = new BlobClient(serverAddress, clientConfig)) { + client.getInternal(new JobID(), BlobKey.createKey(TRANSIENT_BLOB)); + + fail("Should throw an exception."); + } catch (Throwable t) { + assertThat(ExceptionUtils.findThrowable(t, java.net.SocketTimeoutException.class).isPresent(), is(true)); + } + } finally { + clientConfig.setInteger(BlobServerOptions.SO_TIMEOUT, oldSoTimeout); + getBlobServer().setBlockingMillis(0); + } + } + + static class TestBlobServer extends BlobServer { + + private volatile long blockingMillis = 0; + + TestBlobServer(Configuration config, BlobStore blobStore) throws IOException { + super(config, blobStore); + } + + @Override + void getFileInternal(@Nullable JobID jobId, BlobKey blobKey, File localFile) throws IOException { + if (blockingMillis > 0) { + try { + Thread.sleep(blockingMillis); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + super.getFileInternal(jobId, blobKey, localFile); + } + + void setBlockingMillis(long millis) { + this.blockingMillis = millis; + } + } } From 35f3996bac3862ae22c6606489c0a0fd489a99b1 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Mon, 3 Jun 2019 10:12:20 -0700 Subject: [PATCH 58/92] [hotfix][hive] include jdo-api for hive-metastore to enable tests for profile hive-1.2.1 --- flink-connectors/flink-connector-hive/pom.xml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/flink-connectors/flink-connector-hive/pom.xml b/flink-connectors/flink-connector-hive/pom.xml index cd245d961abaf9..ae3284c2ea3b27 100644 --- a/flink-connectors/flink-connector-hive/pom.xml +++ b/flink-connectors/flink-connector-hive/pom.xml @@ -117,10 +117,6 @@ under the License. com.zaxxer HikariCP - - javax.jdo - jdo-api - co.cask.tephra tephra-api From 1e0a16662aceca4f151b3326f0676a3583b85aa0 Mon Sep 17 00:00:00 2001 From: Rui Li Date: Fri, 24 May 2019 18:51:59 +0800 Subject: [PATCH 59/92] [FLINK-12568][hive] Implement OutputFormat to write Hive tables This closes #8536. --- flink-connectors/flink-connector-hive/pom.xml | 7 + .../hive/HiveTableOutputFormat.java | 408 ++++++++++++++++++ .../connectors/hive/HiveTablePartition.java | 49 +++ .../batch/connectors/hive/HiveTableUtil.java | 59 +++ .../flink/table/catalog/hive/HiveCatalog.java | 2 +- .../table/catalog/hive/util/HiveTypeUtil.java | 31 +- .../hive/HiveTableOutputFormatTest.java | 137 ++++++ .../table/catalog/hive/HiveTestUtils.java | 6 +- 8 files changed, 686 insertions(+), 13 deletions(-) create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormat.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableUtil.java create mode 100644 flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormatTest.java diff --git a/flink-connectors/flink-connector-hive/pom.xml b/flink-connectors/flink-connector-hive/pom.xml index ae3284c2ea3b27..545c964cdeb69b 100644 --- a/flink-connectors/flink-connector-hive/pom.xml +++ b/flink-connectors/flink-connector-hive/pom.xml @@ -58,6 +58,13 @@ under the License. provided + + org.apache.flink + flink-hadoop-compatibility_${scala.binary.version} + ${project.version} + provided + + diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormat.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormat.java new file mode 100644 index 00000000000000..324c6c6533bf8c --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormat.java @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.api.common.io.FinalizeOnMaster; +import org.apache.flink.api.common.io.InitializeOnMaster; +import org.apache.flink.api.java.hadoop.common.HadoopInputFormatCommonBase; +import org.apache.flink.api.java.hadoop.common.HadoopOutputFormatCommonBase; +import org.apache.flink.api.java.hadoop.mapreduce.utils.HadoopUtils; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.catalog.exceptions.CatalogException; +import org.apache.flink.table.catalog.hive.client.HiveMetastoreClientFactory; +import org.apache.flink.table.catalog.hive.client.HiveMetastoreClientWrapper; +import org.apache.flink.types.Row; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.StringUtils; + +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.common.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.StorageDescriptor; +import org.apache.hadoop.hive.metastore.api.Table; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.mapred.FileOutputFormat; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.JobContext; +import org.apache.hadoop.mapred.JobContextImpl; +import org.apache.hadoop.mapred.JobID; +import org.apache.hadoop.mapred.OutputCommitter; +import org.apache.hadoop.mapred.OutputFormat; +import org.apache.hadoop.mapred.Reporter; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapred.TaskAttemptContext; +import org.apache.hadoop.mapred.TaskAttemptContextImpl; +import org.apache.hadoop.mapred.TaskAttemptID; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.util.ReflectionUtils; +import org.apache.thrift.TException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.apache.hadoop.mapreduce.lib.output.FileOutputFormat.OUTDIR; + +/** + * HiveTableOutputFormat used to write data to hive table, including non-partition and partitioned table. + */ +public class HiveTableOutputFormat extends HadoopOutputFormatCommonBase implements InitializeOnMaster, + FinalizeOnMaster { + + private static final Logger LOG = LoggerFactory.getLogger(HiveTableOutputFormat.class); + + private static final long serialVersionUID = 5167529504848109023L; + + private transient JobConf jobConf; + private transient String dbName; + private transient String tableName; + private transient List partitionCols; + private transient RowTypeInfo rowTypeInfo; + private transient HiveTablePartition hiveTablePartition; + private transient Properties tblProperties; + private transient boolean overwrite; + private transient boolean isPartitioned; + private transient boolean isDynamicPartition; + // number of non-partitioning columns + private transient int numNonPartitionCols; + + private transient AbstractSerDe serializer; + //StructObjectInspector represents the hive row structure. + private transient StructObjectInspector rowObjectInspector; + private transient Class outputClass; + private transient TaskAttemptContext context; + + // Maps a partition dir name to the corresponding writer. Used for dynamic partitioning. + private transient Map partitionToWriter; + // Writer for non-partitioned and static partitioned table + private transient HivePartitionWriter staticWriter; + + public HiveTableOutputFormat(JobConf jobConf, String dbName, String tableName, List partitionCols, + RowTypeInfo rowTypeInfo, HiveTablePartition hiveTablePartition, + Properties tblProperties, boolean overwrite) { + super(jobConf.getCredentials()); + + Preconditions.checkArgument(!StringUtils.isNullOrWhitespaceOnly(dbName), "DB name is empty"); + Preconditions.checkArgument(!StringUtils.isNullOrWhitespaceOnly(tableName), "Table name is empty"); + Preconditions.checkNotNull(rowTypeInfo, "RowTypeInfo cannot be null"); + Preconditions.checkNotNull(hiveTablePartition, "HiveTablePartition cannot be null"); + Preconditions.checkNotNull(tblProperties, "Table properties cannot be null"); + + HadoopUtils.mergeHadoopConf(jobConf); + this.jobConf = jobConf; + this.dbName = dbName; + this.tableName = tableName; + this.partitionCols = partitionCols; + this.rowTypeInfo = rowTypeInfo; + this.hiveTablePartition = hiveTablePartition; + this.tblProperties = tblProperties; + this.overwrite = overwrite; + isPartitioned = partitionCols != null && !partitionCols.isEmpty(); + isDynamicPartition = isPartitioned && partitionCols.size() > hiveTablePartition.getPartitionSpec().size(); + } + + // Custom serialization methods + + private void writeObject(ObjectOutputStream out) throws IOException { + super.write(out); + jobConf.write(out); + out.writeObject(isPartitioned); + out.writeObject(isDynamicPartition); + out.writeObject(overwrite); + out.writeObject(rowTypeInfo); + out.writeObject(hiveTablePartition); + out.writeObject(partitionCols); + out.writeObject(dbName); + out.writeObject(tableName); + out.writeObject(tblProperties); + } + + @SuppressWarnings("unchecked") + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + super.read(in); + if (jobConf == null) { + jobConf = new JobConf(); + } + jobConf.readFields(in); + jobConf.getCredentials().addAll(this.credentials); + Credentials currentUserCreds = HadoopInputFormatCommonBase.getCredentialsFromUGI(UserGroupInformation.getCurrentUser()); + if (currentUserCreds != null) { + jobConf.getCredentials().addAll(currentUserCreds); + } + isPartitioned = (boolean) in.readObject(); + isDynamicPartition = (boolean) in.readObject(); + overwrite = (boolean) in.readObject(); + rowTypeInfo = (RowTypeInfo) in.readObject(); + hiveTablePartition = (HiveTablePartition) in.readObject(); + partitionCols = (List) in.readObject(); + dbName = (String) in.readObject(); + tableName = (String) in.readObject(); + partitionToWriter = new HashMap<>(); + tblProperties = (Properties) in.readObject(); + } + + @Override + public void finalizeGlobal(int parallelism) throws IOException { + StorageDescriptor jobSD = hiveTablePartition.getStorageDescriptor(); + Path stagingDir = new Path(jobSD.getLocation()); + FileSystem fs = stagingDir.getFileSystem(jobConf); + try (HiveMetastoreClientWrapper client = HiveMetastoreClientFactory.create(new HiveConf(jobConf, HiveConf.class))) { + Table table = client.getTable(dbName, tableName); + if (!isDynamicPartition) { + commitJob(stagingDir.toString()); + } + if (isPartitioned) { + // TODO: to be implemented + } else { + moveFiles(stagingDir, new Path(table.getSd().getLocation())); + } + } catch (TException e) { + throw new CatalogException("Failed to query Hive metaStore", e); + } finally { + fs.delete(stagingDir, true); + } + } + + @Override + public void initializeGlobal(int parallelism) throws IOException { + } + + @Override + public void configure(Configuration parameters) { + // since our writers are transient, we don't need to do anything here + } + + @Override + public void open(int taskNumber, int numTasks) throws IOException { + try { + StorageDescriptor sd = hiveTablePartition.getStorageDescriptor(); + serializer = (AbstractSerDe) Class.forName(sd.getSerdeInfo().getSerializationLib()).newInstance(); + ReflectionUtils.setConf(serializer, jobConf); + // TODO: support partition properties, for now assume they're same as table properties + SerDeUtils.initializeSerDe(serializer, jobConf, tblProperties, null); + outputClass = serializer.getSerializedClass(); + } catch (IllegalAccessException | SerDeException | InstantiationException | ClassNotFoundException e) { + throw new FlinkRuntimeException("Error initializing Hive serializer", e); + } + + TaskAttemptID taskAttemptID = TaskAttemptID.forName("attempt__0000_r_" + + String.format("%" + (6 - Integer.toString(taskNumber).length()) + "s", " ").replace(" ", "0") + + taskNumber + "_0"); + + this.jobConf.set("mapred.task.id", taskAttemptID.toString()); + this.jobConf.setInt("mapred.task.partition", taskNumber); + // for hadoop 2.2 + this.jobConf.set("mapreduce.task.attempt.id", taskAttemptID.toString()); + this.jobConf.setInt("mapreduce.task.partition", taskNumber); + + this.context = new TaskAttemptContextImpl(this.jobConf, taskAttemptID); + + if (!isDynamicPartition) { + staticWriter = writerForLocation(hiveTablePartition.getStorageDescriptor().getLocation()); + } + + List objectInspectors = new ArrayList<>(); + for (int i = 0; i < rowTypeInfo.getArity() - partitionCols.size(); i++) { + objectInspectors.add(HiveTableUtil.getObjectInspector(rowTypeInfo.getTypeAt(i))); + } + + if (!isPartitioned) { + rowObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList(rowTypeInfo.getFieldNames()), + objectInspectors); + numNonPartitionCols = rowTypeInfo.getArity(); + } else { + rowObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList(rowTypeInfo.getFieldNames()).subList(0, rowTypeInfo.getArity() - partitionCols.size()), + objectInspectors); + numNonPartitionCols = rowTypeInfo.getArity() - partitionCols.size(); + } + } + + @Override + public void writeRecord(Row record) throws IOException { + try { + HivePartitionWriter partitionWriter = staticWriter; + if (isDynamicPartition) { + // TODO: to be implemented + } + partitionWriter.recordWriter.write(serializer.serialize(getConvertedRow(record), rowObjectInspector)); + } catch (IOException | SerDeException e) { + throw new IOException("Could not write Record.", e); + } + } + + // moves all files under srcDir into destDir + private void moveFiles(Path srcDir, Path destDir) throws IOException { + if (!srcDir.equals(destDir)) { + // TODO: src and dest may be on different FS + FileSystem fs = destDir.getFileSystem(jobConf); + Preconditions.checkState(fs.exists(destDir) || fs.mkdirs(destDir), "Failed to create dest path " + destDir); + if (overwrite) { + // delete existing files for overwrite + // TODO: support setting auto-purge? + final boolean purge = true; + // Note we assume the srcDir is a hidden dir, otherwise it will be deleted if it's a sub-dir of destDir + FileStatus[] existingFiles = fs.listStatus(destDir, FileUtils.HIDDEN_FILES_PATH_FILTER); + if (existingFiles != null) { + for (FileStatus existingFile : existingFiles) { + Preconditions.checkState(FileUtils.moveToTrash(fs, existingFile.getPath(), jobConf, purge), + "Failed to overwrite existing file " + existingFile); + } + } + } + FileStatus[] srcFiles = fs.listStatus(srcDir, FileUtils.HIDDEN_FILES_PATH_FILTER); + for (FileStatus srcFile : srcFiles) { + Path srcPath = srcFile.getPath(); + Path destPath = new Path(destDir, srcPath.getName()); + int count = 1; + while (!fs.rename(srcPath, destPath)) { + String name = srcPath.getName() + "_copy_" + count; + destPath = new Path(destDir, name); + count++; + } + } + } + } + + private void commitJob(String location) throws IOException { + jobConf.set(OUTDIR, location); + JobContext jobContext = new JobContextImpl(this.jobConf, new JobID()); + OutputCommitter outputCommitter = this.jobConf.getOutputCommitter(); + // finalize HDFS output format + outputCommitter.commitJob(jobContext); + } + + // converts a Row to a list so that Hive can serialize it + private Object getConvertedRow(Row record) { + List res = new ArrayList<>(numNonPartitionCols); + for (int i = 0; i < numNonPartitionCols; i++) { + res.add(record.getField(i)); + } + return res; + } + + @Override + public void close() throws IOException { + for (HivePartitionWriter partitionWriter : getPartitionWriters()) { + // TODO: need a way to decide whether to abort + partitionWriter.recordWriter.close(false); + if (partitionWriter.outputCommitter.needsTaskCommit(context)) { + partitionWriter.outputCommitter.commitTask(context); + } + } + } + + // get all partition writers we've created + private List getPartitionWriters() { + if (isDynamicPartition) { + return new ArrayList<>(partitionToWriter.values()); + } else { + return Collections.singletonList(staticWriter); + } + } + + private HivePartitionWriter writerForLocation(String location) throws IOException { + JobConf clonedConf = new JobConf(jobConf); + clonedConf.set(OUTDIR, location); + OutputFormat outputFormat; + try { + StorageDescriptor sd = hiveTablePartition.getStorageDescriptor(); + Class outputFormatClz = Class.forName(sd.getOutputFormat(), true, + Thread.currentThread().getContextClassLoader()); + outputFormatClz = HiveFileFormatUtils.getOutputFormatSubstitute(outputFormatClz); + outputFormat = (OutputFormat) outputFormatClz.newInstance(); + } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { + throw new FlinkRuntimeException("Unable to instantiate the hadoop output format", e); + } + ReflectionUtils.setConf(outputFormat, clonedConf); + OutputCommitter outputCommitter = clonedConf.getOutputCommitter(); + JobContext jobContext = new JobContextImpl(clonedConf, new JobID()); + outputCommitter.setupJob(jobContext); + final boolean isCompressed = clonedConf.getBoolean(HiveConf.ConfVars.COMPRESSRESULT.varname, false); + if (isCompressed) { + String codecStr = clonedConf.get(HiveConf.ConfVars.COMPRESSINTERMEDIATECODEC.varname); + if (!StringUtils.isNullOrWhitespaceOnly(codecStr)) { + try { + Class codec = + (Class) Class.forName(codecStr, true, + Thread.currentThread().getContextClassLoader()); + FileOutputFormat.setOutputCompressorClass(clonedConf, codec); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + String typeStr = clonedConf.get(HiveConf.ConfVars.COMPRESSINTERMEDIATETYPE.varname); + if (!StringUtils.isNullOrWhitespaceOnly(typeStr)) { + SequenceFile.CompressionType style = SequenceFile.CompressionType.valueOf(typeStr); + SequenceFileOutputFormat.setOutputCompressionType(clonedConf, style); + } + } + String taskPartition = String.valueOf(clonedConf.getInt("mapreduce.task.partition", -1)); + Path taskPath = FileOutputFormat.getTaskOutputPath(clonedConf, taskPartition); + FileSinkOperator.RecordWriter recordWriter; + try { + recordWriter = HiveFileFormatUtils.getRecordWriter(clonedConf, outputFormat, + outputClass, isCompressed, tblProperties, taskPath, Reporter.NULL); + } catch (HiveException e) { + throw new IOException(e); + } + return new HivePartitionWriter(clonedConf, outputFormat, recordWriter, outputCommitter); + } + + private static class HivePartitionWriter { + private final JobConf jobConf; + private final OutputFormat outputFormat; + private final FileSinkOperator.RecordWriter recordWriter; + private final OutputCommitter outputCommitter; + + HivePartitionWriter(JobConf jobConf, OutputFormat outputFormat, FileSinkOperator.RecordWriter recordWriter, + OutputCommitter outputCommitter) { + this.jobConf = jobConf; + this.outputFormat = outputFormat; + this.recordWriter = recordWriter; + this.outputCommitter = outputCommitter; + } + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java new file mode 100644 index 00000000000000..21aeb16ca92c72 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.hadoop.hive.metastore.api.StorageDescriptor; + +import java.io.Serializable; +import java.util.Map; + +/** + * A class that describes a partition of a Hive table. And it represents the whole table if table is not partitioned. + * Please note that the class is serializable because all its member variables are serializable. + */ +public class HiveTablePartition implements Serializable { + + private final StorageDescriptor storageDescriptor; + + // Partition spec for the partition. Should be null if the table is not partitioned. + private final Map partitionSpec; + + public HiveTablePartition(StorageDescriptor storageDescriptor, Map partitionSpec) { + this.storageDescriptor = storageDescriptor; + this.partitionSpec = partitionSpec; + } + + public StorageDescriptor getStorageDescriptor() { + return storageDescriptor; + } + + public Map getPartitionSpec() { + return partitionSpec; + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableUtil.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableUtil.java new file mode 100644 index 00000000000000..10ac1d4d170451 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableUtil.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.catalog.hive.util.HiveTypeUtil; + +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +import java.io.IOException; + +/** + * Util class for accessing Hive tables. + */ +public class HiveTableUtil { + + private HiveTableUtil() { + } + + /** + * Get Hive {@link ObjectInspector} for a Flink {@link TypeInformation}. + */ + public static ObjectInspector getObjectInspector(TypeInformation flinkType) throws IOException { + return getObjectInspector(HiveTypeUtil.toHiveTypeInfo(flinkType)); + } + + // TODO: reuse Hive's TypeInfoUtils? + private static ObjectInspector getObjectInspector(TypeInfo type) throws IOException { + switch (type.getCategory()) { + + case PRIMITIVE: + PrimitiveTypeInfo primitiveType = (PrimitiveTypeInfo) type; + return PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(primitiveType); + + // TODO: support complex types + default: + throw new IOException("Unsupported Hive type category " + type.getCategory()); + } + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java index b35ea648eaf2d6..0489757eec77a7 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveCatalog.java @@ -454,7 +454,7 @@ public boolean tableExists(ObjectPath tablePath) throws CatalogException { } @VisibleForTesting - Table getHiveTable(ObjectPath tablePath) throws TableNotExistException { + public Table getHiveTable(ObjectPath tablePath) throws TableNotExistException { try { return client.getTable(tablePath.getDatabaseName(), tablePath.getObjectName()); } catch (NoSuchObjectException e) { diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTypeUtil.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTypeUtil.java index 5a944156244b31..dc61a202c2e95d 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTypeUtil.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTypeUtil.java @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; /** * Utils to convert data types between Flink and Hive. @@ -49,27 +50,27 @@ private HiveTypeUtil() { * @return the corresponding Hive data type */ public static String toHiveType(TypeInformation type) { - if (type == BasicTypeInfo.BOOLEAN_TYPE_INFO) { + if (type.equals(BasicTypeInfo.BOOLEAN_TYPE_INFO)) { return serdeConstants.BOOLEAN_TYPE_NAME; - } else if (type == BasicTypeInfo.BYTE_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.BYTE_TYPE_INFO)) { return serdeConstants.TINYINT_TYPE_NAME; - } else if (type == BasicTypeInfo.SHORT_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.SHORT_TYPE_INFO)) { return serdeConstants.SMALLINT_TYPE_NAME; - } else if (type == BasicTypeInfo.INT_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.INT_TYPE_INFO)) { return serdeConstants.INT_TYPE_NAME; - } else if (type == BasicTypeInfo.LONG_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.LONG_TYPE_INFO)) { return serdeConstants.BIGINT_TYPE_NAME; - } else if (type == BasicTypeInfo.FLOAT_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.FLOAT_TYPE_INFO)) { return serdeConstants.FLOAT_TYPE_NAME; - } else if (type == BasicTypeInfo.DOUBLE_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.DOUBLE_TYPE_INFO)) { return serdeConstants.DOUBLE_TYPE_NAME; - } else if (type == BasicTypeInfo.STRING_TYPE_INFO) { + } else if (type.equals(BasicTypeInfo.STRING_TYPE_INFO)) { return serdeConstants.STRING_TYPE_NAME; - } else if (type == SqlTimeTypeInfo.DATE) { + } else if (type.equals(SqlTimeTypeInfo.DATE)) { return serdeConstants.DATE_TYPE_NAME; - } else if (type == PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO) { + } else if (type.equals(PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)) { return serdeConstants.BINARY_TYPE_NAME; - } else if (type == SqlTimeTypeInfo.TIMESTAMP) { + } else if (type.equals(SqlTimeTypeInfo.TIMESTAMP)) { return serdeConstants.TIMESTAMP_TYPE_NAME; } else { throw new UnsupportedOperationException( @@ -132,4 +133,12 @@ private static TypeInformation toFlinkPrimitiveType(PrimitiveTypeInfo hiveType) String.format("Flink doesn't support Hive primitive type %s yet", hiveType)); } } + + /** + * Converts a Flink {@link TypeInformation} to corresponding Hive {@link TypeInfo}. + */ + public static TypeInfo toHiveTypeInfo(TypeInformation flinkType) { + // TODO: support complex data types + return TypeInfoFactory.getPrimitiveTypeInfo(toHiveType(flinkType)); + } } diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormatTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormatTest.java new file mode 100644 index 00000000000000..ff443643eb815e --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveTableOutputFormatTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.catalog.ObjectPath; +import org.apache.flink.table.catalog.hive.HiveCatalog; +import org.apache.flink.table.catalog.hive.HiveCatalogTable; +import org.apache.flink.table.catalog.hive.HiveTestUtils; +import org.apache.flink.types.Row; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.MetaStoreUtils; +import org.apache.hadoop.hive.metastore.api.StorageDescriptor; +import org.apache.hadoop.hive.metastore.api.Table; +import org.apache.hadoop.mapred.JobConf; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests {@link HiveTableOutputFormatTest}. + */ +public class HiveTableOutputFormatTest { + + private static HiveCatalog hiveCatalog; + private static HiveConf hiveConf; + + @BeforeClass + public static void createCatalog() throws IOException { + hiveConf = HiveTestUtils.getHiveConf(); + hiveCatalog = HiveTestUtils.createHiveCatalog(hiveConf); + hiveCatalog.open(); + } + + @AfterClass + public static void closeCatalog() { + if (hiveCatalog != null) { + hiveCatalog.close(); + } + } + + @Test + public void testInsertIntoNonPartitionTable() throws Exception { + final String dbName = "default"; + final String tblName = "dest"; + ObjectPath tablePath = new ObjectPath(dbName, tblName); + TableSchema tableSchema = new TableSchema( + new String[]{"i", "l", "d", "s"}, + new TypeInformation[]{ + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO} + ); + HiveCatalogTable catalogTable = new HiveCatalogTable(tableSchema, new HashMap<>(), ""); + hiveCatalog.createTable(tablePath, catalogTable, false); + + Table hiveTable = hiveCatalog.getHiveTable(tablePath); + RowTypeInfo rowTypeInfo = new RowTypeInfo(tableSchema.getFieldTypes(), tableSchema.getFieldNames()); + StorageDescriptor jobSD = hiveTable.getSd().deepCopy(); + jobSD.setLocation(hiveTable.getSd().getLocation() + "/.staging"); + HiveTablePartition hiveTablePartition = new HiveTablePartition(jobSD, null); + JobConf jobConf = new JobConf(hiveConf); + HiveTableOutputFormat outputFormat = new HiveTableOutputFormat(jobConf, dbName, tblName, + Collections.emptyList(), rowTypeInfo, hiveTablePartition, MetaStoreUtils.getTableMetadata(hiveTable), false); + outputFormat.open(0, 1); + List toWrite = generateRecords(); + for (Row row : toWrite) { + outputFormat.writeRecord(row); + } + outputFormat.close(); + outputFormat.finalizeGlobal(1); + + // verify written data + Path outputFile = new Path(hiveTable.getSd().getLocation(), "0"); + FileSystem fs = outputFile.getFileSystem(jobConf); + assertTrue(fs.exists(outputFile)); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(outputFile)))) { + int numWritten = 0; + String line = reader.readLine(); + while (line != null) { + assertEquals(toWrite.get(numWritten++).toString(), line.replaceAll("\u0001", ",")); + line = reader.readLine(); + } + reader.close(); + assertEquals(toWrite.size(), numWritten); + } + } + + private List generateRecords() { + int numRecords = 5; + int arity = 4; + List res = new ArrayList<>(numRecords); + for (int i = 0; i < numRecords; i++) { + Row row = new Row(arity); + row.setField(0, i); + row.setField(1, (long) i); + row.setField(2, Double.valueOf(String.valueOf(String.format("%d.%d", i, i)))); + row.setField(3, String.valueOf((char) ('a' + i))); + res.add(row); + } + return res; + } +} diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveTestUtils.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveTestUtils.java index 0d4367aa4d4b4b..4eea7d817b74d8 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveTestUtils.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/catalog/hive/HiveTestUtils.java @@ -40,7 +40,11 @@ public static HiveCatalog createHiveCatalog() throws IOException { return new HiveCatalog(CatalogTestBase.TEST_CATALOG_NAME, getHiveConf()); } - private static HiveConf getHiveConf() throws IOException { + public static HiveCatalog createHiveCatalog(HiveConf hiveConf) { + return new HiveCatalog(CatalogTestBase.TEST_CATALOG_NAME, hiveConf); + } + + public static HiveConf getHiveConf() throws IOException { ClassLoader classLoader = new HiveTestUtils().getClass().getClassLoader(); HiveConf.setHiveSiteLocation(classLoader.getResource(HIVE_SITE_XML)); From 7682248f6dca971eae1d076dbf0282358331a0e7 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Mon, 3 Jun 2019 10:59:21 -0700 Subject: [PATCH 60/92] [FLINK-12712][table] deprecate ExternalCatalog and its subclasses and impls This PR deprecates `ExternalCatalog` and its subclasses, implementations, and related util classes. This closes #8600. --- .../java/org/apache/flink/table/catalog/ExternalCatalog.java | 5 +++-- .../org/apache/flink/table/catalog/ExternalCatalogTable.java | 5 +++-- .../main/scala/org/apache/flink/table/api/exceptions.scala | 2 ++ .../org/apache/flink/table/catalog/CrudExternalCatalog.scala | 3 +++ .../apache/flink/table/catalog/ExternalCatalogSchema.scala | 3 +++ .../flink/table/catalog/ExternalCatalogTableBuilder.scala | 1 + .../apache/flink/table/catalog/InMemoryExternalCatalog.scala | 3 +++ .../apache/flink/table/catalog/CatalogStructureBuilder.java | 1 + 8 files changed, 19 insertions(+), 4 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalog.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalog.java index 43f111d9af65c8..2e3ef9fa99dfbc 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalog.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalog.java @@ -18,7 +18,6 @@ package org.apache.flink.table.catalog; -import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.api.CatalogNotExistException; import org.apache.flink.table.api.TableNotExistException; @@ -30,8 +29,10 @@ * *

It provides information about catalogs, databases and tables such as names, schema, * statistics, and access information. + * + * @deprecated use {@link Catalog} instead. */ -@PublicEvolving +@Deprecated public interface ExternalCatalog { /** diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalogTable.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalogTable.java index 7e3cf020dd07a8..be24af4c54df61 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalogTable.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/catalog/ExternalCatalogTable.java @@ -18,7 +18,6 @@ package org.apache.flink.table.catalog; -import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.descriptors.DescriptorProperties; import org.apache.flink.table.descriptors.TableDescriptor; import org.apache.flink.table.factories.TableFactory; @@ -39,8 +38,10 @@ *

See also {@link TableFactory} for more information about how to target suitable factories. * *

Use {@code ExternalCatalogTableBuilder} to integrate with the normalized descriptor-based API. + * + * @deprecated use {@link CatalogTable} instead. */ -@PublicEvolving +@Deprecated public class ExternalCatalogTable extends TableDescriptor { /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/exceptions.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/exceptions.scala index 6617554301966e..465b5b32fe8bce 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/exceptions.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/exceptions.scala @@ -65,6 +65,7 @@ case class CatalogAlreadyExistException( * @param catalogName external catalog name * @param cause the cause */ +@deprecated case class ExternalCatalogNotExistException( catalogName: String, cause: Throwable) @@ -79,6 +80,7 @@ case class ExternalCatalogNotExistException( * @param catalogName external catalog name * @param cause the cause */ +@deprecated case class ExternalCatalogAlreadyExistException( catalogName: String, cause: Throwable) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/CrudExternalCatalog.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/CrudExternalCatalog.scala index ed7bcbbda8fd20..7bc12dae1e2955 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/CrudExternalCatalog.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/CrudExternalCatalog.scala @@ -22,7 +22,10 @@ import org.apache.flink.table.api._ /** * The CrudExternalCatalog provides methods to create, drop, and alter (sub-)catalogs or tables. + * + * @deprecated use [[Catalog]] instead. */ +@deprecated trait CrudExternalCatalog extends ExternalCatalog { /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala index b2ce188f97f580..501772766dfa42 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala @@ -37,7 +37,10 @@ import scala.collection.JavaConverters._ * * @param catalogIdentifier external catalog name * @param catalog external catalog + * + * @deprecated use [[CatalogCalciteSchema]] instead. */ +@deprecated class ExternalCatalogSchema( isBatch: Boolean, catalogIdentifier: String, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogTableBuilder.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogTableBuilder.scala index 4ffa9c261e565f..ae5c677c141ddc 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogTableBuilder.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogTableBuilder.scala @@ -53,6 +53,7 @@ import org.apache.flink.table.descriptors._ * * @param connectorDescriptor Connector descriptor describing the external system */ +@deprecated class ExternalCatalogTableBuilder(private val connectorDescriptor: ConnectorDescriptor) extends TableDescriptor with SchematicDescriptor[ExternalCatalogTableBuilder] diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/InMemoryExternalCatalog.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/InMemoryExternalCatalog.scala index c0cfbc35fbd7e4..4c40578dccd488 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/InMemoryExternalCatalog.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/InMemoryExternalCatalog.scala @@ -31,7 +31,10 @@ import scala.collection.JavaConverters._ * @param name The name of the catalog * * It could be used for testing or developing instead of used in production environment. + * + * @deprecated use [[GenericInMemoryCatalog]] instead. */ +@deprecated class InMemoryExternalCatalog(name: String) extends CrudExternalCatalog { private val databases = new mutable.HashMap[String, ExternalCatalog] diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java index cdcea1b5c6597f..eb247a418ec4fc 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java @@ -236,6 +236,7 @@ public ExternalTestTable buildExternalTable(String path) { * Marker interface to make {@link ExternalCatalogBuilder#extCatalog(String, ExternalCatalogEntry...)} * accept both {@link ExternalCatalogBuilder} and {@link TableBuilder}. */ + @Deprecated public interface ExternalCatalogEntry { } From 61d8916a35470d6c122d9b78c1f3ae7aa9996949 Mon Sep 17 00:00:00 2001 From: "bowen.li" Date: Mon, 3 Jun 2019 11:08:29 -0700 Subject: [PATCH 61/92] [FLINK-12713][table] deprecate descriptor, validator, and factory of ExternalCatalog This closes #8601. --- .../org/apache/flink/table/factories/TableFactoryUtil.java | 1 + .../flink/table/descriptors/ExternalCatalogDescriptor.java | 6 +++--- .../descriptors/ExternalCatalogDescriptorValidator.java | 3 +++ .../flink/table/factories/ExternalCatalogFactory.java | 5 +++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/factories/TableFactoryUtil.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/factories/TableFactoryUtil.java index 57c4af093bc9ca..aeafd449f90cb3 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/factories/TableFactoryUtil.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/factories/TableFactoryUtil.java @@ -34,6 +34,7 @@ public class TableFactoryUtil { /** * Returns an external catalog. */ + @Deprecated public static ExternalCatalog findAndCreateExternalCatalog(Descriptor descriptor) { Map properties = descriptor.toProperties(); return TableFactoryService diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptor.java index ff8f64cd13d559..13bfd4af65cc37 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptor.java @@ -18,8 +18,6 @@ package org.apache.flink.table.descriptors; -import org.apache.flink.annotation.PublicEvolving; - import java.util.Map; import static org.apache.flink.table.descriptors.ExternalCatalogDescriptorValidator.CATALOG_PROPERTY_VERSION; @@ -27,8 +25,10 @@ /** * Describes an external catalog of tables, views, and functions. + * + * @deprecated use {@link CatalogDescriptor} instead. */ -@PublicEvolving +@Deprecated public abstract class ExternalCatalogDescriptor extends DescriptorBase implements Descriptor { private final String type; diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptorValidator.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptorValidator.java index 2482f877636637..5b42b7783a7f1c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptorValidator.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/descriptors/ExternalCatalogDescriptorValidator.java @@ -22,7 +22,10 @@ /** * Validator for {@link ExternalCatalogDescriptor}. + * + * @deprecated use {@link CatalogDescriptorValidator} instead. */ +@Deprecated @Internal public abstract class ExternalCatalogDescriptorValidator implements DescriptorValidator { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/ExternalCatalogFactory.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/ExternalCatalogFactory.java index 86e327a30bd25f..6eef294b5298e6 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/ExternalCatalogFactory.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/ExternalCatalogFactory.java @@ -18,7 +18,6 @@ package org.apache.flink.table.factories; -import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.catalog.ExternalCatalog; import java.util.Map; @@ -26,8 +25,10 @@ /** * A factory to create configured external catalog instances based on string-based properties. See * also {@link TableFactory} for more information. + * + * @deprecated use {@link CatalogFactory} instead. */ -@PublicEvolving +@Deprecated public interface ExternalCatalogFactory extends TableFactory { /** From 0f4171956c8005535243781316defb4ef2fbe258 Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Thu, 23 May 2019 11:09:59 +0200 Subject: [PATCH 62/92] [FLINK-12601][table] Wrapping DataStream & DataSet as TableOperations. We do not store the DataStream & DataSet as Calcite's Tables anymore. We treat them as inline operations. When converting from TableOperations to RelNodes we directly create a special kind of DataStream/SetScan that does not access the catalog. This closes #8521 --- .../operations/DataSetTableOperation.java | 84 +++++++++++ .../operations/DataStreamTableOperation.java | 85 +++++++++++ .../catalog/TableOperationCatalogView.java | 66 +++++++++ .../operations/AggregateTableOperation.java | 13 +- .../operations/CalculatedTableOperation.java | 13 +- .../operations/CatalogTableOperation.java | 13 +- .../operations/DistinctTableOperation.java | 7 +- .../operations/FilterTableOperation.java | 12 +- .../table/operations/JoinTableOperation.java | 14 +- .../operations/ProjectTableOperation.java | 12 +- .../table/operations/SetTableOperation.java | 25 +++- .../table/operations/SortTableOperation.java | 14 +- .../table/operations/TableOperation.java | 71 ++++++++- .../operations/TableOperationUtils.java} | 35 +++-- .../WindowAggregateTableOperation.java | 41 +++++- .../table/operations/TableOperationTest.java | 115 +++++++++++++++ .../table/catalog/DatabaseCalciteSchema.java | 2 + .../TableOperationCatalogViewTable.java | 73 ++++++++++ .../table/plan/TableOperationConverter.java | 38 +++++ .../flink/table/api/BatchTableEnvImpl.scala | 67 +++------ .../flink/table/api/StreamTableEnvImpl.scala | 98 ++++++------- .../apache/flink/table/api/TableEnvImpl.scala | 113 +++++++-------- .../table/api/java/BatchTableEnvImpl.scala | 20 +-- .../table/api/java/StreamTableEnvImpl.scala | 26 ++-- .../table/api/scala/BatchTableEnvImpl.scala | 18 +-- .../table/api/scala/StreamTableEnvImpl.scala | 34 ++--- .../flink/table/calcite/FlinkRelBuilder.scala | 10 ++ .../table/calcite/FlinkTypeFactory.scala | 12 +- .../operations/PlannerTableOperation.java | 7 +- .../table/plan/nodes/dataset/BatchScan.scala | 12 +- .../plan/nodes/dataset/DataSetScan.scala | 36 +++-- .../nodes/datastream/DataStreamScan.scala | 38 +++-- ...an.scala => FlinkLogicalDataSetScan.scala} | 61 ++++---- .../logical/FlinkLogicalDataStreamScan.scala | 61 ++++++++ .../table/plan/rules/FlinkRuleSets.scala | 6 +- .../plan/rules/dataSet/DataSetScanRule.scala | 31 ++-- .../rules/datastream/DataStreamScanRule.scala | 30 ++-- ...icalCorrelateToTemporalTableJoinRule.scala | 9 +- .../table/plan/schema/DataStreamTable.scala | 29 ---- .../flink/table/plan/schema/InlineTable.scala | 117 --------------- .../flink/table/plan/schema/RelTable.scala | 46 ------ .../table/plan/stats/FlinkStatistic.scala | 3 +- .../flink/table/api/TableSourceTest.scala | 25 ++-- .../api/batch/BatchTableEnvironmentTest.scala | 6 +- .../flink/table/api/batch/ExplainTest.scala | 68 ++++++--- .../table/api/batch/sql/AggregateTest.scala | 20 +-- .../flink/table/api/batch/sql/CalcTest.scala | 12 +- .../table/api/batch/sql/CorrelateTest.scala | 44 +++--- .../api/batch/sql/DistinctAggregateTest.scala | 50 +++---- .../table/api/batch/sql/GroupWindowTest.scala | 40 +++--- .../api/batch/sql/GroupingSetsTest.scala | 24 ++-- .../flink/table/api/batch/sql/JoinTest.scala | 72 +++++----- .../api/batch/sql/SetOperatorsTest.scala | 36 ++--- .../api/batch/sql/SingleRowJoinTest.scala | 64 ++++----- .../table/api/batch/table/AggregateTest.scala | 8 +- .../table/api/batch/table/CalcTest.scala | 28 ++-- .../api/batch/table/ColumnFunctionsTest.scala | 2 +- .../table/api/batch/table/CorrelateTest.scala | 12 +- .../api/batch/table/GroupWindowTest.scala | 38 ++--- .../table/api/batch/table/JoinTest.scala | 40 +++--- .../api/batch/table/SetOperatorsTest.scala | 30 ++-- .../CorrelateStringExpressionTest.scala | 16 ++- .../table/stringexpr/SetOperatorsTest.scala | 6 +- .../flink/table/api/stream/ExplainTest.scala | 29 +++- .../stream/StreamTableEnvironmentTest.scala | 6 +- .../table/api/stream/sql/AggregateTest.scala | 4 +- .../table/api/stream/sql/CorrelateTest.scala | 48 +++---- .../stream/sql/DistinctAggregateTest.scala | 14 +- .../api/stream/sql/GroupWindowTest.scala | 16 +-- .../flink/table/api/stream/sql/JoinTest.scala | 136 +++++++++--------- .../api/stream/sql/MatchRecognizeTest.scala | 9 +- .../table/api/stream/sql/OverWindowTest.scala | 40 +++--- .../api/stream/sql/SetOperatorsTest.scala | 36 ++--- .../flink/table/api/stream/sql/SortTest.scala | 6 +- .../stream/sql/TemporalTableJoinTest.scala | 122 ++++++++++++++-- .../table/api/stream/sql/UnionTest.scala | 12 +- .../api/stream/table/AggregateTest.scala | 30 ++-- .../table/api/stream/table/CalcTest.scala | 33 ++--- .../stream/table/ColumnFunctionsTest.scala | 32 ++--- .../api/stream/table/CorrelateTest.scala | 25 ++-- .../table/GroupWindowTableAggregateTest.scala | 38 ++--- .../api/stream/table/GroupWindowTest.scala | 54 +++---- .../table/api/stream/table/JoinTest.scala | 68 ++++----- .../api/stream/table/OverWindowTest.scala | 34 ++--- .../api/stream/table/SetOperatorsTest.scala | 22 +-- .../api/stream/table/TableAggregateTest.scala | 12 +- .../api/stream/table/TableSourceTest.scala | 29 ++-- .../stream/table/TemporalTableJoinTest.scala | 103 +++++++------ .../CorrelateStringExpressionTest.scala | 23 ++- .../match/PatternTranslatorTestBase.scala | 1 + .../plan/ExpressionReductionRulesTest.scala | 51 +++---- .../table/plan/NormalizationRulesTest.scala | 12 +- .../table/plan/QueryDecorrelationTest.scala | 32 +++-- .../plan/TimeIndicatorConversionTest.scala | 44 +++--- .../stream/table/TableSourceITCase.scala | 25 ++++ .../flink/table/utils/TableTestBase.scala | 35 ++--- .../src/test/scala/resources/testFilter0.out | 4 +- .../src/test/scala/resources/testFilter1.out | 4 +- .../scala/resources/testFilterStream0.out | 4 +- .../src/test/scala/resources/testJoin0.out | 8 +- .../src/test/scala/resources/testJoin1.out | 8 +- .../src/test/scala/resources/testUnion0.out | 8 +- .../src/test/scala/resources/testUnion1.out | 8 +- .../test/scala/resources/testUnionStream0.out | 8 +- 104 files changed, 2103 insertions(+), 1388 deletions(-) create mode 100644 flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataSetTableOperation.java create mode 100644 flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataStreamTableOperation.java create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java rename flink-table/{flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataSetTable.scala => flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationUtils.java} (51%) create mode 100644 flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/operations/TableOperationTest.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogViewTable.java rename flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/{FlinkLogicalNativeTableScan.scala => FlinkLogicalDataSetScan.scala} (52%) create mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataStreamScan.scala delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/InlineTable.scala delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/RelTable.scala diff --git a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataSetTableOperation.java b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataSetTableOperation.java new file mode 100644 index 00000000000000..85e1a2644d5722 --- /dev/null +++ b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataSetTableOperation.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.operations; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.table.api.TableSchema; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Describes a relational operation that reads from a {@link DataSet}. + * + *

This operation may expose only part, or change the order of the fields available in a + * {@link org.apache.flink.api.common.typeutils.CompositeType} of the underlying {@link DataSet}. + * The {@link DataSetTableOperation#getFieldIndices()} describes the mapping between fields of the + * {@link TableSchema} to the {@link org.apache.flink.api.common.typeutils.CompositeType}. + */ +@Internal +public class DataSetTableOperation extends TableOperation { + + private final DataSet dataSet; + private final int[] fieldIndices; + private final TableSchema tableSchema; + + public DataSetTableOperation( + DataSet dataSet, + int[] fieldIndices, + TableSchema tableSchema) { + this.dataSet = dataSet; + this.tableSchema = tableSchema; + this.fieldIndices = fieldIndices; + } + + public DataSet getDataSet() { + return dataSet; + } + + public int[] getFieldIndices() { + return fieldIndices; + } + + @Override + public TableSchema getTableSchema() { + return tableSchema; + } + + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("fields", tableSchema.getFieldNames()); + + return formatWithChildren("DataSet", args); + } + + @Override + public List getChildren() { + return Collections.emptyList(); + } + + @Override + public T accept(TableOperationVisitor visitor) { + return visitor.visitOther(this); + } +} diff --git a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataStreamTableOperation.java b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataStreamTableOperation.java new file mode 100644 index 00000000000000..2e5020841419de --- /dev/null +++ b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/operations/DataStreamTableOperation.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.operations; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.TableSchema; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Describes a relational operation that reads from a {@link DataStream}. + * + *

This operation may expose only part, or change the order of the fields available in a + * {@link org.apache.flink.api.common.typeutils.CompositeType} of the underlying {@link DataStream}. + * The {@link DataStreamTableOperation#getFieldIndices()} describes the mapping between fields of the + * {@link TableSchema} to the {@link org.apache.flink.api.common.typeutils.CompositeType}. + */ +@Internal +public class DataStreamTableOperation extends TableOperation { + + private final DataStream dataStream; + private final int[] fieldIndices; + private final TableSchema tableSchema; + + public DataStreamTableOperation( + DataStream dataStream, + int[] fieldIndices, + TableSchema tableSchema) { + this.dataStream = dataStream; + this.tableSchema = tableSchema; + this.fieldIndices = fieldIndices; + } + + public DataStream getDataStream() { + return dataStream; + } + + public int[] getFieldIndices() { + return fieldIndices; + } + + @Override + public TableSchema getTableSchema() { + return tableSchema; + } + + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("id", dataStream.getId()); + args.put("fields", tableSchema.getFieldNames()); + + return formatWithChildren("DataStream", args); + } + + @Override + public List getChildren() { + return Collections.emptyList(); + } + + @Override + public T accept(TableOperationVisitor visitor) { + return visitor.visitOther(this); + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java new file mode 100644 index 00000000000000..0e9aafe8d2ac90 --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.operations.TableOperation; + +import java.util.HashMap; +import java.util.Optional; + +/** + * A view created from {@link TableOperation} via operations on {@link org.apache.flink.table.api.Table}. + */ +@Internal +public class TableOperationCatalogView extends AbstractCatalogView { + private final TableOperation tableOperation; + + public TableOperationCatalogView(TableOperation tableOperation) { + this(tableOperation, ""); + } + + public TableOperationCatalogView(TableOperation tableOperation, String comment) { + super( + tableOperation.asSummaryString(), + tableOperation.asSummaryString(), + tableOperation.getTableSchema(), + new HashMap<>(), + comment); + this.tableOperation = tableOperation; + } + + public TableOperation getTableOperation() { + return tableOperation; + } + + @Override + public TableOperationCatalogView copy() { + return new TableOperationCatalogView(this.tableOperation, getComment()); + } + + @Override + public Optional getDescription() { + return Optional.of(getComment()); + } + + @Override + public Optional getDetailedDescription() { + return getDescription(); + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/AggregateTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/AggregateTableOperation.java index a2b24783f35d2a..c3fa6e2ddf9aa9 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/AggregateTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/AggregateTableOperation.java @@ -23,14 +23,16 @@ import org.apache.flink.table.expressions.Expression; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Relational operation that performs computations on top of subsets of input rows grouped by * key. */ @Internal -public class AggregateTableOperation implements TableOperation { +public class AggregateTableOperation extends TableOperation { private final List groupingExpressions; private final List aggregateExpressions; @@ -53,6 +55,15 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("group", groupingExpressions); + args.put("agg", aggregateExpressions); + + return formatWithChildren("Aggregate", args); + } + public List getGroupingExpressions() { return groupingExpressions; } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CalculatedTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CalculatedTableOperation.java index 58097b932cff10..f030214a891e15 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CalculatedTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CalculatedTableOperation.java @@ -25,13 +25,15 @@ import org.apache.flink.table.functions.TableFunction; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Describes a relational operation that was created from applying a {@link TableFunction}. */ @Internal -public class CalculatedTableOperation implements TableOperation { +public class CalculatedTableOperation extends TableOperation { private final TableFunction tableFunction; private final List parameters; @@ -66,6 +68,15 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("function", tableFunction); + args.put("parameters", parameters); + + return formatWithChildren("CalculatedTable", args); + } + @Override public List getChildren() { return Collections.emptyList(); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CatalogTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CatalogTableOperation.java index 3d13d093e7128b..daad3b16133a9a 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CatalogTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/CatalogTableOperation.java @@ -22,13 +22,15 @@ import org.apache.flink.table.api.TableSchema; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Describes a relational operation that was created from a lookup to a catalog. */ @Internal -public class CatalogTableOperation implements TableOperation { +public class CatalogTableOperation extends TableOperation { private final List tablePath; private final TableSchema tableSchema; @@ -47,6 +49,15 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("path", tablePath); + args.put("fields", tableSchema.getFieldNames()); + + return formatWithChildren("CatalogTable", args); + } + @Override public List getChildren() { return Collections.emptyList(); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/DistinctTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/DistinctTableOperation.java index fbb1716467b65f..5453affd49ff61 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/DistinctTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/DistinctTableOperation.java @@ -28,7 +28,7 @@ * Removes duplicated rows of underlying relational operation. */ @Internal -public class DistinctTableOperation implements TableOperation { +public class DistinctTableOperation extends TableOperation { private final TableOperation child; @@ -41,6 +41,11 @@ public TableSchema getTableSchema() { return child.getTableSchema(); } + @Override + public String asSummaryString() { + return formatWithChildren("Distinct", Collections.emptyMap()); + } + @Override public List getChildren() { return Collections.singletonList(child); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/FilterTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/FilterTableOperation.java index 651071a9300574..5605d24422e982 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/FilterTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/FilterTableOperation.java @@ -23,13 +23,15 @@ import org.apache.flink.table.expressions.Expression; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Filters out rows of underlying relational operation that do not match given condition. */ @Internal -public class FilterTableOperation implements TableOperation { +public class FilterTableOperation extends TableOperation { private final Expression condition; private final TableOperation child; @@ -48,6 +50,14 @@ public TableSchema getTableSchema() { return child.getTableSchema(); } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("condition", condition); + + return formatWithChildren("Filter", args); + } + @Override public List getChildren() { return Collections.singletonList(child); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/JoinTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/JoinTableOperation.java index d4808a7a4bb59e..6f077be2e48e53 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/JoinTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/JoinTableOperation.java @@ -24,13 +24,15 @@ import org.apache.flink.table.expressions.Expression; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Table operation that joins two relational operations based on given condition. */ @Internal -public class JoinTableOperation implements TableOperation { +public class JoinTableOperation extends TableOperation { private final TableOperation left; private final TableOperation right; private final JoinType joinType; @@ -105,6 +107,16 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("joinType", joinType); + args.put("condition", condition); + args.put("correlated", correlated); + + return formatWithChildren("Join", args); + } + @Override public List getChildren() { return Arrays.asList(left, right); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/ProjectTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/ProjectTableOperation.java index 6b9daffc4057f9..5d97f455f5cf89 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/ProjectTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/ProjectTableOperation.java @@ -23,14 +23,16 @@ import org.apache.flink.table.expressions.Expression; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Table operation that computes new table using given {@link Expression}s * from its input relational operation. */ @Internal -public class ProjectTableOperation implements TableOperation { +public class ProjectTableOperation extends TableOperation { private final List projectList; private final TableOperation child; @@ -54,6 +56,14 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("projections", projectList); + + return formatWithChildren("Project", args); + } + @Override public List getChildren() { return Collections.singletonList(child); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SetTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SetTableOperation.java index d96d7ef1245db8..1bad51721f1410 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SetTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SetTableOperation.java @@ -22,14 +22,16 @@ import org.apache.flink.table.api.TableSchema; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * A set operation on two relations. It provides a way to union, intersect or subtract underlying * data sets/streams. Both relations must have equal schemas. */ @Internal -public class SetTableOperation implements TableOperation { +public class SetTableOperation extends TableOperation { private final TableOperation leftOperation; private final TableOperation rightOperation; @@ -67,6 +69,27 @@ public TableSchema getTableSchema() { return leftOperation.getTableSchema(); } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("all", all); + + return formatWithChildren(typeToString(), args); + } + + private String typeToString() { + switch (type) { + case INTERSECT: + return "Intersect"; + case MINUS: + return "Minus"; + case UNION: + return "Union"; + default: + throw new IllegalStateException("Unknown set operation type: " + type); + } + } + @Override public T accept(TableOperationVisitor visitor) { return visitor.visitSetOperation(this); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SortTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SortTableOperation.java index f3965cb3ba7d4b..7293678338b140 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SortTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/SortTableOperation.java @@ -23,14 +23,16 @@ import org.apache.flink.table.expressions.Expression; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; /** * Expresses sort operation of rows of the underlying relational operation with given order. * It also allows specifying offset and number of rows to fetch from the sorted data set/stream. */ @Internal -public class SortTableOperation implements TableOperation { +public class SortTableOperation extends TableOperation { private final List order; private final TableOperation child; @@ -71,6 +73,16 @@ public TableSchema getTableSchema() { return child.getTableSchema(); } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("order", order); + args.put("offset", offset); + args.put("fetch", fetch); + + return formatWithChildren("Sort", args); + } + @Override public List getChildren() { return Collections.singletonList(child); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperation.java index 3cf6c261abc6c6..2cfd3135504028 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperation.java @@ -21,21 +21,84 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.util.StringUtils; +import java.util.Arrays; +import java.util.Collection; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; /** * Base class for representing the operation structure behind a user-facing {@link Table} API. */ @Internal -public interface TableOperation { +public abstract class TableOperation { /** * Resolved schema of this operation. */ - TableSchema getTableSchema(); + public abstract TableSchema getTableSchema(); - List getChildren(); + /** + * Returns a string that summarizes this operation for printing to a console. An implementation might + * skip very specific properties. + * + *

Use {@link #asSerializableString()} for a operation string that fully serializes + * this instance. + * + * @return summary string of this operation for debugging purposes + */ + public abstract String asSummaryString(); + + /** + * Returns a string that fully serializes this instance. The serialized string can be used for storing + * the query in e.g. a {@link org.apache.flink.table.catalog.Catalog} as a view. + * + * @return detailed string for persisting in a catalog + */ + public String asSerializableString() { + throw new UnsupportedOperationException("TableOperations are not string serializable for now."); + } + + public abstract List getChildren(); + + public T accept(TableOperationVisitor visitor) { + return visitor.visitOther(this); + } + + protected final String formatWithChildren(String operationName, Map parameters) { + String description = parameters.entrySet() + .stream() + .map(entry -> formatParameter(entry.getKey(), entry.getValue())) + .collect(Collectors.joining(", ")); + + final StringBuilder stringBuilder = new StringBuilder(); + + stringBuilder.append(operationName).append(":"); + + if (!StringUtils.isNullOrWhitespaceOnly(description)) { + stringBuilder.append(" (").append(description).append(")"); + } + + String childrenDescription = getChildren().stream() + .map(child -> TableOperationUtils.indent(child.asSummaryString())) + .collect(Collectors.joining()); + + return stringBuilder.append(childrenDescription).toString(); + } - T accept(TableOperationVisitor visitor); + private String formatParameter(String name, Object value) { + final StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(name); + stringBuilder.append(": "); + if (value.getClass().isArray()) { + stringBuilder.append(Arrays.toString((Object[]) value)); + } else if (value instanceof Collection) { + stringBuilder.append(value); + } else { + stringBuilder.append("[").append(value).append("]"); + } + return stringBuilder.toString(); + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataSetTable.scala b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationUtils.java similarity index 51% rename from flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataSetTable.scala rename to flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationUtils.java index 42274984c8d3ac..d27a224001ddcc 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataSetTable.scala +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationUtils.java @@ -16,14 +16,31 @@ * limitations under the License. */ -package org.apache.flink.table.plan.schema +package org.apache.flink.table.operations; -import org.apache.flink.api.java.DataSet -import org.apache.flink.table.plan.stats.{FlinkStatistic, TableStats} +import org.apache.flink.annotation.Internal; -class DataSetTable[T]( - val dataSet: DataSet[T], - override val fieldIndexes: Array[Int], - override val fieldNames: Array[String], - override val statistic: FlinkStatistic = FlinkStatistic.of(new TableStats(1000L))) - extends InlineTable[T](dataSet.getType, fieldIndexes, fieldNames, statistic) +/** + * Helper methods for {@link TableOperation}s. + */ +@Internal +public class TableOperationUtils { + + private static final String OPERATION_INDENT = " "; + + /** + * Increases indentation for description of string of child {@link TableOperation}. + * The input can already contain indentation. This will increase all the indentations + * by one level. + * + * @param item result of {@link TableOperation#asSummaryString()} + * @return string with increased indentation + */ + static String indent(String item) { + return "\n" + OPERATION_INDENT + + item.replace("\n" + OPERATION_INDENT, "\n" + OPERATION_INDENT + OPERATION_INDENT); + } + + private TableOperationUtils() { + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/WindowAggregateTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/WindowAggregateTableOperation.java index 95c1d33ad03dfb..cbe2f4e957a85d 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/WindowAggregateTableOperation.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/WindowAggregateTableOperation.java @@ -28,7 +28,9 @@ import javax.annotation.Nullable; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import static org.apache.flink.table.operations.WindowAggregateTableOperation.ResolvedGroupWindow.WindowType.SESSION; @@ -42,7 +44,7 @@ * key and group window. It differs from {@link AggregateTableOperation} by the group window. */ @Internal -public class WindowAggregateTableOperation implements TableOperation { +public class WindowAggregateTableOperation extends TableOperation { private final List groupingExpressions; private final List aggregateExpressions; @@ -71,6 +73,17 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + Map args = new LinkedHashMap<>(); + args.put("group", groupingExpressions); + args.put("agg", aggregateExpressions); + args.put("windowProperties", windowPropertiesExpressions); + args.put("window", groupWindow.asSummaryString()); + + return formatWithChildren("WindowAggregate", args); + } + public List getGroupingExpressions() { return groupingExpressions; } @@ -124,7 +137,8 @@ public enum WindowType { */ private ResolvedGroupWindow( WindowType type, - String alias, FieldReferenceExpression timeAttribute, + String alias, + FieldReferenceExpression timeAttribute, @Nullable ValueLiteralExpression size, @Nullable ValueLiteralExpression slide, @Nullable ValueLiteralExpression gap) { @@ -202,5 +216,28 @@ public Optional getSize() { public Optional getGap() { return Optional.of(gap); } + + public String asSummaryString() { + switch (type) { + case SLIDE: + return String.format( + "SlideWindow(field: [%s], slide: [%s], size: [%s])", + timeAttribute, + slide, + size); + case SESSION: + return String.format( + "SessionWindow(field: [%s], gap: [%s])", + timeAttribute, + gap); + case TUMBLE: + return String.format( + "TumbleWindow(field: [%s], size: [%s])", + timeAttribute, + size); + default: + throw new IllegalStateException("Unknown window type: " + type); + } + } } } diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/operations/TableOperationTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/operations/TableOperationTest.java new file mode 100644 index 00000000000000..deb1273d0739d6 --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/operations/TableOperationTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.operations; + +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.expressions.ApiExpressionUtils; +import org.apache.flink.table.expressions.BuiltInFunctionDefinitions; +import org.apache.flink.table.expressions.CallExpression; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.typeutils.TimeIntervalTypeInfo; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for describing {@link TableOperation}s. + */ +public class TableOperationTest { + + @Test + public void testSummaryString() { + TableSchema schema = TableSchema.builder().field("a", DataTypes.INT()).build(); + + ProjectTableOperation tableOperation = new ProjectTableOperation( + Collections.singletonList(new FieldReferenceExpression("a", Types.INT, 0, 0)), + new CatalogTableOperation( + Arrays.asList("cat1", "db1", "tab1"), + schema), schema); + + SetTableOperation unionTableOperation = new SetTableOperation( + tableOperation, + tableOperation, + SetTableOperation.SetTableOperationType.UNION, + true); + + assertEquals("Union: (all: [true])\n" + + " Project: (projections: [a])\n" + + " CatalogTable: (path: [cat1, db1, tab1], fields: [a])\n" + + " Project: (projections: [a])\n" + + " CatalogTable: (path: [cat1, db1, tab1], fields: [a])", + unionTableOperation.asSummaryString()); + } + + @Test + public void testWindowAggregationSummaryString() { + TableSchema schema = TableSchema.builder().field("a", DataTypes.INT()).build(); + FieldReferenceExpression field = new FieldReferenceExpression("a", Types.INT, 0, 0); + WindowAggregateTableOperation tableOperation = new WindowAggregateTableOperation( + Collections.singletonList(field), + Collections.singletonList(new CallExpression(BuiltInFunctionDefinitions.SUM, + Collections.singletonList(field))), + Collections.emptyList(), + WindowAggregateTableOperation.ResolvedGroupWindow.sessionWindow("w", field, ApiExpressionUtils.valueLiteral( + 10, + TimeIntervalTypeInfo.INTERVAL_MILLIS)), + new CatalogTableOperation( + Arrays.asList("cat1", "db1", "tab1"), + schema), + schema + ); + + DistinctTableOperation distinctTableOperation = new DistinctTableOperation(tableOperation); + + assertEquals( + "Distinct:\n" + + " WindowAggregate: (group: [a], agg: [sum(a)], windowProperties: []," + + " window: [SessionWindow(field: [a], gap: [10.millis])])\n" + + " CatalogTable: (path: [cat1, db1, tab1], fields: [a])", + distinctTableOperation.asSummaryString()); + } + + @Test + public void testIndentation() { + + String input = + "firstLevel\n" + + " secondLevel0\n" + + " thirdLevel0\n" + + " secondLevel1\n" + + " thirdLevel1"; + + String indentedInput = TableOperationUtils.indent(input); + + assertEquals( + "\n" + + " firstLevel\n" + + " secondLevel0\n" + + " thirdLevel0\n" + + " secondLevel1\n" + + " thirdLevel1", + indentedInput); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java index cb341ed3bd4c59..7d93b3a6bac1bf 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java @@ -67,6 +67,8 @@ public Table getTable(String tableName) { if (table instanceof CalciteCatalogTable) { return ((CalciteCatalogTable) table).getTable(); + } else if (table instanceof TableOperationCatalogView) { + return TableOperationCatalogViewTable.createCalciteTable(((TableOperationCatalogView) table)); } else { throw new TableException("Unsupported table type: " + table); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogViewTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogViewTable.java new file mode 100644 index 00000000000000..baf2efca036725 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogViewTable.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.calcite.FlinkRelBuilder; +import org.apache.flink.table.calcite.FlinkTypeFactory; + +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelProtoDataType; +import org.apache.calcite.schema.TranslatableTable; +import org.apache.calcite.schema.impl.AbstractTable; + +/** + * A bridge between a Flink's specific {@link TableOperationCatalogView} and a Calcite's + * {@link org.apache.calcite.schema.Table}. It implements {@link TranslatableTable} interface. This enables + * direct translation from {@link org.apache.flink.table.operations.TableOperation} to {@link RelNode}. + * + *

NOTE: Due to legacy inconsistency in null handling in the {@link TableSchema} the translation might introduce + * additional cast to comply with manifested schema in + * {@link TableOperationCatalogViewTable#getRowType(RelDataTypeFactory)}. + */ +@Internal +public class TableOperationCatalogViewTable extends AbstractTable implements TranslatableTable { + private final TableOperationCatalogView catalogView; + private final RelProtoDataType rowType; + + public static TableOperationCatalogViewTable createCalciteTable(TableOperationCatalogView catalogView) { + return new TableOperationCatalogViewTable(catalogView, typeFactory -> { + TableSchema tableSchema = catalogView.getSchema(); + return ((FlinkTypeFactory) typeFactory).buildLogicalRowType(tableSchema); + }); + } + + private TableOperationCatalogViewTable(TableOperationCatalogView catalogView, RelProtoDataType rowType) { + this.catalogView = catalogView; + this.rowType = rowType; + } + + @Override + public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable relOptTable) { + FlinkRelBuilder relBuilder = FlinkRelBuilder.of(context.getCluster(), relOptTable); + + RelNode relNode = relBuilder.tableOperation(catalogView.getTableOperation()).build(); + return RelOptUtil.createCastRel(relNode, rowType.apply(relBuilder.getTypeFactory()), false); + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + return rowType.apply(typeFactory); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java index 565bca7a205589..f3aaac8ba67175 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java @@ -40,6 +40,8 @@ import org.apache.flink.table.operations.AggregateTableOperation; import org.apache.flink.table.operations.CalculatedTableOperation; import org.apache.flink.table.operations.CatalogTableOperation; +import org.apache.flink.table.operations.DataSetTableOperation; +import org.apache.flink.table.operations.DataStreamTableOperation; import org.apache.flink.table.operations.DistinctTableOperation; import org.apache.flink.table.operations.FilterTableOperation; import org.apache.flink.table.operations.JoinTableOperation; @@ -57,12 +59,17 @@ import org.apache.flink.table.plan.logical.SessionGroupWindow; import org.apache.flink.table.plan.logical.SlidingGroupWindow; import org.apache.flink.table.plan.logical.TumblingGroupWindow; +import org.apache.flink.table.plan.nodes.FlinkConventions; +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalDataSetScan; +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalDataStreamScan; import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl; +import org.apache.flink.table.plan.schema.RowSchema; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.logical.LogicalTableFunctionScan; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexNode; import org.apache.calcite.tools.RelBuilder.AggCall; import org.apache.calcite.tools.RelBuilder.GroupKey; @@ -248,11 +255,42 @@ public RelNode visitCatalogTable(CatalogTableOperation catalogTable) { public RelNode visitOther(TableOperation other) { if (other instanceof PlannerTableOperation) { return ((PlannerTableOperation) other).getCalciteTree(); + } else if (other instanceof DataStreamTableOperation) { + return convertToDataStreamScan((DataStreamTableOperation) other); + } else if (other instanceof DataSetTableOperation) { + return convertToDataSetScan((DataSetTableOperation) other); } throw new TableException("Unknown table operation: " + other); } + private RelNode convertToDataStreamScan(DataStreamTableOperation tableOperation) { + RelDataType logicalRowType = relBuilder.getTypeFactory() + .buildLogicalRowType(tableOperation.getTableSchema()); + RowSchema rowSchema = new RowSchema(logicalRowType); + + return new FlinkLogicalDataStreamScan( + relBuilder.getCluster(), + relBuilder.getCluster().traitSet().replace(FlinkConventions.LOGICAL()), + relBuilder.getRelOptSchema(), + tableOperation.getDataStream(), + tableOperation.getFieldIndices(), + rowSchema); + } + + private RelNode convertToDataSetScan(DataSetTableOperation tableOperation) { + RelDataType logicalRowType = relBuilder.getTypeFactory() + .buildLogicalRowType(tableOperation.getTableSchema()); + + return new FlinkLogicalDataSetScan( + relBuilder.getCluster(), + relBuilder.getCluster().traitSet().replace(FlinkConventions.LOGICAL()), + relBuilder.getRelOptSchema(), + tableOperation.getDataSet(), + tableOperation.getFieldIndices(), + logicalRowType); + } + private RexNode convertToRexNode(Expression expression) { return expressionBridge.bridge(expression).toRexNode(relBuilder); } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala index d0ee44088d137a..6dad03aa8307f7 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala @@ -35,6 +35,7 @@ import org.apache.flink.table.descriptors.{BatchTableDescriptor, ConnectorDescri import org.apache.flink.table.explain.PlanJsonParser import org.apache.flink.table.expressions.BuiltInFunctionDefinitions.TIME_ATTRIBUTES import org.apache.flink.table.expressions.{CallExpression, Expression} +import org.apache.flink.table.operations.DataSetTableOperation import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.dataset.DataSetRel import org.apache.flink.table.plan.rules.FlinkRuleSets @@ -42,7 +43,7 @@ import org.apache.flink.table.plan.schema._ import org.apache.flink.table.runtime.MapRunner import org.apache.flink.table.sinks._ import org.apache.flink.table.sources.{BatchTableSource, TableSource} -import org.apache.flink.table.typeutils.FieldInfoUtils.{getFieldsInfo, validateInputTypeInfo} +import org.apache.flink.table.typeutils.FieldInfoUtils.{calculateTableSchema, getFieldsInfo, validateInputTypeInfo} import org.apache.flink.types.Row /** @@ -116,7 +117,7 @@ abstract class BatchTableEnvImpl( val enrichedTable = new TableSourceSinkTable( Some(new BatchTableSourceTable(batchTableSource)), table.tableSinkTable) - replaceRegisteredTable(name, enrichedTable) + replaceRegisteredTableSourceSinkInternal(name, enrichedTable) } // no table is registered @@ -124,7 +125,7 @@ abstract class BatchTableEnvImpl( val newTable = new TableSourceSinkTable( Some(new BatchTableSourceTable(batchTableSource)), None) - registerTableInternal(name, newTable) + registerTableSourceSinkInternal(name, newTable) } // not a batch table source @@ -196,7 +197,7 @@ abstract class BatchTableEnvImpl( val enrichedTable = new TableSourceSinkTable( table.tableSourceTable, Some(new TableSinkTable(configuredSink))) - replaceRegisteredTable(name, enrichedTable) + replaceRegisteredTableSourceSinkInternal(name, enrichedTable) } // no table is registered @@ -204,7 +205,7 @@ abstract class BatchTableEnvImpl( val newTable = new TableSourceSinkTable( None, Some(new TableSinkTable(configuredSink))) - registerTableInternal(name, newTable) + registerTableSourceSinkInternal(name, newTable) } // not a batch table sink @@ -312,51 +313,27 @@ abstract class BatchTableEnvImpl( def explain(table: Table): String = explain(table: Table, extended = false) - /** - * Registers a [[DataSet]] as a table under a given name in the [[TableEnvImpl]]'s catalog. - * - * @param name The name under which the table is registered in the catalog. - * @param dataSet The [[DataSet]] to register as table in the catalog. - * @tparam T the type of the [[DataSet]]. - */ - protected def registerDataSetInternal[T](name: String, dataSet: DataSet[T]): Unit = { - - val fieldInfo = getFieldsInfo[T](dataSet.getType) - val dataSetTable = new DataSetTable[T]( - dataSet, - fieldInfo.getIndices, - fieldInfo.getFieldNames - ) - registerTableInternal(name, dataSetTable) - } - - /** - * Registers a [[DataSet]] as a table under a given name with field names as specified by - * field expressions in the [[TableEnvImpl]]'s catalog. - * - * @param name The name under which the table is registered in the catalog. - * @param dataSet The [[DataSet]] to register as table in the catalog. - * @param fields The field expressions to define the field names of the table. - * @tparam T The type of the [[DataSet]]. - */ - protected def registerDataSetInternal[T]( - name: String, dataSet: DataSet[T], fields: Array[Expression]): Unit = { - + protected def asDataSetTableOperation[T](dataSet: DataSet[T], fields: Option[Array[Expression]]) + : DataSetTableOperation[T] = { val inputType = dataSet.getType - val fieldsInfo = getFieldsInfo[T]( - inputType, - fields) + val fieldsInfo = fields match { + case Some(f) => + if (f.exists(f => + f.isInstanceOf[CallExpression] && + TIME_ATTRIBUTES.contains(f.asInstanceOf[CallExpression].getFunctionDefinition))) { + throw new ValidationException( + ".rowtime and .proctime time indicators are not allowed in a batch environment.") + } - if (fields.exists(f => - f.isInstanceOf[CallExpression] && - TIME_ATTRIBUTES.contains(f.asInstanceOf[CallExpression].getFunctionDefinition))) { - throw new ValidationException( - ".rowtime and .proctime time indicators are not allowed in a batch environment.") + getFieldsInfo[T](inputType, f) + case None => getFieldsInfo[T](inputType) } - val dataSetTable = new DataSetTable[T](dataSet, fieldsInfo.getIndices, fieldsInfo.getFieldNames) - registerTableInternal(name, dataSetTable) + val tableOperation = new DataSetTableOperation[T](dataSet, + fieldsInfo.getIndices, + calculateTableSchema(inputType, fieldsInfo.getIndices, fieldsInfo.getFieldNames)) + tableOperation } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala index 40ecadacbfc58e..5c96e7a609eb98 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala @@ -41,6 +41,7 @@ import org.apache.flink.table.catalog.CatalogManager import org.apache.flink.table.descriptors.{ConnectorDescriptor, StreamTableDescriptor} import org.apache.flink.table.explain.PlanJsonParser import org.apache.flink.table.expressions._ +import org.apache.flink.table.operations.DataStreamTableOperation import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.datastream.{DataStreamRel, UpdateAsRetractionTrait} import org.apache.flink.table.plan.rules.FlinkRuleSets @@ -51,8 +52,8 @@ import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.table.runtime.{CRowMapRunner, OutputRowtimeProcessFunction} import org.apache.flink.table.sinks._ import org.apache.flink.table.sources.{StreamTableSource, TableSource, TableSourceUtil} +import org.apache.flink.table.typeutils.FieldInfoUtils.{calculateTableSchema, getFieldsInfo, isReferenceByPosition} import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TypeCheckUtils} -import org.apache.flink.table.typeutils.FieldInfoUtils.{getFieldsInfo, isReferenceByPosition} import _root_.scala.collection.JavaConverters._ @@ -136,7 +137,7 @@ abstract class StreamTableEnvImpl( val enrichedTable = new TableSourceSinkTable( Some(new StreamTableSourceTable(streamTableSource)), table.tableSinkTable) - replaceRegisteredTable(name, enrichedTable) + replaceRegisteredTableSourceSinkInternal(name, enrichedTable) } // no table is registered @@ -144,7 +145,7 @@ abstract class StreamTableEnvImpl( val newTable = new TableSourceSinkTable( Some(new StreamTableSourceTable(streamTableSource)), None) - registerTableInternal(name, newTable) + registerTableSourceSinkInternal(name, newTable) } // not a stream table source @@ -215,7 +216,7 @@ abstract class StreamTableEnvImpl( val enrichedTable = new TableSourceSinkTable( table.tableSourceTable, Some(new TableSinkTable(configuredSink))) - replaceRegisteredTable(name, enrichedTable) + replaceRegisteredTableSourceSinkInternal(name, enrichedTable) } // no table is registered @@ -223,7 +224,7 @@ abstract class StreamTableEnvImpl( val newTable = new TableSourceSinkTable( None, Some(new TableSinkTable(configuredSink))) - registerTableInternal(name, newTable) + registerTableSourceSinkInternal(name, newTable) } // not a stream table sink @@ -432,67 +433,48 @@ abstract class StreamTableEnvImpl( } } - /** - * Registers a [[DataStream]] as a table under a given name in the [[TableEnvImpl]]'s - * catalog. - * - * @param name The name under which the table is registered in the catalog. - * @param dataStream The [[DataStream]] to register as table in the catalog. - * @tparam T the type of the [[DataStream]]. - */ - protected def registerDataStreamInternal[T]( - name: String, - dataStream: DataStream[T]): Unit = { - - val fieldInfo = getFieldsInfo[T](dataStream.getType) - val dataStreamTable = new DataStreamTable[T]( - dataStream, - fieldInfo.getIndices, - fieldInfo.getFieldNames - ) - registerTableInternal(name, dataStreamTable) - } - - /** - * Registers a [[DataStream]] as a table under a given name with field names as specified by - * field expressions in the [[TableEnvImpl]]'s catalog. - * - * @param name The name under which the table is registered in the catalog. - * @param dataStream The [[DataStream]] to register as table in the catalog. - * @param fields The field expressions to define the field names of the table. - * @tparam T The type of the [[DataStream]]. - */ - protected def registerDataStreamInternal[T]( - name: String, + protected def asTableOperation[T]( dataStream: DataStream[T], - fields: Array[Expression]) - : Unit = { - + fields: Option[Array[Expression]]) + : DataStreamTableOperation[T] = { val streamType = dataStream.getType // get field names and types for all non-replaced fields - val fieldsInfo = getFieldsInfo[T](streamType, fields) - - // validate and extract time attributes - val (rowtime, proctime) = validateAndExtractTimeAttributes(streamType, fields) + val (indices, names) = fields match { + case Some(f) => + // validate and extract time attributes + val fieldsInfo = getFieldsInfo[T](streamType, f) + val (rowtime, proctime) = validateAndExtractTimeAttributes(streamType, f) + + // check if event-time is enabled + if (rowtime.isDefined && + execEnv.getStreamTimeCharacteristic != TimeCharacteristic.EventTime) { + throw new TableException( + s"A rowtime attribute requires an EventTime time characteristic in stream environment" + + s". But is: ${execEnv.getStreamTimeCharacteristic}") + } - // check if event-time is enabled - if (rowtime.isDefined && execEnv.getStreamTimeCharacteristic != TimeCharacteristic.EventTime) { - throw new TableException( - s"A rowtime attribute requires an EventTime time characteristic in stream environment. " + - s"But is: ${execEnv.getStreamTimeCharacteristic}") + // adjust field indexes and field names + val indexesWithIndicatorFields = adjustFieldIndexes( + fieldsInfo.getIndices, + rowtime, + proctime) + val namesWithIndicatorFields = adjustFieldNames( + fieldsInfo.getFieldNames, + rowtime, + proctime) + + (indexesWithIndicatorFields, namesWithIndicatorFields) + case None => + val fieldsInfo = getFieldsInfo[T](streamType) + (fieldsInfo.getIndices, fieldsInfo.getFieldNames) } - // adjust field indexes and field names - val indexesWithIndicatorFields = adjustFieldIndexes(fieldsInfo.getIndices, rowtime, proctime) - val namesWithIndicatorFields = adjustFieldNames(fieldsInfo.getFieldNames, rowtime, proctime) - - val dataStreamTable = new DataStreamTable[T]( + val dataStreamTable = new DataStreamTableOperation( dataStream, - indexesWithIndicatorFields, - namesWithIndicatorFields - ) - registerTableInternal(name, dataStreamTable) + indices, + calculateTableSchema(streamType, indices, names)) + dataStreamTable } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala index 5dd46ccbd02588..7f329dba4361fd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala @@ -18,7 +18,6 @@ package org.apache.flink.table.api -import _root_.java.lang.reflect.Modifier import _root_.java.util.Optional import _root_.java.util.concurrent.atomic.AtomicInteger @@ -30,7 +29,6 @@ import org.apache.calcite.plan._ import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgram, HepProgramBuilder} import org.apache.calcite.rel.RelNode import org.apache.calcite.schema.SchemaPlus -import org.apache.calcite.schema.impl.AbstractTable import org.apache.calcite.sql._ import org.apache.calcite.sql.parser.SqlParser import org.apache.calcite.tools._ @@ -47,7 +45,7 @@ import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, Tabl import org.apache.flink.table.operations.{CatalogTableOperation, OperationTreeBuilder, PlannerTableOperation} import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.rules.FlinkRuleSets -import org.apache.flink.table.plan.schema.{RelTable, RowSchema, TableSourceSinkTable} +import org.apache.flink.table.plan.schema.{RowSchema, TableSourceSinkTable} import org.apache.flink.table.planner.PlanningConfigurationBuilder import org.apache.flink.table.sinks.TableSink import org.apache.flink.table.sources.TableSource @@ -326,12 +324,6 @@ abstract class TableEnvImpl( output } - override def fromTableSource(source: TableSource[_]): Table = { - val name = createUniqueTableName() - registerTableSourceInternal(name, source) - scan(name) - } - override def registerExternalCatalog(name: String, externalCatalog: ExternalCatalog): Unit = { catalogManager.registerExternalCatalog(name, externalCatalog) } @@ -404,24 +396,6 @@ abstract class TableEnvImpl( planningConfigurationBuilder.getTypeFactory) } - override def registerTable(name: String, table: Table): Unit = { - - // check that table belongs to this table environment - if (table.asInstanceOf[TableImpl].tableEnv != this) { - throw new TableException( - "Only tables that belong to this TableEnvironment can be registered.") - } - - checkValidTableName(name) - val tableTable = new RelTable(table.asInstanceOf[TableImpl].getRelNode) - registerTableInternal(name, tableTable) - } - - override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = { - checkValidTableName(name) - registerTableSourceInternal(name, tableSource) - } - override def registerCatalog(catalogName: String, catalog: Catalog): Unit = { catalogManager.registerCatalog(catalogName, catalog) } @@ -446,24 +420,57 @@ abstract class TableEnvImpl( catalogManager.setCurrentDatabase(databaseName) } - /** - * Registers an internal [[TableSource]] in this [[TableEnvironment]]'s catalog without - * name checking. Registered tables can be referenced in SQL queries. - * - * @param name The name under which the [[TableSource]] is registered. - * @param tableSource The [[TableSource]] to register. - */ + override def registerTable(name: String, table: Table): Unit = { + + // check that table belongs to this table environment + if (table.asInstanceOf[TableImpl].tableEnv != this) { + throw new TableException( + "Only tables that belong to this TableEnvironment can be registered.") + } + + checkValidTableName(name) + + val tableTable = new TableOperationCatalogView(table.getTableOperation) + registerTableInternal(name, tableTable) + } + + override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = { + registerTableSourceInternal(name, tableSource) + } + + override def fromTableSource(source: TableSource[_]): Table = { + val name = createUniqueTableName() + registerTableSourceInternal(name, source) + scan(name) + } + + protected def registerTableInternal(name: String, table: CatalogBaseTable): Unit = { + val path = new ObjectPath(defaultDatabaseName, name) + JavaScalaConversionUtil.toScala(catalogManager.getCatalog(defaultCatalogName)) match { + case Some(catalog) => + catalog.createTable( + path, + table, + false) + case None => throw new TableException("The default catalog does not exist.") + } + } + protected def registerTableSourceInternal(name: String, tableSource: TableSource[_]): Unit - /** - * Replaces a registered Table with another Table under the same name. - * We use this method to replace a [[org.apache.flink.table.plan.schema.DataStreamTable]] - * with a [[org.apache.calcite.schema.TranslatableTable]]. - * - * @param name Name of the table to replace. - * @param table The table that replaces the previous table. - */ - protected def replaceRegisteredTable(name: String, table: AbstractTable): Unit = { + protected def registerTableSourceSinkInternal[T1, T2]( + name: String, + table: TableSourceSinkTable[T1, T2]) + : Unit = { + registerTableInternal( + name, + new CalciteCatalogTable(table, planningConfigurationBuilder.getTypeFactory)) + } + + protected def replaceRegisteredTableSourceSinkInternal[T1, T2]( + name: String, + table: TableSourceSinkTable[T1, T2]) + : Unit = { val path = new ObjectPath(defaultDatabaseName, name) JavaScalaConversionUtil.toScala(catalogManager.getCatalog(defaultCatalogName)) match { case Some(catalog) => @@ -617,26 +624,6 @@ abstract class TableEnvImpl( } } - /** - * Registers a Calcite [[AbstractTable]] in the TableEnvironment's default catalog. - * - * @param name The name under which the table will be registered. - * @param table The table to register in the catalog - * @throws TableException if another table is registered under the provided name. - */ - @throws[TableException] - protected def registerTableInternal(name: String, table: AbstractTable): Unit = { - val path = new ObjectPath(defaultDatabaseName, name) - JavaScalaConversionUtil.toScala(catalogManager.getCatalog(defaultCatalogName)) match { - case Some(catalog) => - catalog.createTable( - path, - new CalciteCatalogTable(table, planningConfigurationBuilder.getTypeFactory), - false) - case None => throw new TableException("The default catalog does not exist.") - } - } - /** Returns a unique table name according to the internal naming pattern. */ protected def createUniqueTableName(): String diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvImpl.scala index ddd2fd9c194d4e..9d2de41e6a86c9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvImpl.scala @@ -45,10 +45,7 @@ class BatchTableEnvImpl( with org.apache.flink.table.api.java.BatchTableEnvironment { override def fromDataSet[T](dataSet: DataSet[T]): Table = { - - val name = createUniqueTableName() - registerDataSetInternal(name, dataSet) - scan(name) + new TableImpl(this, asDataSetTableOperation(dataSet, None)) } override def fromDataSet[T](dataSet: DataSet[T], fields: String): Table = { @@ -56,24 +53,15 @@ class BatchTableEnvImpl( .parseExpressionList(fields).asScala .toArray - val name = createUniqueTableName() - registerDataSetInternal(name, dataSet, exprs) - scan(name) + new TableImpl(this, asDataSetTableOperation(dataSet, Some(exprs))) } override def registerDataSet[T](name: String, dataSet: DataSet[T]): Unit = { - - checkValidTableName(name) - registerDataSetInternal(name, dataSet) + registerTable(name, fromDataSet(dataSet)) } override def registerDataSet[T](name: String, dataSet: DataSet[T], fields: String): Unit = { - val exprs = ExpressionParser - .parseExpressionList(fields).asScala - .toArray - - checkValidTableName(name) - registerDataSetInternal(name, dataSet, exprs) + registerTable(name, fromDataSet(dataSet, fields)) } override def toDataSet[T](table: Table, clazz: Class[T]): DataSet[T] = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala index 235370ee1b91e0..108c201398f822 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/java/StreamTableEnvImpl.scala @@ -50,10 +50,7 @@ class StreamTableEnvImpl( with org.apache.flink.table.api.java.StreamTableEnvironment { override def fromDataStream[T](dataStream: DataStream[T]): Table = { - - val name = createUniqueTableName() - registerDataStreamInternal(name, dataStream) - scan(name) + new TableImpl(this, asTableOperation(dataStream, None)) } override def fromDataStream[T](dataStream: DataStream[T], fields: String): Table = { @@ -61,26 +58,19 @@ class StreamTableEnvImpl( .parseExpressionList(fields).asScala .toArray - val name = createUniqueTableName() - registerDataStreamInternal(name, dataStream, exprs) - scan(name) + new TableImpl(this, asTableOperation(dataStream, Some(exprs))) } override def registerDataStream[T](name: String, dataStream: DataStream[T]): Unit = { - - checkValidTableName(name) - registerDataStreamInternal(name, dataStream) + registerTable(name, fromDataStream(dataStream)) } override def registerDataStream[T]( - name: String, dataStream: DataStream[T], fields: String): Unit = { - - val exprs = ExpressionParser - .parseExpressionList(fields).asScala - .toArray - - checkValidTableName(name) - registerDataStreamInternal(name, dataStream, exprs) + name: String, + dataStream: DataStream[T], + fields: String) + : Unit = { + registerTable(name, fromDataStream(dataStream, fields)) } override def toAppendStream[T](table: Table, clazz: Class[T]): DataStream[T] = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvImpl.scala index 3c2e5899b32797..ca74712ea2f772 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/BatchTableEnvImpl.scala @@ -44,29 +44,19 @@ class BatchTableEnvImpl( with org.apache.flink.table.api.scala.BatchTableEnvironment { override def fromDataSet[T](dataSet: DataSet[T]): Table = { - - val name = createUniqueTableName() - registerDataSetInternal(name, dataSet.javaSet) - scan(name) + new TableImpl(this, asDataSetTableOperation(dataSet.javaSet, None)) } override def fromDataSet[T](dataSet: DataSet[T], fields: Expression*): Table = { - - val name = createUniqueTableName() - registerDataSetInternal(name, dataSet.javaSet, fields.toArray) - scan(name) + new TableImpl(this, asDataSetTableOperation(dataSet.javaSet, Some(fields.toArray))) } override def registerDataSet[T](name: String, dataSet: DataSet[T]): Unit = { - - checkValidTableName(name) - registerDataSetInternal(name, dataSet.javaSet) + registerTable(name, fromDataSet(dataSet)) } override def registerDataSet[T](name: String, dataSet: DataSet[T], fields: Expression*): Unit = { - - checkValidTableName(name) - registerDataSetInternal(name, dataSet.javaSet, fields.toArray) + registerTable(name, fromDataSet(dataSet, fields: _*)) } override def toDataSet[T: TypeInformation](table: Table): DataSet[T] = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala index b1805cfac71ae5..45bbec2711e177 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/scala/StreamTableEnvImpl.scala @@ -17,14 +17,13 @@ */ package org.apache.flink.table.api.scala -import org.apache.flink.api.scala._ import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.api.{StreamQueryConfig, Table, TableConfig, TableEnvImpl} +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment, asScalaStream} +import org.apache.flink.table.api.{StreamQueryConfig, Table, TableConfig, TableImpl} +import org.apache.flink.table.catalog.CatalogManager import org.apache.flink.table.expressions.Expression import org.apache.flink.table.functions.{AggregateFunction, TableAggregateFunction, TableFunction} -import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} -import org.apache.flink.streaming.api.scala.asScalaStream -import org.apache.flink.table.catalog.CatalogManager /** * The implementation for a Scala [[StreamTableEnvironment]]. @@ -43,30 +42,25 @@ class StreamTableEnvImpl( with org.apache.flink.table.api.scala.StreamTableEnvironment { override def fromDataStream[T](dataStream: DataStream[T]): Table = { - - val name = createUniqueTableName() - registerDataStreamInternal(name, dataStream.javaStream) - scan(name) + val tableOperation = asTableOperation(dataStream.javaStream, None) + new TableImpl(this, tableOperation) } override def fromDataStream[T](dataStream: DataStream[T], fields: Expression*): Table = { - - val name = createUniqueTableName() - registerDataStreamInternal(name, dataStream.javaStream, fields.toArray) - scan(name) + val tableOperation = asTableOperation(dataStream.javaStream, Some(fields.toArray)) + new TableImpl(this, tableOperation) } override def registerDataStream[T](name: String, dataStream: DataStream[T]): Unit = { - - checkValidTableName(name) - registerDataStreamInternal(name, dataStream.javaStream) + registerTable(name, fromDataStream(dataStream)) } override def registerDataStream[T]( - name: String, dataStream: DataStream[T], fields: Expression*): Unit = { - - checkValidTableName(name) - registerDataStreamInternal(name, dataStream.javaStream, fields.toArray) + name: String, + dataStream: DataStream[T], + fields: Expression*) + : Unit = { + registerTable(name, fromDataStream(dataStream, fields: _*)) } override def toAppendStream[T: TypeInformation](table: Table): DataStream[T] = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala index 704d4814e31978..54962f59ef3772 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala @@ -124,4 +124,14 @@ object FlinkRelBuilder { */ case class NamedWindowProperty(name: String, property: WindowProperty) + def of(cluster: RelOptCluster, relTable: RelOptTable): FlinkRelBuilder = { + val clusterContext = cluster.getPlanner.getContext + + new FlinkRelBuilder( + clusterContext, + cluster, + relTable.getRelOptSchema, + clusterContext.unwrap(classOf[ExpressionBridge[PlannerExpression]])) + } + } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala index c48967b8ad0b18..f84a908e304feb 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkTypeFactory.scala @@ -34,7 +34,7 @@ import org.apache.flink.api.common.typeinfo._ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.ValueTypeInfo._ import org.apache.flink.api.java.typeutils.{MapTypeInfo, MultisetTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} -import org.apache.flink.table.api.TableException +import org.apache.flink.table.api.{TableException, TableSchema} import org.apache.flink.table.calcite.FlinkTypeFactory.typeInfoToSqlTypeName import org.apache.flink.table.plan.schema._ import org.apache.flink.table.typeutils.TypeCheckUtils.isSimple @@ -178,6 +178,16 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp canonize(relType) } + /** + * Creates a struct type with the input fieldNames and input fieldTypes using FlinkTypeFactory + * + * @param tableSchema schema to convert to Calcite's specific one + * @return a struct type with the input fieldNames, input fieldTypes, and system fields + */ + def buildLogicalRowType(tableSchema: TableSchema): RelDataType = { + buildLogicalRowType(tableSchema.getFieldNames, tableSchema.getFieldTypes) + } + /** * Creates a struct type with the input fieldNames and input fieldTypes using FlinkTypeFactory * diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/PlannerTableOperation.java b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/PlannerTableOperation.java index 7c5019dbe15e1c..690e70f2b0eb46 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/PlannerTableOperation.java +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/PlannerTableOperation.java @@ -33,7 +33,7 @@ * Wrapper for valid logical plans generated by Planner. */ @Internal -public class PlannerTableOperation implements TableOperation { +public class PlannerTableOperation extends TableOperation { private final RelNode calciteTree; private final TableSchema tableSchema; @@ -59,6 +59,11 @@ public TableSchema getTableSchema() { return tableSchema; } + @Override + public String asSummaryString() { + return formatWithChildren("PlannerNode", Collections.emptyMap()); + } + @Override public List getChildren() { return Collections.emptyList(); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala index 85152d881c6f71..71536d9e5e9d36 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala @@ -32,9 +32,9 @@ import org.apache.flink.types.Row trait BatchScan extends CommonScan[Row] with DataSetRel { - protected def convertToInternalRow( + protected def convertToInternalRow[T]( schema: RowSchema, - input: DataSet[Any], + input: DataSet[T], fieldIdxs: Array[Int], config: TableConfig, rowtimeExpression: Option[RexNode]): DataSet[Row] = { @@ -58,7 +58,7 @@ trait BatchScan extends CommonScan[Row] with DataSetRel { fieldIdxs, rowtimeExpression) - val runner = new MapRunner[Any, Row]( + val runner = new MapRunner[T, Row]( function.name, function.code, function.returnType) @@ -75,12 +75,12 @@ trait BatchScan extends CommonScan[Row] with DataSetRel { private def generateConversionMapper( config: TableConfig, - inputType: TypeInformation[Any], + inputType: TypeInformation[_], outputType: TypeInformation[Row], conversionOperatorName: String, fieldNames: Seq[String], inputFieldMapping: Array[Int], - rowtimeExpression: Option[RexNode]): GeneratedFunction[MapFunction[Any, Row], Row] = { + rowtimeExpression: Option[RexNode]): GeneratedFunction[MapFunction[_, Row], Row] = { val generator = new FunctionCodeGenerator( config, @@ -102,7 +102,7 @@ trait BatchScan extends CommonScan[Row] with DataSetRel { generator.generateFunction( "DataSetSourceConversion", - classOf[MapFunction[Any, Row]], + classOf[MapFunction[_, Row]], body, outputType) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetScan.scala index 07c50595da4d65..160f27c79d26f0 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetScan.scala @@ -19,32 +19,45 @@ package org.apache.flink.table.plan.nodes.dataset import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode +import org.apache.calcite.prepare.RelOptTableImpl import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.TableScan import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvImpl} -import org.apache.flink.table.plan.schema.{DataSetTable, RowSchema} +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.types.Row +import scala.collection.JavaConverters._ + /** * Flink RelNode which matches along with DataSource. * It ensures that types without deterministic field order (e.g. POJOs) are not part of * the plan translation. + * + * This may read only part, or change the order of the fields available in a + * [[org.apache.flink.api.common.typeutils.CompositeType]] of the underlying [[DataSet]]. + * The fieldIdxs describe the indices of the fields in the + * [[org.apache.flink.api.common.typeinfo.TypeInformation]] */ class DataSetScan( cluster: RelOptCluster, traitSet: RelTraitSet, - table: RelOptTable, + catalog: RelOptSchema, + inputDataSet: DataSet[_], + fieldIdxs: Array[Int], rowRelDataType: RelDataType) - extends TableScan(cluster, traitSet, table) + extends TableScan( + cluster, + traitSet, + RelOptTableImpl.create(catalog, rowRelDataType, List[String]().asJava, null)) with BatchScan { - val dataSetTable: DataSetTable[Any] = getTable.unwrap(classOf[DataSetTable[Any]]) - override def deriveRowType(): RelDataType = rowRelDataType + override def estimateRowCount(mq: RelMetadataQuery): Double = 1000L + override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val rowCnt = metadata.getRowCount(this) planner.getCostFactory.makeCost(rowCnt, rowCnt, 0) @@ -54,7 +67,9 @@ class DataSetScan( new DataSetScan( cluster, traitSet, - getTable, + catalog, + inputDataSet, + fieldIdxs, getRowType ) } @@ -63,10 +78,11 @@ class DataSetScan( tableEnv: BatchTableEnvImpl, queryConfig: BatchQueryConfig): DataSet[Row] = { val schema = new RowSchema(rowRelDataType) - val inputDataSet: DataSet[Any] = dataSetTable.dataSet - val fieldIdxs = dataSetTable.fieldIndexes val config = tableEnv.getConfig - convertToInternalRow(schema, inputDataSet, fieldIdxs, config, None) + convertToInternalRow(schema, inputDataSet.asInstanceOf[DataSet[Any]], fieldIdxs, config, None) } + override def explainTerms(pw: RelWriter): RelWriter = pw + .item("ref", System.identityHashCode(inputDataSet)) + .item("fields", String.join(", ", rowRelDataType.getFieldNames)) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala index 658744455a3562..9fc6ef7e207c4f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamScan.scala @@ -19,40 +19,52 @@ package org.apache.flink.table.plan.nodes.datastream import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode +import org.apache.calcite.prepare.RelOptTableImpl import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.TableScan +import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.rex.RexNode import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvImpl} import org.apache.flink.table.expressions.Cast import org.apache.flink.table.plan.schema.RowSchema -import org.apache.flink.table.plan.schema.DataStreamTable import org.apache.flink.table.runtime.types.CRow import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo +import scala.collection.JavaConverters._ + /** * Flink RelNode which matches along with DataStreamSource. * It ensures that types without deterministic field order (e.g. POJOs) are not part of * the plan translation. + * + * This may read only part, or change the order of the fields available in a + * [[org.apache.flink.api.common.typeutils.CompositeType]] of the underlying [[DataStream]]. + * The fieldIdxs describe the indices of the fields in the + * [[org.apache.flink.api.common.typeinfo.TypeInformation]] */ class DataStreamScan( cluster: RelOptCluster, traitSet: RelTraitSet, - table: RelOptTable, + catalog: RelOptSchema, + dataStream: DataStream[_], + fieldIdxs: Array[Int], schema: RowSchema) - extends TableScan(cluster, traitSet, table) + extends TableScan( + cluster, + traitSet, + RelOptTableImpl.create(catalog, schema.relDataType, List[String]().asJava, null)) with StreamScan { - val dataStreamTable: DataStreamTable[Any] = getTable.unwrap(classOf[DataStreamTable[Any]]) - override def deriveRowType(): RelDataType = schema.relDataType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new DataStreamScan( cluster, traitSet, - getTable, + catalog, + dataStream, + fieldIdxs, schema ) } @@ -62,8 +74,6 @@ class DataStreamScan( queryConfig: StreamQueryConfig): DataStream[CRow] = { val config = tableEnv.getConfig - val inputDataStream: DataStream[Any] = dataStreamTable.dataStream - val fieldIdxs = dataStreamTable.fieldIndexes // get expression to extract timestamp val rowtimeExpr: Option[RexNode] = @@ -79,7 +89,15 @@ class DataStreamScan( } // convert DataStream - convertToInternalRow(schema, inputDataStream, fieldIdxs, config, rowtimeExpr) + convertToInternalRow( + schema, + dataStream.asInstanceOf[DataStream[Any]], + fieldIdxs, + config, + rowtimeExpr) } + override def explainTerms(pw: RelWriter): RelWriter = + pw.item("id", s"${dataStream.getId}") + .item("fields", s"${String.join(", ", schema.relDataType.getFieldNames)}") } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalNativeTableScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataSetScan.scala similarity index 52% rename from flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalNativeTableScan.scala rename to flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataSetScan.scala index fe4b2b69b5e71b..16af21d96c102e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalNativeTableScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataSetScan.scala @@ -21,56 +21,43 @@ package org.apache.flink.table.plan.nodes.logical import java.util import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.prepare.RelOptTableImpl +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.TableScan -import org.apache.calcite.rel.logical.LogicalTableScan import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.schema.{DataSetTable, DataStreamTable} +import org.apache.calcite.rel.{RelNode, RelWriter} +import org.apache.flink.api.java.DataSet -class FlinkLogicalNativeTableScan ( +import scala.collection.JavaConverters._ + +class FlinkLogicalDataSetScan( cluster: RelOptCluster, traitSet: RelTraitSet, - table: RelOptTable) - extends TableScan(cluster, traitSet, table) + val catalog: RelOptSchema, + val dataSet: DataSet[_], + val fieldIdxs: Array[Int], + val schema: RelDataType) + extends TableScan( + cluster, + traitSet, + RelOptTableImpl.create(catalog, schema, List[String]().asJava, null)) with FlinkLogicalRel { + + override def estimateRowCount(mq: RelMetadataQuery): Double = 1000L + + override def deriveRowType(): RelDataType = schema + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new FlinkLogicalNativeTableScan(cluster, traitSet, getTable) + new FlinkLogicalDataSetScan(cluster, traitSet, catalog, dataSet, fieldIdxs, schema) } override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val rowCnt = metadata.getRowCount(this) planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType)) } -} - -class FlinkLogicalNativeTableScanConverter - extends ConverterRule( - classOf[LogicalTableScan], - Convention.NONE, - FlinkConventions.LOGICAL, - "FlinkLogicalNativeTableScanConverter") { - - override def matches(call: RelOptRuleCall): Boolean = { - val scan = call.rel[TableScan](0) - val dataSetTable = scan.getTable.unwrap(classOf[DataSetTable[_]]) - val dataStreamTable = scan.getTable.unwrap(classOf[DataStreamTable[_]]) - dataSetTable != null || dataStreamTable != null - } - - def convert(rel: RelNode): RelNode = { - val scan = rel.asInstanceOf[TableScan] - val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL) - new FlinkLogicalNativeTableScan( - rel.getCluster, - traitSet, - scan.getTable - ) - } -} -object FlinkLogicalNativeTableScan { - val CONVERTER = new FlinkLogicalNativeTableScanConverter + override def explainTerms(pw: RelWriter): RelWriter = pw + .item("ref", System.identityHashCode(dataSet)) + .item("fields", String.join(", ", schema.getFieldNames)) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataStreamScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataStreamScan.scala new file mode 100644 index 00000000000000..4d7d815daf3a6e --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalDataStreamScan.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.logical + +import java.util + +import org.apache.calcite.plan._ +import org.apache.calcite.prepare.RelOptTableImpl +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.TableScan +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rel.{RelNode, RelWriter} +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.plan.schema.RowSchema + +import scala.collection.JavaConverters._ + +class FlinkLogicalDataStreamScan( + cluster: RelOptCluster, + traitSet: RelTraitSet, + val catalog: RelOptSchema, + val dataStream: DataStream[_], + val fieldIdxs: Array[Int], + val schema: RowSchema) + extends TableScan( + cluster, + traitSet, + RelOptTableImpl.create(catalog, schema.relDataType, List[String]().asJava, null)) + with FlinkLogicalRel { + + override def deriveRowType(): RelDataType = schema.relDataType + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new FlinkLogicalDataStreamScan(cluster, traitSet, catalog, dataStream, fieldIdxs, schema) + } + + override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val rowCnt = metadata.getRowCount(this) + planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType)) + } + + override def explainTerms(pw: RelWriter): RelWriter = + pw.item("id", s"${dataStream.getId}") + .item("fields", s"${String.join(", ", schema.relDataType.getFieldNames)}") +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 6dfa88ab54c7a9..b7701cdde07517 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -21,12 +21,11 @@ package org.apache.flink.table.plan.rules import org.apache.calcite.rel.core.RelFactories import org.apache.calcite.rel.rules._ import org.apache.calcite.tools.{RuleSet, RuleSets} -import org.apache.flink.table.plan.nodes.logical +import org.apache.flink.table.plan.nodes.logical._ import org.apache.flink.table.plan.rules.common._ -import org.apache.flink.table.plan.rules.logical.{ExtendedAggregateExtractProjectRule, _} import org.apache.flink.table.plan.rules.dataSet._ import org.apache.flink.table.plan.rules.datastream._ -import org.apache.flink.table.plan.nodes.logical._ +import org.apache.flink.table.plan.rules.logical.{ExtendedAggregateExtractProjectRule, _} object FlinkRuleSets { @@ -139,7 +138,6 @@ object FlinkRuleSets { FlinkLogicalValues.CONVERTER, FlinkLogicalTableSourceScan.CONVERTER, FlinkLogicalTableFunctionScan.CONVERTER, - FlinkLogicalNativeTableScan.CONVERTER, FlinkLogicalMatch.CONVERTER, FlinkLogicalTableAggregate.CONVERTER, FlinkLogicalWindowTableAggregate.CONVERTER diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetScanRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetScanRule.scala index 4198819fc23235..b3aea298b2e3cf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetScanRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetScanRule.scala @@ -18,48 +18,35 @@ package org.apache.flink.table.plan.rules.dataSet -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.plan.RelTraitSet import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.dataset.DataSetScan -import org.apache.flink.table.plan.schema.DataSetTable -import org.apache.flink.table.plan.nodes.logical.FlinkLogicalNativeTableScan +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalDataSetScan class DataSetScanRule extends ConverterRule( - classOf[FlinkLogicalNativeTableScan], + classOf[FlinkLogicalDataSetScan], FlinkConventions.LOGICAL, FlinkConventions.DATASET, "DataSetScanRule") { - /** - * If the input is not a DataSetTable, we want the TableScanRule to match instead - */ - override def matches(call: RelOptRuleCall): Boolean = { - val scan: FlinkLogicalNativeTableScan = call.rel(0).asInstanceOf[FlinkLogicalNativeTableScan] - val dataSetTable = scan.getTable.unwrap(classOf[DataSetTable[Any]]) - dataSetTable match { - case _: DataSetTable[Any] => - true - case _ => - false - } - } - def convert(rel: RelNode): RelNode = { - val scan: FlinkLogicalNativeTableScan = rel.asInstanceOf[FlinkLogicalNativeTableScan] + val scan: FlinkLogicalDataSetScan = rel.asInstanceOf[FlinkLogicalDataSetScan] val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASET) new DataSetScan( rel.getCluster, traitSet, - scan.getTable, - rel.getRowType + scan.catalog, + scan.dataSet, + scan.fieldIdxs, + scan.schema ) } } object DataSetScanRule { - val INSTANCE: RelOptRule = new DataSetScanRule + val INSTANCE = new DataSetScanRule } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamScanRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamScanRule.scala index d8dda80bf9853f..44f97fe047e35b 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamScanRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamScanRule.scala @@ -18,46 +18,38 @@ package org.apache.flink.table.plan.rules.datastream -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.plan.RelTraitSet import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.datastream.DataStreamScan -import org.apache.flink.table.plan.schema.{DataStreamTable, RowSchema} -import org.apache.flink.table.plan.nodes.logical.FlinkLogicalNativeTableScan +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalDataStreamScan class DataStreamScanRule extends ConverterRule( - classOf[FlinkLogicalNativeTableScan], + classOf[FlinkLogicalDataStreamScan], FlinkConventions.LOGICAL, FlinkConventions.DATASTREAM, "DataStreamScanRule") { - override def matches(call: RelOptRuleCall): Boolean = { - val scan: FlinkLogicalNativeTableScan = call.rel(0).asInstanceOf[FlinkLogicalNativeTableScan] - val dataSetTable = scan.getTable.unwrap(classOf[DataStreamTable[Any]]) - dataSetTable match { - case _: DataStreamTable[Any] => - true - case _ => - false - } - } - def convert(rel: RelNode): RelNode = { - val scan: FlinkLogicalNativeTableScan = rel.asInstanceOf[FlinkLogicalNativeTableScan] + val scan: FlinkLogicalDataStreamScan = rel.asInstanceOf[FlinkLogicalDataStreamScan] val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) new DataStreamScan( rel.getCluster, traitSet, - scan.getTable, - new RowSchema(rel.getRowType) + scan.catalog, + scan.dataStream, + scan.fieldIdxs, + scan.schema ) } } object DataStreamScanRule { - val INSTANCE: RelOptRule = new DataStreamScanRule + val INSTANCE = new DataStreamScanRule } + + diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala index 0b07f47119a073..bfdc29fec22a78 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala @@ -85,14 +85,7 @@ class LogicalCorrelateToTemporalTableJoinRule .getUnderlyingHistoryTable val rexBuilder = cluster.getRexBuilder - val expressionBridge = call.getPlanner.getContext - .unwrap(classOf[ExpressionBridge[PlannerExpression]]) - - val relBuilder = new FlinkRelBuilder(call.getPlanner.getContext, - cluster, - leftNode.getTable.getRelOptSchema, - expressionBridge) - + val relBuilder = FlinkRelBuilder.of(cluster, leftNode.getTable) val rightNode: RelNode = relBuilder.tableOperation(underlyingHistoryTable).build() val rightTimeIndicatorExpression = createRightExpression( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala deleted file mode 100644 index 6de962cc3eda9a..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/DataStreamTable.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.schema - -import org.apache.flink.streaming.api.datastream.DataStream -import org.apache.flink.table.plan.stats.FlinkStatistic - -class DataStreamTable[T]( - val dataStream: DataStream[T], - override val fieldIndexes: Array[Int], - override val fieldNames: Array[String], - override val statistic: FlinkStatistic = FlinkStatistic.UNKNOWN) - extends InlineTable[T](dataStream.getType, fieldIndexes, fieldNames, statistic) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/InlineTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/InlineTable.scala deleted file mode 100644 index 84f1d11cdd9025..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/InlineTable.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.schema - -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} -import org.apache.calcite.schema.Statistic -import org.apache.calcite.schema.impl.AbstractTable -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.table.api.{TableException, Types} -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.stats.FlinkStatistic -import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo - -abstract class InlineTable[T]( - val typeInfo: TypeInformation[T], - val fieldIndexes: Array[Int], - val fieldNames: Array[String], - val statistic: FlinkStatistic) - extends AbstractTable { - - if (fieldIndexes.length != fieldNames.length) { - throw new TableException( - s"Number of field names and field indexes must be equal.\n" + - s"Number of names is ${fieldNames.length}, number of indexes is ${fieldIndexes.length}.\n" + - s"List of column names: ${fieldNames.mkString("[", ", ", "]")}.\n" + - s"List of column indexes: ${fieldIndexes.mkString("[", ", ", "]")}.") - } - - // check uniqueness of field names - if (fieldNames.length != fieldNames.toSet.size) { - val duplicateFields = fieldNames - // count occurrences of field names - .groupBy(identity).mapValues(_.length) - // filter for occurrences > 1 and map to field name - .filter(g => g._2 > 1).keys - - throw new TableException( - s"Field names must be unique.\n" + - s"List of duplicate fields: ${duplicateFields.mkString("[", ", ", "]")}.\n" + - s"List of all fields: ${fieldNames.mkString("[", ", ", "]")}.") - } - - val fieldTypes: Array[TypeInformation[_]] = - typeInfo match { - - case ct: CompositeType[_] => - // it is ok to leave out fields - if (fieldIndexes.count(_ >= 0) > ct.getArity) { - throw new TableException( - s"Arity of type (" + ct.getFieldNames.deep + ") " + - "must not be greater than number of field names " + fieldNames.deep + ".") - } - fieldIndexes.map { - case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER => - TimeIndicatorTypeInfo.ROWTIME_INDICATOR - case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER => - TimeIndicatorTypeInfo.PROCTIME_INDICATOR - case TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER => - Types.SQL_TIMESTAMP - case TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER => - Types.SQL_TIMESTAMP - case i => ct.getTypeAt(i).asInstanceOf[TypeInformation[_]]} - - case t: TypeInformation[_] => - var cnt = 0 - val types = fieldIndexes.map { - case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER => - TimeIndicatorTypeInfo.ROWTIME_INDICATOR - case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER => - TimeIndicatorTypeInfo.PROCTIME_INDICATOR - case TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER => - Types.SQL_TIMESTAMP - case TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER => - Types.SQL_TIMESTAMP - case _ => - cnt += 1 - t.asInstanceOf[TypeInformation[_]] - } - // ensure that the atomic type is matched at most once. - if (cnt > 1) { - throw new TableException( - "Non-composite input type may have only a single field and its index must be 0.") - } else { - types - } - } - - override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { - val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory] - flinkTypeFactory.buildLogicalRowType(fieldNames, fieldTypes) - } - - /** - * Returns statistics of current table - * - * @return statistics of current table - */ - override def getStatistic: Statistic = statistic - -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/RelTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/RelTable.scala deleted file mode 100644 index 30052a8404cb17..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/RelTable.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.schema - -import org.apache.calcite.plan.RelOptTable -import org.apache.calcite.plan.RelOptTable.ToRelContext -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} -import org.apache.calcite.schema.Schema.TableType -import org.apache.calcite.schema.impl.AbstractTable -import org.apache.calcite.schema.TranslatableTable - -/** - * A [[org.apache.calcite.schema.Table]] implementation for registering - * Table API Tables in the Calcite schema to be used by Flink SQL. - * It implements [[TranslatableTable]] so that its logical scan - * can be converted to a relational expression. - * - * @see [[DataSetTable]] - */ -class RelTable(relNode: RelNode) extends AbstractTable with TranslatableTable { - - override def getJdbcTableType: TableType = ??? - - override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = relNode.getRowType - - override def toRel(context: ToRelContext, relOptTable: RelOptTable): RelNode = { - relNode - } -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/stats/FlinkStatistic.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/stats/FlinkStatistic.scala index 5469f943382e95..1acc91f0e1ff89 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/stats/FlinkStatistic.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/stats/FlinkStatistic.scala @@ -26,10 +26,9 @@ import org.apache.calcite.rel.{RelCollation, RelDistribution, RelReferentialCons import org.apache.calcite.schema.Statistic import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.table.plan.schema.TableSourceTable -import org.apache.flink.table.plan.schema.InlineTable /** - * The class provides statistics for a [[InlineTable]] or [[TableSourceTable]]. + * The class provides statistics for a [[TableSourceTable]]. * * @param tableStats The table statistics. */ diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala index ee2e749c2461ce..389ac22279a2aa 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableSourceTest.scala @@ -137,7 +137,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - s"BatchTableSourceScan(table=[[$tableName]], " + + s"BatchTableSourceScan(table=[[default_catalog, default_database, $tableName]], " + s"fields=[], " + s"source=[CsvTableSource(read fields: first)])", term("select", "1 AS _c0") @@ -161,7 +161,8 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - "BatchTableSourceScan(table=[[filterableTable]], fields=[price, id, amount])", + "BatchTableSourceScan(table=[[default_catalog, default_database, filterableTable]], " + + "fields=[price, id, amount])", term("select", "price", "id", "amount"), term("where", "<(*(price, 2), 32)") ) @@ -486,13 +487,13 @@ class TableSourceTest extends TableTestBase { } def batchSourceTableNode(sourceName: String, fields: Array[String]): String = { - s"BatchTableSourceScan(table=[[$sourceName]], " + + s"BatchTableSourceScan(table=[[default_catalog, default_database, $sourceName]], " + s"fields=[${fields.mkString(", ")}], " + s"source=[CsvTableSource(read fields: ${fields.mkString(", ")})])" } def streamSourceTableNode(sourceName: String, fields: Array[String] ): String = { - s"StreamTableSourceScan(table=[[$sourceName]], " + + s"StreamTableSourceScan(table=[[default_catalog, default_database, $sourceName]], " + s"fields=[${fields.mkString(", ")}], " + s"source=[CsvTableSource(read fields: ${fields.mkString(", ")})])" } @@ -500,17 +501,25 @@ class TableSourceTest extends TableTestBase { def batchFilterableSourceTableNode( sourceName: String, fields: Array[String], - exp: String): String = { + exp: String) + : String = { "BatchTableSourceScan(" + - s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])" + s"table=[[default_catalog, default_database, $sourceName]], fields=[${ + fields + .mkString(", ") + }], source=[filter=[$exp]])" } def streamFilterableSourceTableNode( sourceName: String, fields: Array[String], - exp: String): String = { + exp: String) + : String = { "StreamTableSourceScan(" + - s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], source=[filter=[$exp]])" + s"table=[[default_catalog, default_database, $sourceName]], fields=[${ + fields + .mkString(", ") + }], source=[filter=[$exp]])" } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/BatchTableEnvironmentTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/BatchTableEnvironmentTest.scala index dde5569533a123..c8e5aff57c8233 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/BatchTableEnvironmentTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/BatchTableEnvironmentTest.scala @@ -35,7 +35,7 @@ class BatchTableEnvironmentTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a, b, c"), term("where", ">(b, 12)")) @@ -49,8 +49,8 @@ class BatchTableEnvironmentTest extends TableTestBase { "DataSetJoin", binaryNode( "DataSetCalc", - batchTableNode(0), - batchTableNode(1), + batchTableNode(table), + batchTableNode(table2), term("select", "c")), term("where", "=(c, d)"), term("join", "c, d, e, f"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala index d1fa36bb11e232..a3f43ad5ac9b3c 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/ExplainTest.scala @@ -19,7 +19,9 @@ package org.apache.flink.table.api.batch import org.apache.flink.api.scala._ +import org.apache.flink.table.api.Table import org.apache.flink.table.api.scala._ +import org.apache.flink.table.utils.TableTestUtil.batchTableNode import org.apache.flink.test.util.MultipleProgramsTestBase import org.junit.Assert.assertEquals import org.junit._ @@ -34,14 +36,15 @@ class ExplainTest val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = BatchTableEnvironment.create(env) - val table = env.fromElements((1, "hello")) - .toTable(tEnv, 'a, 'b) - .filter("a % 2 = 0") + val scan = env.fromElements((1, "hello")).toTable(tEnv, 'a, 'b) + val table = scan.filter("a % 2 = 0") val result = tEnv.explain(table).replaceAll("\\r\\n", "\n") val source = scala.io.Source.fromFile(testFilePath + - "../../src/test/scala/resources/testFilter0.out").mkString.replaceAll("\\r\\n", "\n") - assertEquals(source, result) + "../../src/test/scala/resources/testFilter0.out").mkString + + val expected = replaceString(source, scan) + assertEquals(expected, result) } @Test @@ -49,15 +52,16 @@ class ExplainTest val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = BatchTableEnvironment.create(env) - val table = env.fromElements((1, "hello")) - .toTable(tEnv, 'a, 'b) - .filter("a % 2 = 0") + val scan = env.fromElements((1, "hello")).toTable(tEnv, 'a, 'b) + val table = scan.filter("a % 2 = 0") val result = tEnv.asInstanceOf[BatchTableEnvImpl] .explain(table, extended = true).replaceAll("\\r\\n", "\n") val source = scala.io.Source.fromFile(testFilePath + - "../../src/test/scala/resources/testFilter1.out").mkString.replaceAll("\\r\\n", "\n") - assertEquals(source, result) + "../../src/test/scala/resources/testFilter1.out").mkString + + val expected = replaceString(source, scan) + assertEquals(expected, result) } @Test @@ -71,8 +75,10 @@ class ExplainTest val result = tEnv.explain(table).replaceAll("\\r\\n", "\n") val source = scala.io.Source.fromFile(testFilePath + - "../../src/test/scala/resources/testJoin0.out").mkString.replaceAll("\\r\\n", "\n") - assertEquals(source, result) + "../../src/test/scala/resources/testJoin0.out").mkString + + val expected = replaceString(source, table1, table2) + assertEquals(expected, result) } @Test @@ -87,8 +93,10 @@ class ExplainTest val result = tEnv.asInstanceOf[BatchTableEnvImpl] .explain(table, extended = true).replaceAll("\\r\\n", "\n") val source = scala.io.Source.fromFile(testFilePath + - "../../src/test/scala/resources/testJoin1.out").mkString.replaceAll("\\r\\n", "\n") - assertEquals(source, result) + "../../src/test/scala/resources/testJoin1.out").mkString + + val expected = replaceString(source, table1, table2) + assertEquals(expected, result) } @Test @@ -102,8 +110,10 @@ class ExplainTest val result = tEnv.explain(table).replaceAll("\\r\\n", "\n") val source = scala.io.Source.fromFile(testFilePath + - "../../src/test/scala/resources/testUnion0.out").mkString.replaceAll("\\r\\n", "\n") - assertEquals(source, result) + "../../src/test/scala/resources/testUnion0.out").mkString + + val expected = replaceString(source, table1, table2) + assertEquals(expected, result) } @Test @@ -118,7 +128,29 @@ class ExplainTest val result = tEnv.asInstanceOf[BatchTableEnvImpl] .explain(table, extended = true).replaceAll("\\r\\n", "\n") val source = scala.io.Source.fromFile(testFilePath + - "../../src/test/scala/resources/testUnion1.out").mkString.replaceAll("\\r\\n", "\n") - assertEquals(source, result) + "../../src/test/scala/resources/testUnion1.out").mkString + + val expected = replaceString(source, table1, table2) + assertEquals(expected, result) + } + + + def replaceString(s: String, t1: Table, t2: Table): String = { + replaceSourceNode(replaceSourceNode(replaceString(s), t1, 0), t2, 1) + } + + def replaceString(s: String, t: Table): String = { + replaceSourceNode(replaceString(s), t, 0) + } + + private def replaceSourceNode(s: String, t: Table, idx: Int) = { + s.replace( + s"%logicalSourceNode$idx%", batchTableNode(t) + .replace("DataSetScan", "FlinkLogicalDataSetScan")) + .replace(s"%sourceNode$idx%", batchTableNode(t)) + } + + def replaceString(s: String) = { + s.replaceAll("\\r\\n", "\n") } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala index 82d336d2a6a3a2..07440fd49d99f4 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala @@ -32,13 +32,13 @@ class AggregateTest extends TableTestBase { @Test def testAggregate(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable" val aggregate = unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(table), term("select", "AVG(a) AS EXPR$0", "SUM(b) AS EXPR$1", @@ -50,13 +50,13 @@ class AggregateTest extends TableTestBase { @Test def testAggregateWithFilter(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1" val calcNode = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "CAST(1) AS a", "b", "c"), term("where", "=(a, 1)") ) @@ -75,13 +75,13 @@ class AggregateTest extends TableTestBase { @Test def testAggregateWithFilterOnNestedFields(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, (Int, Long))]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, (Int, Long))]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT avg(a), sum(b), count(c), sum(c._1) FROM MyTable WHERE a = 1" val calcNode = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "CAST(1) AS a", "b", "c", "c._1 AS $f3"), term("where", "=(a, 1)") ) @@ -102,13 +102,13 @@ class AggregateTest extends TableTestBase { @Test def testGroupAggregate(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable GROUP BY a" val aggregate = unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "a"), term("select", "a", @@ -130,13 +130,13 @@ class AggregateTest extends TableTestBase { @Test def testGroupAggregateWithFilter(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1 GROUP BY a" val calcNode = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select","CAST(1) AS a", "b", "c") , term("where","=(a, 1)") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CalcTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CalcTest.scala index 382aeddfe91bfd..5f423ad981ad3a 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CalcTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CalcTest.scala @@ -29,11 +29,11 @@ class CalcTest extends TableTestBase { @Test def testMultipleFlattening(): Unit = { val util = batchTestUtil() - util.addTable[((Int, Long), (String, Boolean), String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[((Int, Long), (String, Boolean), String)]("MyTable", 'a, 'b, 'c) val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a._1 AS _1", "a._2 AS _2", @@ -51,12 +51,12 @@ class CalcTest extends TableTestBase { @Test def testIn(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val resultStr = (1 to 30).mkString(", ") val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "c"), term("where", s"IN(b, $resultStr)") ) @@ -69,12 +69,12 @@ class CalcTest extends TableTestBase { @Test def testNotIn(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val resultStr = (1 to 30).mkString(", ") val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "c"), term("where", s"NOT IN(b, $resultStr)") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala index 733082588e890e..c776e1b84b9b47 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/CorrelateTest.scala @@ -31,7 +31,7 @@ class CorrelateTest extends TableTestBase { def testCrossJoin(): Unit = { val util = batchTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)" @@ -40,7 +40,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func1($cor0.c)"), term("correlate", s"table(func1($$cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -61,7 +61,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func1($cor0.c, '$')"), term("correlate", s"table(func1($$cor0.c, '$$'))"), term("select", "a", "b", "c", "f0"), @@ -79,7 +79,7 @@ class CorrelateTest extends TableTestBase { def testLeftOuterJoinWithLiteralTrue(): Unit = { val util = batchTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE" @@ -88,7 +88,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func1($cor0.c)"), term("correlate", s"table(func1($$cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -106,8 +106,8 @@ class CorrelateTest extends TableTestBase { def testLeftOuterJoinAsSubQuery(): Unit = { val util = batchTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) - util.addTable[(Int, Long, String)]("MyTable2", 'a2, 'b2, 'c2) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table1 = util.addTable[(Int, Long, String)]("MyTable2", 'a2, 'b2, 'c2) util.addFunction("func1", func1) val sqlQuery = @@ -120,12 +120,12 @@ class CorrelateTest extends TableTestBase { val expected = binaryNode( "DataSetJoin", - batchTableNode(1), + batchTableNode(table1), unaryNode( "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func1($cor0.c)"), term("correlate", "table(func1($cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -146,7 +146,7 @@ class CorrelateTest extends TableTestBase { def testCustomType(): Unit = { val util = batchTestUtil() val func2 = new TableFunc2 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func2", func2) val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)" @@ -155,7 +155,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func2($cor0.c)"), term("correlate", s"table(func2($$cor0.c))"), term("select", "a", "b", "c", "f0", "f1"), @@ -173,7 +173,7 @@ class CorrelateTest extends TableTestBase { @Test def testHierarchyType(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val function = new HierarchyTableFunction util.addFunction("hierarchy", function) @@ -183,7 +183,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "hierarchy($cor0.c)"), term("correlate", s"table(hierarchy($$cor0.c))"), term("select", "a", "b", "c", "f0", "f1", "f2"), @@ -201,7 +201,7 @@ class CorrelateTest extends TableTestBase { @Test def testPojoType(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val function = new PojoTableFunc util.addFunction("pojo", function) @@ -211,7 +211,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "pojo($cor0.c)"), term("correlate", s"table(pojo($$cor0.c))"), term("select", "a", "b", "c", "age", "name"), @@ -230,7 +230,7 @@ class CorrelateTest extends TableTestBase { def testFilter(): Unit = { val util = batchTestUtil() val func2 = new TableFunc2 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func2", func2) val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " + @@ -240,7 +240,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func2($cor0.c)"), term("correlate", s"table(func2($$cor0.c))"), term("select", "a", "b", "c", "f0", "f1"), @@ -260,7 +260,7 @@ class CorrelateTest extends TableTestBase { def testScalarFunction(): Unit = { val util = batchTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)" @@ -269,7 +269,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func1(SUBSTRING($cor0.c, 2))"), term("correlate", s"table(func1(SUBSTRING($$cor0.c, 2)))"), term("select", "a", "b", "c", "f0"), @@ -287,7 +287,7 @@ class CorrelateTest extends TableTestBase { def testTableFunctionWithVariableArguments(): Unit = { val util = batchTestUtil() val func1 = new JavaVarsArgTableFunc0 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) var sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1('hello', 'world', c)) AS T(s)" @@ -296,7 +296,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func1('hello', 'world', $cor0.c)"), term("correlate", s"table(func1('hello', 'world', $$cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -319,7 +319,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", "func2('hello', 'world', $cor0.c)"), term("correlate", s"table(func2('hello', 'world', $$cor0.c))"), term("select", "a", "b", "c", "f0"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala index 7796e22e33131c..ff5e560b5a68c9 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala @@ -29,7 +29,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testSingleDistinctAggregate(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT COUNT(DISTINCT a) FROM MyTable" @@ -39,7 +39,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("distinct", "a") @@ -53,7 +53,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testMultiDistinctAggregateOnSameColumn(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT a), MAX(DISTINCT a) FROM MyTable" @@ -63,7 +63,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("distinct", "a") @@ -77,7 +77,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testSingleDistinctAggregateAndOneOrMultiNonDistinctAggregate(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) // case 0x00: DISTINCT on COUNT and Non-DISTINCT on others val sqlQuery0 = "SELECT COUNT(DISTINCT a), SUM(b) FROM MyTable" @@ -88,7 +88,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), term("groupBy", "a"), @@ -108,7 +108,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), term("groupBy", "b"), @@ -123,7 +123,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testMultiDistinctAggregateOnDifferentColumn(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT b) FROM MyTable" @@ -135,7 +135,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("distinct", "a") @@ -148,7 +148,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b") ), term("distinct", "b") @@ -166,7 +166,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testMultiDistinctAndNonDistinctAggregateOnDifferentColumn(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT COUNT(DISTINCT a), SUM(DISTINCT b), COUNT(c) FROM MyTable" @@ -180,7 +180,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "c") ), term("select", "COUNT(c) AS EXPR$2") @@ -191,7 +191,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("distinct", "a") @@ -208,7 +208,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b") ), term("distinct", "b") @@ -228,7 +228,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testSingleDistinctAggregateWithGrouping(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT a, COUNT(a), SUM(DISTINCT b) FROM MyTable GROUP BY a" @@ -238,7 +238,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), term("groupBy", "a", "b"), @@ -254,7 +254,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testSingleDistinctAggregateWithGroupingAndCountStar(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b) FROM MyTable GROUP BY a" @@ -264,7 +264,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), term("groupBy", "a", "b"), @@ -280,7 +280,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testTwoDistinctAggregateWithGroupingAndCountStar(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b), COUNT(DISTINCT b) FROM MyTable GROUP BY a" @@ -292,7 +292,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("groupBy", "a"), @@ -304,7 +304,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), term("distinct", "a, b") @@ -325,7 +325,7 @@ class DistinctAggregateTest extends TableTestBase { @Test def testTwoDifferentDistinctAggregateWithGroupingAndCountStar(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT a, COUNT(*), SUM(DISTINCT b), COUNT(DISTINCT c) FROM MyTable GROUP BY a" @@ -341,7 +341,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("groupBy", "a"), @@ -353,7 +353,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), term("distinct", "a, b") @@ -373,7 +373,7 @@ class DistinctAggregateTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "c") ), term("distinct", "a, c") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala index 7f48442ca24605..e3ce577677248e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala @@ -32,7 +32,7 @@ class GroupWindowTest extends TableTestBase { @Test def testNonPartitionedTumbleWindow(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val sqlQuery = "SELECT SUM(a) AS sumA, COUNT(b) AS cntB FROM T GROUP BY TUMBLE(ts, INTERVAL '2' HOUR)" @@ -42,7 +42,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "ts, a, b") ), term("window", "TumblingGroupWindow('w$, 'ts, 7200000.millis)"), @@ -55,7 +55,7 @@ class GroupWindowTest extends TableTestBase { @Test def testPartitionedTumbleWindow(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val sqlQuery = "SELECT " + @@ -73,7 +73,7 @@ class GroupWindowTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "c"), term("window", "TumblingGroupWindow('w$, 'ts, 240000.millis)"), term("select", "c, SUM(a) AS sumA, MIN(b) AS minB, " + @@ -89,7 +89,7 @@ class GroupWindowTest extends TableTestBase { @Test def testTumbleWindowWithUdAgg() = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val weightedAvg = new WeightedAvgWithMerge util.tableEnv.registerFunction("weightedAvg", weightedAvg) @@ -103,7 +103,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "ts, b, a") ), term("window", "TumblingGroupWindow('w$, 'ts, 240000.millis)"), @@ -116,7 +116,7 @@ class GroupWindowTest extends TableTestBase { @Test def testNonPartitionedHopWindow(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val sqlQuery = "SELECT SUM(a) AS sumA, COUNT(b) AS cntB " + @@ -128,7 +128,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "ts, a, b") ), term("window", "SlidingGroupWindow('w$, 'ts, 5400000.millis, 900000.millis)"), @@ -141,7 +141,7 @@ class GroupWindowTest extends TableTestBase { @Test def testPartitionedHopWindow(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Long, Timestamp)]("T", 'a, 'b, 'c, 'd, 'ts) + val table = util.addTable[(Int, Long, String, Long, Timestamp)]("T", 'a, 'b, 'c, 'd, 'ts) val sqlQuery = "SELECT " + @@ -159,7 +159,7 @@ class GroupWindowTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "c, d"), term("window", "SlidingGroupWindow('w$, 'ts, 10800000.millis, 3600000.millis)"), term("select", "c, d, SUM(a) AS sumA, AVG(b) AS avgB, " + @@ -175,7 +175,7 @@ class GroupWindowTest extends TableTestBase { @Test def testNonPartitionedSessionWindow(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val sqlQuery = "SELECT COUNT(*) AS cnt FROM T GROUP BY SESSION(ts, INTERVAL '30' MINUTE)" @@ -185,7 +185,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "ts") ), term("window", "SessionGroupWindow('w$, 'ts, 1800000.millis)"), @@ -198,7 +198,7 @@ class GroupWindowTest extends TableTestBase { @Test def testPartitionedSessionWindow(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Int, Timestamp)]("T", 'a, 'b, 'c, 'd, 'ts) + val table = util.addTable[(Int, Long, String, Int, Timestamp)]("T", 'a, 'b, 'c, 'd, 'ts) val sqlQuery = "SELECT " + @@ -216,7 +216,7 @@ class GroupWindowTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "c, d"), term("window", "SessionGroupWindow('w$, 'ts, 43200000.millis)"), term("select", "c, d, SUM(a) AS sumA, MIN(b) AS minB, " + @@ -232,7 +232,7 @@ class GroupWindowTest extends TableTestBase { @Test def testWindowEndOnly(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val sqlQuery = "SELECT " + @@ -247,7 +247,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "ts, c") ), term("groupBy", "c"), @@ -263,7 +263,7 @@ class GroupWindowTest extends TableTestBase { @Test def testExpressionOnWindowHavingFunction() = { val util = batchTestUtil() - util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) + val table = util.addTable[(Int, Long, String, Timestamp)]("T", 'a, 'b, 'c, 'ts) val sql = "SELECT " + @@ -282,7 +282,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "ts, a") ), term("window", "SlidingGroupWindow('w$, 'ts, 60000.millis, 900000.millis)"), @@ -305,7 +305,7 @@ class GroupWindowTest extends TableTestBase { @Test def testDecomposableAggFunctions() = { val util = batchTestUtil() - util.addTable[(Int, String, Long, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime) + val table = util.addTable[(Int, String, Long, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime) val sql = "SELECT " + @@ -322,7 +322,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "rowtime", "c", "*(c, c) AS $f2") ), term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala index cfdce10dc572f4..d110125941fe05 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala @@ -29,7 +29,7 @@ class GroupingSetsTest extends TableTestBase { @Test def testGroupingSets(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g FROM MyTable " + "GROUP BY GROUPING SETS (b, c)" @@ -42,7 +42,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b", "a") ), term("groupBy", "b"), @@ -56,7 +56,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "c", "a") ), term("groupBy", "c"), @@ -74,7 +74,7 @@ class GroupingSetsTest extends TableTestBase { @Test def testCube(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " + "GROUPING(b) as gb, GROUPING(c) as gc, " + @@ -87,7 +87,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "b", "c"), term("select", "b", "c", "AVG(a) AS a") ), @@ -101,7 +101,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b", "a") ), term("groupBy", "b"), @@ -117,7 +117,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "c", "a") ), term("groupBy", "c"), @@ -133,7 +133,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("select", "AVG(a) AS a") @@ -169,7 +169,7 @@ class GroupingSetsTest extends TableTestBase { @Test def testRollup(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " + "GROUPING(b) as gb, GROUPING(c) as gc, " + @@ -181,7 +181,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "b", "c"), term("select", "b", "c", "AVG(a) AS a") ), @@ -195,7 +195,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b", "a") ), term("groupBy", "b"), @@ -211,7 +211,7 @@ class GroupingSetsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), term("select", "AVG(a) AS a") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/JoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/JoinTest.scala index a3a597f79e32f8..8574fcfcc8b52e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/JoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/JoinTest.scala @@ -29,8 +29,8 @@ class JoinTest extends TableTestBase { @Test def testLeftOuterJoinEquiPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t LEFT OUTER JOIN s ON a = z" val result = util.tableEnv.sqlQuery(query) @@ -41,12 +41,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -62,8 +62,8 @@ class JoinTest extends TableTestBase { @Test def testLeftOuterJoinEquiAndLocalPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t LEFT OUTER JOIN s ON a = z AND b < 2" val result = util.tableEnv.sqlQuery(query) @@ -74,12 +74,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "<(b, 2) AS $f3") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "y", "z") ), term("where", "AND(=(a, z), $f3)"), @@ -95,8 +95,8 @@ class JoinTest extends TableTestBase { @Test def testLeftOuterJoinEquiAndNonEquiPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t LEFT OUTER JOIN s ON a = z AND b < x" val result = util.tableEnv.sqlQuery(query) @@ -107,10 +107,10 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), - batchTableNode(1), + batchTableNode(table1), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "LeftOuterJoin") @@ -124,8 +124,8 @@ class JoinTest extends TableTestBase { @Test def testRightOuterJoinEquiPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t RIGHT OUTER JOIN s ON a = z" val result = util.tableEnv.sqlQuery(query) @@ -136,12 +136,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -157,8 +157,8 @@ class JoinTest extends TableTestBase { @Test def testRightOuterJoinEquiAndLocalPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, x FROM t RIGHT OUTER JOIN s ON a = z AND x < 2" val result = util.tableEnv.sqlQuery(query) @@ -169,12 +169,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "x", "z", "<(x, 2) AS $f3") ), term("where", "AND(=(a, z), $f3)"), @@ -190,8 +190,8 @@ class JoinTest extends TableTestBase { @Test def testRightOuterJoinEquiAndNonEquiPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t RIGHT OUTER JOIN s ON a = z AND b < x" val result = util.tableEnv.sqlQuery(query) @@ -202,10 +202,10 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), - batchTableNode(1), + batchTableNode(table1), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "RightOuterJoin") @@ -219,8 +219,8 @@ class JoinTest extends TableTestBase { @Test def testFullOuterJoinEquiPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t FULL OUTER JOIN s ON a = z" val result = util.tableEnv.sqlQuery(query) @@ -231,12 +231,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -252,8 +252,8 @@ class JoinTest extends TableTestBase { @Test def testFullOuterJoinEquiAndLocalPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t FULL OUTER JOIN s ON a = z AND b < 2 AND z > 5" val result = util.tableEnv.sqlQuery(query) @@ -264,12 +264,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "<(b, 2) AS $f3") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "y", "z", ">(z, 5) AS $f3") ), term("where", "AND(=(a, z), $f3, $f30)"), @@ -285,8 +285,8 @@ class JoinTest extends TableTestBase { @Test def testFullOuterJoinEquiAndNonEquiPred(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val table = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val table1 = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t FULL OUTER JOIN s ON a = z AND b < x" val result = util.tableEnv.sqlQuery(query) @@ -297,10 +297,10 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b") ), - batchTableNode(1), + batchTableNode(table1), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "FullOuterJoin") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala index 16c3174bbd3d19..44879e87377a7e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala @@ -36,8 +36,8 @@ class SetOperatorsTest extends TableTestBase { val expected = binaryNode( "DataSetMinus", - batchTableNode(0), - batchTableNode(0), + batchTableNode(t), + batchTableNode(t), term("minus", "a", "b", "c") ) @@ -49,21 +49,21 @@ class SetOperatorsTest extends TableTestBase { @Test def testExists(): Unit = { val util = batchTestUtil() - util.addTable[(Long, Int, String)]("A", 'a_long, 'a_int, 'a_string) - util.addTable[(Long, Int, String)]("B", 'b_long, 'b_int, 'b_string) + val table = util.addTable[(Long, Int, String)]("A", 'a_long, 'a_int, 'a_string) + val table1 = util.addTable[(Long, Int, String)]("B", 'b_long, 'b_int, 'b_string) val expected = unaryNode( "DataSetCalc", binaryNode( "DataSetJoin", - batchTableNode(0), + batchTableNode(table), unaryNode( "DataSetCalc", unaryNode( "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "b_long AS b_long3", "true AS $f0"), term("where", "IS NOT NULL(b_long)") ), @@ -88,7 +88,7 @@ class SetOperatorsTest extends TableTestBase { @Test def testNotIn(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("A", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("A", 'a, 'b, 'c) val expected = unaryNode( "DataSetCalc", @@ -98,12 +98,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetCalc", binaryNode( "DataSetSingleRowJoin", - batchTableNode(0), + batchTableNode(table), unaryNode( "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b"), term("where", "OR(=(b, 6), =(b, 1))") ), @@ -119,7 +119,7 @@ class SetOperatorsTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b", "true AS $f1"), term("where", "OR(=(b, 6), =(b, 1))") ), @@ -143,11 +143,11 @@ class SetOperatorsTest extends TableTestBase { @Test def testInWithFields(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, Int, String, Long)]("A", 'a, 'b, 'c, 'd, 'e) + val table = util.addTable[(Int, Long, Int, String, Long)]("A", 'a, 'b, 'c, 'd, 'e) val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "c", "d", "e"), term("where", "OR(=(a, c), =(a, CAST(b)), =(a, 5))") ) @@ -176,18 +176,18 @@ class SetOperatorsTest extends TableTestBase { @Test def testUnionNullableTypes(): Unit = { val util = batchTestUtil() - util.addTable[((Int, String), (Int, String), Int)]("A", 'a, 'b, 'c) + val table = util.addTable[((Int, String), (Int, String), Int)]("A", 'a, 'b, 'c) val expected = binaryNode( "DataSetUnion", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "CASE(>(c, 0), b, null) AS EXPR$0") ), term("all", "true"), @@ -206,18 +206,18 @@ class SetOperatorsTest extends TableTestBase { val typeInfo = Types.ROW( new GenericTypeInfo(classOf[NonPojo]), new GenericTypeInfo(classOf[NonPojo])) - util.addJavaTable(typeInfo, "A", "a, b") + val table = util.addJavaTable(typeInfo, "A", "a, b") val expected = binaryNode( "DataSetUnion", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a") ), unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "b") ), term("all", "true"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala index 37dc3e6d2714f5..0bd6b20557f747 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala @@ -29,7 +29,7 @@ class SingleRowJoinTest extends TableTestBase { @Test def testSingleRowCrossJoin(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Int)]("A", 'a1, 'a2) + val table = util.addTable[(Int, Int)]("A", 'a1, 'a2) val query = "SELECT a1, asum " + @@ -40,14 +40,14 @@ class SingleRowJoinTest extends TableTestBase { "DataSetSingleRowJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a1") ), unaryNode( "DataSetCalc", unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(table), term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1") ), term("select", "+($f0, $f1) AS asum") @@ -63,7 +63,7 @@ class SingleRowJoinTest extends TableTestBase { @Test def testSingleRowEquiJoin(): Unit = { val util = batchTestUtil() - util.addTable[(Int, String)]("A", 'a1, 'a2) + val table = util.addTable[(Int, String)]("A", 'a1, 'a2) val query = "SELECT a1, a2 " + @@ -75,12 +75,12 @@ class SingleRowJoinTest extends TableTestBase { "DataSetCalc", binaryNode( "DataSetSingleRowJoin", - batchTableNode(0), + batchTableNode(table), unaryNode( "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a1") ), term("select", "COUNT(a1) AS cnt") @@ -98,7 +98,7 @@ class SingleRowJoinTest extends TableTestBase { @Test def testSingleRowNotEquiJoin(): Unit = { val util = batchTestUtil() - util.addTable[(Int, String)]("A", 'a1, 'a2) + val table = util.addTable[(Int, String)]("A", 'a1, 'a2) val query = "SELECT a1, a2 " + @@ -110,12 +110,12 @@ class SingleRowJoinTest extends TableTestBase { "DataSetCalc", binaryNode( "DataSetSingleRowJoin", - batchTableNode(0), + batchTableNode(table), unaryNode( "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a1") ), term("select", "COUNT(a1) AS cnt") @@ -133,8 +133,8 @@ class SingleRowJoinTest extends TableTestBase { @Test def testSingleRowJoinWithComplexPredicate(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long)]("A", 'a1, 'a2) - util.addTable[(Int, Long)]("B", 'b1, 'b2) + val table = util.addTable[(Int, Long)]("A", 'a1, 'a2) + val table1 = util.addTable[(Int, Long)]("B", 'b1, 'b2) val query = "SELECT a1, a2, b1, b2 " + @@ -143,10 +143,10 @@ class SingleRowJoinTest extends TableTestBase { val expected = binaryNode( "DataSetSingleRowJoin", - batchTableNode(0), + batchTableNode(table), unaryNode( "DataSetAggregate", - batchTableNode(1), + batchTableNode(table1), term("select", "MIN(b1) AS b1", "MAX(b2) AS b2") ), term("where", "AND(<(a1, b1)", "=(a2, b2))"), @@ -160,8 +160,8 @@ class SingleRowJoinTest extends TableTestBase { @Test def testRightSingleLeftJoinEqualPredicate(): Unit = { val util = batchTestUtil() - util.addTable[(Long, Int)]("A", 'a1, 'a2) - util.addTable[(Int, Int)]("B", 'b1, 'b2) + val table = util.addTable[(Long, Int)]("A", 'a1, 'a2) + val table1 = util.addTable[(Int, Int)]("B", 'b1, 'b2) val queryLeftJoin = "SELECT a2 " + @@ -175,7 +175,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetSingleRowJoin", - batchTableNode(0), + batchTableNode(table), term("where", "=(a1, cnt)"), term("join", "a1", "a2", "cnt"), term("joinType", "NestedLoopLeftJoin") @@ -186,7 +186,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "")), term("select", "COUNT(*) AS cnt") ) @@ -197,8 +197,8 @@ class SingleRowJoinTest extends TableTestBase { @Test def testRightSingleLeftJoinNotEqualPredicate(): Unit = { val util = batchTestUtil() - util.addTable[(Long, Int)]("A", 'a1, 'a2) - util.addTable[(Int, Int)]("B", 'b1, 'b2) + val table = util.addTable[(Long, Int)]("A", 'a1, 'a2) + val table1 = util.addTable[(Int, Int)]("B", 'b1, 'b2) val queryLeftJoin = "SELECT a2 " + @@ -212,7 +212,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetSingleRowJoin", - batchTableNode(0), + batchTableNode(table), term("where", ">(a1, cnt)"), term("join", "a1", "a2", "cnt"), term("joinType", "NestedLoopLeftJoin") @@ -223,7 +223,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "")), term("select", "COUNT(*) AS cnt") ) @@ -234,8 +234,8 @@ class SingleRowJoinTest extends TableTestBase { @Test def testLeftSingleRightJoinEqualPredicate(): Unit = { val util = batchTestUtil() - util.addTable[(Long, Long)]("A", 'a1, 'a2) - util.addTable[(Long, Long)]("B", 'b1, 'b2) + val table = util.addTable[(Long, Long)]("A", 'a1, 'a2) + val table1 = util.addTable[(Long, Long)]("B", 'b1, 'b2) val queryRightJoin = "SELECT a1 " + @@ -259,11 +259,11 @@ class SingleRowJoinTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "")), term("select", "COUNT(*) AS cnt") ) + "\n" + - batchTableNode(0) + batchTableNode(table) util.verifySql(queryRightJoin, expected) } @@ -271,8 +271,8 @@ class SingleRowJoinTest extends TableTestBase { @Test def testLeftSingleRightJoinNotEqualPredicate(): Unit = { val util = batchTestUtil() - util.addTable[(Long, Long)]("A", 'a1, 'a2) - util.addTable[(Long, Long)]("B", 'b1, 'b2) + val table = util.addTable[(Long, Long)]("A", 'a1, 'a2) + val table1 = util.addTable[(Long, Long)]("B", 'b1, 'b2) val queryRightJoin = "SELECT a1 " + @@ -297,11 +297,11 @@ class SingleRowJoinTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "")), term("select", "COUNT(*) AS cnt") ) + "\n" + - batchTableNode(0) + batchTableNode(table) util.verifySql(queryRightJoin, expected) } @@ -309,7 +309,7 @@ class SingleRowJoinTest extends TableTestBase { @Test def testSingleRowJoinInnerJoin(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Int)]("A", 'a1, 'a2) + val table = util.addTable[(Int, Int)]("A", 'a1, 'a2) val query = "SELECT a2, sum(a1) " + "FROM A " + @@ -323,7 +323,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetSingleRowJoin", unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "a2"), term("select", "a2", "SUM(a1) AS EXPR$1") ), @@ -339,7 +339,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a1") ), term("select", "SUM(a1) AS $f0") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala index df65481e9e87dd..25a98d8212df60 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala @@ -41,7 +41,7 @@ class AggregateTest extends TableTestBase { val calcNode = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "b", "c"), term("where", "=(a, 1)") ) @@ -68,7 +68,7 @@ class AggregateTest extends TableTestBase { val expected = unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "AVG(a) AS TMP_0", "SUM(b) AS TMP_1", @@ -87,7 +87,7 @@ class AggregateTest extends TableTestBase { val calcNode = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), // ReduceExpressionsRule will add cast for Project node by force // if the input of the Project node has constant expression. term("select", "CAST(1) AS a", "b", "c"), @@ -116,7 +116,7 @@ class AggregateTest extends TableTestBase { val calcNode = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), // ReduceExpressionsRule will add cast for Project node by force // if the input of the Project node has constant expression. term("select", "CAST(1) AS a", "b", "c", "c._1 AS $f3"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala index c38dd486c94e71..ea9880c095469b 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala @@ -38,7 +38,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a._1 AS a$_1", "a._2 AS a$_2", @@ -61,7 +61,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a._1 AS a$_1", "a._2 AS a$_2", @@ -85,7 +85,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "giveMeCaseClass$().my AS _c0", "giveMeCaseClass$().clazz AS _c1", @@ -110,7 +110,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "b") ) @@ -124,7 +124,7 @@ class CalcTest extends TableTestBase { val resultTable1 = sourceTable.select('*) val resultTable2 = sourceTable.select('a, 'b, 'c, 'd) - val expected = batchTableNode(0) + val expected = batchTableNode(sourceTable) util.verifyTable(resultTable1, expected) util.verifyTable(resultTable2, expected) @@ -140,7 +140,7 @@ class CalcTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "b") ), term("select", "SUM(a) AS TMP_0", "MAX(b) AS TMP_1") @@ -160,7 +160,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "MyHashCode$(c) AS _c0", "b") ) @@ -179,7 +179,7 @@ class CalcTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "c") ), term("distinct", "a", "c") @@ -200,7 +200,7 @@ class CalcTest extends TableTestBase { "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "c") ), term("distinct", "a", "c") @@ -222,7 +222,7 @@ class CalcTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "c") ), term("groupBy", "c"), @@ -247,7 +247,7 @@ class CalcTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), // As stated in https://issues.apache.org/jira/browse/CALCITE-1584 // Calcite planner doesn't promise to retain field names. term("select", "a", "UPPER(c) AS k") @@ -274,7 +274,7 @@ class CalcTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), // As stated in https://issues.apache.org/jira/browse/CALCITE-1584 // Calcite planner doesn't promise to retain field names. term("select", "a", "MyHashCode$(c) AS k") @@ -301,7 +301,7 @@ class CalcTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetAggregate", - batchTableNode(0), + batchTableNode(sourceTable), term("groupBy", "word"), term("select", "word", "SUM(frequency) AS TMP_0") ), @@ -323,7 +323,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(sourceTable), term("select", "a", "b"), term("where", "AND(AND(>(a, 0), <(b, 2)), =(MOD(a, 2), 1))") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/ColumnFunctionsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/ColumnFunctionsTest.scala index 786ee8472e955c..f325f0c7779d04 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/ColumnFunctionsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/ColumnFunctionsTest.scala @@ -51,7 +51,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataSetSort", - batchTableNode(0), + batchTableNode(t), term("orderBy", "a ASC", "b ASC", "c ASC") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala index 91fa483328fe87..9ec49d1a135560 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala @@ -45,7 +45,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "s"), @@ -66,7 +66,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", s"${function.functionIdentifier}($$2, '$$')"), term("correlate", s"table(${function.getClass.getSimpleName}(c, '$$'))"), term("select", "a", "b", "c", "s"), @@ -92,7 +92,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "s"), @@ -119,7 +119,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "s"), @@ -149,7 +149,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(sourceTable), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "d", "e"), @@ -196,7 +196,7 @@ class CorrelateTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetCorrelate", - batchTableNode(0), + batchTableNode(sourceTable), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "d", "e"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala index d040f8296b1dee..1720a66ff8ffd0 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/GroupWindowTest.scala @@ -46,7 +46,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'long, 2)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -69,7 +69,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'long, 5.millis)"), term("select", "string", "myWeightedAvg(long, int) AS TMP_0") @@ -90,7 +90,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'long, 5.millis)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -113,7 +113,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "long", "int") ), term("window", "TumblingGroupWindow('w, 'long, 5.millis)"), @@ -137,7 +137,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "long", "int") ), term("window", "TumblingGroupWindow('w, 'long, 2)"), @@ -159,7 +159,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'ts, 7200000.millis)"), term("select", "string", "COUNT(int) AS TMP_0", @@ -181,7 +181,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'ts, 7200000.millis)"), term("select", "string", "COUNT(int) AS TMP_0", @@ -207,7 +207,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'long, 8.millis, 10.millis)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -228,7 +228,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'long, 2, 1)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -251,7 +251,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'long, 8.millis, 10.millis)"), term("select", "string", "myWeightedAvg(long, int) AS TMP_0") @@ -274,7 +274,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "long", "int") ), term("window", "SlidingGroupWindow('w, 'long, 8.millis, 10.millis)"), @@ -298,7 +298,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "long", "int") ), term("window", "SlidingGroupWindow('w, 'long, 2, 1)"), @@ -320,7 +320,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'ts, 3600000.millis, 600000.millis)"), term("select", "string", "COUNT(int) AS TMP_0", @@ -342,7 +342,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'ts, 3600000.millis, 600000.millis)"), term("select", "string", "COUNT(int) AS TMP_0", @@ -368,7 +368,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'long, 7.millis)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -391,7 +391,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'long, 7.millis)"), term("select", "string", "myWeightedAvg(long, int) AS TMP_0") @@ -412,7 +412,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'ts, 1800000.millis)"), term("select", "string", "COUNT(int) AS TMP_0", @@ -434,7 +434,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataSetWindowAggregate", - batchTableNode(0), + batchTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'ts, 1800000.millis)"), term("select", "string", "COUNT(int) AS TMP_0", @@ -461,7 +461,7 @@ class GroupWindowTest extends TableTestBase { "DataSetWindowAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "rowtime", "c", "*(c, c) AS $f2") ), term("window", "TumblingGroupWindow('w, 'rowtime, 900000.millis)"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/JoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/JoinTest.scala index ce6225280839cc..1dcb80382ef796 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/JoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/JoinTest.scala @@ -42,12 +42,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(s), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -74,12 +74,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(s), term("select", "y", "z") ), term("where", "AND(=(a, z), <(b, 2))"), @@ -106,10 +106,10 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), - batchTableNode(1), + batchTableNode(s), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "LeftOuterJoin") @@ -134,12 +134,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(s), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -166,12 +166,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(s), term("select", "x", "z") ), term("where", "AND(=(a, z), <(x, 2))"), @@ -198,10 +198,10 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), - batchTableNode(1), + batchTableNode(s), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "RightOuterJoin") @@ -226,12 +226,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(s), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -258,12 +258,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(s), term("select", "y", "z") ), term("where", "AND(=(a, z), <(b, 2))"), @@ -290,10 +290,10 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a", "b") ), - batchTableNode(1), + batchTableNode(s), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "FullOuterJoin") @@ -321,12 +321,12 @@ class JoinTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t1), term("select", "b", "c") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(t2), term("select", "e", "f") ), term("where", "=(b, e)"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/SetOperatorsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/SetOperatorsTest.scala index f0f1ca3f984d4a..c09303343e7c8f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/SetOperatorsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/SetOperatorsTest.scala @@ -43,12 +43,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetCalc", binaryNode( "DataSetJoin", - batchTableNode(0), + batchTableNode(t), unaryNode( "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a AS a1"), term("where", "=(b, 'two')") ), @@ -73,7 +73,7 @@ class SetOperatorsTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "IN(b, 1972-02-22 07:12:00.333) AS b2") ) @@ -93,12 +93,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetUnion", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a") ), unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "CASE(>(c, 0), b, null) AS _c0") ), term("all", "true"), @@ -122,12 +122,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetUnion", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a") ), unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "b") ), term("all", "true"), @@ -156,13 +156,13 @@ class SetOperatorsTest extends TableTestBase { "DataSetUnion", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(left), term("select", "a", "b", "c"), term("where", ">(a, 0)") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(right), term("select", "a", "b", "c"), term("where", ">(a, 0)") ), @@ -197,13 +197,13 @@ class SetOperatorsTest extends TableTestBase { "DataSetMinus", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(left), term("select", "a", "b", "c"), term("where", ">(a, 0)") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(right), term("select", "a", "b", "c"), term("where", ">(a, 0)") ), @@ -232,12 +232,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetUnion", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(left), term("select", "b", "c") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(right), term("select", "b", "c") ), term("all", "true"), @@ -262,12 +262,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetMinus", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(left), term("select", "b", "c") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(right), term("select", "b", "c") ), term("minus", "b", "c") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/CorrelateStringExpressionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/CorrelateStringExpressionTest.scala index 4d4abaaa476035..c5bdf3c9da5ce1 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/CorrelateStringExpressionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/CorrelateStringExpressionTest.scala @@ -19,13 +19,14 @@ package org.apache.flink.table.api.batch.table.stringexpr import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.{DataSet => JDataSet} import org.apache.flink.api.scala._ +import org.apache.flink.table.api.Types import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.{Table, Types} import org.apache.flink.table.utils.{PojoTableFunc, TableFunc2, _} -import org.apache.flink.table.utils._ import org.apache.flink.types.Row import org.junit.Test +import org.mockito.Mockito.{mock, when} class CorrelateStringExpressionTest extends TableTestBase { @@ -34,8 +35,15 @@ class CorrelateStringExpressionTest extends TableTestBase { val util = batchTestUtil() val typeInfo = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING): _*) - val sTab = util.addTable[(Int, Long, String)]("Table1", 'a, 'b, 'c) - val jTab = util.addJavaTable[Row](typeInfo, "Table2", "a, b, c") + + val jDs = mock(classOf[JDataSet[Row]]) + when(jDs.getType).thenReturn(typeInfo) + + val sDs = mock(classOf[DataSet[Row]]) + when(sDs.javaSet).thenReturn(jDs) + + val jTab = util.javaTableEnv.fromDataSet(jDs, "a, b, c") + val sTab = util.tableEnv.fromDataSet(sDs, 'a, 'b, 'c) // test cross join val func1 = new TableFunc1 diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/SetOperatorsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/SetOperatorsTest.scala index f134dd913fcf00..c75318789e6055 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/SetOperatorsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/SetOperatorsTest.scala @@ -40,12 +40,12 @@ class SetOperatorsTest extends TableTestBase { "DataSetCalc", binaryNode( "DataSetJoin", - batchTableNode(0), + batchTableNode(t), unaryNode( "DataSetDistinct", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "a AS a1"), term("where", "=(b, 'two')") ), @@ -70,7 +70,7 @@ class SetOperatorsTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(t), term("select", "IN(b, CAST('1972-02-22 07:12:00.333')) AS b2") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala index 9c3defea2bb27c..ec3d0d72ff4ad7 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/ExplainTest.scala @@ -20,7 +20,9 @@ package org.apache.flink.table.api.stream import org.apache.flink.api.scala._ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.Table import org.apache.flink.table.api.scala._ +import org.apache.flink.table.utils.TableTestUtil.streamTableNode import org.apache.flink.test.util.AbstractTestBase import org.junit.Assert.assertEquals import org.junit._ @@ -34,15 +36,14 @@ class ExplainTest extends AbstractTestBase { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = StreamTableEnvironment.create(env) - val table = env.fromElements((1, "hello")) - .toTable(tEnv, 'a, 'b) - .filter("a % 2 = 0") + val scan = env.fromElements((1, "hello")).toTable(tEnv, 'a, 'b) + val table = scan.filter("a % 2 = 0") val result = replaceString(tEnv.explain(table)) val source = scala.io.Source.fromFile(testFilePath + "../../src/test/scala/resources/testFilterStream0.out").mkString - val expect = replaceString(source) + val expect = replaceString(source, scan) assertEquals(expect, result) } @@ -59,11 +60,27 @@ class ExplainTest extends AbstractTestBase { val source = scala.io.Source.fromFile(testFilePath + "../../src/test/scala/resources/testUnionStream0.out").mkString - val expect = replaceString(source) + val expect = replaceString(source, table1, table2) assertEquals(expect, result) } - def replaceString(s: String): String = { + def replaceString(s: String, t1: Table, t2: Table): String = { + replaceSourceNode(replaceSourceNode(replaceString(s), t1, 0), t2, 1) + } + + def replaceString(s: String, t: Table): String = { + replaceSourceNode(replaceString(s), t, 0) + } + + private def replaceSourceNode(s: String, t: Table, idx: Int) = { + replaceString(s) + .replace( + s"%logicalSourceNode$idx%", streamTableNode(t) + .replace("DataStreamScan", "FlinkLogicalDataStreamScan")) + .replace(s"%sourceNode$idx%", streamTableNode(t)) + } + + def replaceString(s: String) = { /* Stage {id} is ignored, because id keeps incrementing in test class * while StreamExecutionEnvironment is up */ diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala index 9f192de8c751cd..b4d3726aff793b 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala @@ -47,7 +47,7 @@ class StreamTableEnvironmentTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a, b, c"), term("where", ">(b, 12)")) @@ -60,8 +60,8 @@ class StreamTableEnvironmentTest extends TableTestBase { val expected2 = binaryNode( "DataStreamUnion", - streamTableNode(1), - streamTableNode(0), + streamTableNode(table2), + streamTableNode(table), term("all", "true"), term("union all", "d, e, f")) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala index 9cd5fd62426a52..1a455876cd9dc3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala @@ -35,7 +35,7 @@ import org.junit.Test class AggregateTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]( + private val table = streamUtil.addTable[(Int, String, Long)]( "MyTable", 'a, 'b, 'c, 'proctime.proctime, 'rowtime.rowtime) @Test @@ -49,7 +49,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "b", "a") ), term("groupBy", "b"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala index f179ae6cfac320..89ccb763fdac0c 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/CorrelateTest.scala @@ -33,7 +33,7 @@ class CorrelateTest extends TableTestBase { def testCrossJoin(): Unit = { val util = streamTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)" @@ -42,7 +42,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func1($cor0.c)"), term("correlate", s"table(func1($$cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -63,7 +63,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func1($cor0.c, '$')"), term("correlate", s"table(func1($$cor0.c, '$$'))"), term("select", "a", "b", "c", "f0"), @@ -81,7 +81,7 @@ class CorrelateTest extends TableTestBase { def testLeftOuterJoinWithLiteralTrue(): Unit = { val util = streamTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE" @@ -90,7 +90,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func1($cor0.c)"), term("correlate", s"table(func1($$cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -108,8 +108,8 @@ class CorrelateTest extends TableTestBase { def testLeftOuterJoinAsSubQuery(): Unit = { val util = streamTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) - util.addTable[(Int, Long, String)]("MyTable2", 'a2, 'b2, 'c2) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table1 = util.addTable[(Int, Long, String)]("MyTable2", 'a2, 'b2, 'c2) util.addFunction("func1", func1) val sqlQuery = @@ -122,12 +122,12 @@ class CorrelateTest extends TableTestBase { val expected = binaryNode( "DataStreamJoin", - streamTableNode(1), + streamTableNode(table1), unaryNode( "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func1($cor0.c)"), term("correlate", "table(func1($cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -148,7 +148,7 @@ class CorrelateTest extends TableTestBase { def testCustomType(): Unit = { val util = streamTestUtil() val func2 = new TableFunc2 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func2", func2) val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)" @@ -157,7 +157,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func2($cor0.c)"), term("correlate", s"table(func2($$cor0.c))"), term("select", "a", "b", "c", "f0", "f1"), @@ -175,7 +175,7 @@ class CorrelateTest extends TableTestBase { @Test def testHierarchyType(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val function = new HierarchyTableFunction util.addFunction("hierarchy", function) @@ -185,7 +185,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "hierarchy($cor0.c)"), term("correlate", s"table(hierarchy($$cor0.c))"), term("select", "a", "b", "c", "f0", "f1", "f2"), @@ -203,7 +203,7 @@ class CorrelateTest extends TableTestBase { @Test def testPojoType(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val function = new PojoTableFunc util.addFunction("pojo", function) @@ -213,7 +213,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "pojo($cor0.c)"), term("correlate", s"table(pojo($$cor0.c))"), term("select", "a", "b", "c", "age", "name"), @@ -232,7 +232,7 @@ class CorrelateTest extends TableTestBase { def testRowType(): Unit = { val util = streamTestUtil() val rowType = Types.ROW(Types.INT, Types.BOOLEAN, Types.ROW(Types.INT, Types.INT, Types.INT)) - util.addTable[Row]("MyTable", 'a, 'b, 'c)(rowType) + val table = util.addTable[Row]("MyTable", 'a, 'b, 'c)(rowType) val function = new TableFunc5 util.addFunction("tableFunc5", function) @@ -242,7 +242,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "tableFunc5($cor0.c)"), term("correlate", "table(tableFunc5($cor0.c))"), term("select", "a", "b", "c", "f0", "f1", "f2"), @@ -265,7 +265,7 @@ class CorrelateTest extends TableTestBase { def testFilter(): Unit = { val util = streamTestUtil() val func2 = new TableFunc2 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func2", func2) val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " + @@ -275,7 +275,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func2($cor0.c)"), term("correlate", s"table(func2($$cor0.c))"), term("select", "a", "b", "c", "f0", "f1"), @@ -295,7 +295,7 @@ class CorrelateTest extends TableTestBase { def testScalarFunction(): Unit = { val util = streamTestUtil() val func1 = new TableFunc1 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)" @@ -304,7 +304,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func1(SUBSTRING($cor0.c, 2))"), term("correlate", s"table(func1(SUBSTRING($$cor0.c, 2)))"), term("select", "a", "b", "c", "f0"), @@ -322,7 +322,7 @@ class CorrelateTest extends TableTestBase { def testTableFunctionWithVariableArguments(): Unit = { val util = streamTestUtil() val func1 = new JavaVarsArgTableFunc0 - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) util.addFunction("func1", func1) var sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1('hello', 'world', c)) AS T(s)" @@ -331,7 +331,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func1('hello', 'world', $cor0.c)"), term("correlate", s"table(func1('hello', 'world', $$cor0.c))"), term("select", "a", "b", "c", "f0"), @@ -354,7 +354,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", "func2('hello', 'world', $cor0.c)"), term("correlate", s"table(func2('hello', 'world', $$cor0.c))"), term("select", "a", "b", "c", "f0"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/DistinctAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/DistinctAggregateTest.scala index 20e5bebd131767..03335319121b97 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/DistinctAggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/DistinctAggregateTest.scala @@ -26,7 +26,7 @@ import org.junit.{Ignore, Test} class DistinctAggregateTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]( + private val table = streamUtil.addTable[(Int, String, Long)]( "MyTable", 'a, 'b, 'c, 'proctime.proctime, 'rowtime.rowtime) @@ -40,7 +40,7 @@ class DistinctAggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a, b, c") ), term("groupBy", "a, b, c"), @@ -61,7 +61,7 @@ class DistinctAggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a") ), term("groupBy", "a"), @@ -82,7 +82,7 @@ class DistinctAggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "c", "a", "b") ), term("groupBy", "c"), @@ -103,7 +103,7 @@ class DistinctAggregateTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime", "a") ), term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), @@ -125,7 +125,7 @@ class DistinctAggregateTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime", "a") ), term("window", "SlidingGroupWindow('w$, 'rowtime, 3600000.millis, 900000.millis)"), @@ -148,7 +148,7 @@ class DistinctAggregateTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "rowtime", "c") ), term("groupBy", "a"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala index b2e61392420065..42e2a769265adb 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala @@ -27,7 +27,7 @@ import org.junit.Test class GroupWindowTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]( + private val table = streamUtil.addTable[(Int, String, Long)]( "MyTable", 'a, 'b, 'c, 'proctime.proctime, 'rowtime.rowtime) @Test @@ -48,7 +48,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime", "c", "a") ), term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), @@ -82,7 +82,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "proctime", "c", "a") ), term("window", "SlidingGroupWindow('w$, 'proctime, 3600000.millis, 900000.millis)"), @@ -117,7 +117,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "proctime", "c", "a") ), term("window", "SessionGroupWindow('w$, 'proctime, 900000.millis)"), @@ -149,7 +149,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime") ), term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), @@ -185,7 +185,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime, a") ), term("window", "SlidingGroupWindow('w$, 'rowtime, 60000.millis, 900000.millis)"), @@ -233,7 +233,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime, a") ), term("window", "TumblingGroupWindow('w$, 'rowtime, 2.millis)"), @@ -277,7 +277,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime", "c", "*(c, c) AS $f2") ), term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala index be3da6d55b2130..f9a1621c74c8ea 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/JoinTest.scala @@ -33,8 +33,18 @@ import org.junit.Test */ class JoinTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]("MyTable", 'a, 'b, 'c.rowtime, 'proctime.proctime) - streamUtil.addTable[(Int, String, Long)]("MyTable2", 'a, 'b, 'c.rowtime, 'proctime.proctime) + private val table = streamUtil.addTable[(Int, String, Long)]( + "MyTable", + 'a, + 'b, + 'c.rowtime, + 'proctime.proctime) + private val table1 = streamUtil.addTable[(Int, String, Long)]( + "MyTable2", + 'a, + 'b, + 'c.rowtime, + 'proctime.proctime) // Tests for inner join @Test @@ -55,12 +65,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "proctime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "proctime") ), term("where", @@ -93,12 +103,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -131,12 +141,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "proctime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "proctime") ), term("where", @@ -169,12 +179,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -203,11 +213,11 @@ class JoinTest extends TableTestBase { unaryNode("DataStreamCalc", binaryNode("DataStreamWindowJoin", unaryNode("DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "proctime") ), unaryNode("DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "proctime") ), term("where", "AND(=(a, a0), =(PROCTIME(proctime), PROCTIME(proctime0)))"), @@ -233,11 +243,11 @@ class JoinTest extends TableTestBase { unaryNode("DataStreamCalc", binaryNode("DataStreamWindowJoin", unaryNode("DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), unaryNode("DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", "AND(=(a, a0), =(CAST(c), CAST(c0)))"), @@ -254,13 +264,11 @@ class JoinTest extends TableTestBase { val streamUtil: StreamTableTestUtil = streamTestUtil() val t1 = streamUtil.addTable[(Int, Long, String)]("Table1", 'a, 'b, 'c, 'proctime.proctime) - .select('a, 'b, 'c, 'proctime, nullOf(Types.LONG) as 'nullField) - val t2 = streamUtil.addTable[(Int, Long, String)]("Table2", 'a, 'b, 'c, 'proctime.proctime) - .select('a, 'b, 'c, 'proctime, 12L as 'nullField) - streamUtil.tableEnv.registerTable("T1", t1) - streamUtil.tableEnv.registerTable("T2", t2) + streamUtil.tableEnv + .registerTable("T1", t1.select('a, 'b, 'c, 'proctime, nullOf(Types.LONG) as 'nullField)) + streamUtil.tableEnv.registerTable("T2", t2.select('a, 'b, 'c, 'proctime, 12L as 'nullField)) val sqlQuery = """ @@ -275,12 +283,12 @@ class JoinTest extends TableTestBase { unaryNode("DataStreamCalc", binaryNode("DataStreamWindowJoin", unaryNode("DataStreamCalc", - streamTableNode(0), + streamTableNode(t1), term("select", "a", "c", "proctime", "null AS nullField") ), unaryNode("DataStreamCalc", - streamTableNode(1), - term("select", "a", "c", "proctime", "12 AS nullField") + streamTableNode(t2), + term("select", "a", "c", "proctime", "CAST(12) AS nullField") ), term("where", "AND(=(a, a0), =(nullField, nullField0), >=(PROCTIME(proctime), " + "-(PROCTIME(proctime0), 5000)), <=(PROCTIME(proctime), +(PROCTIME(proctime0), 5000)))"), @@ -313,12 +321,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -358,12 +366,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -401,12 +409,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "proctime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "proctime") ), term("where", @@ -439,12 +447,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -478,12 +486,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "proctime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "proctime") ), term("where", @@ -516,12 +524,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -555,12 +563,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "proctime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "proctime") ), term("where", @@ -593,12 +601,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -633,12 +641,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "a", "b", "c") ), term("where", @@ -791,8 +799,8 @@ class JoinTest extends TableTestBase { @Test def testLeftOuterJoinEquiPred(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val left = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val right = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t LEFT OUTER JOIN s ON a = z" val result = util.tableEnv.sqlQuery(query) @@ -803,12 +811,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -824,8 +832,8 @@ class JoinTest extends TableTestBase { @Test def testLeftOuterJoinEquiAndLocalPred(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val left = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val right = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t LEFT OUTER JOIN s ON a = z AND b < 2" val result = util.tableEnv.sqlQuery(query) @@ -836,12 +844,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b", "<(b, 2) AS $f3") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "y", "z") ), term("where", "AND(=(a, z), $f3)"), @@ -857,8 +865,8 @@ class JoinTest extends TableTestBase { @Test def testLeftOuterJoinEquiAndNonEquiPred(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val left = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val right = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t LEFT OUTER JOIN s ON a = z AND b < x" val result = util.tableEnv.sqlQuery(query) @@ -869,10 +877,10 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b") ), - streamTableNode(1), + streamTableNode(right), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "LeftOuterJoin") @@ -886,8 +894,8 @@ class JoinTest extends TableTestBase { @Test def testRightOuterJoinEquiPred(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val left = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val right = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t RIGHT OUTER JOIN s ON a = z" val result = util.tableEnv.sqlQuery(query) @@ -898,12 +906,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -919,8 +927,8 @@ class JoinTest extends TableTestBase { @Test def testRightOuterJoinEquiAndLocalPred(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val left = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val right = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, x FROM t RIGHT OUTER JOIN s ON a = z AND x < 2" val result = util.tableEnv.sqlQuery(query) @@ -931,12 +939,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "x", "z", "<(x, 2) AS $f3") ), term("where", "AND(=(a, z), $f3)"), @@ -952,8 +960,8 @@ class JoinTest extends TableTestBase { @Test def testRightOuterJoinEquiAndNonEquiPred(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) - util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) + val left = util.addTable[(Int, Long, String)]("t", 'a, 'b, 'c) + val right = util.addTable[(Long, String, Int)]("s", 'x, 'y, 'z) val query = "SELECT b, y FROM t RIGHT OUTER JOIN s ON a = z AND b < x" val result = util.tableEnv.sqlQuery(query) @@ -964,10 +972,10 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b") ), - streamTableNode(1), + streamTableNode(right), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "RightOuterJoin") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/MatchRecognizeTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/MatchRecognizeTest.scala index 98a8d089d29089..9088bd0da21c9f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/MatchRecognizeTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/MatchRecognizeTest.scala @@ -26,7 +26,12 @@ import org.junit.Test class MatchRecognizeTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]("MyTable", 'a, 'b, 'c.rowtime, 'proctime.proctime) + private val table = streamUtil.addTable[(Int, String, Long)]( + "MyTable", + 'a, + 'b, + 'c.rowtime, + 'proctime.proctime) @Test def testSimpleWithDefaults(): Unit = { @@ -45,7 +50,7 @@ class MatchRecognizeTest extends TableTestBase { val expected = unaryNode( "DataStreamMatch", - streamTableNode(0), + streamTableNode(table), term("orderBy", "proctime ASC"), term("measures", "FINAL(A.a) AS aa"), term("rowsPerMatch", "ONE ROW PER MATCH"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala index 9dbc2f24f43694..0efcb1e04ee043 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/OverWindowTest.scala @@ -26,7 +26,7 @@ import org.junit.Test class OverWindowTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]( + private val table = streamUtil.addTable[(Int, String, Long)]( "MyTable", 'a, 'b, 'c, 'proctime.proctime, 'rowtime.rowtime) @@ -62,7 +62,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "proctime") ), term("partitionBy", "b"), @@ -103,7 +103,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "c"), @@ -143,7 +143,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "c"), @@ -180,7 +180,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "a"), @@ -225,7 +225,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -260,7 +260,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -304,7 +304,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "c"), @@ -337,7 +337,7 @@ class OverWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamOverAggregate", - streamTableNode(0), + streamTableNode(table), term("partitionBy", "c"), term("orderBy", "proctime"), term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), @@ -371,7 +371,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -403,7 +403,7 @@ class OverWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamOverAggregate", - streamTableNode(0), + streamTableNode(table), term("orderBy", "proctime"), term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), term("select", "a", "b", "c", "proctime", "rowtime", "COUNT(a) AS w0$o0") @@ -428,7 +428,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "c"), @@ -456,7 +456,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "c"), @@ -484,7 +484,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -511,7 +511,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -545,7 +545,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "c"), @@ -573,7 +573,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "c"), @@ -614,7 +614,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -641,7 +641,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -690,7 +690,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "a"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SetOperatorsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SetOperatorsTest.scala index e5ee7e9df5095e..fcce9a7ddfcfb3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SetOperatorsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SetOperatorsTest.scala @@ -28,12 +28,12 @@ class SetOperatorsTest extends TableTestBase { @Test def testInOnLiterals(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val resultStr = (1 to 30).mkString(", ") val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c"), term("where", s"IN(b, $resultStr)") ) @@ -46,12 +46,12 @@ class SetOperatorsTest extends TableTestBase { @Test def testNotInOnLiterals(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val resultStr = (1 to 30).mkString(", ") val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c"), term("where", s"NOT IN(b, $resultStr)") ) @@ -64,8 +64,8 @@ class SetOperatorsTest extends TableTestBase { @Test def testInUncorrelated(): Unit = { val streamUtil = streamTestUtil() - streamUtil.addTable[(Int, Long, String)]("tableA", 'a, 'b, 'c) - streamUtil.addTable[(Int, String)]("tableB", 'x, 'y) + val table = streamUtil.addTable[(Int, Long, String)]("tableA", 'a, 'b, 'c) + val table1 = streamUtil.addTable[(Int, String)]("tableB", 'x, 'y) val sqlQuery = s""" @@ -78,12 +78,12 @@ class SetOperatorsTest extends TableTestBase { "DataStreamCalc", binaryNode( "DataStreamJoin", - streamTableNode(0), + streamTableNode(table), unaryNode( "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "x") ), term("groupBy", "x"), @@ -102,8 +102,8 @@ class SetOperatorsTest extends TableTestBase { @Test def testInUncorrelatedWithConditionAndAgg(): Unit = { val streamUtil = streamTestUtil() - streamUtil.addTable[(Int, Long, String)]("tableA", 'a, 'b, 'c) - streamUtil.addTable[(Int, String)]("tableB", 'x, 'y) + val table = streamUtil.addTable[(Int, Long, String)]("tableA", 'a, 'b, 'c) + val table1 = streamUtil.addTable[(Int, String)]("tableB", 'x, 'y) val sqlQuery = s""" @@ -116,7 +116,7 @@ class SetOperatorsTest extends TableTestBase { "DataStreamCalc", binaryNode( "DataStreamJoin", - streamTableNode(0), + streamTableNode(table), unaryNode( "DataStreamGroupAggregate", unaryNode( @@ -125,7 +125,7 @@ class SetOperatorsTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "x", "y"), term("where", "LIKE(y, '%Hanoi%')") ), @@ -150,9 +150,9 @@ class SetOperatorsTest extends TableTestBase { @Test def testInWithMultiUncorrelatedCondition(): Unit = { val streamUtil = streamTestUtil() - streamUtil.addTable[(Int, Long, String)]("tableA", 'a, 'b, 'c) - streamUtil.addTable[(Int, String)]("tableB", 'x, 'y) - streamUtil.addTable[(Long, Int)]("tableC", 'w, 'z) + val table = streamUtil.addTable[(Int, Long, String)]("tableA", 'a, 'b, 'c) + val table1 = streamUtil.addTable[(Int, String)]("tableB", 'x, 'y) + val table2 = streamUtil.addTable[(Long, Int)]("tableC", 'w, 'z) val sqlQuery = s""" @@ -170,12 +170,12 @@ class SetOperatorsTest extends TableTestBase { "DataStreamCalc", binaryNode( "DataStreamJoin", - streamTableNode(0), + streamTableNode(table), unaryNode( "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(table1), term("select", "x") ), term("groupBy", "x"), @@ -191,7 +191,7 @@ class SetOperatorsTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(2), + streamTableNode(table2), term("select", "w") ), term("groupBy", "w"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala index 087ae7d37098e9..b8ad376c7e7d7f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/SortTest.scala @@ -27,7 +27,7 @@ import org.junit.Test class SortTest extends TableTestBase { private val streamUtil: StreamTableTestUtil = streamTestUtil() - streamUtil.addTable[(Int, String, Long)]("MyTable", 'a, 'b, 'c, + private val table = streamUtil.addTable[(Int, String, Long)]("MyTable", 'a, 'b, 'c, 'proctime.proctime, 'rowtime.rowtime) @Test @@ -39,7 +39,7 @@ class SortTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode("DataStreamSort", - streamTableNode(0), + streamTableNode(table), term("orderBy", "proctime ASC", "c ASC")), term("select", "a", "PROCTIME(proctime) AS proctime", "c")) @@ -55,7 +55,7 @@ class SortTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode("DataStreamSort", - streamTableNode(0), + streamTableNode(table), term("orderBy", "rowtime ASC, c ASC")), term("select", "a", "rowtime", "c")) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala index 27c40bbbef24ee..c758660fda9ac5 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala @@ -22,7 +22,8 @@ import java.sql.Timestamp import org.apache.flink.api.scala._ import org.apache.flink.table.api.TableException import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.stream.table.TemporalTableJoinTest._ +import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin.TEMPORAL_JOIN_CONDITION +import org.apache.flink.table.utils.TableTestUtil.{binaryNode, streamTableNode, term, unaryNode} import org.apache.flink.table.utils._ import org.hamcrest.Matchers.startsWith import org.junit.Test @@ -57,7 +58,7 @@ class TemporalTableJoinTest extends TableTestBase { "o_amount * rate as rate " + "FROM Orders AS o, " + "LATERAL TABLE (Rates(o.o_rowtime)) AS r " + - "WHERE currency = o_currency"; + "WHERE currency = o_currency" util.verifySql(sqlQuery, getExpectedSimpleJoinPlan()) } @@ -68,7 +69,7 @@ class TemporalTableJoinTest extends TableTestBase { "o_amount * rate as rate " + "FROM ProctimeOrders AS o, " + "LATERAL TABLE (ProctimeRates(o.o_proctime)) AS r " + - "WHERE currency = o_currency"; + "WHERE currency = o_currency" util.verifySql(sqlQuery, getExpectedSimpleProctimeJoinPlan()) } @@ -81,8 +82,8 @@ class TemporalTableJoinTest extends TableTestBase { @Test def testComplexJoin(): Unit = { val util = streamTestUtil() - util.addTable[(String, Int)]("Table3", 't3_comment, 't3_secondary_key) - util.addTable[(Timestamp, String, Long, String, Int)]( + val thirdTable = util.addTable[(String, Int)]("Table3", 't3_comment, 't3_secondary_key) + val orders = util.addTable[(Timestamp, String, Long, String, Int)]( "Orders", 'o_rowtime.rowtime, 'o_comment, 'o_amount, 'o_currency, 'o_secondary_key) val ratesHistory = util.addTable[(Timestamp, String, String, Int, Int)]( @@ -101,9 +102,49 @@ class TemporalTableJoinTest extends TableTestBase { "LATERAL TABLE (Rates(o_rowtime)) AS r " + "WHERE currency = o_currency OR secondary_key = o_secondary_key), " + "Table3 " + - "WHERE t3_secondary_key = secondary_key"; - - util.verifySql(sqlQuery, getExpectedComplexJoinPlan()) + "WHERE t3_secondary_key = secondary_key" + + util.verifySql(sqlQuery, binaryNode( + "DataStreamJoin", + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamTemporalTableJoin", + unaryNode( + "DataStreamCalc", + streamTableNode(orders), + term("select", "o_rowtime, o_amount, o_currency, o_secondary_key") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(ratesHistory), + term("select", "rowtime, currency, rate, secondary_key"), + term("where", ">(rate, 110)") + ), + term( + "where", + "AND(" + + s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " + + "OR(=(currency, o_currency), =(secondary_key, o_secondary_key)))"), + term( + "join", + "o_rowtime", + "o_amount", + "o_currency", + "o_secondary_key", + "rowtime", + "currency", + "rate", + "secondary_key"), + term("joinType", "InnerJoin") + ), + term("select", "*(o_amount, rate) AS rate", "secondary_key") + ), + streamTableNode(thirdTable), + term("where", "=(t3_secondary_key, secondary_key)"), + term("join", "rate, secondary_key, t3_comment, t3_secondary_key"), + term("joinType", "InnerJoin") + )) } @Test @@ -115,7 +156,7 @@ class TemporalTableJoinTest extends TableTestBase { "o_amount * rate as rate " + "FROM Orders AS o, " + "LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123')) AS r " + - "WHERE currency = o_currency"; + "WHERE currency = o_currency" util.printSql(sqlQuery) } @@ -125,8 +166,69 @@ class TemporalTableJoinTest extends TableTestBase { expectedException.expect(classOf[TableException]) expectedException.expectMessage(startsWith("Cannot generate a valid execution plan")) - val sqlQuery = "SELECT * FROM LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123'))"; + val sqlQuery = "SELECT * FROM LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123'))" util.printSql(sqlQuery) } + + def getExpectedSimpleJoinPlan(): String = { + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamTemporalTableJoin", + streamTableNode(orders), + streamTableNode(ratesHistory), + term("where", + "AND(" + + s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " + + "=(currency, o_currency))"), + term("join", "o_amount", "o_currency", "o_rowtime", "currency", "rate", "rowtime"), + term("joinType", "InnerJoin") + ), + term("select", "*(o_amount, rate) AS rate") + ) + } + + def getExpectedSimpleProctimeJoinPlan(): String = { + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamTemporalTableJoin", + streamTableNode(proctimeOrders), + unaryNode( + "DataStreamCalc", + streamTableNode(proctimeRatesHistory), + term("select", "currency, rate")), + term("where", + "AND(" + + s"${TEMPORAL_JOIN_CONDITION.getName}(o_proctime, currency), " + + "=(currency, o_currency))"), + term("join", "o_amount", "o_currency", "o_proctime", "currency", "rate"), + term("joinType", "InnerJoin") + ), + term("select", "*(o_amount, rate) AS rate") + ) + } + + def getExpectedTemporalTableFunctionOnTopOfQueryPlan(): String = { + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamTemporalTableJoin", + streamTableNode(orders), + unaryNode( + "DataStreamCalc", + streamTableNode(ratesHistory), + term("select", "currency", "*(rate, 2) AS rate", "rowtime"), + term("where", ">(rate, 100)")), + term("where", + "AND(" + + s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " + + "=(currency, o_currency))"), + term("join", "o_amount", "o_currency", "o_rowtime", "currency", "rate", "rowtime"), + term("joinType", "InnerJoin") + ), + term("select", "*(o_amount, rate) AS rate") + ) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/UnionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/UnionTest.scala index b8bd2223f621b2..306a3164a37c95 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/UnionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/UnionTest.scala @@ -32,18 +32,18 @@ class UnionTest extends TableTestBase { @Test def testUnionAllNullableCompositeType() = { val streamUtil = streamTestUtil() - streamUtil.addTable[((Int, String), (Int, String), Int)]("A", 'a, 'b, 'c) + val table = streamUtil.addTable[((Int, String), (Int, String), Int)]("A", 'a, 'b, 'c) val expected = binaryNode( "DataStreamUnion", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a") ), unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "CASE(>(c, 0), b, null) AS EXPR$0") ), term("all", "true"), @@ -62,18 +62,18 @@ class UnionTest extends TableTestBase { val typeInfo = Types.ROW( new GenericTypeInfo(classOf[NonPojo]), new GenericTypeInfo(classOf[NonPojo])) - streamUtil.addJavaTable(typeInfo, "A", "a, b") + val table = streamUtil.addJavaTable(typeInfo, "A", "a, b") val expected = binaryNode( "DataStreamUnion", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a") ), unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "b") ), term("all", "true"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala index 626306f4497db0..70652af50524c8 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/AggregateTest.scala @@ -43,7 +43,7 @@ class AggregateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "b"), term("select", "b", "SUM(DISTINCT a) AS TMP_0", "COUNT(DISTINCT c) AS TMP_1") ), @@ -67,7 +67,7 @@ class AggregateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "c"), term( "select", @@ -96,7 +96,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b") ), term("groupBy", "b"), @@ -124,7 +124,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "4 AS four", "b") ), term("groupBy", "a", "four"), @@ -152,7 +152,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "b", "4 AS four", "a") ), term("groupBy", "b", "four"), @@ -180,7 +180,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "MOD(b, 3) AS d", "c") ), term("groupBy", "d"), @@ -206,7 +206,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b"), term("where", "=(b, 2)") ), @@ -230,7 +230,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "b", "CAST(a) AS a0") ), term("groupBy", "b"), @@ -254,7 +254,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "rowtime") ), term("window", "TumblingGroupWindow('w, 'rowtime, 900000.millis)"), @@ -278,7 +278,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "rowtime") ), term("window", "SlidingGroupWindow('w, 'rowtime, 3600000.millis, 900000.millis)"), @@ -303,7 +303,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("groupBy", "a"), @@ -333,7 +333,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b") ), term("groupBy", "b"), @@ -363,7 +363,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b") ), term("groupBy", "b"), @@ -390,7 +390,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b") ), term("groupBy", "b"), @@ -418,7 +418,7 @@ class AggregateTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b") ), term("groupBy", "b"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala index 3d068d65664cf0..6c7487c0d917f3 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala @@ -48,7 +48,7 @@ class CalcTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "a", "rowtime", "UPPER(c) AS $f5") ), term("window", "TumblingGroupWindow('w, 'rowtime, 5.millis)"), @@ -74,7 +74,7 @@ class CalcTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "a", "b", "rowtime", "UPPER(c) AS $f5") ), term("groupBy", "b"), @@ -98,7 +98,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "a", "b"), term("where", "AND(AND(>(a, 0), <(b, 2)), =(MOD(a, 2), 1))") ) @@ -115,7 +115,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "a", "b", "c"), term("where", s"AND(IN(b, ${(1 to 30).mkString(", ")}), =(c, 'xx'))") ) @@ -132,7 +132,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "a", "b", "c"), term("where", s"OR(NOT IN(b, ${(1 to 30).mkString(", ")}), <>(c, 'xx'))") ) @@ -153,7 +153,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term( "select", "a", "b", "c", "CONCAT(c, '_kid_last') AS kid", "+(a, 2) AS _c4, b AS b2", "'literal_value' AS _c6") @@ -169,7 +169,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "a AS a2", "b AS b2") ) util.verifyTable(resultTable, expected) @@ -183,7 +183,7 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "c") ) util.verifyTable(resultTable, expected) @@ -193,12 +193,12 @@ class CalcTest extends TableTestBase { def testSimpleMap(): Unit = { val util = streamTestUtil() - val resultTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) - .map(Func23('a, 'b, 'c)) + val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val resultTable = sourceTable.map(Func23('a, 'b, 'c)) val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "Func23$(a, b, c).f0 AS _c0, Func23$(a, b, c).f1 AS _c1, " + "Func23$(a, b, c).f2 AS _c2, Func23$(a, b, c).f3 AS _c3") ) @@ -210,12 +210,12 @@ class CalcTest extends TableTestBase { def testScalarResult(): Unit = { val util = streamTestUtil() - val resultTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) - .map(Func1('a)) + val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val resultTable = sourceTable.map(Func1('a)) val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "Func1$(a) AS _c0") ) @@ -226,13 +226,14 @@ class CalcTest extends TableTestBase { def testMultiMap(): Unit = { val util = streamTestUtil() - val resultTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val resultTable = sourceTable .map(Func23('a, 'b, 'c)) .map(Func24('_c0, '_c1, '_c2, '_c3)) val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(sourceTable), term("select", "Func24$(Func23$(a, b, c).f0, Func23$(a, b, c).f1, " + "Func23$(a, b, c).f2, Func23$(a, b, c).f3).f0 AS _c0, " + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/ColumnFunctionsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/ColumnFunctionsTest.scala index bfde6391a579f2..cc41f4a6ccb55a 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/ColumnFunctionsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/ColumnFunctionsTest.scala @@ -33,7 +33,7 @@ import org.junit.Test */ class ColumnFunctionsTest extends TableTestBase { - val util = new StreamTableTestUtil() + val util = streamTestUtil() private def verifyAll(tab1: Table, tab2: Table, expected: String): Unit = { util.verifyTable(tab1, expected) @@ -52,7 +52,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "TestFunc$(double, long) AS _c0") ) @@ -69,7 +69,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "b", "c", "a", "e", "f", "d") ) @@ -86,7 +86,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b", "c", "f") ) @@ -108,7 +108,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "e", "f") ) @@ -125,7 +125,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "CONCAT(string1, string2) AS _c0") ) @@ -143,8 +143,8 @@ class ColumnFunctionsTest extends TableTestBase { val expected = binaryNode( "DataStreamJoin", - streamTableNode(0), - streamTableNode(1), + streamTableNode(t1), + streamTableNode(t2), term("where", "=(int1, int2)"), term("join", "int1", "long1", "string1", "int2", "long2", "string2"), term("joinType", "InnerJoin") @@ -165,7 +165,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(t), term("invocation", "org$apache$flink$table$utils$TableFunc0$497a630d2a145bca99673bcd05a53d2b($2)"), term("correlate", "table(TableFunc0(string))"), @@ -190,7 +190,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "int", "long", "string1", "string2"), term("where", "=(CONCAT(string1, string2), 'a')") ) @@ -215,7 +215,7 @@ class ColumnFunctionsTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b", "c") ), term("groupBy", "a", "b"), @@ -246,7 +246,7 @@ class ColumnFunctionsTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "CAST(b) AS b", "c") ), term("groupBy", "a", "b"), @@ -295,7 +295,7 @@ class ColumnFunctionsTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamOverAggregate", - streamTableNode(0), + streamTableNode(table), term("partitionBy", "c"), term("orderBy", "proctime"), term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), @@ -318,7 +318,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b", "c", "TestFunc$(a, b) AS d") ) @@ -335,7 +335,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a AS d", "b") ) @@ -352,7 +352,7 @@ class ColumnFunctionsTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "c") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala index 1d3ae67ebdf2d2..31a197dff181b8 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala @@ -45,7 +45,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "s"), @@ -66,7 +66,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}($$2, '$$')"), term("correlate", s"table(${function.getClass.getSimpleName}(c, '$$'))"), term("select", "a", "b", "c", "s"), @@ -92,7 +92,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "s"), @@ -120,7 +120,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}(${scalarFunc.functionIdentifier}($$2))"), term("correlate", s"table(${function.getClass.getSimpleName}(Func13(c)))"), @@ -146,7 +146,7 @@ class CorrelateTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", "table(HierarchyTableFunction(c))"), term("select", "a", "b", "c", "name", "adult", "len"), @@ -169,7 +169,7 @@ class CorrelateTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "age", "name"), @@ -197,7 +197,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "name", "len"), @@ -223,7 +223,7 @@ class CorrelateTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(table), term("invocation", s"${function.functionIdentifier}(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"), term("correlate", s"table(${function.getClass.getSimpleName}(SUBSTRING(c, 2, CHAR_LENGTH(c))))"), @@ -253,7 +253,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(sourceTable), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "d", "e"), @@ -299,7 +299,7 @@ class CorrelateTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(sourceTable), term("invocation", s"${function.functionIdentifier}($$2)"), term("correlate", s"table(${function.getClass.getSimpleName}(c))"), term("select", "a", "b", "c", "d", "e"), @@ -320,14 +320,15 @@ class CorrelateTest extends TableTestBase { val util = streamTestUtil() val func2 = new TableFunc2 - val resultTable = util.addTable[(Int, Long, String)]("MyTable", 'f1, 'f2, 'f3) + val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'f1, 'f2, 'f3) + val resultTable = sourceTable .flatMap(func2('f3)) val expected = unaryNode( "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(sourceTable), term("invocation", s"${func2.functionIdentifier}($$2)"), term("correlate", "table(TableFunc2(f3))"), term("select", "f1", "f2", "f3", "f0", "f1_0"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTableAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTableAggregateTest.scala index 37796c038a533b..fc09465f5a171c 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTableAggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTableAggregateTest.scala @@ -53,7 +53,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "e") ), term("groupBy", "c"), @@ -86,7 +86,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "e", "MOD(b, 5) AS bb") ), term("groupBy", "bb"), @@ -111,7 +111,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "e") ), term("groupBy", "c"), @@ -135,7 +135,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "d") ), term("groupBy", "c"), @@ -161,7 +161,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "e") ), term("groupBy", "c"), @@ -187,7 +187,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "e") ), term("groupBy", "c"), @@ -211,7 +211,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "d") ), term("groupBy", "c"), @@ -235,7 +235,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "d") ), term("groupBy", "c"), @@ -259,7 +259,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "e") ), term("window", "TumblingGroupWindow('w, 'e, 50.millis)"), @@ -282,7 +282,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "e") ), term("window", "TumblingGroupWindow('w, 'e, 2)"), @@ -305,7 +305,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "d") ), term("window", "TumblingGroupWindow('w, 'd, 5.millis)"), @@ -329,7 +329,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "e") ), term("window", "SlidingGroupWindow('w, 'e, 50.millis, 50.millis)"), @@ -352,7 +352,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "e") ), term("window", "SlidingGroupWindow('w, 'e, 2, 1)"), @@ -375,7 +375,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "d") ), term("window", "SlidingGroupWindow('w, 'd, 8.millis, 10.millis)"), @@ -398,7 +398,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "d") ), term("window", "SlidingGroupWindow('w, 'd, 8.millis, 10.millis)"), @@ -421,7 +421,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "d") ), term("window", "SessionGroupWindow('w, 'd, 7.millis)"), @@ -446,7 +446,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "d") ), term("groupBy", "c"), @@ -475,7 +475,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "d") ), term("groupBy", "c"), @@ -504,7 +504,7 @@ class GroupWindowTableAggregateTest extends TableTestBase { "DataStreamGroupWindowTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "d") ), term("groupBy", "c"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala index d1aade26025856..4cef48897d7d74 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/GroupWindowTest.scala @@ -49,7 +49,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "proctime") ), term("groupBy", "string"), @@ -78,7 +78,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "proctime") ), term("groupBy", "string"), @@ -103,7 +103,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "proctime") ), term("groupBy", "string"), @@ -126,7 +126,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'long, 5.millis)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -149,7 +149,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'rowtime, 5.millis)"), term("select", "string", "myWeightedAvg(long, int) AS TMP_0") @@ -172,7 +172,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "proctime") ), term("groupBy", "string"), @@ -197,7 +197,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "proctime") ), term("groupBy", "string"), @@ -222,7 +222,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "rowtime") ), term("groupBy", "string"), @@ -246,7 +246,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'long, 8.millis, 10.millis)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -269,7 +269,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "SlidingGroupWindow('w, 'rowtime, 8.millis, 10.millis)"), term("select", "string", "myWeightedAvg(long, int) AS TMP_0") @@ -290,7 +290,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'long, 7.millis)"), term("select", "string", "COUNT(int) AS TMP_0") @@ -313,7 +313,7 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'rowtime, 7.millis)"), term("select", "string", "myWeightedAvg(long, int) AS TMP_0") @@ -336,7 +336,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "proctime") ), term("groupBy", "string"), @@ -361,7 +361,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "proctime") ), term("window", "TumblingGroupWindow('w, 'proctime, 2)"), @@ -385,7 +385,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "rowtime") ), term("window", "TumblingGroupWindow('w, 'rowtime, 5.millis)"), @@ -410,7 +410,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "long") ), term("window", "TumblingGroupWindow('w, 'long, 5.millis)"), @@ -434,7 +434,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "proctime") ), term("window", "SlidingGroupWindow('w, 'proctime, 50.millis, 50.millis)"), @@ -458,7 +458,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "proctime") ), term("window", "SlidingGroupWindow('w, 'proctime, 2, 1)"), @@ -482,7 +482,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "rowtime") ), term("window", "SlidingGroupWindow('w, 'rowtime, 8.millis, 10.millis)"), @@ -507,7 +507,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "long", "int") ), term("window", "SlidingGroupWindow('w, 'long, 8.millis, 10.millis)"), @@ -531,7 +531,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "long", "int") ), term("window", "SessionGroupWindow('w, 'long, 7.millis)"), @@ -555,7 +555,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "rowtime") ), term("groupBy", "string"), @@ -593,7 +593,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string, int2, int3"), term("window", "SlidingGroupWindow('w, 'proctime, 2, 1)"), term( @@ -623,7 +623,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "int", "string", "rowtime") ), term("groupBy", "string"), @@ -652,7 +652,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "SessionGroupWindow('w, 'long, 3.millis)"), term("select", @@ -682,7 +682,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(table), term("groupBy", "string"), term("window", "TumblingGroupWindow('w, 'long, 5.millis)"), term("select", @@ -721,7 +721,7 @@ class GroupWindowTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "rowtime", "c", "*(c, c) AS $f2") ), term("window", "TumblingGroupWindow('w, 'rowtime, 900000.millis)"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/JoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/JoinTest.scala index 138497cada3dd8..ab60d00d546019 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/JoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/JoinTest.scala @@ -49,12 +49,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lrtime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rrtime") ), term("where", "AND(=(a, d), >=(CAST(lrtime), -(CAST(rrtime), 300000))," + @@ -84,12 +84,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lptime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rptime") ), term("where", "AND(=(a, d), >=(PROCTIME(lptime), -(PROCTIME(rptime), 1000)), " + @@ -119,12 +119,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lptime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rptime") ), term("where", "AND(=(a, d), =(PROCTIME(lptime), PROCTIME(rptime)))"), @@ -151,8 +151,8 @@ class JoinTest extends TableTestBase { val expected = binaryNode( "DataStreamWindowJoin", - streamTableNode(0), - streamTableNode(1), + streamTableNode(left), + streamTableNode(right), term("where", "AND(=(a, d), >=(CAST(lrtime), -(CAST(rrtime), 300000)), " + "<(CAST(lrtime), CAST(rrtime)), >(CAST(lrtime), f))"), @@ -182,12 +182,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lrtime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rrtime") ), term("where", "AND(=(a, d), >=(CAST(lrtime), -(CAST(rrtime), 300000))," + @@ -217,12 +217,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lptime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rptime") ), term("where", "AND(=(a, d), >=(PROCTIME(lptime), -(PROCTIME(rptime), 1000)), " + @@ -255,12 +255,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lrtime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rrtime") ), term("where", "AND(=(a, d), >=(CAST(lrtime), -(CAST(rrtime), 300000))," + @@ -290,12 +290,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lptime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rptime") ), term("where", "AND(=(a, d), >=(PROCTIME(lptime), -(PROCTIME(rptime), 1000)), " + @@ -328,12 +328,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lrtime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rrtime") ), term("where", "AND(=(a, d), >=(CAST(lrtime), -(CAST(rrtime), 300000))," + @@ -363,12 +363,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lptime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rptime") ), term("where", "AND(=(a, d), >=(PROCTIME(lptime), -(PROCTIME(rptime), 1000)), " + @@ -399,12 +399,12 @@ class JoinTest extends TableTestBase { "DataStreamWindowJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "lrtime") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "d", "e", "rrtime") ), term("where", "AND(=(a, d), >=(CAST(lrtime), -(CAST(rrtime), 300000))," + @@ -433,12 +433,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(s), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -465,12 +465,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(s), term("select", "y", "z") ), term("where", "AND(=(a, z), <(b, 2))"), @@ -497,10 +497,10 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b") ), - streamTableNode(1), + streamTableNode(s), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "LeftOuterJoin") @@ -525,12 +525,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(s), term("select", "y", "z") ), term("where", "=(a, z)"), @@ -557,12 +557,12 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(s), term("select", "x", "z") ), term("where", "AND(=(a, z), <(x, 2))"), @@ -589,10 +589,10 @@ class JoinTest extends TableTestBase { "DataStreamJoin", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "a", "b") ), - streamTableNode(1), + streamTableNode(s), term("where", "AND(=(a, z), <(b, x))"), term("join", "a", "b", "x", "y", "z"), term("joinType", "RightOuterJoin") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala index c41112411d2665..9e47d2a8e851ea 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/OverWindowTest.scala @@ -52,7 +52,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), // RexSimplify didn't simplify "CAST(1):BIGINT NOT NULL", see [CALCITE-2862] term("select", "a", "b", "c", "proctime", "1 AS $4") ), @@ -89,7 +89,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "proctime") ), term("partitionBy", "b"), @@ -118,7 +118,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "a"), @@ -151,7 +151,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -177,7 +177,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -211,7 +211,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "c"), @@ -253,7 +253,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("partitionBy", "c"), @@ -283,7 +283,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -322,7 +322,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "proctime") ), term("orderBy", "proctime"), @@ -351,7 +351,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c", "rowtime") ), term("partitionBy", "b"), @@ -383,7 +383,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "a"), @@ -417,7 +417,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -443,7 +443,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -478,7 +478,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "c"), @@ -521,7 +521,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("partitionBy", "c"), @@ -551,7 +551,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), @@ -590,7 +590,7 @@ class OverWindowTest extends TableTestBase { "DataStreamOverAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c", "rowtime") ), term("orderBy", "rowtime"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala index dfbaf40d524f80..656147e1f34bf2 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/SetOperatorsTest.scala @@ -45,13 +45,13 @@ class SetOperatorsTest extends TableTestBase { "DataStreamUnion", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "a", "b", "c"), term("where", ">(a, 0)") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "a", "b", "c"), term("where", ">(a, 0)") ), @@ -81,12 +81,12 @@ class SetOperatorsTest extends TableTestBase { "DataStreamUnion", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(left), term("select", "b", "c") ), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(right), term("select", "b", "c") ), term("all", "true"), @@ -109,12 +109,12 @@ class SetOperatorsTest extends TableTestBase { "DataStreamCalc", binaryNode( "DataStreamJoin", - streamTableNode(0), + streamTableNode(tableA), unaryNode( "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(tableB), term("select", "x") ), term("groupBy", "x"), @@ -144,7 +144,7 @@ class SetOperatorsTest extends TableTestBase { "DataStreamCalc", binaryNode( "DataStreamJoin", - streamTableNode(0), + streamTableNode(tableA), unaryNode( "DataStreamGroupAggregate", unaryNode( @@ -153,7 +153,7 @@ class SetOperatorsTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(tableB), term("select", "x", "y"), term("where", "LIKE(y, '%Hanoi%')") ), @@ -194,12 +194,12 @@ class SetOperatorsTest extends TableTestBase { "DataStreamCalc", binaryNode( "DataStreamJoin", - streamTableNode(0), + streamTableNode(tableA), unaryNode( "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(tableB), term("select", "x") ), term("groupBy", "x"), @@ -215,7 +215,7 @@ class SetOperatorsTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(2), + streamTableNode(tableC), term("select", "w") ), term("groupBy", "w"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala index 938d2c5c56a14f..5c44eab180cbd4 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableAggregateTest.scala @@ -49,7 +49,7 @@ class TableAggregateTest extends TableTestBase { "DataStreamGroupTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "MOD(b, 5) AS bb") ), term("groupBy", "bb"), @@ -74,7 +74,7 @@ class TableAggregateTest extends TableTestBase { "DataStreamGroupTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b") ), term("select", "EmptyTableAggFunc(a, b) AS (f0, f1)") @@ -96,7 +96,7 @@ class TableAggregateTest extends TableTestBase { "DataStreamGroupTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "CAST(d) AS d", "PROCTIME(e) AS e") ), term("select", "EmptyTableAggFunc(d, e) AS (f0, f1)") @@ -116,7 +116,7 @@ class TableAggregateTest extends TableTestBase { "DataStreamGroupTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "b") ), term("select", "EmptyTableAggFunc(b) AS (f0, f1)") @@ -138,7 +138,7 @@ class TableAggregateTest extends TableTestBase { "DataStreamGroupTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "b") ), term("select", "EmptyTableAggFunc(b) AS (f0, f1)") @@ -167,7 +167,7 @@ class TableAggregateTest extends TableTestBase { "DataStreamGroupTableAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "c") ), term("groupBy", "c"), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala index c6ba3545b613b6..1b84081ddccfab 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TableSourceTest.scala @@ -47,7 +47,8 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("rowTimeT").select("rowtime, id, name, val") - val expected = "StreamTableSourceScan(table=[[rowTimeT]], fields=[rowtime, id, name, val], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, rowTimeT]], " + + "fields=[rowtime, id, name, val], " + "source=[TestTableSourceWithTime(id, rowtime, val, name)])" util.verifyTable(t, expected) } @@ -70,7 +71,8 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("rowTimeT").select("rowtime, id, name, val") - val expected = "StreamTableSourceScan(table=[[rowTimeT]], fields=[rowtime, id, name, val], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, rowTimeT]], " + + "fields=[rowtime, id, name, val], " + "source=[TestTableSourceWithTime(id, rowtime, val, name)])" util.verifyTable(t, expected) } @@ -104,7 +106,8 @@ class TableSourceTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - "StreamTableSourceScan(table=[[rowTimeT]], fields=[rowtime, val, name], " + + "StreamTableSourceScan(table=[[default_catalog, default_database, rowTimeT]], " + + "fields=[rowtime, val, name], " + "source=[TestTableSourceWithTime(id, rowtime, val, name)])", term("select", "rowtime", "val", "name"), term("where", ">(val, 100)") @@ -138,7 +141,8 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - "StreamTableSourceScan(table=[[procTimeT]], fields=[id, proctime, val, name], " + + "StreamTableSourceScan(table=[[default_catalog, default_database, procTimeT]], " + + "fields=[id, proctime, val, name], " + "source=[TestTableSourceWithTime(id, proctime, val, name)])", term("select", "PROCTIME(proctime) AS proctime", "id", "name", "val") ) @@ -170,7 +174,8 @@ class TableSourceTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamOverAggregate", - "StreamTableSourceScan(table=[[procTimeT]], fields=[id, proctime, val, name], " + + "StreamTableSourceScan(table=[[default_catalog, default_database, procTimeT]], " + + "fields=[id, proctime, val, name], " + "source=[TestTableSourceWithTime(id, proctime, val, name)])", term("partitionBy", "id"), term("orderBy", "proctime"), @@ -200,7 +205,7 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("T").select('name, 'val, 'id) - val expected = "StreamTableSourceScan(table=[[T]], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[name, val, id], " + "source=[TestSource(physical fields: name, val, id)])" util.verifyTable(t, expected) @@ -225,7 +230,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - "StreamTableSourceScan(table=[[T]], " + + "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[ptime, name, val, id], " + "source=[TestSource(physical fields: name, val, id)])", term("select", "PROCTIME(ptime) AS ptime", "name", "val", "id") @@ -249,7 +254,7 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("T").select('name, 'val, 'rtime, 'id) - val expected = "StreamTableSourceScan(table=[[T]], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[name, val, rtime, id], " + "source=[TestSource(physical fields: name, val, rtime, id)])" util.verifyTable(t, expected) @@ -271,7 +276,7 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("T").select('ptime) - val expected = "StreamTableSourceScan(table=[[T]], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[ptime], " + "source=[TestSource(physical fields: )])" util.verifyTable(t, expected) @@ -293,7 +298,7 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("T").select('rtime) - val expected = "StreamTableSourceScan(table=[[T]], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[rtime], " + "source=[TestSource(physical fields: rtime)])" util.verifyTable(t, expected) @@ -317,7 +322,7 @@ class TableSourceTest extends TableTestBase { val t = util.tableEnv.scan("T").select('name, 'rtime, 'val) - val expected = "StreamTableSourceScan(table=[[T]], " + + val expected = "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[name, rtime, val], " + "source=[TestSource(physical fields: remapped-p-name, remapped-p-rtime, remapped-p-val)])" util.verifyTable(t, expected) @@ -364,7 +369,7 @@ class TableSourceTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - "StreamTableSourceScan(table=[[T]], " + + "StreamTableSourceScan(table=[[default_catalog, default_database, T]], " + "fields=[id, deepNested, nested], " + "source=[TestSource(read nested fields: " + "id.*, deepNested.nested2.num, deepNested.nested2.flag, " + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala index 37ded5e2e3fd78..7665192102a5ac 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala @@ -22,17 +22,15 @@ import java.sql.Timestamp import org.apache.flink.api.scala._ import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.stream.table.TemporalTableJoinTest._ import org.apache.flink.table.api.{TableSchema, Types, ValidationException} import org.apache.flink.table.expressions.{Expression, FieldReferenceExpression} import org.apache.flink.table.functions.{TemporalTableFunction, TemporalTableFunctionImpl} import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin._ -import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo.{PROCTIME_INDICATOR, ROWTIME_INDICATOR} import org.apache.flink.table.utils.TableTestUtil._ import org.apache.flink.table.utils._ import org.hamcrest.Matchers.{equalTo, startsWith} -import org.junit.Assert.{assertArrayEquals, assertEquals, assertThat, assertTrue} +import org.junit.Assert.{assertEquals, assertThat} import org.junit.Test class TemporalTableJoinTest extends TableTestBase { @@ -106,7 +104,47 @@ class TemporalTableJoinTest extends TableTestBase { .select('o_amount * 'rate, 'secondary_key).as('rate, 'secondary_key) .join(thirdTable, 't3_secondary_key === 'secondary_key) - util.verifyTable(result, getExpectedComplexJoinPlan()) + util.verifyTable(result, binaryNode( + "DataStreamJoin", + unaryNode( + "DataStreamCalc", + binaryNode( + "DataStreamTemporalTableJoin", + unaryNode( + "DataStreamCalc", + streamTableNode(orders), + term("select", "o_rowtime, o_amount, o_currency, o_secondary_key") + ), + unaryNode( + "DataStreamCalc", + streamTableNode(ratesHistory), + term("select", "rowtime, currency, rate, secondary_key"), + term("where", ">(rate, 110)") + ), + term( + "where", + "AND(" + + s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " + + "OR(=(currency, o_currency), =(secondary_key, o_secondary_key)))"), + term( + "join", + "o_rowtime", + "o_amount", + "o_currency", + "o_secondary_key", + "rowtime", + "currency", + "rate", + "secondary_key"), + term("joinType", "InnerJoin") + ), + term("select", "*(o_amount, rate) AS rate", "secondary_key") + ), + streamTableNode(thirdTable), + term("where", "=(t3_secondary_key, secondary_key)"), + term("join", "rate, secondary_key, t3_comment, t3_secondary_key"), + term("joinType", "InnerJoin") + )) } @Test @@ -176,16 +214,14 @@ class TemporalTableJoinTest extends TableTestBase { expectedSchema.toRowType, rates.getResultType) } -} -object TemporalTableJoinTest { def getExpectedSimpleJoinPlan(): String = { unaryNode( "DataStreamCalc", binaryNode( "DataStreamTemporalTableJoin", - streamTableNode(0), - streamTableNode(1), + streamTableNode(orders), + streamTableNode(ratesHistory), term("where", "AND(" + s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " + @@ -202,10 +238,10 @@ object TemporalTableJoinTest { "DataStreamCalc", binaryNode( "DataStreamTemporalTableJoin", - streamTableNode(2), + streamTableNode(proctimeOrders), unaryNode( "DataStreamCalc", - streamTableNode(3), + streamTableNode(proctimeRatesHistory), term("select", "currency, rate")), term("where", "AND(" + @@ -218,57 +254,15 @@ object TemporalTableJoinTest { ) } - def getExpectedComplexJoinPlan(): String = { - binaryNode( - "DataStreamJoin", - unaryNode( - "DataStreamCalc", - binaryNode( - "DataStreamTemporalTableJoin", - unaryNode( - "DataStreamCalc", - streamTableNode(1), - term("select", "o_rowtime, o_amount, o_currency, o_secondary_key") - ), - unaryNode( - "DataStreamCalc", - streamTableNode(2), - term("select", "rowtime, currency, rate, secondary_key"), - term("where", ">(rate, 110)") - ), - term("where", - "AND(" + - s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " + - "OR(=(currency, o_currency), =(secondary_key, o_secondary_key)))"), - term("join", - "o_rowtime", - "o_amount", - "o_currency", - "o_secondary_key", - "rowtime", - "currency", - "rate", - "secondary_key"), - term("joinType", "InnerJoin") - ), - term("select", "*(o_amount, rate) AS rate", "secondary_key") - ), - streamTableNode(0), - term("where", "=(t3_secondary_key, secondary_key)"), - term("join", "rate, secondary_key, t3_comment, t3_secondary_key"), - term("joinType", "InnerJoin") - ) - } - def getExpectedTemporalTableFunctionOnTopOfQueryPlan(): String = { unaryNode( "DataStreamCalc", binaryNode( "DataStreamTemporalTableJoin", - streamTableNode(0), + streamTableNode(orders), unaryNode( "DataStreamCalc", - streamTableNode(1), + streamTableNode(ratesHistory), term("select", "currency", "*(rate, 2) AS rate", "rowtime"), term("where", ">(rate, 100)")), term("where", @@ -282,3 +276,4 @@ object TemporalTableJoinTest { ) } } + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/CorrelateStringExpressionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/CorrelateStringExpressionTest.scala index 993ac5c1c251f7..e2068e7a31fd8a 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/CorrelateStringExpressionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/table/stringexpr/CorrelateStringExpressionTest.scala @@ -25,6 +25,9 @@ import org.apache.flink.table.api.scala._ import org.apache.flink.table.utils._ import org.apache.flink.types.Row import org.junit.Test +import org.mockito.Mockito.{mock, when} +import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream} +import org.apache.flink.streaming.api.scala.DataStream class CorrelateStringExpressionTest extends TableTestBase { @@ -32,9 +35,15 @@ class CorrelateStringExpressionTest extends TableTestBase { def testCorrelateJoinsWithJoinLateral(): Unit = { val util = streamTestUtil() - val sTab = util.addTable[(Int, Long, String)]('a, 'b, 'c) val typeInfo = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING): _*) - val jTab = util.addJavaTable[Row](typeInfo,"MyTab","a, b, c") + val jDs = mock(classOf[JDataStream[Row]]) + when(jDs.getType).thenReturn(typeInfo) + + val sDs = mock(classOf[DataStream[Row]]) + when(sDs.javaStream).thenReturn(jDs) + + val jTab = util.javaTableEnv.fromDataStream(jDs, "a, b, c") + val sTab = util.tableEnv.fromDataStream(sDs, 'a, 'b, 'c) // test cross join val func1 = new TableFunc1 @@ -95,9 +104,15 @@ class CorrelateStringExpressionTest extends TableTestBase { def testFlatMap(): Unit = { val util = streamTestUtil() - val sTab = util.addTable[(Int, Long, String)]('a, 'b, 'c) val typeInfo = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING): _*) - val jTab = util.addJavaTable[Row](typeInfo,"MyTab","a, b, c") + val jDs = mock(classOf[JDataStream[Row]]) + when(jDs.getType).thenReturn(typeInfo) + + val sDs = mock(classOf[DataStream[Row]]) + when(sDs.javaStream).thenReturn(jDs) + + val jTab = util.javaTableEnv.fromDataStream(jDs, "a, b, c") + val sTab = util.tableEnv.fromDataStream(sDs, 'a, 'b, 'c) // test flatMap val func1 = new TableFunc1 diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/match/PatternTranslatorTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/match/PatternTranslatorTestBase.scala index 106fc7917f7d77..464c3b4a67d94a 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/match/PatternTranslatorTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/match/PatternTranslatorTestBase.scala @@ -54,6 +54,7 @@ abstract class PatternTranslatorTestBase extends TestLogger{ val jDataStreamMock = mock(classOf[JDataStream[Row]]) when(dataStreamMock.javaStream).thenReturn(jDataStreamMock) when(jDataStreamMock.getType).thenReturn(typeInfo) + when(jDataStreamMock.getId).thenReturn(0) val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = StreamTableEnvironment.create(env).asInstanceOf[StreamTableEnvImpl] diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala index ce4de146c3bc35..5c0d7297360e31 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/ExpressionReductionRulesTest.scala @@ -31,7 +31,7 @@ class ExpressionReductionRulesTest extends TableTestBase { @Test def testReduceCalcExpressionForBatchSQL(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "(3+4)+a, " + @@ -51,7 +51,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "+(7, a) AS EXPR$0", "+(b, 3) AS EXPR$1", @@ -76,7 +76,7 @@ class ExpressionReductionRulesTest extends TableTestBase { @Test def testReduceProjectExpressionForBatchSQL(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "(3+4)+a, " + @@ -96,7 +96,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "+(7, a) AS EXPR$0", "+(b, 3) AS EXPR$1", @@ -120,7 +120,7 @@ class ExpressionReductionRulesTest extends TableTestBase { @Test def testReduceFilterExpressionForBatchSQL(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "*" + @@ -128,7 +128,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "c"), term("where", ">(a, 8)") ) @@ -155,7 +155,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "13 AS _c0", "'b' AS _c1", @@ -191,7 +191,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "13 AS _c0", "'b' AS _c1", @@ -218,7 +218,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "a", "b", "c"), term("where", ">(a, 8)") ) @@ -229,7 +229,7 @@ class ExpressionReductionRulesTest extends TableTestBase { @Test def testReduceCalcExpressionForStreamSQL(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "(3+4)+a, " + @@ -249,7 +249,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "+(7, a) AS EXPR$0", "+(b, 3) AS EXPR$1", @@ -274,7 +274,7 @@ class ExpressionReductionRulesTest extends TableTestBase { @Test def testReduceProjectExpressionForStreamSQL(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "(3+4)+a, " + @@ -294,7 +294,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "+(7, a) AS EXPR$0", "+(b, 3) AS EXPR$1", @@ -318,7 +318,7 @@ class ExpressionReductionRulesTest extends TableTestBase { @Test def testReduceFilterExpressionForStreamSQL(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "*" + @@ -326,7 +326,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c"), term("where", ">(a, 8)") ) @@ -353,7 +353,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "13 AS _c0", "'b' AS _c1", @@ -389,7 +389,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "13 AS _c0", "'b' AS _c1", @@ -416,7 +416,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c"), term("where", ">(a, 8)") ) @@ -428,7 +428,7 @@ class ExpressionReductionRulesTest extends TableTestBase { def testNestedTablesReductionStream(): Unit = { val util = streamTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val newTable = util.tableEnv.sqlQuery("SELECT 1 + 1 + a AS a FROM MyTable") @@ -437,7 +437,10 @@ class ExpressionReductionRulesTest extends TableTestBase { val sqlQuery = "SELECT a FROM NewTable" // 1+1 should be normalized to 2 - val expected = unaryNode("DataStreamCalc", streamTableNode(0), term("select", "+(2, a) AS a")) + val expected = unaryNode( + "DataStreamCalc", + streamTableNode(table), + term("select", "+(2, a) AS a")) util.verifySql(sqlQuery, expected) } @@ -446,7 +449,7 @@ class ExpressionReductionRulesTest extends TableTestBase { def testNestedTablesReductionBatch(): Unit = { val util = batchTestUtil() - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val table = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val newTable = util.tableEnv.sqlQuery("SELECT 1 + 1 + a AS a FROM MyTable") @@ -455,7 +458,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val sqlQuery = "SELECT a FROM NewTable" // 1+1 should be normalized to 2 - val expected = unaryNode("DataSetCalc", batchTableNode(0), term("select", "+(2, a) AS a")) + val expected = unaryNode("DataSetCalc", batchTableNode(table), term("select", "+(2, a) AS a")) util.verifySql(sqlQuery, expected) } @@ -472,7 +475,7 @@ class ExpressionReductionRulesTest extends TableTestBase { .where("d.isNull") .select('a, 'b, 'c) - val expected: String = streamTableNode(0) + val expected: String = streamTableNode(table) util.verifyTable(result, expected) } @@ -489,7 +492,7 @@ class ExpressionReductionRulesTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(table), term("select", "a", "b", "c"), term("where", s"IS NULL(NonDeterministicNullFunc$$())") ) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/NormalizationRulesTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/NormalizationRulesTest.scala index 8278a6921c3549..910b7194bc2a35 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/NormalizationRulesTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/NormalizationRulesTest.scala @@ -42,18 +42,20 @@ class NormalizationRulesTest extends TableTestBase { .build() util.tableEnv.getConfig.setPlannerConfig(cc) - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "COUNT(DISTINCT a)" + "FROM MyTable group by b" + val streamNode = batchTableNode(t).replace("DataSetScan", "FlinkLogicalDataSetScan") + // expect double aggregate val expected = unaryNode("LogicalProject", unaryNode("LogicalAggregate", unaryNode("LogicalAggregate", unaryNode("LogicalProject", - values("LogicalTableScan", term("table", "[_DataSetTable_0]")), + streamNode, term("b", "$1"), term("a", "$0")), term("group", "{0, 1}")), term("group", "{0}"), term("EXPR$0", "COUNT($1)") @@ -76,19 +78,21 @@ class NormalizationRulesTest extends TableTestBase { .build() util.tableEnv.getConfig.setPlannerConfig(cc) - util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) val sqlQuery = "SELECT " + "COUNT(DISTINCT a)" + "FROM MyTable group by b" + val streamNode = streamTableNode(t).replace("DataStreamScan", "FlinkLogicalDataStreamScan") + // expect double aggregate val expected = unaryNode( "LogicalProject", unaryNode("LogicalAggregate", unaryNode("LogicalAggregate", unaryNode("LogicalProject", - values("LogicalTableScan", term("table", "[_DataStreamTable_0]")), + streamNode, term("b", "$1"), term("a", "$0")), term("group", "{0, 1}")), term("group", "{0}"), term("EXPR$0", "COUNT($1)") diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala index 4245f2cddc2756..97eb7c7e49f57e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala @@ -29,8 +29,14 @@ class QueryDecorrelationTest extends TableTestBase { @Test def testCorrelationScalarAggAndFilter(): Unit = { val util = batchTestUtil() - util.addTable[(Int, String, String, Int, Int)]("emp", 'empno, 'ename, 'job, 'salary, 'deptno) - util.addTable[(Int, String)]("dept", 'deptno, 'name) + val table = util.addTable[(Int, String, String, Int, Int)]( + "emp", + 'empno, + 'ename, + 'job, + 'salary, + 'deptno) + val table1 = util.addTable[(Int, String)]("dept", 'deptno, 'name) val sql = "SELECT e1.empno\n" + "FROM emp e1, dept d1 where e1.deptno = d1.deptno\n" + @@ -47,13 +53,13 @@ class QueryDecorrelationTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "empno", "salary", "deptno"), term("where", "<(deptno, 10)") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "deptno"), term("where", "<(deptno, 15)") ), @@ -67,7 +73,7 @@ class QueryDecorrelationTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "empno", "salary"), term("where", "IS NOT NULL(empno)") ), @@ -87,8 +93,14 @@ class QueryDecorrelationTest extends TableTestBase { @Test def testDecorrelateWithMultiAggregate(): Unit = { val util = batchTestUtil() - util.addTable[(Int, String, String, Int, Int)]("emp", 'empno, 'ename, 'job, 'salary, 'deptno) - util.addTable[(Int, String)]("dept", 'deptno, 'name) + val table = util.addTable[(Int, String, String, Int, Int)]( + "emp", + 'empno, + 'ename, + 'job, + 'salary, + 'deptno) + val table1 = util.addTable[(Int, String)]("dept", 'deptno, 'name) val sql = "select sum(e1.empno) from emp e1, dept d1 " + "where e1.deptno = d1.deptno " + @@ -108,12 +120,12 @@ class QueryDecorrelationTest extends TableTestBase { "DataSetJoin", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "empno", "salary", "deptno") ), unaryNode( "DataSetCalc", - batchTableNode(1), + batchTableNode(table1), term("select", "deptno") ), term("where", "=(deptno, deptno0)"), @@ -126,7 +138,7 @@ class QueryDecorrelationTest extends TableTestBase { "DataSetAggregate", unaryNode( "DataSetCalc", - batchTableNode(0), + batchTableNode(table), term("select", "deptno", "salary"), term("where", "IS NOT NULL(deptno)") ), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala index be93f7219f7a0f..6b7744d763d184 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala @@ -47,7 +47,7 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "FLOOR(CAST(rowtime)", "FLAG(DAY)) AS rowtime"), term("where", ">(long, 0)") ) @@ -64,7 +64,7 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "rowtime", "long", "int", "PROCTIME(proctime) AS proctime") ) @@ -83,7 +83,7 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "rowtime"), term("where", ">(CAST(rowtime), 1990-12-02 12:11:11)") ) @@ -106,7 +106,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "CAST(rowtime) AS rowtime", "long") ), term("groupBy", "rowtime"), @@ -133,7 +133,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "CAST(rowtime) AS rowtime", "long") ), term("groupBy", "long"), @@ -157,7 +157,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamCorrelate", - streamTableNode(0), + streamTableNode(t), term("invocation", s"${func.functionIdentifier}(CAST($$0):TIMESTAMP(3) NOT NULL, PROCTIME($$3), '')"), term("correlate", s"table(TableFunc(CAST(rowtime), PROCTIME(proctime), ''))"), @@ -186,7 +186,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(t), term("groupBy", "long"), term("window", "TumblingGroupWindow('w, 'rowtime, 100.millis)"), term("select", "long", "SUM(int) AS TMP_1", "end('w) AS TMP_0") @@ -208,12 +208,12 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamUnion", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "rowtime") ), unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "rowtime") ), term("all", "true"), @@ -244,7 +244,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(t), term("groupBy", "long"), term("window", "TumblingGroupWindow('w, 'rowtime, 100.millis)"), term("select", "long", "SUM(int) AS TMP_1", "rowtime('w) AS TMP_0") @@ -264,7 +264,7 @@ class TimeIndicatorConversionTest extends TableTestBase { @Test def testGroupingOnProctime(): Unit = { val util = streamTestUtil() - util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime) + val t = util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime) val result = util.tableEnv.sqlQuery("SELECT COUNT(long) FROM MyTable GROUP BY proctime") @@ -274,7 +274,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "PROCTIME(proctime) AS proctime", "long") ), term("groupBy", "proctime"), @@ -289,7 +289,7 @@ class TimeIndicatorConversionTest extends TableTestBase { @Test def testAggregationOnProctime(): Unit = { val util = streamTestUtil() - util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime) + val t = util.addTable[(Long, Int)]("MyTable" , 'long, 'int, 'proctime.proctime) val result = util.tableEnv.sqlQuery("SELECT MIN(proctime) FROM MyTable GROUP BY long") @@ -299,7 +299,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamGroupAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "long", "PROCTIME(proctime) AS proctime") ), term("groupBy", "long"), @@ -314,7 +314,7 @@ class TimeIndicatorConversionTest extends TableTestBase { @Test def testWindowSql(): Unit = { val util = streamTestUtil() - util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int) + val t = util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int) val result = util.tableEnv.sqlQuery( "SELECT TUMBLE_END(rowtime, INTERVAL '0.1' SECOND) AS `rowtime`, `long`, " + @@ -325,7 +325,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamGroupWindowAggregate", - streamTableNode(0), + streamTableNode(t), term("groupBy", "long"), term("window", "TumblingGroupWindow('w$, 'rowtime, 100.millis)"), term("select", @@ -345,7 +345,7 @@ class TimeIndicatorConversionTest extends TableTestBase { @Test def testWindowWithAggregationOnRowtime(): Unit = { val util = streamTestUtil() - util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int) + val t = util.addTable[(Long, Long, Int)]("MyTable", 'rowtime.rowtime, 'long, 'int) val result = util.tableEnv.sqlQuery("SELECT MIN(rowtime), long FROM MyTable " + "GROUP BY long, TUMBLE(rowtime, INTERVAL '0.1' SECOND)") @@ -356,7 +356,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", - streamTableNode(0), + streamTableNode(t), term("select", "long", "rowtime", "CAST(rowtime) AS rowtime0") ), term("groupBy", "long"), @@ -483,7 +483,7 @@ class TimeIndicatorConversionTest extends TableTestBase { @Test def testMatchRecognizeRowtimeMaterialization(): Unit = { val util = streamTestUtil() - util.addTable[(Long, Long, Int)]( + val t = util.addTable[(Long, Long, Int)]( "RowtimeTicker", 'rowtime.rowtime, 'symbol, @@ -511,7 +511,7 @@ class TimeIndicatorConversionTest extends TableTestBase { val expected = unaryNode( "DataStreamMatch", - streamTableNode(0), + streamTableNode(t), term("partitionBy", "symbol"), term("orderBy", "rowtime ASC"), term("measures", @@ -531,7 +531,7 @@ class TimeIndicatorConversionTest extends TableTestBase { @Test def testMatchRecognizeProctimeMaterialization(): Unit = { val util = streamTestUtil() - util.addTable[(Long, Long, Int)]( + val t = util.addTable[(Long, Long, Int)]( "ProctimeTicker", 'rowtime.rowtime, 'symbol, @@ -562,7 +562,7 @@ class TimeIndicatorConversionTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamMatch", - streamTableNode(0), + streamTableNode(t), term("partitionBy", "symbol"), term("orderBy", "rowtime ASC"), term("measures", diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala index eb571d504a5271..37e340af220d49 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableSourceITCase.scala @@ -74,6 +74,31 @@ class TableSourceITCase extends AbstractTestBase { // info. } + @Test + def testUnregisteredCsvTableSource(): Unit = { + + val csvTable = CommonTestData.getCsvTableSource + StreamITCase.testResults = mutable.MutableList() + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = StreamTableEnvironment.create(env) + + tEnv.fromTableSource(csvTable) + .where('id > 4) + .select('last, 'score * 2) + .toAppendStream[Row] + .addSink(new StreamITCase.StringSink[Row]) + + env.execute() + + val expected = Seq( + "Williams,69.0", + "Miller,13.56", + "Smith,180.2", + "Williams,4.68") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + @Test def testCsvTableSource(): Unit = { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala index 1d6b330252915d..db0ec531ff4cbf 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala @@ -28,20 +28,19 @@ import org.apache.flink.streaming.api.environment.LocalStreamEnvironment import org.apache.flink.streaming.api.functions.source.SourceFunction import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.java.{BatchTableEnvImpl => JavaBatchTableEnvImpl, StreamTableEnvImpl => JavaStreamTableEnvImpl} -import org.apache.flink.table.api.scala.{BatchTableEnvironment => ScalaBatchTableEnv, StreamTableEnvironment => ScalaStreamTableEnv} -import org.apache.flink.table.api.scala.{BatchTableEnvImpl => ScalaBatchTableEnvImpl, StreamTableEnvImpl => ScalaStreamTableEnvImpl} -import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.{BatchTableEnvImpl => ScalaBatchTableEnvImpl, _} import org.apache.flink.table.api.{Table, TableConfig, TableImpl, TableSchema} -import org.apache.flink.table.catalog.{CatalogManager, GenericCatalogDatabase, GenericInMemoryCatalog} +import org.apache.flink.table.catalog.{CatalogManager, GenericInMemoryCatalog} import org.apache.flink.table.expressions.Expression import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction} -import org.apache.flink.table.utils.TableTestUtil.{createCatalogManager, extractBuiltinPath} +import org.apache.flink.table.operations.{CatalogTableOperation, DataSetTableOperation, DataStreamTableOperation} +import org.apache.flink.table.utils.TableTestUtil.createCatalogManager import org.junit.Assert.assertEquals -import org.junit.{ComparisonFailure, Rule} import org.junit.rules.ExpectedException +import org.junit.{ComparisonFailure, Rule} import org.mockito.Mockito.{mock, when} -import util.control.Breaks._ +import scala.util.control.Breaks._ /** * Test base for testing Table API / SQL plans. @@ -108,16 +107,8 @@ abstract class TableTestUtil(verifyCatalogPath: Boolean = false) { // depends on the native machine (Little/Big Endian) val actualNoCharset = actual.replace("_UTF-16LE'", "'").replace("_UTF-16BE'", "'") - // majority of tests did not assume existence of Catalog API. - // this enables disabling catalog path verification - val actualWithAdjustedPath = if (!verifyCatalogPath) { - actualNoCharset.replaceAll("default_catalog, default_database, ", "") - } else { - actualNoCharset - } - val expectedLines = expected.split("\n").map(_.trim) - val actualLines = actualWithAdjustedPath.split("\n").map(_.trim) + val actualLines = actualNoCharset.split("\n").map(_.trim) val expectedMessage = expectedLines.mkString("\n") val actualMessage = actualLines.mkString("\n") @@ -210,12 +201,16 @@ object TableTestUtil { term("tuples", "[" + listValues.mkString(", ") + "]") } - def batchTableNode(idx: Int): String = { - s"DataSetScan(table=[[_DataSetTable_$idx]])" + def batchTableNode(table: Table): String = { + val dataSetTable = table.getTableOperation.asInstanceOf[DataSetTableOperation[_]] + s"DataSetScan(ref=[${System.identityHashCode(dataSetTable.getDataSet)}], " + + s"fields=[${dataSetTable.getTableSchema.getFieldNames.mkString(", ")}])" } - def streamTableNode(idx: Int): String = { - s"DataStreamScan(table=[[_DataStreamTable_$idx]])" + def streamTableNode(table: Table): String = { + val dataStreamTable = table.getTableOperation.asInstanceOf[DataStreamTableOperation[_]] + s"DataStreamScan(id=[${dataStreamTable.getDataStream.getId}], " + + s"fields=[${dataStreamTable.getTableSchema.getFieldNames.mkString(", ")}])" } } diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testFilter0.out b/flink-table/flink-table-planner/src/test/scala/resources/testFilter0.out index c5a4c9fd755e97..81d2b6b4562d4c 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testFilter0.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testFilter0.out @@ -1,10 +1,10 @@ == Abstract Syntax Tree == LogicalFilter(condition=[=(MOD($0, 2), 0)]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_0]]) + %logicalSourceNode0% == Optimized Logical Plan == DataSetCalc(select=[a, b], where=[=(MOD(a, 2), 0)]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_0]]) + %sourceNode0% == Physical Execution Plan == Stage 3 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testFilter1.out b/flink-table/flink-table-planner/src/test/scala/resources/testFilter1.out index 88381f45d7a94f..9eb5c7aae4c6cb 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testFilter1.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testFilter1.out @@ -1,10 +1,10 @@ == Abstract Syntax Tree == LogicalFilter(condition=[=(MOD($0, 2), 0)]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_0]]) + %logicalSourceNode0% == Optimized Logical Plan == DataSetCalc(select=[a, b], where=[=(MOD(a, 2), 0)]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_0]]) + %sourceNode0% == Physical Execution Plan == Stage 3 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testFilterStream0.out b/flink-table/flink-table-planner/src/test/scala/resources/testFilterStream0.out index 3a1088a65702c8..cfedce0ae7f707 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testFilterStream0.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testFilterStream0.out @@ -1,10 +1,10 @@ == Abstract Syntax Tree == LogicalFilter(condition=[=(MOD($0, 2), 0)]) - LogicalTableScan(table=[[default_catalog, default_database, _DataStreamTable_0]]) + %logicalSourceNode0% == Optimized Logical Plan == DataStreamCalc(select=[a, b], where=[=(MOD(a, 2), 0)]) - DataStreamScan(table=[[default_catalog, default_database, _DataStreamTable_0]]) + %sourceNode0% == Physical Execution Plan == Stage 1 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testJoin0.out b/flink-table/flink-table-planner/src/test/scala/resources/testJoin0.out index 0d995ac6f07513..986d21a5832cb1 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testJoin0.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testJoin0.out @@ -2,14 +2,14 @@ LogicalProject(a=[$0], c=[$2]) LogicalFilter(condition=[=($1, $3)]) LogicalJoin(condition=[true], joinType=[inner]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %logicalSourceNode0% + %logicalSourceNode1% == Optimized Logical Plan == DataSetCalc(select=[a, c]) DataSetJoin(where=[=(b, d)], join=[a, b, c, d], joinType=[InnerJoin]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %sourceNode0% + %sourceNode1% == Physical Execution Plan == Stage 4 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testJoin1.out b/flink-table/flink-table-planner/src/test/scala/resources/testJoin1.out index cd9597547d574d..0db76f94d0f2e7 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testJoin1.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testJoin1.out @@ -2,14 +2,14 @@ LogicalProject(a=[$0], c=[$2]) LogicalFilter(condition=[=($1, $3)]) LogicalJoin(condition=[true], joinType=[inner]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %logicalSourceNode0% + %logicalSourceNode1% == Optimized Logical Plan == DataSetCalc(select=[a, c]) DataSetJoin(where=[=(b, d)], join=[a, b, c, d], joinType=[InnerJoin]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %sourceNode0% + %sourceNode1% == Physical Execution Plan == Stage 4 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testUnion0.out b/flink-table/flink-table-planner/src/test/scala/resources/testUnion0.out index 541efd708a1a5a..dbee6093bb9d2a 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testUnion0.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testUnion0.out @@ -1,12 +1,12 @@ == Abstract Syntax Tree == LogicalUnion(all=[true]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %logicalSourceNode0% + %logicalSourceNode1% == Optimized Logical Plan == DataSetUnion(all=[true], union=[count, word]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %sourceNode0% + %sourceNode1% == Physical Execution Plan == Stage 3 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testUnion1.out b/flink-table/flink-table-planner/src/test/scala/resources/testUnion1.out index 63d5865d761f99..cd12e1d4598c93 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testUnion1.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testUnion1.out @@ -1,12 +1,12 @@ == Abstract Syntax Tree == LogicalUnion(all=[true]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - LogicalTableScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %logicalSourceNode0% + %logicalSourceNode1% == Optimized Logical Plan == DataSetUnion(all=[true], union=[count, word]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_0]]) - DataSetScan(table=[[default_catalog, default_database, _DataSetTable_1]]) + %sourceNode0% + %sourceNode1% == Physical Execution Plan == Stage 3 : Data Source diff --git a/flink-table/flink-table-planner/src/test/scala/resources/testUnionStream0.out b/flink-table/flink-table-planner/src/test/scala/resources/testUnionStream0.out index 4affe46b366c0e..e8fb70b23fe736 100644 --- a/flink-table/flink-table-planner/src/test/scala/resources/testUnionStream0.out +++ b/flink-table/flink-table-planner/src/test/scala/resources/testUnionStream0.out @@ -1,12 +1,12 @@ == Abstract Syntax Tree == LogicalUnion(all=[true]) - LogicalTableScan(table=[[default_catalog, default_database, _DataStreamTable_0]]) - LogicalTableScan(table=[[default_catalog, default_database, _DataStreamTable_1]]) + %logicalSourceNode0% + %logicalSourceNode1% == Optimized Logical Plan == DataStreamUnion(all=[true], union all=[count, word]) - DataStreamScan(table=[[default_catalog, default_database, _DataStreamTable_0]]) - DataStreamScan(table=[[default_catalog, default_database, _DataStreamTable_1]]) + %sourceNode0% + %sourceNode1% == Physical Execution Plan == Stage 1 : Data Source From 99d4fd39f2b44630130d47b16ebcd69923f09e9a Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Fri, 24 May 2019 12:55:48 +0200 Subject: [PATCH 63/92] [hotfix][table-api] Port TableSource specific descriptors --- .../flink/table/typeutils/FieldInfoUtils.java | 23 +++-- .../sources/DefinedProctimeAttribute.java} | 38 ++++---- .../sources/DefinedRowtimeAttributes.java} | 41 ++++---- .../sources/RowtimeAttributeDescriptor.java | 77 +++++++++++++++ .../table/sources/definedTimeAttributes.scala | 95 ------------------- 5 files changed, 133 insertions(+), 141 deletions(-) rename flink-table/{flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala => flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedProctimeAttribute.java} (51%) rename flink-table/{flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/BatchTableSourceTable.scala => flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedRowtimeAttributes.java} (50%) create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/RowtimeAttributeDescriptor.java delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/definedTimeAttributes.scala diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java index a0a3cc6e03c180..97c341c05799b7 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/typeutils/FieldInfoUtils.java @@ -49,7 +49,7 @@ import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.TIME_ATTRIBUTES; /** - * Utility classes for extracting names and indices of fields from different {@link TypeInformation}s. + * Utility methods for extracting names and indices of fields from different {@link TypeInformation}s. */ public class FieldInfoUtils { @@ -234,8 +234,19 @@ public static TypeInformation[] getFieldTypes(TypeInformation inputType) { return fieldTypes; } - public static TableSchema calculateTableSchema( - TypeInformation typeInfo, + /** + * Derives {@link TableSchema} out of a {@link TypeInformation}. It is complementary to other + * methods in this class. This also performs translation from time indicator markers such as + * {@link TimeIndicatorTypeInfo#ROWTIME_STREAM_MARKER} etc. to a corresponding + * {@link TimeIndicatorTypeInfo}. + * + * @param typeInfo input type info to calculate fields type infos from + * @param fieldIndexes indices within the typeInfo of the resulting Table schema + * @param fieldNames names of the fields of the resulting schema + * @return calculates resulting schema + */ + public static TableSchema calculateTableSchema( + TypeInformation typeInfo, int[] fieldIndexes, String[] fieldNames) { @@ -292,6 +303,8 @@ public static TableSchema calculateTableSchema( return new TableSchema(fieldNames, types); } + /* Utility methods */ + private static Optional> extractTimeMarkerType(int idx) { switch (idx) { case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER: @@ -306,10 +319,6 @@ private static Optional> extractTimeMarkerType(int idx) { } } - - - /* Utility methods */ - private static Set extractFieldInfoFromAtomicType(Expression[] exprs) { boolean referenced = false; FieldInfo fieldInfo = null; diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedProctimeAttribute.java similarity index 51% rename from flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala rename to flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedProctimeAttribute.java index 38b7df851e5420..803ca6ab1c88c3 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/StreamTableSourceTable.scala +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedProctimeAttribute.java @@ -16,27 +16,27 @@ * limitations under the License. */ -package org.apache.flink.table.plan.schema +package org.apache.flink.table.sources; -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.stats.FlinkStatistic -import org.apache.flink.table.sources.{StreamTableSource, TableSourceUtil} +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.table.api.TableSchema; -class StreamTableSourceTable[T]( - tableSource: StreamTableSource[T], - statistic: FlinkStatistic = FlinkStatistic.UNKNOWN) - extends TableSourceTable[T]( - tableSource, - statistic) { +import javax.annotation.Nullable; - TableSourceUtil.validateTableSource(tableSource) +/** + * Extends a {@link TableSource} to specify a processing time attribute. + */ +@PublicEvolving +public interface DefinedProctimeAttribute { - def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { - TableSourceUtil.getRelDataType( - tableSource, - None, - streaming = true, - typeFactory.asInstanceOf[FlinkTypeFactory]) - } + /** + * Returns the name of a processing time attribute or null if no processing time attribute is + * present. + * + *

The referenced attribute must be present in the {@link TableSchema} of the {@link TableSource} and of + * type {@link Types#SQL_TIMESTAMP}. + */ + @Nullable + String getProctimeAttribute(); } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/BatchTableSourceTable.scala b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedRowtimeAttributes.java similarity index 50% rename from flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/BatchTableSourceTable.scala rename to flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedRowtimeAttributes.java index 14e8cb129b6f3d..5238084ccde1b1 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/BatchTableSourceTable.scala +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedRowtimeAttributes.java @@ -16,28 +16,29 @@ * limitations under the License. */ -package org.apache.flink.table.plan.schema +package org.apache.flink.table.sources; -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.stats.FlinkStatistic -import org.apache.flink.table.sources.{BatchTableSource, TableSourceUtil} +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.table.api.TableSchema; -class BatchTableSourceTable[T]( - tableSource: BatchTableSource[T], - statistic: FlinkStatistic = FlinkStatistic.UNKNOWN) - extends TableSourceTable[T]( - tableSource, - statistic) { +import java.util.List; - TableSourceUtil.validateTableSource(tableSource) +/** + * Extends a {@link TableSource} to specify rowtime attributes via a + * {@link RowtimeAttributeDescriptor}. + */ +@PublicEvolving +public interface DefinedRowtimeAttributes { - override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { - TableSourceUtil.getRelDataType( - tableSource, - None, - streaming = false, - typeFactory.asInstanceOf[FlinkTypeFactory]) - } + /** + * Returns a list of {@link RowtimeAttributeDescriptor} for all rowtime + * attributes of the table. + * + *

All referenced attributes must be present in the {@link TableSchema} + * of the {@link TableSource} and of type {@link Types#SQL_TIMESTAMP}. + * + * @return A list of {@link RowtimeAttributeDescriptor}. + */ + List getRowtimeAttributeDescriptors(); } - diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/RowtimeAttributeDescriptor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/RowtimeAttributeDescriptor.java new file mode 100644 index 00000000000000..6851dd2e6d8a93 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/RowtimeAttributeDescriptor.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources; + +import org.apache.flink.table.sources.tsextractors.TimestampExtractor; +import org.apache.flink.table.sources.wmstrategies.WatermarkStrategy; + +import java.util.Objects; + +/** + * Describes a rowtime attribute of a {@link TableSource}. + */ +public final class RowtimeAttributeDescriptor { + + private final String attributeName; + private final TimestampExtractor timestampExtractor; + private final WatermarkStrategy watermarkStrategy; + + public RowtimeAttributeDescriptor( + String attributeName, + TimestampExtractor timestampExtractor, + WatermarkStrategy watermarkStrategy) { + this.attributeName = attributeName; + this.timestampExtractor = timestampExtractor; + this.watermarkStrategy = watermarkStrategy; + } + + /** Returns the name of the rowtime attribute. */ + public String getAttributeName() { + return attributeName; + } + + /** Returns the [[TimestampExtractor]] for the attribute. */ + public TimestampExtractor getTimestampExtractor() { + return timestampExtractor; + } + + /** Returns the [[WatermarkStrategy]] for the attribute. */ + public WatermarkStrategy getWatermarkStrategy() { + return watermarkStrategy; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RowtimeAttributeDescriptor that = (RowtimeAttributeDescriptor) o; + return Objects.equals(attributeName, that.attributeName) && + Objects.equals(timestampExtractor, that.timestampExtractor) && + Objects.equals(watermarkStrategy, that.watermarkStrategy); + } + + @Override + public int hashCode() { + return Objects.hash(attributeName, timestampExtractor, watermarkStrategy); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/definedTimeAttributes.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/definedTimeAttributes.scala deleted file mode 100644 index b144312caa5bd6..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/definedTimeAttributes.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.sources - -import java.util -import java.util.Objects -import javax.annotation.Nullable - -import org.apache.flink.table.api.TableSchema -import org.apache.flink.table.api.Types -import org.apache.flink.table.sources.tsextractors.TimestampExtractor -import org.apache.flink.table.sources.wmstrategies.WatermarkStrategy - -/** - * Extends a [[TableSource]] to specify a processing time attribute. - */ -trait DefinedProctimeAttribute { - - /** - * Returns the name of a processing time attribute or null if no processing time attribute is - * present. - * - * The referenced attribute must be present in the [[TableSchema]] of the [[TableSource]] and of - * type [[Types.SQL_TIMESTAMP]]. - */ - @Nullable - def getProctimeAttribute: String -} - -/** - * Extends a [[TableSource]] to specify rowtime attributes via a - * [[RowtimeAttributeDescriptor]]. - */ -trait DefinedRowtimeAttributes { - - /** - * Returns a list of [[RowtimeAttributeDescriptor]] for all rowtime attributes of the table. - * - * All referenced attributes must be present in the [[TableSchema]] of the [[TableSource]] and of - * type [[Types.SQL_TIMESTAMP]]. - * - * @return A list of [[RowtimeAttributeDescriptor]]. - */ - def getRowtimeAttributeDescriptors: util.List[RowtimeAttributeDescriptor] -} - -/** - * Describes a rowtime attribute of a [[TableSource]]. - * - * @param attributeName The name of the rowtime attribute. - * @param timestampExtractor The timestamp extractor to derive the values of the attribute. - * @param watermarkStrategy The watermark strategy associated with the attribute. - */ -class RowtimeAttributeDescriptor( - val attributeName: String, - val timestampExtractor: TimestampExtractor, - val watermarkStrategy: WatermarkStrategy) { - - /** Returns the name of the rowtime attribute. */ - def getAttributeName: String = attributeName - - /** Returns the [[TimestampExtractor]] for the attribute. */ - def getTimestampExtractor: TimestampExtractor = timestampExtractor - - /** Returns the [[WatermarkStrategy]] for the attribute. */ - def getWatermarkStrategy: WatermarkStrategy = watermarkStrategy - - override def equals(other: Any): Boolean = other match { - case that: RowtimeAttributeDescriptor => - Objects.equals(attributeName, that.attributeName) && - Objects.equals(timestampExtractor, that.timestampExtractor) && - Objects.equals(watermarkStrategy, that.watermarkStrategy) - case _ => false - } - - override def hashCode(): Int = { - Objects.hash(attributeName, timestampExtractor, watermarkStrategy) - } -} From 668259c70c3adfa3f073dfedabb1921f620b3177 Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Fri, 24 May 2019 13:41:21 +0200 Subject: [PATCH 64/92] [hotfix][table-planner] Fix getting external catalog --- .../main/scala/org/apache/flink/table/api/TableEnvImpl.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala index 7f329dba4361fd..6db31eefec6e76 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala @@ -83,9 +83,6 @@ abstract class TableEnvImpl( // a counter for unique attribute names private[flink] val attrNameCntr: AtomicInteger = new AtomicInteger(0) - // registered external catalog names -> catalog - private val externalCatalogs = new mutable.HashMap[String, ExternalCatalog] - private[flink] val operationTreeBuilder = new OperationTreeBuilder(this) private val planningConfigurationBuilder: PlanningConfigurationBuilder = @@ -329,7 +326,7 @@ abstract class TableEnvImpl( } override def getRegisteredExternalCatalog(name: String): ExternalCatalog = { - this.externalCatalogs.get(name) match { + JavaScalaConversionUtil.toScala(catalogManager.getExternalCatalog(name)) match { case Some(catalog) => catalog case None => throw new ExternalCatalogNotExistException(name) } From 88c7d82a6148a491d9c11f8d4c7eedc1722ee7a9 Mon Sep 17 00:00:00 2001 From: Dawid Wysakowicz Date: Fri, 24 May 2019 14:10:29 +0200 Subject: [PATCH 65/92] [FLINK-12604][table] Register a TableSource/Sink directly as CatalogTables This closes #8549 --- .../flink/table/catalog/CatalogManager.java | 80 ++++-- .../table/catalog/ConnectorCatalogTable.java | 163 ++++++++++++ .../catalog/TableOperationCatalogView.java | 2 +- .../TableOperationDefaultVisitor.java | 5 + .../operations/TableOperationVisitor.java | 2 + .../operations/TableSourceTableOperation.java | 75 ++++++ .../table/catalog/CalciteCatalogTable.java | 88 ------- .../table/catalog/DatabaseCalciteSchema.java | 16 +- .../table/plan/TableOperationConverter.java | 34 +++ .../flink/table/api/BatchTableEnvImpl.scala | 153 +---------- .../flink/table/api/StreamTableEnvImpl.scala | 142 +---------- .../apache/flink/table/api/TableEnvImpl.scala | 239 ++++++++++++------ .../flink/table/calcite/FlinkRelBuilder.scala | 8 +- .../table/catalog/ExternalCatalogSchema.scala | 2 +- .../table/catalog/ExternalTableUtil.scala | 48 +--- .../plan/nodes/PhysicalTableSourceScan.scala | 14 +- .../nodes/dataset/BatchTableSourceScan.scala | 2 +- .../datastream/StreamTableSourceScan.scala | 2 +- .../logical/FlinkLogicalTableSourceScan.scala | 21 +- .../dataSet/BatchTableSourceScanRule.scala | 10 +- .../StreamTableSourceScanRule.scala | 10 +- .../table/plan/schema/TableSinkTable.scala | 49 ---- .../plan/schema/TableSourceSinkTable.scala | 67 ----- .../table/plan/schema/TableSourceTable.scala | 32 ++- .../catalog/CatalogStructureBuilder.java | 45 ++-- .../table/catalog/PathResolutionTest.java | 3 +- .../table/JavaTableEnvironmentITCase.java | 34 ++- .../catalog/ExternalCatalogSchemaTest.scala | 11 +- 28 files changed, 634 insertions(+), 723 deletions(-) create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableSourceTableOperation.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/CalciteCatalogTable.java delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceSinkTable.scala diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java index 975c87cda9912f..5704043d02afe7 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java @@ -24,7 +24,6 @@ import org.apache.flink.table.catalog.exceptions.CatalogException; import org.apache.flink.table.catalog.exceptions.TableNotExistException; import org.apache.flink.table.factories.TableFactoryUtil; -import org.apache.flink.table.operations.CatalogTableOperation; import org.apache.flink.table.sinks.TableSink; import org.apache.flink.util.StringUtils; @@ -64,6 +63,57 @@ public class CatalogManager { private String currentDatabaseName; + /** + * Temporary solution to handle both {@link CatalogBaseTable} and + * {@link ExternalCatalogTable} in a single call. + */ + public static class ResolvedTable { + private final ExternalCatalogTable externalCatalogTable; + private final CatalogBaseTable catalogTable; + private final TableSchema tableSchema; + private final List tablePath; + + static ResolvedTable externalTable( + List tablePath, + ExternalCatalogTable table, + TableSchema tableSchema) { + return new ResolvedTable(table, null, tableSchema, tablePath); + } + + static ResolvedTable catalogTable( + List tablePath, + CatalogBaseTable table) { + return new ResolvedTable(null, table, table.getSchema(), tablePath); + } + + private ResolvedTable( + ExternalCatalogTable externalCatalogTable, + CatalogBaseTable catalogTable, + TableSchema tableSchema, + List tablePath) { + this.externalCatalogTable = externalCatalogTable; + this.catalogTable = catalogTable; + this.tableSchema = tableSchema; + this.tablePath = tablePath; + } + + public Optional getExternalCatalogTable() { + return Optional.ofNullable(externalCatalogTable); + } + + public Optional getCatalogTable() { + return Optional.ofNullable(catalogTable); + } + + public TableSchema getTableSchema() { + return tableSchema; + } + + public List getTablePath() { + return tablePath; + } + } + public CatalogManager(String defaultCatalogName, Catalog defaultCatalog) { checkArgument( !StringUtils.isNullOrWhitespaceOnly(defaultCatalogName), @@ -243,8 +293,8 @@ public void setCurrentDatabase(String databaseName) { } /** - * Tries to resolve a table path to a {@link CatalogTableOperation}. The algorithm looks for requested table - * in following paths in that order: + * Tries to resolve a table path to a {@link ResolvedTable}. The algorithm looks for requested table + * in the following paths in that order: *

    *
  1. {@code [current-catalog].[current-database].[tablePath]}
  2. *
  3. {@code [current-catalog].[tablePath]}
  4. @@ -252,10 +302,10 @@ public void setCurrentDatabase(String databaseName) { *
* * @param tablePath table path to look for - * @return {@link CatalogTableOperation} containing both fully qualified table identifier and its - * {@link TableSchema}. + * @return {@link ResolvedTable} wrapping original table with additional information about table path and + * unified access to {@link TableSchema}. */ - public Optional resolveTable(String... tablePath) { + public Optional resolveTable(String... tablePath) { checkArgument(tablePath != null && tablePath.length != 0, "Table path must not be null or empty."); List userPath = asList(tablePath); @@ -267,7 +317,7 @@ public Optional resolveTable(String... tablePath) { ); for (List prefix : prefixes) { - Optional potentialTable = lookupPath(prefix, userPath); + Optional potentialTable = lookupPath(prefix, userPath); if (potentialTable.isPresent()) { return potentialTable; } @@ -276,12 +326,12 @@ public Optional resolveTable(String... tablePath) { return Optional.empty(); } - private Optional lookupPath(List prefix, List userPath) { + private Optional lookupPath(List prefix, List userPath) { try { List path = new ArrayList<>(prefix); path.addAll(userPath); - Optional potentialTable = lookupCatalogTable(path); + Optional potentialTable = lookupCatalogTable(path); if (!potentialTable.isPresent()) { potentialTable = lookupExternalTable(path); @@ -292,7 +342,7 @@ private Optional lookupPath(List prefix, List lookupCatalogTable(List path) throws TableNotExistException { + private Optional lookupCatalogTable(List path) throws TableNotExistException { if (path.size() == 3) { Catalog currentCatalog = catalogs.get(path.get(0)); String currentDatabaseName = path.get(1); @@ -300,22 +350,22 @@ private Optional lookupCatalogTable(List path) th ObjectPath objectPath = new ObjectPath(currentDatabaseName, tableName); if (currentCatalog != null && currentCatalog.tableExists(objectPath)) { - TableSchema tableSchema = currentCatalog.getTable(objectPath).getSchema(); - return Optional.of(new CatalogTableOperation( + CatalogBaseTable table = currentCatalog.getTable(objectPath); + return Optional.of(ResolvedTable.catalogTable( asList(path.get(0), currentDatabaseName, tableName), - tableSchema)); + table)); } } return Optional.empty(); } - private Optional lookupExternalTable(List path) { + private Optional lookupExternalTable(List path) { ExternalCatalog currentCatalog = externalCatalogs.get(path.get(0)); return Optional.ofNullable(currentCatalog) .flatMap(externalCatalog -> extractPath(externalCatalog, path.subList(1, path.size() - 1))) .map(finalCatalog -> finalCatalog.getTable(path.get(path.size() - 1))) - .map(table -> new CatalogTableOperation(path, getTableSchema(table))); + .map(table -> ResolvedTable.externalTable(path, table, getTableSchema(table))); } private Optional extractPath(ExternalCatalog rootExternalCatalog, List path) { diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java new file mode 100644 index 00000000000000..ce860842349e0a --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.sinks.TableSink; +import org.apache.flink.table.sources.DefinedProctimeAttribute; +import org.apache.flink.table.sources.DefinedRowtimeAttributes; +import org.apache.flink.table.sources.RowtimeAttributeDescriptor; +import org.apache.flink.table.sources.TableSource; +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A {@link CatalogTable} that wraps a {@link TableSource} and/or {@link TableSink}. + * This allows registering those in a {@link Catalog}. It can not be persisted as the + * source and/or sink might be inline implementations and not be representable in a + * property based form. + * + * @param type of the produced elements by the {@link TableSource} + * @param type of the expected elements by the {@link TableSink} + */ +@Internal +public class ConnectorCatalogTable extends AbstractCatalogTable { + private final TableSource tableSource; + private final TableSink tableSink; + private final boolean isBatch; + + public static ConnectorCatalogTable source(TableSource source, boolean isBatch) { + final TableSchema tableSchema = calculateSourceSchema(source, isBatch); + return new ConnectorCatalogTable<>(source, null, tableSchema, isBatch); + } + + public static ConnectorCatalogTable sink(TableSink sink, boolean isBatch) { + TableSchema tableSchema = new TableSchema(sink.getFieldNames(), sink.getFieldTypes()); + return new ConnectorCatalogTable<>(null, sink, tableSchema, isBatch); + } + + public static ConnectorCatalogTable sourceAndSink( + TableSource source, + TableSink sink, + boolean isBatch) { + TableSchema tableSchema = calculateSourceSchema(source, isBatch); + return new ConnectorCatalogTable<>(source, sink, tableSchema, isBatch); + } + + @VisibleForTesting + protected ConnectorCatalogTable( + TableSource tableSource, + TableSink tableSink, + TableSchema tableSchema, + boolean isBatch) { + super(tableSchema, Collections.emptyMap(), ""); + this.tableSource = tableSource; + this.tableSink = tableSink; + this.isBatch = isBatch; + } + + public Optional> getTableSource() { + return Optional.ofNullable(tableSource); + } + + public Optional> getTableSink() { + return Optional.ofNullable(tableSink); + } + + public boolean isBatch() { + return isBatch; + } + + @Override + public Map toProperties() { + // This effectively makes sure the table cannot be persisted in a catalog. + throw new UnsupportedOperationException("ConnectorCatalogTable cannot be converted to properties"); + } + + @Override + public CatalogBaseTable copy() { + return this; + } + + @Override + public Optional getDescription() { + return Optional.empty(); + } + + @Override + public Optional getDetailedDescription() { + return Optional.empty(); + } + + private static TableSchema calculateSourceSchema(TableSource source, boolean isBatch) { + TableSchema tableSchema = source.getTableSchema(); + if (isBatch) { + return tableSchema; + } + + TypeInformation[] types = Arrays.copyOf(tableSchema.getFieldTypes(), tableSchema.getFieldCount()); + String[] fieldNames = tableSchema.getFieldNames(); + if (source instanceof DefinedRowtimeAttributes) { + updateRowtimeIndicators((DefinedRowtimeAttributes) source, fieldNames, types); + } + if (source instanceof DefinedProctimeAttribute) { + updateProctimeIndicator((DefinedProctimeAttribute) source, fieldNames, types); + } + return new TableSchema(fieldNames, types); + } + + private static void updateRowtimeIndicators( + DefinedRowtimeAttributes source, + String[] fieldNames, + TypeInformation[] types) { + List rowtimeAttributes = source.getRowtimeAttributeDescriptors() + .stream() + .map(RowtimeAttributeDescriptor::getAttributeName) + .collect(Collectors.toList()); + + for (int i = 0; i < fieldNames.length; i++) { + if (rowtimeAttributes.contains(fieldNames[i])) { + types[i] = TimeIndicatorTypeInfo.ROWTIME_INDICATOR; + } + } + } + + private static void updateProctimeIndicator( + DefinedProctimeAttribute source, + String[] fieldNames, + TypeInformation[] types) { + String proctimeAttribute = source.getProctimeAttribute(); + + for (int i = 0; i < fieldNames.length; i++) { + if (fieldNames[i].equals(proctimeAttribute)) { + types[i] = TimeIndicatorTypeInfo.PROCTIME_INDICATOR; + break; + } + } + } +} diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java index 0e9aafe8d2ac90..89596e81154748 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/TableOperationCatalogView.java @@ -25,7 +25,7 @@ import java.util.Optional; /** - * A view created from {@link TableOperation} via operations on {@link org.apache.flink.table.api.Table}. + * A view created from a {@link TableOperation} via operations on {@link org.apache.flink.table.api.Table}. */ @Internal public class TableOperationCatalogView extends AbstractCatalogView { diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationDefaultVisitor.java index 303673ea701843..0773d3051a8608 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationDefaultVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationDefaultVisitor.java @@ -78,6 +78,11 @@ public T visitCatalogTable(CatalogTableOperation catalogTable) { return defaultMethod(catalogTable); } + @Override + public T visitTableSourceTable(TableSourceTableOperation tableSourceTable) { + return defaultMethod(tableSourceTable); + } + @Override public T visitOther(TableOperation other) { return defaultMethod(other); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationVisitor.java index 82c80656f458b3..c7c761839c4d32 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableOperationVisitor.java @@ -47,5 +47,7 @@ public interface TableOperationVisitor { T visitCatalogTable(CatalogTableOperation catalogTable); + T visitTableSourceTable(TableSourceTableOperation tableSourceTable); + T visitOther(TableOperation other); } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableSourceTableOperation.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableSourceTableOperation.java new file mode 100644 index 00000000000000..acd12fb269dc7e --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/TableSourceTableOperation.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.operations; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.sources.TableSource; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Inline scan of a {@link TableSource}. Used only when a {@link org.apache.flink.table.api.Table} was created + * from {@link org.apache.flink.table.api.TableEnvironment#fromTableSource(TableSource)}. + */ +@Internal +public class TableSourceTableOperation extends TableOperation { + + private final TableSource tableSource; + private final boolean isBatch; + + public TableSourceTableOperation(TableSource tableSource, boolean isBatch) { + this.tableSource = tableSource; + this.isBatch = isBatch; + } + + @Override + public TableSchema getTableSchema() { + return tableSource.getTableSchema(); + } + + @Override + public String asSummaryString() { + Map args = new HashMap<>(); + args.put("fields", tableSource.getTableSchema().getFieldNames()); + + return formatWithChildren("TableSource", args); + } + + public TableSource getTableSource() { + return tableSource; + } + + public boolean isBatch() { + return isBatch; + } + + @Override + public List getChildren() { + return Collections.emptyList(); + } + + @Override + public R accept(TableOperationVisitor visitor) { + return visitor.visitTableSourceTable(this); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/CalciteCatalogTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/CalciteCatalogTable.java deleted file mode 100644 index d6a91bc9190b32..00000000000000 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/CalciteCatalogTable.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.catalog; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.table.api.TableSchema; -import org.apache.flink.table.calcite.FlinkTypeFactory; - -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.schema.Table; - -import java.util.Collections; -import java.util.Map; -import java.util.Optional; - -/** - * Thin wrapper around Calcite specific {@link Table}, this is a temporary solution - * that allows to register those tables in the {@link CatalogManager}. - * TODO remove once we decouple TableEnvironment from Calcite. - */ -@Internal -public class CalciteCatalogTable implements CatalogBaseTable { - private final Table table; - private final FlinkTypeFactory typeFactory; - - public CalciteCatalogTable(Table table, FlinkTypeFactory typeFactory) { - this.table = table; - this.typeFactory = typeFactory; - } - - public Table getTable() { - return table; - } - - @Override - public Map getProperties() { - return Collections.emptyMap(); - } - - @Override - public TableSchema getSchema() { - RelDataType relDataType = table.getRowType(typeFactory); - - String[] fieldNames = relDataType.getFieldNames().toArray(new String[0]); - TypeInformation[] fieldTypes = relDataType.getFieldList() - .stream() - .map(field -> FlinkTypeFactory.toTypeInfo(field.getType())).toArray(TypeInformation[]::new); - - return new TableSchema(fieldNames, fieldTypes); - } - - @Override - public String getComment() { - return null; - } - - @Override - public CatalogBaseTable copy() { - return this; - } - - @Override - public Optional getDescription() { - return Optional.empty(); - } - - @Override - public Optional getDetailedDescription() { - return Optional.empty(); - } -} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java index 7d93b3a6bac1bf..6584068292d340 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/DatabaseCalciteSchema.java @@ -22,6 +22,8 @@ import org.apache.flink.table.catalog.exceptions.CatalogException; import org.apache.flink.table.catalog.exceptions.DatabaseNotExistException; import org.apache.flink.table.catalog.exceptions.TableNotExistException; +import org.apache.flink.table.plan.schema.TableSourceTable; +import org.apache.flink.table.plan.stats.FlinkStatistic; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.rel.type.RelProtoDataType; @@ -65,10 +67,16 @@ public Table getTable(String tableName) { CatalogBaseTable table = catalog.getTable(tablePath); - if (table instanceof CalciteCatalogTable) { - return ((CalciteCatalogTable) table).getTable(); - } else if (table instanceof TableOperationCatalogView) { + if (table instanceof TableOperationCatalogView) { return TableOperationCatalogViewTable.createCalciteTable(((TableOperationCatalogView) table)); + } else if (table instanceof ConnectorCatalogTable) { + ConnectorCatalogTable connectorTable = (ConnectorCatalogTable) table; + return connectorTable.getTableSource() + .map(tableSource -> new TableSourceTable<>( + tableSource, + !connectorTable.isBatch(), + FlinkStatistic.UNKNOWN())) + .orElseThrow(() -> new TableException("Cannot query a sink only table.")); } else { throw new TableException("Unsupported table type: " + table); } @@ -76,7 +84,7 @@ public Table getTable(String tableName) { // TableNotExistException should never happen, because we are checking it exists // via catalog.tableExists throw new TableException(format( - "A failure occured when accesing table. Table path [%s, %s, %s]", + "A failure occurred when accessing table. Table path [%s, %s, %s]", catalogName, databaseName, tableName), e); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java index f3aaac8ba67175..01fa8c613d5a93 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/TableOperationConverter.java @@ -24,6 +24,7 @@ import org.apache.flink.table.api.TableException; import org.apache.flink.table.calcite.FlinkRelBuilder; import org.apache.flink.table.calcite.FlinkTypeFactory; +import org.apache.flink.table.catalog.CatalogReader; import org.apache.flink.table.expressions.AggFunctionCall; import org.apache.flink.table.expressions.Aggregation; import org.apache.flink.table.expressions.CallExpression; @@ -53,6 +54,7 @@ import org.apache.flink.table.operations.TableOperation; import org.apache.flink.table.operations.TableOperationDefaultVisitor; import org.apache.flink.table.operations.TableOperationVisitor; +import org.apache.flink.table.operations.TableSourceTableOperation; import org.apache.flink.table.operations.WindowAggregateTableOperation; import org.apache.flink.table.operations.WindowAggregateTableOperation.ResolvedGroupWindow; import org.apache.flink.table.plan.logical.LogicalWindow; @@ -62,15 +64,21 @@ import org.apache.flink.table.plan.nodes.FlinkConventions; import org.apache.flink.table.plan.nodes.logical.FlinkLogicalDataSetScan; import org.apache.flink.table.plan.nodes.logical.FlinkLogicalDataStreamScan; +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableSourceScan; import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl; import org.apache.flink.table.plan.schema.RowSchema; +import org.apache.flink.table.plan.schema.TableSourceTable; +import org.apache.flink.table.plan.stats.FlinkStatistic; +import org.apache.calcite.prepare.RelOptTableImpl; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.logical.LogicalTableFunctionScan; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.schema.Schemas; +import org.apache.calcite.schema.Table; import org.apache.calcite.tools.RelBuilder.AggCall; import org.apache.calcite.tools.RelBuilder.GroupKey; @@ -79,6 +87,7 @@ import java.util.Set; import java.util.stream.IntStream; +import scala.Option; import scala.Some; import static java.util.Arrays.asList; @@ -264,6 +273,31 @@ public RelNode visitOther(TableOperation other) { throw new TableException("Unknown table operation: " + other); } + @Override + public RelNode visitTableSourceTable(TableSourceTableOperation tableSourceTable) { + final Table relTable = new TableSourceTable<>( + tableSourceTable.getTableSource(), + !tableSourceTable.isBatch(), + FlinkStatistic.UNKNOWN()); + + CatalogReader catalogReader = (CatalogReader) relBuilder.getRelOptSchema(); + + // TableSourceScan requires a unique name of a Table for computing a digest. + // We are using the identity hash of the TableSource object. + String refId = "unregistered_" + System.identityHashCode(tableSourceTable.getTableSource()); + return new FlinkLogicalTableSourceScan( + relBuilder.getCluster(), + relBuilder.getCluster().traitSet().replace(FlinkConventions.LOGICAL()), + RelOptTableImpl.create( + catalogReader, + relTable.getRowType(relBuilder.getTypeFactory()), + relTable, + Schemas.path(catalogReader.getRootSchema(), Collections.singletonList(refId))), + tableSourceTable.getTableSource(), + Option.empty() + ); + } + private RelNode convertToDataStreamScan(DataStreamTableOperation tableOperation) { RelDataType logicalRowType = relBuilder.getTypeFactory() .buildLogicalRowType(tableOperation.getTableSchema()); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala index 6dad03aa8307f7..66a1d83c86b823 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala @@ -18,8 +18,6 @@ package org.apache.flink.table.api -import _root_.java.util.concurrent.atomic.AtomicInteger - import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType @@ -42,7 +40,7 @@ import org.apache.flink.table.plan.rules.FlinkRuleSets import org.apache.flink.table.plan.schema._ import org.apache.flink.table.runtime.MapRunner import org.apache.flink.table.sinks._ -import org.apache.flink.table.sources.{BatchTableSource, TableSource} +import org.apache.flink.table.sources.{BatchTableSource, TableSource, TableSourceUtil} import org.apache.flink.table.typeutils.FieldInfoUtils.{calculateTableSchema, getFieldsInfo, validateInputTypeInfo} import org.apache.flink.types.Row @@ -58,160 +56,31 @@ abstract class BatchTableEnvImpl( catalogManager: CatalogManager) extends TableEnvImpl(config, catalogManager) { - // a counter for unique table names. - private val nameCntr: AtomicInteger = new AtomicInteger(0) - - // the naming pattern for internally registered tables. - private val internalNamePattern = "^_DataSetTable_[0-9]+$".r - override def queryConfig: BatchQueryConfig = new BatchQueryConfig - /** - * Checks if the chosen table name is valid. - * - * @param name The table name to check. - */ - override protected def checkValidTableName(name: String): Unit = { - val m = internalNamePattern.findFirstIn(name) - m match { - case Some(_) => - throw new TableException(s"Illegal Table name. " + - s"Please choose a name that does not contain the pattern $internalNamePattern") - case None => - } - } - - /** Returns a unique table name according to the internal naming pattern. */ - override protected def createUniqueTableName(): String = - "_DataSetTable_" + nameCntr.getAndIncrement() - /** * Registers an internal [[BatchTableSource]] in this [[TableEnvImpl]]'s catalog without * name checking. Registered tables can be referenced in SQL queries. * - * @param name The name under which the [[TableSource]] is registered. * @param tableSource The [[TableSource]] to register. */ - override protected def registerTableSourceInternal( - name: String, - tableSource: TableSource[_]) - : Unit = { - - tableSource match { + override protected def validateTableSource(tableSource: TableSource[_]): Unit = { + TableSourceUtil.validateTableSource(tableSource) - // check for proper batch table source - case batchTableSource: BatchTableSource[_] => - // check if a table (source or sink) is registered - getTable(defaultCatalogName, defaultDatabaseName, name) match { - - // table source and/or sink is registered - case Some(table: TableSourceSinkTable[_, _]) => table.tableSourceTable match { - - // wrapper contains source - case Some(_: TableSourceTable[_]) => - throw new TableException(s"Table '$name' already exists. " + - s"Please choose a different name.") - - // wrapper contains only sink (not source) - case _ => - val enrichedTable = new TableSourceSinkTable( - Some(new BatchTableSourceTable(batchTableSource)), - table.tableSinkTable) - replaceRegisteredTableSourceSinkInternal(name, enrichedTable) - } - - // no table is registered - case _ => - val newTable = new TableSourceSinkTable( - Some(new BatchTableSourceTable(batchTableSource)), - None) - registerTableSourceSinkInternal(name, newTable) - } - - // not a batch table source - case _ => - throw new TableException("Only BatchTableSource can be registered in " + - "BatchTableEnvironment.") + if (!tableSource.isInstanceOf[BatchTableSource[_]]) { + throw new TableException("Only BatchTableSource can be registered in " + + "BatchTableEnvironment.") } } - def connect(connectorDescriptor: ConnectorDescriptor): BatchTableDescriptor = { - new BatchTableDescriptor(this, connectorDescriptor) - } - - def registerTableSink( - name: String, - fieldNames: Array[String], - fieldTypes: Array[TypeInformation[_]], - tableSink: TableSink[_]): Unit = { - // validate - checkValidTableName(name) - if (fieldNames == null) throw new TableException("fieldNames must not be null.") - if (fieldTypes == null) throw new TableException("fieldTypes must not be null.") - if (fieldNames.length == 0) throw new TableException("fieldNames must not be empty.") - if (fieldNames.length != fieldTypes.length) { - throw new TableException("Same number of field names and types required.") + override protected def validateTableSink(configuredSink: TableSink[_]): Unit = { + if (!configuredSink.isInstanceOf[BatchTableSink[_]]) { + throw new TableException("Only BatchTableSink can be registered in BatchTableEnvironment.") } - - // configure and register - val configuredSink = tableSink.configure(fieldNames, fieldTypes) - registerTableSinkInternal(name, configuredSink) } - def registerTableSink(name: String, configuredSink: TableSink[_]): Unit = { - registerTableSinkInternal(name, configuredSink) - } - - private def registerTableSinkInternal(name: String, configuredSink: TableSink[_]): Unit = { - // validate - checkValidTableName(name) - if (configuredSink.getFieldNames == null || configuredSink.getFieldTypes == null) { - throw new TableException("Table sink is not configured.") - } - if (configuredSink.getFieldNames.length == 0) { - throw new TableException("Field names must not be empty.") - } - if (configuredSink.getFieldNames.length != configuredSink.getFieldTypes.length) { - throw new TableException("Same number of field names and types required.") - } - - // register - configuredSink match { - - // check for proper batch table sink - case _: BatchTableSink[_] => - - // check if a table (source or sink) is registered - getTable(name) match { - - // table source and/or sink is registered - case Some(table: TableSourceSinkTable[_, _]) => table.tableSinkTable match { - - // wrapper contains sink - case Some(_: TableSinkTable[_]) => - throw new TableException(s"Table '$name' already exists. " + - s"Please choose a different name.") - - // wrapper contains only source (not sink) - case _ => - val enrichedTable = new TableSourceSinkTable( - table.tableSourceTable, - Some(new TableSinkTable(configuredSink))) - replaceRegisteredTableSourceSinkInternal(name, enrichedTable) - } - - // no table is registered - case _ => - val newTable = new TableSourceSinkTable( - None, - Some(new TableSinkTable(configuredSink))) - registerTableSourceSinkInternal(name, newTable) - } - - // not a batch table sink - case _ => - throw new TableException("Only BatchTableSink can be registered in BatchTableEnvironment.") - } + def connect(connectorDescriptor: ConnectorDescriptor): BatchTableDescriptor = { + new BatchTableDescriptor(this, connectorDescriptor) } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala index 5c96e7a609eb98..32b3f8c51b01e7 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala @@ -19,7 +19,6 @@ package org.apache.flink.table.api import _root_.java.lang.{Boolean => JBool} -import _root_.java.util.concurrent.atomic.AtomicInteger import org.apache.calcite.plan.RelOptUtil import org.apache.calcite.plan.hep.HepMatchOrder @@ -70,45 +69,17 @@ abstract class StreamTableEnvImpl( catalogManager: CatalogManager) extends TableEnvImpl(config, catalogManager) { - // a counter for unique table names - private val nameCntr: AtomicInteger = new AtomicInteger(0) - - // the naming pattern for internally registered tables. - private val internalNamePattern = "^_DataStreamTable_[0-9]+$".r - override def queryConfig: StreamQueryConfig = new StreamQueryConfig - /** - * Checks if the chosen table name is valid. - * - * @param name The table name to check. - */ - override protected def checkValidTableName(name: String): Unit = { - val m = internalNamePattern.findFirstIn(name) - m match { - case Some(_) => - throw new TableException(s"Illegal Table name. " + - s"Please choose a name that does not contain the pattern $internalNamePattern") - case None => - } - } - - /** Returns a unique table name according to the internal naming pattern. */ - override protected def createUniqueTableName(): String = - "_DataStreamTable_" + nameCntr.getAndIncrement() - /** * Registers an internal [[StreamTableSource]] in this [[TableEnvImpl]]'s catalog without * name checking. Registered tables can be referenced in SQL queries. * - * @param name The name under which the [[TableSource]] is registered. * @param tableSource The [[TableSource]] to register. */ - override protected def registerTableSourceInternal( - name: String, - tableSource: TableSource[_]) - : Unit = { + override protected def validateTableSource(tableSource: TableSource[_]): Unit = { + TableSourceUtil.validateTableSource(tableSource) tableSource match { // check for proper stream table source @@ -121,33 +92,6 @@ abstract class StreamTableEnvImpl( s"environment. But is: ${execEnv.getStreamTimeCharacteristic}") } - // register - getTable(defaultCatalogName, defaultDatabaseName, name) match { - - // check if a table (source or sink) is registered - case Some(table: TableSourceSinkTable[_, _]) => table.tableSourceTable match { - - // wrapper contains source - case Some(_: TableSourceTable[_]) => - throw new TableException(s"Table '$name' already exists. " + - s"Please choose a different name.") - - // wrapper contains only sink (not source) - case _ => - val enrichedTable = new TableSourceSinkTable( - Some(new StreamTableSourceTable(streamTableSource)), - table.tableSinkTable) - replaceRegisteredTableSourceSinkInternal(name, enrichedTable) - } - - // no table is registered - case _ => - val newTable = new TableSourceSinkTable( - Some(new StreamTableSourceTable(streamTableSource)), - None) - registerTableSourceSinkInternal(name, newTable) - } - // not a stream table source case _ => throw new TableException("Only StreamTableSource can be registered in " + @@ -155,84 +99,16 @@ abstract class StreamTableEnvImpl( } } - def connect(connectorDescriptor: ConnectorDescriptor): StreamTableDescriptor = { - new StreamTableDescriptor(this, connectorDescriptor) - } - - def registerTableSink( - name: String, - fieldNames: Array[String], - fieldTypes: Array[TypeInformation[_]], - tableSink: TableSink[_]): Unit = { - - checkValidTableName(name) - if (fieldNames == null) throw new TableException("fieldNames must not be null.") - if (fieldTypes == null) throw new TableException("fieldTypes must not be null.") - if (fieldNames.length == 0) throw new TableException("fieldNames must not be empty.") - if (fieldNames.length != fieldTypes.length) { - throw new TableException("Same number of field names and types required.") + override protected def validateTableSink(configuredSink: TableSink[_]): Unit = { + if (!configuredSink.isInstanceOf[StreamTableSink[_]]) { + throw new TableException( + "Only AppendStreamTableSink, UpsertStreamTableSink, and RetractStreamTableSink can be " + + "registered in StreamTableEnvironment.") } - - val configuredSink = tableSink.configure(fieldNames, fieldTypes) - registerTableSinkInternal(name, configuredSink) - } - - def registerTableSink(name: String, configuredSink: TableSink[_]): Unit = { - registerTableSinkInternal(name, configuredSink) } - private def registerTableSinkInternal(name: String, configuredSink: TableSink[_]): Unit = { - // validate - checkValidTableName(name) - if (configuredSink.getFieldNames == null || configuredSink.getFieldTypes == null) { - throw new TableException("Table sink is not configured.") - } - if (configuredSink.getFieldNames.length == 0) { - throw new TableException("Field names must not be empty.") - } - if (configuredSink.getFieldNames.length != configuredSink.getFieldTypes.length) { - throw new TableException("Same number of field names and types required.") - } - - // register - configuredSink match { - - // check for proper batch table sink - case _: StreamTableSink[_] => - - // check if a table (source or sink) is registered - getTable(name) match { - - // table source and/or sink is registered - case Some(table: TableSourceSinkTable[_, _]) => table.tableSinkTable match { - - // wrapper contains sink - case Some(_: TableSinkTable[_]) => - throw new TableException(s"Table '$name' already exists. " + - s"Please choose a different name.") - - // wrapper contains only source (not sink) - case _ => - val enrichedTable = new TableSourceSinkTable( - table.tableSourceTable, - Some(new TableSinkTable(configuredSink))) - replaceRegisteredTableSourceSinkInternal(name, enrichedTable) - } - - // no table is registered - case _ => - val newTable = new TableSourceSinkTable( - None, - Some(new TableSinkTable(configuredSink))) - registerTableSourceSinkInternal(name, newTable) - } - - // not a stream table sink - case _ => - throw new TableException( - "Only AppendStreamTableSink, UpsertStreamTableSink, and RetractStreamTableSink can be " + - "registered in StreamTableEnvironment.") - } + def connect(connectorDescriptor: ConnectorDescriptor): StreamTableDescriptor = { + new StreamTableDescriptor(this, connectorDescriptor) } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala index 6db31eefec6e76..f2818299cb3846 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala @@ -22,13 +22,11 @@ import _root_.java.util.Optional import _root_.java.util.concurrent.atomic.AtomicInteger import com.google.common.collect.ImmutableList -import org.apache.calcite.jdbc.CalciteSchema import org.apache.calcite.jdbc.CalciteSchemaBuilder.asRootSchema import org.apache.calcite.plan.RelOptPlanner.CannotPlanException import org.apache.calcite.plan._ import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgram, HepProgramBuilder} import org.apache.calcite.rel.RelNode -import org.apache.calcite.schema.SchemaPlus import org.apache.calcite.sql._ import org.apache.calcite.sql.parser.SqlParser import org.apache.calcite.tools._ @@ -40,12 +38,13 @@ import org.apache.flink.table.calcite._ import org.apache.flink.table.catalog._ import org.apache.flink.table.codegen.{FunctionCodeGenerator, GeneratedFunction} import org.apache.flink.table.expressions._ +import org.apache.flink.table.factories.{TableFactoryService, TableFactoryUtil, TableSinkFactory} import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction, UserDefinedAggregateFunction} -import org.apache.flink.table.operations.{CatalogTableOperation, OperationTreeBuilder, PlannerTableOperation} +import org.apache.flink.table.operations.{CatalogTableOperation, OperationTreeBuilder, PlannerTableOperation, TableSourceTableOperation} import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.rules.FlinkRuleSets -import org.apache.flink.table.plan.schema.{RowSchema, TableSourceSinkTable} +import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.planner.PlanningConfigurationBuilder import org.apache.flink.table.sinks.TableSink import org.apache.flink.table.sources.TableSource @@ -53,9 +52,9 @@ import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import org.apache.flink.table.util.JavaScalaConversionUtil import org.apache.flink.table.validate.FunctionCatalog import org.apache.flink.types.Row +import org.apache.flink.util.StringUtils import _root_.scala.collection.JavaConverters._ -import _root_.scala.collection.mutable /** * The abstract base class for the implementation of batch and stream TableEnvironments. @@ -73,9 +72,6 @@ abstract class TableEnvImpl( protected val defaultCatalogName: String = config.getBuiltInCatalogName protected val defaultDatabaseName: String = config.getBuiltInDatabaseName - private val internalSchema: CalciteSchema = - asRootSchema(new CatalogManagerCalciteSchema(catalogManager, isBatch)) - // temporary bridge between API and planner private[flink] val expressionBridge: ExpressionBridge[PlannerExpression] = new ExpressionBridge[PlannerExpression](functionCatalog, PlannerExpressionConverter.INSTANCE) @@ -89,7 +85,7 @@ abstract class TableEnvImpl( new PlanningConfigurationBuilder( config, functionCatalog, - internalSchema, + asRootSchema(new CatalogManagerCalciteSchema(catalogManager, isBatch)), expressionBridge) protected def calciteConfig: CalciteConfig = config.getPlannerConfig @@ -425,23 +421,138 @@ abstract class TableEnvImpl( "Only tables that belong to this TableEnvironment can be registered.") } - checkValidTableName(name) - val tableTable = new TableOperationCatalogView(table.getTableOperation) registerTableInternal(name, tableTable) } override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = { + validateTableSource(tableSource) registerTableSourceInternal(name, tableSource) } + override def registerTableSink( + name: String, + fieldNames: Array[String], + fieldTypes: Array[TypeInformation[_]], + tableSink: TableSink[_]): Unit = { + + if (fieldNames == null) { + throw new TableException("fieldNames must not be null.") + } + if (fieldTypes == null) { + throw new TableException("fieldTypes must not be null.") + } + if (fieldNames.length == 0) { + throw new TableException("fieldNames must not be empty.") + } + if (fieldNames.length != fieldTypes.length) { + throw new TableException("Same number of field names and types required.") + } + + val configuredSink = tableSink.configure(fieldNames, fieldTypes) + registerTableSinkInternal(name, configuredSink) + } + + override def registerTableSink(name: String, configuredSink: TableSink[_]): Unit = { + // validate + if (configuredSink.getFieldNames == null || configuredSink.getFieldTypes == null) { + throw new TableException("Table sink is not configured.") + } + if (configuredSink.getFieldNames.length == 0) { + throw new TableException("Field names must not be empty.") + } + if (configuredSink.getFieldNames.length != configuredSink.getFieldTypes.length) { + throw new TableException("Same number of field names and types required.") + } + + validateTableSink(configuredSink) + registerTableSinkInternal(name, configuredSink) + } + override def fromTableSource(source: TableSource[_]): Table = { - val name = createUniqueTableName() - registerTableSourceInternal(name, source) - scan(name) + new TableImpl(this, new TableSourceTableOperation(source, isBatch)) + } + + /** + * Perform batch or streaming specific validations of the [[TableSource]]. + * This method should throw [[ValidationException]] if the [[TableSource]] cannot be used + * in this [[TableEnvironment]]. + * + * @param tableSource table source to validate + */ + protected def validateTableSource(tableSource: TableSource[_]): Unit + + /** + * Perform batch or streaming specific validations of the [[TableSink]]. + * This method should throw [[ValidationException]] if the [[TableSink]] cannot be used + * in this [[TableEnvironment]]. + * + * @param tableSink table source to validate + */ + protected def validateTableSink(tableSink: TableSink[_]): Unit + + private def registerTableSourceInternal( + name: String, + tableSource: TableSource[_]) + : Unit = { + // register + getCatalogTable(defaultCatalogName, defaultDatabaseName, name) match { + + // check if a table (source or sink) is registered + case Some(table: ConnectorCatalogTable[_, _]) => + if (table.getTableSource.isPresent) { + // wrapper contains source + throw new TableException(s"Table '$name' already exists. " + + s"Please choose a different name.") + } else { + // wrapper contains only sink (not source) + replaceTableInternal( + name, + ConnectorCatalogTable + .sourceAndSink(tableSource, table.getTableSink.get, isBatch)) + } + + // no table is registered + case _ => + registerTableInternal(name, ConnectorCatalogTable.source(tableSource, isBatch)) + } + } + + private def registerTableSinkInternal( + name: String, + tableSink: TableSink[_]) + : Unit = { + // check if a table (source or sink) is registered + getCatalogTable(defaultCatalogName, defaultDatabaseName, name) match { + + // table source and/or sink is registered + case Some(table: ConnectorCatalogTable[_, _]) => + if (table.getTableSink.isPresent) { + // wrapper contains sink + throw new TableException(s"Table '$name' already exists. " + + s"Please choose a different name.") + } else { + // wrapper contains only source (not sink) + replaceTableInternal( + name, + ConnectorCatalogTable + .sourceAndSink(table.getTableSource.get, tableSink, isBatch)) + } + + // no table is registered + case _ => + registerTableInternal(name, ConnectorCatalogTable.sink(tableSink, isBatch)) + } + } + + private def checkValidTableName(name: String) = { + if (StringUtils.isNullOrWhitespaceOnly(name)) { + throw new ValidationException("A table name cannot be null or consist of only whitespaces.") + } } protected def registerTableInternal(name: String, table: CatalogBaseTable): Unit = { + checkValidTableName(name) val path = new ObjectPath(defaultDatabaseName, name) JavaScalaConversionUtil.toScala(catalogManager.getCatalog(defaultCatalogName)) match { case Some(catalog) => @@ -453,27 +564,14 @@ abstract class TableEnvImpl( } } - protected def registerTableSourceInternal(name: String, tableSource: TableSource[_]): Unit - - protected def registerTableSourceSinkInternal[T1, T2]( - name: String, - table: TableSourceSinkTable[T1, T2]) - : Unit = { - registerTableInternal( - name, - new CalciteCatalogTable(table, planningConfigurationBuilder.getTypeFactory)) - } - - protected def replaceRegisteredTableSourceSinkInternal[T1, T2]( - name: String, - table: TableSourceSinkTable[T1, T2]) - : Unit = { + protected def replaceTableInternal(name: String, table: CatalogBaseTable): Unit = { + checkValidTableName(name) val path = new ObjectPath(defaultDatabaseName, name) JavaScalaConversionUtil.toScala(catalogManager.getCatalog(defaultCatalogName)) match { case Some(catalog) => catalog.alterTable( path, - new CalciteCatalogTable(table, planningConfigurationBuilder.getTypeFactory), + table, false) case None => throw new TableException("The default catalog does not exist.") } @@ -488,7 +586,8 @@ abstract class TableEnvImpl( } private[flink] def scanInternal(tablePath: Array[String]): Option[CatalogTableOperation] = { - JavaScalaConversionUtil.toScala(catalogManager.resolveTable(tablePath : _*)) + JavaScalaConversionUtil.toScala(catalogManager.resolveTable(tablePath: _*)) + .map(t => new CatalogTableOperation(t.getTablePath, t.getTableSchema)) } override def listTables(): Array[String] = { @@ -578,16 +677,19 @@ abstract class TableEnvImpl( private[flink] def insertInto(table: Table, conf: QueryConfig, sinkTablePath: String*): Unit = { // check that sink table exists - if (null == sinkTablePath) throw new TableException("Name of TableSink must not be null.") - if (sinkTablePath.isEmpty) throw new TableException("Name of TableSink must not be empty.") + if (null == sinkTablePath) { + throw new TableException("Name of TableSink must not be null.") + } + if (sinkTablePath.isEmpty) { + throw new TableException("Name of TableSink must not be empty.") + } - getTable(sinkTablePath: _*) match { + getTableSink(sinkTablePath: _*) match { case None => throw new TableException(s"No table was registered under the name $sinkTablePath.") - case Some(s: TableSourceSinkTable[_, _]) if s.tableSinkTable.isDefined => - val tableSink = s.tableSinkTable.get.tableSink + case Some(tableSink) => // validate schema of source table and table sink val srcFieldTypes = table.getSchema.getFieldTypes val sinkFieldTypes = tableSink.getFieldTypes @@ -614,58 +716,35 @@ abstract class TableEnvImpl( } // emit the table to the configured table sink writeToSink(table, tableSink, conf) - - case Some(_) => - throw new TableException(s"The table registered as $sinkTablePath is not a TableSink. " + - s"You can only emit query results to a registered TableSink.") } } - /** Returns a unique table name according to the internal naming pattern. */ - protected def createUniqueTableName(): String + private def getTableSink(name: String*): Option[TableSink[_]] = { + JavaScalaConversionUtil.toScala(catalogManager.resolveTable(name: _*)) match { + case Some(s) if s.getExternalCatalogTable.isPresent => - /** - * Checks if the chosen table name is valid. - * - * @param name The table name to check. - */ - protected def checkValidTableName(name: String): Unit + Option(TableFactoryUtil.findAndCreateTableSink(s.getExternalCatalogTable.get())) - /** - * Get a table from either internal or external catalogs. - * - * @param name The name of the table. - * @return The table registered either internally or externally, None otherwise. - */ - protected def getTable(name: String*): Option[org.apache.calcite.schema.Table] = { - - // recursively fetches a table from a schema. - def getTableFromSchema( - schema: SchemaPlus, - path: List[String]): Option[org.apache.calcite.schema.Table] = { - - path match { - case tableName :: Nil => - // look up table - Option(schema.getTable(tableName)) - case subschemaName :: remain => - // look up subschema - val subschema = Option(schema.getSubSchema(subschemaName)) - subschema match { - case Some(s) => - // search for table in subschema - getTableFromSchema(s, remain) - case None => - // subschema does not exist - None - } - } + case Some(s) if JavaScalaConversionUtil.toScala(s.getCatalogTable) + .exists(_.isInstanceOf[ConnectorCatalogTable[_, _]]) => + + JavaScalaConversionUtil + .toScala(s.getCatalogTable.get().asInstanceOf[ConnectorCatalogTable[_, _]].getTableSink) + + case Some(s) if JavaScalaConversionUtil.toScala(s.getCatalogTable) + .exists(_.isInstanceOf[CatalogTable]) => + + val sinkProperties = s.getCatalogTable.get().asInstanceOf[CatalogTable].toProperties + Option(TableFactoryService.find(classOf[TableSinkFactory[_]], sinkProperties) + .createTableSink(sinkProperties)) + + case _ => None } + } + protected def getCatalogTable(name: String*): Option[CatalogBaseTable] = { JavaScalaConversionUtil.toScala(catalogManager.resolveTable(name: _*)) - .flatMap(t => - getTableFromSchema(internalSchema.plus(), t.getTablePath.asScala.toList) - ) + .flatMap(t => JavaScalaConversionUtil.toScala(t.getCatalogTable)) } /** Returns a unique temporary attribute name. */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala index 54962f59ef3772..eae99c13c7b5ee 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala @@ -19,16 +19,12 @@ package org.apache.flink.table.calcite import java.lang.Iterable -import java.util.{Collections, List => JList} +import java.util.{List => JList} -import org.apache.calcite.jdbc.CalciteSchema import org.apache.calcite.plan._ -import org.apache.calcite.plan.volcano.VolcanoPlanner -import org.apache.calcite.prepare.CalciteCatalogReader import org.apache.calcite.rel.logical.LogicalAggregate -import org.apache.calcite.rex.RexBuilder +import org.apache.calcite.tools.RelBuilder import org.apache.calcite.tools.RelBuilder.{AggCall, GroupKey} -import org.apache.calcite.tools.{FrameworkConfig, RelBuilder} import org.apache.flink.table.api.TableException import org.apache.flink.table.expressions.{Alias, ExpressionBridge, PlannerExpression, WindowProperty} import org.apache.flink.table.operations.TableOperation diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala index 501772766dfa42..4b2287fd4b3973 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalCatalogSchema.scala @@ -81,7 +81,7 @@ class ExternalCatalogSchema( */ override def getTable(name: String): Table = try { val externalCatalogTable = catalog.getTable(name) - ExternalTableUtil.fromExternalCatalogTable(isBatch, externalCatalogTable) + ExternalTableUtil.fromExternalCatalogTable(isBatch, externalCatalogTable).orNull } catch { case _: TableNotExistException => { LOG.warn(s"Table $name does not exist in externalCatalog $catalogIdentifier") diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalTableUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalTableUtil.scala index f91f3094bf3bec..4ac24dd214419a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalTableUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/catalog/ExternalTableUtil.scala @@ -22,13 +22,12 @@ import org.apache.flink.table.api._ import org.apache.flink.table.factories._ import org.apache.flink.table.plan.schema._ import org.apache.flink.table.plan.stats.FlinkStatistic -import org.apache.flink.table.sinks.{BatchTableSink, StreamTableSink} -import org.apache.flink.table.sources.{BatchTableSource, StreamTableSource} +import org.apache.flink.table.sources.TableSource import org.apache.flink.table.util.JavaScalaConversionUtil.toScala import org.apache.flink.table.util.Logging /** - * The utility class is used to convert [[ExternalCatalogTable]] to [[TableSourceSinkTable]]. + * The utility class is used to convert [[ExternalCatalogTable]]. * * It uses [[TableFactoryService]] for discovering. */ @@ -40,24 +39,16 @@ object ExternalTableUtil extends Logging { * @param externalTable the [[ExternalCatalogTable]] instance which to convert * @return converted [[TableSourceTable]] instance from the input catalog table */ - def fromExternalCatalogTable[T1, T2](isBatch: Boolean, externalTable: ExternalCatalogTable) - : TableSourceSinkTable[T1, T2] = { + def fromExternalCatalogTable[T](isBatch: Boolean, externalTable: ExternalCatalogTable) + : Option[TableSourceTable[T]] = { val statistics = new FlinkStatistic(toScala(externalTable.getTableStats)) - val source: Option[TableSourceTable[T1]] = if (externalTable.isTableSource) { + if (externalTable.isTableSource) { Some(createTableSource(isBatch, externalTable, statistics)) } else { None } - - val sink: Option[TableSinkTable[T2]] = if (externalTable.isTableSink) { - Some(createTableSink(isBatch, externalTable, statistics)) - } else { - None - } - - new TableSourceSinkTable[T1, T2](source, sink) } private def createTableSource[T]( @@ -65,33 +56,20 @@ object ExternalTableUtil extends Logging { externalTable: ExternalCatalogTable, statistics: FlinkStatistic) : TableSourceTable[T] = { - if (isBatch && externalTable.isBatchTable) { - val source = TableFactoryUtil.findAndCreateTableSource(externalTable) - new BatchTableSourceTable[T](source.asInstanceOf[BatchTableSource[T]], statistics) - } else if (!isBatch && externalTable.isStreamTable) { - val source = TableFactoryUtil.findAndCreateTableSource(externalTable) - new StreamTableSourceTable[T](source.asInstanceOf[StreamTableSource[T]], statistics) + val source = if (isModeCompatibleWithTable(isBatch, externalTable)) { + TableFactoryUtil.findAndCreateTableSource(externalTable) } else { throw new ValidationException( "External catalog table does not support the current environment for a table source.") } + + new TableSourceTable[T](source.asInstanceOf[TableSource[T]], !isBatch, statistics) } - private def createTableSink[T]( + private def isModeCompatibleWithTable[T]( isBatch: Boolean, - externalTable: ExternalCatalogTable, - statistics: FlinkStatistic) - : TableSinkTable[T] = { - - if (isBatch && externalTable.isBatchTable) { - val sink = TableFactoryUtil.findAndCreateTableSink(externalTable) - new TableSinkTable[T](sink.asInstanceOf[BatchTableSink[T]], statistics) - } else if (!isBatch && externalTable.isStreamTable) { - val sink = TableFactoryUtil.findAndCreateTableSink(externalTable) - new TableSinkTable[T](sink.asInstanceOf[StreamTableSink[T]], statistics) - } else { - throw new ValidationException( - "External catalog table does not support the current environment for a table sink.") - } + externalTable: ExternalCatalogTable) + : Boolean = { + isBatch && externalTable.isBatchTable || !isBatch && externalTable.isStreamTable } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala index 03b442bb1fb277..4aad828df99da3 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/PhysicalTableSourceScan.scala @@ -23,8 +23,7 @@ import org.apache.calcite.rel.RelWriter import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.TableScan import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.schema.TableSourceSinkTable -import org.apache.flink.table.sources.{TableSource, TableSourceUtil} +import org.apache.flink.table.sources.{StreamTableSource, TableSource, TableSourceUtil} import scala.collection.JavaConverters._ @@ -32,19 +31,16 @@ abstract class PhysicalTableSourceScan( cluster: RelOptCluster, traitSet: RelTraitSet, table: RelOptTable, - val tableSource: TableSource[_], + tableSource: TableSource[_], val selectedFields: Option[Array[Int]]) extends TableScan(cluster, traitSet, table) { override def deriveRowType(): RelDataType = { val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] - val streamingTable = table.unwrap(classOf[TableSourceSinkTable[_, _]]) match { - case t: TableSourceSinkTable[_, _] if t.isStreamSourceTable => true - // null - case _ => false - } + val streamingTable = tableSource.isInstanceOf[StreamTableSource[_]] - TableSourceUtil.getRelDataType(tableSource, selectedFields, streamingTable, flinkTypeFactory) + TableSourceUtil + .getRelDataType(tableSource, selectedFields, streamingTable, flinkTypeFactory) } override def explainTerms(pw: RelWriter): RelWriter = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala index 0b91ab70f213ee..862523bd0a4355 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala @@ -73,7 +73,7 @@ class BatchTableSourceScan( cluster, traitSet, getTable, - newTableSource.asInstanceOf[BatchTableSource[_]], + tableSource, selectedFields ) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala index 7e451090028839..20e2234556ab02 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala @@ -32,7 +32,7 @@ import org.apache.flink.table.plan.nodes.PhysicalTableSourceScan import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.types.CRow import org.apache.flink.table.sources._ -import org.apache.flink.table.sources.wmstrategies.{PeriodicWatermarkAssigner, PunctuatedWatermarkAssigner, PreserveWatermarks} +import org.apache.flink.table.sources.wmstrategies.{PeriodicWatermarkAssigner, PreserveWatermarks, PunctuatedWatermarkAssigner} import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo /** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala index 0c3d25add45ee2..89aff89e75f07b 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableSourceScan.scala @@ -25,10 +25,9 @@ import org.apache.calcite.rel.core.TableScan import org.apache.calcite.rel.logical.LogicalTableScan import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.schema.TableSourceSinkTable +import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.sources.{FilterableTableSource, TableSource, TableSourceUtil} import scala.collection.JavaConverters._ @@ -51,11 +50,7 @@ class FlinkLogicalTableSourceScan( override def deriveRowType(): RelDataType = { val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] - val streamingTable = table.unwrap(classOf[TableSourceSinkTable[_, _]]) match { - case t: TableSourceSinkTable[_, _] if t.isStreamSourceTable => true - // null - case _ => false - } + val streamingTable = table.unwrap(classOf[TableSourceTable[_]]).isStreaming TableSourceUtil.getRelDataType(tableSource, selectedFields, streamingTable, flinkTypeFactory) } @@ -111,26 +106,18 @@ class FlinkLogicalTableSourceScanConverter override def matches(call: RelOptRuleCall): Boolean = { val scan = call.rel[TableScan](0) - scan.getTable.unwrap(classOf[TableSourceSinkTable[_, _]]) match { - case t: TableSourceSinkTable[_, _] if t.isSourceTable => true - // null - case _ => false - } + scan.getTable.unwrap(classOf[TableSourceTable[_]]) != null } def convert(rel: RelNode): RelNode = { val scan = rel.asInstanceOf[TableScan] val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL) - val tableSource = scan.getTable.unwrap(classOf[TableSourceSinkTable[_, _]]) - .tableSourceTable - .map(_.tableSource) - .getOrElse(throw new TableException("Table source expected.")) new FlinkLogicalTableSourceScan( rel.getCluster, traitSet, scan.getTable, - tableSource, + scan.getTable.unwrap(classOf[TableSourceTable[_]]).tableSource, None ) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/BatchTableSourceScanRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/BatchTableSourceScanRule.scala index cee6eef769f670..8ce97ed0e1708f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/BatchTableSourceScanRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/dataSet/BatchTableSourceScanRule.scala @@ -25,7 +25,7 @@ import org.apache.calcite.rel.core.TableScan import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.dataset.BatchTableSourceScan import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableSourceScan -import org.apache.flink.table.plan.schema.TableSourceSinkTable +import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.sources.BatchTableSource class BatchTableSourceScanRule @@ -38,11 +38,9 @@ class BatchTableSourceScanRule /** Rule must only match if TableScan targets a [[BatchTableSource]] */ override def matches(call: RelOptRuleCall): Boolean = { val scan: TableScan = call.rel(0).asInstanceOf[TableScan] - scan.getTable.unwrap(classOf[TableSourceSinkTable[_, _]]) match { - case t: TableSourceSinkTable[_, _] if t.isBatchSourceTable => true - // null - case _ => false - } + + val sourceTable = scan.getTable.unwrap(classOf[TableSourceTable[_]]) + sourceTable != null && !sourceTable.isStreaming } def convert(rel: RelNode): RelNode = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/StreamTableSourceScanRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/StreamTableSourceScanRule.scala index e99118f7a9b3a5..a23830ca3309ab 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/StreamTableSourceScanRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/StreamTableSourceScanRule.scala @@ -25,7 +25,7 @@ import org.apache.calcite.rel.core.TableScan import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.datastream.StreamTableSourceScan import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableSourceScan -import org.apache.flink.table.plan.schema.TableSourceSinkTable +import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.sources.StreamTableSource class StreamTableSourceScanRule @@ -39,11 +39,9 @@ class StreamTableSourceScanRule /** Rule must only match if TableScan targets a [[StreamTableSource]] */ override def matches(call: RelOptRuleCall): Boolean = { val scan: TableScan = call.rel(0).asInstanceOf[TableScan] - scan.getTable.unwrap(classOf[TableSourceSinkTable[_, _]]) match { - case t: TableSourceSinkTable[_, _] if t.isStreamSourceTable => true - // null - case _ => false - } + + val sourceTable = scan.getTable.unwrap(classOf[TableSourceTable[_]]) + sourceTable != null && sourceTable.isStreaming } def convert(rel: RelNode): RelNode = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala deleted file mode 100644 index 75ce3dad166069..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.schema - -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} -import org.apache.calcite.schema.Statistic -import org.apache.calcite.schema.impl.AbstractTable -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.stats.FlinkStatistic -import org.apache.flink.table.sinks.TableSink - -/** Class which implements the logic to convert a [[TableSink]] to Calcite Table */ -class TableSinkTable[T]( - val tableSink: TableSink[T], - val statistic: FlinkStatistic = FlinkStatistic.UNKNOWN) { - - /** Returns the row type of the table with this tableSink. - * - * @param typeFactory Type factory with which to create the type - * @return Row type - */ - def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { - val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory] - flinkTypeFactory.buildLogicalRowType(tableSink.getFieldNames, tableSink.getFieldTypes) - } - - /** - * Returns statistics of current table - * - * @return statistics of current table - */ - def getStatistic: Statistic = statistic -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceSinkTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceSinkTable.scala deleted file mode 100644 index 923d82ce84e552..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceSinkTable.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.schema - -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} -import org.apache.calcite.schema.Statistic -import org.apache.calcite.schema.impl.AbstractTable -import org.apache.flink.table.api.TableException - -/** - * Wrapper for both a [[TableSourceTable]] and [[TableSinkTable]] under a common name. - * - * @param tableSourceTable table source table (if available) - * @param tableSinkTable table sink table (if available) - * @tparam T1 type of the table source table - * @tparam T2 type of the table sink table - */ -class TableSourceSinkTable[T1, T2]( - val tableSourceTable: Option[TableSourceTable[T1]], - val tableSinkTable: Option[TableSinkTable[T2]]) - extends AbstractTable { - - // In the streaming case, the table schema of source and sink can differ because of extra - // rowtime/proctime fields. We will always return the source table schema if tableSourceTable - // is not None, otherwise return the sink table schema. We move the Calcite validation logic of - // the sink table schema into Flink. This allows us to have different schemas as source and sink - // of the same table. - override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { - tableSourceTable.map(_.getRowType(typeFactory)) - .orElse(tableSinkTable.map(_.getRowType(typeFactory))) - .getOrElse(throw new TableException("Unable to get row type of table source sink table.")) - } - - override def getStatistic: Statistic = { - tableSourceTable.map(_.getStatistic) - .orElse(tableSinkTable.map(_.getStatistic)) - .getOrElse(throw new TableException("Unable to get statistics of table source sink table.")) - } - - def isSourceTable: Boolean = tableSourceTable.isDefined - - def isStreamSourceTable: Boolean = tableSourceTable match { - case Some(_: StreamTableSourceTable[_]) => true - case _ => false - } - - def isBatchSourceTable: Boolean = tableSourceTable match { - case Some(_: BatchTableSourceTable[_]) => true - case _ => false - } -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala index 26224b61c03deb..9dab24928f939c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala @@ -20,27 +20,35 @@ package org.apache.flink.table.plan.schema import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} import org.apache.calcite.schema.Statistic +import org.apache.calcite.schema.impl.AbstractTable +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.stats.FlinkStatistic -import org.apache.flink.table.sources.TableSource +import org.apache.flink.table.sources.{TableSource, TableSourceUtil} -/** Abstract class which define the interfaces required to convert a [[TableSource]] to - * a Calcite Table */ -abstract class TableSourceTable[T]( +/** + * Abstract class which define the interfaces required to convert a [[TableSource]] to + * a Calcite Table. + */ +class TableSourceTable[T]( val tableSource: TableSource[T], - val statistic: FlinkStatistic) { + val isStreaming: Boolean, + val statistic: FlinkStatistic) + extends AbstractTable { - /** Returns the row type of the table with this tableSource. - * - * @param typeFactory Type factory with which to create the type - * @return Row type - */ - def getRowType(typeFactory: RelDataTypeFactory): RelDataType + TableSourceUtil.validateTableSource(tableSource) /** * Returns statistics of current table * * @return statistics of current table */ - def getStatistic: Statistic = statistic + override def getStatistic: Statistic = statistic + def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { + TableSourceUtil.getRelDataType( + tableSource, + None, + isStreaming, + typeFactory.asInstanceOf[FlinkTypeFactory]) + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java index eb247a418ec4fc..8787e94d13f17c 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java @@ -22,20 +22,12 @@ import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.TableSchema; -import org.apache.flink.table.calcite.FlinkTypeFactory; -import org.apache.flink.table.calcite.FlinkTypeSystem; -import org.apache.flink.table.plan.schema.StreamTableSourceTable; -import org.apache.flink.table.plan.schema.TableSourceSinkTable; -import org.apache.flink.table.plan.stats.FlinkStatistic; import org.apache.flink.table.sources.StreamTableSource; import java.util.HashMap; import java.util.Map; import java.util.Objects; -import scala.Option; -import scala.Some; - import static org.apache.flink.table.descriptors.ConnectorDescriptorValidator.CONNECTOR_TYPE; /** @@ -273,33 +265,28 @@ public int hashCode() { } } - private static class TestTable extends CalciteCatalogTable { - + private static class TestTable extends ConnectorCatalogTable { private final String fullyQualifiedPath; - private static final StreamTableSourceTable tableSourceTable = new StreamTableSourceTable<>( - new StreamTableSource() { - @Override - public DataStream getDataStream(StreamExecutionEnvironment execEnv) { - return null; - } + private static final StreamTableSource tableSource = new StreamTableSource() { + @Override + public DataStream getDataStream(StreamExecutionEnvironment execEnv) { + return null; + } - @Override - public TypeInformation getReturnType() { - return null; - } + @Override + public TypeInformation getReturnType() { + return null; + } - @Override - public TableSchema getTableSchema() { - return new TableSchema(new String[] {}, new TypeInformation[] {}); - } - }, FlinkStatistic.UNKNOWN()); + @Override + public TableSchema getTableSchema() { + return TableSchema.builder().build(); + } + }; private TestTable(String fullyQualifiedPath) { - super(new TableSourceSinkTable<>( - new Some<>(tableSourceTable), - Option.empty() - ), new FlinkTypeFactory(new FlinkTypeSystem())); + super(tableSource, null, tableSource.getTableSchema(), false); this.fullyQualifiedPath = fullyQualifiedPath; } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/PathResolutionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/PathResolutionTest.java index c9fc524cca48eb..f68aa5035c2b75 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/PathResolutionTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/PathResolutionTest.java @@ -19,7 +19,6 @@ package org.apache.flink.table.catalog; import org.apache.flink.table.api.java.StreamTableEnvImpl; -import org.apache.flink.table.operations.CatalogTableOperation; import org.apache.flink.table.utils.StreamTableTestUtil; import org.apache.flink.util.Preconditions; @@ -202,7 +201,7 @@ public void testTableApiPathResolution() { testSpec.getDefaultCatalog().ifPresent(catalogManager::setCurrentCatalog); testSpec.getDefaultDatabase().ifPresent(catalogManager::setCurrentDatabase); - CatalogTableOperation tab = catalogManager.resolveTable(lookupPath.toArray(new String[0])).get(); + CatalogManager.ResolvedTable tab = catalogManager.resolveTable(lookupPath.toArray(new String[0])).get(); assertThat(tab.getTablePath(), CoreMatchers.equalTo(testSpec.getExpectedPath())); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/runtime/batch/table/JavaTableEnvironmentITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/runtime/batch/table/JavaTableEnvironmentITCase.java index 78d723d263d807..ffe40f67a0b767 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/runtime/batch/table/JavaTableEnvironmentITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/runtime/batch/table/JavaTableEnvironmentITCase.java @@ -31,6 +31,7 @@ import org.apache.flink.table.api.PlannerConfig; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.api.java.BatchTableEnvironment; import org.apache.flink.table.calcite.CalciteConfigBuilder; import org.apache.flink.table.catalog.exceptions.TableAlreadyExistException; @@ -71,6 +72,28 @@ public static Collection parameters() { }); } + @Test(expected = ValidationException.class) + public void testIllegalEmptyName() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + BatchTableEnvironment tableEnv = BatchTableEnvironment.create(env, config()); + + DataSet> ds = CollectionDataSets.get3TupleDataSet(env); + Table t = tableEnv.fromDataSet(ds); + // Must fail. Table is empty + tableEnv.registerTable("", t); + } + + @Test(expected = ValidationException.class) + public void testIllegalWhitespaceOnlyName() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + BatchTableEnvironment tableEnv = BatchTableEnvironment.create(env, config()); + + DataSet> ds = CollectionDataSets.get3TupleDataSet(env); + Table t = tableEnv.fromDataSet(ds); + // Must fail. Table is empty + tableEnv.registerTable(" ", t); + } + @Test public void testSimpleRegister() throws Exception { final String tableName = "MyTable"; @@ -156,17 +179,6 @@ public void testTableRegister() throws Exception { compareResultAsText(results, expected); } - @Test(expected = TableException.class) - public void testIllegalName() throws Exception { - ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - BatchTableEnvironment tableEnv = BatchTableEnvironment.create(env, config()); - - DataSet> ds = CollectionDataSets.get3TupleDataSet(env); - Table t = tableEnv.fromDataSet(ds); - // Must fail. Table name matches internal name pattern. - tableEnv.registerTable("_DataSetTable_42", t); - } - @Test(expected = TableException.class) public void testRegisterTableFromOtherEnv() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/catalog/ExternalCatalogSchemaTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/catalog/ExternalCatalogSchemaTest.scala index cec23b09dcfc2c..c49cd9a8d4bc08 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/catalog/ExternalCatalogSchemaTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/catalog/ExternalCatalogSchemaTest.scala @@ -27,7 +27,7 @@ import org.apache.calcite.prepare.CalciteCatalogReader import org.apache.calcite.schema.SchemaPlus import org.apache.calcite.sql.validate.SqlMonikerType import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkTypeSystem} -import org.apache.flink.table.plan.schema.{TableSourceSinkTable, TableSourceTable} +import org.apache.flink.table.plan.schema.TableSourceTable import org.apache.flink.table.runtime.utils.CommonTestData import org.apache.flink.table.sources.CsvTableSource import org.apache.flink.table.utils.TableTestBase @@ -77,13 +77,8 @@ class ExternalCatalogSchemaTest extends TableTestBase { def testGetTable(): Unit = { val relOptTable = calciteCatalogReader.getTable(Lists.newArrayList(schemaName, db, tb)) assertNotNull(relOptTable) - val tableSourceSinkTable = relOptTable.unwrap(classOf[TableSourceSinkTable[_, _]]) - tableSourceSinkTable.tableSourceTable match { - case Some(tst: TableSourceTable[_]) => - assertTrue(tst.tableSource.isInstanceOf[CsvTableSource]) - case _ => - fail("unexpected table type!") - } + val tableSourceTable = relOptTable.unwrap(classOf[TableSourceTable[_]]) + assertTrue(tableSourceTable.tableSource.isInstanceOf[CsvTableSource]) } @Test From d952e710351e5a7c872f8b5379b1e8bab78ac9f6 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Wed, 29 May 2019 23:07:48 +0800 Subject: [PATCH 66/92] [hotfix][tests] Refactor the creation of InputGate for StreamNetworkBenchmarkEnvironment --- .../StreamNetworkBenchmarkEnvironment.java | 79 +++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java index 4b28961231f59f..f00a6730510485 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java @@ -26,7 +26,6 @@ import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; -import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.ConnectionID; @@ -77,7 +76,6 @@ public class StreamNetworkBenchmarkEnvironment { protected final JobID jobId = new JobID(); protected final IntermediateDataSetID dataSetID = new IntermediateDataSetID(); - protected final ExecutionAttemptID executionAttemptID = new ExecutionAttemptID(); protected NetworkEnvironment senderEnv; protected NetworkEnvironment receiverEnv; @@ -91,6 +89,8 @@ public class StreamNetworkBenchmarkEnvironment { private int dataPort; + private SingleInputGateFactory gateFactory; + public void setUp( int writers, int channels, @@ -152,6 +152,13 @@ public void setUp( receiverEnv.start(); } + gateFactory = new SingleInputGateFactory( + receiverEnv.getConfiguration(), + receiverEnv.getConnectionManager(), + receiverEnv.getResultPartitionManager(), + new TaskEventDispatcher(), + receiverEnv.getNetworkBufferPool()); + generatePartitionIds(); } @@ -167,12 +174,7 @@ public SerializingLongReceiver createReceiver() throws Exception { LOCAL_ADDRESS, dataPort); - InputGate receiverGate = createInputGate( - dataSetID, - executionAttemptID, - senderLocation, - receiverEnv, - channels); + InputGate receiverGate = createInputGate(senderLocation); SerializingLongReceiver receiver = new SerializingLongReceiver(receiverGate, channels * partitionIds.length); @@ -228,40 +230,19 @@ protected ResultPartitionWriter createResultPartition( return resultPartition; } - private InputGate createInputGate( - IntermediateDataSetID dataSetID, - ExecutionAttemptID executionAttemptID, - final TaskManagerLocation senderLocation, - NetworkEnvironment environment, - final int channels) throws IOException { - + private InputGate createInputGate(TaskManagerLocation senderLocation) throws IOException { InputGate[] gates = new InputGate[channels]; for (int channel = 0; channel < channels; ++channel) { - int finalChannel = channel; - InputChannelDeploymentDescriptor[] channelDescriptors = Arrays.stream(partitionIds) - .map(partitionId -> new InputChannelDeploymentDescriptor( - partitionId, - localMode ? ResultPartitionLocation.createLocal() : ResultPartitionLocation.createRemote(new ConnectionID(senderLocation, finalChannel)))) - .toArray(InputChannelDeploymentDescriptor[]::new); - - final InputGateDeploymentDescriptor gateDescriptor = new InputGateDeploymentDescriptor( - dataSetID, - ResultPartitionType.PIPELINED_BOUNDED, - channel, - channelDescriptors); - - SingleInputGate gate = new SingleInputGateFactory( - environment.getConfiguration(), - environment.getConnectionManager(), - environment.getResultPartitionManager(), - new TaskEventDispatcher(), - environment.getNetworkBufferPool()) - .create( - "receiving task[" + channel + "]", - gateDescriptor, - SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, - InputChannelTestUtils.newUnregisteredInputChannelMetrics(), - new SimpleCounter()); + final InputGateDeploymentDescriptor gateDescriptor = createInputGateDeploymentDescriptor( + senderLocation, + channel); + + final SingleInputGate gate = gateFactory.create( + "receiving task[" + channel + "]", + gateDescriptor, + SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, + InputChannelTestUtils.newUnregisteredInputChannelMetrics(), + new SimpleCounter()); gate.setup(); gates[channel] = gate; @@ -273,4 +254,22 @@ private InputGate createInputGate( return gates[0]; } } + + private InputGateDeploymentDescriptor createInputGateDeploymentDescriptor( + TaskManagerLocation senderLocation, + int consumedSubpartitionIndex) { + + final InputChannelDeploymentDescriptor[] channelDescriptors = Arrays.stream(partitionIds) + .map(partitionId -> new InputChannelDeploymentDescriptor( + partitionId, + localMode ? ResultPartitionLocation.createLocal() : ResultPartitionLocation.createRemote( + new ConnectionID(senderLocation, consumedSubpartitionIndex)))) + .toArray(InputChannelDeploymentDescriptor[]::new); + + return new InputGateDeploymentDescriptor( + dataSetID, + ResultPartitionType.PIPELINED_BOUNDED, + consumedSubpartitionIndex, + channelDescriptors); + } } From 3351ca11a366713328a6d0cb0e87100cb8c24200 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Thu, 30 May 2019 23:15:27 +0800 Subject: [PATCH 67/92] [FLINK-12201][network,metrics] Introduce InputGateWithMetrics in Task to increment numBytesIn metric Incrementing of numBytesIn metric in SingleInputGate does not depend on shuffle service and can be moved out of network internals into Task. Task could wrap InputGate provided by ShuffleService with InputGateWithMetrics which would increment numBytesIn metric. --- .../io/network/NetworkEnvironment.java | 7 +- .../partition/consumer/BufferOrEvent.java | 21 +++- .../partition/consumer/SingleInputGate.java | 9 +- .../consumer/SingleInputGateFactory.java | 5 +- .../taskmanager/InputGateWithMetrics.java | 107 ++++++++++++++++++ .../flink/runtime/taskmanager/Task.java | 14 ++- .../partition/InputGateFairnessTest.java | 2 - .../consumer/SingleInputGateBuilder.java | 5 - .../consumer/SingleInputGateTest.java | 7 +- .../streaming/runtime/io/BufferSpiller.java | 4 +- .../StreamNetworkBenchmarkEnvironment.java | 22 +++- 11 files changed, 156 insertions(+), 47 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index b90e9038eae2fa..d381203218935c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -20,7 +20,6 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.JobID; -import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.Gauge; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -248,8 +247,7 @@ public SingleInputGate[] createInputGates( Collection inputGateDeploymentDescriptors, MetricGroup parentGroup, MetricGroup inputGroup, - MetricGroup buffersGroup, - Counter numBytesInCounter) { + MetricGroup buffersGroup) { synchronized (lock) { Preconditions.checkState(!isShutdown, "The NetworkEnvironment has already been shut down."); @@ -261,8 +259,7 @@ public SingleInputGate[] createInputGates( taskName, igdd, partitionProducerStateProvider, - inputChannelMetrics, - numBytesInCounter); + inputChannelMetrics); InputGateID id = new InputGateID(igdd.getConsumedResultId(), executionId); inputGatesById.put(id, inputGate); inputGate.getCloseFuture().thenRun(() -> inputGatesById.remove(id)); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java index d1da4388c1b2be..f3ba122c5a6101 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.buffer.Buffer; @@ -43,26 +44,32 @@ public class BufferOrEvent { private int channelIndex; - BufferOrEvent(Buffer buffer, int channelIndex, boolean moreAvailable) { + private final int size; + + public BufferOrEvent(Buffer buffer, int channelIndex, boolean moreAvailable) { this.buffer = checkNotNull(buffer); this.event = null; this.channelIndex = channelIndex; this.moreAvailable = moreAvailable; + this.size = buffer.getSize(); } - BufferOrEvent(AbstractEvent event, int channelIndex, boolean moreAvailable) { + public BufferOrEvent(AbstractEvent event, int channelIndex, boolean moreAvailable, int size) { this.buffer = null; this.event = checkNotNull(event); this.channelIndex = channelIndex; this.moreAvailable = moreAvailable; + this.size = size; } + @VisibleForTesting public BufferOrEvent(Buffer buffer, int channelIndex) { this(buffer, channelIndex, true); } + @VisibleForTesting public BufferOrEvent(AbstractEvent event, int channelIndex) { - this(event, channelIndex, true); + this(event, channelIndex, true, 0); } public boolean isBuffer() { @@ -96,11 +103,15 @@ boolean moreAvailable() { @Override public String toString() { - return String.format("BufferOrEvent [%s, channelIndex = %d]", - isBuffer() ? buffer : event, channelIndex); + return String.format("BufferOrEvent [%s, channelIndex = %d, size = %d]", + isBuffer() ? buffer : event, channelIndex, size); } public void setMoreAvailable(boolean moreAvailable) { this.moreAvailable = moreAvailable; } + + public int getSize() { + return size; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 5e5a722ffd51f2..4e718d713b0897 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; import org.apache.flink.runtime.event.AbstractEvent; @@ -165,8 +164,6 @@ public class SingleInputGate extends InputGate { /** A timer to retrigger local partition requests. Only initialized if actually needed. */ private Timer retriggerLocalRequestTimer; - private final Counter numBytesIn; - private final SupplierWithException bufferPoolFactory; private final CompletableFuture closeFuture; @@ -178,7 +175,6 @@ public SingleInputGate( int consumedSubpartitionIndex, int numberOfInputChannels, PartitionProducerStateProvider partitionProducerStateProvider, - Counter numBytesIn, boolean isCreditBased, SupplierWithException bufferPoolFactory) { @@ -200,8 +196,6 @@ public SingleInputGate( this.partitionProducerStateProvider = checkNotNull(partitionProducerStateProvider); - this.numBytesIn = checkNotNull(numBytesIn); - this.isCreditBased = isCreditBased; this.closeFuture = new CompletableFuture<>(); @@ -566,7 +560,6 @@ private BufferOrEvent transformToBufferOrEvent( Buffer buffer, boolean moreAvailable, InputChannel currentChannel) throws IOException, InterruptedException { - numBytesIn.inc(buffer.getSizeUnsafe()); if (buffer.isBuffer()) { return new BufferOrEvent(buffer, currentChannel.getChannelIndex(), moreAvailable); } @@ -596,7 +589,7 @@ private BufferOrEvent transformToBufferOrEvent( currentChannel.releaseAllResources(); } - return new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable); + return new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable, buffer.getSize()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java index fcc36659edd78f..6caf0174fac706 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.metrics.Counter; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; @@ -99,8 +98,7 @@ public SingleInputGate create( @Nonnull String owningTaskName, @Nonnull InputGateDeploymentDescriptor igdd, @Nonnull PartitionProducerStateProvider partitionProducerStateProvider, - @Nonnull InputChannelMetrics metrics, - @Nonnull Counter numBytesInCounter) { + @Nonnull InputChannelMetrics metrics) { final IntermediateDataSetID consumedResultId = checkNotNull(igdd.getConsumedResultId()); final ResultPartitionType consumedPartitionType = checkNotNull(igdd.getConsumedPartitionType()); @@ -116,7 +114,6 @@ public SingleInputGate create( consumedSubpartitionIndex, icdd.length, partitionProducerStateProvider, - numBytesInCounter, isCreditBased, createBufferPoolFactory(icdd.length, consumedPartitionType)); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java new file mode 100644 index 00000000000000..e9e303830a287a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.taskmanager; + +import org.apache.flink.metrics.Counter; +import org.apache.flink.runtime.event.TaskEvent; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; + +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * This class wraps {@link InputGate} provided by shuffle service and it is mainly + * used for increasing general input metrics from {@link TaskIOMetricGroup}. + */ +public class InputGateWithMetrics extends InputGate { + + private final InputGate inputGate; + + private final Counter numBytesIn; + + public InputGateWithMetrics(InputGate inputGate, Counter numBytesIn) { + this.inputGate = checkNotNull(inputGate); + this.numBytesIn = checkNotNull(numBytesIn); + } + + @Override + public CompletableFuture isAvailable() { + return inputGate.isAvailable(); + } + + @Override + public int getNumberOfInputChannels() { + return inputGate.getNumberOfInputChannels(); + } + + @Override + public String getOwningTaskName() { + return inputGate.getOwningTaskName(); + } + + @Override + public boolean isFinished() { + return inputGate.isFinished(); + } + + @Override + public void setup() throws IOException { + inputGate.setup(); + } + + @Override + public void requestPartitions() throws IOException, InterruptedException { + inputGate.requestPartitions(); + } + + @Override + public Optional getNextBufferOrEvent() throws IOException, InterruptedException { + return updateMetrics(inputGate.getNextBufferOrEvent()); + } + + @Override + public Optional pollNextBufferOrEvent() throws IOException, InterruptedException { + return updateMetrics(inputGate.pollNextBufferOrEvent()); + } + + @Override + public void sendTaskEvent(TaskEvent event) throws IOException { + inputGate.sendTaskEvent(event); + } + + @Override + public int getPageSize() { + return inputGate.getPageSize(); + } + + @Override + public void close() throws Exception { + inputGate.close(); + } + + private Optional updateMetrics(Optional bufferOrEvent) { + bufferOrEvent.ifPresent(b -> numBytesIn.inc(b.getSize())); + return bufferOrEvent; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 57d8c6bdc20a32..02ef419e746ac9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -56,7 +56,6 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; -import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; @@ -192,7 +191,7 @@ public class Task implements Runnable, TaskActions, PartitionProducerStateProvid private final ResultPartitionWriter[] producedPartitions; - private final SingleInputGate[] inputGates; + private final InputGate[] inputGates; /** Connection to the task manager. */ private final TaskManagerActions taskManagerActions; @@ -380,15 +379,20 @@ public Task( buffersGroup); // consumed intermediate result partitions - this.inputGates = networkEnvironment.createInputGates( + InputGate[] gates = networkEnvironment.createInputGates( taskNameWithSubtaskAndId, executionId, this, inputGateDeploymentDescriptors, metrics.getIOMetricGroup(), inputGroup, - buffersGroup, - metrics.getIOMetricGroup().getNumBytesInCounter()); + buffersGroup); + + this.inputGates = new InputGate[gates.length]; + int counter = 0; + for (InputGate gate : gates) { + inputGates[counter++] = new InputGateWithMetrics(gate, metrics.getIOMetricGroup().getNumBytesInCounter()); + } invokableHasBeenCanceled = new AtomicBoolean(false); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java index d670d01743e09f..10221c5e86baf8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition; -import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; @@ -331,7 +330,6 @@ public FairnessVerifyingInputGate( consumedSubpartitionIndex, numberOfInputChannels, SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, - new SimpleCounter(), isCreditBased, STUB_BUFFER_POOL_FACTORY); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java index 51eba307403441..8c04d5f60b85dc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java @@ -18,8 +18,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.metrics.Counter; -import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; @@ -51,8 +49,6 @@ public class SingleInputGateBuilder { private final PartitionProducerStateProvider partitionProducerStateProvider = NO_OP_PRODUCER_CHECKER; - private final Counter numBytesInCounter = new SimpleCounter(); - private boolean isCreditBased = true; private SupplierWithException bufferPoolFactory = () -> { @@ -99,7 +95,6 @@ public SingleInputGate build() { consumedSubpartitionIndex, numberOfChannels, partitionProducerStateProvider, - numBytesInCounter, isCreditBased, bufferPoolFactory); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index c87f99f4e932ff..5bbb26a32944a2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.core.memory.MemorySegmentFactory; -import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -347,8 +346,7 @@ public void testRequestBackoffConfiguration() throws Exception { "TestTask", gateDesc, SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, - InputChannelTestUtils.newUnregisteredInputChannelMetrics(), - new SimpleCounter()); + InputChannelTestUtils.newUnregisteredInputChannelMetrics()); try { assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType()); @@ -599,8 +597,7 @@ private static Map createInputGateWithLocalChannel Arrays.asList(gateDescs), new UnregisteredMetricsGroup(), new UnregisteredMetricsGroup(), - new UnregisteredMetricsGroup(), - new SimpleCounter()); + new UnregisteredMetricsGroup()); Map inputGatesById = new HashMap<>(); for (int i = 0; i < numberOfGates; i++) { inputGatesById.put(new InputGateID(ids[i], consumerID), gates[i]); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BufferSpiller.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BufferSpiller.java index 700043092dfc42..5a7c4969bc14a4 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BufferSpiller.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BufferSpiller.java @@ -374,7 +374,7 @@ public BufferOrEvent getNext() throws IOException { Buffer buf = new NetworkBuffer(seg, FreeingBufferRecycler.INSTANCE); buf.setSize(length); - return new BufferOrEvent(buf, channel); + return new BufferOrEvent(buf, channel, true); } else { // deserialize event @@ -399,7 +399,7 @@ public BufferOrEvent getNext() throws IOException { AbstractEvent evt = EventSerializer.fromSerializedEvent(buffer, getClass().getClassLoader()); buffer.limit(oldLimit); - return new BufferOrEvent(evt, channel); + return new BufferOrEvent(evt, channel, true, length); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java index f00a6730510485..7a9c863d088b7b 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java @@ -47,6 +47,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory; import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.taskmanager.InputGateWithMetrics; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.ConfigurationParserUtils; @@ -237,12 +238,7 @@ private InputGate createInputGate(TaskManagerLocation senderLocation) throws IOE senderLocation, channel); - final SingleInputGate gate = gateFactory.create( - "receiving task[" + channel + "]", - gateDescriptor, - SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, - InputChannelTestUtils.newUnregisteredInputChannelMetrics(), - new SimpleCounter()); + final InputGate gate = createInputGateWithMetrics(gateFactory, gateDescriptor, channel); gate.setup(); gates[channel] = gate; @@ -272,4 +268,18 @@ private InputGateDeploymentDescriptor createInputGateDeploymentDescriptor( consumedSubpartitionIndex, channelDescriptors); } + + private InputGate createInputGateWithMetrics( + SingleInputGateFactory gateFactory, + InputGateDeploymentDescriptor gateDescriptor, + int channelIndex) { + + final SingleInputGate singleGate = gateFactory.create( + "receiving task[" + channelIndex + "]", + gateDescriptor, + SingleInputGateBuilder.NO_OP_PRODUCER_CHECKER, + InputChannelTestUtils.newUnregisteredInputChannelMetrics()); + + return new InputGateWithMetrics(singleGate, new SimpleCounter()); + } } From a36ec7fcecf909a1d6a8c2b85a125d0ccc8b0934 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Fri, 31 May 2019 16:47:38 +0800 Subject: [PATCH 68/92] [hotfix][network] Drop redundant getSizeUnsafe from Buffer interface The implementations of getSize and getSizeUnsafe are exactly the same now, which do not need the synchronized way. So We could remove the getSizeUnsafe to make it clean and clear. --- .../flink/runtime/io/network/buffer/Buffer.java | 11 ----------- .../runtime/io/network/buffer/NetworkBuffer.java | 5 ----- .../network/buffer/ReadOnlySlicedNetworkBuffer.java | 5 ----- .../network/partition/consumer/LocalInputChannel.java | 2 +- .../partition/consumer/RemoteInputChannel.java | 2 +- .../runtime/io/network/buffer/NetworkBufferTest.java | 6 ------ .../io/network/buffer/ReadOnlySlicedBufferTest.java | 10 ++-------- 7 files changed, 4 insertions(+), 37 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/Buffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/Buffer.java index 96b18ee354a70f..8e5d85ff59ee29 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/Buffer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/Buffer.java @@ -158,17 +158,6 @@ public interface Buffer { */ void setReaderIndex(int readerIndex) throws IndexOutOfBoundsException; - /** - * Returns the size of the written data, i.e. the writer index, of this buffer in an - * non-synchronized fashion. - * - *

This is where writable bytes start in the backing memory segment. - * - * @return writer index (from 0 (inclusive) to the size of the backing {@link MemorySegment} - * (inclusive)) - */ - int getSizeUnsafe(); - /** * Returns the size of the written data, i.e. the writer index, of this buffer. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBuffer.java index 05b75821e1b6c3..56afbb888bdf40 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBuffer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBuffer.java @@ -290,11 +290,6 @@ public void setReaderIndex(int readerIndex) throws IndexOutOfBoundsException { readerIndex(readerIndex); } - @Override - public int getSizeUnsafe() { - return writerIndex(); - } - @Override public int getSize() { return writerIndex(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedNetworkBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedNetworkBuffer.java index 00e11547cee47a..dd421b703aa1a8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedNetworkBuffer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedNetworkBuffer.java @@ -152,11 +152,6 @@ public void setReaderIndex(int readerIndex) throws IndexOutOfBoundsException { readerIndex(readerIndex); } - @Override - public int getSizeUnsafe() { - return writerIndex(); - } - @Override public int getSize() { return writerIndex(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index feb6b72f65197d..3a543104d24b63 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -193,7 +193,7 @@ Optional getNextBuffer() throws IOException, InterruptedE } } - numBytesIn.inc(next.buffer().getSizeUnsafe()); + numBytesIn.inc(next.buffer().getSize()); numBuffersIn.inc(); return Optional.of(new BufferAndAvailability(next.buffer(), next.isMoreAvailable(), next.buffersInBacklog())); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 50bf1d07945dd3..a42a93bee0f383 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -201,7 +201,7 @@ Optional getNextBuffer() throws IOException { moreAvailable = !receivedBuffers.isEmpty(); } - numBytesIn.inc(next.getSizeUnsafe()); + numBytesIn.inc(next.getSize()); numBuffersIn.inc(); return Optional.of(new BufferAndAvailability(next, moreAvailable, getSenderBacklog())); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferTest.java index 47615d9ea3ec1a..1cc4fc32201871 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferTest.java @@ -216,7 +216,6 @@ private static void testCreateSlice1(boolean isBuffer) { assertEquals(0, slice.getReaderIndex()); assertEquals(10, slice.getSize()); - assertEquals(10, slice.getSizeUnsafe()); assertSame(buffer, slice.unwrap().unwrap()); // slice indices should be independent: @@ -224,7 +223,6 @@ private static void testCreateSlice1(boolean isBuffer) { buffer.setReaderIndex(2); assertEquals(0, slice.getReaderIndex()); assertEquals(10, slice.getSize()); - assertEquals(10, slice.getSizeUnsafe()); } @Test @@ -244,7 +242,6 @@ private static void testCreateSlice2(boolean isBuffer) { assertEquals(0, slice.getReaderIndex()); assertEquals(10, slice.getSize()); - assertEquals(10, slice.getSizeUnsafe()); assertSame(buffer, slice.unwrap().unwrap()); // slice indices should be independent: @@ -252,7 +249,6 @@ private static void testCreateSlice2(boolean isBuffer) { buffer.setReaderIndex(2); assertEquals(0, slice.getReaderIndex()); assertEquals(10, slice.getSize()); - assertEquals(10, slice.getSizeUnsafe()); } @Test @@ -312,13 +308,11 @@ private static void testSetGetSize(boolean isBuffer) { NetworkBuffer buffer = newBuffer(1024, 1024, isBuffer); assertEquals(0, buffer.getSize()); // initially 0 - assertEquals(0, buffer.getSizeUnsafe()); assertEquals(buffer.writerIndex(), buffer.getSize()); assertEquals(0, buffer.readerIndex()); // initially 0 buffer.setSize(10); assertEquals(10, buffer.getSize()); - assertEquals(10, buffer.getSizeUnsafe()); assertEquals(buffer.writerIndex(), buffer.getSize()); assertEquals(0, buffer.readerIndex()); // independent } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedBufferTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedBufferTest.java index d5814ee20c0616..444fdc73a81a79 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedBufferTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/ReadOnlySlicedBufferTest.java @@ -233,8 +233,7 @@ private void testGetSetReaderIndex(ReadOnlySlicedNetworkBuffer slice) { /** * Tests the independence of the writer index via * {@link ReadOnlySlicedNetworkBuffer#setSize(int)}, - * {@link ReadOnlySlicedNetworkBuffer#getSize()}, and - * {@link ReadOnlySlicedNetworkBuffer#getSizeUnsafe()}. + * {@link ReadOnlySlicedNetworkBuffer#getSize()}. */ @Test public void testGetSetSize1() { @@ -244,8 +243,7 @@ public void testGetSetSize1() { /** * Tests the independence of the writer index via * {@link ReadOnlySlicedNetworkBuffer#setSize(int)}, - * {@link ReadOnlySlicedNetworkBuffer#getSize()}, and - * {@link ReadOnlySlicedNetworkBuffer#getSizeUnsafe()}. + * {@link ReadOnlySlicedNetworkBuffer#getSize()}. */ @Test public void testGetSetSize2() { @@ -254,14 +252,10 @@ public void testGetSetSize2() { private void testGetSetSize(ReadOnlySlicedNetworkBuffer slice, int sliceSize) { assertEquals(DATA_SIZE, buffer.getSize()); - assertEquals(DATA_SIZE, buffer.getSizeUnsafe()); assertEquals(sliceSize, slice.getSize()); - assertEquals(sliceSize, slice.getSizeUnsafe()); buffer.setSize(DATA_SIZE + 1); assertEquals(DATA_SIZE + 1, buffer.getSize()); - assertEquals(DATA_SIZE + 1, buffer.getSizeUnsafe()); assertEquals(sliceSize, slice.getSize()); - assertEquals(sliceSize, slice.getSizeUnsafe()); } @Test From 70fa80e3862b367be22b593db685f9898a2838ef Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 3 Jan 2019 23:14:09 +0800 Subject: [PATCH 69/92] [FLINK-11249][kafka] Use custom serializers for NextTransactionalIdHint in 0.11 and universal FlinkKafkaProducer --- .../kafka/FlinkKafkaProducer011.java | 132 ++++++++++++++++- ...NextTransactionalIdHintSerializerTest.java | 56 ++++++++ .../connectors/kafka/FlinkKafkaProducer.java | 136 +++++++++++++++++- ...NextTransactionalIdHintSerializerTest.java | 56 ++++++++ 4 files changed, 376 insertions(+), 4 deletions(-) create mode 100644 flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java create mode 100644 flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java diff --git a/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java b/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java index 8b3cccddacdbc7..277f89124b6626 100644 --- a/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java +++ b/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java @@ -174,10 +174,16 @@ public enum Semantic { /** * Descriptor of the transactional IDs list. + * Note: This state is serialized by Kryo Serializer and it has compatibility problem that will be removed later. + * Please use NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2. */ + @Deprecated private static final ListStateDescriptor NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR = new ListStateDescriptor<>("next-transactional-id-hint", TypeInformation.of(NextTransactionalIdHint.class)); + private static final ListStateDescriptor NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2 = + new ListStateDescriptor<>("next-transactional-id-hint-v2", new NextTransactionalIdHintSerializer()); + /** * State for nextTransactionalIdHint. */ @@ -816,8 +822,8 @@ public void initializeState(FunctionInitializationContext context) throws Except semantic = Semantic.NONE; } - nextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState( - NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR); + migrateNextTransactionalIdHindState(context); + transactionalIdsGenerator = new TransactionalIdsGenerator( getRuntimeContext().getTaskName() + "-" + ((StreamingRuntimeContext) getRuntimeContext()).getOperatorUniqueID(), getRuntimeContext().getIndexOfThisSubtask(), @@ -1002,6 +1008,19 @@ private void readObject(java.io.ObjectInputStream in) throws IOException, ClassN in.defaultReadObject(); } + private void migrateNextTransactionalIdHindState(FunctionInitializationContext context) throws Exception { + ListState oldNextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState( + NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR); + nextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState(NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2); + + ArrayList oldTransactionalIdHints = Lists.newArrayList(oldNextTransactionalIdHintState.get()); + if (!oldTransactionalIdHints.isEmpty()) { + nextTransactionalIdHintState.addAll(oldTransactionalIdHints); + //clear old state + oldNextTransactionalIdHintState.clear(); + } + } + private static Properties getPropertiesFromBrokerList(String brokerList) { String[] elements = brokerList.split(","); @@ -1355,5 +1374,114 @@ public NextTransactionalIdHint(int parallelism, long nextFreeTransactionalId) { this.lastParallelism = parallelism; this.nextFreeTransactionalId = nextFreeTransactionalId; } + + @Override + public String toString() { + return "NextTransactionalIdHint[" + + "lastParallelism=" + lastParallelism + + ", nextFreeTransactionalId=" + nextFreeTransactionalId + + ']'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + NextTransactionalIdHint that = (NextTransactionalIdHint) o; + + if (lastParallelism != that.lastParallelism) { + return false; + } + return nextFreeTransactionalId == that.nextFreeTransactionalId; + } + + @Override + public int hashCode() { + int result = lastParallelism; + result = 31 * result + (int) (nextFreeTransactionalId ^ (nextFreeTransactionalId >>> 32)); + return result; + } + } + + /** + * {@link org.apache.flink.api.common.typeutils.TypeSerializer} for + * {@link NextTransactionalIdHint}. + */ + @VisibleForTesting + @Internal + public static class NextTransactionalIdHintSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + @Override + public boolean isImmutableType() { + return true; + } + + @Override + public NextTransactionalIdHint createInstance() { + return new NextTransactionalIdHint(); + } + + @Override + public NextTransactionalIdHint copy(NextTransactionalIdHint from) { + return from; + } + + @Override + public NextTransactionalIdHint copy(NextTransactionalIdHint from, NextTransactionalIdHint reuse) { + return from; + } + + @Override + public int getLength() { + return Long.BYTES + Integer.BYTES; + } + + @Override + public void serialize(NextTransactionalIdHint record, DataOutputView target) throws IOException { + target.writeLong(record.nextFreeTransactionalId); + target.writeInt(record.lastParallelism); + } + + @Override + public NextTransactionalIdHint deserialize(DataInputView source) throws IOException { + long nextFreeTransactionalId = source.readLong(); + int lastParallelism = source.readInt(); + return new NextTransactionalIdHint(lastParallelism, nextFreeTransactionalId); + } + + @Override + public NextTransactionalIdHint deserialize(NextTransactionalIdHint reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + target.writeLong(source.readLong()); + target.writeInt(source.readInt()); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new NextTransactionalIdHintSerializerSnapshot(); + } + + /** + * Serializer configuration snapshot for compatibility and format evolution. + */ + @SuppressWarnings("WeakerAccess") + public static final class NextTransactionalIdHintSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public NextTransactionalIdHintSerializerSnapshot() { + super(NextTransactionalIdHintSerializer::new); + } + } } + } diff --git a/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java b/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java new file mode 100644 index 00000000000000..db9bbc6016f233 --- /dev/null +++ b/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kafka; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +/** + * A test for the {@link TypeSerializer TypeSerializers} used for + * {@link FlinkKafkaProducer011.NextTransactionalIdHint}. + */ +public class NextTransactionalIdHintSerializerTest extends + SerializerTestBase { + + @Override + protected TypeSerializer createSerializer() { + return new FlinkKafkaProducer011.NextTransactionalIdHintSerializer(); + } + + @Override + protected int getLength() { + return Long.BYTES + Integer.BYTES; + } + + @Override + protected Class getTypeClass() { + return (Class) FlinkKafkaProducer011.NextTransactionalIdHint.class; + } + + @Override + protected FlinkKafkaProducer011.NextTransactionalIdHint[] getTestData() { + return new FlinkKafkaProducer011.NextTransactionalIdHint[] { + new FlinkKafkaProducer011.NextTransactionalIdHint(1, 0L), + new FlinkKafkaProducer011.NextTransactionalIdHint(1, 1L), + new FlinkKafkaProducer011.NextTransactionalIdHint(1, -1L), + new FlinkKafkaProducer011.NextTransactionalIdHint(2, 0L), + new FlinkKafkaProducer011.NextTransactionalIdHint(2, 1L), + new FlinkKafkaProducer011.NextTransactionalIdHint(2, -1L), + }; + } +} diff --git a/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer.java b/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer.java index 9eb2df8ac053c6..3ab6a06b5374be 100644 --- a/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer.java +++ b/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer.java @@ -176,9 +176,15 @@ public enum Semantic { /** * Descriptor of the transactional IDs list. + * Note: This state is serialized by Kryo Serializer and it has compatibility problem that will be removed later. + * Please use NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2. */ + @Deprecated private static final ListStateDescriptor NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR = - new ListStateDescriptor<>("next-transactional-id-hint", TypeInformation.of(FlinkKafkaProducer.NextTransactionalIdHint.class)); + new ListStateDescriptor<>("next-transactional-id-hint", TypeInformation.of(NextTransactionalIdHint.class)); + + private static final ListStateDescriptor NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2 = + new ListStateDescriptor<>("next-transactional-id-hint-v2", new NextTransactionalIdHintSerializer()); /** * State for nextTransactionalIdHint. @@ -819,7 +825,12 @@ public void initializeState(FunctionInitializationContext context) throws Except } nextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState( - NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR); + NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2); + + if (context.getOperatorStateStore().getRegisteredStateNames().contains(NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR)) { + migrateNextTransactionalIdHindState(context); + } + transactionalIdsGenerator = new TransactionalIdsGenerator( getRuntimeContext().getTaskName() + "-" + ((StreamingRuntimeContext) getRuntimeContext()).getOperatorUniqueID(), getRuntimeContext().getIndexOfThisSubtask(), @@ -1008,6 +1019,19 @@ private void readObject(java.io.ObjectInputStream in) throws IOException, ClassN in.defaultReadObject(); } + private void migrateNextTransactionalIdHindState(FunctionInitializationContext context) throws Exception { + ListState oldNextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState( + NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR); + nextTransactionalIdHintState = context.getOperatorStateStore().getUnionListState(NEXT_TRANSACTIONAL_ID_HINT_DESCRIPTOR_V2); + + ArrayList oldTransactionalIdHints = Lists.newArrayList(oldNextTransactionalIdHintState.get()); + if (!oldTransactionalIdHints.isEmpty()) { + nextTransactionalIdHintState.addAll(oldTransactionalIdHints); + //clear old state + oldNextTransactionalIdHintState.clear(); + } + } + private static Properties getPropertiesFromBrokerList(String brokerList) { String[] elements = brokerList.split(","); @@ -1363,5 +1387,113 @@ public NextTransactionalIdHint(int parallelism, long nextFreeTransactionalId) { this.lastParallelism = parallelism; this.nextFreeTransactionalId = nextFreeTransactionalId; } + + @Override + public String toString() { + return "NextTransactionalIdHint[" + + "lastParallelism=" + lastParallelism + + ", nextFreeTransactionalId=" + nextFreeTransactionalId + + ']'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + NextTransactionalIdHint that = (NextTransactionalIdHint) o; + + if (lastParallelism != that.lastParallelism) { + return false; + } + return nextFreeTransactionalId == that.nextFreeTransactionalId; + } + + @Override + public int hashCode() { + int result = lastParallelism; + result = 31 * result + (int) (nextFreeTransactionalId ^ (nextFreeTransactionalId >>> 32)); + return result; + } + } + + /** + * {@link org.apache.flink.api.common.typeutils.TypeSerializer} for + * {@link FlinkKafkaProducer.NextTransactionalIdHint}. + */ + @VisibleForTesting + @Internal + public static class NextTransactionalIdHintSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + @Override + public boolean isImmutableType() { + return true; + } + + @Override + public NextTransactionalIdHint createInstance() { + return new NextTransactionalIdHint(); + } + + @Override + public NextTransactionalIdHint copy(NextTransactionalIdHint from) { + return from; + } + + @Override + public NextTransactionalIdHint copy(NextTransactionalIdHint from, NextTransactionalIdHint reuse) { + return from; + } + + @Override + public int getLength() { + return Long.BYTES + Integer.BYTES; + } + + @Override + public void serialize(NextTransactionalIdHint record, DataOutputView target) throws IOException { + target.writeLong(record.nextFreeTransactionalId); + target.writeInt(record.lastParallelism); + } + + @Override + public NextTransactionalIdHint deserialize(DataInputView source) throws IOException { + long nextFreeTransactionalId = source.readLong(); + int lastParallelism = source.readInt(); + return new NextTransactionalIdHint(lastParallelism, nextFreeTransactionalId); + } + + @Override + public NextTransactionalIdHint deserialize(NextTransactionalIdHint reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + target.writeLong(source.readLong()); + target.writeInt(source.readInt()); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new NextTransactionalIdHintSerializerSnapshot(); + } + + /** + * Serializer configuration snapshot for compatibility and format evolution. + */ + @SuppressWarnings("WeakerAccess") + public static final class NextTransactionalIdHintSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public NextTransactionalIdHintSerializerSnapshot() { + super(NextTransactionalIdHintSerializer::new); + } + } } } diff --git a/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java b/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java new file mode 100644 index 00000000000000..e21afc81422b6d --- /dev/null +++ b/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/NextTransactionalIdHintSerializerTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kafka; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +/** + * A test for the {@link TypeSerializer TypeSerializers} used for + * {@link FlinkKafkaProducer.NextTransactionalIdHint}. + */ +public class NextTransactionalIdHintSerializerTest extends + SerializerTestBase { + + @Override + protected TypeSerializer createSerializer() { + return new FlinkKafkaProducer.NextTransactionalIdHintSerializer(); + } + + @Override + protected int getLength() { + return Long.BYTES + Integer.BYTES; + } + + @Override + protected Class getTypeClass() { + return (Class) FlinkKafkaProducer.NextTransactionalIdHint.class; + } + + @Override + protected FlinkKafkaProducer.NextTransactionalIdHint[] getTestData() { + return new FlinkKafkaProducer.NextTransactionalIdHint[] { + new FlinkKafkaProducer.NextTransactionalIdHint(1, 0L), + new FlinkKafkaProducer.NextTransactionalIdHint(1, 1L), + new FlinkKafkaProducer.NextTransactionalIdHint(1, -1L), + new FlinkKafkaProducer.NextTransactionalIdHint(2, 0L), + new FlinkKafkaProducer.NextTransactionalIdHint(2, 1L), + new FlinkKafkaProducer.NextTransactionalIdHint(2, -1L), + }; + } +} From 9be40f681d40be448dfdb925684e3d6b0d0635af Mon Sep 17 00:00:00 2001 From: Piotr Nowojski Date: Tue, 26 Mar 2019 15:30:14 +0100 Subject: [PATCH 70/92] [FLINK-11249][kafka,test] Update Kafka migration resources with new NextTransactionalIdHint serializer --- ...igration-kafka-producer-flink-1.8-snapshot | Bin 1746 -> 2032 bytes ...igration-kafka-producer-flink-1.8-snapshot | Bin 1731 -> 1232 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/flink-connectors/flink-connector-kafka-0.11/src/test/resources/kafka-migration-kafka-producer-flink-1.8-snapshot b/flink-connectors/flink-connector-kafka-0.11/src/test/resources/kafka-migration-kafka-producer-flink-1.8-snapshot index 29c6ccc245b2e78186f88e17df27c10cc45f4df9..0a22c69360deb3230038dc96f8bc96ba092f1712 100644 GIT binary patch delta 273 zcmcb_`+?t%0SPd&06EMIvU#Z$CAuX=iFw6|$t9Wjd5Jl?nJKy%nRz9;Wk!rZH6RQG ztXC#B$Yeq|ti>gXC8;1$nB)US1{Fhd3$wIj<3wFUQ&SUN69XetT}zW>OWh;`*PdMoynqXdjWO_MP&d0 delta 165 zcmeyse~DL!0SPd&069#P7c$BOF#=g23P-6XSQLyJVCWaH!{ zpdJ=)1_osYR)RK8oa4jTHQA6cj3YR;C^Io9vubiKqtxULEaIE*2xoD g4JX&J3NU{(mI1OIIX)V4fLz5oIgw3gG8@}o0F384F8}}l delta 572 zcmcb>d6>7J0SPd&069zyQhBKrCAuX=iFw6|$t9Wjd5Jl?nJKy%nRz9QKm{NS1S}^Q zSc^*%OHx4`nAmkj29-2Z1JfjPGh3RWB_kGcOy+$<)hAEKAfYsVqn>Ey>I& z)+;K_E6L1F)eFea%Ab6INtBt1A#CFNFcqLvKOLK8TL83=iGe>SvA854u_!SoCp9Ot zI2YvbOt88Tta2d?epu~J&d<%w&qK03DX}}MnF6xA5NvlG13ooj3;`s#!6!d29mx?OF31()h)@7p3l9;7&9#hI83mYt oVz5XAMhoL)M;4vQ-OK{ak9HMJp2zIS@n}Z?$hoXw31ODK0H5iqC;$Ke From 1de854a7fa907c47ecbcd9aee625467e0127f929 Mon Sep 17 00:00:00 2001 From: Piotr Nowojski Date: Mon, 11 Feb 2019 16:22:45 +0100 Subject: [PATCH 71/92] [FLINK-11249][kafka] Fix migration from FlinkKafkaProducer0.11 to universal Add backward compatibile classes to the universal FlinkKafkaProducer, so that it can restore from 0.11 checkpoints. --- docs/dev/connectors/kafka.md | 9 +++ .../kafka/FlinkKafkaProducer011.java | 69 ++++++++++++++++++ ...inkKafkaProducerMigrationOperatorTest.java | 58 +++++++++++++++ ...igration-kafka-producer-flink-1.8-snapshot | Bin 0 -> 2032 bytes 4 files changed, 136 insertions(+) create mode 100644 flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java create mode 100644 flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerMigrationOperatorTest.java create mode 100644 flink-connectors/flink-connector-kafka/src/test/resources/kafka-0.11-migration-kafka-producer-flink-1.8-snapshot diff --git a/docs/dev/connectors/kafka.md b/docs/dev/connectors/kafka.md index 8f893c4c05a6ff..4d0759f37a7011 100644 --- a/docs/dev/connectors/kafka.md +++ b/docs/dev/connectors/kafka.md @@ -135,6 +135,15 @@ The universal Kafka connector is compatible with older and newer Kafka brokers t It is compatible with broker versions 0.11.0 or newer, depending on the features used. For details on Kafka compatibility, please refer to the [Kafka documentation](https://kafka.apache.org/protocol.html#protocol_compatibility). +### Migrating Kafka Connector from 0.11 to universal + +In order to perform the migration, see the [upgrading jobs and Flink versions guide]({{ site.baseurl }}/ops/upgrading.html) +and: +* Use Flink 1.9 or newer for the whole process. +* Do not upgrade the Flink and operators at the same time. +* Make sure that Kafka Consumer and/or Kafka Producer used in your job have assigned unique identifiers (`uid`): +* Use stop with savepoint feature to take the savepoint (for example by using `stop --withSavepoint`)[CLI command]({{ site.baseurl }}/ops/cli.html). + ### Usage To use the universal Kafka connector add a dependency to it: diff --git a/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java b/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java new file mode 100644 index 00000000000000..3d10a8e3abb494 --- /dev/null +++ b/flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kafka; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; + +/** + * Compatibility class to make migration possible from the 0.11 connector to the universal one. + * + *

Problem is that FlinkKafkaProducer (universal) and FlinkKafkaProducer011 have different names and + * they both defined static classes NextTransactionalIdHint, KafkaTransactionState and + * KafkaTransactionContext inside the parent classes. This is causing incompatibility problems since + * for example FlinkKafkaProducer011.KafkaTransactionState and FlinkKafkaProducer.KafkaTransactionState + * are treated as completely incompatible classes, despite being identical. + * + *

This issue is solved by using custom serialization logic: keeping a fake/dummy + * FlinkKafkaProducer011.*Serializer classes in the universal connector + * (this class), as entry points for the deserialization and converting them to + * FlinkKafkaProducer.*Serializer counter parts. After all serialized binary data are exactly + * the same in all of those cases. + * + *

For more details check FLINK-11249 and the discussion in the pull requests. + */ +//CHECKSTYLE:OFF: JavadocType +public class FlinkKafkaProducer011 { + public static class NextTransactionalIdHintSerializer { + public static final class NextTransactionalIdHintSerializerSnapshot extends SimpleTypeSerializerSnapshot { + public NextTransactionalIdHintSerializerSnapshot() { + super(FlinkKafkaProducer.NextTransactionalIdHintSerializer::new); + } + } + } + + public static class ContextStateSerializer { + public static final class ContextStateSerializerSnapshot extends SimpleTypeSerializerSnapshot { + public ContextStateSerializerSnapshot() { + super(FlinkKafkaProducer.ContextStateSerializer::new); + } + } + } + + public static class TransactionStateSerializer { + public static final class TransactionStateSerializerSnapshot extends SimpleTypeSerializerSnapshot { + public TransactionStateSerializerSnapshot() { + super(FlinkKafkaProducer.TransactionStateSerializer::new); + } + } + } + + public static class NextTransactionalIdHint extends FlinkKafkaProducer.NextTransactionalIdHint { + } +} +//CHECKSTYLE:ON: JavadocType diff --git a/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerMigrationOperatorTest.java b/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerMigrationOperatorTest.java new file mode 100644 index 00000000000000..48a754d9258c92 --- /dev/null +++ b/flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerMigrationOperatorTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.connectors.kafka; + +import org.apache.flink.testutils.migration.MigrationVersion; + +import org.junit.Ignore; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +/** + * Migration test from FlinkKafkaProducer011 operator. + */ +public class FlinkKafkaProducerMigrationOperatorTest extends FlinkKafkaProducerMigrationTest { + @Parameterized.Parameters(name = "Migration Savepoint: {0}") + public static Collection parameters() { + return Arrays.asList( + MigrationVersion.v1_8); + } + + public FlinkKafkaProducerMigrationOperatorTest(MigrationVersion testMigrateVersion) { + super(testMigrateVersion); + } + + @Override + public String getOperatorSnapshotPath(MigrationVersion version) { + return "src/test/resources/kafka-0.11-migration-kafka-producer-flink-" + version + "-snapshot"; + } + + /** + * This test depends on the resources generated by {@link FlinkKafkaProducer011MigrationTest}. + * Run {@link FlinkKafkaProducer011MigrationTest#writeSnapshot()} and copy the created resource + * file to the path specified by the {@link #getOperatorSnapshotPath(MigrationVersion)} method. + */ + @Ignore + @Override + public void writeSnapshot() throws Exception { + throw new UnsupportedOperationException(); + } +} diff --git a/flink-connectors/flink-connector-kafka/src/test/resources/kafka-0.11-migration-kafka-producer-flink-1.8-snapshot b/flink-connectors/flink-connector-kafka/src/test/resources/kafka-0.11-migration-kafka-producer-flink-1.8-snapshot new file mode 100644 index 0000000000000000000000000000000000000000..0a22c69360deb3230038dc96f8bc96ba092f1712 GIT binary patch literal 2032 zcmcIl&2G~$9CzAPJ0W&JLPA2~(AyL%Nn2nC4&9cDK-Wf1Is}*Erf%yswX4|OfD7UQ zcmW=W7k~%ghPX}Kk+9R1G;P-{VzZRUc7Fel-~THBBum?p`#g9^@jIc52vLp#K^R4$ zPI~%~P@$h%n#8C_uEKZl;BxE?RNw*$tk83z9~vl{Po6&R2PLGNl}cGJ8&*Z%D+hae z*C=(dRq9quQ_5_=P)4rw=t8#!9yMLNRrQ;#j^|f>yL0Gxe#_Z!`%bgbIcgr-pwMo} zeJrqZT&=h5j%T-=YTbEbx8!XOd}h%Aq6rFy81_R#N05sMqcNca7%)n4AXvoV2=zw@ z))e5S+PV?zO#>Vmrde#r==qDmb9&EZ=sX+|6p{}(x*Fk8G~q)gq)YitA-+lG`zy@R zguoMYilBHu!Bas(4x=d*WQ?K9PFO-{!QC8qy`tx)A_&0N@2_8;P2xn|4G|YEiclEh zknpjpuJAUEuQTX_#tK^%wSbMsj4s5ji#Uc(st8i2g$U-sebrO72;-$HQWV8)Qru+t z7`Pdz))^fn^-)Y8Q)H8zN_PM9E8BiPHx%pZwDykY)*V0Tc{`R{-{o&ts3bM2^4Fiz zc<^z^)d=v Date: Thu, 9 May 2019 18:53:15 +0800 Subject: [PATCH 72/92] [FLINK-11403][network] Introduce ResultPartitionWithConsumableNotifier for simplifying creation of ResultPartitionWriter The creation of ResultPartitionWriter from NetworkEnvironment relies on TaskAction, ResultPartitionConsumableNotifier. For breaking this tie, ResultPartitionWithConsumableNotifier is introduced for wrapping the logic of notification. In this way the later interface method ShuffleService#createResultPartitionWriter would be simple. This closes #7549. --- .../io/network/NetworkEnvironment.java | 9 +- .../api/writer/ResultPartitionWriter.java | 7 +- .../io/network/partition/ResultPartition.java | 67 +------ .../partition/ResultPartitionFactory.java | 19 +- ...tifyingResultPartitionWriterDecorator.java | 168 ++++++++++++++++++ .../flink/runtime/taskmanager/Task.java | 38 ++-- ...stractCollectingResultPartitionWriter.java | 3 +- .../network/api/writer/RecordWriterTest.java | 7 +- .../network/partition/PartitionTestUtils.java | 11 -- .../partition/ResultPartitionBuilder.java | 31 ---- .../partition/ResultPartitionTest.java | 135 +++++++++----- .../consumer/LocalInputChannelTest.java | 1 - .../StreamNetworkBenchmarkEnvironment.java | 17 +- 13 files changed, 305 insertions(+), 208 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index d381203218935c..7eac38140623a5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.io.network; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.JobID; import org.apache.flink.metrics.Gauge; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -39,7 +38,6 @@ import org.apache.flink.runtime.io.network.netty.NettyConnectionManager; import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider; import org.apache.flink.runtime.io.network.partition.ResultPartition; -import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionFactory; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; @@ -50,7 +48,6 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.taskexecutor.TaskExecutor; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -218,10 +215,7 @@ public void releasePartitions(Collection partitionIds) { public ResultPartition[] createResultPartitionWriters( String taskName, - JobID jobId, ExecutionAttemptID executionId, - TaskActions taskActions, - ResultPartitionConsumableNotifier partitionConsumableNotifier, Collection resultPartitionDeploymentDescriptors, MetricGroup outputGroup, MetricGroup buffersGroup) { @@ -231,8 +225,7 @@ public ResultPartition[] createResultPartitionWriters( ResultPartition[] resultPartitions = new ResultPartition[resultPartitionDeploymentDescriptors.size()]; int counter = 0; for (ResultPartitionDeploymentDescriptor rpdd : resultPartitionDeploymentDescriptors) { - resultPartitions[counter++] = resultPartitionFactory.create( - taskName, taskActions, jobId, executionId, rpdd, partitionConsumableNotifier); + resultPartitions[counter++] = resultPartitionFactory.create(taskName, executionId, rpdd); } registerOutputMetrics(outputGroup, buffersGroup, resultPartitions); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java index cc1e49abb5240f..6c869e972f373d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java @@ -50,16 +50,15 @@ public interface ResultPartitionWriter extends AutoCloseable { /** * Adds the bufferConsumer to the subpartition with the given index. * - *

For PIPELINED {@link org.apache.flink.runtime.io.network.partition.ResultPartitionType}s, - * this will trigger the deployment of consuming tasks after the first buffer has been added. - * *

This method takes the ownership of the passed {@code bufferConsumer} and thus is responsible for releasing * it's resources. * *

To avoid problems with data re-ordering, before adding new {@link BufferConsumer} the previously added one * the given {@code subpartitionIndex} must be marked as {@link BufferConsumer#isFinished()}. + * + * @return true if operation succeeded and bufferConsumer was enqueued for consumption. */ - void addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException; + boolean addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException; /** * Manually trigger consumption from enqueued {@link BufferConsumer BufferConsumers} in all subpartitions. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java index fef0278e9b48a6..d46727600bb942 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.io.network.partition; -import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; @@ -30,7 +29,6 @@ import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.taskexecutor.TaskExecutor; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.function.FunctionWithException; import org.slf4j.Logger; @@ -67,13 +65,6 @@ *

  • Release:
  • * * - *

    Lazy deployment and updates of consuming tasks

    - * - *

    Before a consuming task can request the result, it has to be deployed. The time of deployment - * depends on the PIPELINED vs. BLOCKING characteristic of the result partition. With pipelined - * results, receivers are deployed as soon as the first buffer is added to the result partition. - * With blocking results on the other hand, receivers are deployed after the partition is finished. - * *

    Buffer management

    * *

    State management

    @@ -84,10 +75,6 @@ public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner { private final String owningTaskName; - private final TaskActions taskActions; - - private final JobID jobId; - private final ResultPartitionID partitionId; /** Type of this partition. Defines the concrete subpartition implementation to use. */ @@ -98,12 +85,8 @@ public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner { private final ResultPartitionManager partitionManager; - private final ResultPartitionConsumableNotifier partitionConsumableNotifier; - public final int numTargetKeyGroups; - private final boolean sendScheduleOrUpdateConsumersMessage; - // - Runtime state -------------------------------------------------------- private final AtomicBoolean isReleased = new AtomicBoolean(); @@ -117,8 +100,6 @@ public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner { private BufferPool bufferPool; - private boolean hasNotifiedPipelinedConsumers; - private boolean isFinished; private volatile Throwable cause; @@ -127,27 +108,19 @@ public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner { public ResultPartition( String owningTaskName, - TaskActions taskActions, // actions on the owning task - JobID jobId, ResultPartitionID partitionId, ResultPartitionType partitionType, ResultSubpartition[] subpartitions, int numTargetKeyGroups, ResultPartitionManager partitionManager, - ResultPartitionConsumableNotifier partitionConsumableNotifier, - boolean sendScheduleOrUpdateConsumersMessage, FunctionWithException bufferPoolFactory) { this.owningTaskName = checkNotNull(owningTaskName); - this.taskActions = checkNotNull(taskActions); - this.jobId = checkNotNull(jobId); this.partitionId = checkNotNull(partitionId); this.partitionType = checkNotNull(partitionType); this.subpartitions = checkNotNull(subpartitions); this.numTargetKeyGroups = numTargetKeyGroups; this.partitionManager = checkNotNull(partitionManager); - this.partitionConsumableNotifier = checkNotNull(partitionConsumableNotifier); - this.sendScheduleOrUpdateConsumersMessage = sendScheduleOrUpdateConsumersMessage; this.bufferPoolFactory = bufferPoolFactory; } @@ -171,10 +144,6 @@ public void setup() throws IOException { partitionManager.registerResultPartition(this); } - public JobID getJobId() { - return jobId; - } - public String getOwningTaskName() { return owningTaskName; } @@ -221,7 +190,7 @@ public BufferBuilder getBufferBuilder() throws IOException, InterruptedException } @Override - public void addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException { + public boolean addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException { checkNotNull(bufferConsumer); ResultSubpartition subpartition; @@ -234,9 +203,7 @@ public void addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionInd throw ex; } - if (subpartition.add(bufferConsumer)) { - notifyPipelinedConsumers(); - } + return subpartition.add(bufferConsumer); } @Override @@ -260,24 +227,13 @@ public void flush(int subpartitionIndex) { */ @Override public void finish() throws IOException { - boolean success = false; - - try { - checkInProduceState(); - - for (ResultSubpartition subpartition : subpartitions) { - subpartition.finish(); - } + checkInProduceState(); - success = true; + for (ResultSubpartition subpartition : subpartitions) { + subpartition.finish(); } - finally { - if (success) { - isFinished = true; - notifyPipelinedConsumers(); - } - } + isFinished = true; } public void release() { @@ -439,15 +395,4 @@ public ResultSubpartition[] getAllPartitions() { private void checkInProduceState() throws IllegalStateException { checkState(!isFinished, "Partition already finished."); } - - /** - * Notifies pipelined consumers of this result partition once. - */ - private void notifyPipelinedConsumers() { - if (sendScheduleOrUpdateConsumersMessage && !hasNotifiedPipelinedConsumers && partitionType.isPipelined()) { - partitionConsumableNotifier.notifyPartitionConsumable(jobId, partitionId, taskActions); - - hasNotifiedPipelinedConsumers = true; - } - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java index 247f2e252b9d25..1d0a0402a22dc2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java @@ -19,14 +19,12 @@ package org.apache.flink.runtime.io.network.partition; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory; import org.apache.flink.runtime.io.network.buffer.BufferPoolOwner; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FlinkRuntimeException; import org.apache.flink.util.function.FunctionWithException; @@ -74,51 +72,36 @@ public ResultPartitionFactory( public ResultPartition create( @Nonnull String taskNameWithSubtaskAndId, - @Nonnull TaskActions taskActions, - @Nonnull JobID jobId, @Nonnull ExecutionAttemptID executionAttemptID, - @Nonnull ResultPartitionDeploymentDescriptor desc, - @Nonnull ResultPartitionConsumableNotifier partitionConsumableNotifier) { + @Nonnull ResultPartitionDeploymentDescriptor desc) { return create( taskNameWithSubtaskAndId, - taskActions, - jobId, new ResultPartitionID(desc.getPartitionId(), executionAttemptID), desc.getPartitionType(), desc.getNumberOfSubpartitions(), desc.getMaxParallelism(), - partitionConsumableNotifier, - desc.sendScheduleOrUpdateConsumersMessage(), createBufferPoolFactory(desc.getNumberOfSubpartitions(), desc.getPartitionType())); } @VisibleForTesting public ResultPartition create( @Nonnull String taskNameWithSubtaskAndId, - @Nonnull TaskActions taskActions, - @Nonnull JobID jobId, @Nonnull ResultPartitionID id, @Nonnull ResultPartitionType type, int numberOfSubpartitions, int maxParallelism, - @Nonnull ResultPartitionConsumableNotifier partitionConsumableNotifier, - boolean sendScheduleOrUpdateConsumersMessage, FunctionWithException bufferPoolFactory) { ResultSubpartition[] subpartitions = new ResultSubpartition[numberOfSubpartitions]; ResultPartition partition = new ResultPartition( taskNameWithSubtaskAndId, - taskActions, - jobId, id, type, subpartitions, maxParallelism, partitionManager, - partitionConsumableNotifier, - sendScheduleOrUpdateConsumersMessage, bufferPoolFactory); createSubpartitions(partition, type, subpartitions); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java new file mode 100644 index 00000000000000..d9add0689112df --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.taskmanager; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; +import org.apache.flink.runtime.io.network.buffer.BufferConsumer; +import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; + +import java.io.IOException; +import java.util.Collection; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A wrapper of result partition writer for handling the logic of consumable notification. + * + *

    Before a consuming task can request the result, it has to be deployed. The time of deployment + * depends on the PIPELINED vs. BLOCKING characteristic of the result partition. With pipelined + * results, receivers are deployed as soon as the first buffer is added to the result partition. + * With blocking results on the other hand, receivers are deployed after the partition is finished. + */ +public class ConsumableNotifyingResultPartitionWriterDecorator implements ResultPartitionWriter { + + private final TaskActions taskActions; + + private final JobID jobId; + + private final ResultPartitionWriter partitionWriter; + + private final ResultPartitionConsumableNotifier partitionConsumableNotifier; + + private boolean hasNotifiedPipelinedConsumers; + + public ConsumableNotifyingResultPartitionWriterDecorator( + TaskActions taskActions, + JobID jobId, + ResultPartitionWriter partitionWriter, + ResultPartitionConsumableNotifier partitionConsumableNotifier) { + this.taskActions = checkNotNull(taskActions); + this.jobId = checkNotNull(jobId); + this.partitionWriter = checkNotNull(partitionWriter); + this.partitionConsumableNotifier = checkNotNull(partitionConsumableNotifier); + } + + @Override + public BufferBuilder getBufferBuilder() throws IOException, InterruptedException { + return partitionWriter.getBufferBuilder(); + } + + @Override + public ResultPartitionID getPartitionId() { + return partitionWriter.getPartitionId(); + } + + @Override + public int getNumberOfSubpartitions() { + return partitionWriter.getNumberOfSubpartitions(); + } + + @Override + public int getNumTargetKeyGroups() { + return partitionWriter.getNumTargetKeyGroups(); + } + + @Override + public void setup() throws IOException { + partitionWriter.setup(); + } + + @Override + public boolean addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException { + boolean success = partitionWriter.addBufferConsumer(bufferConsumer, subpartitionIndex); + if (success) { + notifyPipelinedConsumers(); + } + + return success; + } + + @Override + public void flushAll() { + partitionWriter.flushAll(); + } + + @Override + public void flush(int subpartitionIndex) { + partitionWriter.flush(subpartitionIndex); + } + + @Override + public void finish() throws IOException { + partitionWriter.finish(); + + notifyPipelinedConsumers(); + } + + @Override + public void fail(Throwable throwable) { + partitionWriter.fail(throwable); + } + + @Override + public void close() throws Exception { + partitionWriter.close(); + } + + /** + * Notifies pipelined consumers of this result partition once. + * + *

    For PIPELINED {@link org.apache.flink.runtime.io.network.partition.ResultPartitionType}s, + * this will trigger the deployment of consuming tasks after the first buffer has been added. + */ + private void notifyPipelinedConsumers() { + if (!hasNotifiedPipelinedConsumers) { + partitionConsumableNotifier.notifyPartitionConsumable(jobId, partitionWriter.getPartitionId(), taskActions); + + hasNotifiedPipelinedConsumers = true; + } + } + + // ------------------------------------------------------------------------ + // Factory + // ------------------------------------------------------------------------ + + public static ResultPartitionWriter[] decorate( + Collection descs, + ResultPartitionWriter[] partitionWriters, + TaskActions taskActions, + JobID jobId, + ResultPartitionConsumableNotifier notifier) { + + ResultPartitionWriter[] consumableNotifyingPartitionWriters = new ResultPartitionWriter[partitionWriters.length]; + int counter = 0; + for (ResultPartitionDeploymentDescriptor desc : descs) { + if (desc.sendScheduleOrUpdateConsumersMessage() && desc.getPartitionType().isPipelined()) { + consumableNotifyingPartitionWriters[counter] = new ConsumableNotifyingResultPartitionWriterDecorator( + taskActions, + jobId, + partitionWriters[counter], + notifier); + } else { + consumableNotifyingPartitionWriters[counter] = partitionWriters[counter]; + } + counter++; + } + return consumableNotifyingPartitionWriters; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 02ef419e746ac9..06df24e706bf0e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -189,7 +189,7 @@ public class Task implements Runnable, TaskActions, PartitionProducerStateProvid /** Serialized version of the job specific execution configuration (see {@link ExecutionConfig}). */ private final SerializedValue serializedExecutionConfig; - private final ResultPartitionWriter[] producedPartitions; + private final ResultPartitionWriter[] consumableNotifyingPartitionWriters; private final InputGate[] inputGates; @@ -368,16 +368,20 @@ public Task( final MetricGroup inputGroup = networkGroup.addGroup("Input"); // produced intermediate result partitions - this.producedPartitions = networkEnvironment.createResultPartitionWriters( + final ResultPartitionWriter[] resultPartitionWriters = networkEnvironment.createResultPartitionWriters( taskNameWithSubtaskAndId, - jobId, executionId, - this, - resultPartitionConsumableNotifier, resultPartitionDeploymentDescriptors, outputGroup, buffersGroup); + this.consumableNotifyingPartitionWriters = ConsumableNotifyingResultPartitionWriterDecorator.decorate( + resultPartitionDeploymentDescriptors, + resultPartitionWriters, + this, + jobId, + resultPartitionConsumableNotifier); + // consumed intermediate result partitions InputGate[] gates = networkEnvironment.createInputGates( taskNameWithSubtaskAndId, @@ -589,10 +593,10 @@ else if (current == ExecutionState.CANCELING) { LOG.info("Registering task at network: {}.", this); - setupPartionsAndGates(producedPartitions, inputGates); + setupPartionsAndGates(consumableNotifyingPartitionWriters, inputGates); - for (ResultPartitionWriter partition : producedPartitions) { - taskEventDispatcher.registerPartition(partition.getPartitionId()); + for (ResultPartitionWriter partitionWriter : consumableNotifyingPartitionWriters) { + taskEventDispatcher.registerPartition(partitionWriter.getPartitionId()); } // next, kick off the background copying of files for the distributed cache @@ -637,7 +641,7 @@ else if (current == ExecutionState.CANCELING) { kvStateRegistry, inputSplitProvider, distributedCacheEntries, - producedPartitions, + consumableNotifyingPartitionWriters, inputGates, taskEventDispatcher, checkpointResponder, @@ -681,9 +685,9 @@ else if (current == ExecutionState.CANCELING) { // ---------------------------------------------------------------- // finish the produced partitions. if this fails, we consider the execution failed. - for (ResultPartitionWriter partition : producedPartitions) { - if (partition != null) { - partition.finish(); + for (ResultPartitionWriter partitionWriter : consumableNotifyingPartitionWriters) { + if (partitionWriter != null) { + partitionWriter.finish(); } } @@ -838,10 +842,10 @@ public static void setupPartionsAndGates( private void releaseNetworkResources() { LOG.debug("Release task {} network resources (state: {}).", taskNameWithSubtask, getExecutionState()); - for (ResultPartitionWriter partition : producedPartitions) { - taskEventDispatcher.unregisterPartition(partition.getPartitionId()); + for (ResultPartitionWriter partitionWriter : consumableNotifyingPartitionWriters) { + taskEventDispatcher.unregisterPartition(partitionWriter.getPartitionId()); if (isCanceledOrFailed()) { - partition.fail(getFailureCause()); + partitionWriter.fail(getFailureCause()); } } @@ -853,9 +857,9 @@ private void releaseNetworkResources() { * release partitions and gates. Another is from task thread during task exiting. */ private void closeNetworkResources() { - for (ResultPartitionWriter partition : producedPartitions) { + for (ResultPartitionWriter partitionWriter : consumableNotifyingPartitionWriters) { try { - partition.close(); + partitionWriter.close(); } catch (Throwable t) { ExceptionUtils.rethrowIfFatalError(t); LOG.error("Failed to release result partition for task {}.", taskNameWithSubtask, t); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java index 8633fe317f3896..035fa565b1e9ba 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AbstractCollectingResultPartitionWriter.java @@ -70,10 +70,11 @@ public BufferBuilder getBufferBuilder() throws IOException, InterruptedException } @Override - public synchronized void addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { + public synchronized boolean addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { checkState(targetChannel < getNumberOfSubpartitions()); bufferConsumers.add(bufferConsumer); processBufferConsumers(); + return true; } private void processBufferConsumers() throws IOException { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index f8c6fdd1871679..882f83c5a9bf40 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -500,8 +500,8 @@ public BufferBuilder getBufferBuilder() throws IOException, InterruptedException } @Override - public void addBufferConsumer(BufferConsumer buffer, int targetChannel) throws IOException { - queues[targetChannel].add(buffer); + public boolean addBufferConsumer(BufferConsumer buffer, int targetChannel) throws IOException { + return queues[targetChannel].add(buffer); } @Override @@ -575,8 +575,9 @@ public BufferBuilder getBufferBuilder() throws IOException, InterruptedException } @Override - public void addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { + public boolean addBufferConsumer(BufferConsumer bufferConsumer, int targetChannel) throws IOException { bufferConsumer.close(); + return true; } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionTestUtils.java index 1b9b895b30f40b..df07db73af7616 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartitionTestUtils.java @@ -35,17 +35,6 @@ public static ResultPartition createPartition(ResultPartitionType type) { return new ResultPartitionBuilder().setResultPartitionType(type).build(); } - public static ResultPartition createPartition( - ResultPartitionConsumableNotifier notifier, - ResultPartitionType type, - boolean sendScheduleOrUpdateConsumersMessage) { - return new ResultPartitionBuilder() - .setResultPartitionConsumableNotifier(notifier) - .setResultPartitionType(type) - .setSendScheduleOrUpdateConsumersMessage(sendScheduleOrUpdateConsumersMessage) - .build(); - } - public static ResultPartition createPartition( NetworkEnvironment environment, ResultPartitionType partitionType, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java index 6a7c4b1172bca3..f34a5981d4e066 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java @@ -18,15 +18,12 @@ package org.apache.flink.runtime.io.network.partition; -import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferPoolOwner; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; -import org.apache.flink.runtime.taskmanager.NoOpTaskActions; -import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.util.function.FunctionWithException; import java.io.IOException; @@ -36,9 +33,6 @@ * Utility class to encapsulate the logic of building a {@link ResultPartition} instance. */ public class ResultPartitionBuilder { - private JobID jobId = new JobID(); - - private final TaskActions taskActions = new NoOpTaskActions(); private ResultPartitionID partitionId = new ResultPartitionID(); @@ -50,12 +44,8 @@ public class ResultPartitionBuilder { private ResultPartitionManager partitionManager = new ResultPartitionManager(); - private ResultPartitionConsumableNotifier partitionConsumableNotifier = new NoOpResultPartitionConsumableNotifier(); - private IOManager ioManager = new IOManagerAsync(); - private boolean sendScheduleOrUpdateConsumersMessage = false; - private NetworkBufferPool networkBufferPool = new NetworkBufferPool(1, 1, 1); private int networkBuffersPerChannel = 1; @@ -65,11 +55,6 @@ public class ResultPartitionBuilder { @SuppressWarnings("OptionalUsedAsFieldOrParameterType") private Optional> bufferPoolFactory = Optional.empty(); - public ResultPartitionBuilder setJobId(JobID jobId) { - this.jobId = jobId; - return this; - } - public ResultPartitionBuilder setResultPartitionId(ResultPartitionID partitionId) { this.partitionId = partitionId; return this; @@ -95,23 +80,11 @@ public ResultPartitionBuilder setResultPartitionManager(ResultPartitionManager p return this; } - ResultPartitionBuilder setResultPartitionConsumableNotifier(ResultPartitionConsumableNotifier notifier) { - this.partitionConsumableNotifier = notifier; - return this; - } - public ResultPartitionBuilder setIOManager(IOManager ioManager) { this.ioManager = ioManager; return this; } - public ResultPartitionBuilder setSendScheduleOrUpdateConsumersMessage( - boolean sendScheduleOrUpdateConsumersMessage) { - - this.sendScheduleOrUpdateConsumersMessage = sendScheduleOrUpdateConsumersMessage; - return this; - } - public ResultPartitionBuilder setupBufferPoolFactoryFromNetworkEnvironment(NetworkEnvironment environment) { return setNetworkBuffersPerChannel(environment.getConfiguration().networkBuffersPerChannel()) .setFloatingNetworkBuffersPerGate(environment.getConfiguration().floatingNetworkBuffersPerGate()) @@ -152,14 +125,10 @@ public ResultPartition build() { return resultPartitionFactory.create( "Result Partition task", - taskActions, - jobId, partitionId, partitionType, numberOfSubpartitions, numTargetKeyGroups, - partitionConsumableNotifier, - sendScheduleOrUpdateConsumersMessage, factory); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java index 3340fea8a9150e..6166e20a716a86 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java @@ -19,22 +19,29 @@ package org.apache.flink.runtime.io.network.partition; import org.apache.flink.api.common.JobID; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironmentBuilder; +import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.BufferBuilder; import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils; import org.apache.flink.runtime.io.network.buffer.BufferConsumer; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.taskmanager.ConsumableNotifyingResultPartitionWriterDecorator; +import org.apache.flink.runtime.taskmanager.NoOpTaskActions; import org.apache.flink.runtime.taskmanager.TaskActions; import org.junit.Assert; import org.junit.Test; +import java.util.Collections; + import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createFilledBufferConsumer; import static org.apache.flink.runtime.io.network.partition.PartitionTestUtils.createPartition; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -52,40 +59,32 @@ public class ResultPartitionTest { */ @Test public void testSendScheduleOrUpdateConsumersMessage() throws Exception { + JobID jobId = new JobID(); + TaskActions taskActions = new NoOpTaskActions(); + { // Pipelined, send message => notify ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); - ResultPartition partition = createPartition(notifier, ResultPartitionType.PIPELINED, true); - partition.addBufferConsumer(createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE), 0); + ResultPartitionWriter consumableNotifyingPartitionWriter = createConsumableNotifyingResultPartitionWriter( + ResultPartitionType.PIPELINED, + taskActions, + jobId, + notifier); + consumableNotifyingPartitionWriter.addBufferConsumer(createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE), 0); verify(notifier, times(1)) - .notifyPartitionConsumable( - eq(partition.getJobId()), - eq(partition.getPartitionId()), - any(TaskActions.class)); - } - - { - // Pipelined, don't send message => don't notify - ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); - ResultPartition partition = createPartition(notifier, ResultPartitionType.PIPELINED, false); - partition.addBufferConsumer(createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE), 0); - verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); + .notifyPartitionConsumable(eq(jobId), eq(consumableNotifyingPartitionWriter.getPartitionId()), eq(taskActions)); } { // Blocking, send message => don't notify ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); - ResultPartition partition = createPartition(notifier, ResultPartitionType.BLOCKING, true); + ResultPartitionWriter partition = createConsumableNotifyingResultPartitionWriter( + ResultPartitionType.BLOCKING, + taskActions, + jobId, + notifier); partition.addBufferConsumer(createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE), 0); - verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); - } - - { - // Blocking, don't send message => don't notify - ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); - ResultPartition partition = createPartition(notifier, ResultPartitionType.BLOCKING, false); - partition.addBufferConsumer(createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE), 0); - verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); + verify(notifier, never()).notifyPartitionConsumable(eq(jobId), eq(partition.getPartitionId()), eq(taskActions)); } } @@ -102,18 +101,23 @@ public void testAddOnFinishedBlockingPartition() throws Exception { /** * Tests {@link ResultPartition#addBufferConsumer} on a partition which has already finished. * - * @param pipelined the result partition type to set up + * @param partitionType the result partition type to set up */ - protected void testAddOnFinishedPartition(final ResultPartitionType pipelined) - throws Exception { + private void testAddOnFinishedPartition(final ResultPartitionType partitionType) throws Exception { BufferConsumer bufferConsumer = createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE); ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); + JobID jobId = new JobID(); + TaskActions taskActions = new NoOpTaskActions(); + ResultPartitionWriter consumableNotifyingPartitionWriter = createConsumableNotifyingResultPartitionWriter( + partitionType, + taskActions, + jobId, + notifier); try { - ResultPartition partition = createPartition(notifier, pipelined, true); - partition.finish(); + consumableNotifyingPartitionWriter.finish(); reset(notifier); // partition.add() should fail - partition.addBufferConsumer(bufferConsumer, 0); + consumableNotifyingPartitionWriter.addBufferConsumer(bufferConsumer, 0); Assert.fail("exception expected"); } catch (IllegalStateException e) { // expected => ignored @@ -123,7 +127,10 @@ protected void testAddOnFinishedPartition(final ResultPartitionType pipelined) Assert.fail("bufferConsumer not recycled"); } // should not have notified either - verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); + verify(notifier, never()).notifyPartitionConsumable( + eq(jobId), + eq(consumableNotifyingPartitionWriter.getPartitionId()), + eq(taskActions)); } } @@ -140,17 +147,24 @@ public void testAddOnReleasedBlockingPartition() throws Exception { /** * Tests {@link ResultPartition#addBufferConsumer} on a partition which has already been released. * - * @param pipelined the result partition type to set up + * @param partitionType the result partition type to set up */ - protected void testAddOnReleasedPartition(final ResultPartitionType pipelined) - throws Exception { + private void testAddOnReleasedPartition(final ResultPartitionType partitionType) throws Exception { BufferConsumer bufferConsumer = createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE); ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); + JobID jobId = new JobID(); + TaskActions taskActions = new NoOpTaskActions(); + ResultPartition partition = createPartition(partitionType); + ResultPartitionWriter consumableNotifyingPartitionWriter = ConsumableNotifyingResultPartitionWriterDecorator.decorate( + Collections.singleton(createPartitionDeploymentDescriptor(partitionType)), + new ResultPartitionWriter[] {partition}, + taskActions, + jobId, + notifier)[0]; try { - ResultPartition partition = createPartition(notifier, pipelined, true); partition.release(); // partition.add() silently drops the bufferConsumer but recycles it - partition.addBufferConsumer(bufferConsumer, 0); + consumableNotifyingPartitionWriter.addBufferConsumer(bufferConsumer, 0); assertTrue(partition.isReleased()); } finally { if (!bufferConsumer.isRecycled()) { @@ -158,7 +172,7 @@ protected void testAddOnReleasedPartition(final ResultPartitionType pipelined) Assert.fail("bufferConsumer not recycled"); } // should not have notified either - verify(notifier, never()).notifyPartitionConsumable(any(JobID.class), any(ResultPartitionID.class), any(TaskActions.class)); + verify(notifier, never()).notifyPartitionConsumable(eq(jobId), eq(partition.getPartitionId()), eq(taskActions)); } } @@ -175,28 +189,30 @@ public void testAddOnBlockingPartition() throws Exception { /** * Tests {@link ResultPartition#addBufferConsumer(BufferConsumer, int)} on a working partition. * - * @param pipelined the result partition type to set up + * @param partitionType the result partition type to set up */ - protected void testAddOnPartition(final ResultPartitionType pipelined) - throws Exception { + private void testAddOnPartition(final ResultPartitionType partitionType) throws Exception { ResultPartitionConsumableNotifier notifier = mock(ResultPartitionConsumableNotifier.class); - ResultPartition partition = createPartition(notifier, pipelined, true); + JobID jobId = new JobID(); + TaskActions taskActions = new NoOpTaskActions(); + ResultPartitionWriter consumableNotifyingPartitionWriter = createConsumableNotifyingResultPartitionWriter( + partitionType, + taskActions, + jobId, + notifier); BufferConsumer bufferConsumer = createFilledBufferConsumer(BufferBuilderTestUtils.BUFFER_SIZE); try { // partition.add() adds the bufferConsumer without recycling it (if not spilling) - partition.addBufferConsumer(bufferConsumer, 0); + consumableNotifyingPartitionWriter.addBufferConsumer(bufferConsumer, 0); assertFalse("bufferConsumer should not be recycled (still in the queue)", bufferConsumer.isRecycled()); } finally { if (!bufferConsumer.isRecycled()) { bufferConsumer.close(); } // should have been notified for pipelined partitions - if (pipelined.isPipelined()) { + if (partitionType.isPipelined()) { verify(notifier, times(1)) - .notifyPartitionConsumable( - eq(partition.getJobId()), - eq(partition.getPartitionId()), - any(TaskActions.class)); + .notifyPartitionConsumable(eq(jobId), eq(consumableNotifyingPartitionWriter.getPartitionId()), eq(taskActions)); } } } @@ -243,4 +259,27 @@ private void testReleaseMemory(final ResultPartitionType resultPartitionType) th network.shutdown(); } } + + private ResultPartitionWriter createConsumableNotifyingResultPartitionWriter( + ResultPartitionType partitionType, + TaskActions taskActions, + JobID jobId, + ResultPartitionConsumableNotifier notifier) { + return ConsumableNotifyingResultPartitionWriterDecorator.decorate( + Collections.singleton(createPartitionDeploymentDescriptor(partitionType)), + new ResultPartitionWriter[] {createPartition(partitionType)}, + taskActions, + jobId, + notifier)[0]; + } + + private ResultPartitionDeploymentDescriptor createPartitionDeploymentDescriptor(ResultPartitionType partitionType) { + return new ResultPartitionDeploymentDescriptor( + new IntermediateDataSetID(), + new IntermediateResultPartitionID(), + partitionType, + 1, + 1, + true); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index 74c4968ccaac7d..35be8e7c199046 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -117,7 +117,6 @@ public void testConcurrentConsumeMultiplePartitions() throws Exception { .setNumberOfSubpartitions(parallelism) .setNumTargetKeyGroups(parallelism) .setResultPartitionManager(partitionManager) - .setSendScheduleOrUpdateConsumersMessage(true) .setBufferPoolFactory(p -> networkBuffers.createBufferPool(producerBufferPoolSize, producerBufferPoolSize)) .build(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java index 7a9c863d088b7b..86e1d6f436efa3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/StreamNetworkBenchmarkEnvironment.java @@ -37,7 +37,7 @@ import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.netty.NettyConfig; import org.apache.flink.runtime.io.network.partition.InputChannelTestUtils; -import org.apache.flink.runtime.io.network.partition.ResultPartition; +import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; @@ -47,8 +47,10 @@ import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateFactory; import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.taskmanager.ConsumableNotifyingResultPartitionWriterDecorator; import org.apache.flink.runtime.taskmanager.InputGateWithMetrics; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; +import org.apache.flink.runtime.taskmanager.NoOpTaskActions; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.ConfigurationParserUtils; @@ -216,8 +218,7 @@ protected ResultPartitionWriter createResultPartition( NetworkEnvironment environment, int channels) throws Exception { - ResultPartition resultPartition = new ResultPartitionBuilder() - .setJobId(jobId) + ResultPartitionWriter resultPartitionWriter = new ResultPartitionBuilder() .setResultPartitionId(partitionId) .setResultPartitionType(ResultPartitionType.PIPELINED_BOUNDED) .setNumberOfSubpartitions(channels) @@ -226,9 +227,15 @@ protected ResultPartitionWriter createResultPartition( .setupBufferPoolFactoryFromNetworkEnvironment(environment) .build(); - resultPartition.setup(); + ResultPartitionWriter consumableNotifyingPartitionWriter = new ConsumableNotifyingResultPartitionWriterDecorator( + new NoOpTaskActions(), + jobId, + resultPartitionWriter, + new NoOpResultPartitionConsumableNotifier()); - return resultPartition; + consumableNotifyingPartitionWriter.setup(); + + return consumableNotifyingPartitionWriter; } private InputGate createInputGate(TaskManagerLocation senderLocation) throws IOException { From 38557bf8a6f8bebef8733f3f4f3b3950e9678fca Mon Sep 17 00:00:00 2001 From: qiaoran Date: Thu, 23 May 2019 17:43:16 +0800 Subject: [PATCH 73/92] [FLINK-12572][hive]Implement HiveInputFormat to read Hive tables Implement HiveInputFormat to read data from Hive non-partition/partition tables. This closes #8522. --- flink-connectors/flink-connector-hive/pom.xml | 15 + .../connectors/hive/FlinkHiveException.java | 35 +++ .../connectors/hive/HiveRecordSerDe.java | 93 ++++++ .../connectors/hive/HiveTableInputFormat.java | 272 ++++++++++++++++++ .../connectors/hive/HiveTableInputSplit.java | 47 +++ .../connectors/hive/HiveTablePartition.java | 9 +- .../table/catalog/hive/HiveTableConfig.java | 1 + .../catalog/hive/util/HiveTableUtil.java | 42 +++ .../connectors/hive/HiveInputFormatTest.java | 126 ++++++++ .../src/test/resources/test/test.txt | 4 + pom.xml | 1 + 11 files changed, 643 insertions(+), 2 deletions(-) create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/FlinkHiveException.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveRecordSerDe.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputFormat.java create mode 100644 flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputSplit.java create mode 100644 flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveInputFormatTest.java create mode 100644 flink-connectors/flink-connector-hive/src/test/resources/test/test.txt diff --git a/flink-connectors/flink-connector-hive/pom.xml b/flink-connectors/flink-connector-hive/pom.xml index 545c964cdeb69b..21d64c1e48ed7b 100644 --- a/flink-connectors/flink-connector-hive/pom.xml +++ b/flink-connectors/flink-connector-hive/pom.xml @@ -353,6 +353,21 @@ under the License. test + + + + org.apache.flink + flink-java + ${project.version} + test + + + org.apache.flink + flink-clients_${scala.binary.version} + ${project.version} + test + + diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/FlinkHiveException.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/FlinkHiveException.java new file mode 100644 index 00000000000000..7c65491a04b309 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/FlinkHiveException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * Exception used by flink's hive data connector. + */ +@PublicEvolving +public class FlinkHiveException extends RuntimeException { + + public FlinkHiveException(Throwable cause) { + super(cause); + } + + public FlinkHiveException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveRecordSerDe.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveRecordSerDe.java new file mode 100644 index 00000000000000..be046d9bb6137f --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveRecordSerDe.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DateObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveCharObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveVarcharObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.TimestampObjectInspector; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; + +/** + * Class used to serialize to and from raw hdfs file type. + * Highly inspired by HCatRecordSerDe (almost copied from this class)in hive-catalog-core. + */ +public class HiveRecordSerDe { + + /** + * Return underlying Java Object from an object-representation + * that is readable by a provided ObjectInspector. + */ + public static Object obtainFlinkRowField(Object field, ObjectInspector fieldObjectInspector) { + Object res; + if (fieldObjectInspector.getCategory() == ObjectInspector.Category.PRIMITIVE) { + res = convertPrimitiveField(field, (PrimitiveObjectInspector) fieldObjectInspector); + } else { + throw new FlinkHiveException(new SerDeException( + String.format("HiveRecordSerDe doesn't support category %s, type %s yet", + fieldObjectInspector.getCategory(), fieldObjectInspector.getTypeName()))); + } + return res; + } + + /** + * This method actually convert java objects of Hive's scalar data types to those of Flink's internal data types. + * + * @param field field value + * @param primitiveObjectInspector Hive's primitive object inspector for the field + * @return the java object conforming to Flink's internal data types. + * + * TODO: Comparing to original HCatRecordSerDe.java, we may need add more type converter according to conf. + */ + private static Object convertPrimitiveField(Object field, PrimitiveObjectInspector primitiveObjectInspector) { + if (field == null) { + return null; + } + + switch(primitiveObjectInspector.getPrimitiveCategory()) { + case DECIMAL: + HiveDecimalObjectInspector decimalOI = (HiveDecimalObjectInspector) primitiveObjectInspector; + BigDecimal bigDecimal = decimalOI.getPrimitiveJavaObject(field).bigDecimalValue(); + return bigDecimal; + case TIMESTAMP: + Timestamp ts = ((TimestampObjectInspector) primitiveObjectInspector).getPrimitiveJavaObject(field); + return ts; + case DATE: + Date date = ((DateObjectInspector) primitiveObjectInspector).getPrimitiveWritableObject(field).get(); + return date; + case CHAR: + HiveChar c = ((HiveCharObjectInspector) primitiveObjectInspector).getPrimitiveJavaObject(field); + return c.getStrippedValue(); + case VARCHAR: + HiveVarchar vc = ((HiveVarcharObjectInspector) primitiveObjectInspector).getPrimitiveJavaObject(field); + return vc.getValue(); + default: + return primitiveObjectInspector.getPrimitiveJavaObject(field); + } + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputFormat.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputFormat.java new file mode 100644 index 00000000000000..fb99ee4ce27cf6 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputFormat.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.api.common.io.LocatableInputSplitAssigner; +import org.apache.flink.api.common.io.statistics.BaseStatistics; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.hadoop.common.HadoopInputFormatCommonBase; +import org.apache.flink.api.java.hadoop.mapred.wrapper.HadoopDummyReporter; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.core.io.InputSplitAssigner; +import org.apache.flink.table.catalog.hive.util.HiveTableUtil; +import org.apache.flink.types.Row; + +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.metastore.api.StorageDescriptor; +import org.apache.hadoop.hive.serde2.Deserializer; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.InputFormat; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.JobConfigurable; +import org.apache.hadoop.mapred.RecordReader; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.util.ReflectionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.hadoop.mapreduce.lib.input.FileInputFormat.INPUT_DIR; + +/** + * The HiveTableInputFormat are inspired by the HCatInputFormat and HadoopInputFormatBase. + * It's used to read from hive partition/non-partition table. + */ +public class HiveTableInputFormat extends HadoopInputFormatCommonBase + implements ResultTypeQueryable { + private static final long serialVersionUID = 6351448428766433164L; + private static Logger logger = LoggerFactory.getLogger(HiveTableInputFormat.class); + + private JobConf jobConf; + + protected transient Writable key; + protected transient Writable value; + + private transient RecordReader recordReader; + protected transient boolean fetched = false; + protected transient boolean hasNext; + + private boolean isPartitioned; + private RowTypeInfo rowTypeInfo; + + //Necessary info to init deserializer + private String[] partitionColNames; + //For non-partition hive table, partitions only contains one partition which partitionValues is empty. + private List partitions; + private transient Deserializer deserializer; + //Hive StructField list contain all related info for specific serde. + private transient List structFields; + //StructObjectInspector in hive helps us to look into the internal structure of a struct object. + private transient StructObjectInspector structObjectInspector; + private transient InputFormat mapredInputFormat; + private transient HiveTablePartition hiveTablePartition; + + public HiveTableInputFormat( + JobConf jobConf, + boolean isPartitioned, + String[] partitionColNames, + List partitions, + RowTypeInfo rowTypeInfo) { + super(jobConf.getCredentials()); + this.rowTypeInfo = checkNotNull(rowTypeInfo, "rowTypeInfo can not be null."); + this.jobConf = new JobConf(jobConf); + this.isPartitioned = isPartitioned; + this.partitionColNames = partitionColNames; + this.partitions = checkNotNull(partitions, "partitions can not be null."); + } + + @Override + public void open(HiveTableInputSplit split) throws IOException { + this.hiveTablePartition = split.getHiveTablePartition(); + StorageDescriptor sd = hiveTablePartition.getStorageDescriptor(); + jobConf.set(INPUT_DIR, sd.getLocation()); + try { + this.mapredInputFormat = (InputFormat) + Class.forName(sd.getInputFormat(), true, Thread.currentThread().getContextClassLoader()).newInstance(); + } catch (Exception e) { + throw new FlinkHiveException("Unable to instantiate the hadoop input format", e); + } + ReflectionUtils.setConf(mapredInputFormat, jobConf); + if (this.mapredInputFormat instanceof Configurable) { + ((Configurable) this.mapredInputFormat).setConf(this.jobConf); + } else if (this.mapredInputFormat instanceof JobConfigurable) { + ((JobConfigurable) this.mapredInputFormat).configure(this.jobConf); + } + this.recordReader = this.mapredInputFormat.getRecordReader(split.getHadoopInputSplit(), + jobConf, new HadoopDummyReporter()); + if (this.recordReader instanceof Configurable) { + ((Configurable) this.recordReader).setConf(jobConf); + } + key = this.recordReader.createKey(); + value = this.recordReader.createValue(); + this.fetched = false; + try { + deserializer = (Deserializer) Class.forName(sd.getSerdeInfo().getSerializationLib()).newInstance(); + Configuration conf = new Configuration(); + //properties are used to initialize hive Deserializer properly. + Properties properties = HiveTableUtil.createPropertiesFromStorageDescriptor(sd); + SerDeUtils.initializeSerDe(deserializer, conf, properties, null); + structObjectInspector = (StructObjectInspector) deserializer.getObjectInspector(); + structFields = structObjectInspector.getAllStructFieldRefs(); + } catch (Exception e) { + throw new FlinkHiveException("Error happens when deserialize from storage file.", e); + } + } + + @Override + public HiveTableInputSplit[] createInputSplits(int minNumSplits) + throws IOException { + List hiveSplits = new ArrayList<>(); + int splitNum = 0; + for (HiveTablePartition partition : partitions) { + StorageDescriptor sd = partition.getStorageDescriptor(); + InputFormat format; + try { + format = (InputFormat) + Class.forName(sd.getInputFormat(), true, Thread.currentThread().getContextClassLoader()).newInstance(); + } catch (Exception e) { + throw new FlinkHiveException("Unable to instantiate the hadoop input format", e); + } + ReflectionUtils.setConf(format, jobConf); + jobConf.set(INPUT_DIR, sd.getLocation()); + //TODO: we should consider how to calculate the splits according to minNumSplits in the future. + org.apache.hadoop.mapred.InputSplit[] splitArray = format.getSplits(jobConf, minNumSplits); + for (int i = 0; i < splitArray.length; i++) { + hiveSplits.add(new HiveTableInputSplit(splitNum++, splitArray[i], jobConf, partition)); + } + } + + return hiveSplits.toArray(new HiveTableInputSplit[hiveSplits.size()]); + } + + @Override + public void configure(org.apache.flink.configuration.Configuration parameters) { + + } + + @Override + public BaseStatistics getStatistics(BaseStatistics cachedStats) throws IOException { + // no statistics available + return null; + } + + @Override + public InputSplitAssigner getInputSplitAssigner(HiveTableInputSplit[] inputSplits) { + return new LocatableInputSplitAssigner(inputSplits); + } + + @Override + public boolean reachedEnd() throws IOException { + if (!fetched) { + fetchNext(); + } + return !hasNext; + } + + @Override + public void close() throws IOException { + if (this.recordReader != null) { + this.recordReader.close(); + this.recordReader = null; + } + } + + protected void fetchNext() throws IOException { + hasNext = this.recordReader.next(key, value); + fetched = true; + } + + @Override + public Row nextRecord(Row ignore) throws IOException { + if (reachedEnd()) { + return null; + } + Row row = new Row(rowTypeInfo.getArity()); + try { + //Use HiveDeserializer to deserialize an object out of a Writable blob + Object hiveRowStruct = deserializer.deserialize(value); + int index = 0; + for (; index < structFields.size(); index++) { + StructField structField = structFields.get(index); + Object object = HiveRecordSerDe.obtainFlinkRowField( + structObjectInspector.getStructFieldData(hiveRowStruct, structField), structField.getFieldObjectInspector()); + row.setField(index, object); + } + if (isPartitioned) { + for (String partition : partitionColNames){ + row.setField(index++, hiveTablePartition.getPartitionSpec().get(partition)); + } + } + } catch (Exception e){ + logger.error("Error happens when converting hive data type to flink data type."); + throw new FlinkHiveException(e); + } + this.fetched = false; + return row; + } + + @Override + public TypeInformation getProducedType() { + return rowTypeInfo; + } + + // -------------------------------------------------------------------------------------------- + // Custom serialization methods + // -------------------------------------------------------------------------------------------- + + private void writeObject(ObjectOutputStream out) throws IOException { + super.write(out); + jobConf.write(out); + out.writeObject(isPartitioned); + out.writeObject(rowTypeInfo); + out.writeObject(partitionColNames); + out.writeObject(partitions); + } + + @SuppressWarnings("unchecked") + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + super.read(in); + if (jobConf == null) { + jobConf = new JobConf(); + } + jobConf.readFields(in); + jobConf.getCredentials().addAll(this.credentials); + Credentials currentUserCreds = getCredentialsFromUGI(UserGroupInformation.getCurrentUser()); + if (currentUserCreds != null) { + jobConf.getCredentials().addAll(currentUserCreds); + } + isPartitioned = (boolean) in.readObject(); + rowTypeInfo = (RowTypeInfo) in.readObject(); + partitionColNames = (String[]) in.readObject(); + partitions = (List) in.readObject(); + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputSplit.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputSplit.java new file mode 100644 index 00000000000000..c727347551769d --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTableInputSplit.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.api.java.hadoop.mapred.wrapper.HadoopInputSplit; + +import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.mapred.JobConf; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * An wrapper class that wraps info needed for a hadoop input split. + * Right now, it contains info about the partition of the split. + */ +public class HiveTableInputSplit extends HadoopInputSplit { + private final HiveTablePartition hiveTablePartition; + + public HiveTableInputSplit( + int splitNumber, + InputSplit hInputSplit, + JobConf jobconf, + HiveTablePartition hiveTablePartition) { + super(splitNumber, hInputSplit, jobconf); + this.hiveTablePartition = checkNotNull(hiveTablePartition, "hiveTablePartition can not be null"); + } + + public HiveTablePartition getHiveTablePartition() { + return hiveTablePartition; + } +} diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java index 21aeb16ca92c72..20e73be285fa09 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/batch/connectors/hive/HiveTablePartition.java @@ -23,19 +23,24 @@ import java.io.Serializable; import java.util.Map; +import static org.apache.flink.util.Preconditions.checkNotNull; + /** * A class that describes a partition of a Hive table. And it represents the whole table if table is not partitioned. * Please note that the class is serializable because all its member variables are serializable. */ public class HiveTablePartition implements Serializable { + private static final long serialVersionUID = 4145470177119940673L; + + /** Partition storage descriptor. */ private final StorageDescriptor storageDescriptor; - // Partition spec for the partition. Should be null if the table is not partitioned. + /** The map of partition key names and their values. */ private final Map partitionSpec; public HiveTablePartition(StorageDescriptor storageDescriptor, Map partitionSpec) { - this.storageDescriptor = storageDescriptor; + this.storageDescriptor = checkNotNull(storageDescriptor, "storageDescriptor can not be null"); this.partitionSpec = partitionSpec; } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveTableConfig.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveTableConfig.java index e5063f21a7383a..273b4e93b139b4 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveTableConfig.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/HiveTableConfig.java @@ -25,5 +25,6 @@ public class HiveTableConfig { // Comment of the Flink table public static final String TABLE_COMMENT = "comment"; + public static final String DEFAULT_LIST_COLUMN_TYPES_SEPARATOR = ":"; } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTableUtil.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTableUtil.java index 1b424290aced63..2aeed298c0d4e0 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTableUtil.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/util/HiveTableUtil.java @@ -21,11 +21,20 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.table.api.TableSchema; +import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.metastore.api.SerDeInfo; +import org.apache.hadoop.hive.metastore.api.StorageDescriptor; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.SerDeUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.apache.flink.table.catalog.hive.HiveTableConfig.DEFAULT_LIST_COLUMN_TYPES_SEPARATOR; /** * Utils to for Hive-backed table. @@ -71,4 +80,37 @@ public static List createHiveColumns(TableSchema schema) { return columns; } + + // -------------------------------------------------------------------------------------------- + // Helper methods + // -------------------------------------------------------------------------------------------- + + /** + * Create properties info to initialize a SerDe. + * @param storageDescriptor + * @return + */ + public static Properties createPropertiesFromStorageDescriptor(StorageDescriptor storageDescriptor) { + SerDeInfo serDeInfo = storageDescriptor.getSerdeInfo(); + Map parameters = serDeInfo.getParameters(); + Properties properties = new Properties(); + properties.setProperty( + serdeConstants.SERIALIZATION_FORMAT, + parameters.get(serdeConstants.SERIALIZATION_FORMAT)); + List colTypes = new ArrayList<>(); + List colNames = new ArrayList<>(); + List cols = storageDescriptor.getCols(); + for (FieldSchema col: cols){ + colTypes.add(col.getType()); + colNames.add(col.getName()); + } + properties.setProperty(serdeConstants.LIST_COLUMNS, StringUtils.join(colNames, String.valueOf(SerDeUtils.COMMA))); + // Note: serdeConstants.COLUMN_NAME_DELIMITER is not defined in previous Hive. We use a literal to save on shim + properties.setProperty("column.name.delimite", String.valueOf(SerDeUtils.COMMA)); + properties.setProperty(serdeConstants.LIST_COLUMN_TYPES, StringUtils.join(colTypes, DEFAULT_LIST_COLUMN_TYPES_SEPARATOR)); + properties.setProperty(serdeConstants.SERIALIZATION_NULL_FORMAT, "NULL"); + properties.putAll(parameters); + return properties; + } + } diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveInputFormatTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveInputFormatTest.java new file mode 100644 index 00000000000000..80ff6debd26a2a --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/batch/connectors/hive/HiveInputFormatTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.batch.connectors.hive; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.catalog.hive.HiveCatalog; +import org.apache.flink.table.catalog.hive.HiveTestUtils; +import org.apache.flink.table.catalog.hive.util.HiveTableUtil; +import org.apache.flink.types.Row; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.HiveMetaStoreClient; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.metastore.RetryingMetaStoreClient; +import org.apache.hadoop.hive.metastore.api.SerDeInfo; +import org.apache.hadoop.hive.metastore.api.StorageDescriptor; +import org.apache.hadoop.mapred.JobConf; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +/** + * Tests {@link HiveTableInputFormat}. + */ +public class HiveInputFormatTest { + + public static final String DEFAULT_HIVE_INPUT_FORMAT_TEST_SERDE_CLASS = org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.class.getName(); + public static final String DEFAULT_HIVE_INPUT_FORMAT_TEST_INPUT_FORMAT_CLASS = org.apache.hadoop.mapred.TextInputFormat.class.getName(); + public static final String DEFAULT_OUTPUT_FORMAT_CLASS = org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat.class.getName(); + + private static HiveCatalog hiveCatalog; + private static HiveConf hiveConf; + + @BeforeClass + public static void createCatalog() throws IOException { + hiveConf = HiveTestUtils.getHiveConf(); + hiveCatalog = HiveTestUtils.createHiveCatalog(hiveConf); + hiveCatalog.open(); + } + + @AfterClass + public static void closeCatalog() { + if (null != hiveCatalog) { + hiveCatalog.close(); + } + } + + @Test + public void testReadFromHiveInputFormat() throws Exception { + final String dbName = "default"; + final String tblName = "test"; + TableSchema tableSchema = new TableSchema( + new String[]{"a", "b", "c", "d", "e"}, + new TypeInformation[]{ + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO} + ); + //Now we used metaStore client to create hive table instead of using hiveCatalog for it doesn't support set + //serDe temporarily. + IMetaStoreClient client = RetryingMetaStoreClient.getProxy(hiveConf, null, null, HiveMetaStoreClient.class.getName(), true); + org.apache.hadoop.hive.metastore.api.Table tbl = new org.apache.hadoop.hive.metastore.api.Table(); + tbl.setDbName(dbName); + tbl.setTableName(tblName); + tbl.setCreateTime((int) (System.currentTimeMillis() / 1000)); + tbl.setParameters(new HashMap<>()); + StorageDescriptor sd = new StorageDescriptor(); + String location = HiveInputFormatTest.class.getResource("/test").getPath(); + sd.setLocation(location); + sd.setInputFormat(DEFAULT_HIVE_INPUT_FORMAT_TEST_INPUT_FORMAT_CLASS); + sd.setOutputFormat(DEFAULT_OUTPUT_FORMAT_CLASS); + sd.setSerdeInfo(new SerDeInfo()); + sd.getSerdeInfo().setSerializationLib(DEFAULT_HIVE_INPUT_FORMAT_TEST_SERDE_CLASS); + sd.getSerdeInfo().setParameters(new HashMap<>()); + sd.getSerdeInfo().getParameters().put("serialization.format", "1"); + sd.getSerdeInfo().getParameters().put("field.delim", ","); + sd.setCols(HiveTableUtil.createHiveColumns(tableSchema)); + tbl.setSd(sd); + tbl.setPartitionKeys(new ArrayList<>()); + + client.createTable(tbl); + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + RowTypeInfo rowTypeInfo = new RowTypeInfo(tableSchema.getFieldTypes(), tableSchema.getFieldNames()); + List partitions = new ArrayList<>(); + partitions.add(new HiveTablePartition(sd, new HashMap<>())); + HiveTableInputFormat hiveTableInputFormat = new HiveTableInputFormat(new JobConf(hiveConf), false, null, + partitions, rowTypeInfo); + DataSet rowDataSet = env.createInput(hiveTableInputFormat); + List rows = rowDataSet.collect(); + Assert.assertEquals(4, rows.size()); + Assert.assertEquals("1,1,a,1000,1.11", rows.get(0).toString()); + Assert.assertEquals("2,2,a,2000,2.22", rows.get(1).toString()); + Assert.assertEquals("3,3,a,3000,3.33", rows.get(2).toString()); + Assert.assertEquals("4,4,a,4000,4.44", rows.get(3).toString()); + } +} diff --git a/flink-connectors/flink-connector-hive/src/test/resources/test/test.txt b/flink-connectors/flink-connector-hive/src/test/resources/test/test.txt new file mode 100644 index 00000000000000..672650e39a5f88 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/resources/test/test.txt @@ -0,0 +1,4 @@ +1,1,a,1000,1.11 +2,2,a,2000,2.22 +3,3,a,3000,3.33 +4,4,a,4000,4.44 \ No newline at end of file diff --git a/pom.xml b/pom.xml index 0334b334e8f991..4da8873b70156e 100644 --- a/pom.xml +++ b/pom.xml @@ -1321,6 +1321,7 @@ under the License. flink-end-to-end-tests/test-scripts/docker-hadoop-secure-cluster/config/keystore.jks flink-connectors/flink-connector-kafka/src/test/resources/** flink-connectors/flink-connector-kafka-0.11/src/test/resources/** + flink-connectors/flink-connector-hive/src/test/resources/** **/src/test/resources/*-snapshot From d11958281ea785b6b5fcc32a169ed4b96dcecd87 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Tue, 4 Jun 2019 11:33:41 +0200 Subject: [PATCH 74/92] [hotfix][core] Add notice about TypeSerializerSnapshot serialization --- .../flink/api/common/typeutils/TypeSerializerSnapshot.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeSerializerSnapshot.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeSerializerSnapshot.java index 20fcbb7bf5e79b..0dff5b29aa5aec 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeSerializerSnapshot.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/TypeSerializerSnapshot.java @@ -89,6 +89,8 @@ public interface TypeSerializerSnapshot { * @param out the {@link DataOutputView} to write the snapshot to. * * @throws IOException Thrown if the snapshot data could not be written. + * + * @see #writeVersionedSnapshot(DataOutputView, TypeSerializerSnapshot) */ void writeSnapshot(DataOutputView out) throws IOException; @@ -102,7 +104,9 @@ public interface TypeSerializerSnapshot { * @param in the {@link DataInputView} to read the snapshot from. * @param userCodeClassLoader the user code classloader * - * * @throws IOException Thrown if the snapshot data could be read or parsed. + * @throws IOException Thrown if the snapshot data could be read or parsed. + * + * @see #readVersionedSnapshot(DataInputView, ClassLoader) */ void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException; From 5f5f02b1272ceba5e72ac8bb29e3d260d66bd493 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Tue, 4 Jun 2019 11:47:32 +0200 Subject: [PATCH 75/92] [FLINK-12726][table-common] Fix ANY type serialization This closes #8612. --- .../flink/table/types/logical/AnyType.java | 4 ++-- .../flink/table/types/LogicalTypesTest.java | 20 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/AnyType.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/AnyType.java index 4f1753965bd304..ba849ca0dbd3a5 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/AnyType.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/AnyType.java @@ -141,7 +141,7 @@ private String getOrCreateSerializerString() { if (serializerString == null) { final DataOutputSerializer outputSerializer = new DataOutputSerializer(128); try { - serializer.snapshotConfiguration().writeSnapshot(outputSerializer); + TypeSerializerSnapshot.writeVersionedSnapshot(outputSerializer, serializer.snapshotConfiguration()); serializerString = EncodingUtils.encodeBytesToBase64(outputSerializer.getCopyOfBuffer()); return serializerString; } catch (Exception e) { @@ -149,7 +149,7 @@ private String getOrCreateSerializerString() { "Unable to generate a string representation of the serializer snapshot of '%s' " + "describing the class '%s' for the ANY type.", serializer.getClass().getName(), - clazz.toString())); + clazz.toString()), e); } } return serializerString; diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java index 600b7189b42b79..100982a21c8d0f 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypesTest.java @@ -568,15 +568,17 @@ public void testAnyType() { testAll( new AnyType<>(Human.class, new KryoSerializer<>(Human.class, new ExecutionConfig())), "ANY(org.apache.flink.table.types.LogicalTypesTest$Human, " + - "ADNvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLnR5cGVzLkxvZ2ljYWxUeXBlc1Rlc3QkSHVtYW4AAATyxpo9cAA" + - "AAAIAM29yZy5hcGFjaGUuZmxpbmsudGFibGUudHlwZXMuTG9naWNhbFR5cGVzVGVzdCRIdW1hbgEAAAA1AD" + - "NvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLnR5cGVzLkxvZ2ljYWxUeXBlc1Rlc3QkSHVtYW4BAAAAOQAzb3JnL" + - "mFwYWNoZS5mbGluay50YWJsZS50eXBlcy5Mb2dpY2FsVHlwZXNUZXN0JEh1bWFuAAAAAAApb3JnLmFwYWNo" + - "ZS5hdnJvLmdlbmVyaWMuR2VuZXJpY0RhdGEkQXJyYXkBAAAAKwApb3JnLmFwYWNoZS5hdnJvLmdlbmVyaWM" + - "uR2VuZXJpY0RhdGEkQXJyYXkBAAAAtgBVb3JnLmFwYWNoZS5mbGluay5hcGkuamF2YS50eXBldXRpbHMucn" + - "VudGltZS5rcnlvLlNlcmlhbGl6ZXJzJER1bW15QXZyb1JlZ2lzdGVyZWRDbGFzcwAAAAEAWW9yZy5hcGFja" + - "GUuZmxpbmsuYXBpLmphdmEudHlwZXV0aWxzLnJ1bnRpbWUua3J5by5TZXJpYWxpemVycyREdW1teUF2cm9L" + - "cnlvU2VyaWFsaXplckNsYXNzAAAE8saaPXAAAAAAAAAE8saaPXAAAAAA)", + "AEdvcmcuYXBhY2hlLmZsaW5rLmFwaS5qYXZhLnR5cGV1dGlscy5ydW50aW1lLmtyeW8uS3J5b1Nlcml" + + "hbGl6ZXJTbmFwc2hvdAAAAAIAM29yZy5hcGFjaGUuZmxpbmsudGFibGUudHlwZXMuTG9naWNhbFR5cG" + + "VzVGVzdCRIdW1hbgAABPLGmj1wAAAAAgAzb3JnLmFwYWNoZS5mbGluay50YWJsZS50eXBlcy5Mb2dpY" + + "2FsVHlwZXNUZXN0JEh1bWFuAQAAADUAM29yZy5hcGFjaGUuZmxpbmsudGFibGUudHlwZXMuTG9naWNh" + + "bFR5cGVzVGVzdCRIdW1hbgEAAAA5ADNvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLnR5cGVzLkxvZ2ljYWx" + + "UeXBlc1Rlc3QkSHVtYW4AAAAAAClvcmcuYXBhY2hlLmF2cm8uZ2VuZXJpYy5HZW5lcmljRGF0YSRBcn" + + "JheQEAAAArAClvcmcuYXBhY2hlLmF2cm8uZ2VuZXJpYy5HZW5lcmljRGF0YSRBcnJheQEAAAC2AFVvc" + + "mcuYXBhY2hlLmZsaW5rLmFwaS5qYXZhLnR5cGV1dGlscy5ydW50aW1lLmtyeW8uU2VyaWFsaXplcnMk" + + "RHVtbXlBdnJvUmVnaXN0ZXJlZENsYXNzAAAAAQBZb3JnLmFwYWNoZS5mbGluay5hcGkuamF2YS50eXB" + + "ldXRpbHMucnVudGltZS5rcnlvLlNlcmlhbGl6ZXJzJER1bW15QXZyb0tyeW9TZXJpYWxpemVyQ2xhc3" + + "MAAATyxpo9cAAAAAAAAATyxpo9cAAAAAA=)", "ANY(org.apache.flink.table.types.LogicalTypesTest$Human, ...)", new Class[]{Human.class, User.class}, // every User is Human new Class[]{Human.class}, From cedd9803a9d7e3312e49fa46025ce9ac985bf2a0 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Fri, 31 May 2019 17:13:42 +0200 Subject: [PATCH 76/92] [FLINK-12254][table] Update TableSource and related interfaces to new type system --- .../table/sources/DefinedFieldMapping.java | 18 ++++----- .../NestedFieldsProjectableTableSource.java | 10 ++--- .../table/sources/ProjectableTableSource.java | 8 ++-- .../flink/table/sources/TableSource.java | 36 ++++++++++++++--- .../plan/nodes/common/CommonLookupJoin.scala | 27 +++++++------ .../batch/BatchExecTableSourceScan.scala | 37 +++++++++-------- .../stream/StreamExecTableSourceScan.scala | 40 ++++++++++--------- .../flink/table/sources/TableSourceUtil.scala | 6 +-- .../nodes/dataset/BatchTableSourceScan.scala | 10 +++-- .../datastream/StreamTableSourceScan.scala | 10 +++-- .../flink/table/sources/TableSourceUtil.scala | 10 +++-- .../catalog/CatalogStructureBuilder.java | 12 +++--- 12 files changed, 134 insertions(+), 90 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedFieldMapping.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedFieldMapping.java index dd1a0cad1556d0..daa79247ec34f2 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedFieldMapping.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/DefinedFieldMapping.java @@ -19,8 +19,8 @@ package org.apache.flink.table.sources; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.types.DataType; import javax.annotation.Nullable; @@ -28,11 +28,11 @@ /** * The {@link DefinedFieldMapping} interface provides a mapping for the fields of the table schema - * ({@link TableSource#getTableSchema} to fields of the physical returned type - * {@link TableSource#getReturnType} of a {@link TableSource}. + * ({@link TableSource#getTableSchema} to fields of the physical produced data type + * {@link TableSource#getProducedDataType()} of a {@link TableSource}. * *

    If a {@link TableSource} does not implement the {@link DefinedFieldMapping} interface, the fields of - * its {@link TableSchema} are mapped to the fields of its return type {@link TypeInformation} by name. + * its {@link TableSchema} are mapped to the fields of its produced {@link DataType} by name. * *

    If the fields cannot or should not be implicitly mapped by name, an explicit mapping can be * provided by implementing this interface. @@ -44,17 +44,17 @@ public interface DefinedFieldMapping { /** * Returns the mapping for the fields of the {@link TableSource}'s {@link TableSchema} to the fields of - * its return type {@link TypeInformation}. + * its produced {@link DataType}. * *

    The mapping is done based on field names, e.g., a mapping "name" -> "f1" maps the schema field - * "name" to the field "f1" of the return type, for example in this case the second field of a + * "name" to the field "f1" of the produced data type, for example in this case the second field of a * {@link org.apache.flink.api.java.tuple.Tuple}. * - *

    The returned mapping must map all fields (except proctime and rowtime fields) to the return + *

    The returned mapping must map all fields (except proctime and rowtime fields) to the produced data * type. It can also provide a mapping for fields which are not in the {@link TableSchema} to make - * fields in the physical {@link TypeInformation} accessible for a {@code TimestampExtractor}. + * fields in the physical {@link DataType} accessible for a {@code TimestampExtractor}. * - * @return A mapping from {@link TableSchema} fields to {@link TypeInformation} fields or + * @return A mapping from {@link TableSchema} fields to {@link DataType} fields or * null if no mapping is necessary. */ @Nullable diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/NestedFieldsProjectableTableSource.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/NestedFieldsProjectableTableSource.java index e161c78b3a69ed..49fb22b1a0ed83 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/NestedFieldsProjectableTableSource.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/NestedFieldsProjectableTableSource.java @@ -34,13 +34,13 @@ public interface NestedFieldsProjectableTableSource { /** * Creates a copy of the {@link TableSource} that projects its output to the given field indexes. - * The field indexes relate to the physical return type ({@link TableSource#getReturnType()}) and not - * to the table schema ({@link TableSource#getTableSchema()} of the {@link TableSource}. + * The field indexes relate to the physical produced data type ({@link TableSource#getProducedDataType()}) + * and not to the table schema ({@link TableSource#getTableSchema()} of the {@link TableSource}. * *

    The table schema ({@link TableSource#getTableSchema()} of the {@link TableSource} copy must not be - * modified by this method, but only the return type ({@link TableSource#getReturnType()}) and the - * produced {@code DataSet} ({@code BatchTableSource.getDataSet(}) or {@code DataStream} - * ({@code StreamTableSource.getDataStream()}). The return type may only be changed by + * modified by this method, but only the produced data type ({@link TableSource#getProducedDataType()}) + * and the produced {@code DataSet} ({@code BatchTableSource.getDataSet(}) or {@code DataStream} + * ({@code StreamTableSource.getDataStream()}). The produced data type may only be changed by * removing or reordering first level fields. The type of the first level fields must not be * changed. * diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/ProjectableTableSource.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/ProjectableTableSource.java index 1c4b3c178e3e00..e79951efdeb447 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/ProjectableTableSource.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/ProjectableTableSource.java @@ -34,12 +34,12 @@ public interface ProjectableTableSource { /** * Creates a copy of the {@link TableSource} that projects its output to the given field indexes. - * The field indexes relate to the physical return type ({@link TableSource#getReturnType}) and not - * to the table schema ({@link TableSource#getTableSchema} of the {@link TableSource}. + * The field indexes relate to the physical poduced data type ({@link TableSource#getProducedDataType()}) + * and not to the table schema ({@link TableSource#getTableSchema} of the {@link TableSource}. * *

    The table schema ({@link TableSource#getTableSchema} of the {@link TableSource} copy must not be - * modified by this method, but only the return type ({@link TableSource#getReturnType}) and the - * produced {@code DataSet} ({@code BatchTableSource#getDataSet(}) or {@code DataStream} + * modified by this method, but only the produced data type ({@link TableSource#getProducedDataType()}) + * and the produced {@code DataSet} ({@code BatchTableSource#getDataSet(}) or {@code DataStream} * ({@code StreamTableSource#getDataStream}). * *

    If the {@link TableSource} implements the {@link DefinedFieldMapping} interface, it might diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java index 05bf92c2c77696..bf379f632c4dd7 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java @@ -20,18 +20,23 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableException; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.types.DataType; import org.apache.flink.table.utils.TableConnectorUtils; +import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType; + /** * Defines an external table with the schema that is provided by {@link TableSource#getTableSchema}. * *

    The data of a {@link TableSource} is produced as a {@code DataSet} in case of a {@code BatchTableSource} * or as a {@code DataStream} in case of a {@code StreamTableSource}. The type of ths produced - * {@code DataSet} or {@code DataStream} is specified by the {@link TableSource#getReturnType} method. + * {@code DataSet} or {@code DataStream} is specified by the {@link TableSource#getProducedDataType()} method. * *

    By default, the fields of the {@link TableSchema} are implicitly mapped by name to the fields of - * the return type {@link TypeInformation}. An explicit mapping can be defined by implementing the + * the produced {@link DataType}. An explicit mapping can be defined by implementing the * {@link DefinedFieldMapping} interface. * * @param The return type of the {@link TableSource}. @@ -40,12 +45,31 @@ public interface TableSource { /** - * Returns the {@link TypeInformation} for the return type of the {@link TableSource}. - * The fields of the return type are mapped to the table schema based on their name. + * Returns the {@link DataType} for the produced data of the {@link TableSource}. + * The fields of the data type are mapped to the table schema based on their name. * - * @return The type of the returned {@code DataSet} or {@code DataStream}. + * @return The data type of the returned {@code DataSet} or {@code DataStream}. + */ + default DataType getProducedDataType() { + final TypeInformation legacyType = getReturnType(); + if (legacyType == null) { + throw new TableException("Table source does not implement a produced data type."); + } + return fromLegacyInfoToDataType(getReturnType()); + } + + /** + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use {@link #getProducedDataType()} instead which uses the new type + * system based on {@link DataTypes}. Please make sure to use either the old or the new type + * system consistently to avoid unintended behavior. See the website documentation + * for more information. */ - TypeInformation getReturnType(); + @Deprecated + @SuppressWarnings("unchecked") + default TypeInformation getReturnType() { + return null; + } /** * Returns the schema of the produced table. diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala index 6d2e40dea649dc..695e1167cfc638 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/common/CommonLookupJoin.scala @@ -17,6 +17,8 @@ */ package org.apache.flink.table.plan.nodes.common +import java.util.Collections + import com.google.common.primitives.Primitives import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} @@ -45,17 +47,15 @@ import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getParamClassesConsiderVarArgs, getUserDefinedMethod, signatureToString, signaturesToString} import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, UserDefinedFunction} import org.apache.flink.table.plan.nodes.FlinkRelNode -import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType -import org.apache.flink.table.plan.util.{JoinTypeUtil, RelExplainUtil} import org.apache.flink.table.plan.util.LookupJoinUtil._ -import org.apache.flink.table.runtime.join.lookup.{AsyncLookupJoinRunner, LookupJoinRunner, AsyncLookupJoinWithCalcRunner, LookupJoinWithCalcRunner} +import org.apache.flink.table.plan.util.{JoinTypeUtil, RelExplainUtil} +import org.apache.flink.table.runtime.join.lookup.{AsyncLookupJoinRunner, AsyncLookupJoinWithCalcRunner, LookupJoinRunner, LookupJoinWithCalcRunner} import org.apache.flink.table.sources.TableIndex.IndexType import org.apache.flink.table.sources.{LookupConfig, LookupableTableSource, TableIndex, TableSource} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import org.apache.flink.table.typeutils.BaseRowTypeInfo import org.apache.flink.types.Row -import java.util.Collections - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -157,6 +157,9 @@ abstract class CommonLookupJoin( val resultRowType = FlinkTypeFactory.toInternalRowType(getRowType) val tableSchema = tableSource.getTableSchema + val producedDataType = tableSource.getProducedDataType + val producedTypeInfo = fromDataTypeToLegacyInfo(producedDataType) + // validate whether the node is valid and supported. validate( tableSource, @@ -202,7 +205,7 @@ abstract class CommonLookupJoin( 0) checkUdtfReturnType( tableSource.explainSource(), - tableSource.getReturnType, + producedTypeInfo, udtfResultType, extractedResultTypeInfo) val parameters = Array(new GenericType(classOf[ResultFuture[_]])) ++ lookupFieldTypesInOrder @@ -216,7 +219,7 @@ abstract class CommonLookupJoin( relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory], inputRowType, resultRowType, - tableSource.getReturnType, + producedTypeInfo, lookupFieldsInOrder, allLookupKeys, asyncLookupFunction) @@ -240,7 +243,7 @@ abstract class CommonLookupJoin( generatedFetcher, generatedCalc, generatedResultFuture, - tableSource.getReturnType, + producedTypeInfo, rightRowType.toTypeInfo, leftOuterJoin, lookupConfig.getAsyncBufferCapacity) @@ -256,7 +259,7 @@ abstract class CommonLookupJoin( new AsyncLookupJoinRunner( generatedFetcher, generatedResultFuture, - tableSource.getReturnType, + producedTypeInfo, rightRowType.toTypeInfo, leftOuterJoin, asyncBufferCapacity) @@ -277,7 +280,7 @@ abstract class CommonLookupJoin( 0) checkUdtfReturnType( tableSource.explainSource(), - tableSource.getReturnType, + producedTypeInfo, udtfResultType, extractedResultTypeInfo) checkEvalMethodSignature( @@ -290,7 +293,7 @@ abstract class CommonLookupJoin( relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory], inputRowType, resultRowType, - tableSource.getReturnType, + producedTypeInfo, lookupFieldsInOrder, allLookupKeys, lookupFunction, @@ -655,7 +658,7 @@ abstract class CommonLookupJoin( "but was " + joinType.toString + " JOIN") } - val tableReturnType = tableSource.getReturnType + val tableReturnType = fromDataTypeToLegacyInfo(tableSource.getProducedDataType) if (!tableReturnType.isInstanceOf[BaseRowTypeInfo] && !tableReturnType.isInstanceOf[RowTypeInfo]) { throw new TableException( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecTableSourceScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecTableSourceScan.scala index ee39511a3b5314..ade92db06e8c95 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecTableSourceScan.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecTableSourceScan.scala @@ -18,25 +18,25 @@ package org.apache.flink.table.plan.nodes.physical.batch +import java.util + +import org.apache.calcite.plan._ +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rex.RexNode import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.transformations.StreamTransformation import org.apache.flink.table.api.{BatchTableEnvironment, TableException, Types} +import org.apache.flink.table.codegen.CodeGeneratorContext import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.plan.nodes.physical.PhysicalTableSourceScan import org.apache.flink.table.plan.schema.FlinkRelOptTable -import org.apache.flink.table.sources.{BatchTableSource, TableSourceUtil} -import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo -import org.apache.flink.table.codegen.CodeGeneratorContext import org.apache.flink.table.plan.util.ScanUtil - -import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.rex.RexNode - -import java.util +import org.apache.flink.table.sources.{BatchTableSource, TableSourceUtil} +import org.apache.flink.table.types.utils.TypeConversions +import org.apache.flink.table.types.utils.TypeConversions.{fromDataTypeToLegacyInfo, fromLegacyInfoToDataType} import scala.collection.JavaConversions._ @@ -89,13 +89,16 @@ class BatchExecTableSourceScan( isStreamTable = false, None) + val inputDataType = fromLegacyInfoToDataType(inputTransform.getOutputType) + val producedDataType = tableSource.getProducedDataType + val producedTypeInfo = fromDataTypeToLegacyInfo(producedDataType) + // check that declared and actual type of table source DataStream are identical - if (createInternalTypeFromTypeInfo(inputTransform.getOutputType) != - createInternalTypeFromTypeInfo(tableSource.getReturnType)) { + if (inputDataType != producedDataType) { throw new TableException(s"TableSource of type ${tableSource.getClass.getCanonicalName} " + - s"returned a DataSet of type ${inputTransform.getOutputType} that does not match with " + - s"the type ${tableSource.getReturnType} declared by the TableSource.getReturnType() " + - s"method. Please validate the implementation of the TableSource.") + s"returned a DataStream of data type $producedDataType that does not match with the " + + s"data type $producedDataType declared by the TableSource.getProducedDataType() method. " + + s"Please validate the implementation of the TableSource.") } // get expression to extract rowtime attribute @@ -111,7 +114,7 @@ class BatchExecTableSourceScan( CodeGeneratorContext(config), inputTransform.asInstanceOf[StreamTransformation[Any]], fieldIndexes, - tableSource.getReturnType, + producedTypeInfo, getRowType, getTable.getQualifiedName, config, @@ -129,7 +132,7 @@ class BatchExecTableSourceScan( None) ScanUtil.hasTimeAttributeField(fieldIndexes) || ScanUtil.needsConversion( - tableSource.getReturnType, + fromDataTypeToLegacyInfo(tableSource.getProducedDataType), TypeExtractor.createTypeInfo( tableSource, classOf[BatchTableSource[_]], tableSource.getClass, 0) .getTypeClass.asInstanceOf[Class[_]]) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTableSourceScan.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTableSourceScan.scala index 096604bedd8491..0579f9e2d01fea 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTableSourceScan.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTableSourceScan.scala @@ -18,30 +18,29 @@ package org.apache.flink.table.plan.nodes.physical.stream +import java.util + +import org.apache.calcite.plan._ +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rex.RexNode import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.functions.{AssignerWithPeriodicWatermarks, AssignerWithPunctuatedWatermarks} import org.apache.flink.streaming.api.transformations.StreamTransformation import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.{StreamTableEnvironment, TableException, Types} +import org.apache.flink.table.codegen.CodeGeneratorContext +import org.apache.flink.table.codegen.OperatorCodeGenerator._ import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} import org.apache.flink.table.plan.nodes.physical.PhysicalTableSourceScan import org.apache.flink.table.plan.schema.FlinkRelOptTable -import org.apache.flink.table.sources.wmstrategies.{PeriodicWatermarkAssigner, PreserveWatermarks, PunctuatedWatermarkAssigner} -import org.apache.flink.table.sources.{RowtimeAttributeDescriptor, StreamTableSource, TableSourceUtil} -import org.apache.flink.table.`type`.TypeConverters.createInternalTypeFromTypeInfo -import org.apache.flink.table.codegen.CodeGeneratorContext -import org.apache.flink.table.codegen.OperatorCodeGenerator._ import org.apache.flink.table.plan.util.ScanUtil import org.apache.flink.table.runtime.AbstractProcessStreamOperator - -import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.rex.RexNode - -import java.util +import org.apache.flink.table.sources.wmstrategies.{PeriodicWatermarkAssigner, PreserveWatermarks, PunctuatedWatermarkAssigner} +import org.apache.flink.table.sources.{RowtimeAttributeDescriptor, StreamTableSource, TableSourceUtil} +import org.apache.flink.table.types.utils.TypeConversions.{fromDataTypeToLegacyInfo, fromLegacyInfoToDataType} import scala.collection.JavaConversions._ @@ -99,13 +98,16 @@ class StreamExecTableSourceScan( isStreamTable = true, None) + val inputDataType = fromLegacyInfoToDataType(inputTransform.getOutputType) + val producedDataType = tableSource.getProducedDataType + val producedTypeInfo = fromDataTypeToLegacyInfo(producedDataType) + // check that declared and actual type of table source DataStream are identical - if (createInternalTypeFromTypeInfo(inputTransform.getOutputType) != - createInternalTypeFromTypeInfo(tableSource.getReturnType)) { + if (inputDataType != producedDataType) { throw new TableException(s"TableSource of type ${tableSource.getClass.getCanonicalName} " + - s"returned a DataStream of type ${inputTransform.getOutputType} that does not match with " + - s"the type ${tableSource.getReturnType} declared by the TableSource.getReturnType() " + - s"method. Please validate the implementation of the TableSource.") + s"returned a DataStream of data type $producedDataType that does not match with the " + + s"data type $producedDataType declared by the TableSource.getProducedDataType() method. " + + s"Please validate the implementation of the TableSource.") } // get expression to extract rowtime attribute @@ -131,7 +133,7 @@ class StreamExecTableSourceScan( ctx, inputTransform.asInstanceOf[StreamTransformation[Any]], fieldIndexes, - tableSource.getReturnType, + producedTypeInfo, getRowType, getTable.getQualifiedName, config, @@ -177,7 +179,7 @@ class StreamExecTableSourceScan( None) ScanUtil.hasTimeAttributeField(fieldIndexes) || ScanUtil.needsConversion( - tableSource.getReturnType, + fromDataTypeToLegacyInfo(tableSource.getProducedDataType), TypeExtractor.createTypeInfo( tableSource, classOf[StreamTableSource[_]], tableSource.getClass, 0) .getTypeClass.asInstanceOf[Class[_]]) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala index a4dc4c74417f82..d39c7a907720ad 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala @@ -32,7 +32,7 @@ import org.apache.flink.table.`type`.{InternalType, TypeConverters} import org.apache.flink.table.api.{Types, ValidationException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.expressions._ -import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType +import org.apache.flink.table.types.utils.TypeConversions.{fromDataTypeToLegacyInfo, fromLegacyInfoToDataType} import scala.collection.JavaConversions._ @@ -113,7 +113,7 @@ object TableSourceUtil { } idx } - val inputType = tableSource.getReturnType + val inputType = fromDataTypeToLegacyInfo(tableSource.getProducedDataType) // ensure that only one field is mapped to an atomic type if (!inputType.isInstanceOf[CompositeType[_]] && mapping.count(_ >= 0) > 1) { @@ -315,7 +315,7 @@ object TableSourceUtil { fieldName: String, tableSource: TableSource[_]): (String, Int, TypeInformation[_]) = { - val returnType = tableSource.getReturnType + val returnType = fromDataTypeToLegacyInfo(tableSource.getProducedDataType) /** Look up a field by name in a type information */ def lookupField(fieldName: String, failMsg: String): (String, Int, TypeInformation[_]) = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala index 862523bd0a4355..7f4a5fb7c0146b 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala @@ -29,6 +29,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.PhysicalTableSourceScan import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.sources._ +import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType import org.apache.flink.types.Row /** Flink RelNode to read data from an external source defined by a [[BatchTableSource]]. */ @@ -91,11 +92,14 @@ class BatchTableSourceScan( val inputDataSet = tableSource.getDataSet(tableEnv.execEnv).asInstanceOf[DataSet[Any]] val outputSchema = new RowSchema(this.getRowType) + val inputDataType = fromLegacyInfoToDataType(inputDataSet.getType) + val producedDataType = tableSource.getProducedDataType + // check that declared and actual type of table source DataSet are identical - if (inputDataSet.getType != tableSource.getReturnType) { + if (inputDataType != producedDataType) { throw new TableException(s"TableSource of type ${tableSource.getClass.getCanonicalName} " + - s"returned a DataSet of type ${inputDataSet.getType} that does not match with the " + - s"type ${tableSource.getReturnType} declared by the TableSource.getReturnType() method. " + + s"returned a DataSet of data type $producedDataType that does not match with the " + + s"data type $producedDataType declared by the TableSource.getProducedDataType() method. " + s"Please validate the implementation of the TableSource.") } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala index 20e2234556ab02..38216ba8c25f4d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala @@ -33,6 +33,7 @@ import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.runtime.types.CRow import org.apache.flink.table.sources._ import org.apache.flink.table.sources.wmstrategies.{PeriodicWatermarkAssigner, PreserveWatermarks, PunctuatedWatermarkAssigner} +import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo /** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */ @@ -95,11 +96,14 @@ class StreamTableSourceScan( val inputDataStream = tableSource.getDataStream(tableEnv.execEnv).asInstanceOf[DataStream[Any]] val outputSchema = new RowSchema(this.getRowType) + val inputDataType = fromLegacyInfoToDataType(inputDataStream.getType) + val producedDataType = tableSource.getProducedDataType + // check that declared and actual type of table source DataStream are identical - if (inputDataStream.getType != tableSource.getReturnType) { + if (inputDataType != producedDataType) { throw new TableException(s"TableSource of type ${tableSource.getClass.getCanonicalName} " + - s"returned a DataStream of type ${inputDataStream.getType} that does not match with the " + - s"type ${tableSource.getReturnType} declared by the TableSource.getReturnType() method. " + + s"returned a DataStream of data type $producedDataType that does not match with the " + + s"data type $producedDataType declared by the TableSource.getProducedDataType() method. " + s"Please validate the implementation of the TableSource.") } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala index 375b295e5cb7d7..417318845cb2d6 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala @@ -32,6 +32,7 @@ import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.table.api.{TableException, Types, ValidationException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.expressions.{Cast, PlannerExpression, PlannerResolvedFieldReference, ResolvedFieldReference} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo import scala.collection.JavaConverters._ @@ -95,9 +96,10 @@ object TableSourceUtil { mappedFieldCnt += 1 } // ensure that only one field is mapped to an atomic type - if (!tableSource.getReturnType.isInstanceOf[CompositeType[_]] && mappedFieldCnt > 1) { + val producedType = fromDataTypeToLegacyInfo(tableSource.getProducedDataType) + if (!producedType.isInstanceOf[CompositeType[_]] && mappedFieldCnt > 1) { throw new ValidationException( - s"More than one table field matched to atomic input type ${tableSource.getReturnType}.") + s"More than one table field matched to atomic input type $producedType.") } // validate rowtime attributes @@ -172,7 +174,7 @@ object TableSourceUtil { tableSource: TableSource[_], isStreamTable: Boolean, selectedFields: Option[Array[Int]]): Array[Int] = { - val inputType = tableSource.getReturnType + val inputType = fromDataTypeToLegacyInfo(tableSource.getProducedDataType) val tableSchema = tableSource.getTableSchema // get names of selected fields @@ -465,7 +467,7 @@ object TableSourceUtil { fieldName: String, tableSource: TableSource[_]): (String, Int, TypeInformation[_]) = { - val returnType = tableSource.getReturnType + val returnType = fromDataTypeToLegacyInfo(tableSource.getProducedDataType) /** Look up a field by name in a type information */ def lookupField(fieldName: String, failMsg: String): (String, Int, TypeInformation[_]) = { diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java index 8787e94d13f17c..ce5072024596e4 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/catalog/CatalogStructureBuilder.java @@ -22,7 +22,9 @@ import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.api.Types; import org.apache.flink.table.sources.StreamTableSource; +import org.apache.flink.types.Row; import java.util.HashMap; import java.util.Map; @@ -265,18 +267,18 @@ public int hashCode() { } } - private static class TestTable extends ConnectorCatalogTable { + private static class TestTable extends ConnectorCatalogTable { private final String fullyQualifiedPath; - private static final StreamTableSource tableSource = new StreamTableSource() { + private static final StreamTableSource tableSource = new StreamTableSource() { @Override - public DataStream getDataStream(StreamExecutionEnvironment execEnv) { + public DataStream getDataStream(StreamExecutionEnvironment execEnv) { return null; } @Override - public TypeInformation getReturnType() { - return null; + public TypeInformation getReturnType() { + return Types.ROW(); } @Override From 33519930a19735d531cd70d0fe32bf531a90ff0b Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Fri, 31 May 2019 17:16:27 +0200 Subject: [PATCH 77/92] [FLINK-12254][table] Update TableSink and related interfaces to new type system This closes #8596. --- .../gateway/local/ExecutionContextTest.java | 4 +- .../flink/table/catalog/CatalogManager.java | 4 +- .../table/catalog/ConnectorCatalogTable.java | 3 +- .../apache/flink/table/sinks/TableSink.java | 62 ++++++++++++++++--- .../flink/table/sinks/TableSinkBase.java | 3 + .../flink/table/sources/TableSource.java | 4 +- .../table/codegen/SinkCodeGenerator.scala | 2 +- .../flink/table/plan/nodes/calcite/Sink.scala | 12 ++-- .../nodes/physical/batch/BatchExecSink.scala | 12 ++-- .../physical/stream/StreamExecSink.scala | 18 +++--- .../table/plan/schema/TableSinkTable.scala | 5 +- .../flink/table/sinks/CollectTableSink.scala | 2 +- .../table/sinks/DataStreamTableSink.scala | 2 +- .../runtime/utils/BatchTableEnvUtil.scala | 7 ++- .../flink/table/api/BatchTableEnvImpl.scala | 4 +- .../flink/table/api/StreamTableEnvImpl.scala | 10 ++- .../apache/flink/table/api/TableEnvImpl.scala | 12 +--- .../flink/table/sinks/CsvTableSink.scala | 4 +- .../utils/MemoryTableSourceSinkUtil.scala | 6 +- 19 files changed, 113 insertions(+), 63 deletions(-) diff --git a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java index 5137f6316a77b0..bcdd9919e7c292 100644 --- a/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java +++ b/flink-table/flink-sql-client/src/test/java/org/apache/flink/table/client/gateway/local/ExecutionContextTest.java @@ -124,11 +124,11 @@ public void testTables() throws Exception { assertArrayEquals( new String[]{"BooleanField", "StringField"}, - sinks.get("TableSourceSink").getFieldNames()); + sinks.get("TableSourceSink").getTableSchema().getFieldNames()); assertArrayEquals( new TypeInformation[]{Types.BOOLEAN(), Types.STRING()}, - sinks.get("TableSourceSink").getFieldTypes()); + sinks.get("TableSourceSink").getTableSchema().getFieldTypes()); final TableEnvironment tableEnv = context.createEnvironmentInstance().getTableEnvironment(); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java index 5704043d02afe7..c965a98c6a8833 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/CatalogManager.java @@ -383,8 +383,8 @@ private static TableSchema getTableSchema(ExternalCatalogTable externalTable) { if (externalTable.isTableSource()) { return TableFactoryUtil.findAndCreateTableSource(externalTable).getTableSchema(); } else { - TableSink tableSink = TableFactoryUtil.findAndCreateTableSink(externalTable); - return new TableSchema(tableSink.getFieldNames(), tableSink.getFieldTypes()); + TableSink tableSink = TableFactoryUtil.findAndCreateTableSink(externalTable); + return tableSink.getTableSchema(); } } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java index ce860842349e0a..84b6e0d3941c16 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ConnectorCatalogTable.java @@ -57,8 +57,7 @@ public static ConnectorCatalogTable source(TableSource source, boolean } public static ConnectorCatalogTable sink(TableSink sink, boolean isBatch) { - TableSchema tableSchema = new TableSchema(sink.getFieldNames(), sink.getFieldTypes()); - return new ConnectorCatalogTable<>(null, sink, tableSchema, isBatch); + return new ConnectorCatalogTable<>(null, sink, sink.getTableSchema(), isBatch); } public static ConnectorCatalogTable sourceAndSink( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSink.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSink.java index 4b13cb57330b9d..f8e003cb661b9e 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSink.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSink.java @@ -20,6 +20,12 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.types.DataType; + +import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType; /** * A {@link TableSink} specifies how to emit a table to an external @@ -33,23 +39,59 @@ public interface TableSink { /** - * Returns the type expected by this {@link TableSink}. + * Returns the data type consumed by this {@link TableSink}. * - *

    This type should depend on the types returned by {@link TableSink#getFieldNames()}. + * @return The data type expected by this {@link TableSink}. + */ + default DataType getConsumedDataType() { + final TypeInformation legacyType = getOutputType(); + if (legacyType == null) { + throw new TableException("Table sink does not implement a consumed data type."); + } + return fromLegacyInfoToDataType(legacyType); + } + + /** + * @deprecated This method will be removed in future versions as it uses the old type system. It + * is recommended to use {@link #getConsumedDataType()} instead which uses the new type + * system based on {@link DataTypes}. Please make sure to use either the old or the new type + * system consistently to avoid unintended behavior. See the website documentation + * for more information. + */ + @Deprecated + default TypeInformation getOutputType() { + return null; + } + + /** + * Returns the schema of the consumed table. * - * @return The type expected by this {@link TableSink}. + * @return The {@link TableSchema} of the consumed table. */ - TypeInformation getOutputType(); + default TableSchema getTableSchema() { + final String[] fieldNames = getFieldNames(); + final TypeInformation[] legacyFieldTypes = getFieldTypes(); + if (fieldNames == null || legacyFieldTypes == null) { + throw new TableException("Table sink does not implement a table schema."); + } + return new TableSchema(fieldNames, legacyFieldTypes); + } /** - * Returns the names of the table fields. + * @deprecated Use the field names of {@link #getTableSchema()} instead. */ - String[] getFieldNames(); + @Deprecated + default String[] getFieldNames() { + return null; + } /** - * Returns the types of the table fields. + * @deprecated Use the field types of {@link #getTableSchema()} instead. */ - TypeInformation[] getFieldTypes(); + @Deprecated + default TypeInformation[] getFieldTypes() { + return null; + } /** * Returns a copy of this {@link TableSink} configured with the field names and types of the @@ -59,6 +101,10 @@ public interface TableSink { * @param fieldTypes The field types of the table to emit. * @return A copy of this {@link TableSink} configured with the field names and types of the * table to emit. + * + * @deprecated This method will be dropped in future versions. It is recommended to pass a + * static schema when instantiating the sink instead. */ + @Deprecated TableSink configure(String[] fieldNames, TypeInformation[] fieldTypes); } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSinkBase.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSinkBase.java index 830bf1cff106df..3bf9f3f4157478 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSinkBase.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sinks/TableSinkBase.java @@ -42,6 +42,7 @@ public abstract class TableSinkBase implements TableSink { /** * Returns the field names of the table to emit. */ + @Override public String[] getFieldNames() { if (fieldNames.isPresent()) { return fieldNames.get(); @@ -54,6 +55,7 @@ public String[] getFieldNames() { /** * Returns the field types of the table to emit. */ + @Override public TypeInformation[] getFieldTypes() { if (fieldTypes.isPresent()) { return fieldTypes.get(); @@ -72,6 +74,7 @@ public TypeInformation[] getFieldTypes() { * @return A copy of this {@link TableSink} configured with the field names and types of the * table to emit. */ + @Override public final TableSink configure(String[] fieldNames, TypeInformation[] fieldTypes) { final TableSinkBase configuredSink = this.copy(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java index bf379f632c4dd7..386240be376def 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/sources/TableSource.java @@ -46,7 +46,6 @@ public interface TableSource { /** * Returns the {@link DataType} for the produced data of the {@link TableSource}. - * The fields of the data type are mapped to the table schema based on their name. * * @return The data type of the returned {@code DataSet} or {@code DataStream}. */ @@ -55,7 +54,7 @@ default DataType getProducedDataType() { if (legacyType == null) { throw new TableException("Table source does not implement a produced data type."); } - return fromLegacyInfoToDataType(getReturnType()); + return fromLegacyInfoToDataType(legacyType); } /** @@ -66,7 +65,6 @@ default DataType getProducedDataType() { * for more information. */ @Deprecated - @SuppressWarnings("unchecked") default TypeInformation getReturnType() { return null; } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala index 5ec3a383034a1d..6006ec1b5aa763 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/SinkCodeGenerator.scala @@ -44,7 +44,7 @@ object SinkCodeGenerator { try { sink match { // DataStreamTableSink has no generic class, so we need get the type to get type class. - case sink: DataStreamTableSink[_] => sink.getOutputType.getTypeClass + case sink: DataStreamTableSink[_] => sink.getConsumedDataType.getConversionClass case _ => TypeExtractor.createTypeInfo(sink, classOf[TableSink[_]], sink.getClass, 0) .getTypeClass.asInstanceOf[Class[_]] } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Sink.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Sink.scala index 29623c32de980c..9d0dffa384296e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Sink.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/calcite/Sink.scala @@ -18,13 +18,13 @@ package org.apache.flink.table.plan.nodes.calcite -import org.apache.flink.table.`type`.TypeConverters -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.sinks.TableSink - import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.table.`type`.TypeConverters +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.sinks.TableSink +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo /** * Relational expression that writes out data of input node into a [[TableSink]]. @@ -45,7 +45,7 @@ abstract class Sink( override def deriveRowType(): RelDataType = { val typeFactory = getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] - val outputType = sink.getOutputType + val outputType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) val internalType = TypeConverters.createInternalTypeFromTypeInfo(outputType) typeFactory.createTypeFromInternalType(internalType, isNullable = true) } @@ -53,7 +53,7 @@ abstract class Sink( override def explainTerms(pw: RelWriter): RelWriter = { super.explainTerms(pw) .itemIf("name", sinkName, sinkName != null) - .item("fields", sink.getFieldNames.mkString(", ")) + .item("fields", sink.getTableSchema.getFieldNames.mkString(", ")) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSink.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSink.scala index 90e9a6e7c9c184..a6f5dcb10c10a4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSink.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/batch/BatchExecSink.scala @@ -18,6 +18,10 @@ package org.apache.flink.table.plan.nodes.physical.batch +import java.util + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.RelNode import org.apache.flink.runtime.operators.DamBehavior import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} @@ -28,13 +32,9 @@ import org.apache.flink.table.dataformat.BaseRow import org.apache.flink.table.plan.nodes.calcite.Sink import org.apache.flink.table.plan.nodes.exec.{BatchExecNode, ExecNode} import org.apache.flink.table.sinks.{BatchTableSink, DataStreamTableSink, TableSink} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import org.apache.flink.table.typeutils.BaseRowTypeInfo -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.RelNode - -import java.util - import scala.collection.JavaConversions._ /** @@ -98,7 +98,7 @@ class BatchExecSink[T]( private def translateToStreamTransformation( withChangeFlag: Boolean, tableEnv: BatchTableEnvironment): StreamTransformation[T] = { - val resultType = sink.getOutputType + val resultType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) TableEnvironment.validateType(resultType) val inputNode = getInputNodes.get(0) inputNode match { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecSink.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecSink.scala index fbfb776b9f38a4..d39b996c907cd8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecSink.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecSink.scala @@ -18,26 +18,26 @@ package org.apache.flink.table.plan.nodes.physical.stream +import java.util + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.RelNode import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation} +import org.apache.flink.table.`type`.InternalTypes import org.apache.flink.table.api.{StreamTableEnvironment, Table, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.codegen.{CodeGenUtils, CodeGeneratorContext} import org.apache.flink.table.codegen.SinkCodeGenerator.{extractTableSinkTypeClass, generateRowConverterOperator} +import org.apache.flink.table.codegen.{CodeGenUtils, CodeGeneratorContext} import org.apache.flink.table.dataformat.BaseRow +import org.apache.flink.table.plan.`trait`.{AccMode, AccModeTraitDef} import org.apache.flink.table.plan.nodes.calcite.Sink import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode} -import org.apache.flink.table.plan.`trait`.{AccMode, AccModeTraitDef} import org.apache.flink.table.plan.util.UpdatingPlanChecker import org.apache.flink.table.sinks._ -import org.apache.flink.table.`type`.InternalTypes +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import org.apache.flink.table.typeutils.BaseRowTypeInfo -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.RelNode - -import java.util - import scala.collection.JavaConversions._ /** @@ -178,7 +178,7 @@ class StreamExecSink[T]( } else { parTransformation.getOutputType } - val resultType = sink.getOutputType + val resultType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) val typeClass = extractTableSinkTypeClass(sink) if (CodeGenUtils.isInternalClass(typeClass, resultType)) { parTransformation.asInstanceOf[StreamTransformation[T]] diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala index 63d82043e407a1..fe02425b6d3895 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/schema/TableSinkTable.scala @@ -35,8 +35,9 @@ class TableSinkTable[T]( override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = { val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory] - val fieldTypes = tableSink.getFieldTypes.map(TypeConverters.createInternalTypeFromTypeInfo) - flinkTypeFactory.buildRelDataType(tableSink.getFieldNames, fieldTypes) + val fieldTypes = tableSink.getTableSchema.getFieldTypes + .map(TypeConverters.createInternalTypeFromTypeInfo) + flinkTypeFactory.buildRelDataType(tableSink.getTableSchema.getFieldNames, fieldTypes) } /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/CollectTableSink.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/CollectTableSink.scala index 3cc3903b7eb043..8c50f570dfd834 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/CollectTableSink.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/CollectTableSink.scala @@ -51,7 +51,7 @@ class CollectTableSink[T](produceOutputType: (Array[TypeInformation[_]] => TypeI } override def getOutputType: TypeInformation[T] = { - produceOutputType(getFieldTypes) + produceOutputType(getTableSchema.getFieldTypes) } def init(typeSerializer: TypeSerializer[T], id: String): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/DataStreamTableSink.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/DataStreamTableSink.scala index c4a308dea88353..c421706ecf7bc7 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/DataStreamTableSink.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/sinks/DataStreamTableSink.scala @@ -44,7 +44,7 @@ class DataStreamTableSink[T]( /** * Return the type expected by this [[TableSink]]. * - * This type should depend on the types returned by [[getFieldNames]]. + * This type should depend on the types returned by [[getTableSchema]]. * * @return The type expected by this [[TableSink]]. */ diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala index fe9315549e4822..2ac5a9659da980 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/utils/BatchTableEnvUtil.scala @@ -29,9 +29,10 @@ import org.apache.flink.table.plan.schema.DataStreamTable import org.apache.flink.table.plan.stats.FlinkStatistic import org.apache.flink.table.sinks.CollectTableSink import org.apache.flink.util.AbstractID - import _root_.java.util.{ArrayList => JArrayList} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo + import _root_.scala.collection.JavaConversions._ import _root_.scala.collection.JavaConverters._ @@ -42,7 +43,9 @@ object BatchTableEnvUtil { table: Table, sink: CollectTableSink[T], jobName: Option[String]): Seq[T] = { - val typeSerializer = sink.getOutputType.createSerializer(tEnv.streamEnv.getConfig) + val typeSerializer = fromDataTypeToLegacyInfo(sink.getConsumedDataType) + .asInstanceOf[TypeInformation[T]] + .createSerializer(tEnv.streamEnv.getConfig) val id = new AbstractID().toString sink.init(typeSerializer.asInstanceOf[TypeSerializer[T]], id) tEnv.writeToSink(table, sink) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala index 66a1d83c86b823..0d6e133758c6de 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/BatchTableEnvImpl.scala @@ -41,6 +41,7 @@ import org.apache.flink.table.plan.schema._ import org.apache.flink.table.runtime.MapRunner import org.apache.flink.table.sinks._ import org.apache.flink.table.sources.{BatchTableSource, TableSource, TableSourceUtil} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import org.apache.flink.table.typeutils.FieldInfoUtils.{calculateTableSchema, getFieldsInfo, validateInputTypeInfo} import org.apache.flink.types.Row @@ -108,7 +109,8 @@ abstract class BatchTableEnvImpl( sink match { case batchSink: BatchTableSink[T] => - val outputType = sink.getOutputType + val outputType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) + .asInstanceOf[TypeInformation[T]] // translate the Table into a DataSet and provide the type that the TableSink expects. val result: DataSet[T] = translate(table, batchQueryConfig)(outputType) // Give the DataSet to the TableSink to emit it. diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala index 32b3f8c51b01e7..d8f3edc03602e0 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/StreamTableEnvImpl.scala @@ -52,6 +52,7 @@ import org.apache.flink.table.runtime.{CRowMapRunner, OutputRowtimeProcessFuncti import org.apache.flink.table.sinks._ import org.apache.flink.table.sources.{StreamTableSource, TableSource, TableSourceUtil} import org.apache.flink.table.typeutils.FieldInfoUtils.{calculateTableSchema, getFieldsInfo, isReferenceByPosition} +import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo import org.apache.flink.table.typeutils.{TimeIndicatorTypeInfo, TypeCheckUtils} import _root_.scala.collection.JavaConverters._ @@ -139,7 +140,8 @@ abstract class StreamTableEnvImpl( case retractSink: RetractStreamTableSink[_] => // retraction sink can always be used - val outputType = sink.getOutputType + val outputType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) + .asInstanceOf[TypeInformation[T]] // translate the Table into a DataStream and provide the type that the TableSink expects. val result: DataStream[T] = translate( @@ -166,7 +168,8 @@ abstract class StreamTableEnvImpl( case None if !isAppendOnlyTable => throw new TableException( "UpsertStreamTableSink requires that Table has full primary keys if it is updated.") } - val outputType = sink.getOutputType + val outputType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) + .asInstanceOf[TypeInformation[T]] val resultType = getResultType(table.getRelNode, optimizedPlan) // translate the Table into a DataStream and provide the type that the TableSink expects. val result: DataStream[T] = @@ -187,7 +190,8 @@ abstract class StreamTableEnvImpl( throw new TableException( "AppendStreamTableSink requires that Table has only insert changes.") } - val outputType = sink.getOutputType + val outputType = fromDataTypeToLegacyInfo(sink.getConsumedDataType) + .asInstanceOf[TypeInformation[T]] val resultType = getResultType(table.getRelNode, optimizedPlan) // translate the Table into a DataStream and provide the type that the TableSink expects. val result: DataStream[T] = diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala index f2818299cb3846..b9ae9125e163c2 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/TableEnvImpl.scala @@ -455,15 +455,9 @@ abstract class TableEnvImpl( override def registerTableSink(name: String, configuredSink: TableSink[_]): Unit = { // validate - if (configuredSink.getFieldNames == null || configuredSink.getFieldTypes == null) { - throw new TableException("Table sink is not configured.") - } - if (configuredSink.getFieldNames.length == 0) { + if (configuredSink.getTableSchema.getFieldNames.length == 0) { throw new TableException("Field names must not be empty.") } - if (configuredSink.getFieldNames.length != configuredSink.getFieldTypes.length) { - throw new TableException("Same number of field names and types required.") - } validateTableSink(configuredSink) registerTableSinkInternal(name, configuredSink) @@ -692,13 +686,13 @@ abstract class TableEnvImpl( case Some(tableSink) => // validate schema of source table and table sink val srcFieldTypes = table.getSchema.getFieldTypes - val sinkFieldTypes = tableSink.getFieldTypes + val sinkFieldTypes = tableSink.getTableSchema.getFieldTypes if (srcFieldTypes.length != sinkFieldTypes.length || srcFieldTypes.zip(sinkFieldTypes).exists { case (srcF, snkF) => srcF != snkF }) { val srcFieldNames = table.getSchema.getFieldNames - val sinkFieldNames = tableSink.getFieldNames + val sinkFieldNames = tableSink.getTableSchema.getFieldNames // format table and table sink schema strings val srcSchema = srcFieldNames.zip(srcFieldTypes) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sinks/CsvTableSink.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sinks/CsvTableSink.scala index 74efcf3f755ea4..8037f4d9bedabf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sinks/CsvTableSink.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/sinks/CsvTableSink.scala @@ -80,7 +80,7 @@ class CsvTableSink( sink.setParallelism(numFiles.get) } - sink.name(TableConnectorUtils.generateRuntimeName(this.getClass, getFieldNames)) + sink.name(TableConnectorUtils.generateRuntimeName(this.getClass, getTableSchema.getFieldNames)) } override def emitDataStream(dataStream: DataStream[Row]): Unit = { @@ -99,7 +99,7 @@ class CsvTableSink( sink.setParallelism(numFiles.get) } - sink.name(TableConnectorUtils.generateRuntimeName(this.getClass, getFieldNames)) + sink.name(TableConnectorUtils.generateRuntimeName(this.getClass, getTableSchema.getFieldNames)) } override protected def copy: TableSinkBase[Row] = { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/MemoryTableSourceSinkUtil.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/MemoryTableSourceSinkUtil.scala index 2ee051b85f28c7..37342c935b24ba 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/MemoryTableSourceSinkUtil.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/MemoryTableSourceSinkUtil.scala @@ -104,7 +104,7 @@ object MemoryTableSourceSinkUtil { with AppendStreamTableSink[Row] { override def getOutputType: TypeInformation[Row] = { - new RowTypeInfo(getFieldTypes, getFieldNames) + new RowTypeInfo(getTableSchema.getFieldTypes, getTableSchema.getFieldNames) } override protected def copy: TableSinkBase[Row] = { @@ -114,7 +114,7 @@ object MemoryTableSourceSinkUtil { override def emitDataSet(dataSet: DataSet[Row]): Unit = { dataSet .output(new MemoryCollectionOutputFormat) - .name(TableConnectorUtils.generateRuntimeName(this.getClass, getFieldNames)) + .name(TableConnectorUtils.generateRuntimeName(this.getClass, getTableSchema.getFieldNames)) } override def emitDataStream(dataStream: DataStream[Row]): Unit = { @@ -122,7 +122,7 @@ object MemoryTableSourceSinkUtil { dataStream .addSink(new MemoryAppendSink) .setParallelism(inputParallelism) - .name(TableConnectorUtils.generateRuntimeName(this.getClass, getFieldNames)) + .name(TableConnectorUtils.generateRuntimeName(this.getClass, getTableSchema.getFieldNames)) } } From 16479f4d07327d6991d899d80246c4b48cc390af Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Mon, 20 May 2019 14:02:44 +0200 Subject: [PATCH 78/92] [hotfix][e2e] Refactor batch wordcount test to use different source types --- .../run-pre-commit-tests.sh | 6 ++-- .../test-scripts/test_batch_wordcount.sh | 31 +++++++++++++++++-- .../test-scripts/test_shaded_hadoop_s3a.sh | 30 ------------------ .../test-scripts/test_shaded_presto_s3.sh | 30 ------------------ 4 files changed, 32 insertions(+), 65 deletions(-) delete mode 100755 flink-end-to-end-tests/test-scripts/test_shaded_hadoop_s3a.sh delete mode 100755 flink-end-to-end-tests/test-scripts/test_shaded_presto_s3.sh diff --git a/flink-end-to-end-tests/run-pre-commit-tests.sh b/flink-end-to-end-tests/run-pre-commit-tests.sh index 3b81b3fec59200..e79542a5467766 100755 --- a/flink-end-to-end-tests/run-pre-commit-tests.sh +++ b/flink-end-to-end-tests/run-pre-commit-tests.sh @@ -52,14 +52,14 @@ run_test "State Migration end-to-end test from 1.6" "$END_TO_END_DIR/test-script run_test "State Evolution end-to-end test" "$END_TO_END_DIR/test-scripts/test_state_evolution.sh" run_test "Batch Python Wordcount end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_python_wordcount.sh" run_test "Streaming Python Wordcount end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_python_wordcount.sh" -run_test "Wordcount end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh" +run_test "Wordcount end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh file" +run_test "Shaded Hadoop S3A end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh hadoop" +run_test "Shaded Presto S3 end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh presto" run_test "Kafka 0.10 end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kafka010.sh" run_test "Kafka 0.11 end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kafka011.sh" run_test "Modern Kafka end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kafka.sh" run_test "Kinesis end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kinesis.sh" run_test "class loading end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_classloader.sh" -run_test "Shaded Hadoop S3A end-to-end test" "$END_TO_END_DIR/test-scripts/test_shaded_hadoop_s3a.sh" -run_test "Shaded Presto S3 end-to-end test" "$END_TO_END_DIR/test-scripts/test_shaded_presto_s3.sh" run_test "Distributed cache end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_distributed_cache_via_blob.sh" printf "\n[PASS] All tests passed\n" diff --git a/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh b/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh index 2c9a17597ca650..a5c82226e72da3 100755 --- a/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh +++ b/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh @@ -19,7 +19,34 @@ source "$(dirname "$0")"/common.sh +INPUT_TYPE=${1:-file} +case $INPUT_TYPE in + (file) + INPUT_LOCATION="${TEST_INFRA_DIR}/test-data/words" + ;; + (hadoop) + source "$(dirname "$0")"/common_s3.sh + s3_setup hadoop + INPUT_LOCATION="${S3_TEST_DATA_WORDS_URI}" + ;; + (presto) + source "$(dirname "$0")"/common_s3.sh + s3_setup presto + INPUT_LOCATION="${S3_TEST_DATA_WORDS_URI}" + ;; + (*) + echo "Unknown input type $INPUT_TYPE" + exit 1 + ;; +esac + +OUTPUT_LOCATION="${TEST_DATA_DIR}/out/wc_out" + +mkdir -p "${TEST_DATA_DIR}" + start_cluster -$FLINK_DIR/bin/flink run -p 1 $FLINK_DIR/examples/batch/WordCount.jar --input $TEST_INFRA_DIR/test-data/words --output $TEST_DATA_DIR/out/wc_out -check_result_hash "WordCount" $TEST_DATA_DIR/out/wc_out "72a690412be8928ba239c2da967328a5" \ No newline at end of file +# The test may run against different source types. +# But the sources should provide the same test data, so the checksum stays the same for all tests. +"${FLINK_DIR}/bin/flink" run -p 1 "${FLINK_DIR}/examples/batch/WordCount.jar" --input "${INPUT_LOCATION}" --output "${OUTPUT_LOCATION}" +check_result_hash "WordCount (${INPUT_TYPE})" "${OUTPUT_LOCATION}" "72a690412be8928ba239c2da967328a5" diff --git a/flink-end-to-end-tests/test-scripts/test_shaded_hadoop_s3a.sh b/flink-end-to-end-tests/test-scripts/test_shaded_hadoop_s3a.sh deleted file mode 100755 index ddbb6868424dd0..00000000000000 --- a/flink-end-to-end-tests/test-scripts/test_shaded_hadoop_s3a.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash -################################################################################ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ - -# Tests for our shaded/bundled Hadoop S3A file system. - -source "$(dirname "$0")"/common.sh -source "$(dirname "$0")"/common_s3.sh - -s3_setup hadoop -start_cluster - -$FLINK_DIR/bin/flink run -p 1 $FLINK_DIR/examples/batch/WordCount.jar --input $S3_TEST_DATA_WORDS_URI --output $TEST_DATA_DIR/out/wc_out - -check_result_hash "WordCountWithShadedS3A" $TEST_DATA_DIR/out/wc_out "72a690412be8928ba239c2da967328a5" diff --git a/flink-end-to-end-tests/test-scripts/test_shaded_presto_s3.sh b/flink-end-to-end-tests/test-scripts/test_shaded_presto_s3.sh deleted file mode 100755 index 9ebbb0d8c08fe7..00000000000000 --- a/flink-end-to-end-tests/test-scripts/test_shaded_presto_s3.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash -################################################################################ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ - -# Tests for our shaded/bundled Hadoop S3A file system. - -source "$(dirname "$0")"/common.sh -source "$(dirname "$0")"/common_s3.sh - -s3_setup presto -start_cluster - -$FLINK_DIR/bin/flink run -p 1 $FLINK_DIR/examples/batch/WordCount.jar --input $S3_TEST_DATA_WORDS_URI --output $TEST_DATA_DIR/out/wc_out - -check_result_hash "WordCountWithShadedPrestoS3" $TEST_DATA_DIR/out/wc_out "72a690412be8928ba239c2da967328a5" From 1c916a47322e04b42f8521cbdb833f8cc2147ee2 Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Mon, 20 May 2019 11:32:42 +0200 Subject: [PATCH 79/92] [FLINK-12556][e2e] Add read-only test FileSystem for end-to-end tests --- .../flink-plugins-test/pom.xml | 74 ++++++++ .../apache/flink/fs/dummy/DummyFSFactory.java | 51 ++++++ .../flink/fs/dummy/DummyFSFileStatus.java | 67 ++++++++ .../flink/fs/dummy/DummyFSFileSystem.java | 162 ++++++++++++++++++ .../flink/fs/dummy/DummyFSInputStream.java | 51 ++++++ ...org.apache.flink.core.fs.FileSystemFactory | 16 ++ flink-end-to-end-tests/pom.xml | 1 + 7 files changed, 422 insertions(+) create mode 100644 flink-end-to-end-tests/flink-plugins-test/pom.xml create mode 100644 flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFactory.java create mode 100644 flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileStatus.java create mode 100644 flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileSystem.java create mode 100644 flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSInputStream.java create mode 100644 flink-end-to-end-tests/flink-plugins-test/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory diff --git a/flink-end-to-end-tests/flink-plugins-test/pom.xml b/flink-end-to-end-tests/flink-plugins-test/pom.xml new file mode 100644 index 00000000000000..1c68fbb1d3c3f3 --- /dev/null +++ b/flink-end-to-end-tests/flink-plugins-test/pom.xml @@ -0,0 +1,74 @@ + + + + + flink-end-to-end-tests + org.apache.flink + 1.9-SNAPSHOT + + 4.0.0 + + flink-plugins-test + + jar + + + + org.apache.flink + flink-core + ${project.version} + provided + + + + + + + + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + org.apache.maven.plugins + maven-shade-plugin + + + shade-flink + package + + shade + + + flink-dummy-fs + + + + + + + diff --git a/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFactory.java b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFactory.java new file mode 100644 index 00000000000000..351879661b4d34 --- /dev/null +++ b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFactory.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.dummy; + +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.FileSystemFactory; + +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +/** + * Factory of dummy FileSystem. See documentation of {@link DummyFSFileSystem}. + */ +public class DummyFSFactory implements FileSystemFactory { + + private final FileSystem fileSystem = new DummyFSFileSystem(getData()); + + @Override + public String getScheme() { + return DummyFSFileSystem.FS_URI.getScheme(); + } + + @Override + public FileSystem create(URI fsUri) throws IOException { + return fileSystem; + } + + private static Map getData() { + Map data = new HashMap<>(); + data.put("/words", "Hello World how are you, my dear dear world\n"); + return data; + } +} diff --git a/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileStatus.java b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileStatus.java new file mode 100644 index 00000000000000..3c6037dc2892fa --- /dev/null +++ b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileStatus.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.dummy; + +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.Path; + +class DummyFSFileStatus implements FileStatus { + private final Path path; + private final int length; + + DummyFSFileStatus(Path path, int length) { + this.path = path; + this.length = length; + } + + @Override + public long getLen() { + return length; + } + + @Override + public long getBlockSize() { + return length; + } + + @Override + public short getReplication() { + return 0; + } + + @Override + public long getModificationTime() { + return 0; + } + + @Override + public long getAccessTime() { + return 0; + } + + @Override + public boolean isDir() { + return false; + } + + @Override + public Path getPath() { + return path; + } +} diff --git a/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileSystem.java b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileSystem.java new file mode 100644 index 00000000000000..ac118074a7c120 --- /dev/null +++ b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSFileSystem.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.dummy; + +import org.apache.flink.core.fs.BlockLocation; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.FSDataOutputStream; +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.FileSystemKind; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.fs.local.LocalBlockLocation; + +import javax.annotation.Nullable; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.HashMap; +import java.util.Map; + +/** + * A FileSystem implementation for integration testing purposes. Supports and serves read-only content from static + * key value map. + */ +class DummyFSFileSystem extends FileSystem { + + static final URI FS_URI = URI.create("dummy:///"); + + private static final String HOSTNAME = "localhost"; + + private final URI workingDir; + + private final URI homeDir; + + private final Map contents; + + DummyFSFileSystem(Map contents) { + this.workingDir = new File(System.getProperty("user.dir")).toURI(); + this.homeDir = new File(System.getProperty("user.home")).toURI(); + this.contents = convertToByteArrayMap(contents); + } + + // ------------------------------------------------------------------------ + + @Override + public URI getUri() { + return FS_URI; + } + + @Override + public Path getWorkingDirectory() { + return new Path(workingDir); + } + + @Override + public Path getHomeDirectory() { + return new Path(homeDir); + } + + @Override + public boolean exists(Path f) throws IOException { + return getDataByPath(f) != null; + } + + @Override + public FileStatus[] listStatus(final Path f) throws IOException { + byte[] data = getDataByPath(f); + if (data == null) { + return null; + } + return new FileStatus[] { new DummyFSFileStatus(f, data.length) }; + } + + @Override + public BlockLocation[] getFileBlockLocations(FileStatus file, long start, long len) throws IOException { + return new BlockLocation[] { + new LocalBlockLocation(HOSTNAME, file.getLen()) + }; + } + + @Override + public FileStatus getFileStatus(Path f) throws IOException { + byte[] data = getDataByPath(f); + if (data == null) { + throw new FileNotFoundException("File " + f + " does not exist."); + } + return new DummyFSFileStatus(f, data.length); + } + + @Override + public FSDataInputStream open(final Path f, final int bufferSize) throws IOException { + return open(f); + } + + @Override + public FSDataInputStream open(final Path f) throws IOException { + return DummyFSInputStream.create(getDataByPath(f)); + } + + @Override + public boolean delete(final Path path, final boolean recursive) throws IOException { + throw new UnsupportedOperationException("Dummy FS doesn't support delete operation"); + } + + @Override + public boolean mkdirs(final Path path) throws IOException { + throw new UnsupportedOperationException("Dummy FS doesn't support mkdirs operation"); + } + + @Override + public FSDataOutputStream create(final Path path, final WriteMode overwrite) throws IOException { + throw new UnsupportedOperationException("Dummy FS doesn't support create operation"); + } + + @Override + public boolean rename(final Path src, final Path dst) throws IOException { + throw new UnsupportedOperationException("Dummy FS doesn't support rename operation"); + } + + @Override + public boolean isDistributedFS() { + return true; + } + + @Override + public FileSystemKind getKind() { + return FileSystemKind.OBJECT_STORE; + } + + @Nullable + private byte[] getDataByPath(Path path) { + return contents.get(path.toUri().getPath()); + } + + private static Map convertToByteArrayMap(Map content) { + Map data = new HashMap<>(); + Charset utf8 = Charset.forName("UTF-8"); + content.entrySet().forEach( + entry -> data.put(entry.getKey(), entry.getValue().getBytes(utf8)) + ); + return data; + } +} diff --git a/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSInputStream.java b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSInputStream.java new file mode 100644 index 00000000000000..77b798db897c1b --- /dev/null +++ b/flink-end-to-end-tests/flink-plugins-test/src/main/java/org/apache/flink/fs/dummy/DummyFSInputStream.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.dummy; + +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; + +import java.io.IOException; + +class DummyFSInputStream extends FSDataInputStream { + private final ByteArrayInputStreamWithPos stream; + + private DummyFSInputStream(ByteArrayInputStreamWithPos stream) { + this.stream = stream; + } + + static DummyFSInputStream create(byte[] buffer) { + return new DummyFSInputStream(new ByteArrayInputStreamWithPos(buffer)); + } + + @Override + public void seek(long desired) throws IOException { + stream.setPosition((int) desired); + } + + @Override + public long getPos() throws IOException { + return stream.getPosition(); + } + + @Override + public int read() throws IOException { + return stream.read(); + } +} diff --git a/flink-end-to-end-tests/flink-plugins-test/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory b/flink-end-to-end-tests/flink-plugins-test/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory new file mode 100644 index 00000000000000..1c5ec8c05e9295 --- /dev/null +++ b/flink-end-to-end-tests/flink-plugins-test/src/main/resources/META-INF/services/org.apache.flink.core.fs.FileSystemFactory @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.flink.fs.dummy.DummyFSFactory diff --git a/flink-end-to-end-tests/pom.xml b/flink-end-to-end-tests/pom.xml index 73ea7efa0be220..0950e2f3f7b062 100644 --- a/flink-end-to-end-tests/pom.xml +++ b/flink-end-to-end-tests/pom.xml @@ -65,6 +65,7 @@ under the License. flink-streaming-kafka-test flink-streaming-kafka011-test flink-streaming-kafka010-test + flink-plugins-test From fc7a3f8f8cb3f8cb58f41a46c7ce054d63697c0f Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Fri, 31 May 2019 10:07:42 +0200 Subject: [PATCH 80/92] [hotfix][e2e] Refactor yarn kerberos test: setup Flink config by one command --- .../test-scripts/test_yarn_kerberos_docker.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh index 5f2dea2ea6a01f..d587990182f426 100755 --- a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh +++ b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh @@ -116,10 +116,14 @@ docker cp $FLINK_TARBALL_DIR/$FLINK_TARBALL master:/home/hadoop-user/ docker exec -it master bash -c "tar xzf /home/hadoop-user/$FLINK_TARBALL --directory /home/hadoop-user/" # minimal Flink config, bebe -docker exec -it master bash -c "echo \"security.kerberos.login.keytab: /home/hadoop-user/hadoop-user.keytab\" > /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" -docker exec -it master bash -c "echo \"security.kerberos.login.principal: hadoop-user\" >> /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" -docker exec -it master bash -c "echo \"slot.request.timeout: 60000\" >> /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" -docker exec -it master bash -c "echo \"containerized.heap-cutoff-min: 100\" >> /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" +FLINK_CONFIG=$(cat << END +security.kerberos.login.keytab: /home/hadoop-user/hadoop-user.keytab +security.kerberos.login.principal: hadoop-user +slot.request.timeout: 60000 +containerized.heap-cutoff-min: 100 +END +) +docker exec -it master bash -c "echo \"$FLINK_CONFIG\" > /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" echo "Flink config:" docker exec -it master bash -c "cat /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" From 422328843b5c98a7bfeeef89c22e05ecb9578e1c Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Fri, 31 May 2019 10:09:07 +0200 Subject: [PATCH 81/92] [hotfix][e2e] Refactor yarn kerberos test: use retry_times to start hadoop cluster --- .../test-scripts/test_yarn_kerberos_docker.sh | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh index d587990182f426..03fde92714d6e0 100755 --- a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh +++ b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh @@ -92,20 +92,7 @@ do sleep 2 done -CLUSTER_STARTED=1 -for (( i = 0; i < $CLUSTER_SETUP_RETRIES; i++ )) -do - if start_hadoop_cluster; then - echo "Cluster started successfully." - CLUSTER_STARTED=0 - break #continue test, cluster set up succeeded - fi - - echo "ERROR: Could not start hadoop cluster. Retrying..." - docker-compose -f $END_TO_END_DIR/test-scripts/docker-hadoop-secure-cluster/docker-compose.yml down -done - -if [[ ${CLUSTER_STARTED} -ne 0 ]]; then +if ! retry_times $CLUSTER_SETUP_RETRIES 0 start_hadoop_cluster; then echo "ERROR: Could not start hadoop cluster. Aborting..." exit 1 fi From 341ae363fd01b76aad2f33acbbb5cacc51cac31b Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Fri, 31 May 2019 10:10:52 +0200 Subject: [PATCH 82/92] [hotfix][e2e] Refactor yarn kerberos test: move Flink tarball creation to a later test step (just before when it's needed) --- .../test-scripts/test_yarn_kerberos_docker.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh index 03fde92714d6e0..74a45b7b8ffba8 100755 --- a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh +++ b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh @@ -76,10 +76,6 @@ function start_hadoop_cluster() { return 0 } - -mkdir -p $FLINK_TARBALL_DIR -tar czf $FLINK_TARBALL_DIR/$FLINK_TARBALL -C $(dirname $FLINK_DIR) . - echo "Building Hadoop Docker container" until docker build --build-arg HADOOP_VERSION=2.8.4 \ -f $END_TO_END_DIR/test-scripts/docker-hadoop-secure-cluster/Dockerfile \ @@ -97,6 +93,9 @@ if ! retry_times $CLUSTER_SETUP_RETRIES 0 start_hadoop_cluster; then exit 1 fi +mkdir -p $FLINK_TARBALL_DIR +tar czf $FLINK_TARBALL_DIR/$FLINK_TARBALL -C $(dirname $FLINK_DIR) . + docker cp $FLINK_TARBALL_DIR/$FLINK_TARBALL master:/home/hadoop-user/ # now, at least the container is ready From f9f431b954405a3a091bcb085e3c60db9b335db5 Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Fri, 31 May 2019 10:15:13 +0200 Subject: [PATCH 83/92] [hotfix][e2e] Refactor yarn kerberos test: group and generalize expected result checks --- .../test-scripts/test_yarn_kerberos_docker.sh | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh index 74a45b7b8ffba8..f9e0e64ee175ad 100755 --- a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh +++ b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh @@ -144,29 +144,22 @@ else exit 1 fi -if [[ ! "$OUTPUT" =~ "consummation,1" ]]; then - echo "Output does not contain (consummation, 1) as required" - mkdir -p $TEST_DATA_DIR/logs - echo "Hadoop logs:" - docker cp master:/var/log/hadoop/* $TEST_DATA_DIR/logs/ - for f in $TEST_DATA_DIR/logs/*; do - echo "$f:" - cat $f - done - echo "Docker logs:" - docker logs master - exit 1 -fi - -if [[ ! "$OUTPUT" =~ "of,14" ]]; then - echo "Output does not contain (of, 14) as required" - exit 1 -fi - -if [[ ! "$OUTPUT" =~ "calamity,1" ]]; then - echo "Output does not contain (calamity, 1) as required" - exit 1 -fi +EXPECTED_RESULT_LOG_CONTAINS=("consummation,1" "of,14" "calamity,1") +for expected_result in ${EXPECTED_RESULT_LOG_CONTAINS[@]}; do + if [[ ! "$OUTPUT" =~ $expected_result ]]; then + echo "Output does not contain '$expected_result' as required" + mkdir -p $TEST_DATA_DIR/logs + echo "Hadoop logs:" + docker cp master:/var/log/hadoop/* $TEST_DATA_DIR/logs/ + for f in $TEST_DATA_DIR/logs/*; do + echo "$f:" + cat $f + done + echo "Docker logs:" + docker logs master + exit 1 + fi +done echo "Running Job without configured keytab, the exception you see below is expected" docker exec -it master bash -c "echo \"\" > /home/hadoop-user/$FLINK_DIRNAME/conf/flink-conf.yaml" From b330bac99e44ee462cf8f1edb1f853ed2c05eaba Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Fri, 31 May 2019 10:19:45 +0200 Subject: [PATCH 84/92] [hotfix][e2e] Refactor yarn kerberos test: move logs copying and printing into a separate function --- .../test-scripts/test_yarn_kerberos_docker.sh | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh index f9e0e64ee175ad..ca12985acf9ba9 100755 --- a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh +++ b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh @@ -118,6 +118,18 @@ docker exec -it master bash -c "cat /home/hadoop-user/$FLINK_DIRNAME/conf/flink- # had cached docker containers OUTPUT_PATH=hdfs:///user/hadoop-user/wc-out-$RANDOM +function copy_and_show_logs { + mkdir -p $TEST_DATA_DIR/logs + echo "Hadoop logs:" + docker cp master:/var/log/hadoop/* $TEST_DATA_DIR/logs/ + for f in $TEST_DATA_DIR/logs/*; do + echo "$f:" + cat $f + done + echo "Docker logs:" + docker logs master +} + start_time=$(date +%s) # it's important to run this with higher parallelism, otherwise we might risk that # JM and TM are on the same YARN node and that we therefore don't test the keytab shipping @@ -132,15 +144,7 @@ then echo "$OUTPUT" else echo "Running the job failed." - mkdir -p $TEST_DATA_DIR/logs - echo "Hadoop logs:" - docker cp master:/var/log/hadoop/* $TEST_DATA_DIR/logs/ - for f in $TEST_DATA_DIR/logs/*; do - echo "$f:" - cat $f - done - echo "Docker logs:" - docker logs master + copy_and_show_logs exit 1 fi @@ -148,15 +152,7 @@ EXPECTED_RESULT_LOG_CONTAINS=("consummation,1" "of,14" "calamity,1") for expected_result in ${EXPECTED_RESULT_LOG_CONTAINS[@]}; do if [[ ! "$OUTPUT" =~ $expected_result ]]; then echo "Output does not contain '$expected_result' as required" - mkdir -p $TEST_DATA_DIR/logs - echo "Hadoop logs:" - docker cp master:/var/log/hadoop/* $TEST_DATA_DIR/logs/ - for f in $TEST_DATA_DIR/logs/*; do - echo "$f:" - cat $f - done - echo "Docker logs:" - docker logs master + copy_and_show_logs exit 1 fi done From 7909797f7fc4345687015d124e106ed6db785dd0 Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Fri, 31 May 2019 10:22:07 +0200 Subject: [PATCH 85/92] [hotfix][e2e] Refactor docker embedded job test: use pushd/popd instead of cd --- .../test-scripts/test_docker_embedded_job.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh b/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh index 2d8aa4f47d415c..67e5486621b836 100755 --- a/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh +++ b/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh @@ -40,12 +40,12 @@ build_image() { mkdir -p $OUTPUT_VOLUME chmod 777 $OUTPUT_VOLUME -cd "$DOCKER_MODULE_DIR" +pushd "$DOCKER_MODULE_DIR" if ! retry_times $DOCKER_IMAGE_BUILD_RETRIES ${BUILD_BACKOFF_TIME} build_image; then echo "Failed to build docker image. Aborting..." exit 1 fi -cd "$END_TO_END_DIR" +popd docker-compose -f ${DOCKER_MODULE_DIR}/docker-compose.yml -f ${DOCKER_SCRIPTS}/docker-compose.test.yml up --abort-on-container-exit --exit-code-from job-cluster &> /dev/null docker-compose -f ${DOCKER_MODULE_DIR}/docker-compose.yml -f ${DOCKER_SCRIPTS}/docker-compose.test.yml logs job-cluster > ${FLINK_DIR}/log/jobmanager.log From 5dfb95c092604b9fe67ccc821820c09b7a39be94 Mon Sep 17 00:00:00 2001 From: Aleksey Pak Date: Tue, 21 May 2019 10:37:47 +0200 Subject: [PATCH 86/92] [FLINK-12556][e2e] Add some end-to-end tests with custom (input) file system plugin --- flink-end-to-end-tests/run-nightly-tests.sh | 3 ++- .../run-pre-commit-tests.sh | 1 + .../test-scripts/test_batch_wordcount.sh | 4 ++++ .../test-scripts/test_docker_embedded_job.sh | 18 ++++++++++++++- .../test-scripts/test_yarn_kerberos_docker.sh | 22 +++++++++++++++++-- tools/travis/splits/split_container.sh | 5 +++-- 6 files changed, 47 insertions(+), 6 deletions(-) diff --git a/flink-end-to-end-tests/run-nightly-tests.sh b/flink-end-to-end-tests/run-nightly-tests.sh index a861189ebd86dc..af2151273e5e4c 100755 --- a/flink-end-to-end-tests/run-nightly-tests.sh +++ b/flink-end-to-end-tests/run-nightly-tests.sh @@ -88,7 +88,8 @@ run_test "Resuming Externalized Checkpoint after terminal failure (rocks, increm # Docker ################################################################################ -run_test "Running Kerberized YARN on Docker test " "$END_TO_END_DIR/test-scripts/test_yarn_kerberos_docker.sh" +run_test "Running Kerberized YARN on Docker test (default input)" "$END_TO_END_DIR/test-scripts/test_yarn_kerberos_docker.sh" +run_test "Running Kerberized YARN on Docker test (custom fs plugin)" "$END_TO_END_DIR/test-scripts/test_yarn_kerberos_docker.sh dummy-fs" ################################################################################ # High Availability diff --git a/flink-end-to-end-tests/run-pre-commit-tests.sh b/flink-end-to-end-tests/run-pre-commit-tests.sh index e79542a5467766..eb4b87e5d3c374 100755 --- a/flink-end-to-end-tests/run-pre-commit-tests.sh +++ b/flink-end-to-end-tests/run-pre-commit-tests.sh @@ -55,6 +55,7 @@ run_test "Streaming Python Wordcount end-to-end test" "$END_TO_END_DIR/test-scri run_test "Wordcount end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh file" run_test "Shaded Hadoop S3A end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh hadoop" run_test "Shaded Presto S3 end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh presto" +run_test "Custom FS plugin end-to-end test" "$END_TO_END_DIR/test-scripts/test_batch_wordcount.sh dummy-fs" run_test "Kafka 0.10 end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kafka010.sh" run_test "Kafka 0.11 end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kafka011.sh" run_test "Modern Kafka end-to-end test" "$END_TO_END_DIR/test-scripts/test_streaming_kafka.sh" diff --git a/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh b/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh index a5c82226e72da3..3b8400bcd0e17a 100755 --- a/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh +++ b/flink-end-to-end-tests/test-scripts/test_batch_wordcount.sh @@ -34,6 +34,10 @@ case $INPUT_TYPE in s3_setup presto INPUT_LOCATION="${S3_TEST_DATA_WORDS_URI}" ;; + (dummy-fs) + cp "${END_TO_END_DIR}/flink-plugins-test/target/flink-dummy-fs.jar" "${FLINK_DIR}/lib/" + INPUT_LOCATION="dummy://localhost/words" + ;; (*) echo "Unknown input type $INPUT_TYPE" exit 1 diff --git a/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh b/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh index 67e5486621b836..dabb3ce757b6a6 100755 --- a/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh +++ b/flink-end-to-end-tests/test-scripts/test_docker_embedded_job.sh @@ -30,7 +30,23 @@ export INPUT_VOLUME=${END_TO_END_DIR}/test-scripts/test-data export OUTPUT_VOLUME=${TEST_DATA_DIR}/out export INPUT_PATH=/data/test/input export OUTPUT_PATH=/data/test/output -export FLINK_JOB_ARGUMENTS="--input ${INPUT_PATH}/words --output ${OUTPUT_PATH}/docker_wc_out" + +INPUT_TYPE=${1:-file} +case $INPUT_TYPE in + (file) + INPUT_LOCATION=${INPUT_PATH}/words + ;; + (dummy-fs) + cp "${END_TO_END_DIR}/flink-plugins-test/target/flink-dummy-fs.jar" "${FLINK_DIR}/lib/" + INPUT_LOCATION="dummy://localhost/words" + ;; + (*) + echo "Unknown input type $INPUT_TYPE" + exit 1 + ;; +esac + +export FLINK_JOB_ARGUMENTS="--input ${INPUT_LOCATION} --output ${OUTPUT_PATH}/docker_wc_out" build_image() { ./build.sh --from-local-dist --job-jar ${FLINK_DIR}/examples/batch/WordCount.jar --image-name ${FLINK_DOCKER_IMAGE_NAME} diff --git a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh index ca12985acf9ba9..4d62f6fbe9899d 100755 --- a/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh +++ b/flink-end-to-end-tests/test-scripts/test_yarn_kerberos_docker.sh @@ -34,6 +34,25 @@ echo "End-to-end directory $END_TO_END_DIR" docker --version docker-compose --version +# Configure Flink dir before making tarball. +INPUT_TYPE=${1:-default-input} +EXPECTED_RESULT_LOG_CONTAINS=() +case $INPUT_TYPE in + (default-input) + INPUT_ARGS="" + EXPECTED_RESULT_LOG_CONTAINS=("consummation,1" "of,14" "calamity,1") + ;; + (dummy-fs) + cp "${END_TO_END_DIR}/flink-plugins-test/target/flink-dummy-fs.jar" "${FLINK_DIR}/lib/" + INPUT_ARGS="--input dummy://localhost/words" + EXPECTED_RESULT_LOG_CONTAINS=("my,1" "dear,2" "world,2") + ;; + (*) + echo "Unknown input type $INPUT_TYPE" + exit 1 + ;; +esac + # make sure we stop our cluster at the end function cluster_shutdown { # don't call ourselves again for another signal interruption @@ -135,7 +154,7 @@ start_time=$(date +%s) # JM and TM are on the same YARN node and that we therefore don't test the keytab shipping if docker exec -it master bash -c "export HADOOP_CLASSPATH=\`hadoop classpath\` && \ /home/hadoop-user/$FLINK_DIRNAME/bin/flink run -m yarn-cluster -yn 3 -ys 1 -ytm 1000 -yjm 1000 \ - -p 3 /home/hadoop-user/$FLINK_DIRNAME/examples/streaming/WordCount.jar --output $OUTPUT_PATH"; + -p 3 /home/hadoop-user/$FLINK_DIRNAME/examples/streaming/WordCount.jar $INPUT_ARGS --output $OUTPUT_PATH"; then docker exec -it master bash -c "kinit -kt /home/hadoop-user/hadoop-user.keytab hadoop-user" docker exec -it master bash -c "hdfs dfs -ls $OUTPUT_PATH" @@ -148,7 +167,6 @@ else exit 1 fi -EXPECTED_RESULT_LOG_CONTAINS=("consummation,1" "of,14" "calamity,1") for expected_result in ${EXPECTED_RESULT_LOG_CONTAINS[@]}; do if [[ ! "$OUTPUT" =~ $expected_result ]]; then echo "Output does not contain '$expected_result' as required" diff --git a/tools/travis/splits/split_container.sh b/tools/travis/splits/split_container.sh index d34d55e38c8c7e..587f397b2f7e74 100755 --- a/tools/travis/splits/split_container.sh +++ b/tools/travis/splits/split_container.sh @@ -43,8 +43,9 @@ echo "Flink distribution directory: $FLINK_DIR" # run_test "" "$END_TO_END_DIR/test-scripts/" -run_test "Wordcount on Docker test" "$END_TO_END_DIR/test-scripts/test_docker_embedded_job.sh" -run_test "Running Kerberized YARN on Docker test " "$END_TO_END_DIR/test-scripts/test_yarn_kerberos_docker.sh" +run_test "Wordcount on Docker test (custom fs plugin)" "$END_TO_END_DIR/test-scripts/test_docker_embedded_job.sh dummy-fs" +run_test "Running Kerberized YARN on Docker test (default input)" "$END_TO_END_DIR/test-scripts/test_yarn_kerberos_docker.sh" +run_test "Running Kerberized YARN on Docker test (custom fs plugin)" "$END_TO_END_DIR/test-scripts/test_yarn_kerberos_docker.sh dummy-fs" run_test "Run kubernetes test" "$END_TO_END_DIR/test-scripts/test_kubernetes_embedded_job.sh" printf "\n[PASS] All tests passed\n" From 24053227b5526c78d6e19331c7efa4c22eee1151 Mon Sep 17 00:00:00 2001 From: Seth Wiesman Date: Mon, 3 Jun 2019 10:53:47 -0500 Subject: [PATCH 87/92] [FLINK-12650][docs] Redirect Users to Documentation Homepage if Requested Resource Does Not Exist --- docs/404.md | 26 ++++++++++++++++++++++++++ docs/_layouts/404_base.html | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 docs/404.md create mode 100644 docs/_layouts/404_base.html diff --git a/docs/404.md b/docs/404.md new file mode 100644 index 00000000000000..42e390b8798ab9 --- /dev/null +++ b/docs/404.md @@ -0,0 +1,26 @@ +--- +title: "404" +permalink: /404.html +layout: 404_base +--- + + diff --git a/docs/_layouts/404_base.html b/docs/_layouts/404_base.html new file mode 100644 index 00000000000000..0acea3e4643a30 --- /dev/null +++ b/docs/_layouts/404_base.html @@ -0,0 +1,32 @@ + + + + + + + + 404 + + + Home. + + From 1f8c7ff7a41429cf5e9f8b1f8011cde64b83971b Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Wed, 5 Jun 2019 12:34:03 +0200 Subject: [PATCH 88/92] [hotfix][chck] Remove Nullable annotation from method with primitive return type ZooKeeperStateHandleStore#releaseAndTryRemove returns a primitive boolean and, thus, does not need a @Nullable annotation. --- .../flink/runtime/zookeeper/ZooKeeperStateHandleStore.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java index 62630104942c85..0b2eaa5a08c9b1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStore.java @@ -31,8 +31,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.annotation.Nullable; - import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; @@ -328,7 +326,6 @@ public List, String>> getAllAndLock() throws Ex * @return True if the state handle could be released * @throws Exception If the ZooKeeper operation or discarding the state handle fails */ - @Nullable public boolean releaseAndTryRemove(String pathInZooKeeper) throws Exception { checkNotNull(pathInZooKeeper, "Path in ZooKeeper"); From 41d0ac0c868466fbca21f9e69a41d868fbbe51ff Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Fri, 24 May 2019 14:41:16 +0800 Subject: [PATCH 89/92] [hotfix] Fix the typo issue --- .../apache/flink/runtime/taskmanager/Task.java | 4 ++-- .../io/network/NetworkEnvironmentTest.java | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 06df24e706bf0e..27d9cdcaa7f81b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -593,7 +593,7 @@ else if (current == ExecutionState.CANCELING) { LOG.info("Registering task at network: {}.", this); - setupPartionsAndGates(consumableNotifyingPartitionWriters, inputGates); + setupPartitionsAndGates(consumableNotifyingPartitionWriters, inputGates); for (ResultPartitionWriter partitionWriter : consumableNotifyingPartitionWriters) { taskEventDispatcher.registerPartition(partitionWriter.getPartitionId()); @@ -823,7 +823,7 @@ else if (transitionState(current, ExecutionState.FAILED, t)) { } @VisibleForTesting - public static void setupPartionsAndGates( + public static void setupPartitionsAndGates( ResultPartitionWriter[] producedPartitions, InputGate[] inputGates) throws IOException { for (ResultPartitionWriter partition : producedPartitions) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java index 5e41ec018a3d15..933f38556d2ecf 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java @@ -64,7 +64,7 @@ public static List parameters() { public ExpectedException expectedException = ExpectedException.none(); /** - * Verifies that {@link Task#setupPartionsAndGates(ResultPartitionWriter[], InputGate[])}} sets up (un)bounded buffer pool + * Verifies that {@link Task#setupPartitionsAndGates(ResultPartitionWriter[], InputGate[])}} sets up (un)bounded buffer pool * instances for various types of input and output channels. */ @Test @@ -87,7 +87,7 @@ public void testRegisterTaskUsesBoundedBuffers() throws Exception { SingleInputGate ig4 = createSingleInputGate(network, ResultPartitionType.PIPELINED_BOUNDED, 8); final SingleInputGate[] inputGates = new SingleInputGate[] {ig1, ig2, ig3, ig4}; - Task.setupPartionsAndGates(resultPartitions, inputGates); + Task.setupPartitionsAndGates(resultPartitions, inputGates); // verify buffer pools for the result partitions assertEquals(rp1.getNumberOfSubpartitions(), rp1.getBufferPool().getNumberOfRequiredMemorySegments()); @@ -128,7 +128,7 @@ public void testRegisterTaskUsesBoundedBuffers() throws Exception { } /** - * Verifies that {@link Task#setupPartionsAndGates(ResultPartitionWriter[], InputGate[])}} sets up (un)bounded buffer pool + * Verifies that {@link Task#setupPartitionsAndGates(ResultPartitionWriter[], InputGate[])}} sets up (un)bounded buffer pool * instances for various types of input and output channels working with the bare minimum of * required buffers. */ @@ -148,7 +148,7 @@ public void testRegisterTaskWithLimitedBuffers() throws Exception { } /** - * Verifies that {@link Task#setupPartionsAndGates(ResultPartitionWriter[], InputGate[])}} fails if the bare minimum of + * Verifies that {@link Task#setupPartitionsAndGates(ResultPartitionWriter[], InputGate[])}} fails if the bare minimum of * required buffers is not available (we are one buffer short). */ @Test @@ -208,7 +208,7 @@ private void testRegisterTaskWithLimitedBuffers(int bufferPoolSize) throws Excep createRemoteInputChannel(ig3, 3, rp4, connManager, network.getNetworkBufferPool()); } - Task.setupPartionsAndGates(resultPartitions, inputGates); + Task.setupPartitionsAndGates(resultPartitions, inputGates); // verify buffer pools for the result partitions assertEquals(Integer.MAX_VALUE, rp1.getBufferPool().getMaxNumberOfMemorySegments()); @@ -250,16 +250,16 @@ private void testRegisterTaskWithLimitedBuffers(int bufferPoolSize) throws Excep /** * Helper to create spy of a {@link SingleInputGate} for use by a {@link Task} inside - * {@link Task#setupPartionsAndGates(ResultPartitionWriter[], InputGate[])}}. + * {@link Task#setupPartitionsAndGates(ResultPartitionWriter[], InputGate[])}}. * * @param network - * network enviroment to create buffer pool factory for {@link SingleInputGate} + * network environment to create buffer pool factory for {@link SingleInputGate} * @param partitionType * the consumed partition type * @param numberOfChannels * the number of input channels * - * @return input gate with some fake settiFngs + * @return input gate with some fake settings */ private SingleInputGate createSingleInputGate( NetworkEnvironment network, ResultPartitionType partitionType, int numberOfChannels) { From 663b40d320baf94a5f81a30082240afab66ca085 Mon Sep 17 00:00:00 2001 From: Zhijiang Date: Wed, 29 May 2019 22:33:56 +0800 Subject: [PATCH 90/92] [FLINK-12603][network] Remove getOwningTaskName method from InputGate In order to make abstract InputGate simple for extending new implementations in shuffle service architecture, we could remove unnecessary methods from it. InputGate#getOwningTaskName is only used for debugging log in BarrierBuffer and StreamInputProcessor. This task name could also be generated in StreamTask via Environment#getTaskInfo and Environment#getExecutionId. Then it could be passed into the constructors of BarrierBuffer/StreamInputProcessor for use. This closes #8529. --- .../network/partition/consumer/InputGate.java | 2 - .../partition/consumer/SingleInputGate.java | 5 -- .../partition/consumer/UnionInputGate.java | 6 --- .../taskmanager/InputGateWithMetrics.java | 5 -- .../streaming/runtime/io/BarrierBuffer.java | 54 ++++++++----------- .../runtime/io/InputProcessorUtil.java | 15 ++++-- .../runtime/io/StreamInputProcessor.java | 10 +++- .../runtime/io/StreamTwoInputProcessor.java | 10 +++- .../runtime/tasks/OneInputStreamTask.java | 3 +- .../streaming/runtime/tasks/StreamTask.java | 10 ++++ .../runtime/tasks/TwoInputStreamTask.java | 3 +- .../io/BarrierBufferAlignmentLimitTest.java | 4 +- .../io/BarrierBufferMassiveRandomTest.java | 12 ----- .../streaming/runtime/io/MockInputGate.java | 12 ----- 14 files changed, 67 insertions(+), 84 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java index 7b87d321d97f23..1c8300ccff0443 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -76,8 +76,6 @@ public abstract class InputGate implements AutoCloseable { public abstract int getNumberOfInputChannels(); - public abstract String getOwningTaskName(); - public abstract boolean isFinished(); public abstract void requestPartitions() throws IOException, InterruptedException; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index 4e718d713b0897..8360e5e343f5ed 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -273,11 +273,6 @@ public int getNumberOfQueuedBuffers() { return 0; } - @Override - public String getOwningTaskName() { - return owningTaskName; - } - public CompletableFuture getCloseFuture() { return closeFuture; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java index 9777f7804b5ed5..8b4da35b991aad 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -136,12 +136,6 @@ public int getNumberOfInputChannels() { return totalNumberOfInputChannels; } - @Override - public String getOwningTaskName() { - // all input gates have the same owning task - return inputGates[0].getOwningTaskName(); - } - @Override public boolean isFinished() { for (InputGate inputGate : inputGates) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java index e9e303830a287a..631583b5d45578 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java @@ -55,11 +55,6 @@ public int getNumberOfInputChannels() { return inputGate.getNumberOfInputChannels(); } - @Override - public String getOwningTaskName() { - return inputGate.getOwningTaskName(); - } - @Override public boolean isFinished() { return inputGate.isFinished(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BarrierBuffer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BarrierBuffer.java index cbeb4ba74d6ae3..41a6eb5802606d 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BarrierBuffer.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/BarrierBuffer.java @@ -18,6 +18,7 @@ package org.apache.flink.streaming.runtime.io; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.decline.AlignmentLimitExceededException; @@ -79,6 +80,8 @@ public class BarrierBuffer implements CheckpointBarrierHandler { */ private final long maxBufferedBytes; + private final String taskName; + /** * The sequence of buffers/events that has been unblocked and must now be consumed before * requesting further data from the input gate. @@ -119,11 +122,10 @@ public class BarrierBuffer implements CheckpointBarrierHandler { * * @param inputGate The input gate to draw the buffers and events from. * @param bufferBlocker The buffer blocker to hold the buffers and events for channels with barrier. - * - * @throws IOException Thrown, when the spilling to temp files cannot be initialized. */ - public BarrierBuffer(InputGate inputGate, BufferBlocker bufferBlocker) throws IOException { - this (inputGate, bufferBlocker, -1); + @VisibleForTesting + BarrierBuffer(InputGate inputGate, BufferBlocker bufferBlocker) { + this (inputGate, bufferBlocker, -1, "Testing: No task associated"); } /** @@ -136,11 +138,9 @@ public BarrierBuffer(InputGate inputGate, BufferBlocker bufferBlocker) throws IO * @param inputGate The input gate to draw the buffers and events from. * @param bufferBlocker The buffer blocker to hold the buffers and events for channels with barrier. * @param maxBufferedBytes The maximum bytes to be buffered before the checkpoint aborts. - * - * @throws IOException Thrown, when the spilling to temp files cannot be initialized. + * @param taskName The task name for logging. */ - public BarrierBuffer(InputGate inputGate, BufferBlocker bufferBlocker, long maxBufferedBytes) - throws IOException { + BarrierBuffer(InputGate inputGate, BufferBlocker bufferBlocker, long maxBufferedBytes, String taskName) { checkArgument(maxBufferedBytes == -1 || maxBufferedBytes > 0); this.inputGate = inputGate; @@ -150,6 +150,8 @@ public BarrierBuffer(InputGate inputGate, BufferBlocker bufferBlocker, long maxB this.bufferBlocker = checkNotNull(bufferBlocker); this.queuedBuffered = new ArrayDeque(); + + this.taskName = taskName; } // ------------------------------------------------------------------------ @@ -213,7 +215,7 @@ else if (bufferOrEvent.getEvent().getClass() == CancelCheckpointMarker.class) { } private void completeBufferedSequence() throws IOException { - LOG.debug("{}: Finished feeding back buffered data.", inputGate.getOwningTaskName()); + LOG.debug("{}: Finished feeding back buffered data.", taskName); currentBuffered.cleanup(); currentBuffered = queuedBuffered.pollFirst(); @@ -249,7 +251,7 @@ else if (barrierId > currentCheckpointId) { // we did not complete the current checkpoint, another started before LOG.warn("{}: Received checkpoint barrier for checkpoint {} before completing current checkpoint {}. " + "Skipping current checkpoint.", - inputGate.getOwningTaskName(), + taskName, barrierId, currentCheckpointId); @@ -283,7 +285,7 @@ else if (barrierId > currentCheckpointId) { // actually trigger checkpoint if (LOG.isDebugEnabled()) { LOG.debug("{}: Received all barriers, triggering checkpoint {} at {}.", - inputGate.getOwningTaskName(), + taskName, receivedBarrier.getId(), receivedBarrier.getTimestamp()); } @@ -314,9 +316,7 @@ private void processCancellationBarrier(CancelCheckpointMarker cancelBarrier) th if (barrierId == currentCheckpointId) { // cancel this alignment if (LOG.isDebugEnabled()) { - LOG.debug("{}: Checkpoint {} canceled, aborting alignment.", - inputGate.getOwningTaskName(), - barrierId); + LOG.debug("{}: Checkpoint {} canceled, aborting alignment.", taskName, barrierId); } releaseBlocksAndResetBarriers(); @@ -326,7 +326,7 @@ else if (barrierId > currentCheckpointId) { // we canceled the next which also cancels the current LOG.warn("{}: Received cancellation barrier for checkpoint {} before completing current checkpoint {}. " + "Skipping current checkpoint.", - inputGate.getOwningTaskName(), + taskName, barrierId, currentCheckpointId); @@ -357,9 +357,7 @@ else if (barrierId > currentCheckpointId) { latestAlignmentDurationNanos = 0L; if (LOG.isDebugEnabled()) { - LOG.debug("{}: Checkpoint {} canceled, skipping alignment.", - inputGate.getOwningTaskName(), - barrierId); + LOG.debug("{}: Checkpoint {} canceled, skipping alignment.", taskName, barrierId); } notifyAbortOnCancellationBarrier(barrierId); @@ -414,7 +412,7 @@ private void checkSizeLimit() throws Exception { if (maxBufferedBytes > 0 && (numQueuedBytes + bufferBlocker.getBytesBlocked()) > maxBufferedBytes) { // exceeded our limit - abort this checkpoint LOG.info("{}: Checkpoint {} aborted because alignment volume limit ({} bytes) exceeded.", - inputGate.getOwningTaskName(), + taskName, currentCheckpointId, maxBufferedBytes); @@ -458,9 +456,7 @@ private void beginNewAlignment(long checkpointId, int channelIndex) throws IOExc startOfAlignmentTimestamp = System.nanoTime(); if (LOG.isDebugEnabled()) { - LOG.debug("{}: Starting stream alignment for checkpoint {}.", - inputGate.getOwningTaskName(), - checkpointId); + LOG.debug("{}: Starting stream alignment for checkpoint {}.", taskName, checkpointId); } } @@ -486,9 +482,7 @@ private void onBarrier(int channelIndex) throws IOException { numBarriersReceived++; if (LOG.isDebugEnabled()) { - LOG.debug("{}: Received barrier from channel {}.", - inputGate.getOwningTaskName(), - channelIndex); + LOG.debug("{}: Received barrier from channel {}.", taskName, channelIndex); } } else { @@ -501,8 +495,7 @@ private void onBarrier(int channelIndex) throws IOException { * Makes sure the just written data is the next to be consumed. */ private void releaseBlocksAndResetBarriers() throws IOException { - LOG.debug("{}: End of stream alignment, feeding buffered data back.", - inputGate.getOwningTaskName()); + LOG.debug("{}: End of stream alignment, feeding buffered data back.", taskName); for (int i = 0; i < blockedChannels.length; i++) { blockedChannels[i] = false; @@ -519,8 +512,7 @@ private void releaseBlocksAndResetBarriers() throws IOException { // uncommon case: buffered data pending // push back the pending data, if we have any LOG.debug("{}: Checkpoint skipped via buffered data:" + - "Pushing back current alignment buffers and feeding back new alignment data first.", - inputGate.getOwningTaskName()); + "Pushing back current alignment buffers and feeding back new alignment data first.", taskName); // since we did not fully drain the previous sequence, we need to allocate a new buffer for this one BufferOrEventSequence bufferedNow = bufferBlocker.rollOverWithoutReusingResources(); @@ -534,7 +526,7 @@ private void releaseBlocksAndResetBarriers() throws IOException { if (LOG.isDebugEnabled()) { LOG.debug("{}: Size of buffered data: {} bytes", - inputGate.getOwningTaskName(), + taskName, currentBuffered == null ? 0L : currentBuffered.size()); } @@ -577,7 +569,7 @@ public long getAlignmentDurationNanos() { @Override public String toString() { return String.format("%s: last checkpoint: %d, current barriers: %d, closed channels: %d", - inputGate.getOwningTaskName(), + taskName, currentCheckpointId, numBarriersReceived, numClosedChannels); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java index b4207819799bd6..da401a7fdeebdb 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java @@ -41,7 +41,8 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( CheckpointingMode checkpointMode, IOManager ioManager, InputGate inputGate, - Configuration taskManagerConfig) throws IOException { + Configuration taskManagerConfig, + String taskName) throws IOException { CheckpointBarrierHandler barrierHandler; if (checkpointMode == CheckpointingMode.EXACTLY_ONCE) { @@ -53,9 +54,17 @@ public static CheckpointBarrierHandler createCheckpointBarrierHandler( } if (taskManagerConfig.getBoolean(NetworkEnvironmentOptions.NETWORK_CREDIT_MODEL)) { - barrierHandler = new BarrierBuffer(inputGate, new CachedBufferBlocker(inputGate.getPageSize()), maxAlign); + barrierHandler = new BarrierBuffer( + inputGate, + new CachedBufferBlocker(inputGate.getPageSize()), + maxAlign, + taskName); } else { - barrierHandler = new BarrierBuffer(inputGate, new BufferSpiller(ioManager, inputGate.getPageSize()), maxAlign); + barrierHandler = new BarrierBuffer( + inputGate, + new BufferSpiller(ioManager, inputGate.getPageSize()), + maxAlign, + taskName); } } else if (checkpointMode == CheckpointingMode.AT_LEAST_ONCE) { barrierHandler = new BarrierTracker(inputGate); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java index a9c64b5f6fe9ef..8420ae044250c6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamInputProcessor.java @@ -121,12 +121,18 @@ public StreamInputProcessor( StreamStatusMaintainer streamStatusMaintainer, OneInputStreamOperator streamOperator, TaskIOMetricGroup metrics, - WatermarkGauge watermarkGauge) throws IOException { + WatermarkGauge watermarkGauge, + String taskName) throws IOException { InputGate inputGate = InputGateUtil.createInputGate(inputGates); this.barrierHandler = InputProcessorUtil.createCheckpointBarrierHandler( - checkpointedTask, checkpointMode, ioManager, inputGate, taskManagerConfig); + checkpointedTask, + checkpointMode, + ioManager, + inputGate, + taskManagerConfig, + taskName); this.lock = checkNotNull(lock); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java index ab4f90dcf23f06..e8c9c2a5bec0db 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessor.java @@ -140,12 +140,18 @@ public StreamTwoInputProcessor( TwoInputStreamOperator streamOperator, TaskIOMetricGroup metrics, WatermarkGauge input1WatermarkGauge, - WatermarkGauge input2WatermarkGauge) throws IOException { + WatermarkGauge input2WatermarkGauge, + String taskName) throws IOException { final InputGate inputGate = InputGateUtil.createInputGate(inputGates1, inputGates2); this.barrierHandler = InputProcessorUtil.createCheckpointBarrierHandler( - checkpointedTask, checkpointMode, ioManager, inputGate, taskManagerConfig); + checkpointedTask, + checkpointMode, + ioManager, + inputGate, + taskManagerConfig, + taskName); this.lock = checkNotNull(lock); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java index 7b82d8f67456cc..76091ffb578c45 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java @@ -88,7 +88,8 @@ public void init() throws Exception { getStreamStatusMaintainer(), this.headOperator, getEnvironment().getMetricGroup().getIOMetricGroup(), - inputWatermarkGauge); + inputWatermarkGauge, + getTaskNameWithSubtaskAndId()); } headOperator.getMetricGroup().gauge(MetricNames.IO_CURRENT_INPUT_WATERMARK, this.inputWatermarkGauge); // wrap watermark gauge since registered metrics must be unique diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 2df565d7cf4ed2..3927e46468a005 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -599,6 +599,16 @@ public String getName() { return getEnvironment().getTaskInfo().getTaskNameWithSubtasks(); } + /** + * Gets the name of the task, appended with the subtask indicator and execution id. + * + * @return The name of the task, with subtask indicator and execution id. + */ + String getTaskNameWithSubtaskAndId() { + return getEnvironment().getTaskInfo().getTaskNameWithSubtasks() + + " (" + getEnvironment().getExecutionId() + ')'; + } + /** * Gets the lock object on which all operations that involve data and state mutation have to lock. * @return The checkpoint lock object. diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java index 934f2cbbc875c9..2092c45d0a8305 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java @@ -98,7 +98,8 @@ public void init() throws Exception { this.headOperator, getEnvironment().getMetricGroup().getIOMetricGroup(), input1WatermarkGauge, - input2WatermarkGauge); + input2WatermarkGauge, + getTaskNameWithSubtaskAndId()); headOperator.getMetricGroup().gauge(MetricNames.IO_CURRENT_INPUT_WATERMARK, minInputWatermarkGauge); headOperator.getMetricGroup().gauge(MetricNames.IO_CURRENT_INPUT_1_WATERMARK, input1WatermarkGauge); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferAlignmentLimitTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferAlignmentLimitTest.java index 1c7ff350962ffa..91f2be44bd8916 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferAlignmentLimitTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferAlignmentLimitTest.java @@ -115,7 +115,7 @@ public void testBreakCheckpointAtAlignmentLimit() throws Exception { // the barrier buffer has a limit that only 1000 bytes may be spilled in alignment MockInputGate gate = new MockInputGate(PAGE_SIZE, 3, Arrays.asList(sequence)); - BarrierBuffer buffer = new BarrierBuffer(gate, new BufferSpiller(ioManager, gate.getPageSize()), 1000); + BarrierBuffer buffer = new BarrierBuffer(gate, new BufferSpiller(ioManager, gate.getPageSize()), 1000, "Testing"); AbstractInvokable toNotify = mock(AbstractInvokable.class); buffer.registerCheckpointEventHandler(toNotify); @@ -209,7 +209,7 @@ public void testAlignmentLimitWithQueuedAlignments() throws Exception { // the barrier buffer has a limit that only 1000 bytes may be spilled in alignment MockInputGate gate = new MockInputGate(PAGE_SIZE, 3, Arrays.asList(sequence)); - BarrierBuffer buffer = new BarrierBuffer(gate, new BufferSpiller(ioManager, gate.getPageSize()), 500); + BarrierBuffer buffer = new BarrierBuffer(gate, new BufferSpiller(ioManager, gate.getPageSize()), 500, "Testing"); AbstractInvokable toNotify = mock(AbstractInvokable.class); buffer.registerCheckpointEventHandler(toNotify); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java index 428407432a978c..9d0d70514b66a0 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java @@ -138,18 +138,11 @@ private static class RandomGeneratingInputGate extends InputGate { private int currentChannel = 0; private long c = 0; - private final String owningTaskName; - public RandomGeneratingInputGate(BufferPool[] bufferPools, BarrierGenerator[] barrierGens) { - this(bufferPools, barrierGens, "TestTask"); - } - - public RandomGeneratingInputGate(BufferPool[] bufferPools, BarrierGenerator[] barrierGens, String owningTaskName) { this.numberOfChannels = bufferPools.length; this.currentBarriers = new int[numberOfChannels]; this.bufferPools = bufferPools; this.barrierGens = barrierGens; - this.owningTaskName = owningTaskName; this.isAvailable = AVAILABLE; } @@ -158,11 +151,6 @@ public int getNumberOfInputChannels() { return numberOfChannels; } - @Override - public String getOwningTaskName() { - return owningTaskName; - } - @Override public boolean isFinished() { return false; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index b37dee4c25cd13..1a4c5b744b3ffa 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -43,18 +43,11 @@ public class MockInputGate extends InputGate { private int closedChannels; - private final String owningTaskName; - public MockInputGate(int pageSize, int numberOfChannels, List bufferOrEvents) { - this(pageSize, numberOfChannels, bufferOrEvents, "MockTask"); - } - - public MockInputGate(int pageSize, int numberOfChannels, List bufferOrEvents, String owningTaskName) { this.pageSize = pageSize; this.numberOfChannels = numberOfChannels; this.bufferOrEvents = new ArrayDeque(bufferOrEvents); this.closed = new boolean[numberOfChannels]; - this.owningTaskName = owningTaskName; isAvailable = AVAILABLE; } @@ -73,11 +66,6 @@ public int getNumberOfInputChannels() { return numberOfChannels; } - @Override - public String getOwningTaskName() { - return owningTaskName; - } - @Override public boolean isFinished() { return bufferOrEvents.isEmpty(); From 88baa34f24251f84d9cb7aa4fc15dd9d4e622fda Mon Sep 17 00:00:00 2001 From: leesf <490081539@qq.com> Date: Sun, 26 May 2019 10:41:22 +0800 Subject: [PATCH 91/92] [FLINK-12101] Race condition when concurrently running uploaded jars via REST --- .../flink/api/java/ExecutionEnvironment.java | 14 +++++++++++--- .../environment/StreamExecutionEnvironment.java | 8 ++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java b/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java index beb1b65c4a5d8c..f9fc1ef8eb6da0 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java @@ -104,6 +104,9 @@ public abstract class ExecutionEnvironment { /** The environment of the context (local by default, cluster if invoked through command line). */ private static ExecutionEnvironmentFactory contextEnvironmentFactory; + /** The ThreadLocal used to store {@link ExecutionEnvironmentFactory}. */ + private static ThreadLocal contextEnvironmentFactoryThreadLocal = new ThreadLocal<>(); + /** The default parallelism used by local environments. */ private static int defaultLocalDop = Runtime.getRuntime().availableProcessors(); @@ -1061,8 +1064,11 @@ private static String getDefaultName() { * @return The execution environment of the context in which the program is executed. */ public static ExecutionEnvironment getExecutionEnvironment() { - return contextEnvironmentFactory == null ? - createLocalEnvironment() : contextEnvironmentFactory.createExecutionEnvironment(); + + return contextEnvironmentFactoryThreadLocal.get() == null ? + (contextEnvironmentFactory == null ? + createLocalEnvironment() : contextEnvironmentFactory.createExecutionEnvironment()) : + contextEnvironmentFactoryThreadLocal.get().createExecutionEnvironment(); } /** @@ -1253,6 +1259,7 @@ public static void setDefaultLocalParallelism(int parallelism) { */ protected static void initializeContextEnvironment(ExecutionEnvironmentFactory ctx) { contextEnvironmentFactory = Preconditions.checkNotNull(ctx); + contextEnvironmentFactoryThreadLocal.set(contextEnvironmentFactory); } /** @@ -1262,6 +1269,7 @@ protected static void initializeContextEnvironment(ExecutionEnvironmentFactory c */ protected static void resetContextEnvironment() { contextEnvironmentFactory = null; + contextEnvironmentFactoryThreadLocal.remove(); } /** @@ -1273,6 +1281,6 @@ protected static void resetContextEnvironment() { */ @Internal public static boolean areExplicitEnvironmentsAllowed() { - return contextEnvironmentFactory == null; + return contextEnvironmentFactory == null && contextEnvironmentFactoryThreadLocal.get() == null; } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java index 7ac1ac68985013..f93fd4c50da29a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java @@ -117,6 +117,9 @@ public abstract class StreamExecutionEnvironment { */ private static StreamExecutionEnvironmentFactory contextEnvironmentFactory; + /** The ThreadLocal used to store {@link StreamExecutionEnvironmentFactory}. */ + private static ThreadLocal contextEnvironmentFactoryThreadLocal = new ThreadLocal<>(); + /** The default parallelism used when creating a local environment. */ private static int defaultLocalParallelism = Runtime.getRuntime().availableProcessors(); @@ -1568,6 +1571,9 @@ public void addOperator(StreamTransformation transformation) { * executed. */ public static StreamExecutionEnvironment getExecutionEnvironment() { + if (contextEnvironmentFactoryThreadLocal.get() != null) { + return contextEnvironmentFactoryThreadLocal.get().createExecutionEnvironment(); + } if (contextEnvironmentFactory != null) { return contextEnvironmentFactory.createExecutionEnvironment(); } @@ -1766,10 +1772,12 @@ public static void setDefaultLocalParallelism(int parallelism) { protected static void initializeContextEnvironment(StreamExecutionEnvironmentFactory ctx) { contextEnvironmentFactory = ctx; + contextEnvironmentFactoryThreadLocal.set(contextEnvironmentFactory); } protected static void resetContextEnvironment() { contextEnvironmentFactory = null; + contextEnvironmentFactoryThreadLocal.remove(); } /** From 26241e1fe3393519953b6280ff7283c013dc384c Mon Sep 17 00:00:00 2001 From: Till Rohrmann Date: Wed, 5 Jun 2019 15:03:01 +0200 Subject: [PATCH 92/92] [FLINK-12101] Deduplicate code by introducing ExecutionEnvironment#resolveFactory ExecutionEnvironment#resolveFactory selects between the thread local and the global factory. This method is used by the ExecutionEnvironment as well as the StreamExecutionEnvironment. This closes #8543. --- .../flink/api/java/ExecutionEnvironment.java | 35 ++++++++++++++----- .../StreamExecutionEnvironment.java | 19 +++++----- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java b/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java index f9fc1ef8eb6da0..435cedf841eb36 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/ExecutionEnvironment.java @@ -64,6 +64,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; + import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; @@ -73,6 +75,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.Set; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -102,10 +105,10 @@ public abstract class ExecutionEnvironment { protected static final Logger LOG = LoggerFactory.getLogger(ExecutionEnvironment.class); /** The environment of the context (local by default, cluster if invoked through command line). */ - private static ExecutionEnvironmentFactory contextEnvironmentFactory; + private static ExecutionEnvironmentFactory contextEnvironmentFactory = null; /** The ThreadLocal used to store {@link ExecutionEnvironmentFactory}. */ - private static ThreadLocal contextEnvironmentFactoryThreadLocal = new ThreadLocal<>(); + private static final ThreadLocal threadLocalContextEnvironmentFactory = new ThreadLocal<>(); /** The default parallelism used by local environments. */ private static int defaultLocalDop = Runtime.getRuntime().availableProcessors(); @@ -1064,11 +1067,25 @@ private static String getDefaultName() { * @return The execution environment of the context in which the program is executed. */ public static ExecutionEnvironment getExecutionEnvironment() { + return resolveFactory(threadLocalContextEnvironmentFactory, contextEnvironmentFactory) + .map(ExecutionEnvironmentFactory::createExecutionEnvironment) + .orElseGet(ExecutionEnvironment::createLocalEnvironment); + } + + /** + * Resolves the given factories. The thread local factory has preference over the static factory. + * If none is set, the method returns {@link Optional#empty()}. + * + * @param threadLocalFactory containing the thread local factory + * @param staticFactory containing the global factory + * @param type of factory + * @return Optional containing the resolved factory if it exists, otherwise it's empty + */ + public static Optional resolveFactory(ThreadLocal threadLocalFactory, @Nullable T staticFactory) { + final T localFactory = threadLocalFactory.get(); + final T factory = localFactory == null ? staticFactory : localFactory; - return contextEnvironmentFactoryThreadLocal.get() == null ? - (contextEnvironmentFactory == null ? - createLocalEnvironment() : contextEnvironmentFactory.createExecutionEnvironment()) : - contextEnvironmentFactoryThreadLocal.get().createExecutionEnvironment(); + return Optional.ofNullable(factory); } /** @@ -1259,7 +1276,7 @@ public static void setDefaultLocalParallelism(int parallelism) { */ protected static void initializeContextEnvironment(ExecutionEnvironmentFactory ctx) { contextEnvironmentFactory = Preconditions.checkNotNull(ctx); - contextEnvironmentFactoryThreadLocal.set(contextEnvironmentFactory); + threadLocalContextEnvironmentFactory.set(contextEnvironmentFactory); } /** @@ -1269,7 +1286,7 @@ protected static void initializeContextEnvironment(ExecutionEnvironmentFactory c */ protected static void resetContextEnvironment() { contextEnvironmentFactory = null; - contextEnvironmentFactoryThreadLocal.remove(); + threadLocalContextEnvironmentFactory.remove(); } /** @@ -1281,6 +1298,6 @@ protected static void resetContextEnvironment() { */ @Internal public static boolean areExplicitEnvironmentsAllowed() { - return contextEnvironmentFactory == null && contextEnvironmentFactoryThreadLocal.get() == null; + return contextEnvironmentFactory == null && threadLocalContextEnvironmentFactory.get() == null; } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java index f93fd4c50da29a..29cc3247b93dbd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java @@ -115,10 +115,10 @@ public abstract class StreamExecutionEnvironment { /** * The environment of the context (local by default, cluster if invoked through command line). */ - private static StreamExecutionEnvironmentFactory contextEnvironmentFactory; + private static StreamExecutionEnvironmentFactory contextEnvironmentFactory = null; /** The ThreadLocal used to store {@link StreamExecutionEnvironmentFactory}. */ - private static ThreadLocal contextEnvironmentFactoryThreadLocal = new ThreadLocal<>(); + private static final ThreadLocal threadLocalContextEnvironmentFactory = new ThreadLocal<>(); /** The default parallelism used when creating a local environment. */ private static int defaultLocalParallelism = Runtime.getRuntime().availableProcessors(); @@ -1571,13 +1571,12 @@ public void addOperator(StreamTransformation transformation) { * executed. */ public static StreamExecutionEnvironment getExecutionEnvironment() { - if (contextEnvironmentFactoryThreadLocal.get() != null) { - return contextEnvironmentFactoryThreadLocal.get().createExecutionEnvironment(); - } - if (contextEnvironmentFactory != null) { - return contextEnvironmentFactory.createExecutionEnvironment(); - } + return ExecutionEnvironment.resolveFactory(threadLocalContextEnvironmentFactory, contextEnvironmentFactory) + .map(StreamExecutionEnvironmentFactory::createExecutionEnvironment) + .orElseGet(StreamExecutionEnvironment::createStreamExecutionEnvironment); + } + private static StreamExecutionEnvironment createStreamExecutionEnvironment() { // because the streaming project depends on "flink-clients" (and not the other way around) // we currently need to intercept the data set environment and create a dependent stream env. // this should be fixed once we rework the project dependencies @@ -1772,12 +1771,12 @@ public static void setDefaultLocalParallelism(int parallelism) { protected static void initializeContextEnvironment(StreamExecutionEnvironmentFactory ctx) { contextEnvironmentFactory = ctx; - contextEnvironmentFactoryThreadLocal.set(contextEnvironmentFactory); + threadLocalContextEnvironmentFactory.set(contextEnvironmentFactory); } protected static void resetContextEnvironment() { contextEnvironmentFactory = null; - contextEnvironmentFactoryThreadLocal.remove(); + threadLocalContextEnvironmentFactory.remove(); } /**

    blob.client.socket.timeout
    300000The socket timeout in milliseconds for the blob client.
    blob.client.connect.timeout
    0The connection timeout in milliseconds for the blob client.
    blob.fetch.backlog
    1000