Skip to content

Commit 783c45b

Browse files
jackkleemanalamb
andauthored
Add case expr simplifiers for literal comparisons (#17743)
* Add case expr simplifiers for literal comparisons * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb <[email protected]> * Avoid expr clones --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent e566e97 commit 783c45b

File tree

2 files changed

+251
-2
lines changed

2 files changed

+251
-2
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,41 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
13991399
// Rules for Case
14001400
//
14011401

1402+
// Inline a comparison to a literal with the case statement into the `THEN` clauses.
1403+
// which can enable further simplifications
1404+
// CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END
1405+
Expr::BinaryExpr(BinaryExpr {
1406+
left,
1407+
op: op @ (Eq | NotEq),
1408+
right,
1409+
}) if is_case_with_literal_outputs(&left) && is_lit(&right) => {
1410+
let case = into_case(*left)?;
1411+
Transformed::yes(Expr::Case(Case {
1412+
expr: None,
1413+
when_then_expr: case
1414+
.when_then_expr
1415+
.into_iter()
1416+
.map(|(when, then)| {
1417+
(
1418+
when,
1419+
Box::new(Expr::BinaryExpr(BinaryExpr {
1420+
left: then,
1421+
op,
1422+
right: right.clone(),
1423+
})),
1424+
)
1425+
})
1426+
.collect(),
1427+
else_expr: case.else_expr.map(|els| {
1428+
Box::new(Expr::BinaryExpr(BinaryExpr {
1429+
left: els,
1430+
op,
1431+
right,
1432+
}))
1433+
}),
1434+
}))
1435+
}
1436+
14021437
// CASE WHEN true THEN A ... END --> A
14031438
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
14041439
Expr::Case(Case {
@@ -1447,7 +1482,11 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14471482
when_then_expr,
14481483
else_expr,
14491484
}) if !when_then_expr.is_empty()
1450-
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
1485+
// The rewrite is O(n²) in general so limit to small number of when-thens that can be true
1486+
&& (when_then_expr.len() < 3 // small number of input whens
1487+
// or all thens are literal bools and a small number of them are true
1488+
|| (when_then_expr.iter().all(|(_, then)| is_bool_lit(then))
1489+
&& when_then_expr.iter().filter(|(_, then)| is_true(then)).count() < 3))
14511490
&& info.is_boolean_type(&when_then_expr[0].1)? =>
14521491
{
14531492
// String disjunction of all the when predicates encountered so far. Not nullable.
@@ -1471,6 +1510,55 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14711510
// Do a first pass at simplification
14721511
out_expr.rewrite(self)?
14731512
}
1513+
// CASE
1514+
// WHEN X THEN true
1515+
// WHEN Y THEN true
1516+
// WHEN Z THEN false
1517+
// ...
1518+
// ELSE true
1519+
// END
1520+
//
1521+
// --->
1522+
//
1523+
// NOT(CASE
1524+
// WHEN X THEN false
1525+
// WHEN Y THEN false
1526+
// WHEN Z THEN true
1527+
// ...
1528+
// ELSE false
1529+
// END)
1530+
//
1531+
// Note: the rationale for this rewrite is that the case can then be further
1532+
// simplified into a small number of ANDs and ORs
1533+
Expr::Case(Case {
1534+
expr: None,
1535+
when_then_expr,
1536+
else_expr,
1537+
}) if !when_then_expr.is_empty()
1538+
&& when_then_expr
1539+
.iter()
1540+
.all(|(_, then)| is_bool_lit(then)) // all thens are literal bools
1541+
// This simplification is only helpful if we end up with a small number of true thens
1542+
&& when_then_expr
1543+
.iter()
1544+
.filter(|(_, then)| is_false(then))
1545+
.count()
1546+
< 3
1547+
&& else_expr.as_deref().is_none_or(is_bool_lit) =>
1548+
{
1549+
Transformed::yes(
1550+
Expr::Case(Case {
1551+
expr: None,
1552+
when_then_expr: when_then_expr
1553+
.into_iter()
1554+
.map(|(when, then)| (when, Box::new(Expr::Not(then))))
1555+
.collect(),
1556+
else_expr: else_expr
1557+
.map(|else_expr| Box::new(Expr::Not(else_expr))),
1558+
})
1559+
.not(),
1560+
)
1561+
}
14741562
Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
14751563
match udf.simplify(args, info)? {
14761564
ExprSimplifyResult::Original(args) => {
@@ -3465,6 +3553,142 @@ mod tests {
34653553
);
34663554
}
34673555

3556+
#[test]
3557+
fn simplify_literal_case_equality() {
3558+
// CASE WHEN c2 != false THEN "ok" ELSE "not_ok"
3559+
let simple_case = Expr::Case(Case::new(
3560+
None,
3561+
vec![(
3562+
Box::new(col("c2_non_null").not_eq(lit(false))),
3563+
Box::new(lit("ok")),
3564+
)],
3565+
Some(Box::new(lit("not_ok"))),
3566+
));
3567+
3568+
// CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok"
3569+
// -->
3570+
// CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok"
3571+
// -->
3572+
// CASE WHEN c2 != false THEN true ELSE false
3573+
// -->
3574+
// c2
3575+
assert_eq!(
3576+
simplify(binary_expr(simple_case.clone(), Operator::Eq, lit("ok"),)),
3577+
col("c2_non_null"),
3578+
);
3579+
3580+
// CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok"
3581+
// -->
3582+
// NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok")
3583+
// -->
3584+
// NOT(CASE WHEN c2 != false THEN true ELSE false)
3585+
// -->
3586+
// NOT(c2)
3587+
assert_eq!(
3588+
simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)),
3589+
not(col("c2_non_null")),
3590+
);
3591+
3592+
let complex_case = Expr::Case(Case::new(
3593+
None,
3594+
vec![
3595+
(
3596+
Box::new(col("c1").eq(lit("inboxed"))),
3597+
Box::new(lit("pending")),
3598+
),
3599+
(
3600+
Box::new(col("c1").eq(lit("scheduled"))),
3601+
Box::new(lit("pending")),
3602+
),
3603+
(
3604+
Box::new(col("c1").eq(lit("completed"))),
3605+
Box::new(lit("completed")),
3606+
),
3607+
(
3608+
Box::new(col("c1").eq(lit("paused"))),
3609+
Box::new(lit("paused")),
3610+
),
3611+
(Box::new(col("c2")), Box::new(lit("running"))),
3612+
(
3613+
Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))),
3614+
Box::new(lit("backing-off")),
3615+
),
3616+
],
3617+
Some(Box::new(lit("ready"))),
3618+
));
3619+
3620+
assert_eq!(
3621+
simplify(binary_expr(
3622+
complex_case.clone(),
3623+
Operator::Eq,
3624+
lit("completed"),
3625+
)),
3626+
not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and(
3627+
distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3628+
.and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3629+
)
3630+
);
3631+
3632+
assert_eq!(
3633+
simplify(binary_expr(
3634+
complex_case.clone(),
3635+
Operator::NotEq,
3636+
lit("completed"),
3637+
)),
3638+
distinct_from(col("c1").eq(lit("completed")), lit(true))
3639+
.or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3640+
.or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true))))
3641+
);
3642+
3643+
assert_eq!(
3644+
simplify(binary_expr(
3645+
complex_case.clone(),
3646+
Operator::Eq,
3647+
lit("running"),
3648+
)),
3649+
not_distinct_from(col("c2"), lit(true)).and(
3650+
distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3651+
.and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3652+
.and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
3653+
.and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
3654+
)
3655+
);
3656+
3657+
assert_eq!(
3658+
simplify(binary_expr(
3659+
complex_case.clone(),
3660+
Operator::Eq,
3661+
lit("ready"),
3662+
)),
3663+
distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3664+
.and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3665+
.and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
3666+
.and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
3667+
.and(distinct_from(col("c2"), lit(true)))
3668+
.and(distinct_from(
3669+
col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
3670+
lit(true)
3671+
))
3672+
);
3673+
3674+
assert_eq!(
3675+
simplify(binary_expr(
3676+
complex_case.clone(),
3677+
Operator::NotEq,
3678+
lit("ready"),
3679+
)),
3680+
not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
3681+
.or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
3682+
.or(not_distinct_from(col("c1").eq(lit("completed")), lit(true)))
3683+
.or(not_distinct_from(col("c1").eq(lit("paused")), lit(true)))
3684+
.or(not_distinct_from(col("c2"), lit(true)))
3685+
.or(not_distinct_from(
3686+
col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
3687+
lit(true)
3688+
))
3689+
);
3690+
}
3691+
34683692
#[test]
34693693
fn simplify_expr_case_when_then_else() {
34703694
// CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true

datafusion/optimizer/src/simplify_expressions/utils.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use datafusion_common::{internal_err, Result, ScalarValue};
2222
use datafusion_expr::{
2323
expr::{Between, BinaryExpr, InList},
2424
expr_fn::{and, bitwise_and, bitwise_or, or},
25-
Expr, Like, Operator,
25+
Case, Expr, Like, Operator,
2626
};
2727

2828
pub static POWS_OF_TEN: [i128; 38] = [
@@ -265,6 +265,31 @@ pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
265265
}
266266
}
267267

268+
pub fn is_case_with_literal_outputs(expr: &Expr) -> bool {
269+
match expr {
270+
Expr::Case(Case {
271+
expr: None,
272+
when_then_expr,
273+
else_expr,
274+
}) => {
275+
when_then_expr.iter().all(|(_, then)| is_lit(then))
276+
&& else_expr.as_deref().is_none_or(is_lit)
277+
}
278+
_ => false,
279+
}
280+
}
281+
282+
pub fn into_case(expr: Expr) -> Result<Case> {
283+
match expr {
284+
Expr::Case(case) => Ok(case),
285+
_ => internal_err!("Expected case, got {expr:?}"),
286+
}
287+
}
288+
289+
pub fn is_lit(expr: &Expr) -> bool {
290+
matches!(expr, Expr::Literal(_, _))
291+
}
292+
268293
/// negate a Not clause
269294
/// input is the clause to be negated.(args of Not clause)
270295
/// For BinaryExpr, use the negation of op instead.

0 commit comments

Comments
 (0)