Skip to content

Commit

Permalink
Support ANSI intervals and TimestampNTZ for UnionEstimation.
Browse files Browse the repository at this point in the history
  • Loading branch information
sarutak committed Nov 26, 2021
1 parent 95fc4c5 commit 7927295
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,23 @@ object UnionEstimation {
case TimestampType => (a: Any, b: Any) =>
TimestampType.ordering.lt(a.asInstanceOf[TimestampType.InternalType],
b.asInstanceOf[TimestampType.InternalType])
case TimestampNTZType => (a: Any, b: Any) =>
TimestampNTZType.ordering.lt(a.asInstanceOf[TimestampNTZType.InternalType],
b.asInstanceOf[TimestampNTZType.InternalType])
case y: YearMonthIntervalType => (a: Any, b: Any) =>
y.ordering.lt(a.asInstanceOf[y.InternalType],
b.asInstanceOf[y.InternalType])
case d: DayTimeIntervalType => (a: Any, b: Any) =>
d.ordering.lt(a.asInstanceOf[d.InternalType],
b.asInstanceOf[d.InternalType])
case _ =>
throw new IllegalStateException(s"Unsupported data type: ${dt.catalogString}")
}

private def isTypeSupported(dt: DataType): Boolean = dt match {
case ByteType | IntegerType | ShortType | FloatType | LongType |
DoubleType | DateType | _: DecimalType | TimestampType => true
DoubleType | DateType | _: DecimalType | TimestampType | TimestampNTZType |
_: YearMonthIntervalType | _: DayTimeIntervalType => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
val attrDecimal = AttributeReference("cdecimal", DecimalType(5, 4))()
val attrDate = AttributeReference("cdate", DateType)()
val attrTimestamp = AttributeReference("ctimestamp", TimestampType)()
val attrTimestampNTZ = AttributeReference("ctimestamp_ntz", TimestampNTZType)()
val attrYMInterval = AttributeReference("cyminterval", YearMonthIntervalType())()
val attrDTInterval = AttributeReference("cdtinterval", DayTimeIntervalType())()

val s1 = 1.toShort
val s2 = 4.toShort
Expand Down Expand Up @@ -84,7 +87,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
attrFloat -> ColumnStat(min = Some(1.1f), max = Some(4.1f)),
attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.5))),
attrDate -> ColumnStat(min = Some(1), max = Some(4)),
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L))))
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)),
attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(4L)),
attrYMInterval -> ColumnStat(min = Some(2), max = Some(5)),
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(5L))))

val s3 = 2.toShort
val s4 = 6.toShort
Expand Down Expand Up @@ -118,7 +124,16 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
AttributeReference("cdate1", DateType)() -> ColumnStat(min = Some(3), max = Some(6)),
AttributeReference("ctimestamp1", TimestampType)() -> ColumnStat(
min = Some(3L),
max = Some(6L))))
max = Some(6L)),
AttributeReference("ctimestamp_ntz1", TimestampNTZType)() -> ColumnStat(
min = Some(3L),
max = Some(6L)),
AttributeReference("cymtimestamp1", YearMonthIntervalType())() -> ColumnStat(
min = Some(4),
max = Some(8)),
AttributeReference("cdttimestamp1", DayTimeIntervalType())() -> ColumnStat(
min = Some(4L),
max = Some(8L))))

val child1 = StatsTestPlan(
outputList = columnInfo.keys.toSeq.sortWith(_.exprId.id < _.exprId.id),
Expand Down Expand Up @@ -147,7 +162,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
attrFloat -> ColumnStat(min = Some(1.1f), max = Some(6.1f)),
attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.9))),
attrDate -> ColumnStat(min = Some(1), max = Some(6)),
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)))))
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)),
attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(6L)),
attrYMInterval -> ColumnStat(min = Some(2), max = Some(8)),
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(8L)))))
assert(union.stats === expectedStats)
}

Expand Down

0 comments on commit 7927295

Please sign in to comment.