Skip to content

Commit

Permalink
Merge pull request apache#10990: [BEAM-9569] disable coder inference …
Browse files Browse the repository at this point in the history
…for rows
  • Loading branch information
reuvenlax committed Mar 23, 2020
1 parent 7310ec2 commit fb59b6a
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -597,6 +598,10 @@ private <T> Coder<T> getCoderFromTypeDescriptor(
throws CannotProvideCoderException {
Type type = typeDescriptor.getType();
Coder<?> coder;
if (typeDescriptor.equals(TypeDescriptors.rows())) {
throw new CannotProvideCoderException(
"Cannot provide a coder for a Beam Row. Please provide a schema instead using PCollection.setRowSchema.");
}
if (typeCoderBindings.containsKey(type)) {
Set<Coder<?>> coders = typeCoderBindings.get(type);
if (coders.size() == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
*
* <p>The result will be a new row schema containing the fields total_cost, top_purchases, and
* transactionDurations, containing the sum of all purchases costs (for that user and country), the
* top ten purchases, and a histogram of transaction durations. The schema will als contain a key
* top ten purchases, and a histogram of transaction durations. The schema will also contain a key
* field, which will be a row containing userId and country.
*
* <p>Note that usually the field type can be automatically inferred from the {@link CombineFn}
Expand Down Expand Up @@ -291,9 +291,9 @@ public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> agg
@Override
public PCollection<Iterable<InputT>> expand(PCollection<InputT> input) {
return input
.apply(WithKeys.of((Void) null))
.apply(GroupByKey.create())
.apply(Values.create());
.apply("addNullKey", WithKeys.of((Void) null))
.apply("group", GroupByKey.create())
.apply("extractValues", Values.create());
}
}

Expand All @@ -308,7 +308,7 @@ public static class CombineGlobally<InputT, OutputT>

@Override
public PCollection<OutputT> expand(PCollection<InputT> input) {
return input.apply(Combine.globally(combineFn));
return input.apply("globalCombine", Combine.globally(combineFn));
}
}

Expand Down Expand Up @@ -460,8 +460,8 @@ public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> agg
public PCollection<Row> expand(PCollection<InputT> input) {
SchemaAggregateFn.Inner fn = schemaAggregateFn.withSchema(input.getSchema());
return input
.apply(Convert.toRows())
.apply(Combine.globally(fn))
.apply("toRows", Convert.toRows())
.apply("Global Combine", Combine.globally(fn))
.setRowSchema(fn.getOutputSchema());
}
}
Expand Down Expand Up @@ -512,7 +512,7 @@ public PCollection<KV<Row, Iterable<Row>>> expand(PCollection<InputT> input) {
"selectKeys",
WithKeys.of((Row e) -> rowSelector.select(e)).withKeyType(TypeDescriptors.rows()))
.setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)))
.apply(GroupByKey.create());
.apply("GroupByKey", GroupByKey.create());
}
}

Expand Down Expand Up @@ -704,8 +704,9 @@ public PCollection<Row> expand(PCollection<InputT> input) {
.build();

return input
.apply(getToKvs())
.apply("ToKvs", getToKvs())
.apply(
"ToRow",
ParDo.of(
new DoFn<KV<Row, Iterable<Row>>, Row>() {
@ProcessElement
Expand Down Expand Up @@ -924,9 +925,10 @@ public PCollection<Row> expand(PCollection<InputT> input) {
.build();

return input
.apply(getByFields().getToKvs())
.apply(Combine.groupedValues(fn))
.apply("ToKvs", getByFields().getToKvs())
.apply("Combine", Combine.groupedValues(fn))
.apply(
"ToRow",
ParDo.of(
new DoFn<KV<Row, Row>, Row>() {
@ProcessElement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@

import java.io.Serializable;
import org.apache.beam.sdk.extensions.sql.impl.transform.BeamSetOperatorsTransforms;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.transforms.CoGroup;
import org.apache.beam.sdk.schemas.transforms.CoGroup.By;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;

/**
* Delegate for Set operators: {@code BeamUnionRel}, {@code BeamIntersectRel} and {@code
Expand Down Expand Up @@ -65,6 +63,16 @@ public PCollection<Row> expand(PCollectionList<Row> inputs) {
inputs);
PCollection<Row> leftRows = inputs.get(0);
PCollection<Row> rightRows = inputs.get(1);
Schema leftSchema = leftRows.getSchema();
Schema rightSchema = rightRows.getSchema();
if (!leftSchema.typesEqual(rightSchema)) {
throw new IllegalArgumentException(
"Can't intersect two tables with different schemas."
+ "lhsSchema: "
+ leftSchema
+ " rhsSchema: "
+ rightSchema);
}

WindowFn leftWindow = leftRows.getWindowingStrategy().getWindowFn();
WindowFn rightWindow = rightRows.getWindowingStrategy().getWindowFn();
Expand All @@ -78,25 +86,20 @@ public PCollection<Row> expand(PCollectionList<Row> inputs) {
+ rightWindow);
}

final TupleTag<Row> leftTag = new TupleTag<>();
final TupleTag<Row> rightTag = new TupleTag<>();

// co-group
PCollection<KV<Row, CoGbkResult>> coGbkResultCollection =
KeyedPCollectionTuple.of(
leftTag,
leftRows.apply(
"CreateLeftIndex",
MapElements.via(new BeamSetOperatorsTransforms.BeamSqlRow2KvFn())))
.and(
rightTag,
rightRows.apply(
"CreateRightIndex",
MapElements.via(new BeamSetOperatorsTransforms.BeamSqlRow2KvFn())))
.apply(CoGroupByKey.create());
return coGbkResultCollection.apply(
ParDo.of(
new BeamSetOperatorsTransforms.SetOperatorFilteringDoFn(
leftTag, rightTag, opType, all)));
// TODO: We may want to preaggregate the counts first using Group instead of calling CoGroup and
// measuring the
// iterable size. If on average there are duplicates in the input, this will be faster.
final String lhsTag = "lhs";
final String rhsTag = "rhs";
PCollection<Row> joined =
PCollectionTuple.of(lhsTag, leftRows, rhsTag, rightRows)
.apply("CoGroup", CoGroup.join(By.fieldNames("*")));
return joined
.apply(
"FilterResults",
ParDo.of(
new BeamSetOperatorsTransforms.SetOperatorFilteringDoFn(
lhsTag, rhsTag, opType, all)))
.setRowSchema(joined.getSchema().getField("key").getType().getRowSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.transform;

import java.util.Iterator;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSetOperatorRelBase;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;

/** Collections of {@code PTransform} and {@code DoFn} used to perform Set operations. */
public abstract class BeamSetOperatorsTransforms {
Expand All @@ -38,61 +35,54 @@ public KV<Row, Row> apply(Row input) {
}

/** Filter function used for Set operators. */
public static class SetOperatorFilteringDoFn extends DoFn<KV<Row, CoGbkResult>, Row> {
private TupleTag<Row> leftTag;
private TupleTag<Row> rightTag;
private BeamSetOperatorRelBase.OpType opType;
public static class SetOperatorFilteringDoFn extends DoFn<Row, Row> {
private final String leftTag;
private final String rightTag;
private final BeamSetOperatorRelBase.OpType opType;
// ALL?
private boolean all;
private final boolean all;

public SetOperatorFilteringDoFn(
TupleTag<Row> leftTag,
TupleTag<Row> rightTag,
BeamSetOperatorRelBase.OpType opType,
boolean all) {
String leftTag, String rightTag, BeamSetOperatorRelBase.OpType opType, boolean all) {
this.leftTag = leftTag;
this.rightTag = rightTag;
this.opType = opType;
this.all = all;
}

@ProcessElement
public void processElement(ProcessContext ctx) {
CoGbkResult coGbkResult = ctx.element().getValue();
Iterable<Row> leftRows = coGbkResult.getAll(leftTag);
Iterable<Row> rightRows = coGbkResult.getAll(rightTag);
public void processElement(@Element Row element, OutputReceiver<Row> o) {
Row key = element.getRow("key");
long numLeftRows = 0;
long numRightRows = 0;
if (!Iterables.isEmpty(element.<Row>getIterable(leftTag))) {
numLeftRows = Iterables.size(element.<Row>getIterable(leftTag));
}
if (!Iterables.isEmpty(element.<Row>getIterable(rightTag))) {
numRightRows = Iterables.size(element.<Row>getIterable(rightTag));
}

switch (opType) {
case UNION:
if (all) {
// output both left & right
Iterator<Row> iter = leftRows.iterator();
while (iter.hasNext()) {
ctx.output(iter.next());
}
iter = rightRows.iterator();
while (iter.hasNext()) {
ctx.output(iter.next());
for (int i = 0; i < numLeftRows + numRightRows; i++) {
o.output(key);
}
} else {
// only output the key
ctx.output(ctx.element().getKey());
o.output(key);
}
break;
case INTERSECT:
if (leftRows.iterator().hasNext() && rightRows.iterator().hasNext()) {
if (numLeftRows > 0 && numRightRows > 0) {
if (all) {
int leftCount = Iterators.size(leftRows.iterator());
int rightCount = Iterators.size(rightRows.iterator());

// Say for Row R, there are m instances on left and n instances on right,
// INTERSECT ALL outputs MIN(m, n) instances of R.
Iterator<Row> iter =
(leftCount <= rightCount) ? leftRows.iterator() : rightRows.iterator();
while (iter.hasNext()) {
ctx.output(iter.next());
for (int i = 0; i < Math.min(numLeftRows, numRightRows); i++) {
o.output(key);
}
} else {
ctx.output(ctx.element().getKey());
o.output(key);
}
}
break;
Expand All @@ -101,27 +91,23 @@ public void processElement(ProcessContext ctx) {
// - EXCEPT ALL outputs MAX(m - n, 0) instances of R.
// - EXCEPT [DISTINCT] outputs a single instance of R if m > 0 and n == 0, else
// they output 0 instances.
if (leftRows.iterator().hasNext() && !rightRows.iterator().hasNext()) {
Iterator<Row> iter = leftRows.iterator();
if (numLeftRows > 0 && numRightRows == 0) {
if (all) {
// output all
while (iter.hasNext()) {
ctx.output(iter.next());
for (int i = 0; i < numLeftRows; i++) {
o.output(key);
}
} else {
// only output one
ctx.output(iter.next());
o.output(key);
}
} else if (leftRows.iterator().hasNext() && rightRows.iterator().hasNext()) {
int leftCount = Iterators.size(leftRows.iterator());
int rightCount = Iterators.size(rightRows.iterator());

int outputCount = leftCount - rightCount;
} else if (numLeftRows > 0 && numRightRows > 0) {
long outputCount = numLeftRows - numRightRows;
if (outputCount > 0) {
if (all) {
while (outputCount > 0) {
outputCount--;
ctx.output(ctx.element().getKey());
o.output(key);
}
}
// Dont output any in DISTINCT (if (!all)) case
Expand Down
Loading

0 comments on commit fb59b6a

Please sign in to comment.