Skip to content

Commit

Permalink
[FLINK-21946][table-planner-blink] FlinkRelMdUtil.numDistinctVals pro…
Browse files Browse the repository at this point in the history
…duces exceptional Double.NaN result when domainSize is in range(0,1)

This closes apache#15357
  • Loading branch information
lincoln-lil committed Mar 26, 2021
1 parent 57e93c9 commit 754fd6b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ object FlinkRelMdUtil {
*/
def numDistinctVals(domainSize: Double, numSelected: Double): Double = {
val EPS = 1e-9
if (Math.abs(1 / domainSize) < EPS) {
if (Math.abs(1 / domainSize) < EPS || domainSize < 1) {
// ln(1+x) ~= x for small x
val dSize = RelMdUtil.capInfinity(domainSize)
val numSel = RelMdUtil.capInfinity(numSelected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package org.apache.flink.table.planner.plan.metadata
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalRank
import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil

import org.apache.calcite.rel.metadata.RelMdUtil
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.util.ImmutableBitSet
import org.junit.Assert._
Expand Down Expand Up @@ -82,7 +81,7 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase {
mq.getDistinctRowCount(logicalValues, ImmutableBitSet.of(0, 1), null))

(0 until logicalValues.getRowType.getFieldCount).foreach { idx =>
assertEquals(Double.NaN, mq.getDistinctRowCount(emptyValues, ImmutableBitSet.of(idx), null))
assertEquals(1.0, mq.getDistinctRowCount(emptyValues, ImmutableBitSet.of(idx), null))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class FlinkRelMdUtilTest {
Assert.assertEquals(
RelMdUtil.numDistinctVals(1e5, 1e4),
FlinkRelMdUtil.numDistinctVals(1e5, 1e4))

Assert.assertEquals(
BigDecimal(0.31606027941427883),
BigDecimal.valueOf(FlinkRelMdUtil.numDistinctVals(0.5, 0.5)))

// This case should be removed once CALCITE-4351 is fixed.
Assert.assertEquals(Double.NaN, RelMdUtil.numDistinctVals(0.5, 0.5))
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.flink.api.common.typeinfo.Types
import org.apache.flink.api.common.typeutils.TypeComparator
import org.apache.flink.api.dag.Transformation
import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo}
import org.apache.flink.streaming.api.transformations.{OneInputTransformation, LegacySinkTransformation, TwoInputTransformation}
import org.apache.flink.streaming.api.transformations.{LegacySinkTransformation, OneInputTransformation, TwoInputTransformation}
import org.apache.flink.table.api.internal.TableEnvironmentInternal
import org.apache.flink.table.planner.delegation.PlannerBase
import org.apache.flink.table.planner.expressions.utils.FuncWithOpen
import org.apache.flink.table.planner.runtime.batch.sql.join.JoinType.{BroadcastHashJoin, HashJoin, JoinType, NestedLoopJoin, SortMergeJoin}
Expand All @@ -34,12 +35,13 @@ import org.apache.flink.table.planner.runtime.utils.TestData._
import org.apache.flink.table.planner.sinks.CollectRowTableSink
import org.apache.flink.table.planner.utils.{TestingStatementSet, TestingTableEnvironment}
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.types.Row

import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{Assert, Before, Test}
import java.util

import org.apache.flink.table.api.internal.TableEnvironmentInternal
import java.util

import scala.collection.JavaConversions._
import scala.collection.Seq
Expand Down Expand Up @@ -588,6 +590,21 @@ class JoinITCase(expectedJoinType: JoinType) extends BatchTestBase {
Seq(row(2, 1.0), row(2, 1.0)))
}

@Test
def testCorrelatedExist2(): Unit = {
val data: Seq[Row] = Seq(
row(0L),
row(123456L),
row(-123456L),
row(2147483647L),
row(-2147483647L))
registerCollection("t1", data, new RowTypeInfo(LONG_TYPE_INFO), "f1")

checkResult(
"select * from t1 o where exists (select 1 from t1 i where i.f1=o.f1 limit 0)",
Seq())
}

@Test
def testCorrelatedNotExist(): Unit = {
checkResult(
Expand Down

0 comments on commit 754fd6b

Please sign in to comment.