From 9e01c5a1b05c29a8ada57c913180bfdcc7522579 Mon Sep 17 00:00:00 2001 From: Andrew Crites Date: Thu, 2 Apr 2020 14:11:06 -0700 Subject: [PATCH] [BEAM-9624] Adds Convert to Accumulators operator for use in combiner lifting for streaming pipelines (#11271) * Adds convert_to_accumulators URN and associated implementations for Go/Java/Python SDKs. * Fixes copy-paste error in proto comments. * Fixes spelling error in comment. * Add test for Java ConvertToAccumulators. * Adds a test for go and extra checking in java translation. * Changes number of parameters to createRunnerForPTransform since signature changed. * Changes ConvertToAccumulator test to only test the ConverToAccumulators phase. --- .../src/main/proto/beam_runner_api.proto | 8 +++ .../construction/PTransformTranslation.java | 5 ++ sdks/go/pkg/beam/core/runtime/exec/combine.go | 27 ++++++++ .../beam/core/runtime/exec/combine_test.go | 41 +++++++++++- .../pkg/beam/core/runtime/exec/translate.go | 7 ++- .../beam/fn/harness/CombineRunners.java | 15 +++++ .../beam/fn/harness/CombineRunnersTest.java | 63 +++++++++++++++++++ .../runners/worker/bundle_processor.py | 14 +++++ .../apache_beam/transforms/combiners.py | 6 ++ 9 files changed, 182 insertions(+), 4 deletions(-) diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index 17229d2e261fd..ed84e5bbbff2e 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -342,6 +342,14 @@ message StandardPTransforms { // https://s.apache.org/beam-runner-api-combine-model#heading=h.aj86ew4v1wk // Payload: CombinePayload COMBINE_GROUPED_VALUES = 3 [(beam_urn) = "beam:transform:combine_grouped_values:v1"]; + + // Represents the Convert To Accumulators transform, as described in the + // following document: + // https://s.apache.org/beam-runner-api-combine-model#heading=h.h5697l1scd9x + // Payload: CombinePayload + COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS = 4 + [(beam_urn) = + "beam:transform:combine_per_key_convert_to_accumulators:v1"]; } // Payload for all of these: ParDoPayload containing the user's SDF enum SplittableParDoComponents { diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java index 6fead0fbea8c0..1ff1d6940c353 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformTranslation.java @@ -101,6 +101,8 @@ public class PTransformTranslation { "beam:transform:combine_per_key_merge_accumulators:v1"; public static final String COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN = "beam:transform:combine_per_key_extract_outputs:v1"; + public static final String COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS_TRANSFORM_URN = + "beam:transform:combine_per_key_convert_to_accumulators:v1"; public static final String COMBINE_GROUPED_VALUES_TRANSFORM_URN = "beam:transform:combine_grouped_values:v1"; @@ -167,6 +169,9 @@ public class PTransformTranslation { checkState( COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN.equals( getUrn(CombineComponents.COMBINE_PER_KEY_EXTRACT_OUTPUTS))); + checkState( + COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS_TRANSFORM_URN.equals( + getUrn(CombineComponents.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS))); checkState( COMBINE_GROUPED_VALUES_TRANSFORM_URN.equals( getUrn(CombineComponents.COMBINE_GROUPED_VALUES))); diff --git a/sdks/go/pkg/beam/core/runtime/exec/combine.go b/sdks/go/pkg/beam/core/runtime/exec/combine.go index 7815ac127f064..0d5432a31a83d 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/combine.go +++ b/sdks/go/pkg/beam/core/runtime/exec/combine.go @@ -499,3 +499,30 @@ func (n *ExtractOutput) ProcessElement(ctx context.Context, value *FullValue, va } return n.Out.ProcessElement(n.Combine.ctx, &FullValue{Windows: value.Windows, Elm: value.Elm, Elm2: out, Timestamp: value.Timestamp}) } + +// ConvertToAccumulators is an executor for converting an input value to an accumulator value. +type ConvertToAccumulators struct { + *Combine +} + +func (n *ConvertToAccumulators) String() string { + return fmt.Sprintf("ConvertToAccumulators[%v] Keyed:%v Out:%v", path.Base(n.Fn.Name()), n.UsesKey, n.Out.ID()) +} + +// ProcessElement accepts an input value and returns an accumulator containing that one value. +func (n *ConvertToAccumulators) ProcessElement(ctx context.Context, value *FullValue, values ...ReStream) error { + if n.status != Active { + return errors.Errorf("invalid status for combine convert %v: %v", n.UID, n.status) + } + a, err := n.newAccum(n.Combine.ctx, value.Elm) + if err != nil { + return n.fail(err) + } + + first := true + a, err = n.addInput(n.Combine.ctx, a, value.Elm, value.Elm2, value.Timestamp, first) + if err != nil { + return n.fail(err) + } + return n.Out.ProcessElement(n.Combine.ctx, &FullValue{Windows: value.Windows, Elm: value.Elm, Elm2: a, Timestamp: value.Timestamp}) +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/combine_test.go b/sdks/go/pkg/beam/core/runtime/exec/combine_test.go index 49f2d457a2b16..e49f9d4c4e21d 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/combine_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/combine_test.go @@ -34,6 +34,7 @@ import ( ) var intInput = []interface{}{int(1), int(2), int(3), int(4), int(5), int(6)} +var int64Input = []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5), int64(6)} var strInput = []interface{}{"1", "2", "3", "4", "5", "6"} var tests = []struct { @@ -113,6 +114,42 @@ func TestLiftedCombine(t *testing.T) { } +// TestConvertToAccumulators verifies that the ConvertToAccumulators phase +// correctly doesn't accumulate values at all. +func TestConvertToAccumulators(t *testing.T) { + tests := []struct { + Fn interface{} + AccumCoder *coder.Coder + Input []interface{} + Expected []interface{} + }{ + {Fn: mergeFn, AccumCoder: intCoder(reflectx.Int), Input: intInput, Expected: intInput}, + {Fn: nonBinaryMergeFn, AccumCoder: intCoder(reflectx.Int), Input: intInput, Expected: intInput}, + {Fn: &MyCombine{}, AccumCoder: intCoder(reflectx.Int64), Input: intInput, Expected: int64Input}, + {Fn: &MyOtherCombine{}, AccumCoder: intCoder(reflectx.Int64), Input: intInput, Expected: int64Input}, + {Fn: &MyThirdCombine{}, AccumCoder: intCoder(reflectx.Int), Input: strInput, Expected: intInput}, + {Fn: &MyContextCombine{}, AccumCoder: intCoder(reflectx.Int64), Input: intInput, Expected: int64Input}, + {Fn: &MyErrorCombine{}, AccumCoder: intCoder(reflectx.Int64), Input: intInput, Expected: int64Input}, + } + for _, test := range tests { + t.Run(fnName(test.Fn), func(t *testing.T) { + edge := getCombineEdge(t, test.Fn, reflectx.Int, test.AccumCoder) + + testKey := 42 + out := &CaptureNode{UID: 1} + convertToAccumulators := &ConvertToAccumulators{Combine: &Combine{UID: 2, Fn: edge.CombineFn, Out: out}} + n := &FixedRoot{UID: 3, Elements: makeKVInput(testKey, test.Input...), Out: convertToAccumulators} + + constructAndExecutePlan(t, []Unit{n, convertToAccumulators, out}) + + expected := makeKVValues(testKey, test.Expected...) + if !equalList(out.Elements, expected) { + t.Errorf("convertToAccumulators(%s) = %#v, want %#v", edge.CombineFn.Name(), extractKeyedValues(out.Elements...), extractKeyedValues(expected...)) + } + }) + } +} + type codable interface { EncodeMe() []byte DecodeMe([]byte) @@ -200,7 +237,7 @@ func constructAndExecutePlan(t *testing.T, us []Unit) { // MergeAccumulators(a, b AccumT) AccumT // ExtractOutput(v AccumT) OutputT // -// In addition, depending there can be three distinct types, depending on where +// In addition, there can be three distinct types, depending on where // they are used in the combine, the input type, InputT, the output type, OutputT, // and the accumulator type AccumT. Depending on the equality of the types, one // or more of the methods can be unspecified. @@ -264,7 +301,7 @@ func (*MyOtherCombine) ExtractOutput(a int64) string { return fmt.Sprintf("%d", a) } -// MyThirdCombine has parses strings as Input, and doesn't specify an ExtractOutput +// MyThirdCombine parses strings as Input, and doesn't specify an ExtractOutput // // InputT == string // AccumT == int diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 7a0381948d465..7041b455e2915 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -43,6 +43,7 @@ const ( urnPerKeyCombinePre = "beam:transform:combine_per_key_precombine:v1" urnPerKeyCombineMerge = "beam:transform:combine_per_key_merge_accumulators:v1" urnPerKeyCombineExtract = "beam:transform:combine_per_key_extract_outputs:v1" + urnPerKeyCombineConvert = "beam:transform:combine_per_key_convert_to_accumulators:v1" ) // UnmarshalPlan converts a model bundle descriptor into an execution Plan. @@ -332,7 +333,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { var u Node switch urn { - case graphx.URNParDo, graphx.URNJavaDoFn, urnPerKeyCombinePre, urnPerKeyCombineMerge, urnPerKeyCombineExtract: + case graphx.URNParDo, graphx.URNJavaDoFn, urnPerKeyCombinePre, urnPerKeyCombineMerge, urnPerKeyCombineExtract, urnPerKeyCombineConvert: var data string switch urn { case graphx.URNParDo: @@ -341,7 +342,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { return nil, errors.Wrapf(err, "invalid ParDo payload for %v", transform) } data = string(pardo.GetDoFn().GetPayload()) - case urnPerKeyCombinePre, urnPerKeyCombineMerge, urnPerKeyCombineExtract: + case urnPerKeyCombinePre, urnPerKeyCombineMerge, urnPerKeyCombineExtract, urnPerKeyCombineConvert: var cmb pb.CombinePayload if err := proto.Unmarshal(payload, &cmb); err != nil { return nil, errors.Wrapf(err, "invalid CombinePayload payload for %v", transform) @@ -426,6 +427,8 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { u = &MergeAccumulators{Combine: cn} case urnPerKeyCombineExtract: u = &ExtractOutput{Combine: cn} + case urnPerKeyCombineConvert: + u = &ConvertToAccumulators{Combine: cn} default: // For unlifted combines u = cn } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java index feb6b6b76faba..88be53692ba88 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java @@ -65,6 +65,8 @@ public Map getPTransformRunnerFactories() { MapFnRunners.forValueMapFnFactory(CombineRunners::createMergeAccumulatorsMapFunction), PTransformTranslation.COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN, MapFnRunners.forValueMapFnFactory(CombineRunners::createExtractOutputsMapFunction), + PTransformTranslation.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS_TRANSFORM_URN, + MapFnRunners.forValueMapFnFactory(CombineRunners::createConvertToAccumulatorsMapFunction), PTransformTranslation.COMBINE_GROUPED_VALUES_TRANSFORM_URN, MapFnRunners.forValueMapFnFactory(CombineRunners::createCombineGroupedValuesMapFunction)); } @@ -210,6 +212,19 @@ ThrowingFunction, KV> createExtractOutputsMapFun KV.of(input.getKey(), combineFn.extractOutput(input.getValue())); } + static + ThrowingFunction, KV> createConvertToAccumulatorsMapFunction( + String pTransformId, PTransform pTransform) throws IOException { + CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload()); + CombineFn combineFn = + (CombineFn) + SerializableUtils.deserializeFromByteArray( + combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn"); + + return (KV input) -> + KV.of(input.getKey(), combineFn.addInput(combineFn.createAccumulator(), input.getValue())); + } + static ThrowingFunction>, KV> createCombineGroupedValuesMapFunction(String pTransformId, PTransform pTransform) diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java index 0b0002f72b869..10239cc9f4ce9 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java @@ -320,6 +320,69 @@ public void testExtractOutputs() throws Exception { valueInGlobalWindow(KV.of("C", -7)))); } + /** + * Create a Convert To Accumulators function that is given keyed accumulators and validates that + * the input values were turned into the accumulator type. + */ + @Test + public void testConvertToAccumulators() throws Exception { + // Create a map of consumers and an output target to check output values. + MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); + PCollectionConsumerRegistry consumers = + new PCollectionConsumerRegistry( + metricsContainerRegistry, mock(ExecutionStateTracker.class)); + Deque>> mainOutputValues = new ArrayDeque<>(); + consumers.register( + Iterables.getOnlyElement(pTransform.getOutputsMap().values()), + TEST_COMBINE_ID, + (FnDataReceiver) + (FnDataReceiver>>) mainOutputValues::add); + + PTransformFunctionRegistry startFunctionRegistry = + new PTransformFunctionRegistry( + mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start"); + PTransformFunctionRegistry finishFunctionRegistry = + new PTransformFunctionRegistry( + mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + + // Create runner. + MapFnRunners.forValueMapFnFactory(CombineRunners::createConvertToAccumulatorsMapFunction) + .createRunnerForPTransform( + PipelineOptionsFactory.create(), + null, + null, + TEST_COMBINE_ID, + pTransform, + null, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + consumers, + startFunctionRegistry, + finishFunctionRegistry, + null, + null, + null); + + assertThat(startFunctionRegistry.getFunctions(), empty()); + assertThat(finishFunctionRegistry.getFunctions(), empty()); + + // Send elements to runner and check outputs. + mainOutputValues.clear(); + assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId)); + + FnDataReceiver> input = consumers.getMultiplexingConsumer(inputPCollectionId); + input.accept(valueInGlobalWindow(KV.of("A", "9"))); + input.accept(valueInGlobalWindow(KV.of("B", "5"))); + input.accept(valueInGlobalWindow(KV.of("C", "7"))); + + assertThat( + mainOutputValues, + contains( + valueInGlobalWindow(KV.of("A", 9)), + valueInGlobalWindow(KV.of("B", 5)), + valueInGlobalWindow(KV.of("C", 7)))); + } /** * Create a Combine Grouped Values function that is given lists of values that are grouped by key * and validates that the lists are properly combined. diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 918069529bc37..ef70815ad811b 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -1666,6 +1666,20 @@ def create_combine_per_key_extract_outputs( factory, transform_id, transform_proto, payload, consumers, 'extract') +@BeamTransformFactory.register_urn( + common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, + beam_runner_api_pb2.CombinePayload) +def create_combine_per_key_convert_to_accumulators( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + payload, # type: beam_runner_api_pb2.CombinePayload + consumers # type: Dict[str, List[operations.Operation]] +): + return _create_combine_phase_operation( + factory, transform_id, transform_proto, payload, consumers, 'convert') + + @BeamTransformFactory.register_urn( common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, beam_runner_api_pb2.CombinePayload) diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 0c882784ee492..d50c69e16a7de 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -878,6 +878,8 @@ def __init__(self, phase, fn, args, kwargs): self.apply = self.merge_only elif phase == 'extract': self.apply = self.extract_only + elif phase == 'convert': + self.apply = self.convert_to_accumulator else: raise ValueError('Unexpected phase: %s' % phase) @@ -894,6 +896,10 @@ def merge_only(self, accumulators): def extract_only(self, accumulator): return self.combine_fn.extract_output(accumulator) + def convert_to_accumulator(self, element): + return self.combine_fn.add_input( + self.combine_fn.create_accumulator(), element) + class Latest(object): """Combiners for computing the latest element"""