Skip to content

Commit

Permalink
LibSQL+SQLServer: Bare bones INSERT and SELECT statements
Browse files Browse the repository at this point in the history
This patch provides very basic, bare bones implementations of the
INSERT and SELECT statements. They are *very* limited:
- The only variant of the INSERT statement that currently works is
   SELECT INTO schema.table (column1, column2, ....) VALUES
      (value11, value21, ...), (value12, value22, ...), ...
   where the values are literals.
- The SELECT statement is even more limited, and is only provided to
  allow verification of the INSERT statement. The only form implemented
  is: SELECT * FROM schema.table

These statements required a bit of change in the Statement::execute
API. Originally execute only received a Database object as parameter.
This is not enough; we now pass an ExecutionContext object which
contains the Database, the current result set, and the last Tuple read
from the database. This object will undoubtedly evolve over time.

This API change dragged SQLServer::SQLStatement into the patch.

Another API addition is Expression::evaluate. This method is,
unsurprisingly, used to evaluate expressions, like the values in the
INSERT statement.

Finally, a new test file is added: TestSqlStatementExecution, which
tests the currently implemented statements. As the number and flavour of
implemented statements grows, this test file will probably have to be
restructured.
  • Loading branch information
JanDeVisser authored and awesomekling committed Aug 21, 2021
1 parent 230118c commit d074a60
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 16 deletions.
101 changes: 101 additions & 0 deletions Tests/LibSQL/TestSqlStatementExecution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2021, Jan de Visser <[email protected]>
*
* SPDX-License-Identifier: BSD-2-Clause
*/

#include <unistd.h>

#include <AK/ScopeGuard.h>
#include <LibSQL/AST/Parser.h>
#include <LibSQL/Database.h>
#include <LibSQL/Row.h>
#include <LibSQL/SQLResult.h>
#include <LibSQL/Value.h>
#include <LibTest/TestCase.h>

namespace {

constexpr const char* db_name = "/tmp/test.db";

RefPtr<SQL::SQLResult> execute(NonnullRefPtr<SQL::Database> database, String const& sql)
{
auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql));
auto statement = parser.next_statement();
EXPECT(!parser.has_errors());
if (parser.has_errors()) {
outln(parser.errors()[0].to_string());
}
SQL::AST::ExecutionContext context { database };
auto result = statement->execute(context);
EXPECT(result->error().code == SQL::SQLErrorCode::NoError);
return result;
}

void create_schema(NonnullRefPtr<SQL::Database> database)
{
auto result = execute(database, "CREATE SCHEMA TestSchema;");
EXPECT(result->inserted() == 1);
}

void create_table(NonnullRefPtr<SQL::Database> database)
{
create_schema(database);
auto result = execute(database, "CREATE TABLE TestSchema.TestTable ( TextColumn text, IntColumn integer );");
EXPECT(result->inserted() == 1);
}

TEST_CASE(create_schema)
{
ScopeGuard guard([]() { unlink(db_name); });
auto database = SQL::Database::construct(db_name);
create_schema(database);
auto schema = database->get_schema("TESTSCHEMA");
EXPECT(schema);
}

TEST_CASE(create_table)
{
ScopeGuard guard([]() { unlink(db_name); });
auto database = SQL::Database::construct(db_name);
create_table(database);
auto table = database->get_table("TESTSCHEMA", "TESTTABLE");
EXPECT(table);
}

TEST_CASE(insert_into_table)
{
ScopeGuard guard([]() { unlink(db_name); });
auto database = SQL::Database::construct(db_name);
create_table(database);
auto result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test', 42 );");
EXPECT(result->inserted() == 1);

auto table = database->get_table("TESTSCHEMA", "TESTTABLE");

int count = 0;
for (auto& row : database->select_all(*table)) {
EXPECT_EQ(row["TEXTCOLUMN"].to_string(), "Test");
EXPECT_EQ(row["INTCOLUMN"].to_int().value(), 42);
count++;
}
EXPECT_EQ(count, 1);
}

TEST_CASE(select_from_table)
{
ScopeGuard guard([]() { unlink(db_name); });
auto database = SQL::Database::construct(db_name);
create_table(database);
auto result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_1', 42 ), ( 'Test_2', 43 );");
EXPECT(result->inserted() == 2);
result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_3', 44 ), ( 'Test_4', 45 );");
EXPECT(result->inserted() == 2);
result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_5', 46 );");
EXPECT(result->inserted() == 1);
result = execute(database, "SELECT * FROM TestSchema.TestTable;");
EXPECT(result->has_results());
EXPECT_EQ(result->results().size(), 5u);
}

}
24 changes: 21 additions & 3 deletions Userland/Libraries/LibSQL/AST/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,15 @@ class LimitClause : public ASTNode {
// Expressions
//==================================================================================================

struct ExecutionContext {
NonnullRefPtr<Database> database;
RefPtr<SQLResult> result { nullptr };
Tuple current_row {};
};

class Expression : public ASTNode {
public:
virtual Value evaluate(ExecutionContext&) const;
};

class ErrorExpression final : public Expression {
Expand All @@ -309,6 +317,7 @@ class NumericLiteral : public Expression {
}

double value() const { return m_value; }
virtual Value evaluate(ExecutionContext&) const override;

private:
double m_value;
Expand All @@ -322,6 +331,7 @@ class StringLiteral : public Expression {
}

const String& value() const { return m_value; }
virtual Value evaluate(ExecutionContext&) const override;

private:
String m_value;
Expand All @@ -341,11 +351,14 @@ class BlobLiteral : public Expression {
};

class NullLiteral : public Expression {
public:
virtual Value evaluate(ExecutionContext&) const override;
};

class NestedExpression : public Expression {
public:
const NonnullRefPtr<Expression>& expression() const { return m_expression; }
virtual Value evaluate(ExecutionContext&) const override;

protected:
explicit NestedExpression(NonnullRefPtr<Expression> expression)
Expand Down Expand Up @@ -439,6 +452,7 @@ class UnaryOperatorExpression : public NestedExpression {
}

UnaryOperator type() const { return m_type; }
virtual Value evaluate(ExecutionContext&) const override;

private:
UnaryOperator m_type;
Expand Down Expand Up @@ -488,6 +502,7 @@ class ChainedExpression : public Expression {
}

const NonnullRefPtrVector<Expression>& expressions() const { return m_expressions; }
virtual Value evaluate(ExecutionContext&) const override;

private:
NonnullRefPtrVector<Expression> m_expressions;
Expand Down Expand Up @@ -667,7 +682,7 @@ class InTableExpression : public InvertibleNestedExpression {

class Statement : public ASTNode {
public:
virtual RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const { return nullptr; }
virtual RefPtr<SQLResult> execute(ExecutionContext&) const { return nullptr; }
};

class ErrorStatement final : public Statement {
Expand All @@ -684,7 +699,7 @@ class CreateSchema : public Statement {
const String& schema_name() const { return m_schema_name; }
bool is_error_if_schema_exists() const { return m_is_error_if_schema_exists; }

RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const override;
RefPtr<SQLResult> execute(ExecutionContext&) const override;

private:
String m_schema_name;
Expand Down Expand Up @@ -723,7 +738,7 @@ class CreateTable : public Statement {
bool is_temporary() const { return m_is_temporary; }
bool is_error_if_table_exists() const { return m_is_error_if_table_exists; }

RefPtr<SQLResult> execute(NonnullRefPtr<Database>) const override;
RefPtr<SQLResult> execute(ExecutionContext&) const override;

private:
String m_schema_name;
Expand Down Expand Up @@ -886,6 +901,8 @@ class Insert : public Statement {
bool has_selection() const { return !m_select_statement.is_null(); }
const RefPtr<Select>& select_statement() const { return m_select_statement; }

RefPtr<SQLResult> execute(ExecutionContext&) const;

private:
RefPtr<CommonTableExpressionList> m_common_table_expression_list;
ConflictResolution m_conflict_resolution;
Expand Down Expand Up @@ -977,6 +994,7 @@ class Select : public Statement {
const RefPtr<GroupByClause>& group_by_clause() const { return m_group_by_clause; }
const NonnullRefPtrVector<OrderingTerm>& ordering_term_list() const { return m_ordering_term_list; }
const RefPtr<LimitClause>& limit_clause() const { return m_limit_clause; }
RefPtr<SQLResult> execute(ExecutionContext&) const override;

private:
RefPtr<CommonTableExpressionList> m_common_table_expression_list;
Expand Down
6 changes: 3 additions & 3 deletions Userland/Libraries/LibSQL/AST/CreateSchema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

namespace SQL::AST {

RefPtr<SQLResult> CreateSchema::execute(NonnullRefPtr<Database> database) const
RefPtr<SQLResult> CreateSchema::execute(ExecutionContext& context) const
{
auto schema_def = database->get_schema(m_schema_name);
auto schema_def = context.database->get_schema(m_schema_name);
if (schema_def) {
if (m_is_error_if_schema_exists) {
return SQLResult::construct(SQLCommand::Create, SQLErrorCode::SchemaExists, m_schema_name);
Expand All @@ -21,7 +21,7 @@ RefPtr<SQLResult> CreateSchema::execute(NonnullRefPtr<Database> database) const
}

schema_def = SchemaDef::construct(m_schema_name);
database->add_schema(*schema_def);
context.database->add_schema(*schema_def);
return SQLResult::construct(SQLCommand::Create, 0, 1);
}

Expand Down
8 changes: 4 additions & 4 deletions Userland/Libraries/LibSQL/AST/CreateTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

namespace SQL::AST {

RefPtr<SQLResult> CreateTable::execute(NonnullRefPtr<Database> database) const
RefPtr<SQLResult> CreateTable::execute(ExecutionContext& context) const
{
auto schema_name = (!m_schema_name.is_null() && !m_schema_name.is_empty()) ? m_schema_name : "default";
auto schema_def = database->get_schema(schema_name);
auto schema_def = context.database->get_schema(schema_name);
if (!schema_def)
return SQLResult::construct(SQLCommand::Create, SQLErrorCode::SchemaDoesNotExist, m_schema_name);
auto table_def = database->get_table(schema_name, m_table_name);
auto table_def = context.database->get_table(schema_name, m_table_name);
if (table_def) {
if (m_is_error_if_table_exists) {
return SQLResult::construct(SQLCommand::Create, SQLErrorCode::TableExists, m_table_name);
Expand All @@ -37,7 +37,7 @@ RefPtr<SQLResult> CreateTable::execute(NonnullRefPtr<Database> database) const
}
table_def->append_column(column.name(), type);
}
database->add_table(*table_def);
context.database->add_table(*table_def);
return SQLResult::construct(SQLCommand::Create, 0, 1);
}

Expand Down
90 changes: 90 additions & 0 deletions Userland/Libraries/LibSQL/AST/Expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2021, Jan de Visser <[email protected]>
*
* SPDX-License-Identifier: BSD-2-Clause
*/

#include <LibSQL/AST/AST.h>
#include <LibSQL/Database.h>

namespace SQL::AST {

Value Expression::evaluate(ExecutionContext&) const
{
return Value::null();
}

Value NumericLiteral::evaluate(ExecutionContext&) const
{
Value ret(SQLType::Float);
ret = value();
return ret;
}

Value StringLiteral::evaluate(ExecutionContext&) const
{
Value ret(SQLType::Text);
ret = value();
return ret;
}

Value NullLiteral::evaluate(ExecutionContext&) const
{
return Value::null();
}

Value NestedExpression::evaluate(ExecutionContext& context) const
{
return expression()->evaluate(context);
}

Value ChainedExpression::evaluate(ExecutionContext& context) const
{
Value ret(SQLType::Tuple);
Vector<Value> values;
for (auto& expression : expressions()) {
values.append(expression.evaluate(context));
}
ret = values;
return ret;
}

Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
{
Value expression_value = NestedExpression::evaluate(context);
switch (type()) {
case UnaryOperator::Plus:
if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float)
return expression_value;
// TODO: Error handling.
VERIFY_NOT_REACHED();
case UnaryOperator::Minus:
if (expression_value.type() == SQLType::Integer) {
expression_value = -int(expression_value);
return expression_value;
}
if (expression_value.type() == SQLType::Float) {
expression_value = -double(expression_value);
return expression_value;
}
// TODO: Error handling.
VERIFY_NOT_REACHED();
case UnaryOperator::Not:
if (expression_value.type() == SQLType::Boolean) {
expression_value = !bool(expression_value);
return expression_value;
}
// TODO: Error handling.
VERIFY_NOT_REACHED();
case UnaryOperator::BitwiseNot:
if (expression_value.type() == SQLType::Integer) {
expression_value = ~u32(expression_value);
return expression_value;
}
// TODO: Error handling.
VERIFY_NOT_REACHED();
}
VERIFY_NOT_REACHED();
}

}
49 changes: 49 additions & 0 deletions Userland/Libraries/LibSQL/AST/Insert.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2021, Jan de Visser <[email protected]>
*
* SPDX-License-Identifier: BSD-2-Clause
*/

#include <LibSQL/AST/AST.h>
#include <LibSQL/Database.h>
#include <LibSQL/Meta.h>
#include <LibSQL/Row.h>

namespace SQL::AST {

RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const
{
auto table_def = context.database->get_table(m_schema_name, m_table_name);
if (!table_def) {
auto schema_name = m_schema_name;
if (schema_name.is_null() || schema_name.is_empty())
schema_name = "default";
return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::TableDoesNotExist, String::formatted("{}.{}", schema_name, m_table_name));
}

Row row(table_def);
for (auto& column : m_column_names) {
if (!row.has(column)) {
return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::ColumnDoesNotExist, column);
}
}

for (auto& row_expr : m_chained_expressions) {
for (auto& column_def : table_def->columns()) {
if (!m_column_names.contains_slow(column_def.name())) {
row[column_def.name()] = column_def.default_value();
}
}
auto row_value = row_expr.evaluate(context);
VERIFY(row_value.type() == SQLType::Tuple);
auto values = row_value.to_vector().value();
for (auto ix = 0u; ix < values.size(); ix++) {
auto& column_name = m_column_names[ix];
row[column_name] = values[ix];
}
context.database->insert(row);
}
return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0);
}

}
Loading

0 comments on commit d074a60

Please sign in to comment.