Skip to content

Commit

Permalink
Merge pull request apache#11046: [BEAM-9442] Properly handle nullable…
Browse files Browse the repository at this point in the history
… fields in Select
  • Loading branch information
reuvenlax committed Mar 13, 2020
1 parent b59ec15 commit 18948d9
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.collection.ArrayAccess;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.collection.ArrayFactory;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.constant.IntegerConstant;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.constant.NullConstant;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.FieldAccess;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodInvocation;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodReturn;
Expand Down Expand Up @@ -149,14 +150,14 @@ private static boolean overridePackage(@Nullable String targetPackage) {
}
};

// This StackManipulation returns onNotNull if the result of readValue is not null. Otherwise it
// returns null.
static class ShortCircuitReturnNull implements StackManipulation {
static class IfNullElse implements StackManipulation {
private final StackManipulation readValue;
private final StackManipulation onNull;
private final StackManipulation onNotNull;

ShortCircuitReturnNull(StackManipulation readValue, StackManipulation onNotNull) {
IfNullElse(StackManipulation readValue, StackManipulation onNull, StackManipulation onNotNull) {
this.readValue = readValue;
this.onNull = onNull;
this.onNotNull = onNotNull;
}

Expand All @@ -173,7 +174,7 @@ public Size apply(MethodVisitor methodVisitor, Context context) {
Label skipLabel = new Label();
methodVisitor.visitJumpInsn(Opcodes.IFNONNULL, label);
size = size.aggregate(new Size(-1, 0));
methodVisitor.visitInsn(Opcodes.ACONST_NULL);
size = size.aggregate(onNull.apply(methodVisitor, context));
methodVisitor.visitJumpInsn(Opcodes.GOTO, skipLabel);
size = size.aggregate(new Size(0, 1));
methodVisitor.visitLabel(label);
Expand All @@ -185,6 +186,14 @@ public Size apply(MethodVisitor methodVisitor, Context context) {
}
}

// This StackManipulation returns onNotNull if the result of readValue is not null. Otherwise it
// returns null.
static class ShortCircuitReturnNull extends IfNullElse {
ShortCircuitReturnNull(StackManipulation readValue, StackManipulation onNotNull) {
super(readValue, NullConstant.INSTANCE, onNotNull);
}
}

// Create a new FieldValueGetter subclass.
@SuppressWarnings("unchecked")
public static DynamicType.Builder<FieldValueGetter> subclassGetterInterface(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.auto.value.AutoValue;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -32,6 +33,8 @@
import org.apache.beam.sdk.schemas.FieldAccessDescriptor.FieldDescriptor.Qualifier;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.IfNullElse;
import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ShortCircuitReturnNull;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.ByteBuddy;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.asm.AsmVisitorWrapper;
Expand All @@ -54,6 +57,7 @@
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.assign.TypeCasting;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.collection.ArrayAccess;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.constant.IntegerConstant;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.constant.NullConstant;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.FieldAccess;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodInvocation;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodReturn;
Expand All @@ -67,7 +71,7 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;

public class SelectByteBuddyHelpers {
class SelectByteBuddyHelpers {
private static final ByteBuddy BYTE_BUDDY = new ByteBuddy();
private static final String SELECT_SCHEMA_FIELD_NAME = "OUTPUTSCHEMA";

Expand Down Expand Up @@ -137,7 +141,7 @@ static SchemaAndDescriptor of(Schema schema, FieldAccessDescriptor fieldAccessDe
private static final Map<SchemaAndDescriptor, RowSelector> CACHED_SELECTORS =
Maps.newConcurrentMap();

public static RowSelector getRowSelector(
static RowSelector getRowSelector(
Schema inputSchema, FieldAccessDescriptor fieldAccessDescriptor) {
return CACHED_SELECTORS.computeIfAbsent(
SchemaAndDescriptor.of(inputSchema, fieldAccessDescriptor),
Expand Down Expand Up @@ -234,7 +238,7 @@ private static class ArrayManager {
this.arraySize = arraySize;
}

public StackManipulation createArray() {
StackManipulation createArray() {
return new StackManipulation() {
@Override
public boolean isValid() {
Expand All @@ -252,11 +256,15 @@ public Size apply(MethodVisitor methodVisitor, Context context) {
};
}

public StackManipulation append(StackManipulation valueToWrite) {
StackManipulation append(StackManipulation valueToWrite) {
return store(currentArrayField++, valueToWrite);
}

public StackManipulation store(int arrayIndexToWrite, StackManipulation valueToWrite) {
int reserveSlot() {
return currentArrayField++;
}

StackManipulation store(int arrayIndexToWrite, StackManipulation valueToWrite) {
Preconditions.checkArgument(arrayIndexToWrite < arraySize);
return new StackManipulation() {
@Override
Expand Down Expand Up @@ -374,16 +382,20 @@ public ByteCodeAppender appender(final Target implementationTarget) {
// Selects a field from the current row being selected (the one stored in
// currentSelectRowArg).
private StackManipulation getCurrentRowFieldValue(int i) {
return new StackManipulation.Compound(
localVariables.readVariable(currentSelectRowArg, Row.class),
IntegerConstant.forValue(i),
MethodInvocation.invoke(
ROW_LOADED_TYPE
.getDeclaredMethods()
.filter(
ElementMatchers.named("getValue")
.and(ElementMatchers.takesArguments(int.class)))
.getOnly()));
StackManipulation readRow = localVariables.readVariable(currentSelectRowArg, Row.class);
StackManipulation getValue =
new StackManipulation.Compound(
localVariables.readVariable(currentSelectRowArg, Row.class),
IntegerConstant.forValue(i),
MethodInvocation.invoke(
ROW_LOADED_TYPE
.getDeclaredMethods()
.filter(
ElementMatchers.named("getValue")
.and(ElementMatchers.takesArguments(int.class)))
.getOnly()));

return new ShortCircuitReturnNull(readRow, getValue);
}

// Generate bytecode to select all specified fields from the Row. The current row being selected
Expand Down Expand Up @@ -536,25 +548,43 @@ private StackManipulation.Size processList(
IntStream.range(0, nestedSchema.getFieldCount())
.map(i -> localVariables.createVariable())
.toArray();

// Each field returned in nestedSchema will become it's own list in the output. So let's
// iterate and create arrays and store each one in the output.
StackManipulation createAllArrayLists =
new StackManipulation.Compound(
IntStream.range(0, nestedSchema.getFieldCount())
Arrays.stream(localVariablesForArrays)
.mapToObj(
i -> {
v -> {
StackManipulation createArrayList =
new StackManipulation.Compound(
MethodInvocation.invoke(LISTS_NEW_ARRAYLIST),
// Store the ArrayList in a local variable.
Duplication.SINGLE,
localVariables.writeVariable(localVariablesForArrays[i]));
return arrayManager.append(createArrayList);
localVariables.writeVariable(v));
StackManipulation storeNull =
new StackManipulation.Compound(
NullConstant.INSTANCE,
localVariables.writeVariable(v),
NullConstant.INSTANCE);

// Create the array only if the input isn't null. Otherwise store a null
// value into the output
// array.
int arraySlot = arrayManager.reserveSlot();
return new IfNullElse(
loadFieldValue(fieldId),
arrayManager.store(arraySlot, storeNull),
arrayManager.store(arraySlot, createArrayList));
})
.collect(Collectors.toList()));
size = size.aggregate(createAllArrayLists.apply(methodVisitor, implementationContext));

// If the input variable is null, then don't try and iterate over it.
Label onNullLabel = new Label();
size = size.aggregate(loadFieldValue(fieldId).apply(methodVisitor, implementationContext));
methodVisitor.visitJumpInsn(Opcodes.IFNULL, onNullLabel);
size = size.aggregate(StackSize.SINGLE.toDecreasingSize());

// Now iterate over the value, selecting from each element.
StackManipulation readListIterator =
new StackManipulation.Compound(
Expand All @@ -563,9 +593,9 @@ private StackManipulation.Size processList(
MethodInvocation.invoke(ITERABLE_ITERATOR));
size = size.aggregate(readListIterator.apply(methodVisitor, implementationContext));

// Loop over the entire iterable.
Label startLoopLabel = new Label();
Label exitLoopLabel = new Label();
// Loop over the entire iterable.
methodVisitor.visitLabel(startLoopLabel);

StackManipulation checkTerminationCondition =
Expand Down Expand Up @@ -663,6 +693,7 @@ private StackManipulation.Size processList(
methodVisitor.visitLabel(exitLoopLabel);
// Remove the iterator from the top of the stack.
size = size.aggregate(Removal.SINGLE.apply(methodVisitor, implementationContext));
methodVisitor.visitLabel(onNullLabel);
return size;
}

Expand Down Expand Up @@ -690,20 +721,35 @@ private StackManipulation.Size processMap(
// iterate and create arrays and store each one in the output.
StackManipulation createAllHashMaps =
new StackManipulation.Compound(
IntStream.range(0, nestedSchema.getFieldCount())
Arrays.stream(localVariablesForMaps)
.mapToObj(
i -> {
v -> {
StackManipulation createHashMap =
new StackManipulation.Compound(
MethodInvocation.invoke(MAPS_NEW_HASHMAP),
// Store the HashMap in a local variable.
Duplication.SINGLE,
localVariables.writeVariable(localVariablesForMaps[i]));
return arrayManager.append(createHashMap);
localVariables.writeVariable(v));
StackManipulation storeNull =
new StackManipulation.Compound(
NullConstant.INSTANCE,
localVariables.writeVariable(v),
NullConstant.INSTANCE);
int arraySlot = arrayManager.reserveSlot();
return new IfNullElse(
loadFieldValue(fieldId),
arrayManager.store(arraySlot, storeNull),
arrayManager.store(arraySlot, createHashMap));
})
.collect(Collectors.toList()));
size = size.aggregate(createAllHashMaps.apply(methodVisitor, implementationContext));

// If the input variable is null, then don't try and iterate over it.
Label onNullLabel = new Label();
size = size.aggregate(loadFieldValue(fieldId).apply(methodVisitor, implementationContext));
methodVisitor.visitJumpInsn(Opcodes.IFNULL, onNullLabel);
size = size.aggregate(StackSize.SINGLE.toDecreasingSize());

// Now iterate over the value, selecting from each element.
StackManipulation readMapEntriesIterator =
new StackManipulation.Compound(
Expand Down Expand Up @@ -822,6 +868,7 @@ private StackManipulation.Size processMap(
methodVisitor.visitLabel(exitLoopLabel);
// Remove the iterator from the top of the stack.
size = size.aggregate(Removal.SINGLE.apply(methodVisitor, implementationContext));
methodVisitor.visitLabel(onNullLabel);
return size;
}

Expand Down
Loading

0 comments on commit 18948d9

Please sign in to comment.