diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index dc276064354b5..a6ecc4500bd8a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -17,17 +17,11 @@ */ package org.apache.beam.sdk.schemas; -import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; - import com.google.auto.value.AutoValue; import java.io.Serializable; import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; import java.util.Arrays; -import java.util.Collection; -import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.values.TypeDescriptor; @@ -129,9 +123,13 @@ public static FieldValueTypeInformation forGetter(Method method) { } public static FieldValueTypeInformation forSetter(Method method) { + return forSetter(method, "set"); + } + + public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { String name; - if (method.getName().startsWith("set")) { - name = ReflectUtils.stripPrefix(method.getName(), "set"); + if (method.getName().startsWith(setterPrefix)) { + name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); } else { throw new RuntimeException("Setter has wrong prefix " + method.getName()); } @@ -162,25 +160,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { } @Nullable - private static FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { + static FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { // TODO: Figure out nullable elements. - TypeDescriptor componentType = null; - if (valueType.isArray()) { - Type component = valueType.getComponentType().getType(); - if (!component.equals(byte.class)) { - componentType = TypeDescriptor.of(component); - } - } else if (valueType.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - TypeDescriptor> collection = valueType.getSupertype(Iterable.class); - if (collection.getType() instanceof ParameterizedType) { - ParameterizedType ptype = (ParameterizedType) collection.getType(); - java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - checkArgument(params.length == 1); - componentType = TypeDescriptor.of(params[0]); - } else { - throw new RuntimeException("Collection parameter is not parameterized!"); - } - } + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); if (componentType == null) { return null; } @@ -223,17 +205,7 @@ private static FieldValueTypeInformation getMapValueType(TypeDescriptor typeDesc @SuppressWarnings("unchecked") @Nullable private static FieldValueTypeInformation getMapType(TypeDescriptor valueType, int index) { - TypeDescriptor mapType = null; - if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) { - TypeDescriptor> map = valueType.getSupertype(Map.class); - if (map.getType() instanceof ParameterizedType) { - ParameterizedType ptype = (ParameterizedType) map.getType(); - java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - mapType = TypeDescriptor.of(params[index]); - } else { - throw new RuntimeException("Map type is not parameterized! " + map); - } - } + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); if (mapType == null) { return null; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java index b1b8ee80c38f0..61c0d0520d3cd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java @@ -21,7 +21,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import java.lang.reflect.Type; -import java.util.Iterator; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; @@ -31,6 +31,9 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.RowWithGetters; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; @@ -107,7 +110,8 @@ private ValueT fromValue( return (ValueT) fromRow((Row) value, (Class) fieldType, typeFactory); } else if (TypeName.ARRAY.equals(type.getTypeName())) { return (ValueT) - fromListValue(type.getCollectionElementType(), (List) value, elementType, typeFactory); + fromCollectionValue( + type.getCollectionElementType(), (Collection) value, elementType, typeFactory); } else if (TypeName.ITERABLE.equals(type.getTypeName())) { return (ValueT) fromIterableValue( @@ -127,25 +131,35 @@ private ValueT fromValue( } } + private static Collection transformCollection( + Collection collection, Function function) { + if (collection instanceof List) { + // For performance reasons if the input is a list, make sure that we produce a list. Otherwise + // Row unwrapping + // is forced to physically copy the collection into a new List object. + return Lists.transform((List) collection, function); + } else { + return Collections2.transform(collection, function); + } + } + @SuppressWarnings("unchecked") - private List fromListValue( + private Collection fromCollectionValue( FieldType elementType, - List rowList, + Collection rowCollection, FieldValueTypeInformation elementTypeInformation, Factory> typeFactory) { - List list = Lists.newArrayList(); - for (ElementT element : rowList) { - list.add( - fromValue( - elementType, - element, - elementTypeInformation.getType().getType(), - elementTypeInformation.getElementType(), - elementTypeInformation.getMapKeyType(), - elementTypeInformation.getMapValueType(), - typeFactory)); - } - return list; + return transformCollection( + rowCollection, + element -> + fromValue( + elementType, + element, + elementTypeInformation.getType().getType(), + elementTypeInformation.getElementType(), + elementTypeInformation.getMapKeyType(), + elementTypeInformation.getMapValueType(), + typeFactory)); } @SuppressWarnings("unchecked") @@ -154,32 +168,17 @@ private Iterable fromIterableValue( Iterable rowIterable, FieldValueTypeInformation elementTypeInformation, Factory> typeFactory) { - return new Iterable() { - @Override - public Iterator iterator() { - return new Iterator() { - Iterator innerIter = rowIterable.iterator(); - - @Override - public boolean hasNext() { - return innerIter.hasNext(); - } - - @Override - public ElementT next() { - ElementT element = innerIter.next(); - return fromValue( + return Iterables.transform( + rowIterable, + element -> + fromValue( elementType, element, elementTypeInformation.getType().getType(), elementTypeInformation.getElementType(), elementTypeInformation.getMapKeyType(), elementTypeInformation.getMapValueType(), - typeFactory); - } - }; - } - }; + typeFactory)); } @SuppressWarnings("unchecked") diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index b12ad5e3f9c0a..c9980377f18ff 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -75,6 +75,11 @@ public boolean equals(Object other) { public int hashCode() { return Arrays.hashCode(array); } + + @Override + public String toString() { + return Arrays.toString(array); + } } // A mapping between field names an indices. private final BiMap fieldIndices = HashBiMap.create(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index 9604b950bde75..791dafbb40f0b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -31,6 +31,8 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.SortedMap; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -44,6 +46,7 @@ import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.description.type.TypeDescription; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.description.type.TypeDescription.ForLoadedType; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.DynamicType; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.loading.ClassLoadingStrategy; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.scaffold.InstrumentedType; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.Implementation; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.ByteCodeAppender; @@ -64,8 +67,12 @@ import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.matcher.ElementMatchers; import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.utility.RandomString; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Primitives; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ClassUtils; import org.joda.time.DateTimeZone; @@ -82,7 +89,7 @@ public class ByteBuddyUtils { private static final ForLoadedType CHAR_SEQUENCE_TYPE = new ForLoadedType(CharSequence.class); private static final ForLoadedType INSTANT_TYPE = new ForLoadedType(Instant.class); private static final ForLoadedType DATE_TIME_ZONE_TYPE = new ForLoadedType(DateTimeZone.class); - private static final ForLoadedType LIST_TYPE = new ForLoadedType(List.class); + private static final ForLoadedType COLLECTION_TYPE = new ForLoadedType(Collection.class); private static final ForLoadedType READABLE_INSTANT_TYPE = new ForLoadedType(ReadableInstant.class); private static final ForLoadedType READABLE_PARTIAL_TYPE = @@ -90,6 +97,8 @@ public class ByteBuddyUtils { private static final ForLoadedType OBJECT_TYPE = new ForLoadedType(Object.class); private static final ForLoadedType INTEGER_TYPE = new ForLoadedType(Integer.class); private static final ForLoadedType ENUM_TYPE = new ForLoadedType(Enum.class); + private static final ForLoadedType BYTE_BUDDY_UTILS_TYPE = + new ForLoadedType(ByteBuddyUtils.class); /** * A naming strategy for ByteBuddy classes. @@ -98,7 +107,7 @@ public class ByteBuddyUtils { * This way, if the class fields or methods are package private, our generated class can still * access them. */ - static class InjectPackageStrategy extends NamingStrategy.AbstractBase { + public static class InjectPackageStrategy extends NamingStrategy.AbstractBase { /** A resolver for the base name for naming the unnamed type. */ private static final BaseNameResolver baseNameResolver = BaseNameResolver.ForUnnamedType.INSTANCE; @@ -123,6 +132,30 @@ protected String name(TypeDescription superClass) { } }; + // Create a new FieldValueGetter subclass. + @SuppressWarnings("unchecked") + static DynamicType.Builder subclassGetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { + TypeDescription.Generic getterGenericType = + TypeDescription.Generic.Builder.parameterizedType( + FieldValueGetter.class, objectType, fieldType) + .build(); + return (DynamicType.Builder) + byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType); + } + + // Create a new FieldValueSetter subclass. + @SuppressWarnings("unchecked") + static DynamicType.Builder subclassSetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { + TypeDescription.Generic setterGenericType = + TypeDescription.Generic.Builder.parameterizedType( + FieldValueSetter.class, objectType, fieldType) + .build(); + return (DynamicType.Builder) + byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType); + } + public interface TypeConversionsFactory { TypeConversion createTypeConversion(boolean returnRawTypes); @@ -148,30 +181,6 @@ public TypeConversion createSetterConversions(StackManipulati } } - // Create a new FieldValueGetter subclass. - @SuppressWarnings("unchecked") - static DynamicType.Builder subclassGetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { - TypeDescription.Generic getterGenericType = - TypeDescription.Generic.Builder.parameterizedType( - FieldValueGetter.class, objectType, fieldType) - .build(); - return (DynamicType.Builder) - byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType); - } - - // Create a new FieldValueSetter subclass. - @SuppressWarnings("unchecked") - static DynamicType.Builder subclassSetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { - TypeDescription.Generic setterGenericType = - TypeDescription.Generic.Builder.parameterizedType( - FieldValueSetter.class, objectType, fieldType) - .build(); - return (DynamicType.Builder) - byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType); - } - // Base class used below to convert types. @SuppressWarnings("unchecked") public abstract static class TypeConversion { @@ -195,7 +204,9 @@ public T convert(TypeDescriptor typeDescriptor) { } else if (typeDescriptor.getRawType().isEnum()) { return convertEnum(typeDescriptor); } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Collection.class))) { + if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(List.class))) { + return convertList(typeDescriptor); + } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Collection.class))) { return convertCollection(typeDescriptor); } else { return convertIterable(typeDescriptor); @@ -211,6 +222,8 @@ public T convert(TypeDescriptor typeDescriptor) { protected abstract T convertCollection(TypeDescriptor type); + protected abstract T convertList(TypeDescriptor type); + protected abstract T convertMap(TypeDescriptor type); protected abstract T convertDateTime(TypeDescriptor type); @@ -253,18 +266,26 @@ protected ConvertType(boolean returnRawTypes) { @Override protected Type convertArray(TypeDescriptor type) { - TypeDescriptor ret = createListType(type); + TypeDescriptor ret = createCollectionType(type.getComponentType()); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertCollection(TypeDescriptor type) { - return Collection.class; + TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + return returnRawTypes ? ret.getRawType() : ret.getType(); + } + + @Override + protected Type convertList(TypeDescriptor type) { + TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - return Iterable.class; + TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); + return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override @@ -305,11 +326,190 @@ protected Type convertDefault(TypeDescriptor type) { } @SuppressWarnings("unchecked") - private TypeDescriptor> createListType(TypeDescriptor type) { - TypeDescriptor componentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(type.getComponentType().getRawType())); - return new TypeDescriptor>() {}.where( - new TypeParameter() {}, componentType); + private TypeDescriptor> createCollectionType( + TypeDescriptor componentType) { + TypeDescriptor wrappedComponentType = + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + return new TypeDescriptor>() {}.where( + new TypeParameter() {}, wrappedComponentType); + } + + @SuppressWarnings("unchecked") + private TypeDescriptor> createIterableType( + TypeDescriptor componentType) { + TypeDescriptor wrappedComponentType = + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + return new TypeDescriptor>() {}.where( + new TypeParameter() {}, wrappedComponentType); + } + } + + private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); + + // When processing a container (e.g. List) we need to recursively process the element type. + // This function + // generates a subclass of Function that can be used to recursively transform each element of the + // container. + static Class createCollectionTransformFunction( + Type fromType, Type toType, Function convertElement) { + // Generate a TypeDescription for the class we want to generate. + TypeDescription.Generic functionGenericType = + TypeDescription.Generic.Builder.parameterizedType( + Function.class, Primitives.wrap((Class) fromType), Primitives.wrap((Class) toType)) + .build(); + + DynamicType.Builder builder = + (DynamicType.Builder) + BYTE_BUDDY + .subclass(functionGenericType) + .method(ElementMatchers.named("apply")) + .intercept( + new Implementation() { + @Override + public ByteCodeAppender appender(Target target) { + return (methodVisitor, implementationContext, instrumentedMethod) -> { + // this + method parameters. + int numLocals = 1 + instrumentedMethod.getParameters().size(); + + StackManipulation readValue = MethodVariableAccess.REFERENCE.loadFrom(1); + StackManipulation stackManipulation = + new StackManipulation.Compound( + convertElement.apply(readValue), MethodReturn.REFERENCE); + + StackManipulation.Size size = + stackManipulation.apply(methodVisitor, implementationContext); + return new Size(size.getMaximalSize(), numLocals); + }; + } + + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + return instrumentedType; + } + }); + + return builder + .make() + .load(ByteBuddyUtils.class.getClassLoader(), ClassLoadingStrategy.Default.INJECTION) + .getLoaded(); + } + + // A function to transform a container, special casing List and Collection types. This is used in + // byte-buddy + // generated code. + public static Iterable transformContainer( + Iterable iterable, Function function) { + if (iterable instanceof List) { + return Lists.transform((List) iterable, function); + } else if (iterable instanceof Collection) { + return Collections2.transform((Collection) iterable, function); + } else { + return Iterables.transform(iterable, function); + } + } + + static StackManipulation createTransformingContainer( + ForLoadedType functionType, StackManipulation readValue) { + StackManipulation stackManipulation = + new Compound( + readValue, + TypeCreation.of(functionType), + Duplication.SINGLE, + MethodInvocation.invoke( + functionType + .getDeclaredMethods() + .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) + .getOnly()), + MethodInvocation.invoke( + BYTE_BUDDY_UTILS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("transformContainer")) + .getOnly())); + return stackManipulation; + } + + public static TransformingMap getTransformingMap( + Map sourceMap, Function keyFunction, Function valueFunction) { + return new TransformingMap<>(sourceMap, keyFunction, valueFunction); + } + + public static class TransformingMap implements Map { + private final Map delegateMap; + + public TransformingMap( + Map sourceMap, Function keyFunction, Function valueFunction) { + if (sourceMap instanceof SortedMap) { + delegateMap = + (Map) + Maps.newTreeMap(); // We don't support copying the comparator. Makes no sense if key + // is changing. + } else { + delegateMap = Maps.newHashMap(); + } + for (Map.Entry entry : sourceMap.entrySet()) { + delegateMap.put(keyFunction.apply(entry.getKey()), valueFunction.apply(entry.getValue())); + } + } + + @Override + public int size() { + return delegateMap.size(); + } + + @Override + public boolean isEmpty() { + return delegateMap.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return delegateMap.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return delegateMap.containsValue(value); + } + + @Override + public V2 get(Object key) { + return delegateMap.get(key); + } + + @Override + public V2 put(K2 key, V2 value) { + return delegateMap.put(key, value); + } + + @Override + public V2 remove(Object key) { + return delegateMap.remove(key); + } + + @Override + public void putAll(Map m) { + delegateMap.putAll(m); + } + + @Override + public void clear() { + delegateMap.clear(); + ; + } + + @Override + public Set keySet() { + return delegateMap.keySet(); + } + + @Override + public Collection values() { + return delegateMap.values(); + } + + @Override + public Set> entrySet() { + return delegateMap.entrySet(); } } @@ -338,46 +538,153 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isComponentTypePrimitive ? Arrays.asList(ArrayUtils.toObject(value)) // : Arrays.asList(value); - ForLoadedType loadedType = new ForLoadedType(type.getRawType()); - StackManipulation stackManipulation = readValue; + TypeDescriptor componentType = type.getComponentType(); + ForLoadedType loadedArrayType = new ForLoadedType(type.getRawType()); + StackManipulation readArrayValue = readValue; // Row always expects to get an Iterable back for array types. Wrap this array into a // List using Arrays.asList before returning. - if (loadedType.getComponentType().isPrimitive()) { + if (loadedArrayType.getComponentType().isPrimitive()) { // Arrays.asList doesn't take primitive arrays, so convert first using ArrayUtils.toObject. - stackManipulation = + readArrayValue = new Compound( - stackManipulation, + readArrayValue, MethodInvocation.invoke( ARRAY_UTILS_TYPE .getDeclaredMethods() .filter( ElementMatchers.isStatic() .and(ElementMatchers.named("toObject")) - .and(ElementMatchers.takesArguments(loadedType))) + .and(ElementMatchers.takesArguments(loadedArrayType))) .getOnly())); + + componentType = TypeDescriptor.of(Primitives.wrap(componentType.getRawType())); + } + // Now convert to a List object. + StackManipulation readListValue = + new Compound( + readArrayValue, + MethodInvocation.invoke( + ARRAYS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.isStatic().and(ElementMatchers.named("asList"))) + .getOnly())); + + // Generate a SerializableFunction to convert the element-type objects. + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + if (!finalComponentType.hasUnresolvedParameters()) { + Type convertedComponentType = + getFactory().createTypeConversion(true).convert(componentType); + ForLoadedType functionType = + new ForLoadedType( + createCollectionTransformFunction( + componentType.getRawType(), + convertedComponentType, + (s) -> getFactory().createGetterConversions(s).convert(finalComponentType))); + return createTransformingContainer(functionType, readListValue); + } else { + return readListValue; } - return new Compound( - stackManipulation, - MethodInvocation.invoke( - ARRAYS_TYPE - .getDeclaredMethods() - .filter(ElementMatchers.isStatic().and(ElementMatchers.named("asList"))) - .getOnly())); } @Override protected StackManipulation convertIterable(TypeDescriptor type) { - return readValue; + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + if (!finalComponentType.hasUnresolvedParameters()) { + ForLoadedType functionType = + new ForLoadedType( + createCollectionTransformFunction( + componentType.getRawType(), + convertedComponentType, + (s) -> getFactory().createGetterConversions(s).convert(finalComponentType))); + return createTransformingContainer(functionType, readValue); + } else { + return readValue; + } } @Override protected StackManipulation convertCollection(TypeDescriptor type) { - return readValue; + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + if (!finalComponentType.hasUnresolvedParameters()) { + ForLoadedType functionType = + new ForLoadedType( + createCollectionTransformFunction( + componentType.getRawType(), + convertedComponentType, + (s) -> getFactory().createGetterConversions(s).convert(finalComponentType))); + return createTransformingContainer(functionType, readValue); + } else { + return readValue; + } + } + + @Override + protected StackManipulation convertList(TypeDescriptor type) { + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + if (!finalComponentType.hasUnresolvedParameters()) { + ForLoadedType functionType = + new ForLoadedType( + createCollectionTransformFunction( + componentType.getRawType(), + convertedComponentType, + (s) -> getFactory().createGetterConversions(s).convert(finalComponentType))); + return createTransformingContainer(functionType, readValue); + } else { + return readValue; + } } @Override protected StackManipulation convertMap(TypeDescriptor type) { - return readValue; + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); + + Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); + Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); + + if (!keyType.hasUnresolvedParameters() && !valueType.hasUnresolvedParameters()) { + ForLoadedType keyFunctionType = + new ForLoadedType( + createCollectionTransformFunction( + keyType.getRawType(), + convertedKeyType, + (s) -> getFactory().createGetterConversions(s).convert(keyType))); + ForLoadedType valueFunctionType = + new ForLoadedType( + createCollectionTransformFunction( + valueType.getRawType(), + convertedValueType, + (s) -> getFactory().createGetterConversions(s).convert(valueType))); + return new Compound( + readValue, + TypeCreation.of(keyFunctionType), + Duplication.SINGLE, + MethodInvocation.invoke( + keyFunctionType + .getDeclaredMethods() + .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) + .getOnly()), + TypeCreation.of(valueFunctionType), + Duplication.SINGLE, + MethodInvocation.invoke( + valueFunctionType + .getDeclaredMethods() + .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) + .getOnly()), + MethodInvocation.invoke( + BYTE_BUDDY_UTILS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("getTransformingMap")) + .getOnly())); + } else { + return readValue; + } } @Override @@ -529,7 +836,7 @@ protected StackManipulation convertDefault(TypeDescriptor type) { * there. This class generates code to convert between these types. */ public static class ConvertValueForSetter extends TypeConversion { - StackManipulation readValue; + protected StackManipulation readValue; protected ConvertValueForSetter(StackManipulation readValue) { this.readValue = readValue; @@ -553,18 +860,31 @@ protected StackManipulation convertArray(TypeDescriptor type) { .build() .asErasure(); + Type rowElementType = + getFactory().createTypeConversion(false).convert(type.getComponentType()); + final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(type.getComponentType()); + if (!arrayElementType.hasUnresolvedParameters()) { + ForLoadedType conversionFunction = + new ForLoadedType( + createCollectionTransformFunction( + TypeDescriptor.of(rowElementType).getRawType(), + Primitives.wrap(arrayElementType.getRawType()), + (s) -> getFactory().createSetterConversions(s).convert(arrayElementType))); + readValue = createTransformingContainer(conversionFunction, readValue); + } + // Extract an array from the collection. StackManipulation stackManipulation = new Compound( readValue, - TypeCasting.to(LIST_TYPE), + TypeCasting.to(COLLECTION_TYPE), // Call Collection.toArray(T[[]) to extract the array. Push new T[0] on the stack // before // calling toArray. ArrayFactory.forType(loadedType.getComponentType().asBoxed().asGenericType()) .withValues(Collections.emptyList()), MethodInvocation.invoke( - LIST_TYPE + COLLECTION_TYPE .getDeclaredMethods() .filter( ElementMatchers.named("toArray").and(ElementMatchers.takesArguments(1))) @@ -591,16 +911,128 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - return readValue; + Type rowElementType = + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getIterableComponentType(type)); + final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); + if (!iterableElementType.hasUnresolvedParameters()) { + ForLoadedType conversionFunction = + new ForLoadedType( + createCollectionTransformFunction( + TypeDescriptor.of(rowElementType).getRawType(), + iterableElementType.getRawType(), + (s) -> getFactory().createSetterConversions(s).convert(iterableElementType))); + StackManipulation transformedContainer = + createTransformingContainer(conversionFunction, readValue); + return transformedContainer; + } else { + return readValue; + } } @Override protected StackManipulation convertCollection(TypeDescriptor type) { - return readValue; + Type rowElementType = + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getIterableComponentType(type)); + final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + + if (!collectionElementType.hasUnresolvedParameters()) { + ForLoadedType conversionFunction = + new ForLoadedType( + createCollectionTransformFunction( + TypeDescriptor.of(rowElementType).getRawType(), + collectionElementType.getRawType(), + (s) -> getFactory().createSetterConversions(s).convert(collectionElementType))); + StackManipulation transformedContainer = + createTransformingContainer(conversionFunction, readValue); + return transformedContainer; + } else { + return readValue; + } + } + + @Override + protected StackManipulation convertList(TypeDescriptor type) { + Type rowElementType = + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getIterableComponentType(type)); + final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + + if (!collectionElementType.hasUnresolvedParameters()) { + ForLoadedType conversionFunction = + new ForLoadedType( + createCollectionTransformFunction( + TypeDescriptor.of(rowElementType).getRawType(), + collectionElementType.getRawType(), + (s) -> getFactory().createSetterConversions(s).convert(collectionElementType))); + readValue = createTransformingContainer(conversionFunction, readValue); + } + // TODO: Don't copy if already a list! + StackManipulation transformedList = + new Compound( + readValue, + MethodInvocation.invoke( + new ForLoadedType(Lists.class) + .getDeclaredMethods() + .filter( + ElementMatchers.named("newArrayList") + .and(ElementMatchers.takesArguments(Iterable.class))) + .getOnly())); + return transformedList; } @Override protected StackManipulation convertMap(TypeDescriptor type) { + Type rowKeyType = + getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); + final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); + Type rowValueType = + getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); + final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); + + if (!keyElementType.hasUnresolvedParameters() + && !valueElementType.hasUnresolvedParameters()) { + ForLoadedType keyConversionFunction = + new ForLoadedType( + createCollectionTransformFunction( + TypeDescriptor.of(rowKeyType).getRawType(), + keyElementType.getRawType(), + (s) -> getFactory().createSetterConversions(s).convert(keyElementType))); + ForLoadedType valueConversionFunction = + new ForLoadedType( + createCollectionTransformFunction( + TypeDescriptor.of(rowValueType).getRawType(), + valueElementType.getRawType(), + (s) -> getFactory().createSetterConversions(s).convert(valueElementType))); + readValue = + new Compound( + readValue, + TypeCreation.of(keyConversionFunction), + Duplication.SINGLE, + MethodInvocation.invoke( + keyConversionFunction + .getDeclaredMethods() + .filter( + ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) + .getOnly()), + TypeCreation.of(valueConversionFunction), + Duplication.SINGLE, + MethodInvocation.invoke( + valueConversionFunction + .getDeclaredMethods() + .filter( + ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) + .getOnly()), + MethodInvocation.invoke( + BYTE_BUDDY_UTILS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("getTransformingMap")) + .getOnly())); + } return readValue; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index b9f1ae5980f47..d56f0bd152f4d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -23,8 +23,11 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.security.InvalidParameterException; import java.util.Arrays; +import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -33,16 +36,19 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Primitives; /** A set of reflection helper methods. */ public class ReflectUtils { - static class ClassWithSchema { + /** Represents a class and a schema. */ + public static class ClassWithSchema { private final Class clazz; private final Schema schema; - ClassWithSchema(Class clazz, Schema schema) { + public ClassWithSchema(Class clazz, Schema schema) { this.clazz = clazz; this.schema = schema; } @@ -78,6 +84,9 @@ public static List getMethods(Class clazz) { clazz, c -> { return Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. .filter(m -> !Modifier.isPrivate(m.getModifiers())) .filter(m -> !Modifier.isProtected(m.getModifiers())) .filter(m -> !Modifier.isStatic(m.getModifiers())) @@ -183,4 +192,49 @@ public static String stripGetterPrefix(String method) { public static String stripSetterPrefix(String method) { return stripPrefix(method, "set"); } + + /** For an array T[] or a subclass of Iterable, return a TypeDescriptor describing T. */ + @Nullable + public static TypeDescriptor getIterableComponentType(TypeDescriptor valueType) { + TypeDescriptor componentType = null; + if (valueType.isArray()) { + Type component = valueType.getComponentType().getType(); + if (!component.equals(byte.class)) { + // Byte arrays are special cased since we have a schema type corresponding to them. + componentType = TypeDescriptor.of(component); + } + } else if (valueType.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { + TypeDescriptor> collection = valueType.getSupertype(Iterable.class); + if (collection.getType() instanceof ParameterizedType) { + ParameterizedType ptype = (ParameterizedType) collection.getType(); + java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); + checkArgument(params.length == 1); + componentType = TypeDescriptor.of(params[0]); + } else { + throw new RuntimeException("Collection parameter is not parameterized!"); + } + } + return componentType; + } + + public static TypeDescriptor getMapType(TypeDescriptor valueType, int index) { + TypeDescriptor mapType = null; + if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) { + TypeDescriptor> map = valueType.getSupertype(Map.class); + if (map.getType() instanceof ParameterizedType) { + ParameterizedType ptype = (ParameterizedType) map.getType(); + java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); + mapType = TypeDescriptor.of(params[index]); + } else { + throw new RuntimeException("Map type is not parameterized! " + map); + } + } + return mapType; + } + + public static TypeDescriptor boxIfPrimitive(TypeDescriptor typeDescriptor) { + return typeDescriptor.getRawType().isPrimitive() + ? TypeDescriptor.of(Primitives.wrap(typeDescriptor.getRawType())) + : typeDescriptor; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index d437b063ef145..be28467b97b2d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -25,6 +25,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -187,7 +188,7 @@ public Boolean getBoolean(String fieldName) { * match. */ @Nullable - public List getArray(String fieldName) { + public Collection getArray(String fieldName) { return getArray(getSchema().indexOf(fieldName)); } @@ -332,7 +333,7 @@ public Boolean getBoolean(int idx) { * match. */ @Nullable - public List getArray(int idx) { + public Collection getArray(int idx) { return getValue(idx); } @@ -421,8 +422,8 @@ public static boolean deepEquals(Object a, Object b, Schema.FieldType fieldType) } else if (fieldType.getTypeName() == Schema.TypeName.BYTES) { return Arrays.equals((byte[]) a, (byte[]) b); } else if (fieldType.getTypeName() == TypeName.ARRAY) { - return deepEqualsForList( - (List) a, (List) b, fieldType.getCollectionElementType()); + return deepEqualsForCollection( + (Collection) a, (Collection) b, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == TypeName.ITERABLE) { return deepEqualsForIterable( (Iterable) a, (Iterable) b, fieldType.getCollectionElementType()); @@ -493,7 +494,8 @@ static int deepHashCodeForMap( return h; } - static boolean deepEqualsForList(List a, List b, Schema.FieldType elementType) { + static boolean deepEqualsForCollection( + Collection a, Collection b, Schema.FieldType elementType) { if (a == b) { return true; } @@ -584,7 +586,7 @@ public Builder addValues(Object... values) { return addValues(Arrays.asList(values)); } - public Builder addArray(List values) { + public Builder addArray(Collection values) { this.values.add(values); return this; } @@ -662,16 +664,16 @@ private Object verifyLogicalType(Object value, LogicalType logicalType, String f private List verifyArray( Object value, FieldType collectionElementType, String fieldName) { boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof List)) { + if (!(value instanceof Collection)) { throw new IllegalArgumentException( String.format( - "For field name %s and array type expected List class. Instead " + "For field name %s and array type expected Collection class. Instead " + "class type was %s.", fieldName, value.getClass())); } - List valueList = (List) value; - List verifiedList = Lists.newArrayListWithCapacity(valueList.size()); - for (Object listValue : valueList) { + Collection valueCollection = (Collection) value; + List verifiedList = Lists.newArrayListWithCapacity(valueCollection.size()); + for (Object listValue : valueCollection) { if (listValue == null) { if (!collectionElementTypeNullable) { throw new IllegalArgumentException( @@ -696,8 +698,8 @@ private Iterable verifyIterable( + "class type was %s.", fieldName, value.getClass())); } - Iterable valueList = (Iterable) value; - for (Object listValue : valueList) { + Iterable valueIterable = (Iterable) value; + for (Object listValue : valueIterable) { if (listValue == null) { if (!collectionElementTypeNullable) { throw new IllegalArgumentException( @@ -708,7 +710,7 @@ private Iterable verifyIterable( verify(listValue, collectionElementType, fieldName); } } - return valueList; + return valueIterable; } private Map verifyMap( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 0d787311c6b7f..ebf59b9216fcf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import java.util.Iterator; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; @@ -29,6 +29,8 @@ import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; @@ -44,7 +46,7 @@ public class RowWithGetters extends Row { private final Object getterTarget; private final List getters; - private final Map cachedLists = Maps.newHashMap(); + private final Map cachedCollections = Maps.newHashMap(); private final Map cachedIterables = Maps.newHashMap(); private final Map cachedMaps = Maps.newHashMap(); @@ -69,36 +71,22 @@ public T getValue(int fieldIdx) { return fieldValue != null ? getValue(type, fieldValue, fieldIdx) : null; } - private List getListValue(FieldType elementType, Object fieldValue) { - Iterable iterable = (Iterable) fieldValue; - List list = Lists.newArrayList(); - for (Object o : iterable) { - list.add(getValue(elementType, o, null)); + private Collection getCollectionValue(FieldType elementType, Object fieldValue) { + Collection collection = (Collection) fieldValue; + if (collection instanceof List) { + // For performance reasons if the input is a list, make sure that we produce a list. Otherwise + // Row forwarding + // is forced to physically copy the collection into a new List object. + return Lists.transform((List) collection, v -> getValue(elementType, v, null)); + } else { + return Collections2.transform(collection, v -> getValue(elementType, v, null)); } - return list; } private Iterable getIterableValue(FieldType elementType, Object fieldValue) { Iterable iterable = (Iterable) fieldValue; // Wrap the iterable to avoid having to materialize the entire collection. - return new Iterable() { - @Override - public Iterator iterator() { - return new Iterator() { - Iterator iterator = iterable.iterator(); - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - - @Override - public Object next() { - return getValue(elementType, iterator.next(), null); - } - }; - } - }; + return Iterables.transform(iterable, v -> getValue(elementType, v, null)); } private Map getMapValue(FieldType keyType, FieldType valueType, Map fieldValue) { @@ -117,9 +105,9 @@ private T getValue(FieldType type, Object fieldValue, @Nullable Integer cach } else if (type.getTypeName().equals(TypeName.ARRAY)) { return cacheKey != null ? (T) - cachedLists.computeIfAbsent( - cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue)) - : (T) getListValue(type.getCollectionElementType(), fieldValue); + cachedCollections.computeIfAbsent( + cacheKey, i -> getCollectionValue(type.getCollectionElementType(), fieldValue)) + : (T) getCollectionValue(type.getCollectionElementType(), fieldValue); } else if (type.getTypeName().equals(TypeName.ITERABLE)) { return cacheKey != null ? (T) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 6ceac8b907be9..feb51db325793 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ITERABLE_BEAM_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.NESTED_ARRAYS_BEAM_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.NESTED_ARRAY_BEAN_SCHEMA; @@ -30,11 +31,13 @@ import static org.junit.Assert.assertTrue; import java.math.BigDecimal; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; +import org.apache.beam.sdk.schemas.utils.TestJavaBeans.ArrayOfByteArray; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.IterableBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.MismatchingNullableBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.NestedArrayBean; @@ -272,7 +275,7 @@ public void testRecursiveArrayGetters() throws NoSuchSchemaException { NestedArrayBean bean = new NestedArrayBean(simple1, simple2, simple3); Row row = registry.getToRowFunction(NestedArrayBean.class).apply(bean); - List rows = row.getArray("beans"); + List rows = (List) row.getArray("beans"); assertSame(simple1, registry.getFromRowFunction(SimpleBean.class).apply(rows.get(0))); assertSame(simple2, registry.getFromRowFunction(SimpleBean.class).apply(rows.get(1))); assertSame(simple3, registry.getFromRowFunction(SimpleBean.class).apply(rows.get(2))); @@ -422,4 +425,38 @@ public void testFromRowIterable() throws NoSuchSchemaException { list.add("three"); assertEquals(list, Lists.newArrayList(converted.getStrings())); } + + @Test + public void testToRowArrayOfBytes() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema schema = registry.getSchema(ArrayOfByteArray.class); + SchemaTestUtils.assertSchemaEquivalent(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA, schema); + + ArrayOfByteArray arrayOfByteArray = + new ArrayOfByteArray( + ImmutableList.of(ByteBuffer.wrap(BYTE_ARRAY), ByteBuffer.wrap(BYTE_ARRAY))); + Row expectedRow = + Row.withSchema(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA) + .addArray(ImmutableList.of(BYTE_ARRAY, BYTE_ARRAY)) + .build(); + Row converted = registry.getToRowFunction(ArrayOfByteArray.class).apply(arrayOfByteArray); + assertEquals(expectedRow, converted); + } + + @Test + public void testFromRowArrayOfBytes() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema schema = registry.getSchema(ArrayOfByteArray.class); + SchemaTestUtils.assertSchemaEquivalent(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA, schema); + + ArrayOfByteArray expectedArrayOfByteArray = + new ArrayOfByteArray( + ImmutableList.of(ByteBuffer.wrap(BYTE_ARRAY), ByteBuffer.wrap(BYTE_ARRAY))); + Row row = + Row.withSchema(ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA) + .addArray(ImmutableList.of(BYTE_ARRAY, BYTE_ARRAY)) + .build(); + ArrayOfByteArray converted = registry.getFromRowFunction(ArrayOfByteArray.class).apply(row); + assertEquals(expectedArrayOfByteArray, converted); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 992c6bb62286d..4134a577ac27b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -305,7 +305,7 @@ public void testRecursiveArrayGetters() throws NoSuchSchemaException { NestedArrayPOJO pojo = new NestedArrayPOJO(simple1, simple2, simple3); Row row = registry.getToRowFunction(NestedArrayPOJO.class).apply(pojo); - List rows = row.getArray("pojos"); + List rows = (List) row.getArray("pojos"); assertSame(simple1, registry.getFromRowFunction(SimplePOJO.class).apply(rows.get(0))); assertSame(simple2, registry.getFromRowFunction(SimplePOJO.class).apply(rows.get(1))); assertSame(simple3, registry.getFromRowFunction(SimplePOJO.class).apply(rows.get(2))); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java index 73aa0855245ad..81517afcdf146 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CoGroupTest.java @@ -23,6 +23,7 @@ import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.junit.Assert.assertThat; +import java.util.Collection; import java.util.List; import org.apache.beam.sdk.TestUtils.KvMatcher; import org.apache.beam.sdk.schemas.Schema; @@ -691,7 +692,7 @@ private static Void containsJoinedFields( Schema valueSchema = value.getSchema(); for (int i = 0; i < valueSchema.getFieldCount(); ++i) { assertEquals(TypeName.ARRAY, valueSchema.getField(i).getType().getTypeName()); - fieldMatchers.add(new ArrayFieldMatchesAnyOrder(i, value.getArray(i))); + fieldMatchers.add(new ArrayFieldMatchesAnyOrder(i, (List) value.getArray(i))); } matchers.add( KvMatcher.isKv(equalTo(row.getKey()), allOf(fieldMatchers.toArray(new Matcher[0])))); @@ -715,7 +716,7 @@ public boolean matches(Object item) { return false; } Row row = (Row) item; - List actual = row.getArray(fieldIndex); + Collection actual = row.getArray(fieldIndex); return containsInAnyOrder(expected).matches(actual); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java index 15d1379ff815e..6deab6d2c564f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java @@ -424,12 +424,17 @@ public boolean equals(Object o) { return false; } PartialRowMultipleArray that = (PartialRowMultipleArray) o; - return Objects.equals(field1, that.field1); + return Objects.equals(field1, that.field1) && Objects.equals(field3, that.field3); } @Override public int hashCode() { - return Objects.hash(field1); + return Objects.hash(field1, field3); + } + + @Override + public String toString() { + return "PartialRowMultipleArray{" + "field1=" + field1 + ", field3=" + field3 + '}'; } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index f137477d79b9a..32cf264591c9f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -972,4 +972,45 @@ public int hashCode() { /** The schema for {@link NestedArrayBean}. * */ public static final Schema ITERABLE_BEAM_SCHEMA = Schema.builder().addIterableField("strings", FieldType.STRING).build(); + + /** A bean containing an Array of ByteArray. * */ + @DefaultSchema(JavaBeanSchema.class) + public static class ArrayOfByteArray { + private List byteBuffers; + + public ArrayOfByteArray(List byteBuffers) { + this.byteBuffers = byteBuffers; + } + + public ArrayOfByteArray() {} + + public List getByteBuffers() { + return byteBuffers; + } + + public void setByteBuffers(List byteBuffers) { + this.byteBuffers = byteBuffers; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayOfByteArray that = (ArrayOfByteArray) o; + return Objects.equals(byteBuffers, that.byteBuffers); + } + + @Override + public int hashCode() { + return Objects.hash(byteBuffers); + } + } + + /** The schema for {@link NestedArrayBean}. * */ + public static final Schema ARRAY_OF_BYTE_ARRAY_BEAM_SCHEMA = + Schema.builder().addArrayField("byteBuffers", FieldType.BYTES).build(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java index 1263b3d0acf24..9f27a4ad93756 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.extensions.sql.impl.rel; -import java.util.List; +import java.util.Collection; import javax.annotation.Nullable; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel; import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats; @@ -129,7 +129,7 @@ private UnnestFn(Schema outputSchema, int unnestIndex) { @ProcessElement public void process(@Element Row row, OutputReceiver out) { - @Nullable List rawValues = row.getArray(unnestIndex); + @Nullable Collection rawValues = row.getArray(unnestIndex); if (rawValues == null) { return;