Skip to content

Commit

Permalink
Validate side input parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reuven Lax committed Aug 25, 2019
1 parent 0e23ca3 commit 83dc627
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.MethodWithExtraParameters;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
Expand Down Expand Up @@ -437,6 +438,29 @@ private static void validateStateApplicableForInput(DoFn<?, ?> fn, PCollection<?
}
}

private static void validateSideInputTypes(
Map<String, PCollectionView<?>> sideInputs, DoFn<?, ?> fn) {
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
DoFnSignature.ProcessElementMethod processElementMethod = signature.processElement();
for (SideInputParameter sideInput : processElementMethod.getSideInputParameters()) {
PCollectionView<?> view = sideInputs.get(sideInput.sideInputId());
checkArgument(
view != null,
"the ProcessElement method expects a side input identified with the tag %s, but no such side input was"
+ " supplied. Use withSideInput(String, PCollectionView) to supply this side input.",
sideInput.sideInputId());
TypeDescriptor<?> viewType = view.getViewFn().getTypeDescriptor();

// Currently check that the types exactly match, even if the types are convertible.
checkArgument(
viewType.equals(sideInput.elementT()),
"Side Input with tag %s and type %s cannot be bound to ProcessElement parameter with type %s",
sideInput.sideInputId(),
viewType,
sideInput.elementT());
}
}

private static FieldAccessDescriptor getFieldAccessDescriptorFromParameter(
@Nullable String fieldAccessString,
Schema inputSchema,
Expand Down Expand Up @@ -865,6 +889,8 @@ public PCollectionTuple expand(PCollection<? extends InputT> input) {
validateStateApplicableForInput(fn, input);
}

validateSideInputTypes(sideInputs, fn);

// TODO: We should validate OutputReceiver<Row> only happens if the output PCollection
// as schema. However coder/schema inference may not have happened yet at this point.
// Need to figure out where to validate this.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ public abstract class ViewFn<PrimitiveViewT, ViewT> implements Serializable {
/** A function to adapt a primitive view type to a desired view type. */
public abstract ViewT apply(PrimitiveViewT primitiveViewT);

/** Return the {@link TypeDescriptor} describing the output of this fn. */
public abstract TypeDescriptor<ViewT> getTypeDescriptor();
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
Expand Down Expand Up @@ -760,6 +761,14 @@ public List<SchemaElementParameter> getSchemaElementParameters() {
.collect(Collectors.toList());
}

@Nullable
public List<SideInputParameter> getSideInputParameters() {
return extraParameters().stream()
.filter(Predicates.instanceOf(SideInputParameter.class)::apply)
.map(SideInputParameter.class::cast)
.collect(Collectors.toList());
}

/** The {@link OutputReceiverParameter} for a main output, or null if there is none. */
@Nullable
public OutputReceiverParameter getMainOutputReceiver() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,76 @@ public void testParDoWithSideInputs() {
pipeline.run();
}

@Test
@Category({NeedsRunner.class, UsesSideInputs.class})
public void testSideInputAnnotationFailedValidationMissing() {
// SideInput tag id
final String sideInputTag1 = "tag1";

DoFn<Integer, List<Integer>> fn =
new DoFn<Integer, List<Integer>>() {
@ProcessElement
public void processElement(@SideInput(sideInputTag1) String tag1) {}
};

thrown.expect(IllegalArgumentException.class);
PCollection<List<Integer>> output =
pipeline.apply("Create main input", Create.of(2)).apply(ParDo.of(fn));
pipeline.run();
}

@Test
@Category({NeedsRunner.class, UsesSideInputs.class})
public void testSideInputAnnotationFailedValidationSingletonType() {

final PCollectionView<Integer> sideInput1 =
pipeline
.apply("CreateSideInput1", Create.of(2))
.apply("ViewSideInput1", View.asSingleton());

// SideInput tag id
final String sideInputTag1 = "tag1";

DoFn<Integer, List<Integer>> fn =
new DoFn<Integer, List<Integer>>() {
@ProcessElement
public void processElement(@SideInput(sideInputTag1) String tag1) {}
};

thrown.expect(IllegalArgumentException.class);
PCollection<List<Integer>> output =
pipeline
.apply("Create main input", Create.of(2))
.apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1));
pipeline.run();
}

@Test
@Category({NeedsRunner.class, UsesSideInputs.class})
public void testSideInputAnnotationFailedValidationListType() {

final PCollectionView<List<Integer>> sideInput1 =
pipeline
.apply("CreateSideInput1", Create.of(2, 1, 0))
.apply("ViewSideInput1", View.asList());

// SideInput tag id
final String sideInputTag1 = "tag1";

DoFn<Integer, List<Integer>> fn =
new DoFn<Integer, List<Integer>>() {
@ProcessElement
public void processElement(@SideInput(sideInputTag1) List<String> tag1) {}
};

thrown.expect(IllegalArgumentException.class);
PCollection<List<Integer>> output =
pipeline
.apply("Create main input", Create.of(2))
.apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1));
pipeline.run();
}

@Test
@Category({ValidatesRunner.class, UsesSideInputs.class})
public void testSideInputAnnotation() {
Expand Down

0 comments on commit 83dc627

Please sign in to comment.