Skip to content

Commit

Permalink
Merge pull request apache#10151: [BEAM-7116] Remove use of KV in Sche…
Browse files Browse the repository at this point in the history
…ma transforms
  • Loading branch information
reuvenlax committed Dec 9, 2019
1 parent 194d1e7 commit 9501152
Show file tree
Hide file tree
Showing 10 changed files with 698 additions and 464 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -56,59 +57,61 @@
*
* <p>This transform has similarities to {@link CoGroupByKey}, however works on PCollections that
* have schemas. This allows users of the transform to simply specify schema fields to join on. The
* output type of the transform is a {@code KV<Row, Row>} where the value contains one field for
* every input PCollection and the key represents the fields that were joined on. By default the
* cross product is not expanded, so all fields in the output row are array fields.
* output type of the transform is {@code Row} that contains one row field for the key and an ITERABLE
* field for each input containing the rows that joined on that key; by default the cross product is
* not expanded, but the cross product can be optionally expanded. By default the key field is named
* "key" (the name can be overridden using withKeyField) and has index 0. The tags in the
* PCollectionTuple control the names of the value fields in the Row.
*
* <p>For example, the following demonstrates joining three PCollections on the "user" and "country"
* fields:
*
* <pre>{@code PCollection<KV<Row, Row>> joined =
* <pre>{@code PCollection<Row> joined =
* PCollectionTuple.of("input1", input1, "input2", input2, "input3", input3)
* .apply(CoGroup.join(By.fieldNames("user", "country")));
* }</pre>
*
* <p>In the above case, the key schema will contain the two string fields "user" and "country"; in
* this case, the schemas for Input1, Input2, Input3 must all have fields named "user" and
* "country". The value schema will contain three array of Row fields named "input1" "input2" and
* "input3". The value Row contains all inputs that came in on any of the inputs for that key.
* "country". The remainder of the Row will contain three iterable of Row fields named "input1"
* "input2" and "input3". This contains all inputs that came in on any of the inputs for that key.
* Standard join types (inner join, outer join, etc.) can be accomplished by expanding the cross
* product of these arrays in various ways.
* product of these iterables in various ways.
*
* <p>To put it in other words, the key schema is convertible to the following POJO:
*
* <pre>{@code @DefaultSchema(JavaFieldSchema.class)
* public class JoinedKey {
* public String user;
* public String country;
* }
*
* PCollection<JoinedKey> keys = joined
* .apply(Keys.create())
* .apply(Convert.to(JoinedKey.class));
* }</pre>
*
* <p>The value schema is convertible to the following POJO:
* <p>The value schema is convertible to the following POJO:
*
* <pre>{@code @DefaultSchema(JavaFieldSchema.class)
* public class JoinedValue {
* // The below lists contain all values from each of the three inputs that match on the given
* // key.
* public List<Input1Type> input1;
* public List<Input2Type> input2;
* public List<Input3Type> input3;
* }
* <pre>{@code @DefaultSchema(JavaFieldSchema.class)
* public class JoinedValue {
* public JoinedKey key;
* // The below lists contain all values from each of the three inputs that match on the given
* // key.
* public Iterable<Input1Type> input1;
* public Iterable<Input2Type> input2;
* public Iterable<Input3Type> input3;
* }
*
* PCollection<JoinedValue> values = joined.apply(Convert.to(JoinedValue.class));
*
* PCollection<JoinedValue> values = joined
* .apply(Values.create())
* .apply(Convert.to(JoinedValue.class));
* PCollection<JoinedKey> keys = values
* .apply(Select.fieldNames("key"))
* .apply(Convert.to(JoinedKey.class));
* }</pre>
*
*
*
* <p>It's also possible to join between different fields in two inputs, as long as the types of
* those fields match. In this case, fields must be specified for every input PCollection. For
* example:
*
* <pre>{@code PCollection<KV<Row, Row>> joined
* <pre>{@code PCollection<Row> joined
* = PCollectionTuple.of("input1Tag", input1, "input2Tag", input2)
* .apply(CoGroup
* .join("input1Tag", By.fieldNames("referringUser")))
Expand Down Expand Up @@ -191,7 +194,7 @@
* {@link CoGroup} transform supports any number of inputs, and optional participation can be
* specified on any subset of them.
*
* <p>Do note that cross-product joins while simpler and easier to program, can cause
* <p>Do note that cross-product joins while simpler and easier to program, can cause performance problems.
*/
@Experimental(Experimental.Kind.SCHEMAS)
public class CoGroup {
Expand Down Expand Up @@ -423,15 +426,25 @@ static void verify(PCollectionTuple input, JoinArguments joinArgs) {
}

/** The implementing PTransform. */
public static class Impl extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> {
public static class Impl extends PTransform<PCollectionTuple, PCollection<Row>> {
private final JoinArguments joinArgs;
private final String keyFieldName;

private Impl() {
this(new JoinArguments(Collections.emptyMap()));
}

private Impl(JoinArguments joinArgs) {
this(joinArgs, "key");
}

private Impl(JoinArguments joinArgs, String keyFieldName) {
this.joinArgs = joinArgs;
this.keyFieldName = keyFieldName;
}

public Impl withKeyField(String keyFieldName) {
return new Impl(joinArgs, keyFieldName);
}

/**
Expand All @@ -443,85 +456,92 @@ public Impl join(String tag, By clause) {
if (joinArgs.allInputsJoinArgs != null) {
throw new IllegalStateException("Cannot set both a global and per-tag fields.");
}
return new Impl(joinArgs.with(tag, clause));
return new Impl(joinArgs.with(tag, clause), keyFieldName);
}

/** Expand the join into individual rows, similar to SQL joins. */
public ExpandCrossProduct crossProductJoin() {
return new ExpandCrossProduct(joinArgs);
}

private Schema getOutputSchema(JoinInformation joinInformation) {
// Construct the output schema. It contains one field for each input PCollection, of type
// ARRAY[ROW].
Schema.Builder joinedSchemaBuilder = Schema.builder();
for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) {
joinedSchemaBuilder.addArrayField(entry.getKey(), FieldType.row(entry.getValue()));
}
return joinedSchemaBuilder.build();
}

@Override
public PCollection<KV<Row, Row>> expand(PCollectionTuple input) {
public PCollection<Row> expand(PCollectionTuple input) {
verify(input, joinArgs);

JoinInformation joinInformation =
JoinInformation.from(input, joinArgs::getFieldAccessDescriptor);

Schema joinedSchema = getOutputSchema(joinInformation);

ConvertToRow convertToRow = new ConvertToRow(joinInformation, keyFieldName);
return joinInformation
.keyedPCollectionTuple
.apply("CoGroupByKey", CoGroupByKey.create())
.apply(
"ConvertToRow",
ParDo.of(
new ConvertToRow(
joinInformation.sortedTags,
joinInformation.toRows,
joinedSchema,
joinInformation.tagToKeyedTag)))
.setCoder(
KvCoder.of(SchemaCoder.of(joinInformation.keySchema), SchemaCoder.of(joinedSchema)));
.apply("ConvertToRow", ParDo.of(convertToRow))
.setRowSchema(convertToRow.getOutputSchema());
}

// Used by the unexpanded join to create the output rows.
private static class ConvertToRow extends DoFn<KV<Row, CoGbkResult>, KV<Row, Row>> {
private static class ConvertToRow extends DoFn<KV<Row, CoGbkResult>, Row> {
private final List<String> sortedTags;
private final Map<Integer, SerializableFunction<Object, Row>> toRows;
private final Map<Integer, String> tagToKeyedTag;
private final Schema joinedSchema;
private final Map<Integer, SerializableFunction<Object, Row>> toRows;

ConvertToRow(
List<String> sortedTags,
Map<Integer, SerializableFunction<Object, Row>> toRows,
Schema joinedSchema,
Map<Integer, String> tagToKeyedTag) {
this.sortedTags = sortedTags;
this.toRows = toRows;
this.joinedSchema = joinedSchema;
this.tagToKeyedTag = tagToKeyedTag;
private final Schema outputSchema;

ConvertToRow(JoinInformation joinInformation, String keyFieldName) {
this.sortedTags = joinInformation.sortedTags;
this.tagToKeyedTag = joinInformation.tagToKeyedTag;
this.toRows = joinInformation.toRows;
Schema.Builder schemaBuilder =
Schema.builder().addRowField(keyFieldName, joinInformation.keySchema);
for (Map.Entry<String, Schema> entry : joinInformation.componentSchemas.entrySet()) {
schemaBuilder.addIterableField(entry.getKey(), FieldType.row(entry.getValue()));
}
outputSchema = schemaBuilder.build();
}

Schema getOutputSchema() {
return outputSchema;
}

/** Lazy iterable that wraps the result returned from CoGroupByKey. */
static final class Result implements Iterable<Row> {
private Iterable<Row> coGbkIterable;
private SerializableFunction<Object, Row> toRow;

Result(Iterable<Row> coGbkIterable, SerializableFunction<Object, Row> toRow) {
this.coGbkIterable = coGbkIterable;
this.toRow = toRow;
}

@Override
public Iterator<Row> iterator() {
return new Iterator<Row>() {
private Iterator<Row> coGbkIterator = coGbkIterable.iterator();

@Override
public boolean hasNext() {
return coGbkIterator.hasNext();
}

@Override
public Row next() {
return toRow.apply(coGbkIterator.next());
}
};
}
}

@ProcessElement
public void process(@Element KV<Row, CoGbkResult> kv, OutputReceiver<KV<Row, Row>> o) {
public void process(@Element KV<Row, CoGbkResult> kv, OutputReceiver<Row> o) {
Row key = kv.getKey();
CoGbkResult result = kv.getValue();
List<Object> fields = Lists.newArrayListWithCapacity(sortedTags.size());
for (int i = 0; i < sortedTags.size(); ++i) {
String tag = sortedTags.get(i);
// TODO: This forces the entire join to materialize in memory. We should create a
// lazy Row interface on top of the iterable returned by CoGbkResult. This will
// allow the data to be streamed in. Tracked in [BEAM-6756].
SerializableFunction<Object, Row> toRow = toRows.get(i);
String tupleTag = tagToKeyedTag.get(i);
List<Row> joined = Lists.newArrayList();
for (Object item : result.getAll(tupleTag)) {
joined.add(toRow.apply(item));
}
fields.add(joined);
SerializableFunction<Object, Row> toRow = toRows.get(i);
fields.add(new Result(result.getAll(tupleTag), toRow));
}
o.output(KV.of(key, Row.withSchema(joinedSchema).addValues(fields).build()));
Row row = Row.withSchema(outputSchema).addValue(key).addValues(fields).build();
o.output(row);
}
}
}
Expand Down
Loading

0 comments on commit 9501152

Please sign in to comment.