Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-3942] [tableAPI] Add support for INTERSECT #2159

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion docs/apis/table.md
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,30 @@ Table result = left.union(right);
Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "a, b, c");
Table result = left.unionAll(right);
{% endhighlight %}
</td>
</tr>

<tr>
<td><strong>Intersect</strong></td>
<td>
<p>Similar to a SQL INTERSECT clause. Intersect returns records that exist in both tables. If a record is present one or both tables more than once, it is returned just once, i.e., the resulting table has no duplicate records. Both tables must have identical field types.</p>
{% highlight java %}
Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "d, e, f");
Table result = left.intersect(right);
{% endhighlight %}
</td>
</tr>

<tr>
<td><strong>IntersectAll</strong></td>
<td>
<p>Similar to a SQL INTERSECT ALL clause. Intersect All returns records that exist in both tables. If a record is present in both tables more than once, it is returned as many times as it is present in both tables, i.e., the resulting table might have duplicate records. Both tables must have identical field types.</p>
{% highlight java %}
Table left = tableEnv.fromDataSet(ds1, "a, b, c");
Table right = tableEnv.fromDataSet(ds2, "d, e, f");
Table result = left.intersectAll(right);
{% endhighlight %}
</td>
</tr>
Expand Down Expand Up @@ -690,6 +714,30 @@ val result = left.union(right);
val left = ds1.toTable(tableEnv, 'a, 'b, 'c);
val right = ds2.toTable(tableEnv, 'a, 'b, 'c);
val result = left.unionAll(right);
{% endhighlight %}
</td>
</tr>

<tr>
<td><strong>Intersect</strong></td>
<td>
<p>Similar to a SQL INTERSECT clause. Intersect returns records that exist in both tables. If a record is present one or both tables more than once, it is returned just once, i.e., the resulting table has no duplicate records. Both tables must have identical field types.</p>
{% highlight scala %}
val left = ds1.toTable(tableEnv, 'a, 'b, 'c);
val right = ds2.toTable(tableEnv, 'e, 'f, 'g);
val result = left.intersect(right);
{% endhighlight %}
</td>
</tr>

<tr>
<td><strong>IntersectAll</strong></td>
<td>
<p>Similar to a SQL INTERSECT ALL clause. Intersect All returns records that exist in both tables. If a record is present in both tables more than once, it is returned as many times as it is present in both tables, i.e., the resulting table might have duplicate records. Both tables must have identical field types.</p>
{% highlight scala %}
val left = ds1.toTable(tableEnv, 'a, 'b, 'c);
val right = ds2.toTable(tableEnv, 'e, 'f, 'g);
val result = left.intersectAll(right);
{% endhighlight %}
</td>
</tr>
Expand Down Expand Up @@ -831,7 +879,7 @@ Among others, the following SQL features are not supported, yet:
- Non-equi joins and Cartesian products
- Result selection by order position (`ORDER BY OFFSET FETCH`)
- Grouping sets
- `INTERSECT` and `EXCEPT` set operations
- `EXCEPT` set operation

*Note: Tables are joined in the order in which they are specified in the `FROM` clause. In some cases the table order must be manually tweaked to resolve Cartesian products.*

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,33 @@ case class Union(left: LogicalNode, right: LogicalNode, all: Boolean) extends Bi
}
}

case class Intersect(left: LogicalNode, right: LogicalNode, all: Boolean) extends BinaryNode {
override def output: Seq[Attribute] = left.output

override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
left.construct(relBuilder)
right.construct(relBuilder)
relBuilder.intersect(all)
}

override def validate(tableEnv: TableEnvironment): LogicalNode = {
val resolvedIntersect = super.validate(tableEnv).asInstanceOf[Intersect]
if (left.output.length != right.output.length) {
failValidation(s"Intersect two table of different column sizes:" +
s" ${left.output.size} and ${right.output.size}")
}
// allow different column names between tables
val sameSchema = left.output.zip(right.output).forall { case (l, r) =>
l.resultType == r.resultType}
if (!sameSchema) {
failValidation(s"Intersect two table of different schema:" +
s" [${left.output.map(a => (a.name, a.resultType)).mkString(", ")}] and" +
s" [${right.output.map(a => (a.name, a.resultType)).mkString(", ")}]")
}
resolvedIntersect
}
}

case class Join(
left: LogicalNode,
right: LogicalNode,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.plan.nodes.dataset

import org.apache.calcite.plan.{RelOptCost, RelOptPlanner, RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelWriter, BiRel, RelNode}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.runtime.IntersectCoGroupFunction
import org.apache.flink.api.table.typeutils.TypeConverter._
import org.apache.flink.api.table.BatchTableEnvironment

import scala.collection.JavaConverters._
import scala.collection.JavaConversions._

/**
* Flink RelNode which translate Intersect into Join Operator.
*
*/
class DataSetIntersect(
cluster: RelOptCluster,
traitSet: RelTraitSet,
left: RelNode,
right: RelNode,
rowType: RelDataType,
all: Boolean,
ruleDescription: String)
extends BiRel(cluster, traitSet, left, right)
with DataSetRel {

override def deriveRowType() = rowType

override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetIntersect(
cluster,
traitSet,
inputs.get(0),
inputs.get(1),
rowType,
all,
ruleDescription
)
}

override def toString: String = {
s"Intersect(intersect: ($intersectSelectionToString))"
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw).item("intersect", intersectSelectionToString)
}

override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val children = this.getInputs
children.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, child) =>
val rowCnt = metadata.getRowCount(child)
val rowSize = this.estimateRowSize(child.getRowType)
cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize))
}
}

override def translateToPlan(
tableEnv: BatchTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {

val leftDataSet: DataSet[Any] = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
val rightDataSet: DataSet[Any] = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv)

val coGroupedDs = leftDataSet.coGroup(rightDataSet)

val coGroupOpName = s"intersect: ($intersectSelectionToString)"
val coGroupFunction = new IntersectCoGroupFunction[Any](all)

val intersectDs = coGroupedDs.where("*").equalTo("*")
.`with`(coGroupFunction).name(coGroupOpName)

val config = tableEnv.getConfig
val leftType = leftDataSet.getType

// here we only care about left type information, because we emit records from left dataset
expectedType match {
case None if config.getEfficientTypeUsage =>
intersectDs

case _ =>
val determinedType = determineReturnType(
getRowType,
expectedType,
config.getNullCheck,
config.getEfficientTypeUsage)

// conversion
if (determinedType != leftType) {

val mapFunc = getConversionMapper(
config,
false,
leftType,
determinedType,
"DataSetIntersectConversion",
getRowType.getFieldNames)

val opName = s"convert: (${rowType.getFieldNames.asScala.toList.mkString(", ")})"

intersectDs.map(mapFunc).name(opName)
}
// no conversion necessary, forward
else {
intersectDs
}
}
}

private def intersectSelectionToString: String = {
rowType.getFieldNames.asScala.toList.mkString(", ")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
package org.apache.flink.api.table.plan.rules

import org.apache.calcite.rel.rules._
import org.apache.calcite.rel.stream.StreamRules
import org.apache.calcite.tools.{RuleSets, RuleSet}
import org.apache.flink.api.table.plan.rules.dataSet._
import org.apache.flink.api.table.plan.rules.datastream._
import org.apache.flink.api.table.plan.rules.datastream.{DataStreamCalcRule, DataStreamScanRule, DataStreamUnionRule}
import scala.collection.JavaConversions._

object FlinkRuleSets {

Expand Down Expand Up @@ -102,6 +100,7 @@ object FlinkRuleSets {
DataSetJoinRule.INSTANCE,
DataSetScanRule.INSTANCE,
DataSetUnionRule.INSTANCE,
DataSetIntersectRule.INSTANCE,
DataSetSortRule.INSTANCE,
DataSetValuesRule.INSTANCE,
BatchTableSourceScanRule.INSTANCE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.plan.rules.dataSet

import org.apache.calcite.plan.{Convention, RelOptRule, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.logical.LogicalIntersect
import org.apache.flink.api.table.plan.nodes.dataset.{DataSetIntersect, DataSetConvention}

class DataSetIntersectRule
extends ConverterRule(
classOf[LogicalIntersect],
Convention.NONE,
DataSetConvention.INSTANCE,
"DataSetIntersectRule")
{

def convert(rel: RelNode): RelNode = {

val intersect: LogicalIntersect = rel.asInstanceOf[LogicalIntersect]
val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
val convLeft: RelNode = RelOptRule.convert(intersect.getInput(0), DataSetConvention.INSTANCE)
val convRight: RelNode = RelOptRule.convert(intersect.getInput(1), DataSetConvention.INSTANCE)

new DataSetIntersect(
rel.getCluster,
traitSet,
convLeft,
convRight,
rel.getRowType,
intersect.all,
description)
}
}

object DataSetIntersectRule {
val INSTANCE: RelOptRule = new DataSetIntersectRule
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.runtime

import java.lang.{Iterable => JIterable}

import org.apache.flink.api.common.functions.CoGroupFunction
import org.apache.flink.util.Collector


class IntersectCoGroupFunction[T](all: Boolean) extends CoGroupFunction[T, T, T]{
override def coGroup(first: JIterable[T], second: JIterable[T], out: Collector[T]): Unit = {
if (first == null || second == null) return
val leftIter = first.iterator()
val rightIter = second.iterator()
if (all) {
while (leftIter.hasNext && rightIter.hasNext) {
out.collect(leftIter.next)
rightIter.next
}
} else {
if (leftIter.hasNext && rightIter.hasNext) {
out.collect(leftIter.next)
}
}
}
}
Loading