Skip to content

Commit

Permalink
[FLINK-3848] [table] Add ProjectableTableSource and push projections …
Browse files Browse the repository at this point in the history
…into BatchTableSourceScan.

This closes apache#2923.
  • Loading branch information
beyond1920 authored and fhueske committed Dec 13, 2016
1 parent 5c86efb commit 5baea3f
Show file tree
Hide file tree
Showing 7 changed files with 522 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
package org.apache.flink.api.table.plan.nodes.dataset

import org.apache.calcite.plan._
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.{BatchTableEnvironment, FlinkTypeFactory}
Expand All @@ -39,6 +40,11 @@ class BatchTableSourceScan(
flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes)
}

override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
val rowCnt = metadata.getRowCount(this)
planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType))
}

override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new BatchTableSourceScan(
cluster,
Expand All @@ -48,6 +54,11 @@ class BatchTableSourceScan(
)
}

override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
.item("fields", tableSource.getFieldsNames.mkString(", "))
}

override def translateToPlan(
tableEnv: BatchTableEnvironment,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ object FlinkRuleSets {
DataSetSortRule.INSTANCE,
DataSetValuesRule.INSTANCE,
DataSetCorrelateRule.INSTANCE,
BatchTableSourceScanRule.INSTANCE
BatchTableSourceScanRule.INSTANCE,
// project pushdown optimization
PushProjectIntoBatchTableSourceScanRule.INSTANCE
)

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.plan.rules.dataSet

import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.plan.RelOptRule.{none, operand}
import org.apache.flink.api.table.plan.nodes.dataset.{BatchTableSourceScan, DataSetCalc}
import org.apache.flink.api.table.plan.rules.util.RexProgramProjectExtractor._
import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource}

/**
* This rule tries to push projections into a BatchTableSourceScan.
*/
class PushProjectIntoBatchTableSourceScanRule extends RelOptRule(
operand(classOf[DataSetCalc],
operand(classOf[BatchTableSourceScan], none)),
"PushProjectIntoBatchTableSourceScanRule") {

override def matches(call: RelOptRuleCall) = {
val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan]
scan.tableSource match {
case _: ProjectableTableSource[_] => true
case _ => false
}
}

override def onMatch(call: RelOptRuleCall) {
val calc: DataSetCalc = call.rel(0).asInstanceOf[DataSetCalc]
val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan]

val usedFields: Array[Int] = extractRefInputFields(calc.calcProgram)

// if no fields can be projected, there is no need to transform subtree
if (scan.tableSource.getNumberOfFields != usedFields.length) {
val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]]
val newTableSource = originTableSource.projectFields(usedFields)
val newScan = new BatchTableSourceScan(
scan.getCluster,
scan.getTraitSet,
scan.getTable,
newTableSource.asInstanceOf[BatchTableSource[_]])

val newCalcProgram = rewriteRexProgram(
calc.calcProgram,
newScan.getRowType,
usedFields,
calc.getCluster.getRexBuilder)

// if project merely returns its input and doesn't exist filter, remove datasetCalc nodes
if (newCalcProgram.isTrivial) {
call.transformTo(newScan)
} else {
val newCalc = new DataSetCalc(
calc.getCluster,
calc.getTraitSet,
newScan,
calc.getRowType,
newCalcProgram,
description)
call.transformTo(newCalc)
}
}
}
}

object PushProjectIntoBatchTableSourceScanRule {
val INSTANCE: RelOptRule = new PushProjectIntoBatchTableSourceScanRule
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.plan.rules.util

import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.JavaConverters._

object RexProgramProjectExtractor {

/**
* Extracts the indexes of input fields accessed by the RexProgram.
*
* @param rexProgram The RexProgram to analyze
* @return The indexes of accessed input fields
*/
def extractRefInputFields(rexProgram: RexProgram): Array[Int] = {
val visitor = new RefFieldsVisitor
// extract input fields from project expressions
rexProgram.getProjectList.foreach(exp => rexProgram.expandLocalRef(exp).accept(visitor))
val condition = rexProgram.getCondition
// extract input fields from condition expression
if (condition != null) {
rexProgram.expandLocalRef(condition).accept(visitor)
}
visitor.getFields
}

/**
* Generates a new RexProgram based on mapped input fields.
*
* @param rexProgram original RexProgram
* @param inputRowType input row type
* @param usedInputFields indexes of used input fields
* @param rexBuilder builder for Rex expressions
*
* @return A RexProgram with mapped input field expressions.
*/
def rewriteRexProgram(
rexProgram: RexProgram,
inputRowType: RelDataType,
usedInputFields: Array[Int],
rexBuilder: RexBuilder): RexProgram = {

val inputRewriter = new InputRewriter(usedInputFields)
val newProjectExpressions = rexProgram.getProjectList.map(
exp => rexProgram.expandLocalRef(exp).accept(inputRewriter)
).toList.asJava

val oldCondition = rexProgram.getCondition
val newConditionExpression = {
oldCondition match {
case ref: RexLocalRef => rexProgram.expandLocalRef(ref).accept(inputRewriter)
case _ => null // null does not match any type
}
}
RexProgram.create(
inputRowType,
newProjectExpressions,
newConditionExpression,
rexProgram.getOutputRowType,
rexBuilder
)
}
}

/**
* A RexVisitor to extract used input fields
*/
class RefFieldsVisitor extends RexVisitorImpl[Unit](true) {
private var fields = mutable.LinkedHashSet[Int]()

def getFields: Array[Int] = fields.toArray

override def visitInputRef(inputRef: RexInputRef): Unit = fields += inputRef.getIndex

override def visitCall(call: RexCall): Unit =
call.operands.foreach(operand => operand.accept(this))
}

/**
* A RexShuttle to rewrite field accesses of a RexProgram.
*
* @param fields fields mapping
*/
class InputRewriter(fields: Array[Int]) extends RexShuttle {

/** old input fields ref index -> new input fields ref index mappings */
private val fieldMap: Map[Int, Int] =
fields.zipWithIndex.toMap

override def visitInputRef(inputRef: RexInputRef): RexNode =
new RexInputRef(relNodeIndex(inputRef), inputRef.getType)

override def visitLocalRef(localRef: RexLocalRef): RexNode =
new RexInputRef(relNodeIndex(localRef), localRef.getType)

private def relNodeIndex(ref: RexSlot): Int =
fieldMap.getOrElse(ref.getIndex,
throw new IllegalArgumentException("input field contains invalid index"))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.sources

/**
* Adds support for projection push-down to a [[TableSource]].
* A [[TableSource]] extending this interface is able to project the fields of the return table.
*
* @tparam T The return type of the [[ProjectableTableSource]].
*/
trait ProjectableTableSource[T] {

/**
* Creates a copy of the [[ProjectableTableSource]] that projects its output on the specified
* fields.
*
* @param fields The indexes of the fields to return.
* @return A copy of the [[ProjectableTableSource]] that projects its output.
*/
def projectFields(fields: Array[Int]): ProjectableTableSource[T]

}
Loading

0 comments on commit 5baea3f

Please sign in to comment.