Skip to content

Commit

Permalink
[FLINK-21223][python] Support to specify the output types of Python U…
Browse files Browse the repository at this point in the history
…DFs via string

This closes apache#21332.
  • Loading branch information
HuangXingBo committed Dec 12, 2022
1 parent 7bff4ea commit 6cc00a7
Show file tree
Hide file tree
Showing 15 changed files with 639 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ The output will be flattened if the output type is a composite type.
from pyflink.common import Row
from pyflink.table import EnvironmentSettings, TableEnvironment
from pyflink.table.expressions import col
from pyflink.table.types import DataTypes
from pyflink.table.udf import udf

env_settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(env_settings)

table = table_env.from_elements([(1, 'Hi'), (2, 'Hello')], ['id', 'data'])

@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
DataTypes.FIELD("data", DataTypes.STRING())]))
@udf(result_type='ROW<id BIGINT, data STRING>')
def func1(id: int, data: str) -> Row:
return Row(id, data * 2)

Expand All @@ -62,8 +60,7 @@ table.map(func1(col('id'), col('data'))).execute().print()
It also supports to take a Row object (containing all the columns of the input table) as input.

```python
@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
DataTypes.FIELD("data", DataTypes.STRING())]))
@udf(result_type='ROW<id BIGINT, data STRING>')
def func2(data: Row) -> Row:
return Row(data.id, data.data * 2)

Expand All @@ -85,9 +82,7 @@ It should be noted that the input type and output type should be pandas.DataFram

```python
import pandas as pd
@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
DataTypes.FIELD("data", DataTypes.STRING())]),
func_type='pandas')
@udf(result_type='ROW<id BIGINT, data STRING>', func_type='pandas')
def func3(data: pd.DataFrame) -> pd.DataFrame:
res = pd.concat([data.id, data.data * 2], axis=1)
return res
Expand All @@ -109,14 +104,14 @@ Performs a `flat_map` operation with a python [table function]({{< ref "docs/dev
```python
from pyflink.common import Row
from pyflink.table.udf import udtf
from pyflink.table import DataTypes, EnvironmentSettings, TableEnvironment
from pyflink.table import EnvironmentSettings, TableEnvironment

env_settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(env_settings)

table = table_env.from_elements([(1, 'Hi,Flink'), (2, 'Hello')], ['id', 'data'])

@udtf(result_types=[DataTypes.INT(), DataTypes.STRING()])
@udtf(result_types=['INT', 'STRING'])
def split(x: Row) -> Row:
for s in x.data.split(","):
yield x.id, s
Expand Down Expand Up @@ -154,7 +149,7 @@ Performs an `aggregate` operation with a python [general aggregate function]({{<

```python
from pyflink.common import Row
from pyflink.table import DataTypes, EnvironmentSettings, TableEnvironment
from pyflink.table import EnvironmentSettings, TableEnvironment
from pyflink.table.expressions import col
from pyflink.table.udf import AggregateFunction, udaf

Expand All @@ -180,14 +175,10 @@ class CountAndSumAggregateFunction(AggregateFunction):
accumulator[1] += other_acc[1]

def get_accumulator_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT())])
return 'ROW<a BIGINT, b BIGINT>'

def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT()),
DataTypes.FIELD("b", DataTypes.BIGINT())])
return 'ROW<a BIGINT, b BIGINT>'

function = CountAndSumAggregateFunction()
agg = udaf(function,
Expand Down Expand Up @@ -221,9 +212,7 @@ table_env = TableEnvironment.create(env_settings)
t = table_env.from_elements([(1, 2), (2, 1), (1, 3)], ['a', 'b'])

pandas_udaf = udaf(lambda pd: (pd.b.mean(), pd.b.max()),
result_type=DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.FLOAT()),
DataTypes.FIELD("b", DataTypes.INT())]),
result_type='ROW<a FLOAT, b INT>',
func_type="pandas")
t.aggregate(pandas_udaf.alias("a", "b")) \
.select(col('a'), col('b')).execute().print()
Expand All @@ -250,7 +239,7 @@ Similar to `aggregate`, you have to close the `flat_aggregate` with a select sta

```python
from pyflink.common import Row
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udtaf, TableAggregateFunction

Expand All @@ -272,11 +261,10 @@ class Top2(TableAggregateFunction):
accumulator[1] = row.a

def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
return 'ARRAY<BIGINT>'

def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT())])
return 'ROW<a BIGINT>'


env_settings = EnvironmentSettings.in_streaming_mode()
Expand Down
33 changes: 15 additions & 18 deletions docs/content.zh/docs/dev/python/table/udfs/python_udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ class HashCode(ScalarFunction):
settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(settings)

hash_code = udf(HashCode(), result_type=DataTypes.BIGINT())
hash_code = udf(HashCode(), result_type='BIGINT')

# 在 Python Table API 中使用 Python 自定义函数
my_table.select(col("string"), col("bigint"), hash_code(col("bigint")), call(hash_code, col("bigint")))

# 在 SQL API 中使用 Python 自定义函数
table_env.create_temporary_function("hash_code", udf(HashCode(), result_type=DataTypes.BIGINT()))
table_env.create_temporary_function("hash_code", udf(HashCode(), result_type='BIGINT'))
table_env.sql_query("SELECT string, bigint, hash_code(bigint) FROM MyTable")
```

Expand Down Expand Up @@ -108,25 +108,25 @@ class Add(ScalarFunction):
add = udf(Add(), result_type=DataTypes.BIGINT())

# 方式二:普通 Python 函数
@udf(result_type=DataTypes.BIGINT())
@udf(result_type='BIGINT')
def add(i, j):
return i + j

# 方式三:lambda 函数
add = udf(lambda i, j: i + j, result_type=DataTypes.BIGINT())
add = udf(lambda i, j: i + j, result_type='BIGINT')

# 方式四:callable 函数
class CallableAdd(object):
def __call__(self, i, j):
return i + j

add = udf(CallableAdd(), result_type=DataTypes.BIGINT())
add = udf(CallableAdd(), result_type='BIGINT')

# 方式五:partial 函数
def partial_add(i, j, k):
return i + j + k

add = udf(functools.partial(partial_add, k=1), result_type=DataTypes.BIGINT())
add = udf(functools.partial(partial_add, k=1), result_type='BIGINT')

# 注册 Python 自定义函数
table_env.create_temporary_function("add", add)
Expand Down Expand Up @@ -160,14 +160,14 @@ table_env = TableEnvironment.create(env_settings)
my_table = ... # type: Table, table schema: [a: String]

# 注册 Python 表值函数
split = udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()])
split = udtf(Split(), result_types=['STRING', 'INT'])

# 在 Python Table API 中使用 Python 表值函数
my_table.join_lateral(split(col("a")).alias("word", "length"))
my_table.left_outer_join_lateral(split(col("a")).alias("word", "length"))

# 在 SQL API 中使用 Python 表值函数
table_env.create_temporary_function("split", udtf(Split(), result_types=[DataTypes.STRING(), DataTypes.INT()]))
table_env.create_temporary_function("split", udtf(Split(), result_types=['STRING', 'INT']))
table_env.sql_query("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")
```
Expand Down Expand Up @@ -219,18 +219,18 @@ table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE

```python
# 方式一:生成器函数
@udtf(result_types=DataTypes.BIGINT())
@udtf(result_types='BIGINT')
def generator_func(x):
yield 1
yield 2

# 方式二:返回迭代器
@udtf(result_types=DataTypes.BIGINT())
@udtf(result_types='BIGINT')
def iterator_func(x):
return range(5)

# 方式三:返回可迭代子类
@udtf(result_types=DataTypes.BIGINT())
@udtf(result_types='BIGINT')
def iterable_func(x):
result = [1, 2, 3]
return result
Expand Down Expand Up @@ -300,12 +300,10 @@ class WeightedAvg(AggregateFunction):
accumulator[1] -= weight

def get_result_type(self):
return DataTypes.BIGINT()
return 'BIGINT'

def get_accumulator_type(self):
return DataTypes.ROW([
DataTypes.FIELD("f0", DataTypes.BIGINT()),
DataTypes.FIELD("f1", DataTypes.BIGINT())])
return 'ROW<f0 BIGINT, f1 BIGINT>'


env_settings = EnvironmentSettings.in_streaming_mode()
Expand Down Expand Up @@ -475,11 +473,10 @@ class Top2(TableAggregateFunction):
accumulator[1] = row[0]

def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
return 'ARRAY<BIGINT>'

def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.BIGINT())])
return 'ROW<a BIGINT>'


env_settings = EnvironmentSettings.in_streaming_mode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ under the License.
以下示例显示了如何定义自己的向量化 Python 标量函数,该函数计算两列的总和,并在查询中使用它:

```python
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udf

@udf(result_type=DataTypes.BIGINT(), func_type="pandas")
@udf(result_type='BIGINT', func_type="pandas")
def add(i, j):
return i + j

Expand Down Expand Up @@ -85,12 +85,12 @@ table_env.sql_query("SELECT add(bigint, bigint) FROM MyTable")
and `Over Window Aggregation` 使用它:

```python
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table import TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col, lit
from pyflink.table.udf import udaf
from pyflink.table.window import Tumble

@udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
@udaf(result_type='FLOAT', func_type="pandas")
def mean_udaf(v):
return v.mean()

Expand Down Expand Up @@ -126,7 +126,6 @@ table_env.sql_query("""
以下示例显示了多种定义向量化 Python 聚合函数的方式。该函数需要两个类型为 bigint 的参数作为输入参数,并返回它们的最大值的和作为结果。

```python
from pyflink.table import DataTypes
from pyflink.table.udf import AggregateFunction, udaf

# 方式一:扩展基类 `AggregateFunction`
Expand All @@ -152,26 +151,26 @@ class MaxAdd(AggregateFunction):
result += arg.max()
accumulator.append(result)

max_add = udaf(MaxAdd(), result_type=DataTypes.BIGINT(), func_type="pandas")
max_add = udaf(MaxAdd(), result_type='BIGINT', func_type="pandas")

# 方式二:普通 Python 函数
@udaf(result_type=DataTypes.BIGINT(), func_type="pandas")
@udaf(result_type='BIGINT', func_type="pandas")
def max_add(i, j):
return i.max() + j.max()

# 方式三:lambda 函数
max_add = udaf(lambda i, j: i.max() + j.max(), result_type=DataTypes.BIGINT(), func_type="pandas")
max_add = udaf(lambda i, j: i.max() + j.max(), result_type='BIGINT', func_type="pandas")

# 方式四:callable 函数
class CallableMaxAdd(object):
def __call__(self, i, j):
return i.max() + j.max()

max_add = udaf(CallableMaxAdd(), result_type=DataTypes.BIGINT(), func_type="pandas")
max_add = udaf(CallableMaxAdd(), result_type='BIGINT', func_type="pandas")

# 方式五:partial 函数
def partial_max_add(i, j, k):
return i.max() + j.max() + k

max_add = udaf(functools.partial(partial_max_add, k=1), result_type=DataTypes.BIGINT(), func_type="pandas")
max_add = udaf(functools.partial(partial_max_add, k=1), result_type='BIGINT', func_type="pandas")
```
Loading

0 comments on commit 6cc00a7

Please sign in to comment.