Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve formatting of binary expressions #3884

Merged
merged 4 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/expected-plans/q19.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
Projection: lineitem.l_extendedprice, lineitem.l_discount
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
Inner Join: lineitem.l_partkey = part.p_partkey
Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 More clearly!

TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15) AND part.p_size >= Int32(1)
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
11 changes: 8 additions & 3 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,7 @@ mod tests {
let expr = col("c1")
.lt(lit(1))
.and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3))));
let expected_expr = "c1_min < Int32(1) AND c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max";
let expected_expr = "c1_min < Int32(1) AND (c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max)";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut required_columns)?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
Expand Down Expand Up @@ -1561,7 +1561,9 @@ mod tests {
list: vec![lit(1), lit(2), lit(3)],
negated: true,
};
let expected_expr = "c1_min != Int32(1) OR Int32(1) != c1_max AND c1_min != Int32(2) OR Int32(2) != c1_max AND c1_min != Int32(3) OR Int32(3) != c1_max";
let expected_expr = "(c1_min != Int32(1) OR Int32(1) != c1_max) \
AND (c1_min != Int32(2) OR Int32(2) != c1_max) \
AND (c1_min != Int32(3) OR Int32(3) != c1_max)";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
Expand Down Expand Up @@ -1633,7 +1635,10 @@ mod tests {
],
negated: true,
};
let expected_expr = "CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64)";
let expected_expr =
"(CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64)) \
AND (CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64)) \
AND (CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64))";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
Expand Down
58 changes: 54 additions & 4 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_common::Result;
use datafusion_common::{plan_err, Column};
use datafusion_common::{DataFusionError, ScalarValue};
use std::fmt;
use std::fmt::Write;
use std::fmt::{Display, Formatter, Write};
use std::hash::{BuildHasher, Hash, Hasher};
use std::ops::Not;
use std::sync::Arc;
Expand Down Expand Up @@ -265,6 +265,58 @@ impl BinaryExpr {
pub fn new(left: Box<Expr>, op: Operator, right: Box<Expr>) -> Self {
Self { left, op, right }
}

/// Get the operator precedence
/// use https://www.postgresql.org/docs/7.0/operators.htm#AEN2026 as a reference
pub fn precedence(&self) -> u8 {
match self.op {
Operator::Or => 5,
Operator::And => 10,
Operator::Like | Operator::NotLike => 19,
Operator::NotEq
| Operator::Eq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq => 20,
Operator::Plus | Operator::Minus => 30,
Operator::Multiply | Operator::Divide | Operator::Modulo => 40,
_ => 0,
}
}
}

impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
// Put parentheses around child binary expressions so that we can see the difference
// between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
// based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
// equivalent and the parentheses are not necessary.

fn write_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.precedence();
if p == 0 || p < precedence {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if p == 0 || p < precedence {
if p <= precedence {

Is the same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point

write!(f, "({})", child)?;
} else {
write!(f, "{}", child)?;
}
}
_ => write!(f, "{}", expr)?,
}
Ok(())
}

let precedence = self.precedence();
write_child(f, self.left.as_ref(), precedence)?;
write!(f, " {} ", self.op)?;
write_child(f, self.right.as_ref(), precedence)
}
}

/// CASE expression
Expand Down Expand Up @@ -717,9 +769,7 @@ impl fmt::Debug for Expr {
negated: false,
} => write!(f, "{:?} IN ({:?})", expr, subquery),
Expr::ScalarSubquery(subquery) => write!(f, "({:?})", subquery),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{:?} {} {:?}", left, op, right)
}
Expr::BinaryExpr(expr) => write!(f, "{}", expr),
Expr::Sort {
expr,
asc,
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ mod test {
)?;

let expected = vec![
(9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
(3, "a + Int32(1)Int32(1)a"),
Expand Down Expand Up @@ -671,8 +671,8 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * Int32(1) + test.c)]]\
\n Projection: test.a * Int32(1) - test.b AS test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Int32(1) - test.b)Int32(1)

this is a little confused, should we change it to 🤔

(Int32(1) - test.bInt32(1))

Maybe we can file another issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is related to #3786 ?

\n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ mod tests {
let expected = "\
Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: test.a * Int32(2) + test.c * Int32(3) = Int64(1)\
\n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/reduce_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ mod tests {
.build()?;

let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
Expand Down Expand Up @@ -936,7 +936,7 @@ mod tests {
.build()?;

let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/subquery_filter_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ mod tests {
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n Filter: test.a = UInt32(1) OR test.b IN (<subquery>) AND test.c IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Filter: (test.a = UInt32(1) OR test.b IN (<subquery>)) AND test.c IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [c:UInt32]\
\n Projection: sq1.c [c:UInt32]\
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3413,7 +3413,7 @@ mod tests {
#[test]
fn select_binary_expr_nested() {
let sql = "SELECT (age + salary)/2 from person";
let expected = "Projection: person.age + person.salary / Int64(2)\
let expected = "Projection: (person.age + person.salary) / Int64(2)\
\n TableScan: person";
quick_test(sql, expected);
}
Expand Down Expand Up @@ -3848,7 +3848,7 @@ mod tests {
fn select_where_nullif_division() {
let sql = "SELECT c3/(c4+c5) \
FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1";
let expected = "Projection: aggregate_test_100.c3 / aggregate_test_100.c4 + aggregate_test_100.c5\
let expected = "Projection: aggregate_test_100.c3 / (aggregate_test_100.c4 + aggregate_test_100.c5)\
\n Filter: aggregate_test_100.c3 / nullif(aggregate_test_100.c4 + aggregate_test_100.c5, Int64(0)) > Float64(0.1)\
\n TableScan: aggregate_test_100";
quick_test(sql, expected);
Expand Down