diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/TableEnvHiveConnectorTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/TableEnvHiveConnectorTest.java index db318ae741cd1..1da85f8b10f0d 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/TableEnvHiveConnectorTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/TableEnvHiveConnectorTest.java @@ -225,8 +225,7 @@ public void testInsertOverwrite() throws Exception { HiveTestUtils.createTextTableInserter(hiveShell, "db1", "dest").addRow(new Object[]{1, "a"}).addRow(new Object[]{2, "b"}).commit(); verifyHiveQueryResult("select * from db1.dest", Arrays.asList("1\ta", "2\tb")); TableEnvironment tableEnv = getTableEnvWithHiveCatalog(); - // TODO: remove the cast once FLINK-15381 is fixed. - tableEnv.sqlUpdate("insert overwrite db1.dest values (3,cast('c' as varchar))"); + tableEnv.sqlUpdate("insert overwrite db1.dest values (3, 'c')"); tableEnv.execute("test insert overwrite"); verifyHiveQueryResult("select * from db1.dest", Collections.singletonList("3\tc")); diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMdCollation.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMdCollation.java new file mode 100644 index 0000000000000..7306684293975 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMdCollation.java @@ -0,0 +1,561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.metadata; + +import org.apache.calcite.adapter.enumerable.EnumerableCorrelate; +import org.apache.calcite.adapter.enumerable.EnumerableHashJoin; +import org.apache.calcite.adapter.enumerable.EnumerableMergeJoin; +import org.apache.calcite.adapter.enumerable.EnumerableNestedLoopJoin; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.plan.volcano.RelSubset; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Match; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.SortExchange; +import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.core.Values; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.metadata.BuiltInMetadata; +import org.apache.calcite.rel.metadata.MetadataDef; +import org.apache.calcite.rel.metadata.MetadataHandler; +import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider; +import org.apache.calcite.rel.metadata.RelMetadataProvider; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCallBinding; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.sql.validate.SqlMonotonicity; +import org.apache.calcite.util.Bug; +import org.apache.calcite.util.BuiltInMethod; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.stream.Collectors; + +/** + * FlinkRelMdCollation supplies a default implementation of + * {@link org.apache.calcite.rel.metadata.RelMetadataQuery#collations} + * for the standard logical algebra. + */ +public class FlinkRelMdCollation implements MetadataHandler { + public static final RelMetadataProvider SOURCE = + ReflectiveRelMetadataProvider.reflectiveSource(BuiltInMethod.COLLATIONS.method, new FlinkRelMdCollation()); + + //~ Constructors ----------------------------------------------------------- + + private FlinkRelMdCollation() { + } + + //~ Methods ---------------------------------------------------------------- + + public MetadataDef getDef() { + return BuiltInMetadata.Collation.DEF; + } + + public com.google.common.collect.ImmutableList collations(TableScan scan, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(table(scan.getTable())); + } + + public com.google.common.collect.ImmutableList collations(Values values, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(values(mq, values.getRowType(), values.getTuples())); + } + + public com.google.common.collect.ImmutableList collations(Project project, + RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(project(mq, project.getInput(), project.getProjects())); + } + + public com.google.common.collect.ImmutableList collations(Filter rel, RelMetadataQuery mq) { + return mq.collations(rel.getInput()); + } + + public com.google.common.collect.ImmutableList collations(Calc calc, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(calc(mq, calc.getInput(), calc.getProgram())); + } + + public com.google.common.collect.ImmutableList collations(SortExchange sort, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(sort(sort.getCollation())); + } + + public com.google.common.collect.ImmutableList collations(Sort sort, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(sort(sort.getCollation())); + } + + public com.google.common.collect.ImmutableList collations(Window rel, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(window(mq, rel.getInput(), rel.groups)); + } + + public com.google.common.collect.ImmutableList collations( + EnumerableCorrelate join, + RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf( + enumerableCorrelate(mq, join.getLeft(), join.getRight(), join.getJoinType())); + } + + public com.google.common.collect.ImmutableList collations( + EnumerableMergeJoin join, + RelMetadataQuery mq) { + // In general a join is not sorted. But a merge join preserves the sort + // order of the left and right sides. + return com.google.common.collect.ImmutableList.copyOf(mergeJoin( + mq, + join.getLeft(), + join.getRight(), + join.analyzeCondition().leftKeys, + join.analyzeCondition().rightKeys)); + } + + public com.google.common.collect.ImmutableList collations( + EnumerableHashJoin join, + RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf( + enumerableHashJoin(mq, join.getLeft(), join.getRight(), join.getJoinType())); + } + + public com.google.common.collect.ImmutableList collations( + EnumerableNestedLoopJoin join, + RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf( + enumerableNestedLoopJoin(mq, join.getLeft(), join.getRight(), join.getJoinType())); + } + + public com.google.common.collect.ImmutableList collations(Match rel, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.copyOf(match( + mq, + rel.getInput(), + rel.getRowType(), + rel.getPattern(), + rel.isStrictStart(), + rel.isStrictEnd(), + rel.getPatternDefinitions(), + rel.getMeasures(), + rel.getAfter(), + rel.getSubsets(), + rel.isAllRows(), + rel.getPartitionKeys(), + rel.getOrderKeys(), + rel.getInterval())); + } + + public com.google.common.collect.ImmutableList collations(TableModify rel, RelMetadataQuery mq) { + return mq.collations(rel.getInput()); + } + + public com.google.common.collect.ImmutableList collations(HepRelVertex rel, RelMetadataQuery mq) { + return mq.collations(rel.getCurrentRel()); + } + + public com.google.common.collect.ImmutableList collations(RelSubset subset, RelMetadataQuery mq) { + if (!Bug.CALCITE_1048_FIXED) { + //if the best node is null, so we can get the collation based original node, due to + //the original node is logically equivalent as the rel. + RelNode rel = Util.first(subset.getBest(), subset.getOriginal()); + return mq.collations(rel); + } else { + throw new RuntimeException("CALCITE_1048 is fixed, so check this method again!"); + } + } + + /** + * Catch-all implementation for + * {@link BuiltInMetadata.Collation#collations()}, + * invoked using reflection, for any relational expression not + * handled by a more specific method. + * + *

{@link org.apache.calcite.rel.core.Union}, + * {@link org.apache.calcite.rel.core.Intersect}, + * {@link org.apache.calcite.rel.core.Minus}, + * {@link org.apache.calcite.rel.core.Join}, + * {@link org.apache.calcite.rel.core.Correlate} + * do not in general return sorted results + * (but implementations using particular algorithms may). + * + * @param rel Relational expression + * @return Relational expression's collations + * @see org.apache.calcite.rel.metadata.RelMetadataQuery#collations(RelNode) + */ + public com.google.common.collect.ImmutableList collations(RelNode rel, RelMetadataQuery mq) { + return com.google.common.collect.ImmutableList.of(); + } + + + // Helper methods + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.TableScan}'s collation. + */ + public static List table(RelOptTable table) { + return table.getCollationList(); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Values}'s collation. + * + *

We actually under-report the collations. A Values with 0 or 1 rows - an + * edge case, but legitimate and very common - is ordered by every permutation + * of every subset of the columns. + * + *

So, our algorithm aims to:

    + *
  • produce at most N collations (where N is the number of columns); + *
  • make each collation as long as possible; + *
  • do not repeat combinations already emitted - + * if we've emitted {@code (a, b)} do not later emit {@code (b, a)}; + *
  • probe the actual values and make sure that each collation is + * consistent with the data + *
+ * + *

So, for an empty Values with 4 columns, we would emit + * {@code (a, b, c, d), (b, c, d), (c, d), (d)}. + */ + public static List values( + RelMetadataQuery mq, + RelDataType rowType, + com.google.common.collect.ImmutableList> tuples) { + Util.discard(mq); // for future use + final List list = new ArrayList<>(); + final int n = rowType.getFieldCount(); + final List>>> pairs = + new ArrayList<>(); + outer: + for (int i = 0; i < n; i++) { + pairs.clear(); + for (int j = i; j < n; j++) { + final RelFieldCollation fieldCollation = new RelFieldCollation(j); + com.google.common.collect.Ordering> comparator = comparator(fieldCollation); + com.google.common.collect.Ordering> ordering; + if (pairs.isEmpty()) { + ordering = comparator; + } else { + ordering = Util.last(pairs).right.compound(comparator); + } + pairs.add(Pair.of(fieldCollation, ordering)); + if (!ordering.isOrdered(tuples)) { + if (j == i) { + continue outer; + } + pairs.remove(pairs.size() - 1); + } + } + if (!pairs.isEmpty()) { + list.add(RelCollations.of(Pair.left(pairs))); + } + } + return list; + } + + private static com.google.common.collect.Ordering> comparator(RelFieldCollation fieldCollation) { + final int nullComparison = fieldCollation.nullDirection.nullComparison; + final int x = fieldCollation.getFieldIndex(); + switch (fieldCollation.direction) { + case ASCENDING: + return new com.google.common.collect.Ordering>() { + public int compare(List o1, List o2) { + final Comparable c1 = o1.get(x).getValueAs(Comparable.class); + final Comparable c2 = o2.get(x).getValueAs(Comparable.class); + return RelFieldCollation.compare(c1, c2, nullComparison); + } + }; + default: + return new com.google.common.collect.Ordering>() { + public int compare(List o1, List o2) { + final Comparable c1 = o1.get(x).getValueAs(Comparable.class); + final Comparable c2 = o2.get(x).getValueAs(Comparable.class); + return RelFieldCollation.compare(c2, c1, -nullComparison); + } + }; + } + } + + /** Helper method to determine a {@link Project}'s collation. */ + public static List project(RelMetadataQuery mq, RelNode input, List projects) { + final SortedSet collations = new TreeSet<>(); + final List inputCollations = mq.collations(input); + if (inputCollations == null || inputCollations.isEmpty()) { + return com.google.common.collect.ImmutableList.of(); + } + final com.google.common.collect.Multimap targets = + com.google.common.collect.LinkedListMultimap.create(); + final Map targetsWithMonotonicity = new HashMap<>(); + for (Ord project : Ord.zip(projects)) { + if (project.e instanceof RexInputRef) { + targets.put(((RexInputRef) project.e).getIndex(), project.i); + } else if (project.e instanceof RexCall) { + final RexCall call = (RexCall) project.e; + final RexCallBinding binding = + RexCallBinding.create(input.getCluster().getTypeFactory(), call, inputCollations); + targetsWithMonotonicity.put(project.i, call.getOperator().getMonotonicity(binding)); + } + } + final List fieldCollations = new ArrayList<>(); + loop: + for (RelCollation ic : inputCollations) { + if (ic.getFieldCollations().isEmpty()) { + continue; + } + fieldCollations.clear(); + for (RelFieldCollation ifc : ic.getFieldCollations()) { + final Collection integers = targets.get(ifc.getFieldIndex()); + if (integers.isEmpty()) { + continue loop; // cannot do this collation + } + fieldCollations.add(ifc.withFieldIndex(integers.iterator().next())); + } + assert !fieldCollations.isEmpty(); + collations.add(RelCollations.of(fieldCollations)); + } + + final List fieldCollationsForRexCalls = new ArrayList<>(); + for (Map.Entry entry : targetsWithMonotonicity.entrySet()) { + final SqlMonotonicity value = entry.getValue(); + switch (value) { + case NOT_MONOTONIC: + case CONSTANT: + break; + default: + fieldCollationsForRexCalls.add( + new RelFieldCollation(entry.getKey(), RelFieldCollation.Direction.of(value))); + break; + } + } + + if (!fieldCollationsForRexCalls.isEmpty()) { + collations.add(RelCollations.of(fieldCollationsForRexCalls)); + } + + return com.google.common.collect.ImmutableList.copyOf(collations); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Filter}'s collation. + */ + public static List filter(RelMetadataQuery mq, RelNode input) { + return mq.collations(input); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Calc}'s collation. + */ + public static List calc(RelMetadataQuery mq, RelNode input, RexProgram program) { + final List projects = program + .getProjectList() + .stream() + .map(program::expandLocalRef) + .collect(Collectors.toList()); + return project(mq, input, projects); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Snapshot}'s collation. + */ + public static List snapshot(RelMetadataQuery mq, RelNode input) { + return mq.collations(input); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Sort}'s collation. + */ + public static List sort(RelCollation collation) { + return com.google.common.collect.ImmutableList.of(collation); + } + + + /** + * Helper method to determine a + * limit's collation. + */ + public static List limit(RelMetadataQuery mq, RelNode input) { + return mq.collations(input); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Window}'s collation. + * + *

A Window projects the fields of its input first, followed by the output + * from each of its windows. Assuming (quite reasonably) that the + * implementation does not re-order its input rows, then any collations of its + * input are preserved. + */ + public static List window( + RelMetadataQuery mq, + RelNode input, + com.google.common.collect.ImmutableList groups) { + return mq.collations(input); + } + + /** + * Helper method to determine a + * {@link org.apache.calcite.rel.core.Match}'s collation. + */ + public static List match(RelMetadataQuery mq, RelNode input, + RelDataType rowType, + RexNode pattern, + boolean strictStart, + boolean strictEnd, + Map patternDefinitions, + Map measures, + RexNode after, + Map> subsets, + boolean allRows, + ImmutableBitSet partitionKeys, + RelCollation orderKeys, + RexNode interval) { + return mq.collations(input); + } + + /** + * Helper method to determine a {@link Join}'s collation assuming that it + * uses a merge-join algorithm. + * + *

If the inputs are sorted on other keys in addition to the join + * key, the result preserves those collations too. + */ + public static List mergeJoin( + RelMetadataQuery mq, + RelNode left, RelNode right, + ImmutableIntList leftKeys, + ImmutableIntList rightKeys) { + final com.google.common.collect.ImmutableList.Builder builder = + com.google.common.collect.ImmutableList.builder(); + + final com.google.common.collect.ImmutableList leftCollations = mq.collations(left); + assert RelCollations.contains(leftCollations, leftKeys) + : "cannot merge join: left input is not sorted on left keys"; + builder.addAll(leftCollations); + + final com.google.common.collect.ImmutableList rightCollations = mq.collations(right); + assert RelCollations.contains(rightCollations, rightKeys) + : "cannot merge join: right input is not sorted on right keys"; + final int leftFieldCount = left.getRowType().getFieldCount(); + for (RelCollation collation : rightCollations) { + builder.add(RelCollations.shift(collation, leftFieldCount)); + } + return builder.build(); + } + + /** + * Returns the collation of {@link EnumerableHashJoin} based on its inputs and the join type. + */ + public static List enumerableHashJoin(RelMetadataQuery mq, + RelNode left, RelNode right, JoinRelType joinType) { + if (joinType == JoinRelType.SEMI) { + return enumerableSemiJoin(mq, left, right); + } else { + return enumerableJoin0(mq, left, right, joinType); + } + } + + /** + * Returns the collation of {@link EnumerableNestedLoopJoin} + * based on its inputs and the join type. + */ + public static List enumerableNestedLoopJoin( + RelMetadataQuery mq, + RelNode left, + RelNode right, + JoinRelType joinType) { + return enumerableJoin0(mq, left, right, joinType); + } + + public static List enumerableCorrelate( + RelMetadataQuery mq, + RelNode left, + RelNode right, + JoinRelType joinType) { + // The current implementation always preserve the sort order of the left input + return mq.collations(left); + } + + public static List enumerableSemiJoin( + RelMetadataQuery mq, + RelNode left, + RelNode right) { + // The current implementation always preserve the sort order of the left input + return mq.collations(left); + } + + public static List enumerableBatchNestedLoopJoin( + RelMetadataQuery mq, + RelNode left, + RelNode right, + JoinRelType joinType) { + // The current implementation always preserve the sort order of the left input + return mq.collations(left); + } + + private static List enumerableJoin0( + RelMetadataQuery mq, + RelNode left, + RelNode right, + JoinRelType joinType) { + // The current implementation can preserve the sort order of the left input if one of the + // following conditions hold: + // (i) join type is INNER or LEFT; + // (ii) RelCollation always orders nulls last. + final com.google.common.collect.ImmutableList leftCollations = mq.collations(left); + switch (joinType) { + case SEMI: + case ANTI: + case INNER: + case LEFT: + return leftCollations; + case RIGHT: + case FULL: + for (RelCollation collation : leftCollations) { + for (RelFieldCollation field : collation.getFieldCollations()) { + if (!(RelFieldCollation.NullDirection.LAST == field.nullDirection)) { + return com.google.common.collect.ImmutableList.of(); + } + } + } + return leftCollations; + } + return com.google.common.collect.ImmutableList.of(); + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala index a1c3cec278103..6eb2d3a7bb93b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala @@ -46,7 +46,7 @@ object FlinkDefaultRelMetadataProvider { RelMdMaxRowCount.SOURCE, RelMdMinRowCount.SOURCE, RelMdPredicates.SOURCE, - RelMdCollation.SOURCE, + FlinkRelMdCollation.SOURCE, RelMdExplainVisibility.SOURCE ) ) diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/CalcTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/CalcTest.xml index dd55968f65e09..9eb1c4d75878f 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/CalcTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/CalcTest.xml @@ -30,6 +30,23 @@ LogicalProject(EXPR$0=[ARRAY(_UTF-16LE'Hi', _UTF-16LE'Hello', $2)]) + + + + + + + + + + + diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml index 11599dff39df0..d90e9fcd847c8 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml @@ -408,9 +408,9 @@ LogicalSort(sort0=[$0], sort1=[$1], sort2=[$2], dir0=[ASC-nulls-first], dir1=[AS rexBuilder.makeLiteral(v.toLong, fieldType, true) + case INTEGER => rexBuilder.makeLiteral(v.toInt, fieldType, true) case BOOLEAN => rexBuilder.makeLiteral(v.toBoolean) case DATE => rexBuilder.makeDateLiteral(new DateString(v)) case TIME => rexBuilder.makeTimeLiteral(new TimeString(v), 0) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala new file mode 100644 index 0000000000000..6260215a6210f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.metadata + +import com.google.common.collect.ImmutableList +import org.apache.calcite.rel.logical.{LogicalFilter, LogicalProject, LogicalValues} +import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation} +import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.sql.fun.SqlStdOperatorTable.{LESS_THAN, PLUS} +import org.junit.Assert.assertEquals +import org.junit.Test + +import scala.collection.JavaConversions._ + +class FlinkRelMdRowCollationTest extends FlinkRelMdHandlerTestBase { + + protected lazy val collationValues: LogicalValues = { + val valuesType = relBuilder.getTypeFactory + .builder() + .add("a", SqlTypeName.BIGINT) + .add("b", SqlTypeName.DOUBLE) + .add("c", SqlTypeName.BOOLEAN) + .add("d", SqlTypeName.INTEGER) + .build() + val tupleList = List( + List("1", "9.0", "true", "2"), + List("2", "6.0", "false", "3"), + List("3", "3.0", "true", "4") + ).map(createLiteralList(valuesType, _)) + relBuilder.clear() + relBuilder.values(tupleList, valuesType) + relBuilder.build().asInstanceOf[LogicalValues] + } + + @Test + def testCollationsOnTableScan(): Unit = { + Array(studentLogicalScan, studentBatchScan, studentStreamScan).foreach { scan => + assertEquals(ImmutableList.of(), mq.collations(scan)) + } + } + + @Test + def testCollationsOnValues(): Unit = { + assertEquals(ImmutableList.of(RelCollations.of(6)), mq.collations(logicalValues)) + assertEquals( + ImmutableList.of( + convertToRelCollation(List.range(0, 8)), + convertToRelCollation(List.range(1, 8)), + convertToRelCollation(List.range(2, 8)), + convertToRelCollation(List.range(3, 8)), + convertToRelCollation(List.range(4, 8)), + convertToRelCollation(List.range(5, 8)), + convertToRelCollation(List.range(6, 8)), + convertToRelCollation(List.range(7, 8)) + ), + mq.collations(emptyValues)) + assertEquals( + ImmutableList.of(convertToRelCollation(List.range(0, 4)), RelCollations.of(3)), + mq.collations(collationValues)) + } + + @Test + def testCollationsOnProject(): Unit = { + assertEquals(ImmutableList.of(), mq.collations(logicalProject)) + + val project: LogicalProject = { + relBuilder.push(collationValues) + val projects = List( + // a + b + relBuilder.call(PLUS, relBuilder.field(0), relBuilder.literal(1)), + // c + relBuilder.field(2), + // d + relBuilder.field(3), + // 2 + rexBuilder.makeLiteral(2L, longType, true) + ) + relBuilder.project(projects).build().asInstanceOf[LogicalProject] + } + assertEquals(ImmutableList.of(RelCollations.of(2)), mq.collations(project)) + } + + @Test + def testCollationsOnFilter(): Unit = { + assertEquals(ImmutableList.of(), mq.collations(logicalFilter)) + + relBuilder.push(studentLogicalScan) + val filter: LogicalFilter = { + relBuilder.push(collationValues) + // a < 10 + val expr = relBuilder.call(LESS_THAN, relBuilder.field(0), relBuilder.literal(10)) + relBuilder.filter(expr).build.asInstanceOf[LogicalFilter] + } + assertEquals( + ImmutableList.of(convertToRelCollation(List.range(0, 4)), RelCollations.of(3)), + mq.collations(filter)) + } + + @Test + def testCollationsOnExpand(): Unit = { + Array(logicalExpand, flinkLogicalExpand, batchExpand, streamExpand).foreach { + expand => assertEquals(ImmutableList.of(), mq.collations(expand)) + } + } + + @Test + def testCollationsOnExchange(): Unit = { + Array(batchExchange, streamExchange).foreach { + exchange => assertEquals(ImmutableList.of(), mq.collations(exchange)) + } + } + + @Test + def testCollationsOnRank(): Unit = { + Array(logicalRank, flinkLogicalRank, batchLocalRank, streamRank).foreach { + rank => assertEquals(ImmutableList.of(), mq.collations(rank)) + } + } + + @Test + def testCollationsOnSort(): Unit = { + Array(logicalSort, flinkLogicalSort, batchSort, streamSort, + logicalSortLimit, flinkLogicalSortLimit, batchSortLimit, streamSortLimit).foreach { sort => + assertEquals( + ImmutableList.of(RelCollations.of( + new RelFieldCollation(6), + new RelFieldCollation(2, RelFieldCollation.Direction.DESCENDING))), + mq.collations(sort)) + } + + Array(logicalLimit, logicalLimit, batchLimit, streamLimit).foreach { limit => + assertEquals(ImmutableList.of(RelCollations.of()), mq.collations(limit)) + } + } + + @Test + def testCollationsOnWindow(): Unit = { + assertEquals(ImmutableList.of(), mq.collations(flinkLogicalOverAgg)) + } + + @Test + def testCollationsOnAggregate(): Unit = { + Array(logicalAgg, flinkLogicalAgg, batchGlobalAggWithLocal, batchGlobalAggWithoutLocal, + batchLocalAgg).foreach { + agg => assertEquals(ImmutableList.of(), mq.collations(agg)) + } + } + + @Test + def testCollationsOnJoin(): Unit = { + Array(logicalInnerJoinOnUniqueKeys, logicalLeftJoinNotOnUniqueKeys, + logicalRightJoinOnRHSUniqueKeys, logicalFullJoinWithoutEquiCond, + logicalSemiJoinOnLHSUniqueKeys, logicalAntiJoinOnRHSUniqueKeys).foreach { + join => assertEquals(ImmutableList.of(), mq.collations(join)) + } + } + + @Test + def testCollationsOnUnion(): Unit = { + Array(logicalUnion, logicalUnionAll).foreach { + union => assertEquals(ImmutableList.of(), mq.collations(union)) + } + } + + @Test + def testCollationsOnIntersect(): Unit = { + Array(logicalIntersect, logicalIntersectAll).foreach { + intersect => assertEquals(ImmutableList.of(), mq.collations(intersect)) + } + } + + @Test + def testCollationsOnMinus(): Unit = { + Array(logicalMinus, logicalMinusAll).foreach { + minus => assertEquals(ImmutableList.of(), mq.collations(minus)) + } + } + + @Test + def testCollationsOnDefault(): Unit = { + assertEquals(ImmutableList.of(), mq.collations(testRel)) + } + + private def convertToRelCollation(relFieldCollations: List[Int]): RelCollation = { + RelCollations.of(relFieldCollations.map(i => new RelFieldCollation(i)): _*) + } +}