Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parser option for parsing SQL numeric literals as decimal #4102

Merged
merged 8 commits into from
Nov 15, 2022
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 122 additions & 19 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,24 @@ pub trait ContextProvider {
fn get_config_option(&self, variable: &str) -> Option<ScalarValue>;
}

/// SQL parser options
#[derive(Debug)]
pub struct ParserOptions {
parse_float_as_decimal: bool,
}

impl Default for ParserOptions {
fn default() -> Self {
Self {
parse_float_as_decimal: false,
}
}
}

/// SQL query planner
pub struct SqlToRel<'a, S: ContextProvider> {
schema_provider: &'a S,
options: ParserOptions,
}

fn plan_key(key: SQLExpr) -> Result<ScalarValue> {
Expand Down Expand Up @@ -137,7 +152,15 @@ fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Create a new query planner
pub fn new(schema_provider: &'a S) -> Self {
SqlToRel { schema_provider }
Self::new_with_options(schema_provider, ParserOptions::default())
}

/// Create a new query planner
pub fn new_with_options(schema_provider: &'a S, options: ParserOptions) -> Self {
SqlToRel {
schema_provider,
options,
}
}

/// Generate a logical plan from an DataFusion SQL statement
Expand Down Expand Up @@ -1699,7 +1722,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.map(|row| {
row.into_iter()
.map(|v| match v {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::Number(n, _)) => self.parse_sql_number(&n),
SQLExpr::Value(
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
) => Ok(lit(s)),
Expand Down Expand Up @@ -1753,7 +1776,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<Expr> {
match sql {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::Number(n, _)) => self.parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
Expand Down Expand Up @@ -2668,6 +2691,51 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

/// Parse number in sql string, convert to Expr::Literal
fn parse_sql_number(&self, n: &str) -> Result<Expr> {
if let Some(_) = n.find('E') {
// not implemented yet
// https://github.com/apache/arrow-datafusion/issues/3448
Err(DataFusionError::NotImplemented(
"sql numeric literals in scientific notation are not supported"
.to_string(),
))
} else if let Ok(n) = n.parse::<i64>() {
Ok(lit(n))
} else if self.options.parse_float_as_decimal {
// remove leading zeroes
let str = n.trim_start_matches('0');
if let Some(i) = str.find('.') {
let p = str.len() - 1;
let s = str.len() - i - 1;
let str = str.replace('.', "");
let n = str.parse::<i128>().map_err(|_| {
DataFusionError::from(ParserError(format!(
"Cannot parse {} as i128 when building decimal",
str
andygrove marked this conversation as resolved.
Show resolved Hide resolved
)))
})?;
Ok(Expr::Literal(ScalarValue::Decimal128(
Some(n),
p as u8,
s as u8,
)))
} else {
let number = n.parse::<i128>().map_err(|_| {
DataFusionError::from(ParserError(format!(
"Cannot parse {} as i128 when building decimal",
n
)))
})?;
Ok(Expr::Literal(ScalarValue::Decimal128(Some(number), 38, 0)))
Comment on lines +2716 to +2722
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the minimum precision to handle this number?

scala> sql("select 10000000000000000000")
res10: org.apache.spark.sql.DataFrame = [10000000000000000000: decimal(20,0)]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya Could you take another look when you have time?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to do it. Let me take another look.

}
} else {
n.parse::<f64>().map(lit).map_err(|_| {
DataFusionError::from(ParserError(format!("Cannot parse {} as f64", n)))
})
}
}

fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::Array(inner_sql_type) => {
Expand Down Expand Up @@ -2919,28 +2987,40 @@ fn extract_possible_join_keys(
}
}

// Parse number in sql string, convert to Expr::Literal
fn parse_sql_number(n: &str) -> Result<Expr> {
// parse first as i64
n.parse::<i64>()
.map(lit)
// if parsing as i64 fails try f64
.or_else(|_| n.parse::<f64>().map(lit))
.map_err(|_| {
DataFusionError::from(ParserError(format!(
"Cannot parse {} as i64 or f64",
n
)))
})
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::assert_contains;
use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
use std::any::Any;

#[test]
fn parse_decimals() {
let test_data = [
("1", "Int64(1)"),
("001", "Int64(1)"),
("0.1", "Decimal128(Some(1),1,1)"),
("0.01", "Decimal128(Some(1),2,2)"),
("1.0", "Decimal128(Some(10),2,1)"),
("10.01", "Decimal128(Some(1001),4,2)"),
(
"10000000000000000000.00",
"Decimal128(Some(1000000000000000000000),22,2)",
),
];
for (a, b) in test_data {
let sql = format!("SELECT {}", a);
let expected = format!("Projection: {}\n EmptyRelation", b);
quick_test_with_options(
&sql,
&expected,
ParserOptions {
parse_float_as_decimal: true,
},
);
}
}

#[test]
fn select_no_relation() {
quick_test(
Expand Down Expand Up @@ -4913,8 +4993,15 @@ mod tests {
}

fn logical_plan(sql: &str) -> Result<LogicalPlan> {
logical_plan_with_options(sql, ParserOptions::default())
}

fn logical_plan_with_options(
sql: &str,
options: ParserOptions,
) -> Result<LogicalPlan> {
let dialect = &GenericDialect {};
logical_plan_with_dialect(sql, dialect)
logical_plan_with_dialect_and_options(sql, dialect, options)
}

fn logical_plan_with_dialect(
Expand All @@ -4927,12 +5014,28 @@ mod tests {
planner.statement_to_plan(ast.pop_front().unwrap())
}

fn logical_plan_with_dialect_and_options(
sql: &str,
dialect: &dyn Dialect,
options: ParserOptions,
) -> Result<LogicalPlan> {
let planner = SqlToRel::new_with_options(&MockContextProvider {}, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
planner.statement_to_plan(ast.pop_front().unwrap())
}

/// Create logical plan, write with formatter, compare to expected output
fn quick_test(sql: &str, expected: &str) {
let plan = logical_plan(sql).unwrap();
assert_eq!(format!("{:?}", plan), expected);
}

fn quick_test_with_options(sql: &str, expected: &str, options: ParserOptions) {
let plan = logical_plan_with_options(sql, options).unwrap();
assert_eq!(format!("{:?}", plan), expected);
}

struct MockContextProvider {}

impl ContextProvider for MockContextProvider {
Expand Down