diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/AppendOnlyTopNFunction.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/AppendOnlyTopNFunction.java index 6a3cf5acd8176..82e2fae4fead7 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/AppendOnlyTopNFunction.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/AppendOnlyTopNFunction.java @@ -211,16 +211,22 @@ private void processElementWithoutRowNumber(RowData input, Collector ou if (buffer.getCurrentTopNum() > rankEnd) { Map.Entry> lastEntry = buffer.lastEntry(); RowData lastKey = lastEntry.getKey(); - List lastList = (List) lastEntry.getValue(); + Collection lastList = lastEntry.getValue(); + RowData lastElement = buffer.lastElement(); + int size = lastList.size(); // remove last one - RowData lastElement = lastList.remove(lastList.size() - 1); - if (lastList.isEmpty()) { + if (size <= 1) { buffer.removeAll(lastKey); dataState.remove(lastKey); } else { - dataState.put(lastKey, lastList); + buffer.removeLast(); + // last element has been removed from lastList, we have to copy a new collection + // for lastList to avoid mutating state values, see CopyOnWriteStateMap, + // otherwise, the result might be corrupt. + // don't need to perform a deep copy, because RowData elements will not be updated + dataState.put(lastKey, new ArrayList<>(lastList)); } - if (input.equals(lastElement)) { + if (size == 0 || input.equals(lastElement)) { return; } else { // lastElement shouldn't be null diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/TopNBuffer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/TopNBuffer.java index 5d3ec52dfd4d1..3c8f7f56a8f01 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/TopNBuffer.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/operators/rank/TopNBuffer.java @@ -23,7 +23,7 @@ import java.io.Serializable; import java.util.Collection; import java.util.Comparator; -import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; @@ -94,12 +94,12 @@ public Collection get(RowData sortKey) { } public void remove(RowData sortKey, RowData value) { - Collection list = treeMap.get(sortKey); - if (list != null) { - if (list.remove(value)) { + Collection collection = treeMap.get(sortKey); + if (collection != null) { + if (collection.remove(value)) { currentTopNum -= 1; } - if (list.size() == 0) { + if (collection.size() == 0) { treeMap.remove(sortKey); } } @@ -111,9 +111,9 @@ public void remove(RowData sortKey, RowData value) { * @param sortKey key to remove */ void removeAll(RowData sortKey) { - Collection list = treeMap.get(sortKey); - if (list != null) { - currentTopNum -= list.size(); + Collection collection = treeMap.get(sortKey); + if (collection != null) { + currentTopNum -= collection.size(); treeMap.remove(sortKey); } } @@ -127,20 +127,47 @@ RowData removeLast() { Map.Entry> last = treeMap.lastEntry(); RowData lastElement = null; if (last != null) { - Collection list = last.getValue(); - lastElement = getLastElement(list); - if (lastElement != null) { - if (list.remove(lastElement)) { - currentTopNum -= 1; - } - if (list.size() == 0) { - treeMap.remove(last.getKey()); + Collection collection = last.getValue(); + if (collection != null) { + if (collection instanceof List) { + // optimization for List + List list = (List) collection; + if (!list.isEmpty()) { + lastElement = list.remove(list.size() - 1); + currentTopNum -= 1; + if (list.isEmpty()) { + treeMap.remove(last.getKey()); + } + } + } else { + lastElement = getLastElement(collection); + if (lastElement != null) { + if (collection.remove(lastElement)) { + currentTopNum -= 1; + } + if (collection.size() == 0) { + treeMap.remove(last.getKey()); + } + } } } } return lastElement; } + /** + * Returns the last record of the last Entry in the buffer. + */ + RowData lastElement() { + Map.Entry> last = treeMap.lastEntry(); + RowData lastElement = null; + if (last != null) { + Collection collection = last.getValue(); + lastElement = getLastElement(collection); + } + return lastElement; + } + /** * Gets record which rank is given value. * @@ -150,28 +177,32 @@ RowData removeLast() { RowData getElement(int rank) { int curRank = 0; for (Map.Entry> entry : treeMap.entrySet()) { - Collection list = entry.getValue(); - - if (curRank + list.size() >= rank) { - for (RowData elem : list) { + Collection collection = entry.getValue(); + if (curRank + collection.size() >= rank) { + for (RowData elem : collection) { curRank += 1; if (curRank == rank) { return elem; } } } else { - curRank += list.size(); + curRank += collection.size(); } } return null; } - private RowData getLastElement(Collection list) { + private RowData getLastElement(Collection collection) { RowData element = null; - if (list != null && !list.isEmpty()) { - Iterator iter = list.iterator(); - while (iter.hasNext()) { - element = iter.next(); + if (collection != null && !collection.isEmpty()) { + if (collection instanceof List) { + // optimize for List + List list = (List) collection; + return list.get(list.size() - 1); + } else { + for (RowData data : collection) { + element = data; + } } } return element;