Skip to content

Commit

Permalink
[FLINK-10974][table] Add support for flatMap to table API
Browse files Browse the repository at this point in the history
This closes apache#7196.
  • Loading branch information
dianfu authored and sunjincheng121 committed Apr 19, 2019
1 parent e41b6d4 commit 58e69a0
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 6 deletions.
88 changes: 86 additions & 2 deletions docs/dev/table/tableApi.md
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,7 @@ The `OverWindow` defines a range of rows over which aggregates are computed. `Ov

### Row-based Operations

The row-based operations generate outputs with multiple columns.
<div class="codetabs" markdown="1">
<div data-lang="java" markdown="1">
<table class="table table-bordered">
Expand All @@ -1838,15 +1839,59 @@ The `OverWindow` defines a range of rows over which aggregates are computed. `Ov
<td>
<p>Performs a map operation with a user-defined scalar function or built-in scalar function. The output will be flattened if the output type is a composite type.</p>
{% highlight java %}
public class MyMapFunction extends ScalarFunction {
public Row eval(String a) {
return Row.of(a, "pre-" + a);
}

@Override
public TypeInformation<?> getResultType(Class<?>[] signature) {
return Types.ROW(Types.STRING(), Types.STRING());
}
}

ScalarFunction func = new MyMapFunction();
tableEnv.registerFunction("func", func);

Table table = input
.map(func("c")).as("a, b")
.map("func(c)").as("a, b")
{% endhighlight %}
</td>
</tr>

<tr>
<td>
<strong>FlatMap</strong><br>
<span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span>
</td>
<td>
<p>Performs a flatMap operation with a table function.</p>
{% highlight java %}
public class MyFlatMapFunction extends TableFunction<Row> {

public void eval(String str) {
if (str.contains("#")) {
String[] array = str.split("#");
for (int i = 0; i < array.length; ++i) {
collect(Row.of(array[i], array[i].length()));
}
}
}

@Override
public TypeInformation<Row> getResultType() {
return Types.ROW(Types.STRING(), Types.INT());
}
}

TableFunction func = new MyFlatMapFunction();
tableEnv.registerFunction("func", func);

Table table = input
.flatMap("func(c)").as("a, b")
{% endhighlight %}
</td>
</tr>
</tbody>
</table>
</div>
Expand All @@ -1868,14 +1913,53 @@ Table table = input
<td>
<p>Performs a map operation with a user-defined scalar function or built-in scalar function. The output will be flattened if the output type is a composite type.</p>
{% highlight scala %}
val func: ScalarFunction = new MyMapFunction()
class MyMapFunction extends ScalarFunction {
def eval(a: String): Row = {
Row.of(a, "pre-" + a)
}

override def getResultType(signature: Array[Class[_]]): TypeInformation[_] =
Types.ROW(Types.STRING, Types.STRING)
}

val func = new MyMapFunction()
val table = input
.map(func('c)).as('a, 'b)
{% endhighlight %}
</td>
</tr>

<tr>
<td>
<strong>FlatMap</strong><br>
<span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span>
</td>
<td>
<p>Performs a flatMap operation with a table function.</p>
{% highlight scala %}
class MyFlatMapFunction extends TableFunction[Row] {
def eval(str: String): Unit = {
if (str.contains("#")) {
str.split("#").foreach({ s =>
val row = new Row(2)
row.setField(0, s)
row.setField(1, s.length)
collect(row)
})
}
}

override def getResultType: TypeInformation[Row] = {
Types.ROW(Types.STRING, Types.INT)
}
}

val func = new MyFlatMapFunction
val table = input
.flatMap(func('c)).as('a, 'b)
{% endhighlight %}
</td>
</tr>
</tbody>
</table>
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,8 @@ public interface Table {
* <p>Example:
*
* <pre>
* {@code ScalarFunction func = new MyMapFunction();
* {@code
* ScalarFunction func = new MyMapFunction();
* tableEnv.registerFunction("func", func);
* tab.map("func(c)");
* }
Expand All @@ -1019,10 +1020,42 @@ public interface Table {
* <p>Scala Example:
*
* <pre>
* {@code val func = new MyMapFunction()
* {@code
* val func = new MyMapFunction()
* tab.map(func('c))
* }
* </pre>
*/
Table map(Expression mapFunction);

/**
* Performs a flatMap operation with an user-defined table function or built-in table function.
* The output will be flattened if the output type is a composite type.
*
* <p>Example:
*
* <pre>
* {@code
* TableFunction func = new MyFlatMapFunction();
* tableEnv.registerFunction("func", func);
* table.flatMap("func(c)");
* }
* </pre>
*/
Table flatMap(String tableFunction);

/**
* Performs a flatMap operation with an user-defined table function or built-in table function.
* The output will be flattened if the output type is a composite type.
*
* <p>Scala Example:
*
* <pre>
* {@code
* val func = new MyFlatMapFunction
* table.flatMap(func('c))
* }
* </pre>
*/
Table flatMap(Expression tableFunction);
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,9 @@ class TableImpl(val tableEnv: TableEnvironment, relNode: RelNode) extends Table

override def map(mapFunction: Expression): Table = ???

override def flatMap(tableFunction: String): Table = ???

override def flatMap(tableFunction: Expression): Table = ???

override def getTableOperation: TableOperation = ???
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.calcite.rel.RelNode
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.table.expressions.{Expression, ExpressionParser, LookupCallResolver}
import org.apache.flink.table.functions.{TemporalTableFunction, TemporalTableFunctionImpl}
import org.apache.flink.table.operations.OperationExpressionsUtils.{extractAggregationsAndProperties}
import org.apache.flink.table.operations.OperationExpressionsUtils.extractAggregationsAndProperties
import org.apache.flink.table.operations.{OperationTreeBuilder, TableOperation}
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.util.JavaScalaConversionUtil.toJava
Expand Down Expand Up @@ -434,6 +434,14 @@ class TableImpl(
wrap(operationTreeBuilder.map(mapFunction, operationTree))
}

override def flatMap(tableFunction: String): Table = {
flatMap(ExpressionParser.parseExpression(tableFunction))
}

override def flatMap(tableFunction: Expression): Table = {
wrap(operationTreeBuilder.flatMap(tableFunction, operationTree))
}

/**
* Registers an unique table name under the table environment
* and return the registered table name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ import java.util.{Collections, Optional, List => JList}
import org.apache.flink.api.java.operators.join.JoinType
import org.apache.flink.table.api._
import org.apache.flink.table.expressions.ExpressionResolver.resolverFor
import org.apache.flink.table.expressions.FunctionDefinition.Type.SCALAR_FUNCTION
import org.apache.flink.table.expressions.FunctionDefinition.Type.{SCALAR_FUNCTION, TABLE_FUNCTION}
import org.apache.flink.table.expressions._
import org.apache.flink.table.expressions.catalog.FunctionDefinitionCatalog
import org.apache.flink.table.expressions.lookups.TableReferenceLookup
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.operations.AliasOperationUtils.createAliasList
import org.apache.flink.table.plan.logical.{Minus => LMinus, _}
import org.apache.flink.table.util.JavaScalaConversionUtil
import org.apache.flink.table.util.JavaScalaConversionUtil.toScala
import org.apache.flink.util.Preconditions

import _root_.scala.collection.JavaConversions._
import _root_.scala.collection.JavaConverters._

/**
Expand Down Expand Up @@ -368,6 +370,53 @@ class OperationTreeBuilder(private val tableEnv: TableEnvironment) {
mapFunction.asInstanceOf[CallExpression].getFunctionDefinition.getType == SCALAR_FUNCTION
}

def flatMap(tableFunction: Expression, child: TableOperation): TableOperation = {

val resolver = resolverFor(tableCatalog, functionCatalog, child).build()
val resolvedTableFunction = resolveSingleExpression(tableFunction, resolver)

if (!isTableFunction(resolvedTableFunction)) {
throw new ValidationException("Only TableFunction can be used in the flatMap operator.")
}

val originFieldNames: Seq[String] =
resolvedTableFunction.asInstanceOf[CallExpression].getFunctionDefinition match {
case tfd: TableFunctionDefinition =>
UserDefinedFunctionUtils.getFieldInfo(tfd.getResultType)._1
}

def getUniqueName(inputName: String, usedFieldNames: Seq[String]): String = {
var i = 0
var resultName = inputName
while (usedFieldNames.contains(resultName)) {
resultName = resultName + "_" + i
i += 1
}
resultName
}

val usedFieldNames = child.asInstanceOf[LogicalNode].output.map(_.name).toBuffer
val newFieldNames = originFieldNames.map({ e =>
val resultName = getUniqueName(e, usedFieldNames)
usedFieldNames.append(resultName)
resultName
})

val renamedTableFunction = ApiExpressionUtils.call(
BuiltInFunctionDefinitions.AS,
resolvedTableFunction +: newFieldNames.map(ApiExpressionUtils.valueLiteral(_)): _*)
val joinNode = joinLateral(child, renamedTableFunction, JoinType.INNER, Optional.empty())
val rightNode = dropColumns(
child.getTableSchema.getFieldNames.map(a => new UnresolvedReferenceExpression(a)).toList,
joinNode)
alias(originFieldNames.map(a => new UnresolvedReferenceExpression(a)), rightNode)
}

private def isTableFunction(tableFunction: Expression) = {
tableFunction.isInstanceOf[CallExpression] &&
tableFunction.asInstanceOf[CallExpression].getFunctionDefinition.getType == TABLE_FUNCTION
}

class NoWindowPropertyChecker(val exceptionMessage: String)
extends ApiExpressionDefaultVisitor[Void] {
override def visitCall(call: CallExpression): Void = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,31 @@ class CorrelateTest extends TableTestBase {

util.verifyTable(result, expected)
}

@Test
def testFlatMap(): Unit = {
val util = streamTestUtil()

val func2 = new TableFunc2
val resultTable = util.addTable[(Int, Long, String)]("MyTable", 'f1, 'f2, 'f3)
.flatMap(func2('f3))

val expected = unaryNode(
"DataStreamCalc",
unaryNode(
"DataStreamCorrelate",
streamTableNode(0),
term("invocation", s"${func2.functionIdentifier}($$2)"),
term("correlate", "table(TableFunc2(f3))"),
term("select", "f1", "f2", "f3", "f0", "f1_0"),
term("rowType",
"RecordType(INTEGER f1, BIGINT f2, VARCHAR(65536) f3, VARCHAR(65536) f0, " +
"INTEGER f1_0)"),
term("joinType", "INNER")
),
term("select", "f0", "f1_0 AS f1")
)

util.verifyTable(resultTable, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,51 @@ class CorrelateStringExpressionTest extends TableTestBase {
"func1(substring(c, 2)) as (s)").select("a, c, s")
verifyTableEquals(scalaTable, javaTable)
}

@Test
def testFlatMap(): Unit = {

val util = streamTestUtil()
val sTab = util.addTable[(Int, Long, String)]('a, 'b, 'c)
val typeInfo = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING): _*)
val jTab = util.addJavaTable[Row](typeInfo,"MyTab","a, b, c")

// test flatMap
val func1 = new TableFunc1
util.javaTableEnv.registerFunction("func1", func1)
var scalaTable = sTab.flatMap(func1('c)).as('s).select('s)
var javaTable = jTab.flatMap("func1(c)").as("s").select("s")
verifyTableEquals(scalaTable, javaTable)

// test custom result type
val func2 = new TableFunc2
util.javaTableEnv.registerFunction("func2", func2)
scalaTable = sTab.flatMap(func2('c)).as('name, 'len).select('name, 'len)
javaTable = jTab.flatMap("func2(c)").as("name, len").select("name, len")
verifyTableEquals(scalaTable, javaTable)

// test hierarchy generic type
val hierarchy = new HierarchyTableFunction
util.javaTableEnv.registerFunction("hierarchy", hierarchy)
scalaTable = sTab.flatMap(hierarchy('c)).as('name, 'adult, 'len).select('name, 'len, 'adult)
javaTable = jTab.flatMap("hierarchy(c)").as("name, adult, len").select("name, len, adult")
verifyTableEquals(scalaTable, javaTable)

// test pojo type
val pojo = new PojoTableFunc
util.javaTableEnv.registerFunction("pojo", pojo)
scalaTable = sTab.flatMap(pojo('c)).select('name, 'age)
javaTable = jTab.flatMap("pojo(c)").select("name, age")
verifyTableEquals(scalaTable, javaTable)

// test with filter
scalaTable = sTab.flatMap(func2('c)).as('name, 'len).select('name, 'len).filter('len > 2)
javaTable = jTab.flatMap("func2(c)").as("name, len").select("name, len").filter("len > 2")
verifyTableEquals(scalaTable, javaTable)

// test with scalar function
scalaTable = sTab.flatMap(func1('c.substring(2))).as('s).select('s)
javaTable = jTab.flatMap("func1(substring(c, 2))").as("s").select("s")
verifyTableEquals(scalaTable, javaTable)
}
}
Loading

0 comments on commit 58e69a0

Please sign in to comment.