Skip to content

Commit

Permalink
[BEAM-10028] Add support for the state backed iterable coder to the J…
Browse files Browse the repository at this point in the history
…ava SDK harness. (apache#11746)

* [BEAM-10028] Add support for the state backed iterable coder to the Java SDK harness.

This required supporting a translation context through CoderTranslator to give access to the BeamFnStateClient and current process bundle instruction id.

* fixup! Address PR comments
  • Loading branch information
lukecwik authored May 21, 2020
1 parent 9cf6f5f commit ffd74b0
Show file tree
Hide file tree
Showing 27 changed files with 600 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Collections;
import java.util.List;
import org.apache.avro.Schema;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.coders.AvroGenericCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
Expand All @@ -37,7 +38,8 @@ public byte[] getPayload(AvroGenericCoder from) {
}

@Override
public AvroGenericCoder fromComponents(List<Coder<?>> components, byte[] payload) {
public AvroGenericCoder fromComponents(
List<Coder<?>> components, byte[] payload, TranslationContext context) {
Schema schema = new Schema.Parser().parse(new String(payload, Charsets.UTF_8));
return AvroGenericCoder.of(schema);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@

/** Converts to and from Beam Runner API representations of {@link Coder Coders}. */
public class CoderTranslation {

/**
* Pass through additional parameters beyond the components and payload to be able to translate
* specific coders.
*
* <p>Portability state API backed coders is an example of such a coder translator requiring
* additional parameters.
*/
public interface TranslationContext {
/** The default translation context containing no additional parameters. */
TranslationContext DEFAULT = new DefaultTranslationContext();
}

/** A convenient class representing a default context containing no additional parameters. */
private static class DefaultTranslationContext implements TranslationContext {}

// This URN says that the coder is just a UDF blob this SDK understands
// TODO: standardize such things
public static final String JAVA_SERIALIZED_CODER_URN = "beam:coders:javasdk:0.1";
Expand Down Expand Up @@ -115,21 +131,29 @@ private static RunnerApi.Coder toCustomCoder(Coder<?> coder) throws IOException
.build();
}

public static Coder<?> fromProto(RunnerApi.Coder protoCoder, RehydratedComponents components)
public static Coder<?> fromProto(
RunnerApi.Coder protoCoder, RehydratedComponents components, TranslationContext context)
throws IOException {
String coderSpecUrn = protoCoder.getSpec().getUrn();
if (coderSpecUrn.equals(JAVA_SERIALIZED_CODER_URN)) {
return fromCustomCoder(protoCoder);
}
return fromKnownCoder(protoCoder, components);
return fromKnownCoder(protoCoder, components, context);
}

private static Coder<?> fromKnownCoder(RunnerApi.Coder coder, RehydratedComponents components)
private static Coder<?> fromKnownCoder(
RunnerApi.Coder coder, RehydratedComponents components, TranslationContext context)
throws IOException {
String coderUrn = coder.getSpec().getUrn();
List<Coder<?>> coderComponents = new ArrayList<>();
for (String componentId : coder.getComponentCoderIdsList()) {
Coder<?> innerCoder = components.getCoder(componentId);
// Only store coders in RehydratedComponents as long as we are not using a custom
// translation context.
Coder<?> innerCoder =
context == TranslationContext.DEFAULT
? components.getCoder(componentId)
: fromProto(
components.getComponents().getCodersOrThrow(componentId), components, context);
coderComponents.add(innerCoder);
}
Class<? extends Coder> coderType = KNOWN_CODER_URNS.inverse().get(coderUrn);
Expand All @@ -139,7 +163,8 @@ private static Coder<?> fromKnownCoder(RunnerApi.Coder coder, RehydratedComponen
"Unknown Coder URN %s. Known URNs: %s",
coderUrn,
KNOWN_CODER_URNS.values());
return translator.fromComponents(coderComponents, coder.getSpec().getPayload().toByteArray());
return translator.fromComponents(
coderComponents, coder.getSpec().getPayload().toByteArray(), context);
}

private static Coder<?> fromCustomCoder(RunnerApi.Coder protoCoder) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.runners.core.construction;

import java.util.List;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.coders.Coder;

/**
Expand All @@ -41,6 +42,9 @@ default byte[] getPayload(T from) {
return new byte[0];
}

/** Create a {@link Coder} from its component {@link Coder coders}. */
T fromComponents(List<Coder<?>> components, byte[] payload);
/**
* Create a {@link Coder} from its component {@link Coder coders} using the specified translation
* context.
*/
T fromComponents(List<Coder<?>> components, byte[] payload, TranslationContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Collections;
import java.util.List;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
Expand Down Expand Up @@ -139,7 +140,7 @@ public byte[] getPayload(WindowedValue.ParamWindowedValueCoder<?> from) {

@Override
public WindowedValue.ParamWindowedValueCoder<?> fromComponents(
List<Coder<?>> components, byte[] payload) {
List<Coder<?>> components, byte[] payload, TranslationContext context) {
return WindowedValue.ParamWindowedValueCoder.fromComponents(components, payload);
}
};
Expand All @@ -158,7 +159,8 @@ public byte[] getPayload(RowCoder from) {
}

@Override
public RowCoder fromComponents(List<Coder<?>> components, byte[] payload) {
public RowCoder fromComponents(
List<Coder<?>> components, byte[] payload, TranslationContext context) {
checkArgument(
components.isEmpty(), "Expected empty component list, but received: " + components);
Schema schema;
Expand All @@ -175,7 +177,8 @@ public RowCoder fromComponents(List<Coder<?>> components, byte[] payload) {
public abstract static class SimpleStructuredCoderTranslator<T extends Coder<?>>
implements CoderTranslator<T> {
@Override
public final T fromComponents(List<Coder<?>> components, byte[] payload) {
public final T fromComponents(
List<Coder<?>> components, byte[] payload, TranslationContext context) {
return fromComponents(components);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.service.AutoService;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import org.apache.beam.sdk.coders.BooleanCoder;
Expand Down Expand Up @@ -97,7 +98,15 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar {
CoderTranslator.class.getSimpleName(),
Sets.difference(BEAM_MODEL_CODER_URNS.keySet(), BEAM_MODEL_CODERS.keySet()));
checkState(
ModelCoders.urns().equals(BEAM_MODEL_CODER_URNS.values()),
Sets.symmetricDifference(
ModelCoders.urns(),
/**
* The state backed iterable coder implementation is environment specific and hence
* is not part of the coder translation checks as these are meant to be used only
* during pipeline construction.
*/
Collections.singleton(ModelCoders.STATE_BACKED_ITERABLE_CODER_URN))
.equals(BEAM_MODEL_CODER_URNS.values()),
"All Model %ss should have an associated java %s",
Coder.class.getSimpleName(),
Coder.class.getSimpleName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.value.AutoValue;
import java.util.Set;
Expand Down Expand Up @@ -59,6 +60,14 @@ private ModelCoders() {}

public static final String ROW_CODER_URN = getUrn(StandardCoders.Enum.ROW);

public static final String STATE_BACKED_ITERABLE_CODER_URN =
"beam:coder:state_backed_iterable:v1";

static {
checkState(
STATE_BACKED_ITERABLE_CODER_URN.equals(getUrn(StandardCoders.Enum.STATE_BACKED_ITERABLE)));
}

private static final Set<String> MODEL_CODER_URNS =
ImmutableSet.of(
BYTES_CODER_URN,
Expand All @@ -74,7 +83,8 @@ private ModelCoders() {}
WINDOWED_VALUE_CODER_URN,
DOUBLE_CODER_URN,
ROW_CODER_URN,
PARAM_WINDOWED_VALUE_CODER_URN);
PARAM_WINDOWED_VALUE_CODER_URN,
STATE_BACKED_ITERABLE_CODER_URN);

public static Set<String> urns() {
return MODEL_CODER_URNS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.values.PCollection;
Expand Down Expand Up @@ -81,7 +82,8 @@ public class RehydratedComponents {
public Coder<?> load(String id) throws Exception {
@Nullable RunnerApi.Coder coder = components.getCodersOrDefault(id, null);
checkState(coder != null, "No coder with id '%s' in serialized components", id);
return CoderTranslation.fromProto(coder, RehydratedComponents.this);
return CoderTranslation.fromProto(
coder, RehydratedComponents.this, TranslationContext.DEFAULT);
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.avro.SchemaBuilder;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.coders.BooleanCoder;
Expand Down Expand Up @@ -165,7 +166,9 @@ public void toAndFromProto() throws Exception {
Components encodedComponents = sdkComponents.toComponents();
Coder<?> decodedCoder =
CoderTranslation.fromProto(
coderProto, RehydratedComponents.forComponents(encodedComponents));
coderProto,
RehydratedComponents.forComponents(encodedComponents),
TranslationContext.DEFAULT);
assertThat(decodedCoder, equalTo(coder));

if (KNOWN_CODERS.contains(coder)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi.StandardCoders;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.ByteCoder;
import org.apache.beam.sdk.coders.Coder;
Expand Down Expand Up @@ -407,7 +408,7 @@ private static Coder<?> instantiateCoder(CommonCoder coder) {
checkNotNull(
translator, "No translator found for common coder class: " + coderType.getSimpleName());

return translator.fromComponents(components, coder.getPayload());
return translator.fromComponents(components, coder.getPayload(), new TranslationContext() {});
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
Expand Down Expand Up @@ -200,7 +201,8 @@ public void toTransformProto() throws Exception {
Coder<?> timerCoder =
CoderTranslation.fromProto(
components.getCodersOrThrow(timerFamilySpec.getTimerFamilyCoderId()),
rehydratedComponents);
rehydratedComponents,
TranslationContext.DEFAULT);
assertEquals(
org.apache.beam.runners.core.construction.Timer.Coder.of(
VarLongCoder.of(), GlobalWindow.Coder.INSTANCE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.coders.BigEndianLongCoder;
Expand Down Expand Up @@ -213,7 +214,9 @@ private static Coder<?> getAccumulatorCoder(AppliedPTransform<?, ?, ?> transform
.orElseThrow(() -> new IOException("Transform does not contain an AccumulatorCoder"));
Components components = sdkComponents.toComponents();
return CoderTranslation.fromProto(
components.getCodersOrThrow(id), RehydratedComponents.forComponents(components));
components.getCodersOrThrow(id),
RehydratedComponents.forComponents(components),
TranslationContext.DEFAULT);
}

private static Optional<CombinePayload> getCombinePayload(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.BeamUrns;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.runners.core.construction.DeduplicatedFlattenFactory;
import org.apache.beam.runners.core.construction.EmptyFlattenAsCreateFactory;
import org.apache.beam.runners.core.construction.Environments;
Expand Down Expand Up @@ -1531,7 +1532,8 @@ private Coder<T> getCoder() throws IOException {
(Coder)
CoderTranslation.fromProto(
coderSpec.getCoder(),
RehydratedComponents.forComponents(coderSpec.getComponents()));
RehydratedComponents.forComponents(coderSpec.getComponents()),
TranslationContext.DEFAULT);
}
return coder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.runners.fnexecution.control.InstructionRequestHandler;
Expand Down Expand Up @@ -136,11 +137,15 @@ public static CacheKey create(FunctionSpec windowMappingFn, BoundedWindow mainWi
outboundCoder =
(Coder)
CoderTranslation.fromProto(
components.getCodersOrThrow(mainInputWindowCoderId), rehydratedComponents);
components.getCodersOrThrow(mainInputWindowCoderId),
rehydratedComponents,
TranslationContext.DEFAULT);
inboundCoder =
(Coder)
CoderTranslation.fromProto(
components.getCodersOrThrow(sideInputWindowCoderId), rehydratedComponents);
components.getCodersOrThrow(sideInputWindowCoderId),
rehydratedComponents,
TranslationContext.DEFAULT);
} catch (IOException e) {
throw new IllegalStateException(
"Unable to create side input window mapping process bundle specification.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,25 @@ public Coder<T> getElemCoder() {
/**
* Builds an instance of {@code IterableT}, this coder's associated {@link Iterable}-like subtype,
* from a list of decoded elements.
*
* <p>Override {@link #decodeToIterable(List, long, InputStream)} if you need access to the
* terminator value and the {@link InputStream}.
*/
protected abstract IterableT decodeToIterable(List<T> decodedElements);

/**
* Builds an instance of {@code IterableT}, this coder's associated {@link Iterable}-like subtype,
* from a list of decoded elements with the {@link InputStream} at the position where this coder
* detected the end of the stream.
*/
protected IterableT decodeToIterable(
List<T> decodedElements, long terminatorValue, InputStream in) throws IOException {
throw new IllegalStateException(
String.format(
"%s does not support non zero terminator values. Received stream with terminator %s.",
iterableName, terminatorValue));
}

/////////////////////////////////////////////////////////////////////////////
// Internal operations below here.

Expand Down Expand Up @@ -136,7 +152,11 @@ public IterableT decode(InputStream inStream) throws IOException, CoderException
count = VarInt.decodeLong(dataInStream);
}
}
return decodeToIterable(elements);
if (count == 0) {
return decodeToIterable(elements);
} else {
return decodeToIterable(elements, count, inStream);
}
}

@Override
Expand Down
Loading

0 comments on commit ffd74b0

Please sign in to comment.