Skip to content

Commit

Permalink
[SPARK-37468][SQL] Support ANSI intervals and TimestampNTZ for UnionE…
Browse files Browse the repository at this point in the history
…stimation

### What changes were proposed in this pull request?

This PR proposes to support ANSI intervals and TimestampNTZ for `UnionEstimation`.
Currently, `UnionEstimation` doesn't support ANSI intervals and TimestampNTZ. But I think it can support those types because their underlying types are integer or long, which `UnionEstimation` can compute stats for.

### Why are the changes needed?

To make CBO better.

### Does this PR introduce _any_ user-facing change?

No. Not interferes with the current behavior.

### How was this patch tested?

Modified `UnionEstimationSuite`.

Closes apache#34716 from sarutak/union-cbo.

Lead-authored-by: Kousuke Saruta <[email protected]>
Co-authored-by: Kousuke Saruta <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
2 people authored and MaxGekk committed Nov 29, 2021
1 parent a6ca481 commit 7484c1b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ object UnionEstimation {
case TimestampType => (a: Any, b: Any) =>
TimestampType.ordering.lt(a.asInstanceOf[TimestampType.InternalType],
b.asInstanceOf[TimestampType.InternalType])
case TimestampNTZType => (a: Any, b: Any) =>
TimestampNTZType.ordering.lt(a.asInstanceOf[TimestampNTZType.InternalType],
b.asInstanceOf[TimestampNTZType.InternalType])
case i: AnsiIntervalType => (a: Any, b: Any) =>
i.ordering.lt(a.asInstanceOf[i.InternalType], b.asInstanceOf[i.InternalType])
case _ =>
throw new IllegalStateException(s"Unsupported data type: ${dt.catalogString}")
}

private def isTypeSupported(dt: DataType): Boolean = dt match {
case ByteType | IntegerType | ShortType | FloatType | LongType |
DoubleType | DateType | _: DecimalType | TimestampType => true
DoubleType | DateType | _: DecimalType | TimestampType | TimestampNTZType |
_: AnsiIntervalType => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
val attrDecimal = AttributeReference("cdecimal", DecimalType(5, 4))()
val attrDate = AttributeReference("cdate", DateType)()
val attrTimestamp = AttributeReference("ctimestamp", TimestampType)()
val attrTimestampNTZ = AttributeReference("ctimestamp_ntz", TimestampNTZType)()
val attrYMInterval = AttributeReference("cyminterval", YearMonthIntervalType())()
val attrDTInterval = AttributeReference("cdtinterval", DayTimeIntervalType())()

val s1 = 1.toShort
val s2 = 4.toShort
Expand Down Expand Up @@ -84,7 +87,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
attrFloat -> ColumnStat(min = Some(1.1f), max = Some(4.1f)),
attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.5))),
attrDate -> ColumnStat(min = Some(1), max = Some(4)),
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L))))
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)),
attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(4L)),
attrYMInterval -> ColumnStat(min = Some(2), max = Some(5)),
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(5L))))

val s3 = 2.toShort
val s4 = 6.toShort
Expand Down Expand Up @@ -118,7 +124,16 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
AttributeReference("cdate1", DateType)() -> ColumnStat(min = Some(3), max = Some(6)),
AttributeReference("ctimestamp1", TimestampType)() -> ColumnStat(
min = Some(3L),
max = Some(6L))))
max = Some(6L)),
AttributeReference("ctimestamp_ntz1", TimestampNTZType)() -> ColumnStat(
min = Some(3L),
max = Some(6L)),
AttributeReference("cymtimestamp1", YearMonthIntervalType())() -> ColumnStat(
min = Some(4),
max = Some(8)),
AttributeReference("cdttimestamp1", DayTimeIntervalType())() -> ColumnStat(
min = Some(4L),
max = Some(8L))))

val child1 = StatsTestPlan(
outputList = columnInfo.keys.toSeq.sortWith(_.exprId.id < _.exprId.id),
Expand Down Expand Up @@ -147,7 +162,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
attrFloat -> ColumnStat(min = Some(1.1f), max = Some(6.1f)),
attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.9))),
attrDate -> ColumnStat(min = Some(1), max = Some(6)),
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)))))
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)),
attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(6L)),
attrYMInterval -> ColumnStat(min = Some(2), max = Some(8)),
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(8L)))))
assert(union.stats === expectedStats)
}

Expand Down

0 comments on commit 7484c1b

Please sign in to comment.