From 754fd6b73c3d7271b6c1fcc3416b2c3996073e43 Mon Sep 17 00:00:00 2001 From: lincoln lee Date: Fri, 26 Mar 2021 10:03:34 +0800 Subject: [PATCH] [FLINK-21946][table-planner-blink] FlinkRelMdUtil.numDistinctVals produces exceptional Double.NaN result when domainSize is in range(0,1) This closes #15357 --- .../planner/plan/utils/FlinkRelMdUtil.scala | 2 +- .../FlinkRelMdDistinctRowCountTest.scala | 3 +-- .../plan/utils/FlinkRelMdUtilTest.scala | 7 ++++++ .../runtime/batch/sql/join/JoinITCase.scala | 23 ++++++++++++++++--- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala index 9204bdc8c72a6..b904d4461dc57 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala @@ -236,7 +236,7 @@ object FlinkRelMdUtil { */ def numDistinctVals(domainSize: Double, numSelected: Double): Double = { val EPS = 1e-9 - if (Math.abs(1 / domainSize) < EPS) { + if (Math.abs(1 / domainSize) < EPS || domainSize < 1) { // ln(1+x) ~= x for small x val dSize = RelMdUtil.capInfinity(domainSize) val numSel = RelMdUtil.capInfinity(numSelected) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala index d81e8cdcff01a..308870298e9e6 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala @@ -21,7 +21,6 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalRank import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil -import org.apache.calcite.rel.metadata.RelMdUtil import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.apache.calcite.util.ImmutableBitSet import org.junit.Assert._ @@ -82,7 +81,7 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { mq.getDistinctRowCount(logicalValues, ImmutableBitSet.of(0, 1), null)) (0 until logicalValues.getRowType.getFieldCount).foreach { idx => - assertEquals(Double.NaN, mq.getDistinctRowCount(emptyValues, ImmutableBitSet.of(idx), null)) + assertEquals(1.0, mq.getDistinctRowCount(emptyValues, ImmutableBitSet.of(idx), null)) } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala index cd3949252b6b4..f90d22db9373b 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala @@ -28,6 +28,13 @@ class FlinkRelMdUtilTest { Assert.assertEquals( RelMdUtil.numDistinctVals(1e5, 1e4), FlinkRelMdUtil.numDistinctVals(1e5, 1e4)) + + Assert.assertEquals( + BigDecimal(0.31606027941427883), + BigDecimal.valueOf(FlinkRelMdUtil.numDistinctVals(0.5, 0.5))) + + // This case should be removed once CALCITE-4351 is fixed. + Assert.assertEquals(Double.NaN, RelMdUtil.numDistinctVals(0.5, 0.5)) } @Test diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala index 148269a74e834..39ebeca0ca423 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala @@ -24,7 +24,8 @@ import org.apache.flink.api.common.typeinfo.Types import org.apache.flink.api.common.typeutils.TypeComparator import org.apache.flink.api.dag.Transformation import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo} -import org.apache.flink.streaming.api.transformations.{OneInputTransformation, LegacySinkTransformation, TwoInputTransformation} +import org.apache.flink.streaming.api.transformations.{LegacySinkTransformation, OneInputTransformation, TwoInputTransformation} +import org.apache.flink.table.api.internal.TableEnvironmentInternal import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.expressions.utils.FuncWithOpen import org.apache.flink.table.planner.runtime.batch.sql.join.JoinType.{BroadcastHashJoin, HashJoin, JoinType, NestedLoopJoin, SortMergeJoin} @@ -34,12 +35,13 @@ import org.apache.flink.table.planner.runtime.utils.TestData._ import org.apache.flink.table.planner.sinks.CollectRowTableSink import org.apache.flink.table.planner.utils.{TestingStatementSet, TestingTableEnvironment} import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory +import org.apache.flink.types.Row + import org.junit.runner.RunWith import org.junit.runners.Parameterized import org.junit.{Assert, Before, Test} -import java.util -import org.apache.flink.table.api.internal.TableEnvironmentInternal +import java.util import scala.collection.JavaConversions._ import scala.collection.Seq @@ -588,6 +590,21 @@ class JoinITCase(expectedJoinType: JoinType) extends BatchTestBase { Seq(row(2, 1.0), row(2, 1.0))) } + @Test + def testCorrelatedExist2(): Unit = { + val data: Seq[Row] = Seq( + row(0L), + row(123456L), + row(-123456L), + row(2147483647L), + row(-2147483647L)) + registerCollection("t1", data, new RowTypeInfo(LONG_TYPE_INFO), "f1") + + checkResult( + "select * from t1 o where exists (select 1 from t1 i where i.f1=o.f1 limit 0)", + Seq()) + } + @Test def testCorrelatedNotExist(): Unit = { checkResult(