Skip to content

Commit

Permalink
Fix last openRecordGroup not processed in FlatArrayBuilder
Browse files Browse the repository at this point in the history
When record size is a multiple of RECORDS_PER_GROUP, the last group is skipped
  • Loading branch information
jinyangli34 authored and dain committed Apr 22, 2024
1 parent 2cb497d commit 202534a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import static io.trino.operator.VariableWidthData.EMPTY_CHUNK;
import static io.trino.operator.VariableWidthData.POINTER_SIZE;
import static io.trino.operator.VariableWidthData.getChunkOffset;
import static java.lang.Math.toIntExact;
import static java.nio.ByteOrder.LITTLE_ENDIAN;
import static java.util.Objects.checkIndex;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -208,7 +209,8 @@ public void writeAll(BlockBuilder blockBuilder)
recordOffset += recordSize;
}
}
int recordsInOpenGroup = ((int) size) & RECORDS_PER_GROUP_MASK;

int recordsInOpenGroup = toIntExact(size - ((long) closedRecordGroups.size() * RECORDS_PER_GROUP));
int recordOffset = 0;
for (int recordIndex = 0; recordIndex < recordsInOpenGroup; recordIndex++) {
write(openRecordGroup, recordOffset, blockBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import java.util.List;
import java.util.OptionalInt;
import java.util.Random;
import java.util.stream.LongStream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.block.BlockAssertions.createArrayBigintBlock;
import static io.trino.block.BlockAssertions.createBooleansBlock;
import static io.trino.block.BlockAssertions.createLongsBlock;
Expand Down Expand Up @@ -103,6 +105,19 @@ public void testBigInt()
createLongsBlock(new Long[] {2L, 1L, 2L}));
}

@Test
public void testBigIntOnFlatArrayGroupSize()
{
long flatArrayGroupSize = 1 << 10;
long inputCount = flatArrayGroupSize * 2; // data will be split into two pages in assertAggregation
assertAggregation(
FUNCTION_RESOLUTION,
"array_agg",
fromTypes(BIGINT),
LongStream.rangeClosed(1L, inputCount).boxed().collect(toImmutableList()),
createLongsBlock(LongStream.rangeClosed(1L, inputCount).boxed().collect(toImmutableList())));
}

@Test
public void testVarchar()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.arrayagg;

import io.trino.spi.block.Block;
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.block.LongArrayBlockBuilder;
import io.trino.spi.block.ValueBlock;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.lang.invoke.MethodHandle;
import java.util.Optional;
import java.util.stream.IntStream;

import static io.trino.operator.PagesIndex.TestingFactory.TYPE_OPERATORS;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN;
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.type.BigintType.BIGINT;

public class TestFlatArrayBuilder
{
private final MethodHandle valueReadFlat = TYPE_OPERATORS.getReadValueOperator(BIGINT, simpleConvention(BLOCK_BUILDER, FLAT));
private final MethodHandle valueWriteFlat = TYPE_OPERATORS.getReadValueOperator(BIGINT, simpleConvention(FLAT_RETURN, VALUE_BLOCK_POSITION_NOT_NULL));

@Test
public void testWriteAll()
{
int size = 1024;
FlatArrayBuilder flatArrayBuilder = new FlatArrayBuilder(BIGINT, valueReadFlat, valueWriteFlat, false);

ValueBlock valueBlock = new LongArrayBlock(size, Optional.empty(),
IntStream.range(0, size).mapToLong(i -> i).toArray());

for (int i = 0; i < size; i++) {
flatArrayBuilder.add(valueBlock, i);
}

LongArrayBlockBuilder blockBuilder = new LongArrayBlockBuilder(null, size);
flatArrayBuilder.writeAll(blockBuilder);

Block block = blockBuilder.build();
Assertions.assertEquals(size, block.getPositionCount());
for (int i = 0; i < size; i++) {
Assertions.assertEquals(i, BIGINT.getLong(block, i));
}
}

@Test
public void testWrite()
{
int size = 1024;
FlatArrayBuilder flatArrayBuilder = new FlatArrayBuilder(BIGINT, valueReadFlat, valueWriteFlat, true);

ValueBlock valueBlock = new LongArrayBlock(size, Optional.empty(),
IntStream.range(0, size).mapToLong(i -> i).toArray());

for (int i = 0; i < size; i++) {
flatArrayBuilder.add(valueBlock, i);