diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java index 35ba2ff19dd1b..ef5546069d4a9 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java @@ -238,7 +238,7 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { + public Object schemaElement(int index) { throw new UnsupportedOperationException(); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 2009f90d1a427..7a7bc60d8b155 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -121,7 +121,7 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { + public Object schemaElement(int index) { throw new UnsupportedOperationException("Not supported in SplittableDoFn"); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index bbe27308ae849..8fd2ad3acf143 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -39,6 +39,7 @@ import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFnOutputReceivers; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; @@ -301,7 +302,7 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { + public Object schemaElement(int index) { throw new UnsupportedOperationException( "Element parameters are not supported outside of @ProcessElement method."); } @@ -415,7 +416,7 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { + public Object schemaElement(int index) { throw new UnsupportedOperationException( "Cannot access element outside of @ProcessElement method."); } @@ -631,9 +632,9 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { - Row row = schemaCoder.getToRowFunction().apply(element()); - return doFnSchemaInformation.getElementParameterSchema().getFromRowFunction().apply(row); + public Object schemaElement(int index) { + SerializableFunction converter = doFnSchemaInformation.getElementConverters().get(index); + return converter.apply(element()); } @Override @@ -781,7 +782,7 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { + public Object schemaElement(int index) { throw new UnsupportedOperationException("Element parameters are not supported."); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java index 6f565acaeaa6c..8074500c6ecf6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java @@ -78,22 +78,52 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid @Nullable @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - SchemaProvider schemaProvider = providers.get(typeDescriptor); - return (schemaProvider != null) ? schemaProvider.schemaFor(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return schemaProvider.schemaFor(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } @Nullable @Override public SerializableFunction toRowFunction(TypeDescriptor typeDescriptor) { - SchemaProvider schemaProvider = providers.get(typeDescriptor); - return (schemaProvider != null) ? schemaProvider.toRowFunction(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return (SerializableFunction) schemaProvider.toRowFunction(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } @Nullable @Override public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { - SchemaProvider schemaProvider = providers.get(typeDescriptor); - return (schemaProvider != null) ? schemaProvider.fromRowFunction(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return (SerializableFunction) schemaProvider.fromRowFunction(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java index 00c93767476bc..d3b7d10b34d74 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; +import java.io.Serializable; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -71,46 +72,64 @@ * delegates to that provider. */ class DefaultSchemaProvider implements SchemaProvider { - final Map cachedProviders = Maps.newConcurrentMap(); + final Map cachedProviders = Maps.newConcurrentMap(); + + private static final class ProviderAndDescriptor implements Serializable { + final SchemaProvider schemaProvider; + final TypeDescriptor typeDescriptor; + + public ProviderAndDescriptor( + SchemaProvider schemaProvider, TypeDescriptor typeDescriptor) { + this.schemaProvider = schemaProvider; + this.typeDescriptor = typeDescriptor; + } + } @Nullable - private SchemaProvider getSchemaProvider(TypeDescriptor typeDescriptor) { + private ProviderAndDescriptor getSchemaProvider(TypeDescriptor typeDescriptor) { return cachedProviders.computeIfAbsent( typeDescriptor, type -> { Class clazz = type.getRawType(); - DefaultSchema annotation = clazz.getAnnotation(DefaultSchema.class); - if (annotation == null) { - return null; - } - Class providerClass = annotation.value(); - checkArgument( - providerClass != null, - "Type " + type + " has a @DefaultSchema annotation with a null argument."); + do { + DefaultSchema annotation = clazz.getAnnotation(DefaultSchema.class); + if (annotation != null) { + Class providerClass = annotation.value(); + checkArgument( + providerClass != null, + "Type " + type + " has a @DefaultSchema annotation with a null argument."); - try { - return providerClass.getDeclaredConstructor().newInstance(); - } catch (NoSuchMethodException - | InstantiationException - | IllegalAccessException - | InvocationTargetException e) { - throw new IllegalStateException( - "Failed to create SchemaProvider " - + providerClass.getSimpleName() - + " which was" - + " specified as the default SchemaProvider for type " - + type - + ". Make " - + " sure that this class has a public default constructor.", - e); - } + try { + return new ProviderAndDescriptor( + providerClass.getDeclaredConstructor().newInstance(), + TypeDescriptor.of(clazz)); + } catch (NoSuchMethodException + | InstantiationException + | IllegalAccessException + | InvocationTargetException e) { + throw new IllegalStateException( + "Failed to create SchemaProvider " + + providerClass.getSimpleName() + + " which was" + + " specified as the default SchemaProvider for type " + + type + + ". Make " + + " sure that this class has a public default constructor.", + e); + } + } + clazz = clazz.getSuperclass(); + } while (clazz != null && !clazz.equals(Object.class)); + return null; }); } @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - SchemaProvider schemaProvider = getSchemaProvider(typeDescriptor); - return (schemaProvider != null) ? schemaProvider.schemaFor(typeDescriptor) : null; + ProviderAndDescriptor providerAndDescriptor = getSchemaProvider(typeDescriptor); + return (providerAndDescriptor != null) + ? providerAndDescriptor.schemaProvider.schemaFor(providerAndDescriptor.typeDescriptor) + : null; } /** @@ -119,8 +138,11 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { */ @Override public SerializableFunction toRowFunction(TypeDescriptor typeDescriptor) { - SchemaProvider schemaProvider = getSchemaProvider(typeDescriptor); - return (schemaProvider != null) ? schemaProvider.toRowFunction(typeDescriptor) : null; + ProviderAndDescriptor providerAndDescriptor = getSchemaProvider(typeDescriptor); + return (providerAndDescriptor != null) + ? providerAndDescriptor.schemaProvider.toRowFunction( + (TypeDescriptor) providerAndDescriptor.typeDescriptor) + : null; } /** @@ -129,8 +151,11 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc */ @Override public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { - SchemaProvider schemaProvider = getSchemaProvider(typeDescriptor); - return (schemaProvider != null) ? schemaProvider.fromRowFunction(typeDescriptor) : null; + ProviderAndDescriptor providerAndDescriptor = getSchemaProvider(typeDescriptor); + return (providerAndDescriptor != null) + ? providerAndDescriptor.schemaProvider.fromRowFunction( + (TypeDescriptor) providerAndDescriptor.typeDescriptor) + : null; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java index 9b01b352c333e..b137c6a062751 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java @@ -20,15 +20,13 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; -import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.utils.ConvertHelpers; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; @@ -99,7 +97,6 @@ public static PTransform, PCollection extends PTransform, PCollection> { TypeDescriptor outputTypeDescriptor; - Schema unboxedSchema = null; ConvertTransform(TypeDescriptor outputTypeDescriptor) { this.outputTypeDescriptor = outputTypeDescriptor; @@ -124,62 +121,34 @@ public PCollection expand(PCollection input) { throw new RuntimeException("Convert requires a schema on the input."); } - final SchemaCoder outputSchemaCoder; - boolean toRow = outputTypeDescriptor.equals(TypeDescriptor.of(Row.class)); - if (toRow) { - // If the output is of type Row, then just forward the schema of the input type to the - // output. - outputSchemaCoder = - (SchemaCoder) - SchemaCoder.of( - input.getSchema(), - SerializableFunctions.identity(), - SerializableFunctions.identity()); - } else { - // Otherwise, try to find a schema for the output type in the schema registry. - SchemaRegistry registry = input.getPipeline().getSchemaRegistry(); - try { - outputSchemaCoder = - SchemaCoder.of( - registry.getSchema(outputTypeDescriptor), - registry.getToRowFunction(outputTypeDescriptor), - registry.getFromRowFunction(outputTypeDescriptor)); - - Schema outputSchema = outputSchemaCoder.getSchema(); - if (!outputSchema.assignableToIgnoreNullable(input.getSchema())) { - // We also support unboxing nested Row schemas, so attempt that. - // TODO: Support unboxing to primitive types as well. - unboxedSchema = getBoxedNestedSchema(input.getSchema()); - if (unboxedSchema == null || !outputSchema.assignableToIgnoreNullable(unboxedSchema)) { - Schema checked = (unboxedSchema == null) ? input.getSchema() : unboxedSchema; - throw new RuntimeException( - "Cannot convert between types that don't have equivalent schemas." - + " input schema: " - + checked - + " output schema: " - + outputSchemaCoder.getSchema()); - } - } - } catch (NoSuchSchemaException e) { - throw new RuntimeException("No schema registered for " + outputTypeDescriptor); - } - } - - return input - .apply( + SchemaRegistry registry = input.getPipeline().getSchemaRegistry(); + ConvertHelpers.ConvertedSchemaInformation converted = + ConvertHelpers.getConvertedSchemaInformation( + input.getSchema(), outputTypeDescriptor, registry); + boolean unbox = converted.unboxedType != null; + PCollection output = + input.apply( ParDo.of( new DoFn() { @ProcessElement public void processElement(@Element Row row, OutputReceiver o) { // Read the row, potentially unboxing if necessary. - Row input = (unboxedSchema == null) ? row : row.getValue(0); - o.output(outputSchemaCoder.getFromRowFunction().apply(input)); + Object input = unbox ? row.getValue(0) : row; + // The output has a schema, so we need to convert to the appropriate type. + o.output(converted.outputSchemaCoder.getFromRowFunction().apply((Row) input)); } - })) - .setSchema( - outputSchemaCoder.getSchema(), - outputSchemaCoder.getToRowFunction(), - outputSchemaCoder.getFromRowFunction()); + })); + if (converted.outputSchemaCoder != null) { + output = + output.setSchema( + converted.outputSchemaCoder.getSchema(), + converted.outputSchemaCoder.getToRowFunction(), + converted.outputSchemaCoder.getFromRowFunction()); + } else { + // TODO: Support full unboxing and boxing in Create. + throw new RuntimeException("Unboxing is not yet supported in the Create transform"); + } + return output; } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java index 077cc33f34de3..7626869f4be98 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java @@ -64,7 +64,7 @@ * *
{@code
  * PCollection events = readUserEvents();
- * PCollection rows = event.apply(Select.fieldNames("location")
+ * PCollection rows = event.apply(Select.fieldNames("location")
  *                              .apply(Convert.to(Location.class));
  * }
*/ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java new file mode 100644 index 0000000000000..e74f85f415c43 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas.utils; + +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Type; +import javax.annotation.Nullable; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.DynamicType; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.dynamic.scaffold.InstrumentedType; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.bytecode.ByteCodeAppender; +import net.bytebuddy.implementation.bytecode.ByteCodeAppender.Size; +import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.member.MethodReturn; +import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; +import net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.sdk.schemas.JavaFieldSchema.JavaFieldTypeSupplier; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertType; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertValueForSetter; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v20_0.com.google.common.primitives.Primitives; + +/** Helper functions for converting between equivalent schema types. */ +public class ConvertHelpers { + /** Return value after converting a schema. */ + public static class ConvertedSchemaInformation implements Serializable { + // If the output type is a composite type, this is the schema coder. + @Nullable public final SchemaCoder outputSchemaCoder; + @Nullable public final FieldType unboxedType; + + public ConvertedSchemaInformation( + @Nullable SchemaCoder outputSchemaCoder, @Nullable FieldType unboxedType) { + this.outputSchemaCoder = outputSchemaCoder; + this.unboxedType = unboxedType; + } + } + + /** Get the coder used for converting from an inputSchema to a given type. */ + public static ConvertedSchemaInformation getConvertedSchemaInformation( + Schema inputSchema, TypeDescriptor outputType, SchemaRegistry schemaRegistry) { + ConvertedSchemaInformation convertedSchema = null; + boolean toRow = outputType.equals(TypeDescriptor.of(Row.class)); + if (toRow) { + // If the output is of type Row, then just forward the schema of the input type to the + // output. + convertedSchema = + new ConvertedSchemaInformation<>( + (SchemaCoder) + SchemaCoder.of( + inputSchema, + SerializableFunctions.identity(), + SerializableFunctions.identity()), + null); + } else { + // Otherwise, try to find a schema for the output type in the schema registry. + Schema outputSchema = null; + SchemaCoder outputSchemaCoder = null; + try { + outputSchema = schemaRegistry.getSchema(outputType); + outputSchemaCoder = + SchemaCoder.of( + outputSchema, + schemaRegistry.getToRowFunction(outputType), + schemaRegistry.getFromRowFunction(outputType)); + } catch (NoSuchSchemaException e) { + + } + FieldType unboxedType = null; + // TODO: Properly handle nullable. + if (outputSchema == null || !outputSchema.assignableToIgnoreNullable(inputSchema)) { + // The schema is not convertible directly. Attempt to unbox it and see if the schema matches + // then. + Schema checkedSchema = inputSchema; + if (inputSchema.getFieldCount() == 1) { + unboxedType = inputSchema.getField(0).getType(); + if (unboxedType.getTypeName().isCompositeType() + && !outputSchema.assignableToIgnoreNullable(unboxedType.getRowSchema())) { + checkedSchema = unboxedType.getRowSchema(); + } else { + checkedSchema = null; + } + } + if (checkedSchema != null) { + throw new RuntimeException( + "Cannot convert between types that don't have equivalent schemas." + + " input schema: " + + checkedSchema + + " output schema: " + + outputSchema); + } + } + convertedSchema = new ConvertedSchemaInformation(outputSchemaCoder, unboxedType); + } + return convertedSchema; + } + + /** + * Returns a function to convert a Row into a primitive type. This only works when the row schema + * contains a single field, and that field is convertible to the primitive type. + */ + @SuppressWarnings("unchecked") + public static SerializableFunction getConvertPrimitive( + FieldType fieldType, TypeDescriptor outputTypeDescriptor) { + FieldType expectedFieldType = + StaticSchemaInference.fieldFromType(outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE); + if (!expectedFieldType.equals(fieldType)) { + throw new IllegalArgumentException( + "Element argument type " + + outputTypeDescriptor + + " does not work with expected schema field type " + + fieldType); + } + + Type expectedInputType = new ConvertType(true).convert(outputTypeDescriptor); + + TypeDescriptor outputType = outputTypeDescriptor; + if (outputType.getRawType().isPrimitive()) { + // A SerializableFunction can only return an Object type, so if the DoFn parameter is a + // primitive type, then box it for the return. The return type will be unboxed before being + // forwarded to the DoFn parameter. + outputType = TypeDescriptor.of(Primitives.wrap(outputType.getRawType())); + } + + TypeDescription.Generic genericType = + TypeDescription.Generic.Builder.parameterizedType( + SerializableFunction.class, expectedInputType, outputType.getType()) + .build(); + DynamicType.Builder builder = + (DynamicType.Builder) new ByteBuddy().subclass(genericType); + try { + return builder + .method(ElementMatchers.named("apply")) + .intercept(new ConvertPrimitiveInstruction(outputType)) + .make() + .load(ReflectHelpers.findClassLoader(), ClassLoadingStrategy.Default.INJECTION) + .getLoaded() + .getDeclaredConstructor() + .newInstance(); + } catch (InstantiationException + | IllegalAccessException + | NoSuchMethodException + | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + + static class ConvertPrimitiveInstruction implements Implementation { + private final TypeDescriptor outputFieldType; + + public ConvertPrimitiveInstruction(TypeDescriptor outputFieldType) { + this.outputFieldType = outputFieldType; + } + + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + return instrumentedType; + } + + @Override + public ByteCodeAppender appender(final Target implementationTarget) { + return (methodVisitor, implementationContext, instrumentedMethod) -> { + int numLocals = 1 + instrumentedMethod.getParameters().size(); + + // Method param is offset 1 (offset 0 is the this parameter). + StackManipulation readValue = MethodVariableAccess.REFERENCE.loadFrom(1); + StackManipulation stackManipulation = + new StackManipulation.Compound( + new ConvertValueForSetter(readValue).convert(outputFieldType), + MethodReturn.REFERENCE); + + StackManipulation.Size size = stackManipulation.apply(methodVisitor, implementationContext); + return new Size(size.getMaximalSize(), numLocals); + }; + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 073ead1e0bd14..7de5fae289a88 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -91,8 +91,8 @@ public static Schema schemaFromClass( return builder.build(); } - // Map a Java field type to a Beam Schema FieldType. - private static Schema.FieldType fieldFromType( + /** Map a Java field type to a Beam Schema FieldType. */ + public static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) { FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType()); if (primitiveType != null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java index 5c0347fdb049d..ab54fb39f9e0c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java @@ -19,8 +19,17 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; -import javax.annotation.Nullable; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.utils.ConvertHelpers; +import org.apache.beam.sdk.schemas.utils.SelectHelpers; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList; /** Represents information about how a DoFn extracts schemas. */ @AutoValue @@ -29,25 +38,201 @@ public abstract class DoFnSchemaInformation implements Serializable { * The schema of the @Element parameter. If the Java type does not match the input PCollection but * the schemas are compatible, Beam will automatically convert between the Java types. */ - @Nullable - public abstract SchemaCoder getElementParameterSchema(); + public abstract List> getElementConverters(); /** Create an instance. */ public static DoFnSchemaInformation create() { - return new AutoValue_DoFnSchemaInformation.Builder().build(); + return new AutoValue_DoFnSchemaInformation.Builder() + .setElementConverters(Collections.emptyList()) + .build(); } /** The builder object. */ @AutoValue.Builder public abstract static class Builder { - abstract Builder setElementParameterSchema(@Nullable SchemaCoder schemaCoder); + abstract Builder setElementConverters(List> converters); abstract DoFnSchemaInformation build(); } public abstract Builder toBuilder(); - public DoFnSchemaInformation withElementParameterSchema(SchemaCoder schemaCoder) { - return toBuilder().setElementParameterSchema(schemaCoder).build(); + /** + * Specified a parameter that is a selection from an input schema (specified using FieldAccess). + * This method is called when the input parameter itself has a schema. The input parameter does + * not need to be a Row. If it is a type with a compatible registered schema, then the conversion + * will be done automatically. + * + * @param inputCoder The coder for the ParDo's input elements. + * @param selectDescriptor The descriptor describing which field to select. + * @param selectOutputSchema The schema of the selected parameter. + * @param parameterCoder The coder for the input parameter to the method. + * @param unbox If unbox is true, then the select result is a 1-field schema that needs to be + * unboxed. + * @return + */ + DoFnSchemaInformation withSelectFromSchemaParameter( + SchemaCoder inputCoder, + FieldAccessDescriptor selectDescriptor, + Schema selectOutputSchema, + SchemaCoder parameterCoder, + boolean unbox) { + List> converters = + ImmutableList.>builder() + .addAll(getElementConverters()) + .add( + ConversionFunction.of( + inputCoder.getSchema(), + inputCoder.getToRowFunction(), + parameterCoder.getFromRowFunction(), + selectDescriptor, + selectOutputSchema, + unbox)) + .build(); + + return toBuilder().setElementConverters(converters).build(); + } + + /** + * Specified a parameter that is a selection from an input schema (specified using FieldAccess). + * This method is called when the input parameter is a Java type that does not itself have a + * schema, e.g. long, or String. In this case we expect the selection predicate to return a + * single-field row with a field of the output type. + * + * @param inputCoder The coder for the ParDo's input elements. + * @param selectDescriptor The descriptor describing which field to select. + * @param selectOutputSchema The schema of the selected parameter. + * @param elementT The type of the method's input parameter. + * @return + */ + DoFnSchemaInformation withUnboxPrimitiveParameter( + SchemaCoder inputCoder, + FieldAccessDescriptor selectDescriptor, + Schema selectOutputSchema, + TypeDescriptor elementT) { + if (selectOutputSchema.getFieldCount() != 1) { + throw new RuntimeException("Parameter has no schema and the input is not a simple type."); + } + FieldType fieldType = selectOutputSchema.getField(0).getType(); + if (fieldType.getTypeName().isCompositeType()) { + throw new RuntimeException("Parameter has no schema and the input is not a primitive type."); + } + + List> converters = + ImmutableList.>builder() + .addAll(getElementConverters()) + .add( + UnboxingConversionFunction.of( + inputCoder.getSchema(), + inputCoder.getToRowFunction(), + selectDescriptor, + selectOutputSchema, + elementT)) + .build(); + + return toBuilder().setElementConverters(converters).build(); + } + + private static class ConversionFunction + implements SerializableFunction { + private final Schema inputSchema; + private final SerializableFunction toRowFunction; + private final SerializableFunction fromRowFunction; + private final FieldAccessDescriptor selectDescriptor; + private final Schema selectOutputSchema; + private final boolean unbox; + + private ConversionFunction( + Schema inputSchema, + SerializableFunction toRowFunction, + SerializableFunction fromRowFunction, + FieldAccessDescriptor selectDescriptor, + Schema selectOutputSchema, + boolean unbox) { + this.inputSchema = inputSchema; + this.toRowFunction = toRowFunction; + this.fromRowFunction = fromRowFunction; + this.selectDescriptor = selectDescriptor; + this.selectOutputSchema = selectOutputSchema; + this.unbox = unbox; + } + + public static ConversionFunction of( + Schema inputSchema, + SerializableFunction toRowFunction, + SerializableFunction fromRowFunction, + FieldAccessDescriptor selectDescriptor, + Schema selectOutputSchema, + boolean unbox) { + return new ConversionFunction<>( + inputSchema, toRowFunction, fromRowFunction, selectDescriptor, selectOutputSchema, unbox); + } + + @Override + public OutputT apply(InputT input) { + Row row = toRowFunction.apply(input); + Row selected = + SelectHelpers.selectRow(row, selectDescriptor, inputSchema, selectOutputSchema); + if (unbox) { + selected = selected.getRow(0); + } + return fromRowFunction.apply(selected); + } + } + + /** + * This function is used when the schema is a singleton schema containing a single primitive field + * and the Java type we are converting to is that of the primitive field. + */ + private static class UnboxingConversionFunction + implements SerializableFunction { + private final Schema inputSchema; + private final SerializableFunction toRowFunction; + private final FieldAccessDescriptor selectDescriptor; + private final Schema selectOutputSchema; + private final FieldType primitiveType; + private final TypeDescriptor primitiveOutputType; + private transient SerializableFunction conversionFunction; + + private UnboxingConversionFunction( + Schema inputSchema, + SerializableFunction toRowFunction, + FieldAccessDescriptor selectDescriptor, + Schema selectOutputSchema, + TypeDescriptor primitiveOutputType) { + this.inputSchema = inputSchema; + this.toRowFunction = toRowFunction; + this.selectDescriptor = selectDescriptor; + this.selectOutputSchema = selectOutputSchema; + this.primitiveType = selectOutputSchema.getField(0).getType(); + this.primitiveOutputType = primitiveOutputType; + } + + public static UnboxingConversionFunction of( + Schema inputSchema, + SerializableFunction toRowFunction, + FieldAccessDescriptor selectDescriptor, + Schema selectOutputSchema, + TypeDescriptor primitiveOutputType) { + return new UnboxingConversionFunction<>( + inputSchema, toRowFunction, selectDescriptor, selectOutputSchema, primitiveOutputType); + } + + @Override + public OutputT apply(InputT input) { + Row row = toRowFunction.apply(input); + Row selected = + SelectHelpers.selectRow(row, selectDescriptor, inputSchema, selectOutputSchema); + return getConversionFunction().apply(selected.getValue(0)); + } + + private SerializableFunction getConversionFunction() { + if (conversionFunction == null) { + conversionFunction = + (SerializableFunction) + ConvertHelpers.getConvertPrimitive(primitiveType, primitiveOutputType); + } + return conversionFunction; + } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index c3b1f82ddd0ab..1c5f4b62954f8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -251,7 +251,7 @@ public InputT element(DoFn doFn) { } @Override - public InputT schemaElement(DoFn doFn) { + public InputT schemaElement(int index) { throw new UnsupportedOperationException("Schemas are not supported by DoFnTester"); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 9febcba6bc3a0..08bae25413447 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -39,6 +39,8 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.utils.ConvertHelpers; +import org.apache.beam.sdk.schemas.utils.SelectHelpers; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.transforms.DoFn.WindowedContext; import org.apache.beam.sdk.transforms.display.DisplayData; @@ -59,7 +61,6 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; @@ -435,7 +436,7 @@ private static void validateStateApplicableForInput(DoFn fn, PCollection fieldAccessDeclarations, @@ -448,25 +449,25 @@ private static void validateFieldAccessParameter( // here as well to catch these errors. FieldAccessDescriptor fieldAccessDescriptor = null; if (fieldAccessString == null) { - // This is the case where no FieldId is defined, just an @Element Row row. Default to all - // fields accessed. + // This is the case where no FieldId is defined. Default to all fields accessed. fieldAccessDescriptor = FieldAccessDescriptor.withAllFields(); } else { - // In this case, we expect to have a FieldAccessDescriptor defined in the class. + // If there is a FieldAccessDescriptor in the class with this id, use that. FieldAccessDeclaration fieldAccessDeclaration = fieldAccessDeclarations.get(fieldAccessString); - checkArgument( - fieldAccessDeclaration != null, - "No FieldAccessDeclaration defined with id", - fieldAccessString); - checkArgument(fieldAccessDeclaration.field().getType().equals(FieldAccessDescriptor.class)); - try { - fieldAccessDescriptor = (FieldAccessDescriptor) fieldAccessDeclaration.field().get(fn); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); + if (fieldAccessDeclaration != null) { + checkArgument(fieldAccessDeclaration.field().getType().equals(FieldAccessDescriptor.class)); + try { + fieldAccessDescriptor = (FieldAccessDescriptor) fieldAccessDeclaration.field().get(fn); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } else { + // Otherwise, interpret the string as a field-name expression. + fieldAccessDescriptor = FieldAccessDescriptor.withFieldNames(fieldAccessString); } } - fieldAccessDescriptor.resolve(inputSchema); + return fieldAccessDescriptor.resolve(inputSchema); } /** @@ -571,64 +572,44 @@ public static DoFnSchemaInformation getDoFnSchemaInformation( DoFn fn, PCollection input) { DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); DoFnSignature.ProcessElementMethod processElementMethod = signature.processElement(); - SchemaElementParameter elementParameter = processElementMethod.getSchemaElementParameter(); - boolean validateInputSchema = elementParameter != null; - TypeDescriptor elementT = null; - if (validateInputSchema) { - elementT = (TypeDescriptor) elementParameter.elementT(); - } - - DoFnSchemaInformation doFnSchemaInformation = DoFnSchemaInformation.create(); - if (validateInputSchema) { - // Element type doesn't match input type, so we need to covnert. + if (!processElementMethod.getSchemaElementParameters().isEmpty()) { if (!input.hasSchema()) { throw new IllegalArgumentException("Type of @Element must match the DoFn type" + input); } + } - validateFieldAccessParameter( - elementParameter.fieldAccessString(), - input.getSchema(), - signature.fieldAccessDeclarations(), - fn); - - boolean toRow = elementT.equals(TypeDescriptor.of(Row.class)); - if (toRow) { + SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry(); + DoFnSchemaInformation doFnSchemaInformation = DoFnSchemaInformation.create(); + for (SchemaElementParameter parameter : processElementMethod.getSchemaElementParameters()) { + TypeDescriptor elementT = parameter.elementT(); + FieldAccessDescriptor accessDescriptor = + getFieldAccessDescriptorFromParameter( + parameter.fieldAccessString(), + input.getSchema(), + signature.fieldAccessDeclarations(), + fn); + Schema selectedSchema = SelectHelpers.getOutputSchema(input.getSchema(), accessDescriptor); + ConvertHelpers.ConvertedSchemaInformation converted = + ConvertHelpers.getConvertedSchemaInformation(selectedSchema, elementT, schemaRegistry); + if (converted.outputSchemaCoder != null) { doFnSchemaInformation = - doFnSchemaInformation.withElementParameterSchema( - SchemaCoder.of( - input.getSchema(), - SerializableFunctions.identity(), - SerializableFunctions.identity())); + doFnSchemaInformation.withSelectFromSchemaParameter( + (SchemaCoder) input.getCoder(), + accessDescriptor, + selectedSchema, + converted.outputSchemaCoder, + converted.unboxedType != null); } else { - // For now we assume the parameter is not of type Row (TODO: change this) - SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry(); - try { - Schema schema = schemaRegistry.getSchema(elementT); - SerializableFunction toRowFunction = schemaRegistry.getToRowFunction(elementT); - SerializableFunction fromRowFunction = schemaRegistry.getFromRowFunction(elementT); - doFnSchemaInformation = - doFnSchemaInformation.withElementParameterSchema( - SchemaCoder.of(schema, toRowFunction, fromRowFunction)); - - // assert matches input schema. - // TODO: Properly handle nullable. - if (!doFnSchemaInformation - .getElementParameterSchema() - .getSchema() - .assignableToIgnoreNullable(input.getSchema())) { - throw new IllegalArgumentException( - "Input to DoFn has schema: " - + input.getSchema() - + " However @ElementParameter of type " - + elementT - + " has incompatible schema " - + doFnSchemaInformation.getElementParameterSchema().getSchema()); - } - } catch (NoSuchSchemaException e) { - throw new RuntimeException("No schema registered for " + elementT); - } + // If the selected schema is a Row containing a single primitive type (which is the output + // of Select when selecting a primitive), attempt to unbox it and match against the + // parameter. + checkArgument(converted.unboxedType != null); + doFnSchemaInformation = + doFnSchemaInformation.withUnboxPrimitiveParameter( + (SchemaCoder) input.getCoder(), accessDescriptor, selectedSchema, elementT); } } + return doFnSchemaInformation; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 47f09f4dcea17..457ee55a2f221 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -34,6 +34,7 @@ import net.bytebuddy.description.method.MethodDescription; import net.bytebuddy.description.modifier.Visibility; import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.description.type.TypeList; import net.bytebuddy.dynamic.DynamicType; import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; @@ -46,10 +47,12 @@ import net.bytebuddy.implementation.MethodDelegation; import net.bytebuddy.implementation.bytecode.ByteCodeAppender; import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.StackManipulation.Compound; import net.bytebuddy.implementation.bytecode.Throw; import net.bytebuddy.implementation.bytecode.assign.Assigner; import net.bytebuddy.implementation.bytecode.assign.Assigner.Typing; import net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import net.bytebuddy.implementation.bytecode.constant.IntegerConstant; import net.bytebuddy.implementation.bytecode.constant.TextConstant; import net.bytebuddy.implementation.bytecode.member.FieldAccess; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; @@ -87,6 +90,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v20_0.com.google.common.primitives.Primitives; /** Dynamically generates a {@link DoFnInvoker} instances for invoking a {@link DoFn}. */ public class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory { @@ -663,13 +667,29 @@ public StackManipulation dispatch(ElementParameter p) { @Override public StackManipulation dispatch(SchemaElementParameter p) { - // Ignore FieldAccess id for now. - return new StackManipulation.Compound( - pushDelegate, - MethodInvocation.invoke( - getExtraContextFactoryMethodDescription( - SCHEMA_ELEMENT_PARAMETER_METHOD, DoFn.class)), - TypeCasting.to(new TypeDescription.ForLoadedType(p.elementT().getRawType()))); + ForLoadedType elementType = new ForLoadedType(p.elementT().getRawType()); + ForLoadedType castType = + elementType.isPrimitive() + ? new ForLoadedType(Primitives.wrap(p.elementT().getRawType())) + : elementType; + + StackManipulation stackManipulation = + new StackManipulation.Compound( + IntegerConstant.forValue(p.index()), + MethodInvocation.invoke( + getExtraContextFactoryMethodDescription( + SCHEMA_ELEMENT_PARAMETER_METHOD, int.class)), + TypeCasting.to(castType)); + if (elementType.isPrimitive()) { + stackManipulation = + new Compound( + stackManipulation, + Assigner.DEFAULT.assign( + elementType.asBoxed().asGenericType(), + elementType.asUnboxed().asGenericType(), + Typing.STATIC)); + } + return stackManipulation; } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index 438a918b99a40..d8504eef220d7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -132,16 +132,19 @@ interface ArgumentProvider { /** Provide a {@link DoFn.OnTimerContext} to use with the given {@link DoFn}. */ DoFn.OnTimerContext onTimerContext(DoFn doFn); - /** Provide a link to the input element. */ + /** Provide a reference to the input element. */ InputT element(DoFn doFn); - /** Provide a link to the input element. */ - Object schemaElement(DoFn doFn); + /** + * Provide a reference to the selected schema field corresponding to the input argument + * specified by index. + */ + Object schemaElement(int index); - /** Provide a link to the input element timestamp. */ + /** Provide a reference to the input element timestamp. */ Instant timestamp(DoFn doFn); - /** Provide a link to the time domain for a timer firing. */ + /** Provide a reference to the time domain for a timer firing. */ TimeDomain timeDomain(DoFn doFn); /** Provide a {@link OutputReceiver} for outputting to the default output. */ @@ -188,7 +191,7 @@ public InputT element(DoFn doFn) { } @Override - public InputT schemaElement(DoFn doFn) { + public InputT schemaElement(int index) { throw new UnsupportedOperationException( String.format( "Should never call non-overridden methods of %s", diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 6151cab20f54b..727d4f7cb4df6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -398,8 +399,12 @@ public static ElementParameter elementParameter(TypeDescriptor elementT) { } public static SchemaElementParameter schemaElementParameter( - TypeDescriptor elementT, @Nullable String fieldAccessId) { - return new AutoValue_DoFnSignature_Parameter_SchemaElementParameter(elementT, fieldAccessId); + TypeDescriptor elementT, @Nullable String fieldAccessString, int index) { + return new AutoValue_DoFnSignature_Parameter_SchemaElementParameter.Builder() + .setElementT(elementT) + .setFieldAccessString(fieldAccessString) + .setIndex(index) + .build(); } public static TimestampParameter timestampParameter() { @@ -511,6 +516,22 @@ public abstract static class SchemaElementParameter extends Parameter { @Nullable public abstract String fieldAccessString(); + + public abstract int index(); + + /** Builder class. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setElementT(TypeDescriptor elementT); + + public abstract Builder setFieldAccessString(@Nullable String fieldAccess); + + public abstract Builder setIndex(int index); + + public abstract SchemaElementParameter build(); + } + + public abstract Builder toBuilder(); } /** @@ -691,12 +712,11 @@ public boolean observesWindow() { } @Nullable - public SchemaElementParameter getSchemaElementParameter() { - Optional parameter = - extraParameters().stream() - .filter(Predicates.instanceOf(SchemaElementParameter.class)::apply) - .findFirst(); - return parameter.isPresent() ? ((SchemaElementParameter) parameter.get()) : null; + public List getSchemaElementParameters() { + return extraParameters().stream() + .filter(Predicates.instanceOf(SchemaElementParameter.class)::apply) + .map(SchemaElementParameter.class::cast) + .collect(Collectors.toList()); } /** The {@link OutputReceiverParameter} for a main output, or null if there is none. */ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 9889adc6c6155..69b03a43d0d9f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter; @@ -260,6 +261,10 @@ public List getExtraParameters() { return Collections.unmodifiableList(extraParameters); } + public void setParameter(int index, Parameter parameter) { + extraParameters.set(index, parameter); + } + /** * Returns an {@link MethodAnalysisContext} like this one but including the provided {@link * StateParameter}. @@ -814,6 +819,16 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( methodContext.addParameter(extraParam); } + int schemaElementIndex = 0; + for (int i = 0; i < methodContext.getExtraParameters().size(); ++i) { + Parameter parameter = methodContext.getExtraParameters().get(i); + if (parameter instanceof SchemaElementParameter) { + SchemaElementParameter schemaParameter = (SchemaElementParameter) parameter; + schemaParameter = schemaParameter.toBuilder().setIndex(schemaElementIndex).build(); + methodContext.setParameter(i, schemaParameter); + ++schemaElementIndex; + } + } // The allowed parameters depend on whether this DoFn is splittable if (methodContext.hasRestrictionTrackerParameter()) { @@ -867,13 +882,13 @@ private static Parameter analyzeExtraParameter( ErrorReporter paramErrors = methodErrors.forParameter(param); - if (hasElementAnnotation(param.getAnnotations())) { - if (paramT.equals(inputT)) { - return Parameter.elementParameter(paramT); - } else { - String fieldAccessString = getFieldAccessId(param.getAnnotations()); - return Parameter.schemaElementParameter(paramT, fieldAccessString); - } + String fieldAccessString = getFieldAccessId(param.getAnnotations()); + if (fieldAccessString != null) { + return Parameter.schemaElementParameter(paramT, fieldAccessString, param.getIndex()); + } else if (hasElementAnnotation(param.getAnnotations())) { + return (paramT.equals(inputT)) + ? Parameter.elementParameter(paramT) + : Parameter.schemaElementParameter(paramT, null, param.getIndex()); } else if (hasTimestampAnnotation(param.getAnnotations())) { methodErrors.checkArgument( rawType.equals(Instant.class), diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java index 96ef3b635e2f7..9408fc35012cb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java @@ -17,15 +17,18 @@ */ package org.apache.beam.sdk.transforms; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import com.google.auto.value.AutoValue; import java.io.Serializable; +import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.FieldAccessDescriptor; -import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.sdk.schemas.annotations.SchemaCreate; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -307,8 +310,7 @@ public void testFieldAccessSchemaPipeline() { FieldAccessDescriptor.withAllFields(); @ProcessElement - public void process( - @FieldAccess("foo") @Element Row row, OutputReceiver r) { + public void process(@FieldAccess("foo") Row row, OutputReceiver r) { r.output(row.getString(0) + ":" + row.getInt32(1)); } })); @@ -361,31 +363,29 @@ public void process(@FieldAccess("a") Row row) {} } /** POJO used for testing. */ - @DefaultSchema(JavaFieldSchema.class) - static class InferredPojo { - final String stringField; - final Integer integerField; + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class Inferred { + abstract String getStringField(); - @SchemaCreate - InferredPojo(String stringField, Integer integerField) { - this.stringField = stringField; - this.integerField = integerField; - } + abstract Integer getIntegerField(); } @Test @Category({ValidatesRunner.class, UsesSchema.class}) public void testInferredSchemaPipeline() { - List pojoList = + List pojoList = Lists.newArrayList( - new InferredPojo("a", 1), new InferredPojo("b", 2), new InferredPojo("c", 3)); + new AutoValue_ParDoSchemaTest_Inferred("a", 1), + new AutoValue_ParDoSchemaTest_Inferred("b", 2), + new AutoValue_ParDoSchemaTest_Inferred("c", 3)); PCollection output = pipeline .apply(Create.of(pojoList)) .apply( ParDo.of( - new DoFn() { + new DoFn() { @ProcessElement public void process(@Element Row row, OutputReceiver r) { r.output(row.getString(0) + ":" + row.getInt32(1)); @@ -398,61 +398,57 @@ public void process(@Element Row row, OutputReceiver r) { @Test @Category({ValidatesRunner.class, UsesSchema.class}) public void testSchemasPassedThrough() { - List pojoList = + List pojoList = Lists.newArrayList( - new InferredPojo("a", 1), new InferredPojo("b", 2), new InferredPojo("c", 3)); + new AutoValue_ParDoSchemaTest_Inferred("a", 1), + new AutoValue_ParDoSchemaTest_Inferred("b", 2), + new AutoValue_ParDoSchemaTest_Inferred("c", 3)); - PCollection out = pipeline.apply(Create.of(pojoList)).apply(Filter.by(e -> true)); + PCollection out = pipeline.apply(Create.of(pojoList)).apply(Filter.by(e -> true)); assertTrue(out.hasSchema()); pipeline.run(); } /** Pojo used for testing. */ - @DefaultSchema(JavaFieldSchema.class) - static class InferredPojo2 { - final Integer integerField; - final String stringField; + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class Inferred2 { + abstract Integer getIntegerField(); - @SchemaCreate - InferredPojo2(String stringField, Integer integerField) { - this.stringField = stringField; - this.integerField = integerField; - } + abstract String getStringField(); } @Test @Category({ValidatesRunner.class, UsesSchema.class}) public void testSchemaConversionPipeline() { - List pojoList = + List pojoList = Lists.newArrayList( - new InferredPojo("a", 1), new InferredPojo("b", 2), new InferredPojo("c", 3)); + new AutoValue_ParDoSchemaTest_Inferred("a", 1), + new AutoValue_ParDoSchemaTest_Inferred("b", 2), + new AutoValue_ParDoSchemaTest_Inferred("c", 3)); PCollection output = pipeline .apply(Create.of(pojoList)) .apply( ParDo.of( - new DoFn() { + new DoFn() { @ProcessElement - public void process(@Element InferredPojo2 pojo, OutputReceiver r) { - r.output(pojo.stringField + ":" + pojo.integerField); + public void process(@Element Inferred2 pojo, OutputReceiver r) { + r.output(pojo.getStringField() + ":" + pojo.getIntegerField()); } })); PAssert.that(output).containsInAnyOrder("a:1", "b:2", "c:3"); pipeline.run(); } - @DefaultSchema(JavaFieldSchema.class) - static class Nested { - final int field1; - final InferredPojo inner; + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class Nested { + abstract int getField1(); - @SchemaCreate - public Nested(int field1, InferredPojo inner) { - this.field1 = field1; - this.inner = inner; - } + abstract Inferred getInner(); } @Test @@ -460,9 +456,10 @@ public Nested(int field1, InferredPojo inner) { public void testNestedSchema() { List pojoList = Lists.newArrayList( - new Nested(1, new InferredPojo("a", 1)), - new Nested(2, new InferredPojo("b", 2)), - new Nested(3, new InferredPojo("c", 3))); + new AutoValue_ParDoSchemaTest_Nested(1, new AutoValue_ParDoSchemaTest_Inferred("a", 1)), + new AutoValue_ParDoSchemaTest_Nested(2, new AutoValue_ParDoSchemaTest_Inferred("b", 2)), + new AutoValue_ParDoSchemaTest_Nested( + 3, new AutoValue_ParDoSchemaTest_Inferred("c", 3))); PCollection output = pipeline @@ -475,10 +472,154 @@ public void testNestedSchema() { new DoFn() { @ProcessElement public void process(@Element Nested nested, OutputReceiver r) { - r.output(nested.inner.stringField + ":" + nested.inner.integerField); + r.output( + nested.getInner().getStringField() + + ":" + + nested.getInner().getIntegerField()); } })); PAssert.that(output).containsInAnyOrder("a:1", "b:2", "c:3"); pipeline.run(); } + + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class ForExtraction { + abstract Integer getIntegerField(); + + abstract String getStringField(); + + abstract List getInts(); + } + + @Test + @Category({ValidatesRunner.class, UsesSchema.class}) + public void testSchemaFieldSelectionUnboxing() { + List pojoList = + Lists.newArrayList( + new AutoValue_ParDoSchemaTest_ForExtraction(1, "a", Lists.newArrayList(1, 2)), + new AutoValue_ParDoSchemaTest_ForExtraction(2, "b", Lists.newArrayList(2, 3)), + new AutoValue_ParDoSchemaTest_ForExtraction(3, "c", Lists.newArrayList(3, 4))); + + PCollection output = + pipeline + .apply(Create.of(pojoList)) + .apply( + ParDo.of( + new DoFn() { + // Read the list twice as two equivalent types to ensure that Beam properly + // converts. + @ProcessElement + public void process( + @FieldAccess("stringField") String stringField, + @FieldAccess("integerField") Integer integerField, + @FieldAccess("ints") Integer[] intArray, + @FieldAccess("ints") List intList, + OutputReceiver r) { + + r.output( + stringField + + ":" + + integerField + + ":" + + Arrays.toString(intArray) + + ":" + + intList.toString()); + } + })); + PAssert.that(output) + .containsInAnyOrder("a:1:[1, 2]:[1, 2]", "b:2:[2, 3]:[2, 3]", "c:3:[3, 4]:[3, 4]"); + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesSchema.class}) + public void testSchemaFieldDescriptorSelectionUnboxing() { + List pojoList = + Lists.newArrayList( + new AutoValue_ParDoSchemaTest_ForExtraction(1, "a", Lists.newArrayList(1, 2)), + new AutoValue_ParDoSchemaTest_ForExtraction(2, "b", Lists.newArrayList(2, 3)), + new AutoValue_ParDoSchemaTest_ForExtraction(3, "c", Lists.newArrayList(3, 4))); + + PCollection output = + pipeline + .apply(Create.of(pojoList)) + .apply( + ParDo.of( + new DoFn() { + @FieldAccess("stringSelector") + final FieldAccessDescriptor stringSelector = + FieldAccessDescriptor.withFieldNames("stringField"); + + @FieldAccess("intSelector") + final FieldAccessDescriptor intSelector = + FieldAccessDescriptor.withFieldNames("integerField"); + + @FieldAccess("intsSelector") + final FieldAccessDescriptor intsSelector = + FieldAccessDescriptor.withFieldNames("ints"); + + @ProcessElement + public void process( + @FieldAccess("stringSelector") String stringField, + @FieldAccess("intSelector") int integerField, + @FieldAccess("intsSelector") int[] intArray, + OutputReceiver r) { + r.output( + stringField + ":" + integerField + ":" + Arrays.toString(intArray)); + } + })); + PAssert.that(output).containsInAnyOrder("a:1:[1, 2]", "b:2:[2, 3]", "c:3:[3, 4]"); + pipeline.run(); + } + + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class NestedForExtraction { + abstract ForExtraction getInner(); + } + + @Test + @Category({ValidatesRunner.class, UsesSchema.class}) + public void testSchemaFieldSelectionNested() { + List pojoList = + Lists.newArrayList( + new AutoValue_ParDoSchemaTest_ForExtraction(1, "a", Lists.newArrayList(1, 2)), + new AutoValue_ParDoSchemaTest_ForExtraction(2, "b", Lists.newArrayList(2, 3)), + new AutoValue_ParDoSchemaTest_ForExtraction(3, "c", Lists.newArrayList(3, 4))); + List outerList = + pojoList.stream() + .map(AutoValue_ParDoSchemaTest_NestedForExtraction::new) + .collect(Collectors.toList()); + + PCollection output = + pipeline + .apply(Create.of(outerList)) + .apply( + ParDo.of( + new DoFn() { + + @ProcessElement + public void process( + @FieldAccess("inner.*") ForExtraction extracted, + @FieldAccess("inner") ForExtraction extracted1, + @FieldAccess("inner.stringField") String stringField, + @FieldAccess("inner.integerField") int integerField, + @FieldAccess("inner.ints") List intArray, + OutputReceiver r) { + assertEquals(extracted, extracted1); + assertEquals(stringField, extracted.getStringField()); + assertEquals(integerField, (int) extracted.getIntegerField()); + assertEquals(intArray, extracted.getInts()); + r.output( + extracted.getStringField() + + ":" + + extracted.getIntegerField() + + ":" + + extracted.getInts().toString()); + } + })); + PAssert.that(output).containsInAnyOrder("a:1:[1, 2]", "b:2:[2, 3]", "c:3:[3, 4]"); + pipeline.run(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index 063af02d90ceb..8c80cfc42497f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -25,6 +25,8 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -66,6 +68,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.hamcrest.Matcher; import org.hamcrest.Matchers; import org.joda.time.Instant; @@ -181,7 +184,35 @@ public void testRowParameterWithoutFieldAccess() { @ProcessElement public void process(@Element Row row) {} }.getClass()); - assertThat(sig.processElement().getSchemaElementParameter(), notNullValue()); + assertFalse(sig.processElement().getSchemaElementParameters().isEmpty()); + } + + @Test + public void testMultipleSchemaParameters() { + DoFnSignature sig = + DoFnSignatures.getSignature( + new DoFn() { + @ProcessElement + public void process( + @Element Row row1, + @Timestamp Instant ts, + @Element Row row2, + OutputReceiver o, + @Element Integer intParameter) {} + }.getClass()); + assertEquals(3, sig.processElement().getSchemaElementParameters().size()); + assertEquals(0, sig.processElement().getSchemaElementParameters().get(0).index()); + assertEquals( + TypeDescriptors.rows(), + sig.processElement().getSchemaElementParameters().get(0).elementT()); + assertEquals(1, sig.processElement().getSchemaElementParameters().get(1).index()); + assertEquals( + TypeDescriptors.rows(), + sig.processElement().getSchemaElementParameters().get(1).elementT()); + assertEquals(2, sig.processElement().getSchemaElementParameters().get(2).index()); + assertEquals( + TypeDescriptors.integers(), + sig.processElement().getSchemaElementParameters().get(2).elementT()); } @Test @@ -202,7 +233,7 @@ public void process(@FieldAccess("foo") @Element Row row) {} assertThat(field.getName(), equalTo("fieldAccess")); assertThat(field.get(doFn), equalTo(descriptor)); - assertThat(sig.processElement().getSchemaElementParameter(), notNullValue()); + assertFalse(sig.processElement().getSchemaElementParameters().isEmpty()); } @Test diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index b8a2dfbeca1ac..78f035b0e2acb 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -44,6 +44,7 @@ import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFnOutputReceivers; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; @@ -120,6 +121,7 @@ public FnApiDoFnRunner createRunner(Context co this.mainOutputConsumers = (Collection>>) (Collection) context.localNameToConsumer.get(context.mainOutputTag.getId()); + this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.parDoPayload); this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn); this.doFnInvoker.invokeSetup(); @@ -157,7 +159,6 @@ public void output( outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } }; - this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.parDoPayload); } @Override @@ -396,9 +397,9 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { - Row row = context.schemaCoder.getToRowFunction().apply(element()); - return doFnSchemaInformation.getElementParameterSchema().getFromRowFunction().apply(row); + public Object schemaElement(int index) { + SerializableFunction converter = doFnSchemaInformation.getElementConverters().get(index); + return converter.apply(element()); } @Override @@ -580,7 +581,7 @@ public InputT element(DoFn doFn) { } @Override - public Object schemaElement(DoFn doFn) { + public Object schemaElement(int index) { throw new UnsupportedOperationException("Element parameters are not supported."); }