diff --git a/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java b/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java index c316874..94f1624 100644 --- a/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java +++ b/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java @@ -41,7 +41,7 @@ MModelInstance trainModelInstance( void dropModelInstance(String name) throws CatalogException; - Collection getModelInstances(String modelName) throws CatalogException; + Collection getModelInstances() throws CatalogException; boolean modelInstanceExists(String name) throws CatalogException; diff --git a/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java b/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java index 958f1ba..9f9d5b6 100644 --- a/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java +++ b/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java @@ -131,15 +131,12 @@ public void dropModelInstance(String name) throws CatalogException { } @Override - public Collection getModelInstances(String modelName) throws CatalogException { + public Collection getModelInstances() throws CatalogException { try { Query query = pm.newQuery(MModelInstance.class); - query.setFilter("model.name == modelName"); - query.declareParameters("String modelName"); - - return (List) query.execute(modelName); + return (List) query.execute(); } catch (RuntimeException e) { - throw new CatalogException("failed to get model '" + modelName + "' instances", e); + throw new CatalogException("failed to get model instances", e); } } @@ -152,7 +149,9 @@ public boolean modelInstanceExists(String name) throws CatalogException { public MModelInstance getModelInstance(String name) throws CatalogException { try { Query query = pm.newQuery(MModelInstance.class); - query.setFilter("name == name"); + if (name != null) { + query.setFilter("name == name"); + } query.declareParameters("String name"); query.setUnique(true); diff --git a/traindb-common/src/main/scripts/traindb-config.sh b/traindb-common/src/main/scripts/traindb-config.sh index 79dbecf..20e80e7 100755 --- a/traindb-common/src/main/scripts/traindb-config.sh +++ b/traindb-common/src/main/scripts/traindb-config.sh @@ -57,6 +57,9 @@ fi TRAINDB_OPTS="$JAVA_HEAP_MAX" +# uncomment if you want to attach a debugger +#TRAINDB_OPTS="$TRAINDB_OPTS -agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:5005" + lines=$("$JAVA" -version 2>&1 | tr '\r' '\n') ver=$(echo $lines | sed -e 's/.*version "\(.*\)"\(.*\)/\1/; 1q') if [[ $ver = "1."* ]]; then diff --git a/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 b/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 index adbe218..aa3e704 100644 --- a/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 +++ b/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 @@ -68,7 +68,7 @@ showStmt showTargets : K_MODELS # ShowModels - | K_MODEL modelName K_INSTANCES # ShowModelInstances + | K_MODEL K_INSTANCES # ShowModelInstances | K_SYNOPSES # ShowSynopses ; diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcConvention.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcConvention.java new file mode 100644 index 0000000..7901e3c --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcConvention.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.sql.SqlDialect; + +/** + * Calling convention for relational operations that occur in a JDBC + * database. + * + *

The convention is a slight misnomer. The operations occur in whatever + * data-flow architecture the database uses internally. Nevertheless, the result + * pops out in JDBC.

+ * + *

This is the only convention, thus far, that is not a singleton. Each + * instance contains a JDBC schema (and therefore a data source). If Calcite is + * working with two different databases, it would even make sense to convert + * from "JDBC#A" convention to "JDBC#B", even though we don't do it currently. + * (That would involve asking database B to open a database link to database + * A.)

+ * + *

As a result, converter rules from and to this convention need to be + * instantiated, at the start of planning, for each JDBC database in play.

+ */ +public class JdbcConvention extends Convention.Impl { + /** + * Cost of a JDBC node versus implementing an equivalent node in a "typical" + * calling convention. + */ + public static final double COST_MULTIPLIER = 0.8d; + + public final SqlDialect dialect; + public final Expression expression; + + public JdbcConvention(SqlDialect dialect, Expression expression, + String name) { + super("JDBC." + name, JdbcRel.class); + this.dialect = dialect; + this.expression = expression; + } + + public static JdbcConvention of(SqlDialect dialect, Expression expression, + String name) { + return new JdbcConvention(dialect, expression, name); + } + + @Override + public void register(RelOptPlanner planner) { + for (RelOptRule rule : JdbcRules.rules(this)) { + planner.addRule(rule); + } + planner.addRule(CoreRules.FILTER_SET_OP_TRANSPOSE); + planner.addRule(CoreRules.PROJECT_REMOVE); + } +} diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcImplementor.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcImplementor.java new file mode 100644 index 0000000..a8deafa --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcImplementor.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.util.Util; + +/** + * State for generating a SQL statement. + */ +public class JdbcImplementor extends RelToSqlConverter { + public JdbcImplementor(SqlDialect dialect, JavaTypeFactory typeFactory) { + super(dialect); + Util.discard(typeFactory); + } + + // CHECKSTYLE: IGNORE 1 + + /** + * @see #dispatch + */ + @SuppressWarnings("MissingSummary") + public Result visit(JdbcTableScan scan) { + return result(scan.jdbcTable.tableName(), + ImmutableList.of(Clause.FROM), scan, null); + } + + public Result implement(RelNode node) { + return dispatch(node); + } +} diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcRel.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcRel.java new file mode 100644 index 0000000..8739032 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcRel.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import org.apache.calcite.rel.RelNode; + +/** + * Relational expression that uses JDBC calling convention. + */ +public interface JdbcRel extends RelNode { + SqlImplementor.Result implement(JdbcImplementor implementor); +} diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcRules.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcRules.java new file mode 100644 index 0000000..1b923b8 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcRules.java @@ -0,0 +1,1276 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import static java.util.Objects.requireNonNull; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; +import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.plan.Contexts; +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.InvalidRelException; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.SingleRel; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Intersect; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Minus; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.core.Union; +import org.apache.calcite.rel.core.Values; +import org.apache.calcite.rel.metadata.RelMdUtil; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexMultisetUtil; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.schema.ModifiableTable; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Util; +import org.apache.calcite.util.trace.CalciteTrace; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; + +/** + * Rules and relational operators for + * {@link JdbcConvention} + * calling convention. + */ +public class JdbcRules { + private JdbcRules() { + } + + protected static final Logger LOGGER = CalciteTrace.getPlannerTracer(); + + static final RelFactories.ProjectFactory PROJECT_FACTORY = + (input, hints, projects, fieldNames) -> { + final RelOptCluster cluster = input.getCluster(); + final RelDataType rowType = + RexUtil.createStructType(cluster.getTypeFactory(), projects, + fieldNames, SqlValidatorUtil.F_SUGGESTER); + return new JdbcProject(cluster, input.getTraitSet(), input, projects, + rowType); + }; + + static final RelFactories.FilterFactory FILTER_FACTORY = + (input, condition, variablesSet) -> { + Preconditions.checkArgument(variablesSet.isEmpty(), + "JdbcFilter does not allow variables"); + return new JdbcFilter(input.getCluster(), + input.getTraitSet(), input, condition); + }; + + static final RelFactories.JoinFactory JOIN_FACTORY = + (left, right, hints, condition, variablesSet, joinType, semiJoinDone) -> { + final RelOptCluster cluster = left.getCluster(); + final RelTraitSet traitSet = cluster.traitSetOf( + requireNonNull(left.getConvention(), "left.getConvention()")); + try { + return new JdbcJoin(cluster, traitSet, left, right, condition, + variablesSet, joinType); + } catch (InvalidRelException e) { + throw new AssertionError(e); + } + }; + + static final RelFactories.CorrelateFactory CORRELATE_FACTORY = + (left, right, correlationId, requiredColumns, joinType) -> { + throw new UnsupportedOperationException("JdbcCorrelate"); + }; + + public static final RelFactories.SortFactory SORT_FACTORY = + (input, collation, offset, fetch) -> { + throw new UnsupportedOperationException("JdbcSort"); + }; + + public static final RelFactories.ExchangeFactory EXCHANGE_FACTORY = + (input, distribution) -> { + throw new UnsupportedOperationException("JdbcExchange"); + }; + + public static final RelFactories.SortExchangeFactory SORT_EXCHANGE_FACTORY = + (input, distribution, collation) -> { + throw new UnsupportedOperationException("JdbcSortExchange"); + }; + + public static final RelFactories.AggregateFactory AGGREGATE_FACTORY = + (input, hints, groupSet, groupSets, aggCalls) -> { + final RelOptCluster cluster = input.getCluster(); + final RelTraitSet traitSet = cluster.traitSetOf( + requireNonNull(input.getConvention(), "input.getConvention()")); + try { + return new JdbcAggregate(cluster, traitSet, input, groupSet, + groupSets, aggCalls); + } catch (InvalidRelException e) { + throw new AssertionError(e); + } + }; + + public static final RelFactories.MatchFactory MATCH_FACTORY = + (input, pattern, rowType, strictStart, strictEnd, patternDefinitions, + measures, after, subsets, allRows, partitionKeys, orderKeys, + interval) -> { + throw new UnsupportedOperationException("JdbcMatch"); + }; + + public static final RelFactories.SetOpFactory SET_OP_FACTORY = + (kind, inputs, all) -> { + RelNode input = inputs.get(0); + RelOptCluster cluster = input.getCluster(); + final RelTraitSet traitSet = cluster.traitSetOf( + requireNonNull(input.getConvention(), "input.getConvention()")); + switch (kind) { + case UNION: + return new JdbcUnion(cluster, traitSet, inputs, all); + case INTERSECT: + return new JdbcIntersect(cluster, traitSet, inputs, all); + case EXCEPT: + return new JdbcMinus(cluster, traitSet, inputs, all); + default: + throw new AssertionError("unknown: " + kind); + } + }; + + public static final RelFactories.ValuesFactory VALUES_FACTORY = + (cluster, rowType, tuples) -> { + throw new UnsupportedOperationException(); + }; + + public static final RelFactories.TableScanFactory TABLE_SCAN_FACTORY = + (toRelContext, table) -> { + throw new UnsupportedOperationException(); + }; + + public static final RelFactories.SnapshotFactory SNAPSHOT_FACTORY = + (input, period) -> { + throw new UnsupportedOperationException(); + }; + + /** + * A {@link RelBuilderFactory} that creates a {@link RelBuilder} that will + * create JDBC relational expressions for everything. + */ + public static final RelBuilderFactory JDBC_BUILDER = + RelBuilder.proto( + Contexts.of(PROJECT_FACTORY, + FILTER_FACTORY, + JOIN_FACTORY, + SORT_FACTORY, + EXCHANGE_FACTORY, + SORT_EXCHANGE_FACTORY, + AGGREGATE_FACTORY, + MATCH_FACTORY, + SET_OP_FACTORY, + VALUES_FACTORY, + TABLE_SCAN_FACTORY, + SNAPSHOT_FACTORY)); + + /** + * Creates a list of rules with the given JDBC convention instance. + */ + public static List rules(JdbcConvention out) { + final ImmutableList.Builder b = ImmutableList.builder(); + foreachRule(out, b::add); + return b.build(); + } + + /** + * Creates a list of rules with the given JDBC convention instance + * and builder factory. + */ + public static List rules(JdbcConvention out, + RelBuilderFactory relBuilderFactory) { + final ImmutableList.Builder b = ImmutableList.builder(); + foreachRule(out, r -> + b.add(r.config.withRelBuilderFactory(relBuilderFactory).toRule())); + return b.build(); + } + + private static void foreachRule(JdbcConvention out, + Consumer> consumer) { + consumer.accept(JdbcToEnumerableConverterRule.create(out)); + consumer.accept(JdbcJoinRule.create(out)); + consumer.accept(JdbcCalcRule.create(out)); + consumer.accept(JdbcProjectRule.create(out)); + consumer.accept(JdbcFilterRule.create(out)); + consumer.accept(JdbcAggregateRule.create(out)); + consumer.accept(JdbcSortRule.create(out)); + consumer.accept(JdbcUnionRule.create(out)); + consumer.accept(JdbcIntersectRule.create(out)); + consumer.accept(JdbcMinusRule.create(out)); + consumer.accept(JdbcTableModificationRule.create(out)); + consumer.accept(JdbcValuesRule.create(out)); + } + + /** + * Abstract base class for rule that converts to JDBC. + */ + abstract static class JdbcConverterRule extends ConverterRule { + protected JdbcConverterRule(Config config) { + super(config); + } + } + + /** + * Rule that converts a join to JDBC. + */ + public static class JdbcJoinRule extends JdbcConverterRule { + /** + * Creates a JdbcJoinRule. + */ + public static JdbcJoinRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Join.class, Convention.NONE, out, "JdbcJoinRule") + .withRuleFactory(JdbcJoinRule::new) + .toRule(JdbcJoinRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcJoinRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Join join = (Join) rel; + switch (join.getJoinType()) { + case SEMI: + case ANTI: + // It's not possible to convert semi-joins or anti-joins. They have fewer columns + // than regular joins. + return null; + default: + return convert(join, true); + } + } + + /** + * Converts a {@code Join} into a {@code JdbcJoin}. + * + * @param join Join operator to convert + * @param convertInputTraits Whether to convert input to {@code join}'s + * JDBC convention + * @return A new JdbcJoin + */ + public @Nullable RelNode convert(Join join, boolean convertInputTraits) { + final List newInputs = new ArrayList<>(); + for (RelNode input : join.getInputs()) { + if (convertInputTraits && input.getConvention() != getOutTrait()) { + input = + convert(input, + input.getTraitSet().replace(out)); + } + newInputs.add(input); + } + if (convertInputTraits && !canJoinOnCondition(join.getCondition())) { + return null; + } + try { + return new JdbcJoin( + join.getCluster(), + join.getTraitSet().replace(out), + newInputs.get(0), + newInputs.get(1), + join.getCondition(), + join.getVariablesSet(), + join.getJoinType()); + } catch (InvalidRelException e) { + LOGGER.debug(e.toString()); + return null; + } + } + + /** + * Returns whether a condition is supported by {@link JdbcJoin}. + * + *

Corresponds to the capabilities of + * {@link SqlImplementor#convertConditionToSqlNode}. + * + * @param node Condition + * @return Whether condition is supported + */ + private static boolean canJoinOnCondition(RexNode node) { + final List operands; + switch (node.getKind()) { + case AND: + case OR: + operands = ((RexCall) node).getOperands(); + for (RexNode operand : operands) { + if (!canJoinOnCondition(operand)) { + return false; + } + } + return true; + + case EQUALS: + case IS_NOT_DISTINCT_FROM: + case NOT_EQUALS: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + operands = ((RexCall) node).getOperands(); + if ((operands.get(0) instanceof RexInputRef) + && (operands.get(1) instanceof RexInputRef)) { + return true; + } + // fall through + + default: + return false; + } + } + + @Override + public boolean matches(RelOptRuleCall call) { + Join join = call.rel(0); + JoinRelType joinType = join.getJoinType(); + return ((JdbcConvention) getOutConvention()).dialect.supportsJoinType(joinType); + } + } + + /** + * Join operator implemented in JDBC convention. + */ + public static class JdbcJoin extends Join implements JdbcRel { + /** + * Creates a JdbcJoin. + */ + public JdbcJoin(RelOptCluster cluster, RelTraitSet traitSet, + RelNode left, RelNode right, RexNode condition, + Set variablesSet, JoinRelType joinType) + throws InvalidRelException { + super(cluster, traitSet, ImmutableList.of(), left, right, condition, variablesSet, joinType); + } + + @Deprecated // to be removed before 2.0 + protected JdbcJoin( + RelOptCluster cluster, + RelTraitSet traitSet, + RelNode left, + RelNode right, + RexNode condition, + JoinRelType joinType, + Set variablesStopped) + throws InvalidRelException { + this(cluster, traitSet, left, right, condition, + CorrelationId.setOf(variablesStopped), joinType); + } + + @Override + public JdbcJoin copy(RelTraitSet traitSet, RexNode condition, + RelNode left, RelNode right, JoinRelType joinType, + boolean semiJoinDone) { + try { + return new JdbcJoin(getCluster(), traitSet, left, right, + condition, variablesSet, joinType); + } catch (InvalidRelException e) { + // Semantic error not possible. Must be a bug. Convert to + // internal error. + throw new AssertionError(e); + } + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + // We always "build" the + double rowCount = mq.getRowCount(this); + + return planner.getCostFactory().makeCost(rowCount, 0, 0); + } + + @Override + public double estimateRowCount(RelMetadataQuery mq) { + final double leftRowCount = left.estimateRowCount(mq); + final double rightRowCount = right.estimateRowCount(mq); + return Math.max(leftRowCount, rightRowCount); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Calc} to an + * {@link JdbcCalc}. + */ + private static class JdbcCalcRule extends JdbcConverterRule { + /** + * Creates a JdbcCalcRule. + */ + public static JdbcCalcRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Calc.class, Convention.NONE, out, "JdbcCalcRule") + .withRuleFactory(JdbcCalcRule::new) + .toRule(JdbcCalcRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcCalcRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Calc calc = (Calc) rel; + + // If there's a multiset, let FarragoMultisetSplitter work on it + // first. + if (RexMultisetUtil.containsMultiset(calc.getProgram())) { + return null; + } + + return new JdbcCalc(rel.getCluster(), rel.getTraitSet().replace(out), + convert(calc.getInput(), calc.getTraitSet().replace(out)), + calc.getProgram()); + } + } + + /** + * Calc operator implemented in JDBC convention. + * + * @see Calc + */ + public static class JdbcCalc extends SingleRel implements JdbcRel { + private final RexProgram program; + + public JdbcCalc(RelOptCluster cluster, + RelTraitSet traitSet, + RelNode input, + RexProgram program) { + super(cluster, traitSet, input); + assert getConvention() instanceof JdbcConvention; + this.program = program; + this.rowType = program.getOutputRowType(); + } + + @Deprecated // to be removed before 2.0 + public JdbcCalc(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, + RexProgram program, int flags) { + this(cluster, traitSet, input, program); + Util.discard(flags); + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return program.explainCalc(super.explainTerms(pw)); + } + + @Override + public double estimateRowCount(RelMetadataQuery mq) { + return RelMdUtil.estimateFilteredRows(getInput(), program, mq); + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + double dRows = mq.getRowCount(this); + double dCpu = mq.getRowCount(getInput()) + * program.getExprCount(); + double dIo = 0; + return planner.getCostFactory().makeCost(dRows, dCpu, dIo); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new JdbcCalc(getCluster(), traitSet, sole(inputs), program); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Project} to + * an {@link JdbcProject}. + */ + public static class JdbcProjectRule extends JdbcConverterRule { + /** + * Creates a JdbcProjectRule. + */ + public static JdbcProjectRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Project.class, project -> + (out.dialect.supportsWindowFunctions() + || !project.containsOver()) + && !userDefinedFunctionInProject(project), + Convention.NONE, out, "JdbcProjectRule") + .withRuleFactory(JdbcProjectRule::new) + .toRule(JdbcProjectRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcProjectRule(Config config) { + super(config); + } + + private static boolean userDefinedFunctionInProject(Project project) { + CheckingUserDefinedFunctionVisitor visitor = new CheckingUserDefinedFunctionVisitor(); + for (RexNode node : project.getProjects()) { + node.accept(visitor); + if (visitor.containsUserDefinedFunction()) { + return true; + } + } + return false; + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Project project = (Project) rel; + + return new JdbcProject( + rel.getCluster(), + rel.getTraitSet().replace(out), + convert( + project.getInput(), + project.getInput().getTraitSet().replace(out)), + project.getProjects(), + project.getRowType()); + } + } + + /** + * Implementation of {@link Project} in + * {@link JdbcConvention jdbc calling convention}. + */ + public static class JdbcProject + extends Project + implements JdbcRel { + public JdbcProject( + RelOptCluster cluster, + RelTraitSet traitSet, + RelNode input, + List projects, + RelDataType rowType) { + super(cluster, traitSet, ImmutableList.of(), input, projects, rowType); + assert getConvention() instanceof JdbcConvention; + } + + @Deprecated // to be removed before 2.0 + public JdbcProject(RelOptCluster cluster, RelTraitSet traitSet, + RelNode input, List projects, RelDataType rowType, int flags) { + this(cluster, traitSet, input, projects, rowType); + Util.discard(flags); + } + + @Override + public JdbcProject copy(RelTraitSet traitSet, RelNode input, + List projects, RelDataType rowType) { + return new JdbcProject(getCluster(), traitSet, input, projects, rowType); + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + RelOptCost cost = super.computeSelfCost(planner, mq); + if (cost == null) { + return null; + } + return cost.multiplyBy(JdbcConvention.COST_MULTIPLIER); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Filter} to + * an {@link JdbcFilter}. + */ + public static class JdbcFilterRule extends JdbcConverterRule { + /** + * Creates a JdbcFilterRule. + */ + public static JdbcFilterRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Filter.class, r -> !userDefinedFunctionInFilter(r), + Convention.NONE, out, "JdbcFilterRule") + .withRuleFactory(JdbcFilterRule::new) + .toRule(JdbcFilterRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcFilterRule(Config config) { + super(config); + } + + private static boolean userDefinedFunctionInFilter(Filter filter) { + CheckingUserDefinedFunctionVisitor visitor = new CheckingUserDefinedFunctionVisitor(); + filter.getCondition().accept(visitor); + return visitor.containsUserDefinedFunction(); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Filter filter = (Filter) rel; + + return new JdbcFilter( + rel.getCluster(), + rel.getTraitSet().replace(out), + convert(filter.getInput(), + filter.getInput().getTraitSet().replace(out)), + filter.getCondition()); + } + } + + /** + * Implementation of {@link Filter} in + * {@link JdbcConvention jdbc calling convention}. + */ + public static class JdbcFilter extends Filter implements JdbcRel { + public JdbcFilter( + RelOptCluster cluster, + RelTraitSet traitSet, + RelNode input, + RexNode condition) { + super(cluster, traitSet, input, condition); + assert getConvention() instanceof JdbcConvention; + } + + @Override + public JdbcFilter copy(RelTraitSet traitSet, RelNode input, + RexNode condition) { + return new JdbcFilter(getCluster(), traitSet, input, condition); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Aggregate} + * to a {@link JdbcAggregate}. + */ + public static class JdbcAggregateRule extends JdbcConverterRule { + /** + * Creates a JdbcAggregateRule. + */ + public static JdbcAggregateRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Aggregate.class, Convention.NONE, out, + "JdbcAggregateRule") + .withRuleFactory(JdbcAggregateRule::new) + .toRule(JdbcAggregateRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcAggregateRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Aggregate agg = (Aggregate) rel; + if (agg.getGroupSets().size() != 1) { + // GROUPING SETS not supported; see + // [CALCITE-734] Push GROUPING SETS to underlying SQL via JDBC adapter + return null; + } + final RelTraitSet traitSet = + agg.getTraitSet().replace(out); + try { + return new JdbcAggregate(rel.getCluster(), traitSet, + convert(agg.getInput(), out), agg.getGroupSet(), + agg.getGroupSets(), agg.getAggCallList()); + } catch (InvalidRelException e) { + LOGGER.debug(e.toString()); + return null; + } + } + } + + /** + * Returns whether this JDBC data source can implement a given aggregate + * function. + */ + private static boolean canImplement(AggregateCall aggregateCall, + SqlDialect sqlDialect) { + return sqlDialect.supportsAggregateFunction( + aggregateCall.getAggregation().getKind()) + && aggregateCall.distinctKeys == null; + } + + /** + * Aggregate operator implemented in JDBC convention. + */ + public static class JdbcAggregate extends Aggregate implements JdbcRel { + public JdbcAggregate( + RelOptCluster cluster, + RelTraitSet traitSet, + RelNode input, + ImmutableBitSet groupSet, + @Nullable List groupSets, + List aggCalls) + throws InvalidRelException { + super(cluster, traitSet, ImmutableList.of(), input, groupSet, groupSets, aggCalls); + assert getConvention() instanceof JdbcConvention; + assert this.groupSets.size() == 1 : "Grouping sets not supported"; + final SqlDialect dialect = ((JdbcConvention) getConvention()).dialect; + for (AggregateCall aggCall : aggCalls) { + if (!canImplement(aggCall, dialect)) { + throw new InvalidRelException("cannot implement aggregate function " + + aggCall); + } + if (aggCall.hasFilter() && !dialect.supportsAggregateFunctionFilter()) { + throw new InvalidRelException("dialect does not support aggregate " + + "functions FILTER clauses"); + } + } + } + + @Deprecated // to be removed before 2.0 + public JdbcAggregate(RelOptCluster cluster, RelTraitSet traitSet, + RelNode input, boolean indicator, ImmutableBitSet groupSet, + List groupSets, List aggCalls) + throws InvalidRelException { + this(cluster, traitSet, input, groupSet, groupSets, aggCalls); + checkIndicator(indicator); + } + + @Override + public JdbcAggregate copy(RelTraitSet traitSet, RelNode input, + ImmutableBitSet groupSet, + @Nullable List groupSets, + List aggCalls) { + try { + return new JdbcAggregate(getCluster(), traitSet, input, + groupSet, groupSets, aggCalls); + } catch (InvalidRelException e) { + // Semantic error not possible. Must be a bug. Convert to + // internal error. + throw new AssertionError(e); + } + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Sort} to an + * {@link JdbcSort}. + */ + public static class JdbcSortRule extends JdbcConverterRule { + /** + * Creates a JdbcSortRule. + */ + public static JdbcSortRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Sort.class, Convention.NONE, out, "JdbcSortRule") + .withRuleFactory(JdbcSortRule::new) + .toRule(JdbcSortRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcSortRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + return convert((Sort) rel, true); + } + + /** + * Converts a {@code Sort} into a {@code JdbcSort}. + * + * @param sort Sort operator to convert + * @param convertInputTraits Whether to convert input to {@code sort}'s + * JDBC convention + * @return A new JdbcSort + */ + public RelNode convert(Sort sort, boolean convertInputTraits) { + final RelTraitSet traitSet = sort.getTraitSet().replace(out); + + final RelNode input; + if (convertInputTraits) { + final RelTraitSet inputTraitSet = sort.getInput().getTraitSet().replace(out); + input = convert(sort.getInput(), inputTraitSet); + } else { + input = sort.getInput(); + } + + return new JdbcSort(sort.getCluster(), traitSet, + input, sort.getCollation(), sort.offset, sort.fetch); + } + } + + /** + * Sort operator implemented in JDBC convention. + */ + public static class JdbcSort + extends Sort + implements JdbcRel { + public JdbcSort( + RelOptCluster cluster, + RelTraitSet traitSet, + RelNode input, + RelCollation collation, + @Nullable RexNode offset, + @Nullable RexNode fetch) { + super(cluster, traitSet, input, collation, offset, fetch); + assert getConvention() instanceof JdbcConvention; + assert getConvention() == input.getConvention(); + } + + @Override + public JdbcSort copy(RelTraitSet traitSet, RelNode newInput, + RelCollation newCollation, @Nullable RexNode offset, + @Nullable RexNode fetch) { + return new JdbcSort(getCluster(), traitSet, newInput, newCollation, + offset, fetch); + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + RelOptCost cost = super.computeSelfCost(planner, mq); + if (cost == null) { + return null; + } + return cost.multiplyBy(0.9); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert an {@link Union} to a + * {@link JdbcUnion}. + */ + public static class JdbcUnionRule extends JdbcConverterRule { + /** + * Creates a JdbcUnionRule. + */ + public static JdbcUnionRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Union.class, Convention.NONE, out, "JdbcUnionRule") + .withRuleFactory(JdbcUnionRule::new) + .toRule(JdbcUnionRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcUnionRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Union union = (Union) rel; + final RelTraitSet traitSet = + union.getTraitSet().replace(out); + return new JdbcUnion(rel.getCluster(), traitSet, + convertList(union.getInputs(), out), union.all); + } + } + + /** + * Union operator implemented in JDBC convention. + */ + public static class JdbcUnion extends Union implements JdbcRel { + public JdbcUnion( + RelOptCluster cluster, + RelTraitSet traitSet, + List inputs, + boolean all) { + super(cluster, traitSet, inputs, all); + } + + @Override + public JdbcUnion copy( + RelTraitSet traitSet, List inputs, boolean all) { + return new JdbcUnion(getCluster(), traitSet, inputs, all); + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + RelOptCost cost = super.computeSelfCost(planner, mq); + if (cost == null) { + return null; + } + return cost.multiplyBy(.1); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Intersect} + * to a {@link JdbcIntersect}. + */ + public static class JdbcIntersectRule extends JdbcConverterRule { + /** + * Creates a JdbcIntersectRule. + */ + public static JdbcIntersectRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Intersect.class, Convention.NONE, out, + "JdbcIntersectRule") + .withRuleFactory(JdbcIntersectRule::new) + .toRule(JdbcIntersectRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcIntersectRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Intersect intersect = (Intersect) rel; + if (intersect.all) { + return null; // INTERSECT ALL not implemented + } + final RelTraitSet traitSet = + intersect.getTraitSet().replace(out); + return new JdbcIntersect(rel.getCluster(), traitSet, + convertList(intersect.getInputs(), out), false); + } + } + + /** + * Intersect operator implemented in JDBC convention. + */ + public static class JdbcIntersect + extends Intersect + implements JdbcRel { + public JdbcIntersect( + RelOptCluster cluster, + RelTraitSet traitSet, + List inputs, + boolean all) { + super(cluster, traitSet, inputs, all); + assert !all; + } + + @Override + public JdbcIntersect copy( + RelTraitSet traitSet, List inputs, boolean all) { + return new JdbcIntersect(getCluster(), traitSet, inputs, all); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule to convert a {@link Minus} to a + * {@link JdbcMinus}. + */ + public static class JdbcMinusRule extends JdbcConverterRule { + /** + * Creates a JdbcMinusRule. + */ + public static JdbcMinusRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Minus.class, Convention.NONE, out, "JdbcMinusRule") + .withRuleFactory(JdbcMinusRule::new) + .toRule(JdbcMinusRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcMinusRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final Minus minus = (Minus) rel; + if (minus.all) { + return null; // EXCEPT ALL not implemented + } + final RelTraitSet traitSet = + rel.getTraitSet().replace(out); + return new JdbcMinus(rel.getCluster(), traitSet, + convertList(minus.getInputs(), out), false); + } + } + + /** + * Minus operator implemented in JDBC convention. + */ + public static class JdbcMinus extends Minus implements JdbcRel { + public JdbcMinus(RelOptCluster cluster, RelTraitSet traitSet, + List inputs, boolean all) { + super(cluster, traitSet, inputs, all); + assert !all; + } + + @Override + public JdbcMinus copy(RelTraitSet traitSet, List inputs, + boolean all) { + return new JdbcMinus(getCluster(), traitSet, inputs, all); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule that converts a table-modification to JDBC. + */ + public static class JdbcTableModificationRule extends JdbcConverterRule { + /** + * Creates a JdbcToEnumerableConverterRule. + */ + public static JdbcTableModificationRule create( + JdbcConvention out) { + return Config.INSTANCE + .withConversion(TableModify.class, Convention.NONE, out, + "JdbcTableModificationRule") + .withRuleFactory(JdbcTableModificationRule::new) + .toRule(JdbcTableModificationRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcTableModificationRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + final TableModify modify = + (TableModify) rel; + final ModifiableTable modifiableTable = + modify.getTable().unwrap(ModifiableTable.class); + if (modifiableTable == null) { + return null; + } + final RelTraitSet traitSet = + modify.getTraitSet().replace(out); + return new JdbcTableModify( + modify.getCluster(), traitSet, + modify.getTable(), + modify.getCatalogReader(), + convert(modify.getInput(), traitSet), + modify.getOperation(), + modify.getUpdateColumnList(), + modify.getSourceExpressionList(), + modify.isFlattened()); + } + } + + /** + * Table-modification operator implemented in JDBC convention. + */ + public static class JdbcTableModify extends TableModify implements JdbcRel { + public JdbcTableModify(RelOptCluster cluster, + RelTraitSet traitSet, + RelOptTable table, + Prepare.CatalogReader catalogReader, + RelNode input, + Operation operation, + @Nullable List updateColumnList, + @Nullable List sourceExpressionList, + boolean flattened) { + super(cluster, traitSet, table, catalogReader, input, operation, + updateColumnList, sourceExpressionList, flattened); + assert input.getConvention() instanceof JdbcConvention; + assert getConvention() instanceof JdbcConvention; + final ModifiableTable modifiableTable = + table.unwrap(ModifiableTable.class); + if (modifiableTable == null) { + throw new AssertionError(); // TODO: user error in validator + } + Expression expression = table.getExpression(Queryable.class); + if (expression == null) { + throw new AssertionError(); // TODO: user error in validator + } + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + RelOptCost cost = super.computeSelfCost(planner, mq); + if (cost == null) { + return null; + } + return cost.multiplyBy(.1); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new JdbcTableModify( + getCluster(), traitSet, getTable(), getCatalogReader(), + sole(inputs), getOperation(), getUpdateColumnList(), + getSourceExpressionList(), isFlattened()); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Rule that converts a values operator to JDBC. + */ + public static class JdbcValuesRule extends JdbcConverterRule { + /** + * Creates a JdbcValuesRule. + */ + public static JdbcValuesRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(Values.class, Convention.NONE, out, "JdbcValuesRule") + .withRuleFactory(JdbcValuesRule::new) + .toRule(JdbcValuesRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcValuesRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + Values values = (Values) rel; + return new JdbcValues(values.getCluster(), values.getRowType(), + values.getTuples(), values.getTraitSet().replace(out)); + } + } + + /** + * Values operator implemented in JDBC convention. + */ + public static class JdbcValues extends Values implements JdbcRel { + JdbcValues(RelOptCluster cluster, RelDataType rowType, + ImmutableList> tuples, RelTraitSet traitSet) { + super(cluster, rowType, tuples, traitSet); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + assert inputs.isEmpty(); + return new JdbcValues(getCluster(), getRowType(), tuples, traitSet); + } + + @Override + public SqlImplementor.Result implement(JdbcImplementor implementor) { + return implementor.implement(this); + } + } + + /** + * Visitor that checks whether part of a projection is a user-defined + * function (UDF). + */ + private static class CheckingUserDefinedFunctionVisitor + extends RexVisitorImpl { + + private boolean containsUsedDefinedFunction = false; + + CheckingUserDefinedFunctionVisitor() { + super(true); + } + + public boolean containsUserDefinedFunction() { + return containsUsedDefinedFunction; + } + + @Override + public Void visitCall(RexCall call) { + SqlOperator operator = call.getOperator(); + if (operator instanceof SqlFunction + && ((SqlFunction) operator).getFunctionType().isUserDefined()) { + containsUsedDefinedFunction |= true; + } + return super.visitCall(call); + } + + } + +} diff --git a/traindb-core/src/main/java/traindb/schema/JdbcTableScan.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcTableScan.java similarity index 87% rename from traindb-core/src/main/java/traindb/schema/JdbcTableScan.java rename to traindb-core/src/main/java/traindb/adapter/jdbc/JdbcTableScan.java index ba818c0..29a8c31 100644 --- a/traindb-core/src/main/java/traindb/schema/JdbcTableScan.java +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcTableScan.java @@ -12,15 +12,12 @@ * limitations under the License. */ -package traindb.schema; +package traindb.adapter.jdbc; import static org.apache.calcite.linq4j.Nullness.castNonNull; import com.google.common.collect.ImmutableList; import java.util.List; -import org.apache.calcite.adapter.jdbc.JdbcConvention; -import org.apache.calcite.adapter.jdbc.JdbcImplementor; -import org.apache.calcite.adapter.jdbc.JdbcRel; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitSet; @@ -47,7 +44,7 @@ public RelNode copy(RelTraitSet traitSet, List inputs) { } @Override - public JdbcImplementor.Result implement(JdbcImplementor implementor) { + public SqlImplementor.Result implement(JdbcImplementor implementor) { return implementor.result(jdbcTable.tableName(), ImmutableList.of(JdbcImplementor.Clause.FROM), this, null); } diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcToEnumerableConverter.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcToEnumerableConverter.java new file mode 100644 index 0000000..59453cd --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcToEnumerableConverter.java @@ -0,0 +1,358 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.linq4j.Nullness.castNonNull; + +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.TimeZone; +import java.util.stream.Collectors; +import org.apache.calcite.DataContext; +import org.apache.calcite.adapter.enumerable.EnumerableRel; +import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor; +import org.apache.calcite.adapter.enumerable.JavaRowFormat; +import org.apache.calcite.adapter.enumerable.PhysType; +import org.apache.calcite.adapter.enumerable.PhysTypeImpl; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.config.CalciteSystemProperty; +import org.apache.calcite.linq4j.tree.BlockBuilder; +import org.apache.calcite.linq4j.tree.ConstantExpression; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.ParameterExpression; +import org.apache.calcite.linq4j.tree.Primitive; +import org.apache.calcite.linq4j.tree.UnaryExpression; +import org.apache.calcite.plan.ConventionTraitDef; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterImpl; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.runtime.Hook; +import org.apache.calcite.runtime.SqlFunctions; +import org.apache.calcite.schema.Schemas; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.util.SqlString; +import org.apache.calcite.util.BuiltInMethod; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Relational expression representing a scan of a table in a JDBC data source. + */ +public class JdbcToEnumerableConverter + extends ConverterImpl + implements EnumerableRel { + protected JdbcToEnumerableConverter( + RelOptCluster cluster, + RelTraitSet traits, + RelNode input) { + super(cluster, ConventionTraitDef.INSTANCE, traits, input); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new JdbcToEnumerableConverter( + getCluster(), traitSet, sole(inputs)); + } + + @Override + public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, + RelMetadataQuery mq) { + RelOptCost cost = super.computeSelfCost(planner, mq); + if (cost == null) { + return null; + } + return cost.multiplyBy(.1); + } + + @Override + public Result implement(EnumerableRelImplementor implementor, Prefer pref) { + // Generate: + // ResultSetEnumerable.of(schema.getDataSource(), "select ...") + final BlockBuilder builder0 = new BlockBuilder(false); + final JdbcRel child = (JdbcRel) getInput(); + final PhysType physType = + PhysTypeImpl.of( + implementor.getTypeFactory(), getRowType(), + pref.prefer(JavaRowFormat.CUSTOM)); + final JdbcConvention jdbcConvention = + (JdbcConvention) requireNonNull(child.getConvention(), + () -> "child.getConvention() is null for " + child); + SqlString sqlString = generateSql(jdbcConvention.dialect); + String sql = sqlString.getSql(); + if (CalciteSystemProperty.DEBUG.value()) { + System.out.println("[" + sql + "]"); + } + Hook.QUERY_PLAN.run(sql); + final Expression sql_ = + builder0.append("sql", Expressions.constant(sql)); + final int fieldCount = getRowType().getFieldCount(); + BlockBuilder builder = new BlockBuilder(); + final ParameterExpression resultSet_ = + Expressions.parameter(Modifier.FINAL, ResultSet.class, + builder.newName("resultSet")); + final SqlDialect.CalendarPolicy calendarPolicy = + jdbcConvention.dialect.getCalendarPolicy(); + final Expression calendar_; + switch (calendarPolicy) { + case LOCAL: + calendar_ = + builder0.append("calendar", + Expressions.call(Calendar.class, "getInstance", + getTimeZoneExpression(implementor))); + break; + default: + calendar_ = null; + } + if (fieldCount == 1) { + final ParameterExpression value_ = + Expressions.parameter(Object.class, builder.newName("value")); + builder.add(Expressions.declare(Modifier.FINAL, value_, null)); + generateGet(implementor, physType, builder, resultSet_, 0, value_, + calendar_, calendarPolicy); + builder.add(Expressions.return_(null, value_)); + } else { + final Expression values_ = + builder.append("values", + Expressions.newArrayBounds(Object.class, 1, + Expressions.constant(fieldCount))); + for (int i = 0; i < fieldCount; i++) { + generateGet(implementor, physType, builder, resultSet_, i, + Expressions.arrayIndex(values_, Expressions.constant(i)), + calendar_, calendarPolicy); + } + builder.add( + Expressions.return_(null, values_)); + } + final ParameterExpression e_ = + Expressions.parameter(SQLException.class, builder.newName("e")); + final Expression rowBuilderFactory_ = + builder0.append("rowBuilderFactory", + Expressions.lambda( + Expressions.block( + Expressions.return_(null, + Expressions.lambda( + Expressions.block( + Expressions.tryCatch( + builder.toBlock(), + Expressions.catch_( + e_, + Expressions.throw_( + Expressions.new_( + RuntimeException.class, + e_)))))))), + resultSet_)); + + final Expression enumerable; + + if (sqlString.getDynamicParameters() != null + && !sqlString.getDynamicParameters().isEmpty()) { + final Expression preparedStatementConsumer_ = + builder0.append("preparedStatementConsumer", + Expressions.call(BuiltInMethod.CREATE_ENRICHER.method, + Expressions.newArrayInit(Integer.class, 1, + toIndexesTableExpression(sqlString)), + DataContext.ROOT)); + + enumerable = builder0.append("enumerable", + Expressions.call( + BuiltInMethod.RESULT_SET_ENUMERABLE_OF_PREPARED.method, + Expressions.call( + Schemas.unwrap(jdbcConvention.expression, TrainDBJdbcDataSource.class), + BuiltInMethod.JDBC_SCHEMA_DATA_SOURCE.method), + sql_, + rowBuilderFactory_, + preparedStatementConsumer_)); + } else { + enumerable = builder0.append("enumerable", + Expressions.call( + BuiltInMethod.RESULT_SET_ENUMERABLE_OF.method, + Expressions.call( + Schemas.unwrap(jdbcConvention.expression, TrainDBJdbcDataSource.class), + BuiltInMethod.JDBC_SCHEMA_DATA_SOURCE.method), + sql_, + rowBuilderFactory_)); + } + builder0.add( + Expressions.statement( + Expressions.call(enumerable, + BuiltInMethod.RESULT_SET_ENUMERABLE_SET_TIMEOUT.method, + DataContext.ROOT))); + builder0.add( + Expressions.return_(null, enumerable)); + return implementor.result(physType, builder0.toBlock()); + } + + private static List toIndexesTableExpression(SqlString sqlString) { + return requireNonNull(sqlString.getDynamicParameters(), + () -> "sqlString.getDynamicParameters() is null for " + sqlString).stream() + .map(Expressions::constant) + .collect(Collectors.toList()); + } + + private static UnaryExpression getTimeZoneExpression( + EnumerableRelImplementor implementor) { + return Expressions.convert_( + Expressions.call( + implementor.getRootExpression(), + "get", + Expressions.constant("timeZone")), + TimeZone.class); + } + + private static void generateGet(EnumerableRelImplementor implementor, + PhysType physType, BlockBuilder builder, + ParameterExpression resultSet_, + int i, Expression target, @Nullable Expression calendar_, + SqlDialect.CalendarPolicy calendarPolicy) { + final Primitive primitive = Primitive.ofBoxOr(physType.fieldClass(i)); + final RelDataType fieldType = + physType.getRowType().getFieldList().get(i).getType(); + final List dateTimeArgs = new ArrayList<>(); + dateTimeArgs.add(Expressions.constant(i + 1)); + SqlTypeName sqlTypeName = fieldType.getSqlTypeName(); + boolean offset = false; + switch (calendarPolicy) { + case LOCAL: + assert calendar_ != null : "calendar must not be null"; + dateTimeArgs.add(calendar_); + break; + case NULL: + // We don't specify a calendar at all, so we don't add an argument and + // instead use the version of the getXXX that doesn't take a Calendar + break; + case DIRECT: + sqlTypeName = SqlTypeName.ANY; + break; + case SHIFT: + switch (sqlTypeName) { + case TIMESTAMP: + case DATE: + offset = true; + break; + default: + break; + } + break; + default: + break; + } + final Expression source; + switch (sqlTypeName) { + case DATE: + case TIME: + case TIMESTAMP: + source = Expressions.call( + getMethod(sqlTypeName, fieldType.isNullable(), offset), + Expressions.list() + .append( + Expressions.call(resultSet_, + getMethod2(sqlTypeName), dateTimeArgs)) + .appendIf(offset, getTimeZoneExpression(implementor))); + break; + case ARRAY: + final Expression x = Expressions.convert_( + Expressions.call(resultSet_, jdbcGetMethod(primitive), + Expressions.constant(i + 1)), + java.sql.Array.class); + source = Expressions.call(BuiltInMethod.JDBC_ARRAY_TO_LIST.method, x); + break; + default: + source = Expressions.call( + resultSet_, jdbcGetMethod(primitive), Expressions.constant(i + 1)); + } + builder.add( + Expressions.statement( + Expressions.assign( + target, source))); + + // [CALCITE-596] If primitive type columns contain null value, returns null + // object + if (primitive != null) { + builder.add( + Expressions.ifThen( + Expressions.call(resultSet_, "wasNull"), + Expressions.statement( + Expressions.assign(target, + Expressions.constant(null))))); + } + } + + private static Method getMethod(SqlTypeName sqlTypeName, boolean nullable, + boolean offset) { + switch (sqlTypeName) { + case DATE: + return (nullable + ? BuiltInMethod.DATE_TO_INT_OPTIONAL + : BuiltInMethod.DATE_TO_INT).method; + case TIME: + return (nullable + ? BuiltInMethod.TIME_TO_INT_OPTIONAL + : BuiltInMethod.TIME_TO_INT).method; + case TIMESTAMP: + return (nullable + ? (offset + ? BuiltInMethod.TIMESTAMP_TO_LONG_OPTIONAL_OFFSET + : BuiltInMethod.TIMESTAMP_TO_LONG_OPTIONAL) + : (offset + ? BuiltInMethod.TIMESTAMP_TO_LONG_OFFSET + : BuiltInMethod.TIMESTAMP_TO_LONG)).method; + default: + throw new AssertionError(sqlTypeName + ":" + nullable); + } + } + + private static Method getMethod2(SqlTypeName sqlTypeName) { + switch (sqlTypeName) { + case DATE: + return BuiltInMethod.RESULT_SET_GET_DATE2.method; + case TIME: + return BuiltInMethod.RESULT_SET_GET_TIME2.method; + case TIMESTAMP: + return BuiltInMethod.RESULT_SET_GET_TIMESTAMP2.method; + default: + throw new AssertionError(sqlTypeName); + } + } + + /** + * E,g, {@code jdbcGetMethod(int)} returns "getInt". + */ + private static String jdbcGetMethod(@Nullable Primitive primitive) { + return primitive == null + ? "getObject" + : "get" + SqlFunctions.initcap(castNonNull(primitive.primitiveName)); + } + + private SqlString generateSql(SqlDialect dialect) { + final JdbcImplementor jdbcImplementor = + new JdbcImplementor(dialect, + (JavaTypeFactory) getCluster().getTypeFactory()); + final SqlImplementor.Result result = + jdbcImplementor.visitInput(this, 0); + return result.asStatement().toSqlString(dialect); + } +} diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcToEnumerableConverterRule.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcToEnumerableConverterRule.java new file mode 100644 index 0000000..6d5e604 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcToEnumerableConverterRule.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import org.apache.calcite.adapter.enumerable.EnumerableConvention; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Rule to convert a relational expression from + * {@link org.apache.calcite.adapter.jdbc.JdbcConvention} to + * {@link EnumerableConvention}. + */ +public class JdbcToEnumerableConverterRule extends ConverterRule { + /** + * Creates a JdbcToEnumerableConverterRule. + */ + public static JdbcToEnumerableConverterRule create(JdbcConvention out) { + return Config.INSTANCE + .withConversion(RelNode.class, out, EnumerableConvention.INSTANCE, + "JdbcToEnumerableConverterRule") + .withRuleFactory(JdbcToEnumerableConverterRule::new) + .toRule(JdbcToEnumerableConverterRule.class); + } + + /** + * Called from the Config. + */ + protected JdbcToEnumerableConverterRule(Config config) { + super(config); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + RelTraitSet newTraitSet = rel.getTraitSet().replace(getOutTrait()); + return new JdbcToEnumerableConverter(rel.getCluster(), newTraitSet, rel); + } +} diff --git a/traindb-core/src/main/java/traindb/schema/JdbcUtils.java b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcUtils.java similarity index 99% rename from traindb-core/src/main/java/traindb/schema/JdbcUtils.java rename to traindb-core/src/main/java/traindb/adapter/jdbc/JdbcUtils.java index b39e0cc..76cb33a 100644 --- a/traindb-core/src/main/java/traindb/schema/JdbcUtils.java +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/JdbcUtils.java @@ -12,7 +12,7 @@ * limitations under the License. */ -package traindb.schema; +package traindb.adapter.jdbc; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -288,6 +288,7 @@ public DataSource get(String url, @Nullable String driverClassName, return cache.getUnchecked(key); } } + public static void close( @Nullable Connection connection, @Nullable Statement statement, diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/RelToSqlConverter.java b/traindb-core/src/main/java/traindb/adapter/jdbc/RelToSqlConverter.java new file mode 100644 index 0000000..e096f23 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/RelToSqlConverter.java @@ -0,0 +1,1101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.rex.RexLiteral.stringValue; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Ordering; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Deque; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.stream.Collectors; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Intersect; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Match; +import org.apache.calcite.rel.core.Minus; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.TableFunctionScan; +import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.core.Uncollect; +import org.apache.calcite.rel.core.Union; +import org.apache.calcite.rel.core.Values; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexLocalRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.sql.JoinConditionType; +import org.apache.calcite.sql.JoinType; +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDelete; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlInsert; +import org.apache.calcite.sql.SqlIntervalLiteral; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlMatchRecognize; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlUpdate; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.fun.SqlInternalOperators; +import org.apache.calcite.sql.fun.SqlSingleValueAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.util.SqlShuttle; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Permutation; +import org.apache.calcite.util.ReflectUtil; +import org.apache.calcite.util.ReflectiveVisitor; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Utility to convert relational expressions to SQL abstract syntax tree. + */ +public class RelToSqlConverter extends SqlImplementor + implements ReflectiveVisitor { + private final ReflectUtil.MethodDispatcher dispatcher; + + private final Deque stack = new ArrayDeque<>(); + + /** + * Creates a RelToSqlConverter. + */ + @SuppressWarnings("argument.type.incompatible") + public RelToSqlConverter(SqlDialect dialect) { + super(dialect); + dispatcher = ReflectUtil.createMethodDispatcher(Result.class, this, "visit", + RelNode.class); + } + + /** + * Dispatches a call to the {@code visit(Xxx e)} method where {@code Xxx} + * most closely matches the runtime type of the argument. + */ + protected Result dispatch(RelNode e) { + return dispatcher.invoke(e); + } + + @Override + public Result visitInput(RelNode parent, int i, boolean anon, + boolean ignoreClauses, Set expectedClauses) { + try { + final RelNode e = parent.getInput(i); + stack.push(new Frame(parent, i, e, anon, ignoreClauses, expectedClauses)); + return dispatch(e); + } finally { + stack.pop(); + } + } + + @Override + protected boolean isAnon() { + Frame peek = stack.peek(); + return peek == null || peek.anon; + } + + @Override + protected Result result(SqlNode node, Collection clauses, + @Nullable String neededAlias, @Nullable RelDataType neededType, + Map aliases) { + final Frame frame = requireNonNull(stack.peek()); + return super.result(node, clauses, neededAlias, neededType, aliases) + .withAnon(isAnon()) + .withExpectedClauses(frame.ignoreClauses, frame.expectedClauses, + frame.parent); + } + + /** + * Visits a RelNode; called by {@link #dispatch} via reflection. + */ + public Result visit(RelNode e) { + throw new AssertionError("Need to implement " + e.getClass().getName()); + } + + /** + * A SqlShuttle to replace references to a column of a table alias with the expression + * from the select item that is the source of that column. + * ANTI- and SEMI-joins generate an alias for right hand side relation which + * is used in the ON condition. But that alias is never created, so we have to inline references. + */ + private static class AliasReplacementShuttle extends SqlShuttle { + private final String tableAlias; + private final RelDataType tableType; + private final SqlNodeList replaceSource; + + AliasReplacementShuttle(String tableAlias, RelDataType tableType, + SqlNodeList replaceSource) { + this.tableAlias = tableAlias; + this.tableType = tableType; + this.replaceSource = replaceSource; + } + + @Override + public SqlNode visit(SqlIdentifier id) { + if (tableAlias.equals(id.names.get(0))) { + int index = requireNonNull( + tableType.getField(id.names.get(1), false, false), + () -> "field " + id.names.get(1) + " is not found in " + tableType) + .getIndex(); + SqlNode selectItem = requireNonNull(replaceSource, "replaceSource").get(index); + if (selectItem.getKind() == SqlKind.AS) { + selectItem = ((SqlCall) selectItem).operand(0); + } + return selectItem.clone(id.getParserPosition()); + } + return id; + } + } + + /** + * Visits a Join; called by {@link #dispatch} via reflection. + */ + public Result visit(Join e) { + switch (e.getJoinType()) { + case ANTI: + case SEMI: + return visitAntiOrSemiJoin(e); + default: + break; + } + final Result leftResult = visitInput(e, 0).resetAlias(); + final Result rightResult = visitInput(e, 1).resetAlias(); + final Context leftContext = leftResult.qualifiedContext(); + final Context rightContext = rightResult.qualifiedContext(); + final SqlNode sqlCondition; + SqlLiteral condType = JoinConditionType.ON.symbol(POS); + JoinType joinType = joinType(e.getJoinType()); + if (isCrossJoin(e)) { + sqlCondition = null; + joinType = dialect.emulateJoinTypeForCrossJoin(); + condType = JoinConditionType.NONE.symbol(POS); + } else { + sqlCondition = + convertConditionToSqlNode(e.getCondition(), leftContext, + rightContext); + } + SqlNode join = + new SqlJoin(POS, + leftResult.asFrom(), + SqlLiteral.createBoolean(false, POS), + joinType.symbol(POS), + rightResult.asFrom(), + condType, + sqlCondition); + return result(join, leftResult, rightResult); + } + + protected Result visitAntiOrSemiJoin(Join e) { + final Result leftResult = visitInput(e, 0).resetAlias(); + final Result rightResult = visitInput(e, 1).resetAlias(); + final Context leftContext = leftResult.qualifiedContext(); + final Context rightContext = rightResult.qualifiedContext(); + + final SqlSelect sqlSelect = leftResult.asSelect(); + SqlNode sqlCondition = + convertConditionToSqlNode(e.getCondition(), leftContext, rightContext); + if (leftResult.neededAlias != null) { + SqlShuttle visitor = new AliasReplacementShuttle(leftResult.neededAlias, + e.getLeft().getRowType(), sqlSelect.getSelectList()); + sqlCondition = sqlCondition.accept(visitor); + } + SqlNode fromPart = rightResult.asFrom(); + SqlSelect existsSqlSelect; + if (fromPart.getKind() == SqlKind.SELECT) { + existsSqlSelect = (SqlSelect) fromPart; + existsSqlSelect.setSelectList( + new SqlNodeList(ImmutableList.of(SqlLiteral.createExactNumeric("1", POS)), POS)); + if (existsSqlSelect.getWhere() != null) { + sqlCondition = SqlStdOperatorTable.AND.createCall(POS, + existsSqlSelect.getWhere(), + sqlCondition); + } + existsSqlSelect.setWhere(sqlCondition); + } else { + existsSqlSelect = + new SqlSelect(POS, null, + new SqlNodeList( + ImmutableList.of(SqlLiteral.createExactNumeric("1", POS)), POS), + fromPart, sqlCondition, null, + null, null, null, null, null, null); + } + sqlCondition = SqlStdOperatorTable.EXISTS.createCall(POS, existsSqlSelect); + if (e.getJoinType() == JoinRelType.ANTI) { + sqlCondition = SqlStdOperatorTable.NOT.createCall(POS, sqlCondition); + } + if (sqlSelect.getWhere() != null) { + sqlCondition = SqlStdOperatorTable.AND.createCall(POS, + sqlSelect.getWhere(), + sqlCondition); + } + sqlSelect.setWhere(sqlCondition); + final SqlNode resultNode = + leftResult.neededAlias == null ? sqlSelect + : as(sqlSelect, leftResult.neededAlias); + return result(resultNode, leftResult, rightResult); + } + + private static boolean isCrossJoin(final Join e) { + return e.getJoinType() == JoinRelType.INNER && e.getCondition().isAlwaysTrue(); + } + + /** + * Visits a Correlate; called by {@link #dispatch} via reflection. + */ + public Result visit(Correlate e) { + final Result leftResult = + visitInput(e, 0) + .resetAlias(e.getCorrelVariable(), e.getRowType()); + parseCorrelTable(e, leftResult); + final Result rightResult = visitInput(e, 1); + final SqlNode rightLateral = + SqlStdOperatorTable.LATERAL.createCall(POS, rightResult.node); + final SqlNode rightLateralAs = + SqlStdOperatorTable.AS.createCall(POS, rightLateral, + new SqlIdentifier( + requireNonNull(rightResult.neededAlias, + () -> "rightResult.neededAlias is null, node is " + rightResult.node), POS)); + + final SqlNode join = + new SqlJoin(POS, + leftResult.asFrom(), + SqlLiteral.createBoolean(false, POS), + JoinType.COMMA.symbol(POS), + rightLateralAs, + JoinConditionType.NONE.symbol(POS), + null); + return result(join, leftResult, rightResult); + } + + /** + * Visits a Filter; called by {@link #dispatch} via reflection. + */ + public Result visit(Filter e) { + final RelNode input = e.getInput(); + if (input instanceof Aggregate) { + final Aggregate aggregate = (Aggregate) input; + final boolean ignoreClauses = aggregate.getInput() instanceof Project; + final Result x = visitInput(e, 0, isAnon(), ignoreClauses, + ImmutableSet.of(Clause.HAVING)); + parseCorrelTable(e, x); + final Builder builder = x.builder(e); + builder.setHaving(builder.context.toSql(null, e.getCondition())); + return builder.result(); + } else { + final Result x = visitInput(e, 0, Clause.WHERE); + parseCorrelTable(e, x); + final Builder builder = x.builder(e); + builder.setWhere(builder.context.toSql(null, e.getCondition())); + return builder.result(); + } + } + + /** + * Visits a Project; called by {@link #dispatch} via reflection. + */ + public Result visit(Project e) { + final Result x = visitInput(e, 0, Clause.SELECT); + parseCorrelTable(e, x); + final Builder builder = x.builder(e); + if (!isStar(e.getProjects(), e.getInput().getRowType(), e.getRowType())) { + final List selectList = new ArrayList<>(); + for (RexNode ref : e.getProjects()) { + SqlNode sqlExpr = builder.context.toSql(null, ref); + if (SqlUtil.isNullLiteral(sqlExpr, false)) { + final RelDataTypeField field = + e.getRowType().getFieldList().get(selectList.size()); + sqlExpr = castNullType(sqlExpr, field.getType()); + } + addSelect(selectList, sqlExpr, e.getRowType()); + } + + builder.setSelect(new SqlNodeList(selectList, POS)); + } + return builder.result(); + } + + /** + * Wraps a NULL literal in a CAST operator to a target type. + * + * @param nullLiteral NULL literal + * @param type Target type + * @return null literal wrapped in CAST call + */ + private SqlNode castNullType(SqlNode nullLiteral, RelDataType type) { + final SqlNode typeNode = dialect.getCastSpec(type); + if (typeNode == null) { + return nullLiteral; + } + return SqlStdOperatorTable.CAST.createCall(POS, nullLiteral, typeNode); + } + + /** + * Visits a Window; called by {@link #dispatch} via reflection. + */ + public Result visit(Window e) { + final Result x = visitInput(e, 0); + final Builder builder = x.builder(e); + final RelNode input = e.getInput(); + final int inputFieldCount = input.getRowType().getFieldCount(); + final List rexOvers = new ArrayList<>(); + for (Window.Group group : e.groups) { + rexOvers.addAll(builder.context.toSql(group, e.constants, inputFieldCount)); + } + final List selectList = new ArrayList<>(); + + for (RelDataTypeField field : input.getRowType().getFieldList()) { + addSelect(selectList, builder.context.field(field.getIndex()), e.getRowType()); + } + + for (SqlNode rexOver : rexOvers) { + addSelect(selectList, rexOver, e.getRowType()); + } + + builder.setSelect(new SqlNodeList(selectList, POS)); + return builder.result(); + } + + /** + * Visits an Aggregate; called by {@link #dispatch} via reflection. + */ + public Result visit(Aggregate e) { + final Builder builder = + visitAggregate(e, e.getGroupSet().toList(), Clause.GROUP_BY); + return builder.result(); + } + + private Builder visitAggregate(Aggregate e, List groupKeyList, + Clause... clauses) { + // "select a, b, sum(x) from ( ... ) group by a, b" + final boolean ignoreClauses = e.getInput() instanceof Project; + final Result x = visitInput(e, 0, isAnon(), ignoreClauses, + ImmutableSet.copyOf(clauses)); + final Builder builder = x.builder(e); + final List selectList = new ArrayList<>(); + final List groupByList = + generateGroupList(builder, selectList, e, groupKeyList); + return buildAggregate(e, builder, selectList, groupByList); + } + + /** + * Builds the group list for an Aggregate node. + * + * @param e The Aggregate node + * @param builder The SQL builder + * @param groupByList output group list + * @param selectList output select list + */ + protected void buildAggGroupList(Aggregate e, Builder builder, + List groupByList, List selectList) { + for (int group : e.getGroupSet()) { + final SqlNode field = builder.context.field(group); + addSelect(selectList, field, e.getRowType()); + groupByList.add(field); + } + } + + /** + * Builds an aggregate query. + * + * @param e The Aggregate node + * @param builder The SQL builder + * @param selectList The precomputed group list + * @param groupByList The precomputed select list + * @return The aggregate query result + */ + protected Builder buildAggregate(Aggregate e, Builder builder, + List selectList, List groupByList) { + for (AggregateCall aggCall : e.getAggCallList()) { + SqlNode aggCallSqlNode = builder.context.toSql(aggCall); + if (aggCall.getAggregation() instanceof SqlSingleValueAggFunction) { + aggCallSqlNode = dialect.rewriteSingleValueExpr(aggCallSqlNode); + } + addSelect(selectList, aggCallSqlNode, e.getRowType()); + } + builder.setSelect(new SqlNodeList(selectList, POS)); + if (!groupByList.isEmpty() || e.getAggCallList().isEmpty()) { + // Some databases don't support "GROUP BY ()". We can omit it as long + // as there is at least one aggregate function. + builder.setGroupBy(new SqlNodeList(groupByList, POS)); + } + return builder; + } + + /** + * Generates the GROUP BY items, for example {@code GROUP BY x, y}, + * {@code GROUP BY CUBE (x, y)} or {@code GROUP BY ROLLUP (x, y)}. + * + *

Also populates the SELECT clause. If the GROUP BY list is simple, the + * SELECT will be identical; if the GROUP BY list contains GROUPING SETS, + * CUBE or ROLLUP, the SELECT clause will contain the distinct leaf + * expressions. + */ + private List generateGroupList(Builder builder, + List selectList, Aggregate aggregate, + List groupList) { + final List sortedGroupList = + Ordering.natural().sortedCopy(groupList); + assert aggregate.getGroupSet().asList().equals(sortedGroupList) + : "groupList " + groupList + " must be equal to groupSet " + + aggregate.getGroupSet() + ", just possibly a different order"; + + final List groupKeys = new ArrayList<>(); + for (int key : groupList) { + final SqlNode field = builder.context.field(key); + groupKeys.add(field); + } + for (int key : sortedGroupList) { + final SqlNode field = builder.context.field(key); + addSelect(selectList, field, aggregate.getRowType()); + } + switch (aggregate.getGroupType()) { + case SIMPLE: + return ImmutableList.copyOf(groupKeys); + case CUBE: + if (aggregate.getGroupSet().cardinality() > 1) { + return ImmutableList.of( + SqlStdOperatorTable.CUBE.createCall(SqlParserPos.ZERO, groupKeys)); + } + // a singleton CUBE and ROLLUP are the same but we prefer ROLLUP; + // fall through + case ROLLUP: + return ImmutableList.of( + SqlStdOperatorTable.ROLLUP.createCall(SqlParserPos.ZERO, groupKeys)); + default: + case OTHER: + return ImmutableList.of( + SqlStdOperatorTable.GROUPING_SETS.createCall(SqlParserPos.ZERO, + aggregate.getGroupSets().stream() + .map(groupSet -> + groupItem(groupKeys, groupSet, aggregate.getGroupSet())) + .collect(Collectors.toList()))); + } + } + + private static SqlNode groupItem(List groupKeys, + ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) { + final List nodes = groupSet.asList().stream() + .map(key -> groupKeys.get(wholeGroupSet.indexOf(key))) + .collect(Collectors.toList()); + switch (nodes.size()) { + case 1: + return nodes.get(0); + default: + return SqlStdOperatorTable.ROW.createCall(SqlParserPos.ZERO, nodes); + } + } + + /** + * Visits a TableScan; called by {@link #dispatch} via reflection. + */ + public Result visit(TableScan e) { + final SqlIdentifier identifier = getSqlTargetTable(e); + return result(identifier, ImmutableList.of(Clause.FROM), e, null); + } + + /** + * Visits a Union; called by {@link #dispatch} via reflection. + */ + public Result visit(Union e) { + return setOpToSql(e.all + ? SqlStdOperatorTable.UNION_ALL + : SqlStdOperatorTable.UNION, e); + } + + /** + * Visits an Intersect; called by {@link #dispatch} via reflection. + */ + public Result visit(Intersect e) { + return setOpToSql(e.all + ? SqlStdOperatorTable.INTERSECT_ALL + : SqlStdOperatorTable.INTERSECT, e); + } + + /** + * Visits a Minus; called by {@link #dispatch} via reflection. + */ + public Result visit(Minus e) { + return setOpToSql(e.all + ? SqlStdOperatorTable.EXCEPT_ALL + : SqlStdOperatorTable.EXCEPT, e); + } + + /** + * Visits a Calc; called by {@link #dispatch} via reflection. + */ + public Result visit(Calc e) { + final RexProgram program = e.getProgram(); + final ImmutableSet expectedClauses = + program.getCondition() != null + ? ImmutableSet.of(Clause.WHERE) + : ImmutableSet.of(); + final Result x = visitInput(e, 0, expectedClauses); + parseCorrelTable(e, x); + final Builder builder = x.builder(e); + if (!isStar(program)) { + final List selectList = new ArrayList<>(program.getProjectList().size()); + for (RexLocalRef ref : program.getProjectList()) { + SqlNode sqlExpr = builder.context.toSql(program, ref); + addSelect(selectList, sqlExpr, e.getRowType()); + } + builder.setSelect(new SqlNodeList(selectList, POS)); + } + + if (program.getCondition() != null) { + builder.setWhere( + builder.context.toSql(program, program.getCondition())); + } + return builder.result(); + } + + /** + * Visits a Values; called by {@link #dispatch} via reflection. + */ + public Result visit(Values e) { + final List clauses = ImmutableList.of(Clause.SELECT); + final Map pairs = ImmutableMap.of(); + final Context context = aliasContext(pairs, false); + SqlNode query; + final boolean rename = stack.size() <= 1 + || !(Iterables.get(stack, 1).r instanceof TableModify); + final List fieldNames = e.getRowType().getFieldNames(); + if (!dialect.supportsAliasedValues() && rename) { + // Some dialects (such as Oracle and BigQuery) don't support + // "AS t (c1, c2)". So instead of + // (VALUES (v0, v1), (v2, v3)) AS t (c0, c1) + // we generate + // SELECT v0 AS c0, v1 AS c1 FROM DUAL + // UNION ALL + // SELECT v2 AS c0, v3 AS c1 FROM DUAL + // for Oracle and + // SELECT v0 AS c0, v1 AS c1 + // UNION ALL + // SELECT v2 AS c0, v3 AS c1 + // for dialects that support SELECT-without-FROM. + List list = new ArrayList<>(); + for (List tuple : e.getTuples()) { + final List values2 = new ArrayList<>(); + final SqlNodeList exprList = exprList(context, tuple); + for (Pair value : Pair.zip(exprList, fieldNames)) { + values2.add(as(value.left, value.right)); + } + list.add( + new SqlSelect(POS, null, + new SqlNodeList(values2, POS), + getDual(), null, null, + null, null, null, null, null, null)); + } + if (list.isEmpty()) { + // In this case we need to construct the following query: + // SELECT NULL as C0, NULL as C1, NULL as C2 ... FROM DUAL WHERE FALSE + // This would return an empty result set with the same number of columns as the field names. + final List nullColumnNames = new ArrayList<>(fieldNames.size()); + for (String fieldName : fieldNames) { + SqlCall nullColumnName = as(SqlLiteral.createNull(POS), fieldName); + nullColumnNames.add(nullColumnName); + } + final SqlIdentifier dual = getDual(); + if (dual == null) { + query = new SqlSelect(POS, null, + new SqlNodeList(nullColumnNames, POS), null, null, null, null, + null, null, null, null, null); + + // Wrap "SELECT 1 AS x" + // as "SELECT * FROM (SELECT 1 AS x) AS t WHERE false" + query = new SqlSelect(POS, null, SqlNodeList.SINGLETON_STAR, + as(query, "t"), createAlwaysFalseCondition(), null, null, + null, null, null, null, null); + } else { + query = new SqlSelect(POS, null, + new SqlNodeList(nullColumnNames, POS), + dual, createAlwaysFalseCondition(), null, + null, null, null, null, null, null); + } + } else if (list.size() == 1) { + query = list.get(0); + } else { + query = SqlStdOperatorTable.UNION_ALL.createCall( + new SqlNodeList(list, POS)); + } + } else { + // Generate ANSI syntax + // (VALUES (v0, v1), (v2, v3)) + // or, if rename is required + // (VALUES (v0, v1), (v2, v3)) AS t (c0, c1) + final SqlNodeList selects = new SqlNodeList(POS); + final boolean isEmpty = Values.isEmpty(e); + if (isEmpty) { + // In case of empty values, we need to build: + // SELECT * + // FROM (VALUES (NULL, NULL ...)) AS T (C1, C2 ...) + // WHERE 1 = 0 + selects.add( + SqlInternalOperators.ANONYMOUS_ROW.createCall(POS, + Collections.nCopies(fieldNames.size(), + SqlLiteral.createNull(POS)))); + } else { + for (List tuple : e.getTuples()) { + selects.add( + SqlInternalOperators.ANONYMOUS_ROW.createCall( + exprList(context, tuple))); + } + } + query = SqlStdOperatorTable.VALUES.createCall(selects); + if (rename) { + query = as(query, "t", fieldNames.toArray(new String[0])); + } + if (isEmpty) { + if (!rename) { + query = as(query, "t"); + } + query = + new SqlSelect(POS, null, SqlNodeList.SINGLETON_STAR, query, + createAlwaysFalseCondition(), null, null, null, + null, null, null, null); + } + } + return result(query, clauses, e, null); + } + + private @Nullable SqlIdentifier getDual() { + final List names = dialect.getSingleRowTableName(); + if (names == null) { + return null; + } + return new SqlIdentifier(names, POS); + } + + private static SqlNode createAlwaysFalseCondition() { + // Building the select query in the form: + // select * from VALUES(NULL,NULL ...) where 1=0 + // Use condition 1=0 since "where false" does not seem to be supported + // on some DB vendors. + return SqlStdOperatorTable.EQUALS.createCall(POS, + ImmutableList.of(SqlLiteral.createExactNumeric("1", POS), + SqlLiteral.createExactNumeric("0", POS))); + } + + /** + * Visits a Sort; called by {@link #dispatch} via reflection. + */ + public Result visit(Sort e) { + if (e.getInput() instanceof Aggregate) { + final Aggregate aggregate = (Aggregate) e.getInput(); + if (hasTrickyRollup(e, aggregate)) { + // MySQL 5 does not support standard "GROUP BY ROLLUP(x, y)", only + // the non-standard "GROUP BY x, y WITH ROLLUP". + // It does not allow "WITH ROLLUP" in combination with "ORDER BY", + // but "GROUP BY x, y WITH ROLLUP" implicitly sorts by x, y, + // so skip the ORDER BY. + final Set groupList = new LinkedHashSet<>(); + for (RelFieldCollation fc : e.collation.getFieldCollations()) { + groupList.add(aggregate.getGroupSet().nth(fc.getFieldIndex())); + } + groupList.addAll(Aggregate.Group.getRollup(aggregate.getGroupSets())); + final Builder builder = + visitAggregate(aggregate, ImmutableList.copyOf(groupList), + Clause.GROUP_BY, Clause.OFFSET, Clause.FETCH); + offsetFetch(e, builder); + return builder.result(); + } + } + if (e.getInput() instanceof Project) { + // Deal with the case Sort(Project(Aggregate ...)) + // by converting it to Project(Sort(Aggregate ...)). + final Project project = (Project) e.getInput(); + final Permutation permutation = project.getPermutation(); + if (permutation != null + && project.getInput() instanceof Aggregate) { + final Aggregate aggregate = (Aggregate) project.getInput(); + if (hasTrickyRollup(e, aggregate)) { + final RelCollation collation = + RelCollations.permute(e.collation, permutation); + final Sort sort2 = + LogicalSort.create(aggregate, collation, e.offset, e.fetch); + final Project project2 = + LogicalProject.create( + sort2, + ImmutableList.of(), + project.getProjects(), + project.getRowType()); + return visit(project2); + } + } + } + final Result x = visitInput(e, 0, Clause.ORDER_BY, Clause.OFFSET, + Clause.FETCH); + final Builder builder = x.builder(e); + if (stack.size() != 1 + && builder.select.getSelectList().equals(SqlNodeList.SINGLETON_STAR)) { + // Generates explicit column names instead of start(*) for + // non-root order by to avoid ambiguity. + final List selectList = Expressions.list(); + for (RelDataTypeField field : e.getRowType().getFieldList()) { + addSelect(selectList, builder.context.field(field.getIndex()), e.getRowType()); + } + builder.select.setSelectList(new SqlNodeList(selectList, POS)); + } + List orderByList = Expressions.list(); + for (RelFieldCollation field : e.getCollation().getFieldCollations()) { + builder.addOrderItem(orderByList, field); + } + if (!orderByList.isEmpty()) { + builder.setOrderBy(new SqlNodeList(orderByList, POS)); + } + offsetFetch(e, builder); + return builder.result(); + } + + /** + * Adds OFFSET and FETCH to a builder, if applicable. + * The builder must have been created with OFFSET and FETCH clauses. + */ + void offsetFetch(Sort e, Builder builder) { + if (e.fetch != null) { + builder.setFetch(builder.context.toSql(null, e.fetch)); + } + if (e.offset != null) { + builder.setOffset(builder.context.toSql(null, e.offset)); + } + } + + public boolean hasTrickyRollup(Sort e, Aggregate aggregate) { + return !dialect.supportsAggregateFunction(SqlKind.ROLLUP) + && dialect.supportsGroupByWithRollup() + && (aggregate.getGroupType() == Aggregate.Group.ROLLUP + || aggregate.getGroupType() == Aggregate.Group.CUBE + && aggregate.getGroupSet().cardinality() == 1) + && e.collation.getFieldCollations().stream().allMatch(fc -> + fc.getFieldIndex() < aggregate.getGroupSet().cardinality()); + } + + private static SqlIdentifier getSqlTargetTable(RelNode e) { + // Use the foreign catalog, schema and table names, if they exist, + // rather than the qualified name of the shadow table in Calcite. + final RelOptTable table = requireNonNull(e.getTable()); + return table.maybeUnwrap(TrainDBJdbcTable.class) + .map(TrainDBJdbcTable::tableName) + .orElseGet(() -> + new SqlIdentifier(table.getQualifiedName(), SqlParserPos.ZERO)); + } + + /** + * Visits a TableModify; called by {@link #dispatch} via reflection. + */ + public Result visit(TableModify modify) { + final Map pairs = ImmutableMap.of(); + final Context context = aliasContext(pairs, false); + + // Target Table Name + final SqlIdentifier sqlTargetTable = getSqlTargetTable(modify); + + switch (modify.getOperation()) { + case INSERT: { + // Convert the input to a SELECT query or keep as VALUES. Not all + // dialects support naked VALUES, but all support VALUES inside INSERT. + final SqlNode sqlSource = + visitInput(modify, 0).asQueryOrValues(); + + final SqlInsert sqlInsert = + new SqlInsert(POS, SqlNodeList.EMPTY, sqlTargetTable, sqlSource, + identifierList(modify.getTable().getRowType().getFieldNames())); + + return result(sqlInsert, ImmutableList.of(), modify, null); + } + case UPDATE: { + final Result input = visitInput(modify, 0); + + final SqlUpdate sqlUpdate = + new SqlUpdate(POS, sqlTargetTable, + identifierList( + requireNonNull(modify.getUpdateColumnList(), + () -> "modify.getUpdateColumnList() is null for " + modify)), + exprList(context, + requireNonNull(modify.getSourceExpressionList(), + () -> "modify.getSourceExpressionList() is null for " + modify)), + ((SqlSelect) input.node).getWhere(), input.asSelect(), + null); + + return result(sqlUpdate, input.clauses, modify, null); + } + case DELETE: { + final Result input = visitInput(modify, 0); + + final SqlDelete sqlDelete = + new SqlDelete(POS, sqlTargetTable, + input.asSelect().getWhere(), input.asSelect(), null); + + return result(sqlDelete, input.clauses, modify, null); + } + case MERGE: + default: + throw new AssertionError("not implemented: " + modify); + } + } + + /** + * Converts a list of {@link RexNode} expressions to {@link SqlNode} + * expressions. + */ + private static SqlNodeList exprList(final Context context, + List exprs) { + return new SqlNodeList( + Util.transform(exprs, e -> context.toSql(null, e)), POS); + } + + /** + * Converts a list of names expressions to a list of single-part + * {@link SqlIdentifier}s. + */ + private static SqlNodeList identifierList(List names) { + return new SqlNodeList( + Util.transform(names, name -> new SqlIdentifier(name, POS)), POS); + } + + /** + * Visits a Match; called by {@link #dispatch} via reflection. + */ + public Result visit(Match e) { + final RelNode input = e.getInput(); + final Result x = visitInput(e, 0); + final Context context = matchRecognizeContext(x.qualifiedContext()); + + SqlNode tableRef = x.asQueryOrValues(); + + final RexBuilder rexBuilder = input.getCluster().getRexBuilder(); + final List partitionSqlList = new ArrayList<>(); + for (int key : e.getPartitionKeys()) { + final RexInputRef ref = rexBuilder.makeInputRef(input, key); + SqlNode sqlNode = context.toSql(null, ref); + partitionSqlList.add(sqlNode); + } + final SqlNodeList partitionList = new SqlNodeList(partitionSqlList, POS); + + final List orderBySqlList = new ArrayList<>(); + if (e.getOrderKeys() != null) { + for (RelFieldCollation fc : e.getOrderKeys().getFieldCollations()) { + if (fc.nullDirection != RelFieldCollation.NullDirection.UNSPECIFIED) { + boolean first = fc.nullDirection == RelFieldCollation.NullDirection.FIRST; + SqlNode nullDirectionNode = + dialect.emulateNullDirection(context.field(fc.getFieldIndex()), + first, fc.direction.isDescending()); + if (nullDirectionNode != null) { + orderBySqlList.add(nullDirectionNode); + fc = new RelFieldCollation(fc.getFieldIndex(), fc.getDirection(), + RelFieldCollation.NullDirection.UNSPECIFIED); + } + } + orderBySqlList.add(context.toSql(fc)); + } + } + final SqlNodeList orderByList = new SqlNodeList(orderBySqlList, SqlParserPos.ZERO); + + final SqlLiteral rowsPerMatch = e.isAllRows() + ? SqlMatchRecognize.RowsPerMatchOption.ALL_ROWS.symbol(POS) + : SqlMatchRecognize.RowsPerMatchOption.ONE_ROW.symbol(POS); + + final SqlNode after; + if (e.getAfter() instanceof RexLiteral) { + SqlMatchRecognize.AfterOption value = (SqlMatchRecognize.AfterOption) + ((RexLiteral) e.getAfter()).getValue2(); + after = SqlLiteral.createSymbol(value, POS); + } else { + RexCall call = (RexCall) e.getAfter(); + String operand = requireNonNull(stringValue(call.getOperands().get(0)), + () -> "non-null string value expected for 0th operand of AFTER call " + call); + after = call.getOperator().createCall(POS, new SqlIdentifier(operand, POS)); + } + + RexNode rexPattern = e.getPattern(); + final SqlNode pattern = context.toSql(null, rexPattern); + final SqlLiteral strictStart = SqlLiteral.createBoolean(e.isStrictStart(), POS); + final SqlLiteral strictEnd = SqlLiteral.createBoolean(e.isStrictEnd(), POS); + + RexLiteral rexInterval = (RexLiteral) e.getInterval(); + SqlIntervalLiteral interval = null; + if (rexInterval != null) { + interval = (SqlIntervalLiteral) context.toSql(null, rexInterval); + } + + final SqlNodeList subsetList = new SqlNodeList(POS); + for (Map.Entry> entry : e.getSubsets().entrySet()) { + SqlNode left = new SqlIdentifier(entry.getKey(), POS); + List rhl = new ArrayList<>(); + for (String right : entry.getValue()) { + rhl.add(new SqlIdentifier(right, POS)); + } + subsetList.add( + SqlStdOperatorTable.EQUALS.createCall(POS, left, + new SqlNodeList(rhl, POS))); + } + + final SqlNodeList measureList = new SqlNodeList(POS); + for (Map.Entry entry : e.getMeasures().entrySet()) { + final String alias = entry.getKey(); + final SqlNode sqlNode = context.toSql(null, entry.getValue()); + measureList.add(as(sqlNode, alias)); + } + + final SqlNodeList patternDefList = new SqlNodeList(POS); + for (Map.Entry entry : e.getPatternDefinitions().entrySet()) { + final String alias = entry.getKey(); + final SqlNode sqlNode = context.toSql(null, entry.getValue()); + patternDefList.add(as(sqlNode, alias)); + } + + final SqlNode matchRecognize = new SqlMatchRecognize(POS, tableRef, + pattern, strictStart, strictEnd, patternDefList, measureList, after, + subsetList, rowsPerMatch, partitionList, orderByList, interval); + return result(matchRecognize, Expressions.list(Clause.FROM), e, null); + } + + private static SqlCall as(SqlNode e, String alias) { + return SqlStdOperatorTable.AS.createCall(POS, e, + new SqlIdentifier(alias, POS)); + } + + public Result visit(Uncollect e) { + final Result x = visitInput(e, 0); + final SqlNode unnestNode = SqlStdOperatorTable.UNNEST.createCall(POS, x.asStatement()); + final List operands = createAsFullOperands(e.getRowType(), unnestNode, + requireNonNull(x.neededAlias, () -> "x.neededAlias is null, node is " + x.node)); + final SqlNode asNode = SqlStdOperatorTable.AS.createCall(POS, operands); + return result(asNode, ImmutableList.of(Clause.FROM), e, null); + } + + public Result visit(TableFunctionScan e) { + final List inputSqlNodes = new ArrayList<>(); + final int inputSize = e.getInputs().size(); + for (int i = 0; i < inputSize; i++) { + final Result x = visitInput(e, i); + inputSqlNodes.add(x.asStatement()); + } + final Context context = tableFunctionScanContext(inputSqlNodes); + SqlNode callNode = context.toSql(null, e.getCall()); + // Convert to table function call, "TABLE($function_name(xxx))" + SqlNode tableCall = new SqlBasicCall( + SqlStdOperatorTable.COLLECTION_TABLE, + new SqlNode[] {callNode}, + SqlParserPos.ZERO); + SqlNode select = new SqlSelect( + SqlParserPos.ZERO, null, SqlNodeList.SINGLETON_STAR, tableCall, + null, null, null, null, null, null, null, SqlNodeList.EMPTY); + return result(select, ImmutableList.of(Clause.SELECT), e, null); + } + + /** + * Creates operands for a full AS operator. Format SqlNode AS alias(col_1, col_2,... ,col_n). + * + * @param rowType Row type of the SqlNode + * @param leftOperand SqlNode + * @param alias alias + */ + public List createAsFullOperands(RelDataType rowType, SqlNode leftOperand, + String alias) { + final List result = new ArrayList<>(); + result.add(leftOperand); + result.add(new SqlIdentifier(alias, POS)); + Ord.forEach(rowType.getFieldNames(), (fieldName, i) -> { + if (fieldName.toLowerCase(Locale.ROOT).startsWith("expr$")) { + fieldName = "col_" + i; + } + result.add(new SqlIdentifier(fieldName, POS)); + }); + return result; + } + + @Override + public void addSelect(List selectList, SqlNode node, + RelDataType rowType) { + String name = rowType.getFieldNames().get(selectList.size()); + String alias = SqlValidatorUtil.getAlias(node, -1); + if (alias == null || !alias.equals(name)) { + node = as(node, name); + } + selectList.add(node); + } + + private void parseCorrelTable(RelNode relNode, Result x) { + for (CorrelationId id : relNode.getVariablesSet()) { + correlTableMap.put(id, x.qualifiedContext()); + } + } + + /** + * Stack frame. + */ + private static class Frame { + private final RelNode parent; + @SuppressWarnings("unused") + private final int ordinalInParent; + private final RelNode r; + private final boolean anon; + private final boolean ignoreClauses; + private final ImmutableSet expectedClauses; + + Frame(RelNode parent, int ordinalInParent, RelNode r, boolean anon, + boolean ignoreClauses, Iterable expectedClauses) { + this.parent = requireNonNull(parent, "parent"); + this.ordinalInParent = ordinalInParent; + this.r = requireNonNull(r, "r"); + this.anon = anon; + this.ignoreClauses = ignoreClauses; + this.expectedClauses = ImmutableSet.copyOf(expectedClauses); + } + } +} diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/SqlImplementor.java b/traindb-core/src/main/java/traindb/adapter/jdbc/SqlImplementor.java new file mode 100644 index 0000000..e510d22 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/SqlImplementor.java @@ -0,0 +1,2200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.linq4j.Nullness.castNonNull; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import java.math.BigDecimal; +import java.util.AbstractList; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.function.IntFunction; +import java.util.function.Predicate; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.SingleRel; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeSystemImpl; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexDynamicParam; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexFieldCollation; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexLocalRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; +import org.apache.calcite.rex.RexPatternFieldRef; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.rex.RexUnknownAs; +import org.apache.calcite.rex.RexWindow; +import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.sql.JoinType; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlBinaryOperator; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.sql.SqlDynamicParam; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlMatchRecognize; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOverOperator; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlSelectKeyword; +import org.apache.calcite.sql.SqlSetOperator; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.SqlWindow; +import org.apache.calcite.sql.fun.SqlCase; +import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.util.DateString; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.RangeSets; +import org.apache.calcite.util.Sarg; +import org.apache.calcite.util.TimeString; +import org.apache.calcite.util.TimestampString; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * State for generating a SQL statement. + */ +public abstract class SqlImplementor { + + // Always use quoted position, the "isQuoted" info is only used when + // unparsing a SqlIdentifier. For some rex nodes, saying RexInputRef, we have + // no idea about whether it is quoted or not for the original sql statement. + // So we just quote it. + public static final SqlParserPos POS = SqlParserPos.QUOTED_ZERO; + + public final SqlDialect dialect; + protected final Set aliasSet = new LinkedHashSet<>(); + + protected final Map correlTableMap = new HashMap<>(); + + /** + * Private RexBuilder for short-lived expressions. It has its own + * dedicated type factory, so don't trust the types to be canonized. + */ + final RexBuilder rexBuilder = + new RexBuilder(new SqlTypeFactoryImpl(RelDataTypeSystemImpl.DEFAULT)); + + protected SqlImplementor(SqlDialect dialect) { + this.dialect = requireNonNull(dialect, "dialect"); + } + + /** + * Visits a relational expression that has no parent. + */ + public final Result visitRoot(RelNode e) { + return visitInput(holder(e), 0); + } + + /** + * Creates a relational expression that has {@code r} as its input. + */ + private static RelNode holder(RelNode r) { + return new SingleRel(r.getCluster(), r.getTraitSet(), r) { + }; + } + + // CHECKSTYLE: IGNORE 1 + + /** + * @deprecated Use either {@link #visitRoot(RelNode)} or + * {@link #visitInput(RelNode, int)}. + */ + @Deprecated // to be removed before 2.0 + public final Result visitChild(int i, RelNode e) { + throw new UnsupportedOperationException(); + } + + /** + * Visits an input of the current relational expression, + * deducing {@code anon} using {@link #isAnon()}. + */ + public final Result visitInput(RelNode e, int i) { + return visitInput(e, i, ImmutableSet.of()); + } + + /** + * Visits an input of the current relational expression, + * with the given expected clauses. + */ + public final Result visitInput(RelNode e, int i, Clause... clauses) { + return visitInput(e, i, ImmutableSet.copyOf(clauses)); + } + + /** + * Visits an input of the current relational expression, + * deducing {@code anon} using {@link #isAnon()}. + */ + public final Result visitInput(RelNode e, int i, Set clauses) { + return visitInput(e, i, isAnon(), false, clauses); + } + + /** + * Visits the {@code i}th input of {@code e}, the current relational + * expression. + * + * @param e Current relational expression + * @param i Ordinal of input within {@code e} + * @param anon Whether to remove trivial aliases such as "EXPR$0" + * @param ignoreClauses Whether to ignore the expected clauses when deciding + * whether a sub-query is required + * @param expectedClauses Set of clauses that we expect the builder that + * consumes this result will create + * @return Result + * @see #isAnon() + */ + public abstract Result visitInput(RelNode e, int i, boolean anon, + boolean ignoreClauses, Set expectedClauses); + + public void addSelect(List selectList, SqlNode node, + RelDataType rowType) { + String name = rowType.getFieldNames().get(selectList.size()); + String alias = SqlValidatorUtil.getAlias(node, -1); + if (alias == null || !alias.equals(name)) { + node = as(node, name); + } + selectList.add(node); + } + + /** + * Convenience method for creating column and table aliases. + * + *

{@code AS(e, "c")} creates "e AS c"; + * {@code AS(e, "t", "c1", "c2"} creates "e AS t (c1, c2)". + */ + protected SqlCall as(SqlNode e, String alias, String... fieldNames) { + final List operandList = new ArrayList<>(); + operandList.add(e); + operandList.add(new SqlIdentifier(alias, POS)); + for (String fieldName : fieldNames) { + operandList.add(new SqlIdentifier(fieldName, POS)); + } + return SqlStdOperatorTable.AS.createCall(POS, operandList); + } + + /** + * Returns whether a list of expressions projects all fields, in order, + * from the input, with the same names. + */ + public static boolean isStar(List exps, RelDataType inputRowType, + RelDataType projectRowType) { + assert exps.size() == projectRowType.getFieldCount(); + int i = 0; + for (RexNode ref : exps) { + if (!(ref instanceof RexInputRef)) { + return false; + } else if (((RexInputRef) ref).getIndex() != i++) { + return false; + } + } + return i == inputRowType.getFieldCount() + && inputRowType.getFieldNames().equals(projectRowType.getFieldNames()); + } + + public static boolean isStar(RexProgram program) { + int i = 0; + for (RexLocalRef ref : program.getProjectList()) { + if (ref.getIndex() != i++) { + return false; + } + } + return i == program.getInputRowType().getFieldCount(); + } + + public Result setOpToSql(SqlSetOperator operator, RelNode rel) { + SqlNode node = null; + for (Ord input : Ord.zip(rel.getInputs())) { + final Result result = visitInput(rel, input.i); + if (node == null) { + node = result.asSelect(); + } else { + node = operator.createCall(POS, node, result.asSelect()); + } + } + assert node != null : "set op must have at least one input, operator = " + operator + + ", rel = " + rel; + final List clauses = + Expressions.list(Clause.SET_OP); + return result(node, clauses, rel, null); + } + + /** + * Converts a {@link RexNode} condition into a {@link SqlNode}. + * + * @param node Join condition + * @param leftContext Left context + * @param rightContext Right context + * @return SqlNode that represents the condition + */ + public static SqlNode convertConditionToSqlNode(RexNode node, + Context leftContext, + Context rightContext) { + if (node.isAlwaysTrue()) { + return SqlLiteral.createBoolean(true, POS); + } + if (node.isAlwaysFalse()) { + return SqlLiteral.createBoolean(false, POS); + } + final Context joinContext = + leftContext.implementor().joinContext(leftContext, rightContext); + return joinContext.toSql(null, node); + } + + /** + * Removes cast from string. + * + *

For example, {@code x > CAST('2015-01-07' AS DATE)} + * becomes {@code x > '2015-01-07'}. + */ + private static RexNode stripCastFromString(RexNode node, SqlDialect dialect) { + switch (node.getKind()) { + case EQUALS: + case IS_NOT_DISTINCT_FROM: + case NOT_EQUALS: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + final RexCall call = (RexCall) node; + final RexNode o0 = call.operands.get(0); + final RexNode o1 = call.operands.get(1); + if (o0.getKind() == SqlKind.CAST + && o1.getKind() != SqlKind.CAST) { + if (!dialect.supportsImplicitTypeCoercion((RexCall) o0)) { + // If the dialect does not support implicit type coercion, + // we definitely can not strip the cast. + return node; + } + final RexNode o0b = ((RexCall) o0).getOperands().get(0); + return call.clone(call.getType(), ImmutableList.of(o0b, o1)); + } + if (o1.getKind() == SqlKind.CAST + && o0.getKind() != SqlKind.CAST) { + if (!dialect.supportsImplicitTypeCoercion((RexCall) o1)) { + return node; + } + final RexNode o1b = ((RexCall) o1).getOperands().get(0); + return call.clone(call.getType(), ImmutableList.of(o0, o1b)); + } + break; + default: + break; + } + return node; + } + + public static JoinType joinType(JoinRelType joinType) { + switch (joinType) { + case LEFT: + return JoinType.LEFT; + case RIGHT: + return JoinType.RIGHT; + case INNER: + return JoinType.INNER; + case FULL: + return JoinType.FULL; + default: + throw new AssertionError(joinType); + } + } + + /** + * Creates a result based on a single relational expression. + */ + public Result result(SqlNode node, Collection clauses, + RelNode rel, @Nullable Map aliases) { + assert aliases == null + || aliases.size() < 2 + || aliases instanceof LinkedHashMap + || aliases instanceof ImmutableMap + : "must use a Map implementation that preserves order"; + final String alias2 = SqlValidatorUtil.getAlias(node, -1); + final String alias3 = alias2 != null ? alias2 : "t"; + final String alias4 = + SqlValidatorUtil.uniquify( + alias3, aliasSet, SqlValidatorUtil.EXPR_SUGGESTER); + final RelDataType rowType = adjustedRowType(rel, node); + if (aliases != null + && !aliases.isEmpty() + && (!dialect.hasImplicitTableAlias() + || aliases.size() > 1)) { + return result(node, clauses, alias4, rowType, aliases); + } + final String alias5; + if (alias2 == null + || !alias2.equals(alias4) + || !dialect.hasImplicitTableAlias()) { + alias5 = alias4; + } else { + alias5 = null; + } + return result(node, clauses, alias5, rowType, + ImmutableMap.of(alias4, rowType)); + } + + /** + * Factory method for {@link Result}. + * + *

Call this method rather than creating a {@code Result} directly, + * because sub-classes may override. + */ + protected Result result(SqlNode node, Collection clauses, + @Nullable String neededAlias, @Nullable RelDataType neededType, + Map aliases) { + return new Result(node, clauses, neededAlias, neededType, aliases); + } + + /** + * Returns the row type of {@code rel}, adjusting the field names if + * {@code node} is "(query) as tableAlias (fieldAlias, ...)". + */ + private static RelDataType adjustedRowType(RelNode rel, SqlNode node) { + final RelDataType rowType = rel.getRowType(); + final RelDataTypeFactory.Builder builder; + switch (node.getKind()) { + case UNION: + case INTERSECT: + case EXCEPT: + return adjustedRowType(rel, ((SqlCall) node).getOperandList().get(0)); + + case SELECT: + final SqlNodeList selectList = ((SqlSelect) node).getSelectList(); + if (selectList.equals(SqlNodeList.SINGLETON_STAR)) { + return rowType; + } + builder = rel.getCluster().getTypeFactory().builder(); + Pair.forEach(selectList, + rowType.getFieldList(), + (selectItem, field) -> + builder.add( + Util.first(SqlValidatorUtil.getAlias(selectItem, -1), + field.getName()), + field.getType())); + return builder.build(); + + case AS: + final List operandList = ((SqlCall) node).getOperandList(); + if (operandList.size() <= 2) { + return rowType; + } + builder = rel.getCluster().getTypeFactory().builder(); + Pair.forEach(Util.skip(operandList, 2), + rowType.getFieldList(), + (operand, field) -> + builder.add(operand.toString(), field.getType())); + return builder.build(); + + default: + return rowType; + } + } + + /** + * Creates a result based on a join. (Each join could contain one or more + * relational expressions.) + */ + public Result result(SqlNode join, Result leftResult, Result rightResult) { + final Map aliases; + if (join.getKind() == SqlKind.JOIN) { + final ImmutableMap.Builder builder = + ImmutableMap.builder(); + collectAliases(builder, join, + Iterables.concat(leftResult.aliases.values(), + rightResult.aliases.values()).iterator()); + aliases = builder.build(); + } else { + aliases = leftResult.aliases; + } + return result(join, ImmutableList.of(Clause.FROM), null, null, aliases); + } + + private static void collectAliases(ImmutableMap.Builder builder, + SqlNode node, Iterator aliases) { + if (node instanceof SqlJoin) { + final SqlJoin join = (SqlJoin) node; + collectAliases(builder, join.getLeft(), aliases); + collectAliases(builder, join.getRight(), aliases); + } else { + final String alias = SqlValidatorUtil.getAlias(node, -1); + assert alias != null; + builder.put(alias, aliases.next()); + } + } + + /** + * Returns whether to remove trivial aliases such as "EXPR$0" + * when converting the current relational expression into a SELECT. + * + *

For example, INSERT does not care about field names; + * we would prefer to generate without the "EXPR$0" alias: + * + *

{@code INSERT INTO t1 SELECT x, y + 1 FROM t2}
+ *

+ * rather than with it: + * + *

{@code INSERT INTO t1 SELECT x, y + 1 AS EXPR$0 FROM t2}
+ * + *

But JOIN does care about field names; we have to generate the "EXPR$0" + * alias: + * + *

{@code SELECT *
+   * FROM emp AS e
+   * JOIN (SELECT x, y + 1 AS EXPR$0) AS d
+   * ON e.deptno = d.EXPR$0}
+   * 
+ * + *

because if we omit "AS EXPR$0" we do not know the field we are joining + * to, and the following is invalid: + * + *

{@code SELECT *
+   * FROM emp AS e
+   * JOIN (SELECT x, y + 1) AS d
+   * ON e.deptno = d.EXPR$0}
+   * 
+ */ + protected boolean isAnon() { + return false; + } + + /** + * Wraps a node in a SELECT statement that has no clauses: + * "SELECT ... FROM (node)". + */ + SqlSelect wrapSelect(SqlNode node) { + assert node instanceof SqlJoin + || node instanceof SqlIdentifier + || node instanceof SqlMatchRecognize + || node instanceof SqlCall + && (((SqlCall) node).getOperator() instanceof SqlSetOperator + || ((SqlCall) node).getOperator() == SqlStdOperatorTable.AS + || ((SqlCall) node).getOperator() == SqlStdOperatorTable.VALUES) + : node; + if (requiresAlias(node)) { + node = as(node, "t"); + } + return new SqlSelect(POS, SqlNodeList.EMPTY, SqlNodeList.SINGLETON_STAR, + node, null, null, null, SqlNodeList.EMPTY, null, null, null, null); + } + + /** + * Returns whether we need to add an alias if this node is to be the FROM + * clause of a SELECT. + */ + private boolean requiresAlias(SqlNode node) { + if (!dialect.requiresAliasForFromItems()) { + return false; + } + switch (node.getKind()) { + case IDENTIFIER: + return !dialect.hasImplicitTableAlias(); + case AS: + case JOIN: + case EXPLICIT_TABLE: + return false; + default: + return true; + } + } + + /** + * Returns whether a node is a call to an aggregate function. + */ + private static boolean isAggregate(SqlNode node) { + return node instanceof SqlCall + && ((SqlCall) node).getOperator() instanceof SqlAggFunction; + } + + /** + * Returns whether a node is a call to a windowed aggregate function. + */ + private static boolean isWindowedAggregate(SqlNode node) { + return node instanceof SqlCall + && ((SqlCall) node).getOperator() instanceof SqlOverOperator; + } + + /** + * Context for translating a {@link RexNode} expression (within a + * {@link RelNode}) into a {@link SqlNode} expression (within a SQL parse + * tree). + */ + public abstract static class Context { + final SqlDialect dialect; + final int fieldCount; + private final boolean ignoreCast; + + protected Context(SqlDialect dialect, int fieldCount) { + this(dialect, fieldCount, false); + } + + protected Context(SqlDialect dialect, int fieldCount, boolean ignoreCast) { + this.dialect = dialect; + this.fieldCount = fieldCount; + this.ignoreCast = ignoreCast; + } + + public abstract SqlNode field(int ordinal); + + /** + * Creates a reference to a field to be used in an ORDER BY clause. + * + *

By default, it returns the same result as {@link #field}. + * + *

If the field has an alias, uses the alias. + * If the field is an unqualified column reference which is the same an + * alias, switches to a qualified column reference. + */ + public SqlNode orderField(int ordinal) { + return field(ordinal); + } + + /** + * Converts an expression from {@link RexNode} to {@link SqlNode} + * format. + * + * @param program Required only if {@code rex} contains {@link RexLocalRef} + * @param rex Expression to convert + */ + public SqlNode toSql(@Nullable RexProgram program, RexNode rex) { + final RexSubQuery subQuery; + final SqlNode sqlSubQuery; + final RexLiteral literal; + switch (rex.getKind()) { + case LOCAL_REF: + final int index = ((RexLocalRef) rex).getIndex(); + return toSql(program, requireNonNull(program, "program").getExprList().get(index)); + + case INPUT_REF: + return field(((RexInputRef) rex).getIndex()); + + case FIELD_ACCESS: + final Deque accesses = new ArrayDeque<>(); + RexNode referencedExpr = rex; + while (referencedExpr.getKind() == SqlKind.FIELD_ACCESS) { + accesses.offerLast((RexFieldAccess) referencedExpr); + referencedExpr = ((RexFieldAccess) referencedExpr).getReferenceExpr(); + } + SqlIdentifier sqlIdentifier; + switch (referencedExpr.getKind()) { + case CORREL_VARIABLE: + final RexCorrelVariable variable = (RexCorrelVariable) referencedExpr; + final Context correlAliasContext = getAliasContext(variable); + final RexFieldAccess lastAccess = accesses.pollLast(); + assert lastAccess != null; + sqlIdentifier = (SqlIdentifier) correlAliasContext + .field(lastAccess.getField().getIndex()); + break; + case ROW: + final SqlNode expr = toSql(program, referencedExpr); + sqlIdentifier = new SqlIdentifier(expr.toString(), POS); + break; + default: + sqlIdentifier = (SqlIdentifier) toSql(program, referencedExpr); + } + + int nameIndex = sqlIdentifier.names.size(); + RexFieldAccess access; + while ((access = accesses.pollLast()) != null) { + sqlIdentifier = sqlIdentifier.add(nameIndex++, access.getField().getName(), POS); + } + return sqlIdentifier; + + case PATTERN_INPUT_REF: + final RexPatternFieldRef ref = (RexPatternFieldRef) rex; + String pv = ref.getAlpha(); + SqlNode refNode = field(ref.getIndex()); + final SqlIdentifier id = (SqlIdentifier) refNode; + if (id.names.size() > 1) { + return id.setName(0, pv); + } else { + return new SqlIdentifier(ImmutableList.of(pv, id.names.get(0)), POS); + } + + case LITERAL: + return SqlImplementor.toSql(program, (RexLiteral) rex); + + case CASE: + final RexCall caseCall = (RexCall) rex; + final List caseNodeList = + toSql(program, caseCall.getOperands()); + final SqlNode valueNode; + final List whenList = Expressions.list(); + final List thenList = Expressions.list(); + final SqlNode elseNode; + if (caseNodeList.size() % 2 == 0) { + // switched: + // "case x when v1 then t1 when v2 then t2 ... else e end" + valueNode = caseNodeList.get(0); + for (int i = 1; i < caseNodeList.size() - 1; i += 2) { + whenList.add(caseNodeList.get(i)); + thenList.add(caseNodeList.get(i + 1)); + } + } else { + // other: "case when w1 then t1 when w2 then t2 ... else e end" + valueNode = null; + for (int i = 0; i < caseNodeList.size() - 1; i += 2) { + whenList.add(caseNodeList.get(i)); + thenList.add(caseNodeList.get(i + 1)); + } + } + elseNode = caseNodeList.get(caseNodeList.size() - 1); + return new SqlCase(POS, valueNode, new SqlNodeList(whenList, POS), + new SqlNodeList(thenList, POS), elseNode); + + case DYNAMIC_PARAM: + final RexDynamicParam caseParam = (RexDynamicParam) rex; + return new SqlDynamicParam(caseParam.getIndex(), POS); + + case IN: + subQuery = (RexSubQuery) rex; + sqlSubQuery = implementor().visitRoot(subQuery.rel).asQueryOrValues(); + final List operands = subQuery.operands; + SqlNode op0; + if (operands.size() == 1) { + op0 = toSql(program, operands.get(0)); + } else { + final List cols = toSql(program, operands); + op0 = new SqlNodeList(cols, POS); + } + return subQuery.getOperator().createCall(POS, op0, sqlSubQuery); + + case SEARCH: + final RexCall search = (RexCall) rex; + literal = (RexLiteral) search.operands.get(1); + final Sarg sarg = castNonNull(literal.getValueAs(Sarg.class)); + //noinspection unchecked + return toSql(program, search.operands.get(0), literal.getType(), sarg); + + case EXISTS: + case SCALAR_QUERY: + subQuery = (RexSubQuery) rex; + sqlSubQuery = + implementor().visitRoot(subQuery.rel).asQueryOrValues(); + return subQuery.getOperator().createCall(POS, sqlSubQuery); + + case NOT: + RexNode operand = ((RexCall) rex).operands.get(0); + final SqlNode node = toSql(program, operand); + final SqlOperator inverseOperator = getInverseOperator(operand); + if (inverseOperator != null) { + switch (operand.getKind()) { + case IN: + assert operand instanceof RexSubQuery + : "scalar IN is no longer allowed in RexCall: " + rex; + break; + default: + break; + } + return inverseOperator.createCall(POS, + ((SqlCall) node).getOperandList()); + } else { + return SqlStdOperatorTable.NOT.createCall(POS, node); + } + + default: + if (rex instanceof RexOver) { + return toSql(program, (RexOver) rex); + } + + return callToSql(program, (RexCall) rex, false); + } + } + + private SqlNode callToSql(@Nullable RexProgram program, RexCall call0, + boolean not) { + final RexCall call1 = reverseCall(call0); + final RexCall call = (RexCall) stripCastFromString(call1, dialect); + SqlOperator op = call.getOperator(); + switch (op.getKind()) { + case SUM0: + op = SqlStdOperatorTable.SUM; + break; + case NOT: + RexNode operand = call.operands.get(0); + if (getInverseOperator(operand) != null) { + return callToSql(program, (RexCall) operand, !not); + } + break; + default: + break; + } + if (not) { + op = requireNonNull(getInverseOperator(call), + () -> "unable to negate " + call.getKind()); + } + final List nodeList = toSql(program, call.getOperands()); + switch (call.getKind()) { + case CAST: + // CURSOR is used inside CAST, like 'CAST ($0): CURSOR NOT NULL', + // convert it to sql call of {@link SqlStdOperatorTable#CURSOR}. + final RelDataType dataType = call.getType(); + if (dataType.getSqlTypeName() == SqlTypeName.CURSOR) { + final RexNode operand0 = call.operands.get(0); + assert operand0 instanceof RexInputRef; + int ordinal = ((RexInputRef) operand0).getIndex(); + SqlNode fieldOperand = field(ordinal); + return SqlStdOperatorTable.CURSOR.createCall(SqlParserPos.ZERO, fieldOperand); + } + if (ignoreCast) { + assert nodeList.size() == 1; + return nodeList.get(0); + } else { + nodeList.add(castNonNull(dialect.getCastSpec(call.getType()))); + } + break; + default: + break; + } + return SqlUtil.createCall(op, POS, nodeList); + } + + /** + * Reverses the order of a call, while preserving semantics, if it improves + * readability. + * + *

In the base implementation, this method does nothing; + * in a join context, reverses a call such as + * "e.deptno = d.deptno" to + * "d.deptno = e.deptno" + * if "d" is the left input to the join + * and "e" is the right. + */ + protected RexCall reverseCall(RexCall call) { + return call; + } + + /** + * If {@code node} is a {@link RexCall}, extracts the operator and + * finds the corresponding inverse operator using {@link SqlOperator#not()}. + * Returns null if {@code node} is not a {@link RexCall}, + * or if the operator has no logical inverse. + */ + private static @Nullable SqlOperator getInverseOperator(RexNode node) { + if (node instanceof RexCall) { + return ((RexCall) node).getOperator().not(); + } else { + return null; + } + } + + /** + * Converts a Sarg to SQL, generating "operand IN (c1, c2, ...)" if the + * ranges are all points. + */ + @SuppressWarnings({"BetaApi", "UnstableApiUsage"}) + private > SqlNode toSql(@Nullable RexProgram program, + RexNode operand, RelDataType type, + Sarg sarg) { + final List orList = new ArrayList<>(); + final SqlNode operandSql = toSql(program, operand); + if (sarg.nullAs == RexUnknownAs.TRUE) { + orList.add(SqlStdOperatorTable.IS_NULL.createCall(POS, operandSql)); + } + if (sarg.isPoints()) { + // generate 'x = 10' or 'x IN (10, 20, 30)' + orList.add( + toIn(operandSql, SqlStdOperatorTable.EQUALS, + SqlStdOperatorTable.IN, program, type, sarg.rangeSet)); + } else if (sarg.isComplementedPoints()) { + // generate 'x <> 10' or 'x NOT IN (10, 20, 30)' + orList.add( + toIn(operandSql, SqlStdOperatorTable.NOT_EQUALS, + SqlStdOperatorTable.NOT_IN, program, type, + sarg.rangeSet.complement())); + } else { + final RangeSets.Consumer consumer = + new RangeToSql<>(operandSql, orList, v -> + toSql(program, + implementor().rexBuilder.makeLiteral(v, type))); + RangeSets.forEach(sarg.rangeSet, consumer); + } + return SqlUtil.createCall(SqlStdOperatorTable.OR, POS, orList); + } + + @SuppressWarnings("BetaApi") + private > SqlNode toIn(SqlNode operandSql, + SqlBinaryOperator eqOp, SqlBinaryOperator inOp, + @Nullable RexProgram program, RelDataType type, + RangeSet rangeSet) { + final SqlNodeList list = rangeSet.asRanges().stream() + .map(range -> + toSql(program, + implementor().rexBuilder.makeLiteral(range.lowerEndpoint(), + type, true, true))) + .collect(SqlNode.toList()); + switch (list.size()) { + case 1: + return eqOp.createCall(POS, operandSql, list.get(0)); + default: + return inOp.createCall(POS, operandSql, list); + } + } + + /** + * Converts an expression from {@link RexWindowBound} to {@link SqlNode} + * format. + * + * @param rexWindowBound Expression to convert + */ + public SqlNode toSql(RexWindowBound rexWindowBound) { + final SqlNode offsetLiteral = + rexWindowBound.getOffset() == null + ? null + : SqlLiteral.createCharString(rexWindowBound.getOffset().toString(), + SqlParserPos.ZERO); + if (rexWindowBound.isPreceding()) { + return offsetLiteral == null + ? SqlWindow.createUnboundedPreceding(POS) + : SqlWindow.createPreceding(offsetLiteral, POS); + } else if (rexWindowBound.isFollowing()) { + return offsetLiteral == null + ? SqlWindow.createUnboundedFollowing(POS) + : SqlWindow.createFollowing(offsetLiteral, POS); + } else { + assert rexWindowBound.isCurrentRow(); + return SqlWindow.createCurrentRow(POS); + } + } + + public List toSql(Window.Group group, ImmutableList constants, + int inputFieldCount) { + final List rexOvers = new ArrayList<>(); + final List partitionKeys = new ArrayList<>(); + final List orderByKeys = new ArrayList<>(); + for (int partition : group.keys) { + partitionKeys.add(this.field(partition)); + } + for (RelFieldCollation collation : group.orderKeys.getFieldCollations()) { + this.addOrderItem(orderByKeys, collation); + } + SqlLiteral isRows = SqlLiteral.createBoolean(group.isRows, POS); + SqlNode lowerBound = null; + SqlNode upperBound = null; + + final SqlLiteral allowPartial = null; + + for (Window.RexWinAggCall winAggCall : group.aggCalls) { + SqlAggFunction aggFunction = (SqlAggFunction) winAggCall.getOperator(); + final SqlWindow sqlWindow = SqlWindow.create(null, null, + new SqlNodeList(partitionKeys, POS), new SqlNodeList(orderByKeys, POS), + isRows, lowerBound, upperBound, allowPartial, POS); + if (aggFunction.allowsFraming()) { + lowerBound = createSqlWindowBound(group.lowerBound); + upperBound = createSqlWindowBound(group.upperBound); + sqlWindow.setLowerBound(lowerBound); + sqlWindow.setUpperBound(upperBound); + } + + RexShuttle replaceConstants = new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + int index = inputRef.getIndex(); + RexNode ref; + if (index > inputFieldCount - 1) { + ref = constants.get(index - inputFieldCount); + } else { + ref = inputRef; + } + return ref; + } + }; + RexCall aggCall = (RexCall) winAggCall.accept(replaceConstants); + List operands = toSql(null, aggCall.operands); + rexOvers.add(createOverCall(aggFunction, operands, sqlWindow, winAggCall.distinct)); + } + return rexOvers; + } + + protected Context getAliasContext(RexCorrelVariable variable) { + throw new UnsupportedOperationException(); + } + + private SqlCall toSql(@Nullable RexProgram program, RexOver rexOver) { + final RexWindow rexWindow = rexOver.getWindow(); + final SqlNodeList partitionList = new SqlNodeList( + toSql(program, rexWindow.partitionKeys), POS); + + List orderNodes = Expressions.list(); + if (rexWindow.orderKeys != null) { + for (RexFieldCollation rfc : rexWindow.orderKeys) { + addOrderItem(orderNodes, program, rfc); + } + } + final SqlNodeList orderList = + new SqlNodeList(orderNodes, POS); + + final SqlLiteral isRows = + SqlLiteral.createBoolean(rexWindow.isRows(), POS); + + // null defaults to true. + // During parsing the allowPartial == false (e.g. disallow partial) + // is expand into CASE expression and is handled as a such. + // Not sure if we can collapse this CASE expression back into + // "disallow partial" and set the allowPartial = false. + final SqlLiteral allowPartial = null; + + SqlAggFunction sqlAggregateFunction = rexOver.getAggOperator(); + + SqlNode lowerBound = null; + SqlNode upperBound = null; + + if (sqlAggregateFunction.allowsFraming()) { + lowerBound = createSqlWindowBound(rexWindow.getLowerBound()); + upperBound = createSqlWindowBound(rexWindow.getUpperBound()); + } + + final SqlWindow sqlWindow = SqlWindow.create(null, null, partitionList, + orderList, isRows, lowerBound, upperBound, allowPartial, POS); + + final List nodeList = toSql(program, rexOver.getOperands()); + return createOverCall(sqlAggregateFunction, nodeList, sqlWindow, rexOver.isDistinct()); + } + + private static SqlCall createOverCall(SqlAggFunction op, List operands, + SqlWindow window, boolean isDistinct) { + if (op instanceof SqlSumEmptyIsZeroAggFunction) { + // Rewrite "SUM0(x) OVER w" to "COALESCE(SUM(x) OVER w, 0)" + final SqlCall node = + createOverCall(SqlStdOperatorTable.SUM, operands, window, isDistinct); + return SqlStdOperatorTable.COALESCE.createCall(POS, node, + SqlLiteral.createExactNumeric("0", POS)); + } + SqlCall aggFunctionCall; + if (isDistinct) { + aggFunctionCall = op.createCall( + SqlSelectKeyword.DISTINCT.symbol(POS), + POS, + operands); + } else { + aggFunctionCall = op.createCall(POS, operands); + } + return SqlStdOperatorTable.OVER.createCall(POS, aggFunctionCall, + window); + } + + private SqlNode toSql(@Nullable RexProgram program, RexFieldCollation rfc) { + SqlNode node = toSql(program, rfc.left); + switch (rfc.getDirection()) { + case DESCENDING: + case STRICTLY_DESCENDING: + node = SqlStdOperatorTable.DESC.createCall(POS, node); + break; + default: + break; + } + if (rfc.getNullDirection() + != dialect.defaultNullDirection(rfc.getDirection())) { + switch (rfc.getNullDirection()) { + case FIRST: + node = SqlStdOperatorTable.NULLS_FIRST.createCall(POS, node); + break; + case LAST: + node = SqlStdOperatorTable.NULLS_LAST.createCall(POS, node); + break; + default: + break; + } + } + return node; + } + + private SqlNode createSqlWindowBound(RexWindowBound rexWindowBound) { + if (rexWindowBound.isCurrentRow()) { + return SqlWindow.createCurrentRow(POS); + } + if (rexWindowBound.isPreceding()) { + if (rexWindowBound.isUnbounded()) { + return SqlWindow.createUnboundedPreceding(POS); + } else { + SqlNode literal = toSql(null, rexWindowBound.getOffset()); + return SqlWindow.createPreceding(literal, POS); + } + } + if (rexWindowBound.isFollowing()) { + if (rexWindowBound.isUnbounded()) { + return SqlWindow.createUnboundedFollowing(POS); + } else { + SqlNode literal = toSql(null, rexWindowBound.getOffset()); + return SqlWindow.createFollowing(literal, POS); + } + } + + throw new AssertionError("Unsupported Window bound: " + + rexWindowBound); + } + + private List toSql(@Nullable RexProgram program, List operandList) { + final List list = new ArrayList<>(); + for (RexNode rex : operandList) { + list.add(toSql(program, rex)); + } + return list; + } + + public List fieldList() { + return new AbstractList() { + @Override + public SqlNode get(int index) { + return field(index); + } + + @Override + public int size() { + return fieldCount; + } + }; + } + + void addOrderItem(List orderByList, RelFieldCollation field) { + if (field.nullDirection != RelFieldCollation.NullDirection.UNSPECIFIED) { + final boolean first = + field.nullDirection == RelFieldCollation.NullDirection.FIRST; + SqlNode nullDirectionNode = + dialect.emulateNullDirection(field(field.getFieldIndex()), + first, field.direction.isDescending()); + if (nullDirectionNode != null) { + orderByList.add(nullDirectionNode); + field = new RelFieldCollation(field.getFieldIndex(), + field.getDirection(), + RelFieldCollation.NullDirection.UNSPECIFIED); + } + } + orderByList.add(toSql(field)); + } + + /** + * Converts a RexFieldCollation to an ORDER BY item. + */ + private void addOrderItem(List orderByList, + @Nullable RexProgram program, RexFieldCollation field) { + SqlNode node = toSql(program, field.left); + SqlNode nullDirectionNode = null; + if (field.getNullDirection() != RelFieldCollation.NullDirection.UNSPECIFIED) { + final boolean first = + field.getNullDirection() == RelFieldCollation.NullDirection.FIRST; + nullDirectionNode = dialect.emulateNullDirection( + node, first, field.getDirection().isDescending()); + } + if (nullDirectionNode != null) { + orderByList.add(nullDirectionNode); + switch (field.getDirection()) { + case DESCENDING: + case STRICTLY_DESCENDING: + node = SqlStdOperatorTable.DESC.createCall(POS, node); + break; + default: + break; + } + orderByList.add(node); + } else { + orderByList.add(toSql(program, field)); + } + } + + /** + * Converts a call to an aggregate function to an expression. + */ + public SqlNode toSql(AggregateCall aggCall) { + return toSql(aggCall.getAggregation(), aggCall.isDistinct(), + Util.transform(aggCall.getArgList(), this::field), + aggCall.filterArg, aggCall.collation); + } + + /** + * Converts a call to an aggregate function, with a given list of operands, + * to an expression. + */ + private SqlCall toSql(SqlOperator op, boolean distinct, + List operandList, int filterArg, RelCollation collation) { + final SqlLiteral qualifier = + distinct ? SqlSelectKeyword.DISTINCT.symbol(POS) : null; + if (op instanceof SqlSumEmptyIsZeroAggFunction) { + final SqlNode node = toSql(SqlStdOperatorTable.SUM, distinct, + operandList, filterArg, collation); + return SqlStdOperatorTable.COALESCE.createCall(POS, node, + SqlLiteral.createExactNumeric("0", POS)); + } + + // Handle filter on dialects that do support FILTER by generating CASE. + if (filterArg >= 0 && !dialect.supportsAggregateFunctionFilter()) { + // SUM(x) FILTER(WHERE b) ==> SUM(CASE WHEN b THEN x END) + // COUNT(*) FILTER(WHERE b) ==> COUNT(CASE WHEN b THEN 1 END) + // COUNT(x) FILTER(WHERE b) ==> COUNT(CASE WHEN b THEN x END) + // COUNT(x, y) FILTER(WHERE b) ==> COUNT(CASE WHEN b THEN x END, y) + final SqlNodeList whenList = SqlNodeList.of(field(filterArg)); + final SqlNodeList thenList = + SqlNodeList.of(operandList.isEmpty() + ? SqlLiteral.createExactNumeric("1", POS) + : operandList.get(0)); + final SqlNode elseList = SqlLiteral.createNull(POS); + final SqlCall caseCall = + SqlStdOperatorTable.CASE.createCall(null, POS, null, whenList, + thenList, elseList); + final List newOperandList = new ArrayList<>(); + newOperandList.add(caseCall); + if (operandList.size() > 1) { + newOperandList.addAll(Util.skip(operandList)); + } + return toSql(op, distinct, newOperandList, -1, collation); + } + + if (op instanceof SqlCountAggFunction && operandList.isEmpty()) { + // If there is no parameter in "count" function, add a star identifier + // to it. + operandList = ImmutableList.of(SqlIdentifier.STAR); + } + final SqlCall call = + op.createCall(qualifier, POS, operandList); + + // Handle filter by generating FILTER (WHERE ...) + final SqlCall call2; + if (filterArg < 0) { + call2 = call; + } else { + assert dialect.supportsAggregateFunctionFilter(); // we checked above + call2 = SqlStdOperatorTable.FILTER.createCall(POS, call, + field(filterArg)); + } + + // Handle collation + return withOrder(call2, collation); + } + + /** + * Wraps a call in a {@link SqlKind#WITHIN_GROUP} call, if + * {@code collation} is non-empty. + */ + private SqlCall withOrder(SqlCall call, RelCollation collation) { + if (collation.getFieldCollations().isEmpty()) { + return call; + } + final List orderByList = new ArrayList<>(); + for (RelFieldCollation field : collation.getFieldCollations()) { + addOrderItem(orderByList, field); + } + return SqlStdOperatorTable.WITHIN_GROUP.createCall(POS, call, + new SqlNodeList(orderByList, POS)); + } + + /** + * Converts a collation to an ORDER BY item. + */ + public SqlNode toSql(RelFieldCollation collation) { + SqlNode node = orderField(collation.getFieldIndex()); + switch (collation.getDirection()) { + case DESCENDING: + case STRICTLY_DESCENDING: + node = SqlStdOperatorTable.DESC.createCall(POS, node); + break; + default: + break; + } + if (collation.nullDirection != dialect.defaultNullDirection(collation.direction)) { + switch (collation.nullDirection) { + case FIRST: + node = SqlStdOperatorTable.NULLS_FIRST.createCall(POS, node); + break; + case LAST: + node = SqlStdOperatorTable.NULLS_LAST.createCall(POS, node); + break; + default: + break; + } + } + return node; + } + + public abstract SqlImplementor implementor(); + + /** + * Converts a {@link Range} to a SQL expression. + * + * @param Value type + */ + private static class RangeToSql> + implements RangeSets.Consumer { + private final List list; + private final Function literalFactory; + private final SqlNode arg; + + RangeToSql(SqlNode arg, List list, + Function literalFactory) { + this.arg = arg; + this.list = list; + this.literalFactory = literalFactory; + } + + private void addAnd(SqlNode... nodes) { + list.add( + SqlUtil.createCall(SqlStdOperatorTable.AND, POS, + ImmutableList.copyOf(nodes))); + } + + private SqlNode op(SqlOperator op, C value) { + return op.createCall(POS, arg, literalFactory.apply(value)); + } + + @Override + public void all() { + list.add(SqlLiteral.createBoolean(true, POS)); + } + + @Override + public void atLeast(C lower) { + list.add(op(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, lower)); + } + + @Override + public void atMost(C upper) { + list.add(op(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, upper)); + } + + @Override + public void greaterThan(C lower) { + list.add(op(SqlStdOperatorTable.GREATER_THAN, lower)); + } + + @Override + public void lessThan(C upper) { + list.add(op(SqlStdOperatorTable.LESS_THAN, upper)); + } + + @Override + public void singleton(C value) { + list.add(op(SqlStdOperatorTable.EQUALS, value)); + } + + @Override + public void closed(C lower, C upper) { + addAnd(op(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, lower), + op(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, upper)); + } + + @Override + public void closedOpen(C lower, C upper) { + addAnd(op(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, lower), + op(SqlStdOperatorTable.LESS_THAN, upper)); + } + + @Override + public void openClosed(C lower, C upper) { + addAnd(op(SqlStdOperatorTable.GREATER_THAN, lower), + op(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, upper)); + } + + @Override + public void open(C lower, C upper) { + addAnd(op(SqlStdOperatorTable.GREATER_THAN, lower), + op(SqlStdOperatorTable.LESS_THAN, upper)); + } + } + } + + /** + * Converts a {@link RexLiteral} in the context of a {@link RexProgram} + * to a {@link SqlNode}. + */ + public static SqlNode toSql(@Nullable RexProgram program, RexLiteral literal) { + switch (literal.getTypeName()) { + case SYMBOL: + final Enum symbol = (Enum) literal.getValue(); + return SqlLiteral.createSymbol(symbol, POS); + + case ROW: + //noinspection unchecked + final List list = castNonNull(literal.getValueAs(List.class)); + return SqlStdOperatorTable.ROW.createCall(POS, + list.stream().map(e -> toSql(program, e)) + .collect(Util.toImmutableList())); + + case SARG: + final Sarg arg = literal.getValueAs(Sarg.class); + throw new AssertionError("sargs [" + arg + + "] should be handled as part of predicates, not as literals"); + + default: + return toSql(literal); + } + } + + /** + * Converts a {@link RexLiteral} to a {@link SqlLiteral}. + */ + public static SqlNode toSql(RexLiteral literal) { + SqlTypeName typeName = literal.getTypeName(); + switch (typeName) { + case SYMBOL: + final Enum symbol = (Enum) literal.getValue(); + return SqlLiteral.createSymbol(symbol, POS); + + case ROW: + //noinspection unchecked + final List list = castNonNull(literal.getValueAs(List.class)); + return SqlStdOperatorTable.ROW.createCall(POS, + list.stream().map(e -> toSql(e)) + .collect(Util.toImmutableList())); + + case SARG: + final Sarg arg = literal.getValueAs(Sarg.class); + throw new AssertionError("sargs [" + arg + + "] should be handled as part of predicates, not as literals"); + default: + break; + } + SqlTypeFamily family = requireNonNull(typeName.getFamily(), + () -> "literal " + literal + " has null SqlTypeFamily, and is SqlTypeName is " + typeName); + switch (family) { + case CHARACTER: + return SqlLiteral.createCharString((String) castNonNull(literal.getValue2()), POS); + case NUMERIC: + case EXACT_NUMERIC: + return SqlLiteral.createExactNumeric( + castNonNull(literal.getValueAs(BigDecimal.class)).toPlainString(), POS); + case APPROXIMATE_NUMERIC: + return SqlLiteral.createApproxNumeric( + castNonNull(literal.getValueAs(BigDecimal.class)).toPlainString(), POS); + case BOOLEAN: + return SqlLiteral.createBoolean(castNonNull(literal.getValueAs(Boolean.class)), + POS); + case INTERVAL_YEAR_MONTH: + case INTERVAL_DAY_TIME: + final boolean negative = castNonNull(literal.getValueAs(Boolean.class)); + return SqlLiteral.createInterval(negative ? -1 : 1, + castNonNull(literal.getValueAs(String.class)), + castNonNull(literal.getType().getIntervalQualifier()), POS); + case DATE: + return SqlLiteral.createDate(castNonNull(literal.getValueAs(DateString.class)), + POS); + case TIME: + return SqlLiteral.createTime(castNonNull(literal.getValueAs(TimeString.class)), + literal.getType().getPrecision(), POS); + case TIMESTAMP: + return SqlLiteral.createTimestamp( + castNonNull(literal.getValueAs(TimestampString.class)), + literal.getType().getPrecision(), POS); + case ANY: + case NULL: + switch (typeName) { + case NULL: + return SqlLiteral.createNull(POS); + default: + break; + } + // fall through + default: + throw new AssertionError(literal + ": " + typeName); + } + } + + /** + * Simple implementation of {@link Context} that cannot handle sub-queries + * or correlations. Because it is so simple, you do not need to create a + * {@link SqlImplementor} or {@link org.apache.calcite.tools.RelBuilder} + * to use it. It is a good way to convert a {@link RexNode} to SQL text. + */ + public static class SimpleContext extends Context { + private final IntFunction field; + + public SimpleContext(SqlDialect dialect, IntFunction field) { + super(dialect, 0, false); + this.field = field; + } + + @Override + public SqlImplementor implementor() { + throw new UnsupportedOperationException(); + } + + @Override + public SqlNode field(int ordinal) { + return field.apply(ordinal); + } + } + + /** + * Implementation of {@link Context} that has an enclosing + * {@link SqlImplementor} and can therefore do non-trivial expressions. + */ + protected abstract class BaseContext extends Context { + BaseContext(SqlDialect dialect, int fieldCount) { + super(dialect, fieldCount); + } + + @Override + protected Context getAliasContext(RexCorrelVariable variable) { + return requireNonNull( + correlTableMap.get(variable.id), + () -> "variable " + variable.id + " is not found"); + } + + @Override + public SqlImplementor implementor() { + return SqlImplementor.this; + } + } + + private static int computeFieldCount( + Map aliases) { + int x = 0; + for (RelDataType type : aliases.values()) { + x += type.getFieldCount(); + } + return x; + } + + public Context aliasContext(Map aliases, + boolean qualified) { + return new AliasContext(dialect, aliases, qualified); + } + + public Context joinContext(Context leftContext, Context rightContext) { + return new JoinContext(dialect, leftContext, rightContext); + } + + public Context matchRecognizeContext(Context context) { + return new MatchRecognizeContext(dialect, ((AliasContext) context).aliases); + } + + public Context tableFunctionScanContext(List inputSqlNodes) { + return new TableFunctionScanContext(dialect, inputSqlNodes); + } + + /** + * Context for translating MATCH_RECOGNIZE clause. + */ + public class MatchRecognizeContext extends AliasContext { + protected MatchRecognizeContext(SqlDialect dialect, + Map aliases) { + super(dialect, aliases, false); + } + + @Override + public SqlNode toSql(@Nullable RexProgram program, RexNode rex) { + if (rex.getKind() == SqlKind.LITERAL) { + final RexLiteral literal = (RexLiteral) rex; + if (literal.getTypeName().getFamily() == SqlTypeFamily.CHARACTER) { + return new SqlIdentifier(castNonNull(RexLiteral.stringValue(literal)), POS); + } + } + return super.toSql(program, rex); + } + } + + /** + * Implementation of Context that precedes field references with their + * "table alias" based on the current sub-query's FROM clause. + */ + public class AliasContext extends BaseContext { + private final boolean qualified; + private final Map aliases; + + /** + * Creates an AliasContext; use {@link #aliasContext(Map, boolean)}. + */ + protected AliasContext(SqlDialect dialect, + Map aliases, boolean qualified) { + super(dialect, computeFieldCount(aliases)); + this.aliases = aliases; + this.qualified = qualified; + } + + @Override + public SqlNode field(int ordinal) { + for (Map.Entry alias : aliases.entrySet()) { + final List fields = alias.getValue().getFieldList(); + if (ordinal < fields.size()) { + RelDataTypeField field = fields.get(ordinal); + return new SqlIdentifier(!qualified + ? ImmutableList.of(field.getName()) + : ImmutableList.of(alias.getKey(), field.getName()), + POS); + } + ordinal -= fields.size(); + } + throw new AssertionError( + "field ordinal " + ordinal + " out of range " + aliases); + } + } + + /** + * Context for translating ON clause of a JOIN from {@link RexNode} to + * {@link SqlNode}. + */ + class JoinContext extends BaseContext { + private final Context leftContext; + private final Context rightContext; + + /** + * Creates a JoinContext; use {@link #joinContext(Context, Context)}. + */ + private JoinContext(SqlDialect dialect, Context leftContext, + Context rightContext) { + super(dialect, leftContext.fieldCount + rightContext.fieldCount); + this.leftContext = leftContext; + this.rightContext = rightContext; + } + + @Override + public SqlNode field(int ordinal) { + if (ordinal < leftContext.fieldCount) { + return leftContext.field(ordinal); + } else { + return rightContext.field(ordinal - leftContext.fieldCount); + } + } + + @Override + protected RexCall reverseCall(RexCall call) { + switch (call.getKind()) { + case EQUALS: + case IS_DISTINCT_FROM: + case IS_NOT_DISTINCT_FROM: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + assert call.operands.size() == 2; + final RexNode op0 = call.operands.get(0); + final RexNode op1 = call.operands.get(1); + if (op0 instanceof RexInputRef + && op1 instanceof RexInputRef + && ((RexInputRef) op1).getIndex() < leftContext.fieldCount + && ((RexInputRef) op0).getIndex() >= leftContext.fieldCount) { + // Arguments were of form 'op1 = op0' + final SqlOperator op2 = requireNonNull(call.getOperator().reverse()); + return (RexCall) rexBuilder.makeCall(op2, op1, op0); + } + // fall through + default: + return call; + } + } + } + + /** + * Context for translating call of a TableFunctionScan from {@link RexNode} to + * {@link SqlNode}. + */ + class TableFunctionScanContext extends BaseContext { + private final List inputSqlNodes; + + TableFunctionScanContext(SqlDialect dialect, List inputSqlNodes) { + super(dialect, inputSqlNodes.size()); + this.inputSqlNodes = inputSqlNodes; + } + + @Override + public SqlNode field(int ordinal) { + return inputSqlNodes.get(ordinal); + } + } + + /** + * Result of implementing a node. + */ + public class Result { + final SqlNode node; + final @Nullable String neededAlias; + private final @Nullable RelDataType neededType; + private final Map aliases; + final List clauses; + private final boolean anon; + /** + * Whether to treat {@link #expectedClauses} as empty for the + * purposes of figuring out whether we need a new sub-query. + */ + private final boolean ignoreClauses; + /** + * Clauses that will be generated to implement current relational + * expression. + */ + private final ImmutableSet expectedClauses; + private final @Nullable RelNode expectedRel; + private final boolean needNew; + + public Result(SqlNode node, Collection clauses, @Nullable String neededAlias, + @Nullable RelDataType neededType, Map aliases) { + this(node, clauses, neededAlias, neededType, aliases, false, false, + ImmutableSet.of(), null); + } + + private Result(SqlNode node, Collection clauses, @Nullable String neededAlias, + @Nullable RelDataType neededType, Map aliases, boolean anon, + boolean ignoreClauses, Set expectedClauses, + @Nullable RelNode expectedRel) { + this.node = node; + this.neededAlias = neededAlias; + this.neededType = neededType; + this.aliases = aliases; + this.clauses = ImmutableList.copyOf(clauses); + this.anon = anon; + this.ignoreClauses = ignoreClauses; + this.expectedClauses = ImmutableSet.copyOf(expectedClauses); + this.expectedRel = expectedRel; + final Set clauses2 = + ignoreClauses ? ImmutableSet.of() : expectedClauses; + this.needNew = expectedRel != null + && needNewSubQuery(expectedRel, this.clauses, clauses2); + } + + /** + * Creates a builder for the SQL of the given relational expression, + * using the clauses that you declared when you called + * {@link #visitInput(RelNode, int, Set)}. + */ + public Builder builder(RelNode rel) { + return builder(rel, expectedClauses); + } + + // CHECKSTYLE: IGNORE 3 + + /** + * @deprecated Provide the expected clauses up-front, when you call + * {@link #visitInput(RelNode, int, Set)}, then create a builder using + * {@link #builder(RelNode)}. + */ + @Deprecated // to be removed before 2.0 + public Builder builder(RelNode rel, Clause clause, Clause... clauses) { + return builder(rel, ImmutableSet.copyOf(Lists.asList(clause, clauses))); + } + + /** + * Once you have a Result of implementing a child relational expression, + * call this method to create a Builder to implement the current relational + * expression by adding additional clauses to the SQL query. + * + *

You need to declare which clauses you intend to add. If the clauses + * are "later", you can add to the same query. For example, "GROUP BY" comes + * after "WHERE". But if they are the same or earlier, this method will + * start a new SELECT that wraps the previous result. + * + *

When you have called + * {@link Builder#setSelect(SqlNodeList)}, + * {@link Builder#setWhere(SqlNode)} etc. call + * {@link Builder#result(SqlNode, Collection, RelNode, Map)} + * to fix the new query. + * + * @param rel Relational expression being implemented + * @return A builder + */ + private Builder builder(RelNode rel, Set clauses) { + assert expectedClauses.containsAll(clauses); + assert rel.equals(expectedRel); + final Set clauses2 = ignoreClauses ? ImmutableSet.of() : clauses; + final boolean needNew = needNewSubQuery(rel, this.clauses, clauses2); + assert needNew == this.needNew; + SqlSelect select; + Expressions.FluentList clauseList = Expressions.list(); + if (needNew) { + select = subSelect(); + } else { + select = asSelect(); + clauseList.addAll(this.clauses); + } + clauseList.appendAll(clauses); + final Context newContext; + Map newAliases = null; + final SqlNodeList selectList = select.getSelectList(); + if (!selectList.equals(SqlNodeList.SINGLETON_STAR)) { + final boolean aliasRef = expectedClauses.contains(Clause.HAVING) + && dialect.getConformance().isHavingAlias(); + newContext = new Context(dialect, selectList.size()) { + @Override + public SqlImplementor implementor() { + return SqlImplementor.this; + } + + @Override + public SqlNode field(int ordinal) { + final SqlNode selectItem = selectList.get(ordinal); + switch (selectItem.getKind()) { + case AS: + final SqlCall asCall = (SqlCall) selectItem; + if (aliasRef) { + // For BigQuery, given the query + // SELECT SUM(x) AS x FROM t HAVING(SUM(t.x) > 0) + // we can generate + // SELECT SUM(x) AS x FROM t HAVING(x > 0) + // because 'x' in HAVING resolves to the 'AS x' not 't.x'. + return asCall.operand(1); + } + return asCall.operand(0); + default: + break; + } + return selectItem; + } + + @Override + public SqlNode orderField(int ordinal) { + // If the field expression is an unqualified column identifier + // and matches a different alias, use an ordinal. + // For example, given + // SELECT deptno AS empno, empno AS x FROM emp ORDER BY emp.empno + // we generate + // SELECT deptno AS empno, empno AS x FROM emp ORDER BY 2 + // "ORDER BY empno" would give incorrect result; + // "ORDER BY x" is acceptable but is not preferred. + final SqlNode node = field(ordinal); + if (node instanceof SqlIdentifier + && ((SqlIdentifier) node).isSimple()) { + final String name = ((SqlIdentifier) node).getSimple(); + for (Ord selectItem : Ord.zip(selectList)) { + if (selectItem.i != ordinal) { + final String alias = + SqlValidatorUtil.getAlias(selectItem.e, -1); + if (name.equalsIgnoreCase(alias)) { + return SqlLiteral.createExactNumeric( + Integer.toString(ordinal + 1), SqlParserPos.ZERO); + } + } + } + } + return node; + } + }; + } else { + boolean qualified = + !dialect.hasImplicitTableAlias() || aliases.size() > 1; + // basically, we did a subSelect() since needNew is set and neededAlias is not null + // now, we need to make sure that we need to update the alias context. + // if our aliases map has a single element: , + // then we don't need to rewrite the alias but otherwise, it should be updated. + if (needNew + && neededAlias != null + && (aliases.size() != 1 || !aliases.containsKey(neededAlias))) { + newAliases = + ImmutableMap.of(neededAlias, rel.getInput(0).getRowType()); + newContext = aliasContext(newAliases, qualified); + } else { + newContext = aliasContext(aliases, qualified); + } + } + return new Builder(rel, clauseList, select, newContext, isAnon(), + needNew && !aliases.containsKey(neededAlias) ? newAliases : aliases); + } + + /** + * Returns whether a new sub-query is required. + */ + private boolean needNewSubQuery( + @UnknownInitialization Result this, + RelNode rel, List clauses, + Set expectedClauses) { + if (clauses.isEmpty()) { + return false; + } + final Clause maxClause = Collections.max(clauses); + // If old and new clause are equal and belong to below set, + // then new SELECT wrap is not required + final Set nonWrapSet = ImmutableSet.of(Clause.SELECT); + for (Clause clause : expectedClauses) { + if (maxClause.ordinal() > clause.ordinal() + || (maxClause == clause + && !nonWrapSet.contains(clause))) { + return true; + } + } + + if (rel instanceof Project + && clauses.contains(Clause.HAVING) + && dialect.getConformance().isHavingAlias()) { + return true; + } + + if (rel instanceof Project + && ((Project) rel).containsOver() + && maxClause == Clause.SELECT) { + // Cannot merge a Project that contains windowed functions onto an + // underlying Project + return true; + } + + if (rel instanceof Aggregate) { + final Aggregate agg = (Aggregate) rel; + final boolean hasNestedAgg = + hasNested(agg, SqlImplementor::isAggregate); + final boolean hasNestedWindowedAgg = + hasNested(agg, SqlImplementor::isWindowedAggregate); + if (!dialect.supportsNestedAggregations() + && (hasNestedAgg || hasNestedWindowedAgg)) { + return true; + } + + if (clauses.contains(Clause.GROUP_BY)) { + // Avoid losing the distinct attribute of inner aggregate. + return !hasNestedAgg || Aggregate.isNotGrandTotal(agg); + } + } + + return false; + } + + /** + * Returns whether an {@link Aggregate} contains nested operands that + * match the predicate. + * + * @param aggregate Aggregate node + * @param operandPredicate Predicate for the nested operands + * @return whether any nested operands matches the predicate + */ + private boolean hasNested( + @UnknownInitialization Result this, + Aggregate aggregate, + Predicate operandPredicate) { + if (node instanceof SqlSelect) { + final SqlNodeList selectList = ((SqlSelect) node).getSelectList(); + if (!selectList.equals(SqlNodeList.SINGLETON_STAR)) { + final Set aggregatesArgs = new HashSet<>(); + for (AggregateCall aggregateCall : aggregate.getAggCallList()) { + aggregatesArgs.addAll(aggregateCall.getArgList()); + } + for (int aggregatesArg : aggregatesArgs) { + if (selectList.get(aggregatesArg) instanceof SqlBasicCall) { + final SqlBasicCall call = + (SqlBasicCall) selectList.get(aggregatesArg); + for (SqlNode operand : call.getOperands()) { + if (operand != null && operandPredicate.test(operand)) { + return true; + } + } + } + } + } + } + return false; + } + + /** + * Returns the highest clause that is in use. + */ + @Deprecated + public Clause maxClause() { + return Collections.max(clauses); + } + + /** + * Returns a node that can be included in the FROM clause or a JOIN. It has + * an alias that is unique within the query. The alias is implicit if it + * can be derived using the usual rules (For example, "SELECT * FROM emp" is + * equivalent to "SELECT * FROM emp AS emp".) + */ + public SqlNode asFrom() { + if (neededAlias != null) { + if (node.getKind() == SqlKind.AS) { + // If we already have an AS node, we need to replace the alias + // This is especially relevant for the VALUES clause rendering + SqlCall sqlCall = (SqlCall) node; + @SuppressWarnings("assignment.type.incompatible") + SqlNode[] operands = sqlCall.getOperandList().toArray(new SqlNode[0]); + operands[1] = new SqlIdentifier(neededAlias, POS); + return SqlStdOperatorTable.AS.createCall(POS, operands); + } else { + return SqlStdOperatorTable.AS.createCall(POS, node, + new SqlIdentifier(neededAlias, POS)); + } + } + return node; + } + + public SqlSelect subSelect() { + return wrapSelect(asFrom()); + } + + /** + * Converts a non-query node into a SELECT node. Set operators (UNION, + * INTERSECT, EXCEPT) remain as is. + */ + public SqlSelect asSelect() { + if (node instanceof SqlSelect) { + return (SqlSelect) node; + } + if (!dialect.hasImplicitTableAlias()) { + return wrapSelect(asFrom()); + } + return wrapSelect(node); + } + + public void stripTrivialAliases(SqlNode node) { + switch (node.getKind()) { + case SELECT: + final SqlSelect select = (SqlSelect) node; + final SqlNodeList nodeList = select.getSelectList(); + if (nodeList != null) { + for (int i = 0; i < nodeList.size(); i++) { + final SqlNode n = nodeList.get(i); + if (n.getKind() == SqlKind.AS) { + final SqlCall call = (SqlCall) n; + final SqlIdentifier identifier = call.operand(1); + if (identifier.getSimple().toLowerCase(Locale.ROOT) + .startsWith("expr$")) { + nodeList.set(i, call.operand(0)); + } + } + } + } + break; + + case UNION: + case INTERSECT: + case EXCEPT: + case INSERT: + case UPDATE: + case DELETE: + case MERGE: + final SqlCall call = (SqlCall) node; + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + stripTrivialAliases(operand); + } + } + break; + default: + break; + } + } + + /** + * Strips trivial aliases if anon. + */ + private SqlNode maybeStrip(SqlNode node) { + if (anon) { + stripTrivialAliases(node); + } + return node; + } + + /** + * Converts a non-query node into a SELECT node. Set operators (UNION, + * INTERSECT, EXCEPT) and DML operators (INSERT, UPDATE, DELETE, MERGE) + * remain as is. + */ + public SqlNode asStatement() { + switch (node.getKind()) { + case UNION: + case INTERSECT: + case EXCEPT: + case INSERT: + case UPDATE: + case DELETE: + case MERGE: + return maybeStrip(node); + default: + return maybeStrip(asSelect()); + } + } + + /** + * Converts a non-query node into a SELECT node. Set operators (UNION, + * INTERSECT, EXCEPT) and VALUES remain as is. + */ + public SqlNode asQueryOrValues() { + switch (node.getKind()) { + case UNION: + case INTERSECT: + case EXCEPT: + case VALUES: + return maybeStrip(node); + default: + return maybeStrip(asSelect()); + } + } + + /** + * Returns a context that always qualifies identifiers. Useful if the + * Context deals with just one arm of a join, yet we wish to generate + * a join condition that qualifies column names to disambiguate them. + */ + public Context qualifiedContext() { + return aliasContext(aliases, true); + } + + /** + * In join, when the left and right nodes have been generated, + * update their alias with 'neededAlias' if not null. + */ + public Result resetAlias() { + if (neededAlias == null) { + return this; + } else { + return new Result(node, clauses, neededAlias, neededType, + ImmutableMap.of(neededAlias, castNonNull(neededType)), anon, ignoreClauses, + expectedClauses, expectedRel); + } + } + + /** + * Sets the alias of the join or correlate just created. + * + * @param alias New alias + * @param type type of the node associated with the alias + */ + public Result resetAlias(String alias, RelDataType type) { + return new Result(node, clauses, alias, neededType, + ImmutableMap.of(alias, type), anon, ignoreClauses, + expectedClauses, expectedRel); + } + + /** + * Returns a copy of this Result, overriding the value of {@code anon}. + */ + Result withAnon(boolean anon) { + return anon == this.anon ? this + : new Result(node, clauses, neededAlias, neededType, aliases, anon, + ignoreClauses, expectedClauses, expectedRel); + } + + /** + * Returns a copy of this Result, overriding the value of + * {@code ignoreClauses} and {@code expectedClauses}. + */ + Result withExpectedClauses(boolean ignoreClauses, + Set expectedClauses, RelNode expectedRel) { + return ignoreClauses == this.ignoreClauses + && expectedClauses.equals(this.expectedClauses) + && expectedRel == this.expectedRel + ? this + : new Result(node, clauses, neededAlias, neededType, aliases, anon, + ignoreClauses, ImmutableSet.copyOf(expectedClauses), expectedRel); + } + } + + /** + * Builder. + */ + public class Builder { + private final RelNode rel; + final List clauses; + final SqlSelect select; + public final Context context; + final boolean anon; + private final @Nullable Map aliases; + + public Builder(RelNode rel, List clauses, SqlSelect select, + Context context, boolean anon, + @Nullable Map aliases) { + this.rel = requireNonNull(rel, "rel"); + this.clauses = ImmutableList.copyOf(clauses); + this.select = requireNonNull(select, "select"); + this.context = requireNonNull(context, "context"); + this.anon = anon; + this.aliases = aliases; + } + + public void setSelect(SqlNodeList nodeList) { + select.setSelectList(nodeList); + } + + public void setWhere(SqlNode node) { + assert clauses.contains(Clause.WHERE); + select.setWhere(node); + } + + public void setGroupBy(SqlNodeList nodeList) { + assert clauses.contains(Clause.GROUP_BY); + select.setGroupBy(nodeList); + } + + public void setHaving(SqlNode node) { + assert clauses.contains(Clause.HAVING); + select.setHaving(node); + } + + public void setOrderBy(SqlNodeList nodeList) { + assert clauses.contains(Clause.ORDER_BY); + select.setOrderBy(nodeList); + } + + public void setFetch(SqlNode fetch) { + assert clauses.contains(Clause.FETCH); + select.setFetch(fetch); + } + + public void setOffset(SqlNode offset) { + assert clauses.contains(Clause.OFFSET); + select.setOffset(offset); + } + + public void addOrderItem(List orderByList, + RelFieldCollation field) { + context.addOrderItem(orderByList, field); + } + + public Result result() { + return SqlImplementor.this.result(select, clauses, rel, aliases) + .withAnon(anon); + } + } + + /** + * Clauses in a SQL query. Ordered by evaluation order. + * SELECT is set only when there is a NON-TRIVIAL SELECT clause. + */ + public enum Clause { + FROM, WHERE, GROUP_BY, HAVING, SELECT, SET_OP, ORDER_BY, FETCH, OFFSET + } +} diff --git a/traindb-core/src/main/java/traindb/schema/TrainDBJdbcDataSource.java b/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcDataSource.java similarity index 72% rename from traindb-core/src/main/java/traindb/schema/TrainDBJdbcDataSource.java rename to traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcDataSource.java index c0d2817..1c63d82 100644 --- a/traindb-core/src/main/java/traindb/schema/TrainDBJdbcDataSource.java +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcDataSource.java @@ -12,7 +12,7 @@ * limitations under the License. */ -package traindb.schema; +package traindb.adapter.jdbc; import static java.util.Objects.requireNonNull; @@ -20,34 +20,29 @@ import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; -import java.util.Map; import javax.sql.DataSource; -import org.apache.calcite.adapter.jdbc.JdbcCatalogSchema; -import org.apache.calcite.adapter.jdbc.JdbcConvention; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.Schemas; -import org.apache.calcite.schema.impl.AbstractSchema; import org.apache.calcite.sql.SqlDialect; import traindb.common.TrainDBLogger; +import traindb.schema.TrainDBDataSource; -public class TrainDBJdbcDataSource extends AbstractSchema { +public class TrainDBJdbcDataSource extends TrainDBDataSource { private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBJdbcDataSource.class); - private final String name; private final DataSource dataSource; private final SqlDialect dialect; private final JdbcConvention convention; - private ImmutableMap subSchemaMap; public TrainDBJdbcDataSource(SchemaPlus parentSchema, DataSource dataSource) { - this.name = "traindb"; // FIXME + super(); this.dataSource = dataSource; final Expression expression = - Schemas.subSchemaExpression(parentSchema, name, JdbcCatalogSchema.class); + Schemas.subSchemaExpression(parentSchema, getName(), TrainDBJdbcSchema.class); this.dialect = createDialect(dataSource); - this.convention = JdbcConvention.of(dialect, expression, name); + this.convention = JdbcConvention.of(dialect, expression, getName()); computeSubSchemaMap(); } @@ -80,25 +75,6 @@ public static SqlDialect createDialect(DataSource dataSource) { return JdbcUtils.DialectPool.INSTANCE.get(dataSource); } - public final String getName() { - return name; - } - - @Override - public final boolean isMutable() { - return false; - } - - @Override - public final Map getSubSchemaMap() { - LOG.debug("getSubSchemaMap called. subSchemaMapSize=" + subSchemaMap.size()); - return subSchemaMap; - } - - public final void setSubSchemaMap(ImmutableMap subSchemaMap) { - this.subSchemaMap = subSchemaMap; - } - public DataSource getDataSource() { return dataSource; } diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcSchema.java b/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcSchema.java new file mode 100644 index 0000000..54950e0 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcSchema.java @@ -0,0 +1,179 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import com.google.common.collect.ImmutableMap; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import javax.annotation.Nullable; +import org.apache.calcite.avatica.MetaImpl; +import org.apache.calcite.avatica.SqlType; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.schema.Table; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Util; +import traindb.common.TrainDBLogger; +import traindb.schema.TrainDBSchema; + +public class TrainDBJdbcSchema extends TrainDBSchema { + private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBJdbcSchema.class); + + public TrainDBJdbcSchema(String name, TrainDBJdbcDataSource dataSource) { + super(name, dataSource); + computeTableMap(); + } + + public void computeTableMap() { + final ImmutableMap.Builder builder = ImmutableMap.builder(); + Connection connection = null; + ResultSet resultSet = null; + try { + TrainDBJdbcDataSource dataSource = (TrainDBJdbcDataSource) getDataSource(); + connection = dataSource.getDataSource().getConnection(); + DatabaseMetaData databaseMetaData = connection.getMetaData(); + resultSet = databaseMetaData.getTables(getName(), null, null, null); + while (resultSet.next()) { + final String catalogName = resultSet.getString(1); + final String schemaName = resultSet.getString(2); + final String tableName = resultSet.getString(3); + final String tableTypeName = resultSet.getString(4).replace(" ", "_"); + + MetaImpl.MetaTable tableDef = + new MetaImpl.MetaTable(catalogName, schemaName, tableName, tableTypeName); + + builder.put(tableName, new TrainDBJdbcTable(tableName, this, tableDef, + getProtoType(tableDef, databaseMetaData))); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } finally { + JdbcUtils.close(connection, null, resultSet); + } + setTableMap(builder.build()); + } + + private RelDataType getProtoType(MetaImpl.MetaTable tableDef, DatabaseMetaData databaseMetaData) { + RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT); + RelDataTypeFactory.Builder builder = typeFactory.builder(); + ResultSet resultSet = null; + try { + resultSet = databaseMetaData.getColumns( + tableDef.tableCat, tableDef.tableSchem, tableDef.tableName, null); + while (resultSet.next()) { + String columnName = resultSet.getString(4); + int dataType = resultSet.getInt(5); + String typeString = resultSet.getString(6); + int precision; + int scale; + switch (SqlType.valueOf(dataType)) { + case TIMESTAMP: + case TIME: + precision = resultSet.getInt(9); // SCALE + scale = 0; + break; + default: + precision = resultSet.getInt(7); // SIZE + scale = resultSet.getInt(9); // SCALE + break; + } + RelDataType sqlType = sqlType(typeFactory, dataType, precision, scale, typeString); + + builder.add(columnName, sqlType); + } + } catch (SQLException e) { + LOG.debug(e.getMessage()); + JdbcUtils.close(null, null, resultSet); + } + + return builder.build(); + } + + private static RelDataType sqlType(RelDataTypeFactory typeFactory, int dataType, + int precision, int scale, @Nullable String typeString) { + // Fall back to ANY if type is unknown + final SqlTypeName sqlTypeName = + Util.first(SqlTypeName.getNameForJdbcType(dataType), SqlTypeName.ANY); + switch (sqlTypeName) { + case ARRAY: + RelDataType component = null; + if (typeString != null && typeString.endsWith(" ARRAY")) { + // E.g. hsqldb gives "INTEGER ARRAY", so we deduce the component type + // "INTEGER". + final String remaining = typeString.substring(0, + typeString.length() - " ARRAY".length()); + component = parseTypeString(typeFactory, remaining); + } + if (component == null) { + component = typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.ANY), true); + } + return typeFactory.createArrayType(component, -1); + default: + break; + } + if (precision >= 0 + && scale >= 0 + && sqlTypeName.allowsPrecScale(true, true)) { + return typeFactory.createSqlType(sqlTypeName, precision, scale); + } else if (precision >= 0 && sqlTypeName.allowsPrecNoScale()) { + return typeFactory.createSqlType(sqlTypeName, precision); + } else { + assert sqlTypeName.allowsNoPrecNoScale(); + return typeFactory.createSqlType(sqlTypeName); + } + } + + /** + * Given "INTEGER", returns BasicSqlType(INTEGER). + * Given "VARCHAR(10)", returns BasicSqlType(VARCHAR, 10). + * Given "NUMERIC(10, 2)", returns BasicSqlType(NUMERIC, 10, 2). + */ + private static RelDataType parseTypeString(RelDataTypeFactory typeFactory, + String typeString) { + int precision = -1; + int scale = -1; + int open = typeString.indexOf("("); + if (open >= 0) { + int close = typeString.indexOf(")", open); + if (close >= 0) { + String rest = typeString.substring(open + 1, close); + typeString = typeString.substring(0, open); + int comma = rest.indexOf(","); + if (comma >= 0) { + precision = Integer.parseInt(rest.substring(0, comma)); + scale = Integer.parseInt(rest.substring(comma)); + } else { + precision = Integer.parseInt(rest); + } + } + } + try { + final SqlTypeName typeName = SqlTypeName.valueOf(typeString); + return typeName.allowsPrecScale(true, true) + ? typeFactory.createSqlType(typeName, precision, scale) + : typeName.allowsPrecScale(true, false) + ? typeFactory.createSqlType(typeName, precision) + : typeFactory.createSqlType(typeName); + } catch (IllegalArgumentException e) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.ANY), true); + } + } +} diff --git a/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcTable.java b/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcTable.java new file mode 100644 index 0000000..be44dd8 --- /dev/null +++ b/traindb-core/src/main/java/traindb/adapter/jdbc/TrainDBJdbcTable.java @@ -0,0 +1,139 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.adapter.jdbc; + +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.DataContext; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.MetaImpl; +import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.Enumerator; +import org.apache.calcite.linq4j.QueryProvider; +import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.runtime.ResultSetEnumerable; +import org.apache.calcite.schema.ScannableTable; +import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.TranslatableTable; +import org.apache.calcite.schema.impl.AbstractTableQueryable; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlWriterConfig; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.pretty.SqlPrettyWriter; +import org.apache.calcite.sql.util.SqlString; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; +import traindb.common.TrainDBLogger; +import traindb.schema.TrainDBTable; + +public final class TrainDBJdbcTable extends TrainDBTable + implements TranslatableTable, ScannableTable { + private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBJdbcTable.class); + + public TrainDBJdbcTable(String name, TrainDBJdbcSchema schema, MetaImpl.MetaTable tableDef, + RelDataType protoType) { + super(name, schema, Schema.TableType.valueOf(tableDef.tableType), protoType); + } + + @Override + public String toString() { + return "TrainDBTable {" + getName() + "}"; + } + + private List> fieldClasses(final JavaTypeFactory typeFactory) { + final RelDataType rowType = getRowType(typeFactory); + return Util.transform(rowType.getFieldList(), f -> { + final RelDataType type = f.getType(); + final Class clazz = (Class) typeFactory.getJavaClass(type); + final ColumnMetaData.Rep rep = + Util.first(ColumnMetaData.Rep.of(clazz), + ColumnMetaData.Rep.OBJECT); + return Pair.of(rep, type.getSqlTypeName().getJdbcOrdinal()); + }); + } + + SqlString generateSql() { + final SqlNodeList selectList = SqlNodeList.SINGLETON_STAR; + SqlSelect node = + new SqlSelect(SqlParserPos.ZERO, SqlNodeList.EMPTY, selectList, + tableName(), null, null, null, null, null, null, null, null); + final SqlWriterConfig config = SqlPrettyWriter.config() + .withAlwaysUseParentheses(true) + .withDialect(getDataSource().getDialect()); + final SqlPrettyWriter writer = new SqlPrettyWriter(config); + node.unparse(writer, 0, 0); + return writer.toSqlString(); + } + + SqlIdentifier tableName() { + final List strings = new ArrayList<>(); + strings.add(getSchema().getName()); + strings.add(getName()); + return new SqlIdentifier(strings, SqlParserPos.ZERO); + } + + private final TrainDBJdbcDataSource getDataSource() { + return (TrainDBJdbcDataSource) getSchema().getDataSource(); + } + + public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable relOptTable) { + return new JdbcTableScan(context.getCluster(), relOptTable, this, + getDataSource().getConvention()); + } + + @Override + public Queryable asQueryable(QueryProvider queryProvider, + SchemaPlus schema, String tableName) { + return new JdbcTableQueryable<>(queryProvider, schema, tableName); + } + + public Enumerable scan(DataContext root) { + JavaTypeFactory typeFactory = root.getTypeFactory(); + final SqlString sql = generateSql(); + return ResultSetEnumerable.of(getDataSource().getDataSource(), sql.getSql(), + JdbcUtils.rowBuilderFactory2(fieldClasses(typeFactory))); + } + + private class JdbcTableQueryable extends AbstractTableQueryable { + JdbcTableQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) { + super(queryProvider, schema, TrainDBJdbcTable.this, tableName); + } + + @Override + public String toString() { + return "JdbcTableQueryable {table: " + tableName + "}"; + } + + @Override + public Enumerator enumerator() { + final JavaTypeFactory typeFactory = + ((CalciteConnection) queryProvider).getTypeFactory(); + final SqlString sql = generateSql(); + final List> pairs = fieldClasses(typeFactory); + @SuppressWarnings({"rawtypes", "unchecked"}) final Enumerable enumerable = + (Enumerable) ResultSetEnumerable.of(getDataSource().getDataSource(), sql.getSql(), + JdbcUtils.rowBuilderFactory2(pairs)); + return enumerable.enumerator(); + } + } +} diff --git a/traindb-core/src/main/java/traindb/engine/SqlTableIdentifierFindVisitor.java b/traindb-core/src/main/java/traindb/engine/SqlTableIdentifierFindVisitor.java new file mode 100644 index 0000000..bcf4315 --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/SqlTableIdentifierFindVisitor.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine; + +import java.util.ArrayList; +import java.util.Stack; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.util.SqlBasicVisitor; + +public final class SqlTableIdentifierFindVisitor extends SqlBasicVisitor { + private final Stack nodeStack = new Stack<>(); + private final ArrayList tableIds; + + /* + * To find table name identifiers, we use a state stack. + * It is used to indicate whether an identifier is in FROM clause. + */ + private enum State { + NOT_FROM, + FROM + } + + public SqlTableIdentifierFindVisitor(ArrayList tableIds) { + this.tableIds = tableIds; + } + + @Override + public SqlNode visit(SqlCall call) { + if (call instanceof SqlSelect) { + int i = 0; + for (SqlNode operand : call.getOperandList()) { + // FROM operand + if (i == 2) { + nodeStack.push(State.FROM); + } else { + nodeStack.push(State.NOT_FROM); + } + + i++; + + if (operand == null) { + continue; + } + + operand.accept(this); + nodeStack.pop(); + } + return null; + } + + SqlOperator operator = call.getOperator(); + if (operator != null && operator.getKind() == SqlKind.AS) { + // AS operator will be probed only if it is in FROM clause + if (nodeStack.peek() == State.FROM) { + call.operand(0).accept(this); + } + return null; + } + + return super.visit(call); + } + + @Override + public SqlNode visit(SqlIdentifier identifier) { + // check whether this is fully qualified table name + if (!nodeStack.empty() && nodeStack.peek() == State.FROM) { + tableIds.add(identifier); + } + + return identifier; + } +} diff --git a/traindb-core/src/main/java/traindb/engine/TableNameQualifier.java b/traindb-core/src/main/java/traindb/engine/TableNameQualifier.java new file mode 100644 index 0000000..f15043c --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/TableNameQualifier.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine; + +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlNode; +import traindb.schema.SchemaManager; + +public final class TableNameQualifier { + private TableNameQualifier() { + } + + // translate table name identifier to full qualified name + public static void toFullyQualifiedName(SchemaManager schemaManager, String defaultSchema, + SqlNode query) { + ArrayList tableIds = new ArrayList<>(); + query.accept(new SqlTableIdentifierFindVisitor(tableIds)); + + for (SqlIdentifier tableId : tableIds) { + List fqn = schemaManager.toFullyQualifiedTableName( + tableId.names, defaultSchema); + tableId.setNames(fqn, null); + } + } +} diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBExecContext.java b/traindb-core/src/main/java/traindb/engine/TrainDBExecContext.java index a396dea..45575b5 100644 --- a/traindb-core/src/main/java/traindb/engine/TrainDBExecContext.java +++ b/traindb-core/src/main/java/traindb/engine/TrainDBExecContext.java @@ -74,7 +74,9 @@ public VerdictSingleResult sql(String query, boolean getResult) throws TrainDBEx // Pass input query to VerdictDB try { - // engine.processQuery(query); + if (System.getenv("ENABLE_TRAINDB_CALCITE") != null) { + return engine.processQuery(query); + } return executionContext.sql(query, getResult); } catch (Exception e) { throw new TrainDBException(e.getMessage()); diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java index aaee32f..9c2e985 100644 --- a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java +++ b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java @@ -21,7 +21,10 @@ import java.nio.file.Files; import java.nio.file.Path; import java.sql.Connection; +import java.sql.DriverManager; import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.Types; import java.util.ArrayList; import java.util.Arrays; @@ -29,6 +32,7 @@ import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.sql.SqlDialect; import org.apache.calcite.sql.SqlExplainFormat; import org.apache.calcite.sql.SqlExplainLevel; import org.apache.calcite.sql.SqlNode; @@ -36,12 +40,14 @@ import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.Planner; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.json.simple.JSONObject; import org.verdictdb.VerdictSingleResult; import org.verdictdb.connection.CachedDbmsConnection; import org.verdictdb.connection.DbmsConnection; import org.verdictdb.connection.DbmsQueryResult; import org.verdictdb.connection.JdbcConnection; +import org.verdictdb.connection.JdbcQueryResult; import org.verdictdb.coordinator.VerdictSingleResultFromDbmsQueryResult; import traindb.catalog.CatalogContext; import traindb.catalog.CatalogException; @@ -379,13 +385,13 @@ public VerdictSingleResult showModels() throws Exception { } @Override - public VerdictSingleResult showModelInstances(String modelName) throws Exception { + public VerdictSingleResult showModelInstances() throws Exception { List header = Arrays.asList("model", "model_instance", "schema", "table", "columns"); List> modelInstanceInfo = new ArrayList<>(); - for (MModelInstance mModelInstance : catalogContext.getModelInstances(modelName)) { - modelInstanceInfo.add(Arrays.asList(modelName, mModelInstance.getName(), - mModelInstance.getSchemaName(), mModelInstance.getTableName(), + for (MModelInstance mModelInstance : catalogContext.getModelInstances()) { + modelInstanceInfo.add(Arrays.asList(mModelInstance.getModel().getName(), + mModelInstance.getName(), mModelInstance.getSchemaName(), mModelInstance.getTableName(), mModelInstance.getColumnNames().toString())); } @@ -417,13 +423,28 @@ public VerdictSingleResult processQuery(String query) throws Exception { .parserConfig(parserConf).build(); Planner planner = Frameworks.getPlanner(config); SqlNode parse = planner.parse(query); + TableNameQualifier.toFullyQualifiedName(schemaManager, conn.getDefaultSchema(), parse); LOG.debug("Parsed query: " + parse.toString()); + SqlNode validate = planner.validate(parse); RelRoot relRoot = planner.rel(validate); LOG.debug( RelOptUtil.dumpPlan("Generated plan: ", relRoot.rel, SqlExplainFormat.TEXT, SqlExplainLevel.ALL_ATTRIBUTES)); + SqlDialect.DatabaseProduct dp = SqlDialect.DatabaseProduct.POSTGRESQL; + String queryString = validate.toSqlString(dp.getDialect()).getSql(); + LOG.debug("query string: " + queryString); + + try { + Connection internalConn = DriverManager.getConnection("jdbc:traindb-calcite:"); + PreparedStatement stmt = internalConn.prepareStatement(queryString); + ResultSet rs = stmt.executeQuery(); + return new VerdictSingleResultFromDbmsQueryResult(new JdbcQueryResult(rs)); + } catch (SQLException e) { + LOG.debug(ExceptionUtils.getStackTrace(e)); + } + return null; } } diff --git a/traindb-core/src/main/java/traindb/engine/calcite/CalciteConnectionImpl.java b/traindb-core/src/main/java/traindb/engine/calcite/CalciteConnectionImpl.java new file mode 100644 index 0000000..88f9e30 --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/CalciteConnectionImpl.java @@ -0,0 +1,681 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.linq4j.Nullness.castNonNull; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.lang.reflect.Type; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.calcite.DataContext; +import org.apache.calcite.DataContexts; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaFactory; +import org.apache.calcite.avatica.AvaticaSite; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Helper; +import org.apache.calcite.avatica.InternalProperty; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.MetaImpl; +import org.apache.calcite.avatica.NoSuchStatementException; +import org.apache.calcite.avatica.UnregisteredDriver; +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionConfigImpl; +import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.jdbc.CalcitePrepare.Context; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.calcite.linq4j.BaseQueryable; +import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.Enumerator; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.QueryProvider; +import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.linq4j.function.Function0; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.materialize.Lattice; +import org.apache.calcite.materialize.MaterializationService; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.DelegatingTypeSystem; +import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.runtime.Hook; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.SchemaVersion; +import org.apache.calcite.schema.Schemas; +import org.apache.calcite.schema.impl.AbstractSchema; +import org.apache.calcite.schema.impl.LongSchemaVersion; +import org.apache.calcite.server.CalciteServer; +import org.apache.calcite.server.CalciteServerStatement; +import org.apache.calcite.sql.advise.SqlAdvisor; +import org.apache.calcite.sql.advise.SqlAdvisorValidator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorWithHints; +import org.apache.calcite.tools.RelRunner; +import org.apache.calcite.util.BuiltInMethod; +import org.apache.calcite.util.Holder; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Implementation of JDBC connection + * in the Calcite engine. + * + *

Abstract to allow newer versions of JDBC to add methods. + */ +abstract class CalciteConnectionImpl + extends AvaticaConnection + implements CalciteConnection, QueryProvider { + // must be package-protected + static final Trojan TROJAN = createTrojan(); + public final JavaTypeFactory typeFactory; + final Function0 prepareFactory; + final CalciteServer server = new CalciteServerImpl(); + CalciteSchema rootSchema; + + /** + * Creates a CalciteConnectionImpl. + * + *

Not public; method is called only from the driver. + * + * @param driver Driver + * @param factory Factory for JDBC objects + * @param url Server URL + * @param info Other connection properties + * @param rootSchema Root schema, or null + * @param typeFactory Type factory, or null + */ + protected CalciteConnectionImpl(Driver driver, AvaticaFactory factory, + String url, Properties info, @Nullable CalciteSchema rootSchema, + @Nullable JavaTypeFactory typeFactory) { + super(driver, factory, url, info); + CalciteConnectionConfig cfg = new CalciteConnectionConfigImpl(info); + this.prepareFactory = driver.prepareFactory; + if (typeFactory != null) { + this.typeFactory = typeFactory; + } else { + RelDataTypeSystem typeSystem = + cfg.typeSystem(RelDataTypeSystem.class, RelDataTypeSystem.DEFAULT); + if (cfg.conformance().shouldConvertRaggedUnionTypesToVarying()) { + typeSystem = + new DelegatingTypeSystem(typeSystem) { + @Override + public boolean + shouldConvertRaggedUnionTypesToVarying() { + return true; + } + }; + } + this.typeFactory = new JavaTypeFactoryImpl(typeSystem); + } + this.rootSchema = + requireNonNull(rootSchema != null + ? rootSchema + : CalciteSchema.createRootSchema(true)); + Preconditions.checkArgument(this.rootSchema.isRoot(), "must be root schema"); + this.properties.put(InternalProperty.CASE_SENSITIVE, cfg.caseSensitive()); + this.properties.put(InternalProperty.UNQUOTED_CASING, cfg.unquotedCasing()); + this.properties.put(InternalProperty.QUOTED_CASING, cfg.quotedCasing()); + this.properties.put(InternalProperty.QUOTING, cfg.quoting()); + } + + CalciteMetaImpl meta() { + return (CalciteMetaImpl) meta; + } + + @Override + public CalciteConnectionConfig config() { + return new CalciteConnectionConfigImpl(info); + } + + @Override + public Context createPrepareContext() { + return new ContextImpl(this); + } + + /** + * Called after the constructor has completed and the model has been + * loaded. + */ + void init() { + final MaterializationService service = MaterializationService.instance(); + for (CalciteSchema.LatticeEntry e : Schemas.getLatticeEntries(rootSchema)) { + final Lattice lattice = e.getLattice(); + for (Lattice.Tile tile : lattice.computeTiles()) { + service.defineTile(lattice, tile.bitSet(), tile.measures, e.schema, + true, true); + } + } + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface == RelRunner.class) { + return iface.cast(new RelRunner() { + @Override + public PreparedStatement prepareStatement(RelNode rel) + throws SQLException { + return prepareStatement_(CalcitePrepare.Query.of(rel), + ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + getHoldability()); + } + + @SuppressWarnings("deprecation") + @Override + public PreparedStatement prepare(RelNode rel) { + try { + return prepareStatement(rel); + } catch (SQLException e) { + throw Util.throwAsRuntime(e); + } + } + }); + } + return super.unwrap(iface); + } + + @Override + public CalciteStatement createStatement(int resultSetType, + int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return (CalciteStatement) super.createStatement(resultSetType, + resultSetConcurrency, resultSetHoldability); + } + + @Override + public CalcitePreparedStatement prepareStatement( + String sql, + int resultSetType, + int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + final CalcitePrepare.Query query = CalcitePrepare.Query.of(sql); + return prepareStatement_(query, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + private CalcitePreparedStatement prepareStatement_( + CalcitePrepare.Query query, + int resultSetType, + int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + try { + final Meta.Signature signature = + parseQuery(query, createPrepareContext(), -1); + final CalcitePreparedStatement calcitePreparedStatement = + (CalcitePreparedStatement) factory.newPreparedStatement(this, null, + signature, resultSetType, resultSetConcurrency, resultSetHoldability); + server.getStatement(calcitePreparedStatement.handle).setSignature(signature); + return calcitePreparedStatement; + } catch (Exception e) { + String message = query.rel == null + ? "Error while preparing statement [" + query.sql + "]" + : "Error while preparing plan [" + RelOptUtil.toString(query.rel) + "]"; + throw Helper.INSTANCE.createException(message, e); + } + } + + CalcitePrepare.CalciteSignature parseQuery( + CalcitePrepare.Query query, + CalcitePrepare.Context prepareContext, long maxRowCount) { + CalcitePrepare.Dummy.push(prepareContext); + try { + final CalcitePrepare prepare = prepareFactory.apply(); + return prepare.prepareSql(prepareContext, query, Object[].class, + maxRowCount); + } finally { + CalcitePrepare.Dummy.pop(prepareContext); + } + } + + @Override + public AtomicBoolean getCancelFlag(Meta.StatementHandle handle) + throws NoSuchStatementException { + final CalciteServerStatement serverStatement = server.getStatement(handle); + return ((CalciteServerStatementImpl) serverStatement).cancelFlag; + } + + // CalciteConnection methods + + @Override + public SchemaPlus getRootSchema() { + return rootSchema.plus(); + } + + public void setRootSchema(CalciteSchema rootSchema) { + this.rootSchema = rootSchema; + } + + @Override + public JavaTypeFactory getTypeFactory() { + return typeFactory; + } + + // QueryProvider methods + + @Override + public Properties getProperties() { + return info; + } + + @Override + public Queryable createQuery( + Expression expression, Class rowType) { + return new CalciteQueryable<>(this, rowType, expression); + } + + @Override + public Queryable createQuery(Expression expression, Type rowType) { + return new CalciteQueryable<>(this, rowType, expression); + } + + @Override + public T execute(Expression expression, Type type) { + return castNonNull(null); // TODO: + } + + @Override + public T execute(Expression expression, Class type) { + return castNonNull(null); // TODO: + } + + @Override + public Enumerator executeQuery(Queryable queryable) { + try { + CalciteStatement statement = (CalciteStatement) createStatement(); + CalcitePrepare.CalciteSignature signature = + statement.prepare(queryable); + return enumerable(statement.handle, signature).enumerator(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public Enumerable enumerable(Meta.StatementHandle handle, + CalcitePrepare.CalciteSignature signature) + throws SQLException { + Map map = new LinkedHashMap<>(); + AvaticaStatement statement = lookupStatement(handle); + final List parameterValues = + TROJAN.getParameterValues(statement); + + if (MetaImpl.checkParameterValueHasNull(parameterValues)) { + throw new SQLException("exception while executing query: unbound parameter"); + } + + Ord.forEach(parameterValues, + (e, i) -> map.put("?" + i, e.toLocal())); + map.putAll(signature.internalParameters); + final AtomicBoolean cancelFlag; + try { + cancelFlag = getCancelFlag(handle); + } catch (NoSuchStatementException e) { + throw new RuntimeException(e); + } + map.put(DataContext.Variable.CANCEL_FLAG.camelName, cancelFlag); + int queryTimeout = statement.getQueryTimeout(); + // Avoid overflow + if (queryTimeout > 0 && queryTimeout < Integer.MAX_VALUE / 1000) { + map.put(DataContext.Variable.TIMEOUT.camelName, queryTimeout * 1000L); + } + final DataContext dataContext = createDataContext(map, signature.rootSchema); + return signature.enumerable(dataContext); + } + + public DataContext createDataContext(Map parameterValues, + @Nullable CalciteSchema rootSchema) { + if (config().spark()) { + return DataContexts.EMPTY; + } + return new DataContextImpl(this, parameterValues, rootSchema); + } + + // do not make public + UnregisteredDriver getDriver() { + return driver; + } + + // do not make public + AvaticaFactory getFactory() { + return factory; + } + + /** + * Implementation of Queryable. + * + * @param element type + */ + static class CalciteQueryable extends BaseQueryable { + CalciteQueryable(CalciteConnection connection, Type elementType, + Expression expression) { + super(connection, elementType, expression); + } + + public CalciteConnection getConnection() { + return (CalciteConnection) provider; + } + } + + /** + * Implementation of Server. + */ + private static class CalciteServerImpl implements CalciteServer { + final Map statementMap = new HashMap<>(); + + @Override + public void removeStatement(Meta.StatementHandle h) { + statementMap.remove(h.id); + } + + @Override + public void addStatement(CalciteConnection connection, + Meta.StatementHandle h) { + final CalciteConnectionImpl c = (CalciteConnectionImpl) connection; + final CalciteServerStatement previous = + statementMap.put(h.id, new CalciteServerStatementImpl(c)); + if (previous != null) { + throw new AssertionError(); + } + } + + @Override + public CalciteServerStatement getStatement(Meta.StatementHandle h) + throws NoSuchStatementException { + CalciteServerStatement statement = statementMap.get(h.id); + if (statement == null) { + throw new NoSuchStatementException(h); + } + return statement; + } + } + + /** + * Schema that has no parents. + */ + static class RootSchema extends AbstractSchema { + RootSchema() { + super(); + } + + @Override + public Expression getExpression(@Nullable SchemaPlus parentSchema, + String name) { + return Expressions.call( + DataContext.ROOT, + BuiltInMethod.DATA_CONTEXT_GET_ROOT_SCHEMA.method); + } + } + + /** + * Implementation of DataContext. + */ + static class DataContextImpl implements DataContext { + private final ImmutableMap map; + private final @Nullable CalciteSchema rootSchema; + private final QueryProvider queryProvider; + private final JavaTypeFactory typeFactory; + + DataContextImpl(CalciteConnectionImpl connection, + Map parameters, @Nullable CalciteSchema rootSchema) { + this.queryProvider = connection; + this.typeFactory = connection.getTypeFactory(); + this.rootSchema = rootSchema; + + // Store the time at which the query started executing. The SQL + // standard says that functions such as CURRENT_TIMESTAMP return the + // same value throughout the query. + final Holder timeHolder = Holder.of(System.currentTimeMillis()); + + // Give a hook chance to alter the clock. + Hook.CURRENT_TIME.run(timeHolder); + final long time = timeHolder.get(); + final TimeZone timeZone = connection.getTimeZone(); + final long localOffset = timeZone.getOffset(time); + final long currentOffset = localOffset; + final String user = "sa"; + final String systemUser = System.getProperty("user.name"); + final String localeName = connection.config().locale(); + final Locale locale = localeName != null + ? Util.parseLocale(localeName) : Locale.ROOT; + + // Give a hook chance to alter standard input, output, error streams. + final Holder streamHolder = + Holder.of(new Object[] {System.in, System.out, System.err}); + Hook.STANDARD_STREAMS.run(streamHolder); + + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put(Variable.UTC_TIMESTAMP.camelName, time) + .put(Variable.CURRENT_TIMESTAMP.camelName, time + currentOffset) + .put(Variable.LOCAL_TIMESTAMP.camelName, time + localOffset) + .put(Variable.TIME_ZONE.camelName, timeZone) + .put(Variable.USER.camelName, user) + .put(Variable.SYSTEM_USER.camelName, systemUser) + .put(Variable.LOCALE.camelName, locale) + .put(Variable.STDIN.camelName, streamHolder.get()[0]) + .put(Variable.STDOUT.camelName, streamHolder.get()[1]) + .put(Variable.STDERR.camelName, streamHolder.get()[2]); + for (Map.Entry entry : parameters.entrySet()) { + Object e = entry.getValue(); + if (e == null) { + e = AvaticaSite.DUMMY_VALUE; + } + builder.put(entry.getKey(), e); + } + map = builder.build(); + } + + @Override + public synchronized @Nullable Object get(String name) { + Object o = map.get(name); + if (o == AvaticaSite.DUMMY_VALUE) { + return null; + } + if (o == null && Variable.SQL_ADVISOR.camelName.equals(name)) { + return getSqlAdvisor(); + } + return o; + } + + private SqlAdvisor getSqlAdvisor() { + final CalciteConnectionImpl con = (CalciteConnectionImpl) queryProvider; + final String schemaName; + try { + schemaName = con.getSchema(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + final List schemaPath = + schemaName == null + ? ImmutableList.of() + : ImmutableList.of(schemaName); + final SqlValidatorWithHints validator = + new SqlAdvisorValidator(SqlStdOperatorTable.instance(), + new CalciteCatalogReader(requireNonNull(rootSchema, "rootSchema"), + schemaPath, typeFactory, con.config()), + typeFactory, SqlValidator.Config.DEFAULT); + final CalciteConnectionConfig config = con.config(); + // This duplicates org.apache.calcite.prepare.CalcitePrepareImpl.prepare2_ + final SqlParser.Config parserConfig = SqlParser.config() + .withQuotedCasing(config.quotedCasing()) + .withUnquotedCasing(config.unquotedCasing()) + .withQuoting(config.quoting()) + .withConformance(config.conformance()) + .withCaseSensitive(config.caseSensitive()); + return new SqlAdvisor(validator, parserConfig); + } + + @Override + public @Nullable SchemaPlus getRootSchema() { + return rootSchema == null ? null : rootSchema.plus(); + } + + @Override + public JavaTypeFactory getTypeFactory() { + return typeFactory; + } + + @Override + public QueryProvider getQueryProvider() { + return queryProvider; + } + } + + /** + * Implementation of Context. + */ + static class ContextImpl implements CalcitePrepare.Context { + private final CalciteConnectionImpl connection; + private final CalciteSchema mutableRootSchema; + private final CalciteSchema rootSchema; + + ContextImpl(CalciteConnectionImpl connection) { + this.connection = requireNonNull(connection, "connection"); + long now = System.currentTimeMillis(); + SchemaVersion schemaVersion = new LongSchemaVersion(now); + this.mutableRootSchema = connection.rootSchema; + this.rootSchema = mutableRootSchema.createSnapshot(schemaVersion); + } + + @Override + public JavaTypeFactory getTypeFactory() { + return connection.typeFactory; + } + + @Override + public CalciteSchema getRootSchema() { + return rootSchema; + } + + @Override + public CalciteSchema getMutableRootSchema() { + return mutableRootSchema; + } + + @Override + public List getDefaultSchemaPath() { + final String schemaName; + try { + schemaName = connection.getSchema(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + return schemaName == null + ? ImmutableList.of() + : ImmutableList.of(schemaName); + } + + @Override + public @Nullable List getObjectPath() { + return null; + } + + @Override + public CalciteConnectionConfig config() { + return connection.config(); + } + + @Override + public DataContext getDataContext() { + return connection.createDataContext(ImmutableMap.of(), + rootSchema); + } + + @Override + public RelRunner getRelRunner() { + final RelRunner runner; + try { + runner = connection.unwrap(RelRunner.class); + } catch (SQLException e) { + throw new RuntimeException(e); + } + if (runner == null) { + throw new UnsupportedOperationException(); + } + return runner; + } + + @Override + public CalcitePrepare.SparkHandler spark() { + final boolean enable = config().spark(); + return CalcitePrepare.Dummy.getSparkHandler(enable); + } + } + + /** + * Implementation of {@link CalciteServerStatement}. + */ + static class CalciteServerStatementImpl + implements CalciteServerStatement { + private final CalciteConnectionImpl connection; + private final AtomicBoolean cancelFlag = new AtomicBoolean(); + private @Nullable Iterator iterator; + private Meta.@Nullable Signature signature; + + CalciteServerStatementImpl(CalciteConnectionImpl connection) { + this.connection = requireNonNull(connection, "connection"); + } + + @Override + public Context createPrepareContext() { + return connection.createPrepareContext(); + } + + @Override + public CalciteConnection getConnection() { + return connection; + } + + @Override + public Meta.@Nullable Signature getSignature() { + return signature; + } + + @Override + public void setSignature(Meta.Signature signature) { + this.signature = signature; + } + + @Override + public @Nullable Iterator getResultSet() { + return iterator; + } + + @Override + public void setResultSet(Iterator iterator) { + this.iterator = iterator; + } + } + +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/CalciteJdbc41Factory.java b/traindb-core/src/main/java/traindb/engine/calcite/CalciteJdbc41Factory.java new file mode 100644 index 0000000..a10a7da --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/CalciteJdbc41Factory.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import java.io.InputStream; +import java.io.Reader; +import java.sql.NClob; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLXML; +import java.util.Properties; +import java.util.TimeZone; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaDatabaseMetaData; +import org.apache.calcite.avatica.AvaticaFactory; +import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.UnregisteredDriver; +import org.apache.calcite.jdbc.CalciteFactory; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.jdbc.CalciteSchema; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Implementation of {@link org.apache.calcite.avatica.AvaticaFactory} + * for Calcite and JDBC 4.1 (corresponds to JDK 1.7). + */ +@SuppressWarnings("UnusedDeclaration") +public class CalciteJdbc41Factory extends CalciteFactory { + /** + * Creates a factory for JDBC version 4.1. + */ + public CalciteJdbc41Factory() { + this(4, 1); + } + + /** + * Creates a JDBC factory with given major/minor version number. + */ + protected CalciteJdbc41Factory(int major, int minor) { + super(major, minor); + } + + @Override + public CalciteJdbc41Connection newConnection( + UnregisteredDriver driver, AvaticaFactory factory, String url, Properties info, + @Nullable CalciteSchema rootSchema, @Nullable JavaTypeFactory typeFactory) { + return new CalciteJdbc41Connection( + (Driver) driver, factory, url, info, rootSchema, typeFactory); + } + + @Override + public CalciteJdbc41DatabaseMetaData newDatabaseMetaData( + AvaticaConnection connection) { + return new CalciteJdbc41DatabaseMetaData( + (CalciteConnectionImpl) connection); + } + + @Override + public CalciteJdbc41Statement newStatement(AvaticaConnection connection, + Meta.@Nullable StatementHandle h, + int resultSetType, + int resultSetConcurrency, + int resultSetHoldability) { + return new CalciteJdbc41Statement( + (CalciteConnectionImpl) connection, + h, + resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public AvaticaPreparedStatement newPreparedStatement( + AvaticaConnection connection, + Meta.@Nullable StatementHandle h, + Meta.Signature signature, + int resultSetType, + int resultSetConcurrency, + int resultSetHoldability) throws SQLException { + return new CalciteJdbc41PreparedStatement( + (CalciteConnectionImpl) connection, h, + (CalcitePrepare.CalciteSignature) signature, resultSetType, + resultSetConcurrency, resultSetHoldability); + } + + @Override + public CalciteResultSet newResultSet( + AvaticaStatement statement, QueryState state, Meta.Signature signature, TimeZone timeZone, + Meta.Frame firstFrame) throws SQLException { + final ResultSetMetaData metaData = + newResultSetMetaData(statement, signature); + final CalcitePrepare.CalciteSignature calciteSignature = + (CalcitePrepare.CalciteSignature) signature; + return new CalciteResultSet(statement, calciteSignature, metaData, timeZone, + firstFrame); + } + + @Override + public ResultSetMetaData newResultSetMetaData(AvaticaStatement statement, + Meta.Signature signature) { + return new AvaticaResultSetMetaData(statement, null, signature); + } + + /** + * Implementation of connection for JDBC 4.1. + */ + private static class CalciteJdbc41Connection extends CalciteConnectionImpl { + CalciteJdbc41Connection(Driver driver, AvaticaFactory factory, String url, + Properties info, @Nullable CalciteSchema rootSchema, + @Nullable JavaTypeFactory typeFactory) { + super(driver, factory, url, info, rootSchema, typeFactory); + } + } + + /** + * Implementation of statement for JDBC 4.1. + */ + private static class CalciteJdbc41Statement extends CalciteStatement { + CalciteJdbc41Statement(CalciteConnectionImpl connection, + Meta.@Nullable StatementHandle h, int resultSetType, + int resultSetConcurrency, + int resultSetHoldability) { + super(connection, h, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + } + + /** + * Implementation of prepared statement for JDBC 4.1. + */ + private static class CalciteJdbc41PreparedStatement + extends CalcitePreparedStatement { + CalciteJdbc41PreparedStatement(CalciteConnectionImpl connection, + Meta.@Nullable StatementHandle h, + CalcitePrepare.CalciteSignature signature, + int resultSetType, int resultSetConcurrency, + int resultSetHoldability) + throws SQLException { + super(connection, h, signature, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public void setRowId( + int parameterIndex, + @Nullable RowId x) throws SQLException { + getSite(parameterIndex).setRowId(x); + } + + @Override + public void setNString( + int parameterIndex, @Nullable String value) throws SQLException { + getSite(parameterIndex).setNString(value); + } + + @Override + public void setNCharacterStream( + int parameterIndex, + @Nullable Reader value, + long length) throws SQLException { + getSite(parameterIndex) + .setNCharacterStream(value, length); + } + + @Override + public void setNClob( + int parameterIndex, + @Nullable NClob value) throws SQLException { + getSite(parameterIndex).setNClob(value); + } + + @Override + public void setClob( + int parameterIndex, + @Nullable Reader reader, + long length) throws SQLException { + getSite(parameterIndex) + .setClob(reader, length); + } + + @Override + public void setBlob( + int parameterIndex, + @Nullable InputStream inputStream, + long length) throws SQLException { + getSite(parameterIndex) + .setBlob(inputStream, length); + } + + @Override + public void setNClob( + int parameterIndex, + @Nullable Reader reader, + long length) throws SQLException { + getSite(parameterIndex).setNClob(reader, length); + } + + @Override + public void setSQLXML( + int parameterIndex, @Nullable SQLXML xmlObject) throws SQLException { + getSite(parameterIndex).setSQLXML(xmlObject); + } + + @Override + public void setAsciiStream( + int parameterIndex, + @Nullable InputStream x, + long length) throws SQLException { + getSite(parameterIndex) + .setAsciiStream(x, length); + } + + @Override + public void setBinaryStream( + int parameterIndex, + @Nullable InputStream x, + long length) throws SQLException { + getSite(parameterIndex) + .setBinaryStream(x, length); + } + + @Override + public void setCharacterStream( + int parameterIndex, + @Nullable Reader reader, + long length) throws SQLException { + getSite(parameterIndex) + .setCharacterStream(reader, length); + } + + @Override + public void setAsciiStream( + int parameterIndex, @Nullable InputStream x) throws SQLException { + getSite(parameterIndex).setAsciiStream(x); + } + + @Override + public void setBinaryStream( + int parameterIndex, @Nullable InputStream x) throws SQLException { + getSite(parameterIndex).setBinaryStream(x); + } + + @Override + public void setCharacterStream( + int parameterIndex, @Nullable Reader reader) throws SQLException { + getSite(parameterIndex) + .setCharacterStream(reader); + } + + @Override + public void setNCharacterStream( + int parameterIndex, @Nullable Reader value) throws SQLException { + getSite(parameterIndex) + .setNCharacterStream(value); + } + + @Override + public void setClob( + int parameterIndex, + @Nullable Reader reader) throws SQLException { + getSite(parameterIndex).setClob(reader); + } + + @Override + public void setBlob( + int parameterIndex, @Nullable InputStream inputStream) throws SQLException { + getSite(parameterIndex) + .setBlob(inputStream); + } + + @Override + public void setNClob( + int parameterIndex, @Nullable Reader reader) throws SQLException { + getSite(parameterIndex) + .setNClob(reader); + } + } + + /** + * Implementation of database metadata for JDBC 4.1. + */ + private static class CalciteJdbc41DatabaseMetaData + extends AvaticaDatabaseMetaData { + CalciteJdbc41DatabaseMetaData(CalciteConnectionImpl connection) { + super(connection); + } + } +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/CalciteMetaImpl.java b/traindb-core/src/main/java/traindb/engine/calcite/CalciteMetaImpl.java new file mode 100644 index 0000000..d885331 --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/CalciteMetaImpl.java @@ -0,0 +1,955 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; +import com.google.common.primitives.Longs; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import org.apache.calcite.DataContext; +import org.apache.calcite.adapter.java.AbstractQueryableTable; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.AvaticaUtils; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.MetaImpl; +import org.apache.calcite.avatica.NoSuchStatementException; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.jdbc.CalcitePrepare.Context; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.Enumerator; +import org.apache.calcite.linq4j.Linq4j; +import org.apache.calcite.linq4j.QueryProvider; +import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.linq4j.function.Functions; +import org.apache.calcite.linq4j.function.Predicate1; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFactoryImpl; +import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.runtime.FlatLists; +import org.apache.calcite.runtime.Hook; +import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.AbstractTableQueryable; +import org.apache.calcite.schema.impl.MaterializedViewTable; +import org.apache.calcite.server.CalciteServerStatement; +import org.apache.calcite.sql.SqlJdbcFunctionCall; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.tools.FrameworkConfig; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.util.Holder; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Helper for implementing the {@code getXxx} methods such as + * {@link org.apache.calcite.avatica.AvaticaDatabaseMetaData#getTables}. + */ +public class CalciteMetaImpl extends MetaImpl { + static final traindb.engine.calcite.Driver DRIVER = new Driver(); + + public CalciteMetaImpl(CalciteConnectionImpl connection) { + super(connection); + this.connProps + .setAutoCommit(false) + .setReadOnly(false) + .setTransactionIsolation(Connection.TRANSACTION_NONE); + this.connProps.setDirty(false); + } + + static Predicate1 namedMatcher(final Pat pattern) { + if (pattern.s == null || pattern.s.equals("%")) { + return Functions.truePredicate1(); + } + final Pattern regex = likeToRegex(pattern); + return v1 -> regex.matcher(v1.getName()).matches(); + } + + static Predicate1 matcher(final Pat pattern) { + if (pattern.s == null || pattern.s.equals("%")) { + return Functions.truePredicate1(); + } + final Pattern regex = likeToRegex(pattern); + return v1 -> regex.matcher(v1).matches(); + } + + /** + * Converts a LIKE-style pattern (where '%' represents a wild-card, escaped + * using '\') to a Java regex. + */ + public static Pattern likeToRegex(Pat pattern) { + StringBuilder buf = new StringBuilder("^"); + char[] charArray = pattern.s.toCharArray(); + int slash = -2; + for (int i = 0; i < charArray.length; i++) { + char c = charArray[i]; + if (slash == i - 1) { + buf.append('[').append(c).append(']'); + } else { + switch (c) { + case '\\': + slash = i; + break; + case '%': + buf.append(".*"); + break; + case '[': + buf.append("\\["); + break; + case ']': + buf.append("\\]"); + break; + default: + buf.append('[').append(c).append(']'); + } + } + } + buf.append("$"); + return Pattern.compile(buf.toString()); + } + + private static ImmutableMap.Builder addProperty( + ImmutableMap.Builder builder, + DatabaseProperty p) { + switch (p) { + case GET_S_Q_L_KEYWORDS: + return builder.put(p, + SqlParser.create("").getMetadata().getJdbcKeywords()); + case GET_NUMERIC_FUNCTIONS: + return builder.put(p, SqlJdbcFunctionCall.getNumericFunctions()); + case GET_STRING_FUNCTIONS: + return builder.put(p, SqlJdbcFunctionCall.getStringFunctions()); + case GET_SYSTEM_FUNCTIONS: + return builder.put(p, SqlJdbcFunctionCall.getSystemFunctions()); + case GET_TIME_DATE_FUNCTIONS: + return builder.put(p, SqlJdbcFunctionCall.getTimeDateFunctions()); + default: + return builder; + } + } + + /** + * Wraps the SQL string in a + * {@link org.apache.calcite.jdbc.CalcitePrepare.Query} object, giving the + * {@link Hook#STRING_TO_QUERY} hook chance to override. + */ + private static CalcitePrepare.Query toQuery( + Context context, String sql) { + final Holder> queryHolder = + Holder.of(CalcitePrepare.Query.of(sql)); + final FrameworkConfig config = Frameworks.newConfigBuilder() + .parserConfig(SqlParser.Config.DEFAULT) + .defaultSchema(context.getRootSchema().plus()) + .build(); + Hook.STRING_TO_QUERY.run(Pair.of(config, queryHolder)); + return queryHolder.get(); + } + + /** + * A trojan-horse method, subject to change without notice. + */ + @VisibleForTesting + public static DataContext createDataContext(CalciteConnection connection) { + return ((CalciteConnectionImpl) connection) + .createDataContext(ImmutableMap.of(), + CalciteSchema.from(connection.getRootSchema())); + } + + /** + * A trojan-horse method, subject to change without notice. + */ + @VisibleForTesting + public static CalciteConnection connect(CalciteSchema schema, + @Nullable JavaTypeFactory typeFactory) { + return DRIVER.connect(schema, typeFactory); + } + + @Override + public StatementHandle createStatement(ConnectionHandle ch) { + final StatementHandle h = super.createStatement(ch); + final CalciteConnectionImpl calciteConnection = getConnection(); + calciteConnection.server.addStatement(calciteConnection, h); + return h; + } + + @Override + public void closeStatement(StatementHandle h) { + final CalciteConnectionImpl calciteConnection = getConnection(); + @SuppressWarnings("unused") final CalciteServerStatement stmt; + try { + stmt = calciteConnection.server.getStatement(h); + } catch (NoSuchStatementException e) { + // statement is not valid; nothing to do + return; + } + // stmt.close(); // TODO: implement + calciteConnection.server.removeStatement(h); + } + + private MetaResultSet createResultSet(Enumerable enumerable, + Class clazz, String... names) { + requireNonNull(names, "names"); + final List columns = new ArrayList<>(names.length); + final List fields = new ArrayList<>(names.length); + final List fieldNames = new ArrayList<>(names.length); + for (String name : names) { + final int index = fields.size(); + final String fieldName = AvaticaUtils.toCamelCase(name); + final Field field; + try { + field = clazz.getField(fieldName); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + columns.add(columnMetaData(name, index, field.getType(), false)); + fields.add(field); + fieldNames.add(fieldName); + } + //noinspection unchecked + final Iterable iterable = (Iterable) (Iterable) enumerable; + return createResultSet(Collections.emptyMap(), + columns, CursorFactory.record(clazz, fields, fieldNames), + new Frame(0, true, iterable)); + } + + @Override + protected MetaResultSet createResultSet( + Map internalParameters, List columns, + CursorFactory cursorFactory, final Frame firstFrame) { + try { + final CalciteConnectionImpl connection = getConnection(); + final AvaticaStatement statement = connection.createStatement(); + final CalcitePrepare.CalciteSignature signature = + new CalcitePrepare.CalciteSignature("", + ImmutableList.of(), internalParameters, null, + columns, cursorFactory, null, ImmutableList.of(), -1, + null, Meta.StatementType.SELECT) { + @Override + public Enumerable enumerable( + DataContext dataContext) { + return Linq4j.asEnumerable(firstFrame.rows); + } + }; + return MetaResultSet.create(connection.id, statement.getId(), true, + signature, firstFrame); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + CalciteConnectionImpl getConnection() { + return (CalciteConnectionImpl) connection; + } + + @Override + public Map getDatabaseProperties(ConnectionHandle ch) { + final ImmutableMap.Builder builder = + ImmutableMap.builder(); + for (DatabaseProperty p : DatabaseProperty.values()) { + addProperty(builder, p); + } + return builder.build(); + } + + @Override + public MetaResultSet getTables(ConnectionHandle ch, + String catalog, + final Pat schemaPattern, + final Pat tableNamePattern, + final List typeList) { + final Predicate1 typeFilter; + if (typeList == null) { + typeFilter = Functions.truePredicate1(); + } else { + typeFilter = v1 -> typeList.contains(v1.tableType); + } + final Predicate1 schemaMatcher = namedMatcher(schemaPattern); + return createResultSet(schemas(catalog) + .where(schemaMatcher) + .selectMany(schema -> tables(schema, matcher(tableNamePattern))) + .where(typeFilter), + MetaTable.class, + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "TABLE_TYPE", + "REMARKS", + "TYPE_CAT", + "TYPE_SCHEM", + "TYPE_NAME", + "SELF_REFERENCING_COL_NAME", + "REF_GENERATION"); + } + + @Override + public MetaResultSet getTypeInfo(ConnectionHandle ch) { + return createResultSet(allTypeInfo(), + MetaTypeInfo.class, + "TYPE_NAME", + "DATA_TYPE", + "PRECISION", + "LITERAL_PREFIX", + "LITERAL_SUFFIX", + "CREATE_PARAMS", + "NULLABLE", + "CASE_SENSITIVE", + "SEARCHABLE", + "UNSIGNED_ATTRIBUTE", + "FIXED_PREC_SCALE", + "AUTO_INCREMENT", + "LOCAL_TYPE_NAME", + "MINIMUM_SCALE", + "MAXIMUM_SCALE", + "SQL_DATA_TYPE", + "SQL_DATETIME_SUB", + "NUM_PREC_RADIX"); + } + + @Override + public MetaResultSet getColumns(ConnectionHandle ch, + String catalog, + Pat schemaPattern, + Pat tableNamePattern, + Pat columnNamePattern) { + final Predicate1 tableNameMatcher = matcher(tableNamePattern); + final Predicate1 schemaMatcher = namedMatcher(schemaPattern); + final Predicate1 columnMatcher = + namedMatcher(columnNamePattern); + return createResultSet(schemas(catalog) + .where(schemaMatcher) + .selectMany(schema -> tables(schema, tableNameMatcher)) + .selectMany(this::columns) + .where(columnMatcher), + MetaColumn.class, + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "DATA_TYPE", + "TYPE_NAME", + "COLUMN_SIZE", + "BUFFER_LENGTH", + "DECIMAL_DIGITS", + "NUM_PREC_RADIX", + "NULLABLE", + "REMARKS", + "COLUMN_DEF", + "SQL_DATA_TYPE", + "SQL_DATETIME_SUB", + "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", + "IS_NULLABLE", + "SCOPE_CATALOG", + "SCOPE_SCHEMA", + "SCOPE_TABLE", + "SOURCE_DATA_TYPE", + "IS_AUTOINCREMENT", + "IS_GENERATEDCOLUMN"); + } + + Enumerable catalogs() { + final String catalog; + try { + catalog = connection.getCatalog(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + return Linq4j.asEnumerable( + ImmutableList.of(new MetaCatalog(catalog))); + } + + Enumerable tableTypes() { + return Linq4j.asEnumerable( + ImmutableList.of( + new MetaTableType("TABLE"), new MetaTableType("VIEW"))); + } + + Enumerable schemas(final String catalog) { + return Linq4j.asEnumerable( + getConnection().rootSchema.getSubSchemaMap().values()) + .select((Function1) calciteSchema -> + new CalciteMetaSchema(calciteSchema, catalog, + calciteSchema.getName())) + .orderBy((Function1) metaSchema -> + (Comparable) FlatLists.of(Util.first(metaSchema.tableCatalog, ""), + metaSchema.tableSchem)); + } + + Enumerable tables(String catalog) { + return schemas(catalog) + .selectMany(schema -> + tables(schema, Functions.truePredicate1())); + } + + Enumerable tables(final MetaSchema schema_) { + final CalciteMetaSchema schema = (CalciteMetaSchema) schema_; + return Linq4j.asEnumerable(schema.calciteSchema.getTableNames()) + .select((Function1) name -> { + final Table table = requireNonNull( + schema.calciteSchema.getTable(name, true), + () -> "table " + name + " is not found (case sensitive)").getTable(); + return new CalciteMetaTable(table, + schema.tableCatalog, + schema.tableSchem, + name); + }) + .concat( + Linq4j.asEnumerable( + schema.calciteSchema.getTablesBasedOnNullaryFunctions() + .entrySet()) + .select(pair -> { + final Table table = pair.getValue(); + return new CalciteMetaTable(table, + schema.tableCatalog, + schema.tableSchem, + pair.getKey()); + })); + } + + Enumerable tables( + final MetaSchema schema, + final Predicate1 matcher) { + return tables(schema) + .where(v1 -> matcher.apply(v1.getName())); + } + + private ImmutableList getAllDefaultType() { + final ImmutableList.Builder allTypeList = + ImmutableList.builder(); + final CalciteConnectionImpl conn = (CalciteConnectionImpl) connection; + final RelDataTypeSystem typeSystem = conn.typeFactory.getTypeSystem(); + for (SqlTypeName sqlTypeName : SqlTypeName.values()) { + if (sqlTypeName.isSpecial()) { + // Skip internal types (NULL, ANY, SYMBOL, SARG). + continue; + } + allTypeList.add( + new MetaTypeInfo(sqlTypeName.getName(), + sqlTypeName.getJdbcOrdinal(), + typeSystem.getMaxPrecision(sqlTypeName), + typeSystem.getLiteral(sqlTypeName, true), + typeSystem.getLiteral(sqlTypeName, false), + // All types are nullable + (short) DatabaseMetaData.typeNullable, + typeSystem.isCaseSensitive(sqlTypeName), + // Making all type searchable; we may want to + // be specific and declare under SqlTypeName + (short) DatabaseMetaData.typeSearchable, + false, + false, + typeSystem.isAutoincrement(sqlTypeName), + (short) sqlTypeName.getMinScale(), + (short) typeSystem.getMaxScale(sqlTypeName), + typeSystem.getNumTypeRadix(sqlTypeName))); + } + return allTypeList.build(); + } + + protected Enumerable allTypeInfo() { + return Linq4j.asEnumerable(getAllDefaultType()); + } + + public Enumerable columns(final MetaTable table_) { + final CalciteMetaTable table = (CalciteMetaTable) table_; + final RelDataType rowType = + table.calciteTable.getRowType(getConnection().typeFactory); + return Linq4j.asEnumerable(rowType.getFieldList()) + .select(field -> { + final int precision = + field.getType().getSqlTypeName().allowsPrec() + && !(field.getType() + instanceof RelDataTypeFactoryImpl.JavaType) + ? field.getType().getPrecision() + : -1; + return new MetaColumn( + table.tableCat, + table.tableSchem, + table.tableName, + field.getName(), + field.getType().getSqlTypeName().getJdbcOrdinal(), + field.getType().getFullTypeString(), + precision, + field.getType().getSqlTypeName().allowsScale() + ? field.getType().getScale() + : null, + 10, + field.getType().isNullable() + ? DatabaseMetaData.columnNullable + : DatabaseMetaData.columnNoNulls, + precision, + field.getIndex() + 1, + field.getType().isNullable() ? "YES" : "NO"); + }); + } + + @Override + public MetaResultSet getSchemas(ConnectionHandle ch, String catalog, + Pat schemaPattern) { + final Predicate1 schemaMatcher = namedMatcher(schemaPattern); + return createResultSet(schemas(catalog).where(schemaMatcher), + MetaSchema.class, + "TABLE_SCHEM", + "TABLE_CATALOG"); + } + + @Override + public MetaResultSet getCatalogs(ConnectionHandle ch) { + return createResultSet(catalogs(), + MetaCatalog.class, + "TABLE_CAT"); + } + + @Override + public MetaResultSet getTableTypes(ConnectionHandle ch) { + return createResultSet(tableTypes(), + MetaTableType.class, + "TABLE_TYPE"); + } + + @Override + public MetaResultSet getFunctions(ConnectionHandle ch, + String catalog, + Pat schemaPattern, + Pat functionNamePattern) { + final Predicate1 schemaMatcher = namedMatcher(schemaPattern); + return createResultSet(schemas(catalog) + .where(schemaMatcher) + .selectMany(schema -> functions(schema, catalog, matcher(functionNamePattern))) + .orderBy(x -> + (Comparable) FlatLists.of( + x.functionCat, x.functionSchem, x.functionName, x.specificName + )), + MetaFunction.class, + "FUNCTION_CAT", + "FUNCTION_SCHEM", + "FUNCTION_NAME", + "REMARKS", + "FUNCTION_TYPE", + "SPECIFIC_NAME"); + } + + Enumerable functions(final MetaSchema schema_, final String catalog) { + final CalciteMetaSchema schema = (CalciteMetaSchema) schema_; + Enumerable opTableFunctions = Linq4j.emptyEnumerable(); + if (schema.calciteSchema.schema.equals(MetadataSchema.INSTANCE)) { + SqlOperatorTable opTable = getConnection().config() + .fun(SqlOperatorTable.class, SqlStdOperatorTable.instance()); + List q = opTable.getOperatorList(); + opTableFunctions = Linq4j.asEnumerable(q) + .where(op -> SqlKind.FUNCTION.contains(op.getKind())) + .select(op -> + new MetaFunction( + catalog, + schema.getName(), + op.getName(), + (short) DatabaseMetaData.functionResultUnknown, + op.getName() + ) + ); + } + return Linq4j.asEnumerable(schema.calciteSchema.getFunctionNames()) + .selectMany(name -> + Linq4j.asEnumerable(schema.calciteSchema.getFunctions(name, true)) + //exclude materialized views from the result set + .where(fn -> !(fn instanceof MaterializedViewTable.MaterializedViewTableMacro)) + .select(fnx -> + new MetaFunction( + catalog, + schema.getName(), + name, + (short) DatabaseMetaData.functionResultUnknown, + name + ) + ) + ) + .concat(opTableFunctions); + } + + Enumerable functions(final MetaSchema schema, final String catalog, + final Predicate1 functionNameMatcher) { + return functions(schema, catalog) + .where(v1 -> functionNameMatcher.apply(v1.functionName)); + } + + @Override + public Iterable createIterable(StatementHandle handle, QueryState state, + Signature signature, + @Nullable List parameterValues, + @Nullable Frame firstFrame) { + // Drop QueryState + return _createIterable(handle, signature, parameterValues, firstFrame); + } + + Iterable _createIterable(StatementHandle handle, + Signature signature, @Nullable List parameterValues, + @Nullable Frame firstFrame) { + try { + //noinspection unchecked + final CalcitePrepare.CalciteSignature calciteSignature = + (CalcitePrepare.CalciteSignature) signature; + return getConnection().enumerable(handle, calciteSignature); + } catch (SQLException e) { + throw new RuntimeException(e.getMessage()); + } + } + + @Override + public StatementHandle prepare(ConnectionHandle ch, String sql, + long maxRowCount) { + final StatementHandle h = createStatement(ch); + final CalciteConnectionImpl calciteConnection = getConnection(); + + final CalciteServerStatement statement; + try { + statement = calciteConnection.server.getStatement(h); + } catch (NoSuchStatementException e) { + // Not possible. We just created a statement. + throw new AssertionError("missing statement", e); + } + final Context context = statement.createPrepareContext(); + final CalcitePrepare.Query query = toQuery(context, sql); + h.signature = calciteConnection.parseQuery(query, context, maxRowCount); + statement.setSignature(h.signature); + return h; + } + + @SuppressWarnings("deprecation") + @Override + public ExecuteResult prepareAndExecute(StatementHandle h, + String sql, long maxRowCount, PrepareCallback callback) + throws NoSuchStatementException { + return prepareAndExecute(h, sql, maxRowCount, -1, callback); + } + + @Override + public ExecuteResult prepareAndExecute(StatementHandle h, + String sql, long maxRowCount, int maxRowsInFirstFrame, + PrepareCallback callback) throws NoSuchStatementException { + final CalcitePrepare.CalciteSignature signature; + try { + final int updateCount; + synchronized (callback.getMonitor()) { + callback.clear(); + final CalciteConnectionImpl calciteConnection = getConnection(); + final CalciteServerStatement statement = + calciteConnection.server.getStatement(h); + final Context context = statement.createPrepareContext(); + final CalcitePrepare.Query query = toQuery(context, sql); + signature = calciteConnection.parseQuery(query, context, maxRowCount); + statement.setSignature(signature); + switch (signature.statementType) { + case CREATE: + case DROP: + case ALTER: + case OTHER_DDL: + updateCount = 0; // DDL produces no result set + break; + default: + updateCount = -1; // SELECT and DML produces result set + break; + } + callback.assign(signature, null, updateCount); + } + callback.execute(); + final MetaResultSet metaResultSet = + MetaResultSet.create(h.connectionId, h.id, false, signature, null, updateCount); + return new ExecuteResult(ImmutableList.of(metaResultSet)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + // TODO: share code with prepare and createIterable + } + + @Override + public Frame fetch(StatementHandle h, long offset, + int fetchMaxRowCount) throws NoSuchStatementException { + final CalciteConnectionImpl calciteConnection = getConnection(); + CalciteServerStatement stmt = calciteConnection.server.getStatement(h); + final Signature signature = requireNonNull(stmt.getSignature(), + () -> "stmt.getSignature() is null for " + stmt); + final Iterator iterator; + Iterator stmtResultSet = stmt.getResultSet(); + if (stmtResultSet == null) { + final Iterable iterable = + _createIterable(h, signature, null, null); + iterator = iterable.iterator(); + stmt.setResultSet(iterator); + } else { + iterator = stmtResultSet; + } + final List rows = + MetaImpl.collect(signature.cursorFactory, + LimitIterator.of(iterator, fetchMaxRowCount), + new ArrayList>()); + boolean done = fetchMaxRowCount == 0 || rows.size() < fetchMaxRowCount; + @SuppressWarnings("unchecked") List rows1 = (List) rows; + return new Meta.Frame(offset, done, rows1); + } + + @SuppressWarnings("deprecation") + @Override + public ExecuteResult execute(StatementHandle h, + List parameterValues, long maxRowCount) + throws NoSuchStatementException { + return execute(h, parameterValues, Ints.saturatedCast(maxRowCount)); + } + + @Override + public ExecuteResult execute(StatementHandle h, + List parameterValues, int maxRowsInFirstFrame) + throws NoSuchStatementException { + final CalciteConnectionImpl calciteConnection = getConnection(); + CalciteServerStatement stmt = calciteConnection.server.getStatement(h); + final Signature signature = requireNonNull(stmt.getSignature(), + () -> "stmt.getSignature() is null for " + stmt); + + MetaResultSet metaResultSet; + if (signature.statementType.canUpdate()) { + final Iterable iterable = + _createIterable(h, signature, parameterValues, null); + final Iterator iterator = iterable.iterator(); + stmt.setResultSet(iterator); + metaResultSet = MetaResultSet.count(h.connectionId, h.id, + ((Number) iterator.next()).intValue()); + } else { + // Don't populate the first frame. + // It's not worth saving a round-trip, since we're local. + final Meta.Frame frame = + new Meta.Frame(0, false, Collections.emptyList()); + metaResultSet = + MetaResultSet.create(h.connectionId, h.id, false, signature, frame); + } + + return new ExecuteResult(ImmutableList.of(metaResultSet)); + } + + @Override + public ExecuteBatchResult executeBatch(StatementHandle h, + List> parameterValueLists) + throws NoSuchStatementException { + final List updateCounts = new ArrayList<>(); + for (List parameterValueList : parameterValueLists) { + ExecuteResult executeResult = execute(h, parameterValueList, -1); + final long updateCount = + executeResult.resultSets.size() == 1 + ? executeResult.resultSets.get(0).updateCount + : -1L; + updateCounts.add(updateCount); + } + return new ExecuteBatchResult(Longs.toArray(updateCounts)); + } + + @Override + public ExecuteBatchResult prepareAndExecuteBatch( + final StatementHandle h, + List sqlCommands) throws NoSuchStatementException { + final CalciteConnectionImpl calciteConnection = getConnection(); + final CalciteServerStatement statement = + calciteConnection.server.getStatement(h); + final List updateCounts = new ArrayList<>(); + final Meta.PrepareCallback callback = + new Meta.PrepareCallback() { + long updateCount; + @Nullable Signature signature; + + @Override + public Object getMonitor() { + return statement; + } + + @Override + public void clear() throws SQLException { + } + + @Override + public void assign(Meta.Signature signature, Meta.@Nullable Frame firstFrame, + long updateCount) throws SQLException { + this.signature = signature; + this.updateCount = updateCount; + } + + @Override + public void execute() throws SQLException { + Signature signature = requireNonNull(this.signature, "signature"); + if (signature.statementType.canUpdate()) { + final Iterable iterable = + _createIterable(h, signature, ImmutableList.of(), + null); + final Iterator iterator = iterable.iterator(); + updateCount = ((Number) iterator.next()).longValue(); + } + updateCounts.add(updateCount); + } + }; + for (String sqlCommand : sqlCommands) { + Util.discard(prepareAndExecute(h, sqlCommand, -1L, -1, callback)); + } + return new ExecuteBatchResult(Longs.toArray(updateCounts)); + } + + @Override + public boolean syncResults(StatementHandle h, QueryState state, long offset) + throws NoSuchStatementException { + // Doesn't have application in Calcite itself. + throw new UnsupportedOperationException(); + } + + @Override + public void commit(ConnectionHandle ch) { + throw new UnsupportedOperationException(); + } + + @Override + public void rollback(ConnectionHandle ch) { + throw new UnsupportedOperationException(); + } + + /** + * Metadata describing a Calcite table. + */ + private static class CalciteMetaTable extends MetaTable { + private final Table calciteTable; + + CalciteMetaTable(Table calciteTable, String tableCat, + String tableSchem, String tableName) { + super(tableCat, tableSchem, tableName, + calciteTable.getJdbcTableType().jdbcName); + this.calciteTable = requireNonNull(calciteTable, "calciteTable"); + } + } + + /** + * Metadata describing a Calcite schema. + */ + private static class CalciteMetaSchema extends MetaSchema { + private final CalciteSchema calciteSchema; + + CalciteMetaSchema(CalciteSchema calciteSchema, + String tableCatalog, String tableSchem) { + super(tableCatalog, tableSchem); + this.calciteSchema = calciteSchema; + } + } + + /** + * Table whose contents are metadata. + * + * @param element type + */ + abstract static class MetadataTable extends AbstractQueryableTable { + MetadataTable(Class clazz) { + super(clazz); + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + return ((JavaTypeFactory) typeFactory).createType(elementType); + } + + @Override + public Schema.TableType getJdbcTableType() { + return Schema.TableType.SYSTEM_TABLE; + } + + @SuppressWarnings("unchecked") + @Override + public Class getElementType() { + return (Class) elementType; + } + + protected abstract Enumerator enumerator(CalciteMetaImpl connection); + + @Override + public Queryable asQueryable(QueryProvider queryProvider, + SchemaPlus schema, String tableName) { + return new AbstractTableQueryable(queryProvider, schema, this, + tableName) { + @SuppressWarnings("unchecked") + @Override + public Enumerator enumerator() { + return (Enumerator) MetadataTable.this.enumerator( + ((CalciteConnectionImpl) queryProvider).meta()); + } + }; + } + } + + /** + * Iterator that returns at most {@code limit} rows from an underlying + * {@link Iterator}. + * + * @param element type + */ + private static class LimitIterator implements Iterator { + private final Iterator iterator; + private final long limit; + int i = 0; + + private LimitIterator(Iterator iterator, long limit) { + this.iterator = iterator; + this.limit = limit; + } + + static Iterator of(Iterator iterator, long limit) { + if (limit <= 0) { + return iterator; + } + return new LimitIterator<>(iterator, limit); + } + + @Override + public boolean hasNext() { + return iterator.hasNext() && i < limit; + } + + @Override + public E next() { + ++i; + return iterator.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/CalcitePreparedStatement.java b/traindb-core/src/main/java/traindb/engine/calcite/CalcitePreparedStatement.java new file mode 100644 index 0000000..d0dbb37 --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/CalcitePreparedStatement.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import java.sql.SQLException; +import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.Meta; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Implementation of {@link java.sql.PreparedStatement} + * for the Calcite engine. + * + *

This class has sub-classes which implement JDBC 3.0 and JDBC 4.0 APIs; + * it is instantiated using + * {@link org.apache.calcite.avatica.AvaticaFactory#newPreparedStatement}. + */ +abstract class CalcitePreparedStatement extends AvaticaPreparedStatement { + /** + * Creates a CalcitePreparedStatement. + * + * @param connection Connection + * @param h Statement handle + * @param signature Result of preparing statement + * @param resultSetType Result set type + * @param resultSetConcurrency Result set concurrency + * @param resultSetHoldability Result set holdability + * @throws SQLException if database error occurs + */ + protected CalcitePreparedStatement(CalciteConnectionImpl connection, + Meta.@Nullable StatementHandle h, Meta.Signature signature, + int resultSetType, + int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + super(connection, h, signature, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public CalciteConnectionImpl getConnection() throws SQLException { + return (CalciteConnectionImpl) super.getConnection(); + } +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/CalciteResultSet.java b/traindb-core/src/main/java/traindb/engine/calcite/CalciteResultSet.java new file mode 100644 index 0000000..731f75b --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/CalciteResultSet.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import com.google.common.collect.ImmutableList; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.List; +import java.util.TimeZone; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.Handler; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.util.Cursor; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.linq4j.Enumerator; +import org.apache.calcite.linq4j.Linq4j; +import org.apache.calcite.runtime.ArrayEnumeratorCursor; +import org.apache.calcite.runtime.ObjectEnumeratorCursor; + +/** + * Implementation of {@link ResultSet} + * for the Calcite engine. + */ +public class CalciteResultSet extends AvaticaResultSet { + + /** + * Creates a CalciteResultSet. + */ + CalciteResultSet(AvaticaStatement statement, + CalcitePrepare.CalciteSignature calciteSignature, + ResultSetMetaData resultSetMetaData, TimeZone timeZone, + Meta.Frame firstFrame) throws SQLException { + super(statement, null, calciteSignature, resultSetMetaData, timeZone, firstFrame); + } + + private static Cursor createCursor(ColumnMetaData.AvaticaType elementType, + Iterable iterable) { + final Enumerator enumerator = Linq4j.iterableEnumerator(iterable); + //noinspection unchecked + return !(elementType instanceof ColumnMetaData.StructType) + || ((ColumnMetaData.StructType) elementType).columns.size() == 1 + ? new ObjectEnumeratorCursor(enumerator) + : new ArrayEnumeratorCursor(enumerator); + } + + @Override + protected CalciteResultSet execute() throws SQLException { + // Call driver's callback. It is permitted to throw a RuntimeException. + CalciteConnectionImpl connection = getCalciteConnection(); + final boolean autoTemp = connection.config().autoTemp(); + Handler.ResultSink resultSink = null; + if (autoTemp) { + resultSink = () -> { + }; + } + connection.getDriver().handler.onStatementExecute(statement, resultSink); + + super.execute(); + return this; + } + + @Override + public ResultSet create(ColumnMetaData.AvaticaType elementType, + Iterable iterable) throws SQLException { + final List columnMetaDataList; + if (elementType instanceof ColumnMetaData.StructType) { + columnMetaDataList = ((ColumnMetaData.StructType) elementType).columns; + } else { + columnMetaDataList = + ImmutableList.of(ColumnMetaData.dummy(elementType, false)); + } + final CalcitePrepare.CalciteSignature signature = + (CalcitePrepare.CalciteSignature) this.signature; + final CalcitePrepare.CalciteSignature newSignature = + new CalcitePrepare.CalciteSignature<>(signature.sql, + signature.parameters, signature.internalParameters, + signature.rowType, columnMetaDataList, Meta.CursorFactory.ARRAY, + signature.rootSchema, ImmutableList.of(), -1, null, + statement.getStatementType()); + ResultSetMetaData subResultSetMetaData = + new AvaticaResultSetMetaData(statement, null, newSignature); + final CalciteResultSet resultSet = + new CalciteResultSet(statement, signature, subResultSetMetaData, + localCalendar.getTimeZone(), new Meta.Frame(0, true, iterable)); + final Cursor cursor = CalciteResultSet.createCursor(elementType, iterable); + return resultSet.execute2(cursor, columnMetaDataList); + } + + // do not make public + CalcitePrepare.CalciteSignature getSignature() { + //noinspection unchecked + return (CalcitePrepare.CalciteSignature) signature; + } + + // do not make public + CalciteConnectionImpl getCalciteConnection() throws SQLException { + return (CalciteConnectionImpl) statement.getConnection(); + } +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/CalciteStatement.java b/traindb-core/src/main/java/traindb/engine/calcite/CalciteStatement.java new file mode 100644 index 0000000..f3e1ec6 --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/CalciteStatement.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import java.sql.SQLException; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.NoSuchStatementException; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.server.CalciteServerStatement; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Implementation of {@link java.sql.Statement} + * for the Calcite engine. + */ +public abstract class CalciteStatement extends AvaticaStatement { + /** + * Creates a CalciteStatement. + * + * @param connection Connection + * @param h Statement handle + * @param resultSetType Result set type + * @param resultSetConcurrency Result set concurrency + * @param resultSetHoldability Result set holdability + */ + CalciteStatement(CalciteConnectionImpl connection, Meta.@Nullable StatementHandle h, + int resultSetType, int resultSetConcurrency, int resultSetHoldability) { + super(connection, h, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + // implement Statement + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface == CalciteServerStatement.class) { + final CalciteServerStatement statement; + try { + statement = getConnection().server.getStatement(handle); + } catch (NoSuchStatementException e) { + throw new AssertionError("invalid statement", e); + } + return iface.cast(statement); + } + return super.unwrap(iface); + } + + @Override + public CalciteConnectionImpl getConnection() { + return (CalciteConnectionImpl) connection; + } + + protected CalcitePrepare.CalciteSignature prepare( + Queryable queryable) { + final CalciteConnectionImpl calciteConnection = getConnection(); + final CalcitePrepare prepare = calciteConnection.prepareFactory.apply(); + final CalciteServerStatement serverStatement; + try { + serverStatement = calciteConnection.server.getStatement(handle); + } catch (NoSuchStatementException e) { + throw new AssertionError("invalid statement", e); + } + final CalcitePrepare.Context prepareContext = + serverStatement.createPrepareContext(); + return prepare.prepareQueryable(prepareContext, queryable); + } + + @Override + protected void close_() { + if (!closed) { + ((CalciteConnectionImpl) connection).server.removeStatement(handle); + super.close_(); + } + } +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/Driver.java b/traindb-core/src/main/java/traindb/engine/calcite/Driver.java new file mode 100644 index 0000000..5d649df --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/Driver.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.BuiltInConnectionProperty; +import org.apache.calcite.avatica.ConnectionProperty; +import org.apache.calcite.avatica.DriverVersion; +import org.apache.calcite.avatica.Handler; +import org.apache.calcite.avatica.HandlerImpl; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.UnregisteredDriver; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.jdbc.CalciteFactory; +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.linq4j.function.Function0; +import org.checkerframework.checker.nullness.qual.Nullable; +import traindb.schema.SchemaManager; + +/** + * Calcite JDBC driver. + */ +public class Driver extends UnregisteredDriver { + public static final String CONNECT_STRING_PREFIX = "jdbc:traindb-calcite:"; + + static { + new Driver().register(); + } + + final Function0 prepareFactory; + + @SuppressWarnings("method.invocation.invalid") + public Driver() { + super(); + this.prepareFactory = createPrepareFactory(); + } + + protected Function0 createPrepareFactory() { + return CalcitePrepare.DEFAULT_FACTORY; + } + + @Override + protected String getConnectStringPrefix() { + return CONNECT_STRING_PREFIX; + } + + @Override + protected String getFactoryClassName(JdbcVersion jdbcVersion) { + switch (jdbcVersion) { + case JDBC_30: + case JDBC_40: + throw new IllegalArgumentException("JDBC version not supported: " + + jdbcVersion); + case JDBC_41: + default: + return "traindb.engine.calcite.CalciteJdbc41Factory"; + } + } + + @Override + protected DriverVersion createDriverVersion() { + return DriverVersion.load( + Driver.class, + "org-apache-calcite-jdbc.properties", + "Calcite JDBC Driver", + "unknown version", + "Calcite", + "unknown version"); + } + + @Override + protected Handler createHandler() { + return new HandlerImpl() { + @Override + public void onConnectionInit(AvaticaConnection connection_) + throws SQLException { + final CalciteConnectionImpl connection = (CalciteConnectionImpl) connection_; + super.onConnectionInit(connection); + CalciteSchema calciteSchema = CalciteSchema.from( + SchemaManager.getInstance(null).getCurrentSchema()); + connection.setRootSchema(calciteSchema); + connection.init(); + } + }; + } + + @Override + protected Collection getConnectionProperties() { + final List list = new ArrayList<>(); + Collections.addAll(list, BuiltInConnectionProperty.values()); + Collections.addAll(list, CalciteConnectionProperty.values()); + return list; + } + + @Override + public Meta createMeta(AvaticaConnection connection) { + return new CalciteMetaImpl((CalciteConnectionImpl) connection); + } + + /** + * Creates an internal connection. + */ + CalciteConnection connect(CalciteSchema rootSchema, + @Nullable JavaTypeFactory typeFactory) { + return (CalciteConnection) ((CalciteFactory) factory) + .newConnection(this, factory, CONNECT_STRING_PREFIX, new Properties(), + (CalciteSchema) SchemaManager.getInstance(null).getCurrentSchema(), typeFactory); + } + +} diff --git a/traindb-core/src/main/java/traindb/engine/calcite/MetadataSchema.java b/traindb-core/src/main/java/traindb/engine/calcite/MetadataSchema.java new file mode 100644 index 0000000..3947bb1 --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/calcite/MetadataSchema.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine.calcite; + +import static org.apache.calcite.avatica.MetaImpl.MetaColumn; +import static org.apache.calcite.avatica.MetaImpl.MetaTable; + +import com.google.common.collect.ImmutableMap; +import java.sql.SQLException; +import java.util.Map; +import org.apache.calcite.linq4j.Enumerator; +import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.AbstractSchema; + +/** + * Schema that contains metadata tables such as "TABLES" and "COLUMNS". + */ +class MetadataSchema extends AbstractSchema { + public static final Schema INSTANCE = new MetadataSchema(); + private static final Map TABLE_MAP = + ImmutableMap.of( + "COLUMNS", + new CalciteMetaImpl.MetadataTable(MetaColumn.class) { + @Override + public Enumerator enumerator( + final CalciteMetaImpl meta) { + final String catalog; + try { + catalog = meta.getConnection().getCatalog(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + return meta.tables(catalog) + .selectMany(meta::columns).enumerator(); + } + }, + "TABLES", + new CalciteMetaImpl.MetadataTable(MetaTable.class) { + @Override + public Enumerator enumerator(CalciteMetaImpl meta) { + final String catalog; + try { + catalog = meta.getConnection().getCatalog(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + return meta.tables(catalog).enumerator(); + } + }); + + /** + * Creates the data dictionary, also called the information schema. It is a + * schema called "metadata" that contains tables "TABLES", "COLUMNS" etc. + */ + private MetadataSchema() { + } + + @Override + protected Map getTableMap() { + return TABLE_MAP; + } +} diff --git a/traindb-core/src/main/java/traindb/schema/SchemaManager.java b/traindb-core/src/main/java/traindb/schema/SchemaManager.java index 7e5cfe9..362d7ed 100644 --- a/traindb-core/src/main/java/traindb/schema/SchemaManager.java +++ b/traindb-core/src/main/java/traindb/schema/SchemaManager.java @@ -14,6 +14,9 @@ package traindb.schema; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; @@ -21,7 +24,9 @@ import javax.sql.DataSource; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.Table; import org.apache.calcite.tools.Frameworks; +import traindb.adapter.jdbc.TrainDBJdbcDataSource; import traindb.catalog.CatalogStore; import traindb.common.TrainDBLogger; @@ -35,6 +40,10 @@ public final class SchemaManager { private final CatalogStore catalogStore; private TrainDBJdbcDataSource traindbDataSource; + private final Map> dataSourceMap; + private final Map> schemaMap; + private final Map> tableMap; + // to synchronize requests for Calcite Schema private final ReadWriteLock lock = new ReentrantReadWriteLock(false); private final Lock readLock = lock.readLock(); @@ -45,6 +54,15 @@ private SchemaManager(CatalogStore catalogStore) { this.catalogStore = catalogStore; rootSchema = Frameworks.createRootSchema(false); traindbDataSource = null; + dataSourceMap = new HashMap<>(); + schemaMap = new HashMap<>(); + tableMap = new HashMap<>(); + + try { + Class.forName("traindb.engine.calcite.Driver"); + } catch (ClassNotFoundException e) { + // FIXME + } } public static SchemaManager getInstance(CatalogStore catalogStore) { @@ -58,7 +76,8 @@ public static SchemaManager getInstance(CatalogStore catalogStore) { public void loadDataSource(DataSource dataSource) { SchemaPlus newRootSchema = Frameworks.createRootSchema(false); TrainDBJdbcDataSource newJdbcDataSource = new TrainDBJdbcDataSource(newRootSchema, dataSource); - addSchemaInfo(newRootSchema, newJdbcDataSource.getSubSchemaMap()); + newRootSchema.add(newJdbcDataSource.getName(), newJdbcDataSource); + addDataSourceToMaps(newJdbcDataSource); writeLock.lock(); this.traindbDataSource = newJdbcDataSource; @@ -67,14 +86,33 @@ public void loadDataSource(DataSource dataSource) { } public void refreshDataSource() { + writeLock.lock(); + dataSourceMap.clear(); + schemaMap.clear(); + tableMap.clear(); loadDataSource(traindbDataSource.getDataSource()); + writeLock.unlock(); + } + + private void addDataSourceToMaps(TrainDBDataSource traindbDataSource) { + addToListMap(dataSourceMap, traindbDataSource.getName(), traindbDataSource); + for (Schema schema : traindbDataSource.getSubSchemaMap().values()) { + TrainDBSchema traindbSchema = (TrainDBSchema) schema; + addToListMap(schemaMap, traindbSchema.getName(), traindbSchema); + for (Table table : traindbSchema.getTableMap().values()) { + TrainDBTable traindbTable = (TrainDBTable) table; + addToListMap(tableMap, traindbTable.getName(), traindbTable); + } + } } - private void addSchemaInfo(SchemaPlus parentSchema, Map subSchemaMap) { - for (Map.Entry entry : subSchemaMap.entrySet()) { - TrainDBJdbcSchema schema = (TrainDBJdbcSchema) entry.getValue(); - parentSchema.add(entry.getKey(), schema); + private void addToListMap(Map> map, String key, T value) { + List values = map.get(key); + if (values == null) { + values = new ArrayList<>(); + map.put(key, values); } + values.add(value); } public SchemaPlus getCurrentSchema() { @@ -96,4 +134,66 @@ public void lockRead() { public void unlockRead() { readLock.unlock(); } + + public List toFullyQualifiedTableName(List names, String defaultSchema) { + TrainDBDataSource dataSource = null; + TrainDBSchema schema = null; + TrainDBTable table = null; + + List candidateDataSources; + List candidateSchemas; + + switch (names.size()) { + case 1: // table + candidateSchemas = schemaMap.get(defaultSchema); + if (candidateSchemas == null || candidateSchemas.size() != 1) { + throw new RuntimeException("invalid name: " + defaultSchema + "." + names.get(0)); + } + schema = candidateSchemas.get(0); + table = (TrainDBTable) schema.getTable(names.get(0)); + if (table == null) { + throw new RuntimeException("invalid name: " + defaultSchema + "." + names.get(0)); + } + dataSource = schema.getDataSource(); + break; + case 2: // schema.table + candidateSchemas = schemaMap.get(names.get(0)); + if (candidateSchemas == null || candidateSchemas.size() != 1) { + throw new RuntimeException("invalid name: " + names.get(0) + "." + names.get(1)); + } + schema = candidateSchemas.get(0); + table = (TrainDBTable) schema.getTable(names.get(1)); + if (table == null) { + throw new RuntimeException("invalid name: " + names.get(0) + "." + names.get(1)); + } + dataSource = schema.getDataSource(); + break; + case 3: // dataSource.schema.table + candidateDataSources = dataSourceMap.get(names.get(0)); + if (candidateDataSources == null || candidateDataSources.size() != 1) { + throw new RuntimeException( + "invalid name: " + names.get(0) + "." + names.get(1) + "." + names.get(2)); + } + dataSource = candidateDataSources.get(0); + schema = (TrainDBSchema) dataSource.getSubSchemaMap().get(names.get(1)); + if (schema == null) { + throw new RuntimeException( + "invalid name: " + names.get(0) + "." + names.get(1) + "." + names.get(2)); + } + table = (TrainDBTable) schema.getTable(names.get(2)); + if (table == null) { + throw new RuntimeException( + "invalid name: " + names.get(0) + "." + names.get(1) + "." + names.get(2)); + } + break; + default: + throw new RuntimeException("invalid identifier length: " + names.size()); + } + + List fqn = new ArrayList<>(); + fqn.add(dataSource.getName()); + fqn.add(schema.getName()); + fqn.add(table.getName()); + return fqn; + } } diff --git a/traindb-core/src/main/java/traindb/schema/TrainDBDataSource.java b/traindb-core/src/main/java/traindb/schema/TrainDBDataSource.java new file mode 100644 index 0000000..6a32819 --- /dev/null +++ b/traindb-core/src/main/java/traindb/schema/TrainDBDataSource.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.schema; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.impl.AbstractSchema; +import traindb.common.TrainDBLogger; + +public abstract class TrainDBDataSource extends AbstractSchema { + private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBDataSource.class); + + private final String name; + private ImmutableMap subSchemaMap; + + public TrainDBDataSource() { + this.name = "traindb"; // FIXME + } + + public final String getName() { + return name; + } + + @Override + public final boolean isMutable() { + return false; + } + + @Override + public final Map getSubSchemaMap() { + LOG.debug("getSubSchemaMap called. subSchemaMapSize=" + subSchemaMap.size()); + return subSchemaMap; + } + + public final void setSubSchemaMap(ImmutableMap subSchemaMap) { + this.subSchemaMap = subSchemaMap; + } +} diff --git a/traindb-core/src/main/java/traindb/schema/TrainDBJdbcSchema.java b/traindb-core/src/main/java/traindb/schema/TrainDBJdbcSchema.java deleted file mode 100644 index 7a5bbe9..0000000 --- a/traindb-core/src/main/java/traindb/schema/TrainDBJdbcSchema.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package traindb.schema; - -import com.google.common.collect.ImmutableMap; -import java.sql.Connection; -import java.sql.DatabaseMetaData; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.Map; -import org.apache.calcite.avatica.MetaImpl; -import org.apache.calcite.schema.Table; -import org.apache.calcite.schema.impl.AbstractSchema; -import traindb.common.TrainDBLogger; - -public class TrainDBJdbcSchema extends AbstractSchema { - private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBJdbcSchema.class); - private final String name; - private TrainDBJdbcDataSource jdbcDataSource; - private ImmutableMap tableMap; - - public TrainDBJdbcSchema(String name, TrainDBJdbcDataSource jdbcDataSource) { - this.name = name; - this.jdbcDataSource = jdbcDataSource; - computeTableMap(); - } - - public void computeTableMap() { - final ImmutableMap.Builder builder = ImmutableMap.builder(); - Connection connection = null; - ResultSet resultSet = null; - try { - connection = jdbcDataSource.getDataSource().getConnection(); - DatabaseMetaData databaseMetaData = connection.getMetaData(); - resultSet = databaseMetaData.getTables(name, null, null, null); - while (resultSet.next()) { - final String catalogName = resultSet.getString(1); - final String schemaName = resultSet.getString(2); - final String tableName = resultSet.getString(3); - final String tableTypeName = resultSet.getString(4).replace(" ", "_"); - - builder.put(tableName, new TrainDBJdbcTable(tableName, this, - new MetaImpl.MetaTable(catalogName, schemaName, tableName, tableTypeName), - databaseMetaData)); - } - } catch (SQLException e) { - throw new RuntimeException(e); - } finally { - JdbcUtils.close(connection, null, resultSet); - } - setTableMap(builder.build()); - } - - public final String getName() { - return name; - } - - @Override - public final boolean isMutable() { - return false; - } - - @Override - public final Map getTableMap() { - LOG.debug("getTableMap called. tableMapSize: " + tableMap.size()); - if (tableMap == null) { - computeTableMap(); - } - return tableMap; - } - - public final void setTableMap(ImmutableMap tableMap) { - this.tableMap = tableMap; - } - - public final TrainDBJdbcDataSource getJdbcDataSource() { - return jdbcDataSource; - } -} diff --git a/traindb-core/src/main/java/traindb/schema/TrainDBJdbcTable.java b/traindb-core/src/main/java/traindb/schema/TrainDBJdbcTable.java deleted file mode 100644 index b85ea70..0000000 --- a/traindb-core/src/main/java/traindb/schema/TrainDBJdbcTable.java +++ /dev/null @@ -1,279 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package traindb.schema; - -import java.sql.DatabaseMetaData; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import javax.annotation.Nullable; -import org.apache.calcite.DataContext; -import org.apache.calcite.adapter.java.AbstractQueryableTable; -import org.apache.calcite.adapter.java.JavaTypeFactory; -import org.apache.calcite.avatica.ColumnMetaData; -import org.apache.calcite.avatica.SqlType; -import org.apache.calcite.jdbc.CalciteConnection; -import org.apache.calcite.linq4j.Enumerable; -import org.apache.calcite.linq4j.Enumerator; -import org.apache.calcite.linq4j.QueryProvider; -import org.apache.calcite.linq4j.Queryable; -import org.apache.calcite.plan.RelOptTable; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rel.type.RelDataTypeImpl; -import org.apache.calcite.rel.type.RelDataTypeSystem; -import org.apache.calcite.rel.type.RelProtoDataType; -import org.apache.calcite.runtime.ResultSetEnumerable; -import org.apache.calcite.schema.ScannableTable; -import org.apache.calcite.schema.Schema; -import org.apache.calcite.schema.SchemaPlus; -import org.apache.calcite.schema.TranslatableTable; -import org.apache.calcite.schema.impl.AbstractTableQueryable; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlSelect; -import org.apache.calcite.sql.SqlWriterConfig; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.pretty.SqlPrettyWriter; -import org.apache.calcite.sql.type.SqlTypeFactoryImpl; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.util.SqlString; -import org.apache.calcite.util.Pair; -import org.apache.calcite.util.Util; -import traindb.common.TrainDBLogger; -import org.apache.calcite.avatica.MetaImpl; - -public final class TrainDBJdbcTable extends AbstractQueryableTable - implements TranslatableTable, ScannableTable { - private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBJdbcTable.class); - private final String name; - private final TrainDBJdbcSchema jdbcSchema; - private Schema.TableType tableType; - private RelProtoDataType protoRowType; - - public TrainDBJdbcTable(String name, TrainDBJdbcSchema schema, MetaImpl.MetaTable tableDef, - DatabaseMetaData databaseMetaData) { - super(Object[].class); - this.name = name; - this.tableType = Schema.TableType.valueOf(tableDef.tableType); - this.jdbcSchema = schema; - - RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT); - RelDataTypeFactory.Builder builder = typeFactory.builder(); - ResultSet resultSet = null; - try { - resultSet = databaseMetaData.getColumns( - tableDef.tableCat, tableDef.tableSchem, tableDef.tableName, null); - while (resultSet.next()) { - String columnName = resultSet.getString(4); - int dataType = resultSet.getInt(5); - String typeString = resultSet.getString(6); - int precision; - int scale; - switch (SqlType.valueOf(dataType)) { - case TIMESTAMP: - case TIME: - precision = resultSet.getInt(9); // SCALE - scale = 0; - break; - default: - precision = resultSet.getInt(7); // SIZE - scale = resultSet.getInt(9); // SCALE - break; - } - RelDataType sqlType = sqlType(typeFactory, dataType, precision, scale, typeString); - - builder.add(columnName, sqlType); - } - } catch (SQLException e) { - LOG.debug(e.getMessage()); - JdbcUtils.close(null, null, resultSet); - } - protoRowType = RelDataTypeImpl.proto(builder.build()); - } - - private static RelDataType sqlType(RelDataTypeFactory typeFactory, int dataType, - int precision, int scale, @Nullable String typeString) { - // Fall back to ANY if type is unknown - final SqlTypeName sqlTypeName = - Util.first(SqlTypeName.getNameForJdbcType(dataType), SqlTypeName.ANY); - switch (sqlTypeName) { - case ARRAY: - RelDataType component = null; - if (typeString != null && typeString.endsWith(" ARRAY")) { - // E.g. hsqldb gives "INTEGER ARRAY", so we deduce the component type - // "INTEGER". - final String remaining = typeString.substring(0, - typeString.length() - " ARRAY".length()); - component = parseTypeString(typeFactory, remaining); - } - if (component == null) { - component = typeFactory.createTypeWithNullability( - typeFactory.createSqlType(SqlTypeName.ANY), true); - } - return typeFactory.createArrayType(component, -1); - default: - break; - } - if (precision >= 0 - && scale >= 0 - && sqlTypeName.allowsPrecScale(true, true)) { - return typeFactory.createSqlType(sqlTypeName, precision, scale); - } else if (precision >= 0 && sqlTypeName.allowsPrecNoScale()) { - return typeFactory.createSqlType(sqlTypeName, precision); - } else { - assert sqlTypeName.allowsNoPrecNoScale(); - return typeFactory.createSqlType(sqlTypeName); - } - } - - /** Given "INTEGER", returns BasicSqlType(INTEGER). - * Given "VARCHAR(10)", returns BasicSqlType(VARCHAR, 10). - * Given "NUMERIC(10, 2)", returns BasicSqlType(NUMERIC, 10, 2). */ - private static RelDataType parseTypeString(RelDataTypeFactory typeFactory, - String typeString) { - int precision = -1; - int scale = -1; - int open = typeString.indexOf("("); - if (open >= 0) { - int close = typeString.indexOf(")", open); - if (close >= 0) { - String rest = typeString.substring(open + 1, close); - typeString = typeString.substring(0, open); - int comma = rest.indexOf(","); - if (comma >= 0) { - precision = Integer.parseInt(rest.substring(0, comma)); - scale = Integer.parseInt(rest.substring(comma)); - } else { - precision = Integer.parseInt(rest); - } - } - } - try { - final SqlTypeName typeName = SqlTypeName.valueOf(typeString); - return typeName.allowsPrecScale(true, true) - ? typeFactory.createSqlType(typeName, precision, scale) - : typeName.allowsPrecScale(true, false) - ? typeFactory.createSqlType(typeName, precision) - : typeFactory.createSqlType(typeName); - } catch (IllegalArgumentException e) { - return typeFactory.createTypeWithNullability( - typeFactory.createSqlType(SqlTypeName.ANY), true); - } - } - - @Override - public String toString() { - return "TrainDBTable {" + getName() + "}"; - } - - @Override - public final Schema.TableType getJdbcTableType() { - return tableType; - } - - @Override - public RelDataType getRowType(RelDataTypeFactory relDataTypeFactory) { - return protoRowType.apply(relDataTypeFactory); - } - - private List> fieldClasses(final JavaTypeFactory typeFactory) { - final RelDataType rowType = getRowType(typeFactory); - return Util.transform(rowType.getFieldList(), f -> { - final RelDataType type = f.getType(); - final Class clazz = (Class) typeFactory.getJavaClass(type); - final ColumnMetaData.Rep rep = - Util.first(ColumnMetaData.Rep.of(clazz), - ColumnMetaData.Rep.OBJECT); - return Pair.of(rep, type.getSqlTypeName().getJdbcOrdinal()); - }); - } - - SqlString generateSql() { - final SqlNodeList selectList = SqlNodeList.SINGLETON_STAR; - SqlSelect node = - new SqlSelect(SqlParserPos.ZERO, SqlNodeList.EMPTY, selectList, - tableName(), null, null, null, null, null, null, null, null); - final SqlWriterConfig config = SqlPrettyWriter.config() - .withAlwaysUseParentheses(true) - .withDialect(jdbcSchema.getJdbcDataSource().getDialect()); - final SqlPrettyWriter writer = new SqlPrettyWriter(config); - node.unparse(writer, 0, 0); - return writer.toSqlString(); - } - - SqlIdentifier tableName() { - final List strings = new ArrayList<>(); - strings.add(getJdbcSchema().getName()); - strings.add(getName()); - return new SqlIdentifier(strings, SqlParserPos.ZERO); - } - - public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable relOptTable) { - return new JdbcTableScan(context.getCluster(), relOptTable, this, - getJdbcDataSource().getConvention()); - } - - @Override - public Queryable asQueryable(QueryProvider queryProvider, - SchemaPlus schema, String tableName) { - return new JdbcTableQueryable<>(queryProvider, schema, tableName); - } - - public Enumerable scan(DataContext root) { - JavaTypeFactory typeFactory = root.getTypeFactory(); - final SqlString sql = generateSql(); - return ResultSetEnumerable.of(getJdbcDataSource().getDataSource(), sql.getSql(), - JdbcUtils.rowBuilderFactory2(fieldClasses(typeFactory))); - } - - public final String getName() { - return name; - } - - public final TrainDBJdbcSchema getJdbcSchema() { - return jdbcSchema; - } - - public final TrainDBJdbcDataSource getJdbcDataSource() { - return jdbcSchema.getJdbcDataSource(); - } - - private class JdbcTableQueryable extends AbstractTableQueryable { - JdbcTableQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) { - super(queryProvider, schema, TrainDBJdbcTable.this, tableName); - } - - @Override - public String toString() { - return "JdbcTableQueryable {table: " + tableName + "}"; - } - - @Override - public Enumerator enumerator() { - final JavaTypeFactory typeFactory = - ((CalciteConnection) queryProvider).getTypeFactory(); - final SqlString sql = generateSql(); - final List> pairs = fieldClasses(typeFactory); - @SuppressWarnings({"rawtypes", "unchecked"}) - final Enumerable enumerable = - (Enumerable) ResultSetEnumerable.of(getJdbcDataSource().getDataSource(), sql.getSql(), - JdbcUtils.rowBuilderFactory2(pairs)); - return enumerable.enumerator(); - } - } -} diff --git a/traindb-core/src/main/java/traindb/schema/TrainDBSchema.java b/traindb-core/src/main/java/traindb/schema/TrainDBSchema.java new file mode 100644 index 0000000..5d056a7 --- /dev/null +++ b/traindb-core/src/main/java/traindb/schema/TrainDBSchema.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.schema; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.AbstractSchema; +import traindb.common.TrainDBLogger; + +public abstract class TrainDBSchema extends AbstractSchema { + private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBSchema.class); + private final String name; + private TrainDBDataSource dataSource; + private ImmutableMap tableMap; + + public TrainDBSchema(String name, TrainDBDataSource dataSource) { + this.name = name; + this.dataSource = dataSource; + } + + public final String getName() { + return name; + } + + @Override + public final boolean isMutable() { + return false; + } + + @Override + public final Map getTableMap() { + LOG.debug("getTableMap called. tableMapSize: " + tableMap.size()); + return tableMap; + } + + public final void setTableMap(ImmutableMap tableMap) { + this.tableMap = tableMap; + } + + public final TrainDBDataSource getDataSource() { + return dataSource; + } +} diff --git a/traindb-core/src/main/java/traindb/schema/TrainDBTable.java b/traindb-core/src/main/java/traindb/schema/TrainDBTable.java new file mode 100644 index 0000000..568f76e --- /dev/null +++ b/traindb-core/src/main/java/traindb/schema/TrainDBTable.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.schema; + +import org.apache.calcite.adapter.java.AbstractQueryableTable; +import org.apache.calcite.linq4j.QueryProvider; +import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeImpl; +import org.apache.calcite.rel.type.RelProtoDataType; +import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.SchemaPlus; +import traindb.common.TrainDBLogger; + +public abstract class TrainDBTable extends AbstractQueryableTable { + private static TrainDBLogger LOG = TrainDBLogger.getLogger(TrainDBTable.class); + private final String name; + private final TrainDBSchema schema; + private Schema.TableType tableType; + private RelProtoDataType protoRowType; + + public TrainDBTable(String name, TrainDBSchema schema, Schema.TableType tableType, + RelDataType protoType) { + super(Object[].class); + this.name = name; + this.tableType = tableType; + this.schema = schema; + this.protoRowType = RelDataTypeImpl.proto(protoType); + } + + @Override + public final Schema.TableType getJdbcTableType() { + return tableType; + } + + @Override + public RelDataType getRowType(RelDataTypeFactory relDataTypeFactory) { + return protoRowType.apply(relDataTypeFactory); + } + + @Override + public abstract Queryable asQueryable( + QueryProvider queryProvider, SchemaPlus schema, String tableName); + + public final String getName() { + return name; + } + + public final TrainDBSchema getSchema() { + return schema; + } +} diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSql.java b/traindb-core/src/main/java/traindb/sql/TrainDBSql.java index 630185f..921958f 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSql.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSql.java @@ -70,7 +70,7 @@ public static VerdictSingleResult run(TrainDBSqlCommand command, TrainDBSqlRunne return runner.showModels(); case SHOW_MODEL_INSTANCES: TrainDBSqlShowCommand showModelInstances = (TrainDBSqlShowCommand) command; - return runner.showModelInstances(showModelInstances.getModelName()); + return runner.showModelInstances(); case TRAIN_MODEL_INSTANCE: TrainDBSqlTrainModelInstance trainModelInstance = (TrainDBSqlTrainModelInstance) command; runner.trainModelInstance( @@ -145,7 +145,7 @@ public void exitShowModels(TrainDBSqlParser.ShowModelsContext ctx) { @Override public void exitShowModelInstances(TrainDBSqlParser.ShowModelInstancesContext ctx) { - commands.add(new TrainDBSqlShowCommand.ModelInstances(ctx.modelName().getText())); + commands.add(new TrainDBSqlShowCommand.ModelInstances()); } @Override diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java b/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java index 0d50ff4..04b8ab9 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java @@ -36,7 +36,7 @@ void createSynopsis(String synopsisName, String modelInstanceName, int limitNumb VerdictSingleResult showModels() throws Exception; - VerdictSingleResult showModelInstances(String modelName) throws Exception; + VerdictSingleResult showModelInstances() throws Exception; VerdictSingleResult showSynopses() throws Exception; diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java b/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java index 287b81f..bf38eaa 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java @@ -15,19 +15,13 @@ package traindb.sql; abstract class TrainDBSqlShowCommand extends TrainDBSqlCommand { - private String modelName; - protected TrainDBSqlShowCommand(String modelName) { - this.modelName = modelName; - } - - String getModelName() { - return modelName; + protected TrainDBSqlShowCommand() { } static class Models extends TrainDBSqlShowCommand { Models() { - super(null); + super(); } @Override @@ -37,8 +31,8 @@ public Type getType() { } static class ModelInstances extends TrainDBSqlShowCommand { - ModelInstances(String modelName) { - super(modelName); + ModelInstances() { + super(); } @Override @@ -49,7 +43,7 @@ public Type getType() { static class Synopses extends TrainDBSqlShowCommand { Synopses() { - super(null); + super(); } @Override diff --git a/traindb-core/src/test/resources/sql/basic.iq b/traindb-core/src/test/resources/sql/basic.iq index 48d42c5..7e17c37 100644 --- a/traindb-core/src/test/resources/sql/basic.iq +++ b/traindb-core/src/test/resources/sql/basic.iq @@ -22,7 +22,7 @@ TRAIN MODEL tablegan INSTANCE tgan ON instacart.order_products(product_id, add_t !update -SHOW MODEL tablegan INSTANCES; +SHOW MODEL INSTANCES; +----------+----------------+-----------+----------------+---------------------------------+ | model | model_instance | schema | table | columns | +----------+----------------+-----------+----------------+---------------------------------+ @@ -47,7 +47,7 @@ SHOW SYNOPSES; !ok -SELECT count(*) FROM instacart.order_products_syn; +SELECT count(*) as c2 FROM instacart.order_products_syn; +------+ | c2 | +------+ @@ -76,7 +76,7 @@ DROP MODEL INSTANCE tgan; !update -SHOW MODEL tablegan INSTANCES; +SHOW MODEL INSTANCES; + | +