Skip to content

Commit

Permalink
Add support for FULL OUTER JOIN
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhun committed Apr 29, 2015
1 parent 7c44f84 commit 4e74609
Show file tree
Hide file tree
Showing 19 changed files with 340 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
import com.google.common.primitives.Ints;
import io.airlift.slice.XxHash64;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.longs.AbstractLongIterator;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongIterator;

import java.util.Arrays;
import java.util.List;

import static com.facebook.presto.operator.SyntheticAddress.decodePosition;
import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.airlift.slice.SizeOf.sizeOfBooleanArray;
import static io.airlift.slice.SizeOf.sizeOfIntArray;

// This implementation assumes arrays used in the hash are always a power of 2
Expand All @@ -41,6 +44,7 @@ public final class InMemoryJoinHash
private final int channelCount;
private final int mask;
private final int[] key;
private final boolean[] keyVisited;
private final int[] positionLinks;
private final long size;
private final List<Type> hashTypes;
Expand All @@ -54,10 +58,11 @@ public InMemoryJoinHash(LongArrayList addresses, List<Type> hashTypes, PagesHash

// reserve memory for the arrays
int hashSize = HashCommon.arraySize(addresses.size(), 0.75f);
size = sizeOfIntArray(hashSize) + sizeOfIntArray(addresses.size());
size = sizeOfIntArray(hashSize) + sizeOfBooleanArray(hashSize) + sizeOfIntArray(addresses.size());

mask = hashSize - 1;
key = new int[hashSize];
keyVisited = new boolean[hashSize];
Arrays.fill(key, -1);

this.positionLinks = new int[addresses.size()];
Expand Down Expand Up @@ -111,6 +116,7 @@ public long getJoinPosition(int position, Page page, int rawHash)

while (key[pos] != -1) {
if (positionEqualsCurrentRow(key[pos], position, page.getBlocks())) {
keyVisited[pos] = true;
return key[pos];
}
// increment position and mask to handler wrap around
Expand All @@ -125,6 +131,56 @@ public final long getNextJoinPosition(long currentPosition)
return positionLinks[Ints.checkedCast(currentPosition)];
}

@Override
public LongIterator getUnvisitedJoinPositions()
{
return new UnvisitedJoinPositionIterator();
}

public class UnvisitedJoinPositionIterator extends AbstractLongIterator
{
private int nextKeyId = 0;
private long nextJoinPosition = -1;

private UnvisitedJoinPositionIterator()
{
findUnvisitedKeyId();
}

@Override
public long nextLong()
{
long result = nextJoinPosition;

nextJoinPosition = getNextJoinPosition(nextJoinPosition);
if (nextJoinPosition < 0) {
nextKeyId++;
findUnvisitedKeyId();
}

return result;
}

@Override
public boolean hasNext()
{
return nextKeyId < keyVisited.length;
}

private void findUnvisitedKeyId()
{
while (nextKeyId < keyVisited.length) {
if (key[nextKeyId] != -1 && !keyVisited[nextKeyId]) {
break;
}
nextKeyId++;
}
if (nextKeyId < keyVisited.length) {
nextJoinPosition = key[nextKeyId];
}
}
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
*/
package com.facebook.presto.operator;

import com.facebook.presto.operator.LookupJoinOperators.JoinType;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.longs.LongIterator;

import java.io.Closeable;
import java.util.List;

import static com.facebook.presto.operator.LookupJoinOperators.JoinType.FULL_OUTER;
import static com.facebook.presto.operator.LookupJoinOperators.JoinType.INNER;
import static com.facebook.presto.util.MoreFutures.tryGetUnchecked;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -33,8 +37,9 @@ public class LookupJoinOperator

private final OperatorContext operatorContext;
private final JoinProbeFactory joinProbeFactory;
private final boolean enableOuterJoin;
private final JoinType joinType;
private final List<Type> types;
private final List<Type> probeTypes;
private final PageBuilder pageBuilder;

private LookupSource lookupSource;
Expand All @@ -43,11 +48,13 @@ public class LookupJoinOperator
private boolean finishing;
private long joinPosition = -1;

private LongIterator unvisitedJoinPositions;

public LookupJoinOperator(
OperatorContext operatorContext,
LookupSourceSupplier lookupSourceSupplier,
List<Type> probeTypes,
boolean enableOuterJoin,
JoinType joinType,
JoinProbeFactory joinProbeFactory)
{
this.operatorContext = checkNotNull(operatorContext, "operatorContext is null");
Expand All @@ -58,12 +65,13 @@ public LookupJoinOperator(

this.lookupSourceFuture = lookupSourceSupplier.getLookupSource(operatorContext);
this.joinProbeFactory = joinProbeFactory;
this.enableOuterJoin = enableOuterJoin;
this.joinType = joinType;

this.types = ImmutableList.<Type>builder()
.addAll(probeTypes)
.addAll(lookupSourceSupplier.getTypes())
.build();
this.probeTypes = probeTypes;
this.pageBuilder = new PageBuilder(types);
}

Expand All @@ -88,7 +96,11 @@ public void finish()
@Override
public boolean isFinished()
{
boolean finished = finishing && probe == null && pageBuilder.isEmpty();
boolean finished =
finishing &&
probe == null &&
pageBuilder.isEmpty() &&
(joinType != FULL_OUTER || (unvisitedJoinPositions != null && !unvisitedJoinPositions.hasNext()));

// if finished drop references so memory is freed early
if (finished) {
Expand Down Expand Up @@ -139,6 +151,14 @@ public void addInput(Page page)
@Override
public Page getOutput()
{
// If needsInput was never called, lookupSource has not been initialized so far.
if (lookupSource == null) {
lookupSource = tryGetUnchecked(lookupSourceFuture);
if (lookupSource == null) {
return null;
}
}

// join probe page with the lookup source
if (probe != null) {
while (joinCurrentPosition()) {
Expand All @@ -151,6 +171,10 @@ public Page getOutput()
}
}

if (joinType == FULL_OUTER && finishing && probe == null) {
buildSideOuterJoinUnvisitedPositions();
}

// only flush full pages unless we are done
if (pageBuilder.isFull() || (finishing && !pageBuilder.isEmpty() && probe == null)) {
Page page = pageBuilder.build();
Expand Down Expand Up @@ -205,7 +229,7 @@ private boolean advanceProbePosition()

private boolean outerJoinCurrentPosition()
{
if (enableOuterJoin && joinPosition < 0) {
if (joinType != INNER && joinPosition < 0) {
// write probe columns
pageBuilder.declarePosition();
probe.appendTo(pageBuilder);
Expand All @@ -222,4 +246,28 @@ private boolean outerJoinCurrentPosition()
}
return true;
}

private void buildSideOuterJoinUnvisitedPositions()
{
if (unvisitedJoinPositions == null) {
unvisitedJoinPositions = lookupSource.getUnvisitedJoinPositions();
}

while (unvisitedJoinPositions.hasNext()) {
long buildSideOuterJoinPosition = unvisitedJoinPositions.nextLong();
pageBuilder.declarePosition();

// write nulls into probe columns
for (int probeChannel = 0; probeChannel < probeTypes.size(); probeChannel++) {
pageBuilder.getBlockBuilder(probeChannel).appendNull();
}

// write build columns
lookupSource.appendTo(buildSideOuterJoinPosition, pageBuilder, probeTypes.size());

if (pageBuilder.isFull()) {
return;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.operator;

import com.facebook.presto.operator.LookupJoinOperators.JoinType;
import com.facebook.presto.spi.type.Type;
import com.google.common.collect.ImmutableList;

Expand All @@ -26,21 +27,21 @@ public class LookupJoinOperatorFactory
private final int operatorId;
private final LookupSourceSupplier lookupSourceSupplier;
private final List<Type> probeTypes;
private final boolean enableOuterJoin;
private final JoinType joinType;
private final List<Type> types;
private final JoinProbeFactory joinProbeFactory;
private boolean closed;

public LookupJoinOperatorFactory(int operatorId,
LookupSourceSupplier lookupSourceSupplier,
List<Type> probeTypes,
boolean enableOuterJoin,
JoinType joinType,
JoinProbeFactory joinProbeFactory)
{
this.operatorId = operatorId;
this.lookupSourceSupplier = lookupSourceSupplier;
this.probeTypes = probeTypes;
this.enableOuterJoin = enableOuterJoin;
this.joinType = joinType;

this.joinProbeFactory = joinProbeFactory;

Expand All @@ -61,7 +62,7 @@ public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, LookupJoinOperator.class.getSimpleName());
return new LookupJoinOperator(operatorContext, lookupSourceSupplier, probeTypes, enableOuterJoin, joinProbeFactory);
return new LookupJoinOperator(operatorContext, lookupSourceSupplier, probeTypes, joinType, joinProbeFactory);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@

public class LookupJoinOperators
{
public enum JoinType {
INNER,
PROBE_OUTER,
FULL_OUTER,
}

private LookupJoinOperators()
{
}
Expand All @@ -29,12 +35,16 @@ private LookupJoinOperators()

public static OperatorFactory innerJoin(int operatorId, LookupSourceSupplier lookupSourceSupplier, List<? extends Type> probeTypes, List<Integer> probeJoinChannel, Optional<Integer> probeHashChannel)
{
OperatorFactory operatorFactory = JOIN_PROBE_COMPILER.compileJoinOperatorFactory(operatorId, lookupSourceSupplier, probeTypes, probeJoinChannel, probeHashChannel, false);
return operatorFactory;
return JOIN_PROBE_COMPILER.compileJoinOperatorFactory(operatorId, lookupSourceSupplier, probeTypes, probeJoinChannel, probeHashChannel, JoinType.INNER);
}

public static OperatorFactory probeOuterJoin(int operatorId, LookupSourceSupplier lookupSourceSupplier, List<? extends Type> probeTypes, List<Integer> probeJoinChannel, Optional<Integer> probeHashChannel)
{
return JOIN_PROBE_COMPILER.compileJoinOperatorFactory(operatorId, lookupSourceSupplier, probeTypes, probeJoinChannel, probeHashChannel, JoinType.PROBE_OUTER);
}

public static OperatorFactory outerJoin(int operatorId, LookupSourceSupplier lookupSourceSupplier, List<? extends Type> probeTypes, List<Integer> probeJoinChannel, Optional<Integer> probeHashChannel)
public static OperatorFactory fullOuterJoin(int operatorId, LookupSourceSupplier lookupSourceSupplier, List<? extends Type> probeTypes, List<Integer> probeJoinChannel, Optional<Integer> probeHashChannel)
{
return JOIN_PROBE_COMPILER.compileJoinOperatorFactory(operatorId, lookupSourceSupplier, probeTypes, probeJoinChannel, probeHashChannel, true);
return JOIN_PROBE_COMPILER.compileJoinOperatorFactory(operatorId, lookupSourceSupplier, probeTypes, probeJoinChannel, probeHashChannel, JoinType.FULL_OUTER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import it.unimi.dsi.fastutil.longs.LongIterator;

import java.io.Closeable;

Expand All @@ -33,6 +34,8 @@ public interface LookupSource

void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset);

LongIterator getUnvisitedJoinPositions();

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import it.unimi.dsi.fastutil.longs.LongIterator;

import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.NotThreadSafe;
Expand Down Expand Up @@ -389,6 +390,12 @@ public long getNextJoinPosition(long currentPosition)
return IndexSnapshot.UNLOADED_INDEX_KEY;
}

@Override
public LongIterator getUnvisitedJoinPositions()
{
throw new UnsupportedOperationException();
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import it.unimi.dsi.fastutil.longs.LongIterator;

import javax.annotation.concurrent.NotThreadSafe;

Expand Down Expand Up @@ -80,6 +81,12 @@ public long getNextJoinPosition(long currentPosition)
return nextPosition;
}

@Override
public LongIterator getUnvisitedJoinPositions()
{
throw new UnsupportedOperationException();
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand Down
Loading

0 comments on commit 4e74609

Please sign in to comment.