Skip to content

Commit

Permalink
[FLINK-19449][table] Pass isBounded to AggFunctionFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
JingsongLi committed Apr 26, 2021
1 parent 9b9d4f0 commit bf8f998
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);

final AggregateInfoList localAggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
false); // isStateBackendDataViews

final AggregateInfoList globalAggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);

final AggregateInfoList aggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
// Hopping window requires additional COUNT(*) to determine whether to register next timer
// through whether the current fired window is empty, see SliceSharedWindowAggProcessor.
final AggregateInfoList aggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
def getAggCallFromLocalAgg(
index: Int,
aggCalls: Seq[AggregateCall],
inputType: RelDataType): AggregateCall = {
inputType: RelDataType,
isBounded: Boolean): AggregateCall = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
aggCalls, inputType)
aggCalls, inputType, isBounded)
if (outputIndexToAggCallIndexMap.containsKey(index)) {
val realIndex = outputIndexToAggCallIndexMap.get(index)
aggCalls(realIndex)
Expand All @@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
def getAggCallIndexInLocalAgg(
index: Int,
globalAggCalls: Seq[AggregateCall],
inputRowType: RelDataType): Integer = {
inputRowType: RelDataType,
isBounded: Boolean): Integer = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
globalAggCalls, inputRowType)
globalAggCalls, inputRowType, isBounded)

outputIndexToAggCallIndexMap.foreach {
case (k, v) => if (v == index) {
Expand All @@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamPhysicalGlobalGroupAggregate
if agg.aggCalls.length > aggCallIndex =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded = false)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: StreamPhysicalLocalGroupAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = false)
case agg: StreamPhysicalIncrementalGroupAggregate
if agg.partialAggCalls.length > aggCallIndex =>
agg.partialAggCalls(aggCallIndex)
case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchPhysicalLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchPhysicalLocalSortAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveWindowAggregateInfoList(
lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ import scala.collection.JavaConversions._
* as subclasses of [[SqlAggFunction]] in Calcite but not as [[BridgingSqlAggFunction]]. The factory
* returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]].
*
* @param inputType the input rel data type
* @param orderKeyIdx the indexes of order key (null when is not over agg)
* @param needRetraction true if need retraction
* @param inputRowType the input row type
* @param orderKeyIndexes the indexes of order key (null when is not over agg)
* @param aggCallNeedRetractions true if need retraction
* @param isBounded true if the source is bounded source
*/
class AggFunctionFactory(
inputRowType: RowType,
orderKeyIndexes: Array[Int],
aggCallNeedRetractions: Array[Boolean]) {
aggCallNeedRetractions: Array[Boolean],
isBounded: Boolean) {

/**
* The entry point to create an aggregate function from the given [[AggregateCall]].
Expand Down Expand Up @@ -94,8 +96,12 @@ class AggFunctionFactory(
case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
createDenseRankAggFunction(argTypes)

case _: SqlLeadLagAggFunction =>
createLeadLagAggFunction(argTypes, index)
case func: SqlLeadLagAggFunction =>
if (isBounded) {
createBatchLeadLagAggFunction(argTypes, index)
} else {
createStreamLeadLagAggFunction(func, argTypes, index)
}

case _: SqlSingleValueAggFunction =>
createSingleValueAggFunction(argTypes)
Expand Down Expand Up @@ -328,7 +334,22 @@ class AggFunctionFactory(
}
}

private def createLeadLagAggFunction(
private def createStreamLeadLagAggFunction(
func: SqlLeadLagAggFunction,
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (func.getKind == SqlKind.LEAD) {
throw new TableException("LEAD Function is not supported in stream mode.")
}

if (aggCallNeedRetractions(index)) {
throw new TableException("LAG Function with retraction is not supported in stream mode.")
}

new LagAggFunction(argTypes)
}

private def createBatchLeadLagAggFunction(
argTypes: Array[LogicalType], index: Int): UserDefinedFunction = {
argTypes(0).getTypeRoot match {
case TINYINT =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ object AggregateUtil extends Enumeration {
def getOutputIndexToAggCallIndexMap(
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
isBounded: Boolean,
orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = {
val aggInfos = transformToAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputType),
Expand All @@ -161,7 +162,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
needDistinctInfo = false).aggInfos
needDistinctInfo = false,
isBounded).aggInfos

val map = new util.HashMap[Integer, Integer]()
var outputIndex = 0
Expand Down Expand Up @@ -248,7 +250,7 @@ object AggregateUtil extends Enumeration {
isStateBackendDataViews = true)
}

def deriveWindowAggregateInfoList(
def deriveStreamWindowAggregateInfoList(
inputRowType: RowType,
aggCalls: Seq[AggregateCall],
windowSpec: WindowSpec,
Expand All @@ -271,7 +273,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes = null,
needInputCount,
isStateBackendDataViews,
needDistinctInfo = true)
needDistinctInfo = true,
isBounded = false)
}

def transformToBatchAggregateFunctions(
Expand All @@ -287,7 +290,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
needDistinctInfo = false).aggInfos
needDistinctInfo = false,
isBounded = true).aggInfos

val aggFields = aggInfos.map(_.argIndexes)
val bufferTypes = aggInfos.map(_.externalAccTypes)
Expand Down Expand Up @@ -315,7 +319,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
needDistinctInfo = false)
needDistinctInfo = false,
isBounded = true)
}

def transformToStreamAggregateInfoList(
Expand All @@ -332,7 +337,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes = null,
needInputCount,
isStateBackendDataViews,
needDistinctInfo)
needDistinctInfo,
isBounded = false)
}

/**
Expand All @@ -355,7 +361,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes: Array[Int],
needInputCount: Boolean,
isStateBackedDataViews: Boolean,
needDistinctInfo: Boolean): AggregateInfoList = {
needDistinctInfo: Boolean,
isBounded: Boolean): AggregateInfoList = {

// Step-1:
// if need inputCount, find count1 in the existed aggregate calls first,
Expand All @@ -375,7 +382,11 @@ object AggregateUtil extends Enumeration {

// Step-3:
// create aggregate information
val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, aggCallNeedRetractions)
val factory = new AggFunctionFactory(
inputRowType,
orderKeyIndexes,
aggCallNeedRetractions,
isBounded)
val aggInfos = newAggCalls
.zipWithIndex
.map { case (call, index) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,8 @@ class FlinkRelMdHandlerTestBase {
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
Array.fill(aggCalls.size())(false),
false)
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
Expand Down Expand Up @@ -1157,7 +1158,8 @@ class FlinkRelMdHandlerTestBase {
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(calcOnStudentScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
Array.fill(aggCalls.size())(false),
false)
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
Expand Down Expand Up @@ -1324,7 +1326,8 @@ class FlinkRelMdHandlerTestBase {
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
Array.fill(aggCalls.size())(false))
Array.fill(aggCalls.size())(false),
false)
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index))
}
Expand Down

0 comments on commit bf8f998

Please sign in to comment.