Skip to content

Commit

Permalink
[FLINK-16377][table] Support calls to inline functions in the
Browse files Browse the repository at this point in the history
expressions DSL.
  • Loading branch information
dawidwys committed Mar 25, 2020
1 parent 66101d8 commit 90a9b3e
Show file tree
Hide file tree
Showing 31 changed files with 873 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.table.expressions.TimePointUnit;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.utils.TypeConversions;
import org.apache.flink.table.types.utils.ValueDataTypeConverter;
Expand Down Expand Up @@ -509,10 +510,20 @@ public static ApiExpression withoutColumns(Object head, Object... tail) {
* @see TableEnvironment#createTemporaryFunction
* @see TableEnvironment#createTemporarySystemFunction
*/
public static ApiExpression call(String path, Object... params) {
public static ApiExpression call(String path, Object... arguments) {
return new ApiExpression(ApiExpressionUtils.lookupCall(
path,
Arrays.stream(params).map(ApiExpressionUtils::objectToExpression).toArray(Expression[]::new)));
Arrays.stream(arguments).map(ApiExpressionUtils::objectToExpression).toArray(Expression[]::new)));
}

/**
* A call to an unregistered, inline function.
*
* <p>For functions that have been registered before and are identified by a name, use
* {@link #call(String, Object...)}.
*/
public static ApiExpression call(UserDefinedFunction function, Object... arguments) {
return apiCall(function, arguments);
}

private static ApiExpression apiCall(FunctionDefinition functionDefinition, Object... args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.AggregateFunctionDefinition;
import org.apache.flink.table.functions.BuiltInFunctionDefinition;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
Expand Down Expand Up @@ -174,14 +173,19 @@ private List<ResolvedExpression> flattenCompositeType(ResolvedExpression composi
* Temporary method until all calls define a type inference.
*/
private Optional<TypeInference> getOptionalTypeInference(FunctionDefinition definition) {
if (definition instanceof BuiltInFunctionDefinition) {
final BuiltInFunctionDefinition builtInDefinition = (BuiltInFunctionDefinition) definition;
final TypeInference inference = builtInDefinition.getTypeInference(resolutionContext.typeFactory());
if (inference.getOutputTypeStrategy() != TypeStrategies.MISSING) {
return Optional.of(inference);
}
if (definition instanceof ScalarFunctionDefinition ||
definition instanceof TableFunctionDefinition ||
definition instanceof AggregateFunctionDefinition ||
definition instanceof TableAggregateFunctionDefinition) {
return Optional.empty();
}

final TypeInference inference = definition.getTypeInference(resolutionContext.typeFactory());
if (inference.getOutputTypeStrategy() != TypeStrategies.MISSING) {
return Optional.of(inference);
} else {
return Optional.empty();
}
return Optional.empty();
}

private ResolvedExpression runTypeInference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,52 @@
package org.apache.flink.table.operations;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.functions.TableFunction;

import javax.annotation.Nullable;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Describes a relational operation that was created from applying a {@link TableFunction}.
*/
@Internal
public class CalculatedQueryOperation<T> implements QueryOperation {
public class CalculatedQueryOperation implements QueryOperation {

private final TableFunction<T> tableFunction;
private final List<ResolvedExpression> parameters;
private final TypeInformation<T> resultType;
private final FunctionDefinition functionDefinition;
private final @Nullable FunctionIdentifier functionIdentifier;
private final List<ResolvedExpression> arguments;
private final TableSchema tableSchema;

public CalculatedQueryOperation(
TableFunction<T> tableFunction,
List<ResolvedExpression> parameters,
TypeInformation<T> resultType,
FunctionDefinition functionDefinition,
@Nullable FunctionIdentifier functionIdentifier,
List<ResolvedExpression> arguments,
TableSchema tableSchema) {
this.tableFunction = tableFunction;
this.parameters = parameters;
this.resultType = resultType;
this.functionDefinition = functionDefinition;
this.functionIdentifier = functionIdentifier;
this.arguments = arguments;
this.tableSchema = tableSchema;
}

public TableFunction<T> getTableFunction() {
return tableFunction;
public FunctionDefinition getFunctionDefinition() {
return functionDefinition;
}

public List<ResolvedExpression> getParameters() {
return parameters;
public Optional<FunctionIdentifier> getFunctionIdentifier() {
return Optional.ofNullable(functionIdentifier);
}

public TypeInformation<T> getResultType() {
return resultType;
public List<ResolvedExpression> getArguments() {
return arguments;
}

@Override
Expand All @@ -71,8 +75,12 @@ public TableSchema getTableSchema() {
@Override
public String asSummaryString() {
Map<String, Object> args = new LinkedHashMap<>();
args.put("function", tableFunction);
args.put("parameters", parameters);
if (functionIdentifier != null) {
args.put("function", functionIdentifier);
} else {
args.put("function", functionDefinition.toString());
}
args.put("arguments", arguments);

return OperationUtils.formatWithChildren("CalculatedTable", args, getChildren(), Operation::asSummaryString);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public interface QueryOperationVisitor<T> {

T visit(SortQueryOperation sort);

<U> T visit(CalculatedQueryOperation<U> calculatedTable);
T visit(CalculatedQueryOperation calculatedTable);

T visit(CatalogQueryOperation catalogTable);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,14 @@ public QueryOperation flatMap(Expression tableFunction, QueryOperation child) {
throw new ValidationException("Only a table function can be used in the flatMap operator.");
}

TypeInformation<?> resultType = ((TableFunctionDefinition) ((UnresolvedCallExpression) resolvedTableFunction)
.getFunctionDefinition())
.getResultType();
FunctionDefinition functionDefinition = ((UnresolvedCallExpression) resolvedTableFunction)
.getFunctionDefinition();
if (!(functionDefinition instanceof TableFunctionDefinition)) {
throw new ValidationException(
"The new type inference for functions is not supported in the flatMap yet.");
}

TypeInformation<?> resultType = ((TableFunctionDefinition) functionDefinition).getResultType();
List<String> originFieldNames = Arrays.asList(FieldInfoUtils.getFieldNames(resultType));

List<String> childFields = Arrays.asList(child.getTableSchema().getFieldNames());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.flink.table.operations.utils.factories;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.expressions.CallExpression;
Expand All @@ -28,19 +27,19 @@
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.utils.ResolvedExpressionDefaultVisitor;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.TableFunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.operations.CalculatedQueryOperation;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.typeutils.FieldInfoUtils;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
import org.apache.flink.table.types.utils.DataTypeUtils;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static java.util.stream.Collectors.toList;
import static org.apache.flink.table.expressions.ApiExpressionUtils.isFunctionOfKind;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AS;
import static org.apache.flink.table.functions.FunctionKind.TABLE;

/**
* Utility class for creating a valid {@link CalculatedQueryOperation} operation.
Expand All @@ -59,82 +58,102 @@ public QueryOperation create(ResolvedExpression callExpr, String[] leftTableFiel
return callExpr.accept(calculatedTableCreator);
}

private class FunctionTableCallVisitor extends ResolvedExpressionDefaultVisitor<CalculatedQueryOperation<?>> {

private String[] leftTableFieldNames;
private static class FunctionTableCallVisitor extends ResolvedExpressionDefaultVisitor<CalculatedQueryOperation> {
private List<String> leftTableFieldNames;
private static final String ATOMIC_FIELD_NAME = "f0";

public FunctionTableCallVisitor(String[] leftTableFieldNames) {
this.leftTableFieldNames = leftTableFieldNames;
this.leftTableFieldNames = Arrays.asList(leftTableFieldNames);
}

@Override
public CalculatedQueryOperation<?> visit(CallExpression call) {
public CalculatedQueryOperation visit(CallExpression call) {
FunctionDefinition definition = call.getFunctionDefinition();
if (definition.equals(AS)) {
return unwrapFromAlias(call);
} else if (definition instanceof TableFunctionDefinition) {
return createFunctionCall(
(TableFunctionDefinition) definition,
Collections.emptyList(),
call.getResolvedChildren());
} else {
return defaultMethod(call);
}

return createFunctionCall(call, Collections.emptyList(), call.getResolvedChildren());
}

private CalculatedQueryOperation<?> unwrapFromAlias(CallExpression call) {
private CalculatedQueryOperation unwrapFromAlias(CallExpression call) {
List<Expression> children = call.getChildren();
List<String> aliases = children.subList(1, children.size())
.stream()
.map(alias -> ExpressionUtils.extractValue(alias, String.class)
.orElseThrow(() -> new ValidationException("Unexpected alias: " + alias)))
.collect(toList());

if (!isFunctionOfKind(children.get(0), TABLE)) {
if (!(children.get(0) instanceof CallExpression)) {
throw fail();
}

CallExpression tableCall = (CallExpression) children.get(0);
TableFunctionDefinition tableFunctionDefinition =
(TableFunctionDefinition) tableCall.getFunctionDefinition();
return createFunctionCall(tableFunctionDefinition, aliases, tableCall.getResolvedChildren());
return createFunctionCall(tableCall, aliases, tableCall.getResolvedChildren());
}

private CalculatedQueryOperation<?> createFunctionCall(
TableFunctionDefinition tableFunctionDefinition,
private CalculatedQueryOperation createFunctionCall(
CallExpression callExpression,
List<String> aliases,
List<ResolvedExpression> parameters) {
TypeInformation<?> resultType = tableFunctionDefinition.getResultType();

int callArity = resultType.getTotalFields();
int aliasesSize = aliases.size();
FunctionDefinition functionDefinition = callExpression.getFunctionDefinition();
final TableSchema tableSchema = adjustNames(
extractSchema(callExpression.getOutputDataType()),
aliases,
callExpression.getFunctionIdentifier()
.map(FunctionIdentifier::asSummaryString)
.orElse(functionDefinition.toString()));

return new CalculatedQueryOperation(
functionDefinition,
callExpression.getFunctionIdentifier().orElse(null),
parameters,
tableSchema);
}

String[] fieldNames;
private TableSchema extractSchema(DataType resultDataType) {
if (LogicalTypeChecks.isCompositeType(resultDataType.getLogicalType())) {
return DataTypeUtils.expandCompositeTypeToSchema(resultDataType);
}

int i = 0;
String fieldName = ATOMIC_FIELD_NAME;
while (leftTableFieldNames.contains(fieldName)) {
fieldName = ATOMIC_FIELD_NAME + "_" + i++;
}
return TableSchema.builder()
.field(fieldName, resultDataType)
.build();
}

private TableSchema adjustNames(
TableSchema tableSchema,
List<String> aliases,
String functionName) {
int aliasesSize = aliases.size();
if (aliasesSize == 0) {
fieldNames = FieldInfoUtils.getFieldNames(resultType, Arrays.asList(leftTableFieldNames));
} else if (aliasesSize != callArity) {
return tableSchema;
}

int callArity = tableSchema.getFieldCount();
if (callArity != aliasesSize) {
throw new ValidationException(String.format(
"List of column aliases must have same degree as table; " +
"the returned table of function '%s' has " +
"%d columns, whereas alias list has %d columns",
tableFunctionDefinition.toString(),
functionName,
callArity,
aliasesSize));
} else {
fieldNames = aliases.toArray(new String[aliasesSize]);
}

TypeInformation<?>[] fieldTypes = FieldInfoUtils.getFieldTypes(resultType);

return new CalculatedQueryOperation(
tableFunctionDefinition.getTableFunction(),
parameters,
tableFunctionDefinition.getResultType(),
new TableSchema(fieldNames, fieldTypes));
return TableSchema.builder()
.fields(aliases.toArray(new String[0]), tableSchema.getFieldDataTypes())
.build();
}

@Override
protected CalculatedQueryOperation<?> defaultMethod(ResolvedExpression expression) {
protected CalculatedQueryOperation defaultMethod(ResolvedExpression expression) {
throw fail();
}

Expand Down
Loading

0 comments on commit 90a9b3e

Please sign in to comment.