Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 224 additions & 1 deletion datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,39 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
// Rules for Case
//

// CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END
Expr::BinaryExpr(BinaryExpr {
left,
op: op @ (Eq | NotEq),
right,
}) if is_case_with_literal_outputs(&left) && is_lit(&right) => {
let case = as_case(&left)?;
Transformed::yes(Expr::Case(Case {
expr: None,
when_then_expr: case
.when_then_expr
.iter()
.map(|(when, then)| {
(
when.clone(),
Box::new(Expr::BinaryExpr(BinaryExpr {
left: then.clone(),
op,
right: right.clone(),
})),
)
})
.collect(),
else_expr: case.else_expr.as_ref().map(|els| {
Box::new(Expr::BinaryExpr(BinaryExpr {
left: els.clone(),
op,
right,
}))
}),
}))
}

// CASE WHEN true THEN A ... END --> A
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
Expr::Case(Case {
Expand Down Expand Up @@ -1447,7 +1480,11 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
// The rewrite is O(n²) in general so limit to small number of when-thens that can be true
&& (when_then_expr.len() < 3 // small number of input whens
// or all thens are literal bools and a small number of them are true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would be valuable to do this rewrite regardless of the terms as long as all the arguments are boolean literals (why still limit to three true cases?)

Copy link
Contributor Author

@jackkleeman jackkleeman Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the comment about it being O(n^2) is very legit, and this is also true in the literal bool case; the thens are literal bools but we dont know anything about the whens and whether they might be reducible. Indeed, in my target use case, the thens are not reducible at all and we definitely end up with On^2 boolean expressions. The intent of this check is to maintain the '<3 when then' limitation, but just be cleverer about how we count when thens, ignoring those where the then is a literal false

We could, if we wanted, split this out into a separate rewrite that removes when_thens that have false thens. However, this is not as trivial as the when=false trim PR, and I actually think it could lead to longer case statements - you would have to fold in the when of the trimmed branch into all the following whens, as they would not have matched if the trimmed when had matched. Else is also complex in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense -- let's keep the check then

|| (when_then_expr.iter().all(|(_, then)| is_bool_lit(then))
&& when_then_expr.iter().filter(|(_, then)| is_true(then)).count() < 3))
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// String disjunction of all the when predicates encountered so far. Not nullable.
Expand All @@ -1471,6 +1508,56 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
// Do a first pass at simplification
out_expr.rewrite(self)?
}
// CASE
// WHEN X THEN true
// WHEN Y THEN true
// WHEN Z THEN false
// ...
// ELSE true
// END
//
// --->
//
// NOT(CASE
// WHEN X THEN false
// WHEN Y THEN false
// WHEN Z THEN true
// ...
// ELSE false
// END)
//
// Note: the rationale for this rewrite is that the case can then be further
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why rewrite adding a NOT make it simpler / easier to simplify?

In recent days we have done several other CASE simplifications related to constants such as

We have one more open (perhaps you can help nudge it along) which I think would result in simplifying the expressions

Once #17628 is complete, do you think this rewrite would still add value?

Copy link
Contributor Author

@jackkleeman jackkleeman Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets say we have:

CASE
  WHEN col_1 THEN true
  WHEN col_2 THEN false
...
  WHEN col_99 THEN true
  ELSE true
END

none of the WHEN can be trimmed, so I am not sure those open prs will help. the naive disaggregation, if we removed the <3 limit will look like:

col1 OR (not col_1 and col_2) or (not col_1 and not col_2 and col_3)... 

which is a very long expression indeed (O(n^2) in the number of true branches). instead, if we invert, we can write an expression like this:

not(not col_1 and col_2) -> col_1 or not(col_2)

which is O(n^2) in the number of false branches (a much smaller number)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great example, maybe we can add it to the comments (as a follow on PR)

// simplified into a small number of ANDs and ORs
Expr::Case(Case {
expr: None,
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr
.iter()
.all(|(_, then)| is_bool_lit(then)) // all thens are literal bools
// This simplification is only helpful if we end up with a small number of true thens
&& when_then_expr
.iter()
.filter(|(_, then)| is_false(then))
.count()
< 3
&& else_expr.as_deref().is_none_or(is_bool_lit) =>
{
Transformed::yes(
Expr::Case(Case {
expr: None,
when_then_expr: when_then_expr
.iter()
.map(|(when, then)| {
(when.clone(), Box::new(then.clone().not()))
})
.collect(),
else_expr: else_expr.map(|else_expr| Box::new(else_expr.not())),
})
.not(),
)
}
Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
match udf.simplify(args, info)? {
ExprSimplifyResult::Original(args) => {
Expand Down Expand Up @@ -3465,6 +3552,142 @@ mod tests {
);
}

#[test]
fn simplify_literal_case_equality() {
// CASE WHEN c2 != false THEN "ok" ELSE "not_ok"
let simple_case = Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok")),
)],
Some(Box::new(lit("not_ok"))),
));

// CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok"
// -->
// CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok"
// -->
// CASE WHEN c2 != false THEN true ELSE false
// -->
// c2
assert_eq!(
simplify(binary_expr(simple_case.clone(), Operator::Eq, lit("ok"),)),
col("c2_non_null"),
);

// CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok"
// -->
// NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok")
// -->
// NOT(CASE WHEN c2 != false THEN true ELSE false)
// -->
// NOT(c2)
assert_eq!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very nice

simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)),
not(col("c2_non_null")),
);

let complex_case = Expr::Case(Case::new(
None,
vec![
(
Box::new(col("c1").eq(lit("inboxed"))),
Box::new(lit("pending")),
),
(
Box::new(col("c1").eq(lit("scheduled"))),
Box::new(lit("pending")),
),
(
Box::new(col("c1").eq(lit("completed"))),
Box::new(lit("completed")),
),
(
Box::new(col("c1").eq(lit("paused"))),
Box::new(lit("paused")),
),
(Box::new(col("c2")), Box::new(lit("running"))),
(
Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))),
Box::new(lit("backing-off")),
),
],
Some(Box::new(lit("ready"))),
));

assert_eq!(
simplify(binary_expr(
complex_case.clone(),
Operator::Eq,
lit("completed"),
)),
not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and(
distinct_from(col("c1").eq(lit("inboxed")), lit(true))
.and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
)
);

assert_eq!(
simplify(binary_expr(
complex_case.clone(),
Operator::NotEq,
lit("completed"),
)),
distinct_from(col("c1").eq(lit("completed")), lit(true))
.or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
.or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true))))
);

assert_eq!(
simplify(binary_expr(
complex_case.clone(),
Operator::Eq,
lit("running"),
)),
not_distinct_from(col("c2"), lit(true)).and(
distinct_from(col("c1").eq(lit("inboxed")), lit(true))
.and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
.and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
.and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
)
);

assert_eq!(
simplify(binary_expr(
complex_case.clone(),
Operator::Eq,
lit("ready"),
)),
distinct_from(col("c1").eq(lit("inboxed")), lit(true))
.and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
.and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
.and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
.and(distinct_from(col("c2"), lit(true)))
.and(distinct_from(
col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
lit(true)
))
);

assert_eq!(
simplify(binary_expr(
complex_case.clone(),
Operator::NotEq,
lit("ready"),
)),
not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
.or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
.or(not_distinct_from(col("c1").eq(lit("completed")), lit(true)))
.or(not_distinct_from(col("c1").eq(lit("paused")), lit(true)))
.or(not_distinct_from(col("c2"), lit(true)))
.or(not_distinct_from(
col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
lit(true)
))
);
}

#[test]
fn simplify_expr_case_when_then_else() {
// CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
Expand Down
27 changes: 26 additions & 1 deletion datafusion/optimizer/src/simplify_expressions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::{
expr::{Between, BinaryExpr, InList},
expr_fn::{and, bitwise_and, bitwise_or, or},
Expr, Like, Operator,
Case, Expr, Like, Operator,
};

pub static POWS_OF_TEN: [i128; 38] = [
Expand Down Expand Up @@ -265,6 +265,31 @@ pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
}
}

pub fn is_case_with_literal_outputs(expr: &Expr) -> bool {
match expr {
Expr::Case(Case {
expr: None,
when_then_expr,
else_expr,
}) => {
when_then_expr.iter().all(|(_, then)| is_lit(then))
&& else_expr.as_deref().is_none_or(is_lit)
}
_ => false,
}
}

pub fn as_case(expr: &Expr) -> Result<&Case> {
match expr {
Expr::Case(case) => Ok(case),
_ => internal_err!("Expected case, got {expr:?}"),
}
}

pub fn is_lit(expr: &Expr) -> bool {
matches!(expr, Expr::Literal(_, _))
}

/// negate a Not clause
/// input is the clause to be negated.(args of Not clause)
/// For BinaryExpr, use the negation of op instead.
Expand Down