Skip to content

Commit

Permalink
Add implicit coersions to VALUES
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Apr 30, 2015
1 parent b645d9a commit 413cc96
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1042,4 +1042,11 @@ public static ExpressionAnalyzer create(Analysis analysis, Session session, Meta
metadata.getTypeManager(),
node -> new StatementAnalyzer(analysis, metadata, sqlParser, session, experimentalSyntaxEnabled, Optional.empty()));
}

public static ExpressionAnalyzer createWithoutSubqueries(FunctionRegistry functionRegistry, TypeManager typeManager, SemanticErrorCode errorCode, String message)
{
return new ExpressionAnalyzer(functionRegistry, typeManager, node -> {
throw new SemanticException(errorCode, node, message);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
package com.facebook.presto.sql.analyzer;

import com.facebook.presto.Session;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.metadata.FunctionInfo;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataUtil;
import com.facebook.presto.metadata.QualifiedTableName;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.metadata.TableMetadata;
import com.facebook.presto.metadata.ViewDefinition;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ColumnMetadata;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.type.BigintType;
Expand Down Expand Up @@ -56,6 +56,7 @@
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.Relation;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SampledRelation;
import com.facebook.presto.sql.tree.SelectItem;
import com.facebook.presto.sql.tree.SingleColumn;
Expand Down Expand Up @@ -89,6 +90,7 @@
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.metadata.FunctionRegistry.getCommonSuperType;
import static com.facebook.presto.metadata.ViewDefinition.ViewColumn;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
Expand Down Expand Up @@ -534,32 +536,60 @@ protected TupleDescriptor visitValues(Values node, AnalysisContext context)
{
checkState(node.getRows().size() >= 1);

Set<Type> types = node.getRows().stream()
.map((row) -> analyzeExpression(row, new TupleDescriptor(), context).getType(row))
// get unique row types
Set<List<Type>> rowTypes = node.getRows().stream()
.map(row -> analyzeExpression(row, new TupleDescriptor(), context).getType(row))
.map(type -> {
if (type instanceof RowType) {
return type.getTypeParameters();
}
return ImmutableList.of(type);
})
.collect(ImmutableCollectors.toImmutableSet());

if (types.size() > 1) {
throw new SemanticException(MISMATCHED_SET_COLUMN_TYPES,
node,
"Values rows have mismatched types: %s vs %s",
Iterables.get(types, 0),
Iterables.get(types, 1));
// determine common super type of the rows
List<Type> fieldTypes = new ArrayList<>(rowTypes.iterator().next());
for (List<Type> rowType : rowTypes) {
for (int i = 0; i < rowType.size(); i++) {
Type fieldType = rowType.get(i);
Type superType = fieldTypes.get(i);

Optional<Type> commonSuperType = getCommonSuperType(fieldType, superType);
if (!commonSuperType.isPresent()) {
throw new SemanticException(MISMATCHED_SET_COLUMN_TYPES,
node,
"Values rows have mismatched types: %s vs %s",
Iterables.get(rowTypes, 0),
Iterables.get(rowTypes, 1));
}
fieldTypes.set(i, commonSuperType.get());
}
}

Type type = Iterables.getOnlyElement(types);

List<Field> fields;
if (type instanceof RowType) {
fields = ((RowType) type).getFields().stream()
.map(RowType.RowField::getType)
.map((valueType) -> Field.newUnqualified(Optional.empty(), valueType))
.collect(toImmutableList());
}
else {
fields = ImmutableList.of(Field.newUnqualified(Optional.empty(), type));
// add coercions for the rows
for (Expression row : node.getRows()) {
if (row instanceof Row) {
List<Expression> items = ((Row) row).getItems();
for (int i = 0; i < items.size(); i++) {
Type expectedType = fieldTypes.get(i);
Expression item = items.get(i);
if (!analysis.getType(item).equals(expectedType)) {
analysis.addCoercion(item, expectedType);
}
}
}
else {
Type expectedType = fieldTypes.get(0);
if (!analysis.getType(row).equals(expectedType)) {
analysis.addCoercion(row, expectedType);
}
}
}

TupleDescriptor descriptor = new TupleDescriptor(fields);
TupleDescriptor descriptor = new TupleDescriptor(fieldTypes.stream()
.map(valueType -> Field.newUnqualified(Optional.empty(), valueType))
.collect(toImmutableList()));

analysis.setOutputDescriptor(node, descriptor);
return descriptor;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.Analysis;
import com.facebook.presto.sql.analyzer.AnalysisContext;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.Field;
import com.facebook.presto.sql.analyzer.FieldOrExpression;
import com.facebook.presto.sql.analyzer.SemanticException;
Expand All @@ -42,12 +44,16 @@
import com.facebook.presto.sql.tree.AliasedRelation;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.InputReference;
import com.facebook.presto.sql.tree.Join;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
Expand Down Expand Up @@ -438,14 +444,17 @@ protected RelationPlan visitValues(Values node, Void context)
ImmutableList.Builder<List<Expression>> rows = ImmutableList.builder();
for (Expression row : node.getRows()) {
ImmutableList.Builder<Expression> values = ImmutableList.builder();

if (row instanceof Row) {
for (Expression expression : ((Row) row).getItems()) {
values.add(evaluateConstantExpression(expression));
List<Expression> items = ((Row) row).getItems();
for (int i = 0; i < items.size(); i++) {
Expression expression = items.get(i);
Object constantValue = evaluateConstantExpression(expression);
values.add(LiteralInterpreter.toExpression(constantValue, descriptor.getFieldByIndex(i).getType()));
}
}
else {
values.add(row);
Object constantValue = evaluateConstantExpression(row);
values.add(LiteralInterpreter.toExpression(constantValue, descriptor.getFieldByIndex(0).getType()));
}

rows.add(values.build());
Expand All @@ -472,8 +481,9 @@ protected RelationPlan visitUnnest(Unnest node, Void context)
ImmutableMap.Builder<Symbol, List<Symbol>> unnestSymbols = ImmutableMap.builder();
Iterator<Symbol> unnestedSymbolsIterator = unnestedSymbols.iterator();
for (Expression expression : node.getExpressions()) {
values.add(evaluateConstantExpression(expression));
Object constantValue = evaluateConstantExpression(expression);
Type type = analysis.getType(expression);
values.add(LiteralInterpreter.toExpression(constantValue, type));
Symbol inputSymbol = symbolAllocator.newSymbol(expression, type);
argumentSymbols.add(inputSymbol);
if (type instanceof ArrayType) {
Expand All @@ -494,27 +504,61 @@ else if (type instanceof MapType) {
return new RelationPlan(unnestNode, descriptor, unnestedSymbols, Optional.empty());
}

private Expression evaluateConstantExpression(final Expression expression)
private Object evaluateConstantExpression(Expression expression)
{
// verify expression is constant
expression.accept(new DefaultTraversalVisitor<Void, Void>()
{
@Override
protected Void visitQualifiedNameReference(QualifiedNameReference node, Void context)
{
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references");
}

@Override
protected Void visitInputReference(InputReference node, Void context)
{
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain input references");
}
}, null);

// add coercions
Expression rewrite = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
{
@Override
public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context);

// cast expression if coercion is registered
Type coercion = analysis.getCoercion(node);
if (coercion != null) {
rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString());
}

return rewrittenExpression;
}
}, expression);

try {
// expressionInterpreter/optimizer only understands a subset of expression types
// TODO: remove this when the new expression tree is implemented
Expression canonicalized = CanonicalizeExpressions.canonicalizeExpression(expression);

// verify the expression is constant (has no inputs)
ExpressionInterpreter.expressionOptimizer(canonicalized, metadata, session, analysis.getTypes()).optimize(new SymbolResolver() {
@Override
public Object getValue(Symbol symbol)
{
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references");
}
});
Expression canonicalized = CanonicalizeExpressions.canonicalizeExpression(rewrite);

// The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis
// to re-analyze coercions that might be necessary
ExpressionAnalyzer analyzer = ExpressionAnalyzer.createWithoutSubqueries(
metadata.getFunctionRegistry(),
metadata.getTypeManager(),
EXPRESSION_NOT_CONSTANT,
"Constant expression cannot contain as sub-query");
analyzer.analyze(canonicalized, new TupleDescriptor(), new AnalysisContext());

// evaluate the expression
Object result = ExpressionInterpreter.expressionInterpreter(canonicalized, metadata, session, analysis.getTypes()).evaluate(0);
Object result = ExpressionInterpreter.expressionInterpreter(canonicalized, metadata, session, analyzer.getExpressionTypes()).evaluate(0);
checkState(!(result instanceof Expression), "Expression interpreter returned an unresolved expression");

return LiteralInterpreter.toExpression(result, analysis.getType(expression));
return result;
}
catch (Exception e) {
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Error evaluating constant expression: %s", e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ public void testValues()
"WITH a AS (VALUES (1.1, 2), (sin(3.3), 2+2)) " +
"SELECT * FROM a",
"VALUES (1.1, 2), (sin(3.3), 2+2)");

// implicity coersions
assertQuery("VALUES 1, 2.2, 3, 4.4");
assertQuery("VALUES (1, 2), (3.3, 4.4)");
assertQuery("VALUES true, 1.0 in (1, 2, 3)");
}

@Test
Expand Down

0 comments on commit 413cc96

Please sign in to comment.