From 413cc9616963f827779cd3bd1f7ea72b9c79946c Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sun, 19 Apr 2015 00:28:27 -0700 Subject: [PATCH] Add implicit coersions to VALUES --- .../sql/analyzer/ExpressionAnalyzer.java | 7 ++ .../presto/sql/analyzer/TupleAnalyzer.java | 72 ++++++++++++----- .../presto/sql/planner/RelationPlanner.java | 80 ++++++++++++++----- .../presto/tests/AbstractTestQueries.java | 5 ++ 4 files changed, 125 insertions(+), 39 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index da81ca33f00b..fdce9159bb6f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -1042,4 +1042,11 @@ public static ExpressionAnalyzer create(Analysis analysis, Session session, Meta metadata.getTypeManager(), node -> new StatementAnalyzer(analysis, metadata, sqlParser, session, experimentalSyntaxEnabled, Optional.empty())); } + + public static ExpressionAnalyzer createWithoutSubqueries(FunctionRegistry functionRegistry, TypeManager typeManager, SemanticErrorCode errorCode, String message) + { + return new ExpressionAnalyzer(functionRegistry, typeManager, node -> { + throw new SemanticException(errorCode, node, message); + }); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java index 02f1b6235de7..82d3c4c90a29 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/TupleAnalyzer.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; -import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.metadata.FunctionInfo; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataUtil; @@ -22,6 +21,7 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.metadata.TableMetadata; import com.facebook.presto.metadata.ViewDefinition; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.BigintType; @@ -56,6 +56,7 @@ import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.Relation; +import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SampledRelation; import com.facebook.presto.sql.tree.SelectItem; import com.facebook.presto.sql.tree.SingleColumn; @@ -89,6 +90,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.metadata.FunctionRegistry.getCommonSuperType; import static com.facebook.presto.metadata.ViewDefinition.ViewColumn; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -534,32 +536,60 @@ protected TupleDescriptor visitValues(Values node, AnalysisContext context) { checkState(node.getRows().size() >= 1); - Set types = node.getRows().stream() - .map((row) -> analyzeExpression(row, new TupleDescriptor(), context).getType(row)) + // get unique row types + Set> rowTypes = node.getRows().stream() + .map(row -> analyzeExpression(row, new TupleDescriptor(), context).getType(row)) + .map(type -> { + if (type instanceof RowType) { + return type.getTypeParameters(); + } + return ImmutableList.of(type); + }) .collect(ImmutableCollectors.toImmutableSet()); - if (types.size() > 1) { - throw new SemanticException(MISMATCHED_SET_COLUMN_TYPES, - node, - "Values rows have mismatched types: %s vs %s", - Iterables.get(types, 0), - Iterables.get(types, 1)); + // determine common super type of the rows + List fieldTypes = new ArrayList<>(rowTypes.iterator().next()); + for (List rowType : rowTypes) { + for (int i = 0; i < rowType.size(); i++) { + Type fieldType = rowType.get(i); + Type superType = fieldTypes.get(i); + + Optional commonSuperType = getCommonSuperType(fieldType, superType); + if (!commonSuperType.isPresent()) { + throw new SemanticException(MISMATCHED_SET_COLUMN_TYPES, + node, + "Values rows have mismatched types: %s vs %s", + Iterables.get(rowTypes, 0), + Iterables.get(rowTypes, 1)); + } + fieldTypes.set(i, commonSuperType.get()); + } } - Type type = Iterables.getOnlyElement(types); - - List fields; - if (type instanceof RowType) { - fields = ((RowType) type).getFields().stream() - .map(RowType.RowField::getType) - .map((valueType) -> Field.newUnqualified(Optional.empty(), valueType)) - .collect(toImmutableList()); - } - else { - fields = ImmutableList.of(Field.newUnqualified(Optional.empty(), type)); + // add coercions for the rows + for (Expression row : node.getRows()) { + if (row instanceof Row) { + List items = ((Row) row).getItems(); + for (int i = 0; i < items.size(); i++) { + Type expectedType = fieldTypes.get(i); + Expression item = items.get(i); + if (!analysis.getType(item).equals(expectedType)) { + analysis.addCoercion(item, expectedType); + } + } + } + else { + Type expectedType = fieldTypes.get(0); + if (!analysis.getType(row).equals(expectedType)) { + analysis.addCoercion(row, expectedType); + } + } } - TupleDescriptor descriptor = new TupleDescriptor(fields); + TupleDescriptor descriptor = new TupleDescriptor(fieldTypes.stream() + .map(valueType -> Field.newUnqualified(Optional.empty(), valueType)) + .collect(toImmutableList())); + analysis.setOutputDescriptor(node, descriptor); return descriptor; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 7da4f8569f39..ef016e7a094d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -22,6 +22,8 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.analyzer.Analysis; +import com.facebook.presto.sql.analyzer.AnalysisContext; +import com.facebook.presto.sql.analyzer.ExpressionAnalyzer; import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.analyzer.FieldOrExpression; import com.facebook.presto.sql.analyzer.SemanticException; @@ -42,12 +44,16 @@ import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.BooleanLiteral; +import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.InputReference; import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.QualifiedName; @@ -438,14 +444,17 @@ protected RelationPlan visitValues(Values node, Void context) ImmutableList.Builder> rows = ImmutableList.builder(); for (Expression row : node.getRows()) { ImmutableList.Builder values = ImmutableList.builder(); - if (row instanceof Row) { - for (Expression expression : ((Row) row).getItems()) { - values.add(evaluateConstantExpression(expression)); + List items = ((Row) row).getItems(); + for (int i = 0; i < items.size(); i++) { + Expression expression = items.get(i); + Object constantValue = evaluateConstantExpression(expression); + values.add(LiteralInterpreter.toExpression(constantValue, descriptor.getFieldByIndex(i).getType())); } } else { - values.add(row); + Object constantValue = evaluateConstantExpression(row); + values.add(LiteralInterpreter.toExpression(constantValue, descriptor.getFieldByIndex(0).getType())); } rows.add(values.build()); @@ -472,8 +481,9 @@ protected RelationPlan visitUnnest(Unnest node, Void context) ImmutableMap.Builder> unnestSymbols = ImmutableMap.builder(); Iterator unnestedSymbolsIterator = unnestedSymbols.iterator(); for (Expression expression : node.getExpressions()) { - values.add(evaluateConstantExpression(expression)); + Object constantValue = evaluateConstantExpression(expression); Type type = analysis.getType(expression); + values.add(LiteralInterpreter.toExpression(constantValue, type)); Symbol inputSymbol = symbolAllocator.newSymbol(expression, type); argumentSymbols.add(inputSymbol); if (type instanceof ArrayType) { @@ -494,27 +504,61 @@ else if (type instanceof MapType) { return new RelationPlan(unnestNode, descriptor, unnestedSymbols, Optional.empty()); } - private Expression evaluateConstantExpression(final Expression expression) + private Object evaluateConstantExpression(Expression expression) { + // verify expression is constant + expression.accept(new DefaultTraversalVisitor() + { + @Override + protected Void visitQualifiedNameReference(QualifiedNameReference node, Void context) + { + throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references"); + } + + @Override + protected Void visitInputReference(InputReference node, Void context) + { + throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain input references"); + } + }, null); + + // add coercions + Expression rewrite = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { + @Override + public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter treeRewriter) + { + Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context); + + // cast expression if coercion is registered + Type coercion = analysis.getCoercion(node); + if (coercion != null) { + rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString()); + } + + return rewrittenExpression; + } + }, expression); + try { // expressionInterpreter/optimizer only understands a subset of expression types // TODO: remove this when the new expression tree is implemented - Expression canonicalized = CanonicalizeExpressions.canonicalizeExpression(expression); - - // verify the expression is constant (has no inputs) - ExpressionInterpreter.expressionOptimizer(canonicalized, metadata, session, analysis.getTypes()).optimize(new SymbolResolver() { - @Override - public Object getValue(Symbol symbol) - { - throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references"); - } - }); + Expression canonicalized = CanonicalizeExpressions.canonicalizeExpression(rewrite); + + // The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis + // to re-analyze coercions that might be necessary + ExpressionAnalyzer analyzer = ExpressionAnalyzer.createWithoutSubqueries( + metadata.getFunctionRegistry(), + metadata.getTypeManager(), + EXPRESSION_NOT_CONSTANT, + "Constant expression cannot contain as sub-query"); + analyzer.analyze(canonicalized, new TupleDescriptor(), new AnalysisContext()); // evaluate the expression - Object result = ExpressionInterpreter.expressionInterpreter(canonicalized, metadata, session, analysis.getTypes()).evaluate(0); + Object result = ExpressionInterpreter.expressionInterpreter(canonicalized, metadata, session, analyzer.getExpressionTypes()).evaluate(0); checkState(!(result instanceof Expression), "Expression interpreter returned an unresolved expression"); - return LiteralInterpreter.toExpression(result, analysis.getType(expression)); + return result; } catch (Exception e) { throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Error evaluating constant expression: %s", e.getMessage()); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 5cb6f499024a..dad959875d8f 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -219,6 +219,11 @@ public void testValues() "WITH a AS (VALUES (1.1, 2), (sin(3.3), 2+2)) " + "SELECT * FROM a", "VALUES (1.1, 2), (sin(3.3), 2+2)"); + + // implicity coersions + assertQuery("VALUES 1, 2.2, 3, 4.4"); + assertQuery("VALUES (1, 2), (3.3, 4.4)"); + assertQuery("VALUES true, 1.0 in (1, 2, 3)"); } @Test