Skip to content

Commit

Permalink
sql: move ResultSet to engine::StatementResult and simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
erikgrinaker committed Jun 16, 2024
1 parent 80d7fe5 commit 3457a9f
Show file tree
Hide file tree
Showing 13 changed files with 326 additions and 245 deletions.
30 changes: 14 additions & 16 deletions src/bin/toysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

#![warn(clippy::all)]

use itertools::Itertools as _;
use rustyline::history::DefaultHistory;
use rustyline::validate::{ValidationContext, ValidationResult, Validator};
use rustyline::{error::ReadlineError, Editor, Modifiers};
use rustyline_derive::{Completer, Helper, Highlighter, Hinter};
use toydb::errinput;
use toydb::error::{Error, Result};
use toydb::sql::execution::ResultSet;
use toydb::sql::engine::StatementResult;
use toydb::sql::parser::{Lexer, Token};
use toydb::Client;

Expand Down Expand Up @@ -181,22 +182,22 @@ Storage: {keys} keys, {logical_size} MB logical, {nodes}x {disk_size} MB disk,
/// Runs a query and displays the results
fn execute_query(&mut self, query: &str) -> Result<()> {
match self.client.execute(query)? {
ResultSet::Begin { version, read_only } => match read_only {
StatementResult::Begin { version, read_only } => match read_only {
false => println!("Began transaction at new version {}", version),
true => println!("Began read-only transaction at version {}", version),
},
ResultSet::Commit { version: id } => println!("Committed transaction {}", id),
ResultSet::Rollback { version: id } => println!("Rolled back transaction {}", id),
ResultSet::Create { count } => println!("Created {} rows", count),
ResultSet::Delete { count } => println!("Deleted {} rows", count),
ResultSet::Update { count } => println!("Updated {} rows", count),
ResultSet::CreateTable { name } => println!("Created table {}", name),
ResultSet::DropTable { name, existed } => match existed {
StatementResult::Commit { version: id } => println!("Committed transaction {}", id),
StatementResult::Rollback { version: id } => println!("Rolled back transaction {}", id),
StatementResult::Create { count } => println!("Created {} rows", count),
StatementResult::Delete { count } => println!("Deleted {} rows", count),
StatementResult::Update { count } => println!("Updated {} rows", count),
StatementResult::CreateTable { name } => println!("Created table {}", name),
StatementResult::DropTable { name, existed } => match existed {
true => println!("Dropped table {}", name),
false => println!("Table {} did not exit", name),
},
ResultSet::Explain(plan) => println!("{}", plan),
ResultSet::Query { columns, mut rows } => {
StatementResult::Explain(plan) => println!("{}", plan),
StatementResult::Query { columns, rows } => {
if self.show_headers {
println!(
"{}",
Expand All @@ -207,11 +208,8 @@ Storage: {keys} keys, {logical_size} MB logical, {nodes}x {disk_size} MB disk,
.join("|")
);
}
while let Some(row) = rows.next().transpose()? {
println!(
"{}",
row.into_iter().map(|v| format!("{}", v)).collect::<Vec<_>>().join("|")
);
for row in rows {
println!("{}", row.into_iter().map(|v| format!("{}", v)).join("|"));
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/bin/workload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::HashSet;
use std::io::Write as _;
use std::time::Duration;
use toydb::error::Result;
use toydb::{Client, ResultSet};
use toydb::{Client, StatementResult};

fn main() -> Result<()> {
let Command { runner, subcommand } = Command::parse();
Expand Down Expand Up @@ -337,7 +337,7 @@ impl Workload for Write {
r#"INSERT INTO "write" (id, value) VALUES {}"#,
item.iter().map(|(id, value)| format!("({}, '{}')", id, value)).join(", ")
);
if let ResultSet::Create { count } = client.execute(&query)? {
if let StatementResult::Create { count } = client.execute(&query)? {
assert_eq!(count as usize, batch_size, "Unexpected row count");
} else {
panic!("Unexpected result")
Expand Down
26 changes: 8 additions & 18 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::encoding::Value as _;
use crate::errdata;
use crate::error::{Error, Result};
use crate::server::{Request, Response, Status};
use crate::sql::execution::ResultSet;
use crate::sql::engine::StatementResult;
use crate::sql::types::schema::Table;

use rand::Rng;
Expand Down Expand Up @@ -33,27 +33,17 @@ impl Client {
}

/// Executes a query
pub fn execute(&mut self, query: &str) -> Result<ResultSet> {
let mut resultset = match self.call(Request::Execute(query.into()))? {
pub fn execute(&mut self, query: &str) -> Result<StatementResult> {
let resultset = match self.call(Request::Execute(query.into()))? {
Response::Execute(rs) => rs,
response => return errdata!("unexpected response {response:?}"),
};
if let ResultSet::Query { columns, .. } = resultset {
// FIXME We buffer rows for now to avoid lifetime hassles
let mut rows = Vec::new();
loop {
match Result::<Response>::decode_from(&mut self.reader)?? {
Response::Row(Some(row)) => rows.push(row),
Response::Row(None) => break,
response => return errdata!("unexpected response {response:?}"),
}
}
resultset = ResultSet::Query { columns, rows: Box::new(rows.into_iter().map(Ok)) }
};
match &resultset {
ResultSet::Begin { version, read_only } => self.txn = Some((*version, *read_only)),
ResultSet::Commit { .. } => self.txn = None,
ResultSet::Rollback { .. } => self.txn = None,
StatementResult::Begin { version, read_only } => {
self.txn = Some((*version, *read_only))
}
StatementResult::Commit { .. } => self.txn = None,
StatementResult::Rollback { .. } => self.txn = None,
_ => {}
}
Ok(resultset)
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ pub mod storage;

pub use client::Client;
pub use server::Server;
pub use sql::execution::ResultSet;
pub use sql::engine::StatementResult;
32 changes: 3 additions & 29 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use crate::encoding::{self, Value as _};
use crate::error::{Error, Result};
use crate::raft;
use crate::sql;
use crate::sql::engine::Engine as _;
use crate::sql::execution::ResultSet;
use crate::sql::engine::{Engine as _, StatementResult};
use crate::sql::types::schema::{Catalog as _, Table};
use crate::sql::types::Row;
use crate::storage;
Expand Down Expand Up @@ -288,7 +287,7 @@ impl Server {
while let Some(request) = Request::maybe_decode_from(&mut reader)? {
// Execute request.
debug!("Received request {request:?}");
let mut response = match request {
let response = match request {
Request::Execute(query) => session.execute(&query).map(Response::Execute),
Request::GetTable(table) => session
.with_txn_read_only(|txn| txn.must_read_table(&table))
Expand All @@ -304,32 +303,7 @@ impl Server {

// Process response.
debug!("Returning response {response:?}");
let mut rows: Box<dyn Iterator<Item = Result<Response>> + Send> =
Box::new(std::iter::empty());
if let Ok(Response::Execute(ResultSet::Query { rows: ref mut resultrows, .. })) =
&mut response
{
// TODO: don't stream results, for simplicity.
rows = Box::new(
std::mem::replace(resultrows, Box::new(std::iter::empty()))
.map(|result| result.map(|row| Response::Row(Some(row))))
.chain(std::iter::once(Ok(Response::Row(None))))
.scan(false, |err_sent, response| match (&err_sent, &response) {
(true, _) => None,
(_, Err(error)) => {
*err_sent = true;
Some(Err(error.clone()))
}
_ => Some(response),
})
.fuse(),
);
}

response.encode_into(&mut writer)?;
for row in rows {
row.encode_into(&mut writer)?;
}
writer.flush()?;
}
Ok(())
Expand All @@ -354,7 +328,7 @@ impl encoding::Value for Request {}
/// A SQL server response.
#[derive(Debug, Serialize, Deserialize)]
pub enum Response {
Execute(ResultSet),
Execute(StatementResult),
Row(Option<Row>),
GetTable(Table),
ListTables(Vec<String>),
Expand Down
104 changes: 91 additions & 13 deletions src/sql/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ mod kv;
pub mod raft;
pub use kv::KV;
pub use raft::{Raft, Status};
use serde::{Deserialize, Serialize};

use super::execution::ResultSet;
use super::execution::ExecutionResult;
use super::parser::{ast, Parser};
use super::plan::Plan;
use super::types::schema::Catalog;
use super::types::{Expression, Row, Value};
use crate::errinput;
use crate::error::Result;
use super::types::{Columns, Expression, Row, Rows, Value};
use crate::error::{Error, Result};
use crate::{errdata, errinput};

use std::collections::HashSet;

Expand Down Expand Up @@ -72,7 +73,7 @@ pub struct Session<E: Engine + 'static> {

impl<E: Engine + 'static> Session<E> {
/// Executes a query, managing transaction status for the session
pub fn execute(&mut self, query: &str) -> Result<ResultSet> {
pub fn execute(&mut self, query: &str) -> Result<StatementResult> {
// FIXME We should match on self.txn as well, but get this error:
// error[E0009]: cannot bind by-move and by-ref in the same pattern
// ...which seems like an arbitrary compiler limitation
Expand All @@ -82,13 +83,13 @@ impl<E: Engine + 'static> Session<E> {
}
ast::Statement::Begin { read_only: true, as_of: None } => {
let txn = self.engine.begin_read_only()?;
let result = ResultSet::Begin { version: txn.version(), read_only: true };
let result = StatementResult::Begin { version: txn.version(), read_only: true };
self.txn = Some(txn);
Ok(result)
}
ast::Statement::Begin { read_only: true, as_of: Some(version) } => {
let txn = self.engine.begin_as_of(version)?;
let result = ResultSet::Begin { version, read_only: true };
let result = StatementResult::Begin { version, read_only: true };
self.txn = Some(txn);
Ok(result)
}
Expand All @@ -97,7 +98,7 @@ impl<E: Engine + 'static> Session<E> {
}
ast::Statement::Begin { read_only: false, as_of: None } => {
let txn = self.engine.begin()?;
let result = ResultSet::Begin { version: txn.version(), read_only: false };
let result = StatementResult::Begin { version: txn.version(), read_only: false };
self.txn = Some(txn);
Ok(result)
}
Expand All @@ -108,17 +109,17 @@ impl<E: Engine + 'static> Session<E> {
let txn = self.txn.take().unwrap();
let version = txn.version();
txn.commit()?;
Ok(ResultSet::Commit { version })
Ok(StatementResult::Commit { version })
}
ast::Statement::Rollback => {
let txn = self.txn.take().unwrap();
let version = txn.version();
txn.rollback()?;
Ok(ResultSet::Rollback { version })
Ok(StatementResult::Rollback { version })
}
// TODO: this needs testing.
ast::Statement::Explain(statement) => self.with_txn_read_only(|txn| {
Ok(ResultSet::Explain(Plan::build(*statement, txn)?.optimize(txn)?))
Ok(StatementResult::Explain(Plan::build(*statement, txn)?.optimize(txn)?))
}),
statement if self.txn.is_some() => {
Self::execute_with(statement, self.txn.as_mut().unwrap())
Expand Down Expand Up @@ -146,8 +147,11 @@ impl<E: Engine + 'static> Session<E> {
/// or a temporary read-only or read/write transaction.
///
/// TODO: reconsider this.
fn execute_with(statement: ast::Statement, txn: &mut E::Transaction) -> Result<ResultSet> {
Ok(Plan::build(statement, txn)?.optimize(txn)?.execute(txn)?.into())
fn execute_with(
statement: ast::Statement,
txn: &mut E::Transaction,
) -> Result<StatementResult> {
Plan::build(statement, txn)?.optimize(txn)?.execute(txn)?.try_into()
}

/// Runs a read-only closure in the session's transaction, or a new
Expand Down Expand Up @@ -185,3 +189,77 @@ pub type Scan = Box<dyn DoubleEndedIterator<Item = Result<Row>> + Send>;

/// An index scan iterator
pub type IndexScan = Box<dyn DoubleEndedIterator<Item = Result<(Value, HashSet<Value>)>> + Send>;

/// A session statement result. This is also sent across the wire to SQL
/// clients.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub enum StatementResult {
// Transaction started
Begin { version: u64, read_only: bool },
// Transaction committed
Commit { version: u64 },
// Transaction rolled back
Rollback { version: u64 },
// Rows created
Create { count: u64 },
// Rows deleted
Delete { count: u64 },
// Rows updated
Update { count: u64 },
// Table created
CreateTable { name: String },
// Table dropped
DropTable { name: String, existed: bool },
// Query result.
//
// For simplicity, buffer and send the entire result as a vector instead of
// streaming it to the client. Streaming reads haven't been implemented from
// Raft either.
Query { columns: Columns, rows: Vec<Row> },
// Explain result
Explain(Plan),
}

impl StatementResult {
/// Converts the ResultSet into a row, or errors if not a query result with rows.
pub fn into_row(self) -> Result<Row> {
self.into_rows()?.next().transpose()?.ok_or(errdata!("no rows returned"))
}

/// Converts the ResultSet into a row iterator, or errors if not a query
/// result with rows.
pub fn into_rows(self) -> Result<Rows> {
if let StatementResult::Query { rows, .. } = self {
Ok(Box::new(rows.into_iter().map(Ok)))
} else {
errdata!("not a query result: {self:?}")
}
}

/// Converts the ResultSet into a value, if possible.
/// TODO: use TryFrom for this, also to primitive types via Value as TryFrom.
pub fn into_value(self) -> Result<Value> {
self.into_row()?.into_iter().next().ok_or(errdata!("no value returned"))
}
}

// TODO: remove or revisit this.
impl TryFrom<ExecutionResult> for StatementResult {
type Error = Error;

fn try_from(result: ExecutionResult) -> Result<Self> {
Ok(match result {
ExecutionResult::CreateTable { name } => StatementResult::CreateTable { name },
ExecutionResult::DropTable { name, existed } => {
StatementResult::DropTable { name, existed }
}
ExecutionResult::Delete { count } => StatementResult::Delete { count },
ExecutionResult::Insert { count } => StatementResult::Create { count },
ExecutionResult::Select { iter } => StatementResult::Query {
columns: iter.columns,
rows: iter.rows.collect::<Result<_>>()?,
},
ExecutionResult::Update { count } => StatementResult::Update { count },
})
}
}
Loading

0 comments on commit 3457a9f

Please sign in to comment.