Skip to content

Commit

Permalink
[FLINK-22781][table-planner-blink] Fix bug in emit behavior of GroupW…
Browse files Browse the repository at this point in the history
…indowAggregate to skip emit window result if input stream of GroupWindowAggregate contains retraction and input counter of input records for current window is zero

This closes apache#16219
  • Loading branch information
beyond1920 authored and godfreyhe committed Jul 6, 2021
1 parent f61d9af commit c919edf
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {

final LogicalType[] aggValueTypes = extractLogicalTypes(aggInfoList.getActualValueTypes());
final LogicalType[] accTypes = extractLogicalTypes(aggInfoList.getAccTypes());
final int inputCountIndex = aggInfoList.getIndexOfCountStar();

final WindowOperator<?, ?> operator =
createWindowOperator(
Expand All @@ -256,7 +257,8 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
aggValueTypes,
inputRowType.getChildren().toArray(new LogicalType[0]),
inputTimeFieldIndex,
shiftTimeZone);
shiftTimeZone,
inputCountIndex);

final OneInputTransformation<RowData, RowData> transform =
new OneInputTransformation<>(
Expand Down Expand Up @@ -351,11 +353,13 @@ private GeneratedClass<?> createAggsHandler(
LogicalType[] aggValueTypes,
LogicalType[] inputFields,
int timeFieldIndex,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
WindowOperatorBuilder builder =
WindowOperatorBuilder.builder()
.withInputFields(inputFields)
.withShiftTimezone(shiftTimeZone);
.withShiftTimezone(shiftTimeZone)
.withInputCountIndex(inputCountIndex);

if (window instanceof TumblingGroupWindow) {
TumblingGroupWindow tumblingWindow = (TumblingGroupWindow) window;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import org.apache.flink.table.api.internal.TableEnvironmentInternal
import org.apache.flink.table.planner.factories.TestValuesTableFactory
import org.apache.flink.table.planner.factories.TestValuesTableFactory.{changelogRow, registerData}
import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.{ConcatDistinctAggFunction, WeightedAvg}
import org.apache.flink.table.planner.plan.utils.WindowEmitStrategy.{TABLE_EXEC_EMIT_LATE_FIRE_DELAY, TABLE_EXEC_EMIT_LATE_FIRE_ENABLED, TABLE_EXEC_EMIT_ALLOW_LATENESS}
import org.apache.flink.table.planner.plan.utils.WindowEmitStrategy.{TABLE_EXEC_EMIT_ALLOW_LATENESS, TABLE_EXEC_EMIT_LATE_FIRE_DELAY, TABLE_EXEC_EMIT_LATE_FIRE_ENABLED}
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode}
import org.apache.flink.table.planner.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset
import org.apache.flink.table.planner.runtime.utils._
import org.apache.flink.table.runtime.types.TypeInfoDataTypeConverter.fromDataTypeToTypeInfo
import org.apache.flink.types.Row

import org.junit.Assert.assertEquals
Expand All @@ -53,10 +54,11 @@ class GroupWindowITCase(mode: StateBackendMode, useTimestampLtz: Boolean)

val upsertSourceCurrencyData = List(
changelogRow("+U", "Euro", "no1", JLong.valueOf(114L), localDateTime(1L)),
changelogRow("+U", "US Dollar", "no1", JLong.valueOf(100L), localDateTime(1L)),
changelogRow("+U", "US Dollar", "no1", JLong.valueOf(102L), localDateTime(2L)),
changelogRow("+U", "Yen", "no1", JLong.valueOf(1L), localDateTime(3L)),
changelogRow("+U", "RMB", "no1", JLong.valueOf(702L), localDateTime(4L)),
changelogRow("+U", "Euro", "no1", JLong.valueOf(118L), localDateTime(6L)),
changelogRow("+U", "Euro", "no1", JLong.valueOf(118L), localDateTime(18L)),
changelogRow("+U", "US Dollar", "no1", JLong.valueOf(104L), localDateTime(4L)),
changelogRow("-D", "RMB", "no1", JLong.valueOf(702L), localDateTime(4L)))

Expand Down Expand Up @@ -403,6 +405,7 @@ class GroupWindowITCase(mode: StateBackendMode, useTimestampLtz: Boolean)
|SELECT
|currency,
|COUNT(1) AS cnt,
|MAX(rate),
|TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start,
|TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end
|FROM upsert_currency
Expand All @@ -412,14 +415,61 @@ class GroupWindowITCase(mode: StateBackendMode, useTimestampLtz: Boolean)
tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink)
env.execute()
val expected = Seq(
"Euro,0,1970-01-01T00:00,1970-01-01T00:00:05",
"US Dollar,1,1970-01-01T00:00,1970-01-01T00:00:05",
"Yen,1,1970-01-01T00:00,1970-01-01T00:00:05",
"RMB,0,1970-01-01T00:00,1970-01-01T00:00:05",
"Euro,1,1970-01-01T00:00:05,1970-01-01T00:00:10")
"US Dollar,1,102,1970-01-01T00:00,1970-01-01T00:00:05",
"Yen,1,1,1970-01-01T00:00,1970-01-01T00:00:05",
"Euro,1,118,1970-01-01T00:00:15,1970-01-01T00:00:20")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

@Test
def testWindowAggregateOnUpsertSourceWithAllowLateness(): Unit = {
// wait 15 second for late elements
tEnv.getConfig.getConfiguration.set(
TABLE_EXEC_EMIT_ALLOW_LATENESS, Duration.ofSeconds(15))
// emit result without delay after watermark
withLateFireDelay(tEnv.getConfig, Time.of(0, TimeUnit.NANOSECONDS))
val upsertSourceDataId = registerData(upsertSourceCurrencyData)
tEnv.executeSql(
s"""
|CREATE TABLE upsert_currency (
| currency STRING,
| currency_no STRING,
| rate BIGINT,
| currency_time TIMESTAMP(3),
| WATERMARK FOR currency_time AS currency_time - interval '5' SECOND,
| PRIMARY KEY(currency) NOT ENFORCED
|) WITH (
| 'connector' = 'values',
| 'changelog-mode' = 'UA,D',
| 'data-id' = '$upsertSourceDataId'
|)
|""".stripMargin)
val sql =
"""
|SELECT
|currency,
|COUNT(1) AS cnt,
|MAX(rate),
|TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start,
|TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end
|FROM upsert_currency
|GROUP BY currency, TUMBLE(currency_time, INTERVAL '5' SECOND)
|""".stripMargin
val table = tEnv.sqlQuery(sql)
val schema = table.getSchema
val sink = new TestingRetractTableSink().
configure(schema.getFieldNames,
schema.getFieldDataTypes.map(_.nullable()).map(fromDataTypeToTypeInfo))
tEnv.asInstanceOf[TableEnvironmentInternal].registerTableSinkInternal("MySink1", sink)
table.executeInsert("MySink1").await()

val expected = Seq(
"US Dollar,1,104,1970-01-01T00:00,1970-01-01T00:00:05",
"Yen,1,1,1970-01-01T00:00,1970-01-01T00:00:05",
"Euro,1,118,1970-01-01T00:00:15,1970-01-01T00:00:20")
assertEquals(expected.sorted, sink.getRetractResults.sorted)
}

@Test
def testWindowAggregateOnUpsertSourcePushdownWatermark(): Unit = {
val upsertSourceDataId = registerData(upsertSourceCurrencyData)
Expand Down Expand Up @@ -451,8 +501,8 @@ class GroupWindowITCase(mode: StateBackendMode, useTimestampLtz: Boolean)
tEnv.sqlQuery(sql).toAppendStream[Row].addSink(sink)
env.execute()
val expected = Seq(
"1970-01-01T00:00,1970-01-01T00:00:05,104",
"1970-01-01T00:00:05,1970-01-01T00:00:10,118")
"1970-01-01T00:00,1970-01-01T00:00:05,102",
"1970-01-01T00:00:15,1970-01-01T00:00:20,118")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

Expand Down Expand Up @@ -483,9 +533,7 @@ class GroupWindowITCase(mode: StateBackendMode, useTimestampLtz: Boolean)
val expected = Seq(
"Hi,1970-01-01T00:00,1970-01-01T00:00:00.005,1",
"Hallo,1970-01-01T00:00,1970-01-01T00:00:00.005,1",
"Hello,1970-01-01T00:00,1970-01-01T00:00:00.005,0",
"Hello,1970-01-01T00:00:00.005,1970-01-01T00:00:00.010,1",
"Hello world,1970-01-01T00:00:00.005,1970-01-01T00:00:00.010,0",
"Hello world,1970-01-01T00:00:00.015,1970-01-01T00:00:00.020,1",
"null,1970-01-01T00:00:00.030,1970-01-01T00:00:00.035,1")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class RecordCounter implements Serializable {
*
* @return true if input record count is zero, false if not.
*/
abstract boolean recordCountIsZero(RowData acc);
public abstract boolean recordCountIsZero(RowData acc);

/**
* Creates a {@link RecordCounter} depends on the index of count(*). If index is less than zero,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public class AggregateWindowOperator<K, W extends Window> extends WindowOperator
int rowtimeIndex,
boolean produceUpdates,
long allowedLateness,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
super(
windowAggregator,
windowAssigner,
Expand All @@ -95,7 +96,8 @@ public class AggregateWindowOperator<K, W extends Window> extends WindowOperator
rowtimeIndex,
produceUpdates,
allowedLateness,
shiftTimeZone);
shiftTimeZone,
inputCountIndex);
this.aggWindowAggregator = windowAggregator;
this.equaliser = checkNotNull(equaliser);
}
Expand All @@ -113,7 +115,8 @@ public class AggregateWindowOperator<K, W extends Window> extends WindowOperator
int rowtimeIndex,
boolean sendRetraction,
long allowedLateness,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
super(
windowAssigner,
trigger,
Expand All @@ -125,7 +128,8 @@ public class AggregateWindowOperator<K, W extends Window> extends WindowOperator
rowtimeIndex,
sendRetraction,
allowedLateness,
shiftTimeZone);
shiftTimeZone,
inputCountIndex);
this.generatedAggWindowAggregator = generatedAggWindowAggregator;
this.generatedEqualiser = checkNotNull(generatedEqualiser);
}
Expand Down Expand Up @@ -156,35 +160,55 @@ protected void compileGeneratedCode() {
@Override
protected void emitWindowResult(W window) throws Exception {
windowFunction.prepareAggregateAccumulatorForEmit(window);
RowData acc = aggWindowAggregator.getAccumulators();
RowData aggResult = aggWindowAggregator.getValue(window);
if (produceUpdates) {
previousState.setCurrentNamespace(window);
RowData previousAggResult = previousState.value();

// has emitted result for the window
if (previousAggResult != null) {
// current agg is not equal to the previous emitted, should emit retract
if (!equaliser.equals(aggResult, previousAggResult)) {
// send UPDATE_BEFORE
collect(RowKind.UPDATE_BEFORE, (RowData) getCurrentKey(), previousAggResult);
// send UPDATE_AFTER
collect(RowKind.UPDATE_AFTER, (RowData) getCurrentKey(), aggResult);
if (!recordCounter.recordCountIsZero(acc)) {
// has emitted result for the window
if (previousAggResult != null) {
// current agg is not equal to the previous emitted, should emit retract
if (!equaliser.equals(aggResult, previousAggResult)) {
// send UPDATE_BEFORE
collect(
RowKind.UPDATE_BEFORE,
(RowData) getCurrentKey(),
previousAggResult);
// send UPDATE_AFTER
collect(RowKind.UPDATE_AFTER, (RowData) getCurrentKey(), aggResult);
// update previousState
previousState.update(aggResult);
}
// if the previous agg equals to the current agg, no need to send retract and
// accumulate
}
// the first fire for the window, only send INSERT
else {
// send INSERT
collect(RowKind.INSERT, (RowData) getCurrentKey(), aggResult);
// update previousState
previousState.update(aggResult);
}
// if the previous agg equals to the current agg, no need to send retract and
// accumulate
} else {
// has emitted result for the window
// we retracted the last record for this key
if (previousAggResult != null) {
// send DELETE
collect(RowKind.DELETE, (RowData) getCurrentKey(), previousAggResult);
// clear previousState
previousState.clear();
}
// if the counter is zero, no need to send accumulate
}
// the first fire for the window, only send INSERT
else {
} else {
if (!recordCounter.recordCountIsZero(acc)) {
// send INSERT
collect(RowKind.INSERT, (RowData) getCurrentKey(), aggResult);
// update previousState
previousState.update(aggResult);
}
} else {
// send INSERT
collect(RowKind.INSERT, (RowData) getCurrentKey(), aggResult);
// if the counter is zero, no need to send accumulate
// there is no possible skip `if` branch when `produceUpdates` is false
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public class TableAggregateWindowOperator<K, W extends Window> extends WindowOpe
int rowtimeIndex,
boolean produceUpdates,
long allowedLateness,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
super(
windowTableAggregator,
windowAssigner,
Expand All @@ -79,7 +80,8 @@ public class TableAggregateWindowOperator<K, W extends Window> extends WindowOpe
rowtimeIndex,
produceUpdates,
allowedLateness,
shiftTimeZone);
shiftTimeZone,
inputCountIndex);
this.tableAggWindowAggregator = windowTableAggregator;
}

Expand All @@ -95,7 +97,8 @@ public class TableAggregateWindowOperator<K, W extends Window> extends WindowOpe
int rowtimeIndex,
boolean sendRetraction,
long allowedLateness,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
super(
windowAssigner,
trigger,
Expand All @@ -107,7 +110,8 @@ public class TableAggregateWindowOperator<K, W extends Window> extends WindowOpe
rowtimeIndex,
sendRetraction,
allowedLateness,
shiftTimeZone);
shiftTimeZone,
inputCountIndex);
this.generatedTableAggWindowAggregator = generatedTableAggWindowAggregator;
}

Expand All @@ -124,6 +128,9 @@ protected void compileGeneratedCode() {
@Override
protected void emitWindowResult(W window) throws Exception {
windowFunction.prepareAggregateAccumulatorForEmit(window);
tableAggWindowAggregator.emitValue(window, (RowData) getCurrentKey(), collector);
RowData acc = tableAggWindowAggregator.getAccumulators();
if (!recordCounter.recordCountIsZero(acc)) {
tableAggWindowAggregator.emitValue(window, (RowData) getCurrentKey(), collector);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.flink.table.data.util.RowDataUtil;
import org.apache.flink.table.runtime.dataview.PerWindowStateDataViewStore;
import org.apache.flink.table.runtime.generated.NamespaceAggsHandleFunctionBase;
import org.apache.flink.table.runtime.operators.aggregate.RecordCounter;
import org.apache.flink.table.runtime.operators.window.assigners.MergingWindowAssigner;
import org.apache.flink.table.runtime.operators.window.assigners.PanedWindowAssigner;
import org.apache.flink.table.runtime.operators.window.assigners.WindowAssigner;
Expand Down Expand Up @@ -143,6 +144,9 @@ public abstract class WindowOperator<K, W extends Window> extends AbstractStream
*/
private final long allowedLateness;

/** Used to count the number of added and retracted input records. */
protected final RecordCounter recordCounter;

// --------------------------------------------------------------------------------

protected NamespaceAggsHandleFunctionBase<W> windowAggregator;
Expand Down Expand Up @@ -182,7 +186,8 @@ public abstract class WindowOperator<K, W extends Window> extends AbstractStream
int rowtimeIndex,
boolean produceUpdates,
long allowedLateness,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
checkArgument(allowedLateness >= 0);
this.windowAggregator = checkNotNull(windowAggregator);
this.windowAssigner = checkNotNull(windowAssigner);
Expand All @@ -199,6 +204,8 @@ public abstract class WindowOperator<K, W extends Window> extends AbstractStream
checkArgument(!windowAssigner.isEventTime() || rowtimeIndex >= 0);
this.rowtimeIndex = rowtimeIndex;
this.shiftTimeZone = shiftTimeZone;
this.recordCounter = RecordCounter.of(inputCountIndex);

setChainingStrategy(ChainingStrategy.ALWAYS);
}

Expand All @@ -213,7 +220,8 @@ public abstract class WindowOperator<K, W extends Window> extends AbstractStream
int rowtimeIndex,
boolean produceUpdates,
long allowedLateness,
ZoneId shiftTimeZone) {
ZoneId shiftTimeZone,
int inputCountIndex) {
checkArgument(allowedLateness >= 0);
this.windowAssigner = checkNotNull(windowAssigner);
this.trigger = checkNotNull(trigger);
Expand All @@ -229,6 +237,7 @@ public abstract class WindowOperator<K, W extends Window> extends AbstractStream
checkArgument(!windowAssigner.isEventTime() || rowtimeIndex >= 0);
this.rowtimeIndex = rowtimeIndex;
this.shiftTimeZone = shiftTimeZone;
this.recordCounter = RecordCounter.of(inputCountIndex);

setChainingStrategy(ChainingStrategy.ALWAYS);
}
Expand Down
Loading

0 comments on commit c919edf

Please sign in to comment.