Skip to content

Commit

Permalink
Merge pull request apache#10983: [BEAM-9393] Support schemas in state…
Browse files Browse the repository at this point in the history
… API
  • Loading branch information
reuvenlax committed Apr 28, 2020
2 parents 0276d85 + 06fe9a7 commit 74a6565
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.values.Row;

/** Static methods for working with {@link StateSpec StateSpecs}. */
@Experimental(Kind.STATE)
Expand All @@ -42,14 +45,22 @@ private StateSpecs() {}
/**
* Create a {@link StateSpec} for a single value of type {@code T}.
*
* <p>This method attempts to infer the accumulator coder automatically.
* <p>This method attempts to infer the value coder automatically.
*
* <p>If the value type has a schema registered, then the schema will be used to encode the
* values.
*
* @see #value(Coder)
*/
public static <T> StateSpec<ValueState<T>> value() {
return new ValueStateSpec<>(null);
}

/** Create a {@link StateSpec} for a row value with the specified schema. */
public static StateSpec<ValueState<Row>> rowValue(Schema schema) {
return value(RowCoder.of(schema));
}

/**
* Identical to {@link #value()}, but with a coder explicitly supplied.
*
Expand Down Expand Up @@ -129,12 +140,25 @@ StateSpec<CombiningState<InputT, AccumT, OutputT>> combining(
*
* <p>This method attempts to infer the element coder automatically.
*
* <p>If the element type has a schema registered, then the schema will be used to encode the
* values.
*
* @see #bag(Coder)
*/
public static <T> StateSpec<BagState<T>> bag() {
return new BagStateSpec<>(null);
}

/**
* Create a {@link StateSpec} for a {@link BagState}, optimized for adding values frequently and
* occasionally retrieving all the values that have been added.
*
* <p>This method is for storing row elements with the given schema.
*/
public static StateSpec<BagState<Row>> rowBag(Schema schema) {
return new BagStateSpec<>(RowCoder.of(schema));
}

/**
* Identical to {@link #bag()}, but with an element coder explicitly supplied.
*
Expand All @@ -149,12 +173,24 @@ public static <T> StateSpec<BagState<T>> bag(Coder<T> elemCoder) {
*
* <p>This method attempts to infer the element coder automatically.
*
* <p>If the element type has a schema registered, then the schema will be used to encode the
* values.
*
* @see #set(Coder)
*/
public static <T> StateSpec<SetState<T>> set() {
return new SetStateSpec<>(null);
}

/**
* Create a {@link StateSpec} for a {@link SetState}, optimized for checking membership.
*
* <p>This method is for storing row elements with the given schema.
*/
public static StateSpec<SetState<Row>> rowSet(Schema schema) {
return new SetStateSpec<>(RowCoder.of(schema));
}

/**
* Identical to {@link #set()}, but with an element coder explicitly supplied.
*
Expand All @@ -169,12 +205,27 @@ public static <T> StateSpec<SetState<T>> set(Coder<T> elemCoder) {
*
* <p>This method attempts to infer the key and value coders automatically.
*
* <p>If the key and value types have schemas registered, then the schemas will be used to encode
* the elements.
*
* @see #map(Coder, Coder)
*/
public static <K, V> StateSpec<MapState<K, V>> map() {
return new MapStateSpec<>(null, null);
}

/**
* Create a {@link StateSpec} for a {@link SetState}, optimized for key lookups and writes.
*
* <p>This method is for storing maps where both the keys and the values are rows with the
* specified schemas.
*
* @see #map(Coder, Coder)
*/
public static StateSpec<MapState<Row, Row>> rowMap(Schema keySchema, Schema valueSchema) {
return new MapStateSpec<>(RowCoder.of(keySchema), RowCoder.of(valueSchema));
}

/**
* Identical to {@link #map()}, but with key and value coders explicitly supplied.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,22 @@ private static <T> DisplayData.ItemSpec<? extends Class<?>> displayDataForFn(T f
}

private static void finishSpecifyingStateSpecs(
DoFn<?, ?> fn, CoderRegistry coderRegistry, Coder<?> inputCoder) {
DoFn<?, ?> fn,
CoderRegistry coderRegistry,
SchemaRegistry schemaRegistry,
Coder<?> inputCoder) {
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
Map<String, DoFnSignature.StateDeclaration> stateDeclarations = signature.stateDeclarations();
for (DoFnSignature.StateDeclaration stateDeclaration : stateDeclarations.values()) {
try {
StateSpec<?> stateSpec = (StateSpec<?>) stateDeclaration.field().get(fn);
stateSpec.offerCoders(codersForStateSpecTypes(stateDeclaration, coderRegistry, inputCoder));
Coder[] coders;
try {
coders = schemasForStateSpecTypes(stateDeclaration, schemaRegistry);
} catch (NoSuchSchemaException e) {
coders = codersForStateSpecTypes(stateDeclaration, coderRegistry, inputCoder);
}
stateSpec.offerCoders(coders);
stateSpec.finishSpecifying();
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -494,6 +503,32 @@ private static FieldAccessDescriptor getFieldAccessDescriptorFromParameter(
return fieldAccessDescriptor.resolve(inputSchema);
}

private static SchemaCoder[] schemasForStateSpecTypes(
DoFnSignature.StateDeclaration stateDeclaration, SchemaRegistry schemaRegistry)
throws NoSuchSchemaException {
Type stateType = stateDeclaration.stateType().getType();

if (!(stateType instanceof ParameterizedType)) {
// No type arguments means no coders to infer.
return new SchemaCoder[0];
}

Type[] typeArguments = ((ParameterizedType) stateType).getActualTypeArguments();
SchemaCoder[] coders = new SchemaCoder[typeArguments.length];

for (int i = 0; i < typeArguments.length; i++) {
Type typeArgument = typeArguments[i];
TypeDescriptor typeDescriptor = TypeDescriptor.of(typeArgument);
coders[i] =
SchemaCoder.of(
schemaRegistry.getSchema(typeDescriptor),
typeDescriptor,
schemaRegistry.getToRowFunction(typeDescriptor),
schemaRegistry.getFromRowFunction(typeDescriptor));
}
return coders;
}

/**
* Try to provide coders for as many of the type arguments of given {@link
* DoFnSignature.StateDeclaration} as possible.
Expand Down Expand Up @@ -741,8 +776,8 @@ public MultiOutput<InputT, OutputT> withOutputTags(
@Override
public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry();
CoderRegistry registry = input.getPipeline().getCoderRegistry();
finishSpecifyingStateSpecs(fn, registry, input.getCoder());
CoderRegistry coderRegistry = input.getPipeline().getCoderRegistry();
finishSpecifyingStateSpecs(fn, coderRegistry, schemaRegistry, input.getCoder());
TupleTag<OutputT> mainOutput = new TupleTag<>(MAIN_OUTPUT_TAG);
PCollection<OutputT> res =
input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput);
Expand All @@ -757,7 +792,7 @@ public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
} catch (NoSuchSchemaException e) {
try {
res.setCoder(
registry.getCoder(
coderRegistry.getCoder(
outputTypeDescriptor,
getFn().getInputTypeDescriptor(),
((PCollection<InputT>) input).getCoder()));
Expand Down Expand Up @@ -895,8 +930,9 @@ public PCollectionTuple expand(PCollection<? extends InputT> input) {
validateWindowType(input, fn);

// Use coder registry to determine coders for all StateSpec defined in the fn signature.
CoderRegistry registry = input.getPipeline().getCoderRegistry();
finishSpecifyingStateSpecs(fn, registry, input.getCoder());
CoderRegistry coderRegistry = input.getPipeline().getCoderRegistry();
SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry();
finishSpecifyingStateSpecs(fn, coderRegistry, schemaRegistry, input.getCoder());

DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
if (signature.usesState() || signature.usesTimers()) {
Expand All @@ -923,7 +959,7 @@ public PCollectionTuple expand(PCollection<? extends InputT> input) {
try {
out.setCoder(
(Coder)
registry.getCoder(
coderRegistry.getCoder(
out.getTypeDescriptor(), getFn().getInputTypeDescriptor(), inputCoder));
} catch (CannotProvideCoderException e) {
// Ignore and let coder inference happen later.
Expand Down
Loading

0 comments on commit 74a6565

Please sign in to comment.