@@ -1399,6 +1399,41 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
13991399 // Rules for Case
14001400 //
14011401
1402+ // Inline a comparison to a literal with the case statement into the `THEN` clauses.
1403+ // which can enable further simplifications
1404+ // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END
1405+ Expr :: BinaryExpr ( BinaryExpr {
1406+ left,
1407+ op : op @ ( Eq | NotEq ) ,
1408+ right,
1409+ } ) if is_case_with_literal_outputs ( & left) && is_lit ( & right) => {
1410+ let case = into_case ( * left) ?;
1411+ Transformed :: yes ( Expr :: Case ( Case {
1412+ expr : None ,
1413+ when_then_expr : case
1414+ . when_then_expr
1415+ . into_iter ( )
1416+ . map ( |( when, then) | {
1417+ (
1418+ when,
1419+ Box :: new ( Expr :: BinaryExpr ( BinaryExpr {
1420+ left : then,
1421+ op,
1422+ right : right. clone ( ) ,
1423+ } ) ) ,
1424+ )
1425+ } )
1426+ . collect ( ) ,
1427+ else_expr : case. else_expr . map ( |els| {
1428+ Box :: new ( Expr :: BinaryExpr ( BinaryExpr {
1429+ left : els,
1430+ op,
1431+ right,
1432+ } ) )
1433+ } ) ,
1434+ } ) )
1435+ }
1436+
14021437 // CASE WHEN true THEN A ... END --> A
14031438 // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
14041439 Expr :: Case ( Case {
@@ -1447,7 +1482,11 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14471482 when_then_expr,
14481483 else_expr,
14491484 } ) if !when_then_expr. is_empty ( )
1450- && when_then_expr. len ( ) < 3 // The rewrite is O(n²) so limit to small number
1485+ // The rewrite is O(n²) in general so limit to small number of when-thens that can be true
1486+ && ( when_then_expr. len ( ) < 3 // small number of input whens
1487+ // or all thens are literal bools and a small number of them are true
1488+ || ( when_then_expr. iter ( ) . all ( |( _, then) | is_bool_lit ( then) )
1489+ && when_then_expr. iter ( ) . filter ( |( _, then) | is_true ( then) ) . count ( ) < 3 ) )
14511490 && info. is_boolean_type ( & when_then_expr[ 0 ] . 1 ) ? =>
14521491 {
14531492 // String disjunction of all the when predicates encountered so far. Not nullable.
@@ -1471,6 +1510,55 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14711510 // Do a first pass at simplification
14721511 out_expr. rewrite ( self ) ?
14731512 }
1513+ // CASE
1514+ // WHEN X THEN true
1515+ // WHEN Y THEN true
1516+ // WHEN Z THEN false
1517+ // ...
1518+ // ELSE true
1519+ // END
1520+ //
1521+ // --->
1522+ //
1523+ // NOT(CASE
1524+ // WHEN X THEN false
1525+ // WHEN Y THEN false
1526+ // WHEN Z THEN true
1527+ // ...
1528+ // ELSE false
1529+ // END)
1530+ //
1531+ // Note: the rationale for this rewrite is that the case can then be further
1532+ // simplified into a small number of ANDs and ORs
1533+ Expr :: Case ( Case {
1534+ expr : None ,
1535+ when_then_expr,
1536+ else_expr,
1537+ } ) if !when_then_expr. is_empty ( )
1538+ && when_then_expr
1539+ . iter ( )
1540+ . all ( |( _, then) | is_bool_lit ( then) ) // all thens are literal bools
1541+ // This simplification is only helpful if we end up with a small number of true thens
1542+ && when_then_expr
1543+ . iter ( )
1544+ . filter ( |( _, then) | is_false ( then) )
1545+ . count ( )
1546+ < 3
1547+ && else_expr. as_deref ( ) . is_none_or ( is_bool_lit) =>
1548+ {
1549+ Transformed :: yes (
1550+ Expr :: Case ( Case {
1551+ expr : None ,
1552+ when_then_expr : when_then_expr
1553+ . into_iter ( )
1554+ . map ( |( when, then) | ( when, Box :: new ( Expr :: Not ( then) ) ) )
1555+ . collect ( ) ,
1556+ else_expr : else_expr
1557+ . map ( |else_expr| Box :: new ( Expr :: Not ( else_expr) ) ) ,
1558+ } )
1559+ . not ( ) ,
1560+ )
1561+ }
14741562 Expr :: ScalarFunction ( ScalarFunction { func : udf, args } ) => {
14751563 match udf. simplify ( args, info) ? {
14761564 ExprSimplifyResult :: Original ( args) => {
@@ -3465,6 +3553,142 @@ mod tests {
34653553 ) ;
34663554 }
34673555
3556+ #[ test]
3557+ fn simplify_literal_case_equality ( ) {
3558+ // CASE WHEN c2 != false THEN "ok" ELSE "not_ok"
3559+ let simple_case = Expr :: Case ( Case :: new (
3560+ None ,
3561+ vec ! [ (
3562+ Box :: new( col( "c2_non_null" ) . not_eq( lit( false ) ) ) ,
3563+ Box :: new( lit( "ok" ) ) ,
3564+ ) ] ,
3565+ Some ( Box :: new ( lit ( "not_ok" ) ) ) ,
3566+ ) ) ;
3567+
3568+ // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok"
3569+ // -->
3570+ // CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok"
3571+ // -->
3572+ // CASE WHEN c2 != false THEN true ELSE false
3573+ // -->
3574+ // c2
3575+ assert_eq ! (
3576+ simplify( binary_expr( simple_case. clone( ) , Operator :: Eq , lit( "ok" ) , ) ) ,
3577+ col( "c2_non_null" ) ,
3578+ ) ;
3579+
3580+ // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok"
3581+ // -->
3582+ // NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok")
3583+ // -->
3584+ // NOT(CASE WHEN c2 != false THEN true ELSE false)
3585+ // -->
3586+ // NOT(c2)
3587+ assert_eq ! (
3588+ simplify( binary_expr( simple_case, Operator :: NotEq , lit( "ok" ) , ) ) ,
3589+ not( col( "c2_non_null" ) ) ,
3590+ ) ;
3591+
3592+ let complex_case = Expr :: Case ( Case :: new (
3593+ None ,
3594+ vec ! [
3595+ (
3596+ Box :: new( col( "c1" ) . eq( lit( "inboxed" ) ) ) ,
3597+ Box :: new( lit( "pending" ) ) ,
3598+ ) ,
3599+ (
3600+ Box :: new( col( "c1" ) . eq( lit( "scheduled" ) ) ) ,
3601+ Box :: new( lit( "pending" ) ) ,
3602+ ) ,
3603+ (
3604+ Box :: new( col( "c1" ) . eq( lit( "completed" ) ) ) ,
3605+ Box :: new( lit( "completed" ) ) ,
3606+ ) ,
3607+ (
3608+ Box :: new( col( "c1" ) . eq( lit( "paused" ) ) ) ,
3609+ Box :: new( lit( "paused" ) ) ,
3610+ ) ,
3611+ ( Box :: new( col( "c2" ) ) , Box :: new( lit( "running" ) ) ) ,
3612+ (
3613+ Box :: new( col( "c1" ) . eq( lit( "invoked" ) ) . and( col( "c3" ) . gt( lit( 0 ) ) ) ) ,
3614+ Box :: new( lit( "backing-off" ) ) ,
3615+ ) ,
3616+ ] ,
3617+ Some ( Box :: new ( lit ( "ready" ) ) ) ,
3618+ ) ) ;
3619+
3620+ assert_eq ! (
3621+ simplify( binary_expr(
3622+ complex_case. clone( ) ,
3623+ Operator :: Eq ,
3624+ lit( "completed" ) ,
3625+ ) ) ,
3626+ not_distinct_from( col( "c1" ) . eq( lit( "completed" ) ) , lit( true ) ) . and(
3627+ distinct_from( col( "c1" ) . eq( lit( "inboxed" ) ) , lit( true ) )
3628+ . and( distinct_from( col( "c1" ) . eq( lit( "scheduled" ) ) , lit( true ) ) )
3629+ )
3630+ ) ;
3631+
3632+ assert_eq ! (
3633+ simplify( binary_expr(
3634+ complex_case. clone( ) ,
3635+ Operator :: NotEq ,
3636+ lit( "completed" ) ,
3637+ ) ) ,
3638+ distinct_from( col( "c1" ) . eq( lit( "completed" ) ) , lit( true ) )
3639+ . or( not_distinct_from( col( "c1" ) . eq( lit( "inboxed" ) ) , lit( true ) )
3640+ . or( not_distinct_from( col( "c1" ) . eq( lit( "scheduled" ) ) , lit( true ) ) ) )
3641+ ) ;
3642+
3643+ assert_eq ! (
3644+ simplify( binary_expr(
3645+ complex_case. clone( ) ,
3646+ Operator :: Eq ,
3647+ lit( "running" ) ,
3648+ ) ) ,
3649+ not_distinct_from( col( "c2" ) , lit( true ) ) . and(
3650+ distinct_from( col( "c1" ) . eq( lit( "inboxed" ) ) , lit( true ) )
3651+ . and( distinct_from( col( "c1" ) . eq( lit( "scheduled" ) ) , lit( true ) ) )
3652+ . and( distinct_from( col( "c1" ) . eq( lit( "completed" ) ) , lit( true ) ) )
3653+ . and( distinct_from( col( "c1" ) . eq( lit( "paused" ) ) , lit( true ) ) )
3654+ )
3655+ ) ;
3656+
3657+ assert_eq ! (
3658+ simplify( binary_expr(
3659+ complex_case. clone( ) ,
3660+ Operator :: Eq ,
3661+ lit( "ready" ) ,
3662+ ) ) ,
3663+ distinct_from( col( "c1" ) . eq( lit( "inboxed" ) ) , lit( true ) )
3664+ . and( distinct_from( col( "c1" ) . eq( lit( "scheduled" ) ) , lit( true ) ) )
3665+ . and( distinct_from( col( "c1" ) . eq( lit( "completed" ) ) , lit( true ) ) )
3666+ . and( distinct_from( col( "c1" ) . eq( lit( "paused" ) ) , lit( true ) ) )
3667+ . and( distinct_from( col( "c2" ) , lit( true ) ) )
3668+ . and( distinct_from(
3669+ col( "c1" ) . eq( lit( "invoked" ) ) . and( col( "c3" ) . gt( lit( 0 ) ) ) ,
3670+ lit( true )
3671+ ) )
3672+ ) ;
3673+
3674+ assert_eq ! (
3675+ simplify( binary_expr(
3676+ complex_case. clone( ) ,
3677+ Operator :: NotEq ,
3678+ lit( "ready" ) ,
3679+ ) ) ,
3680+ not_distinct_from( col( "c1" ) . eq( lit( "inboxed" ) ) , lit( true ) )
3681+ . or( not_distinct_from( col( "c1" ) . eq( lit( "scheduled" ) ) , lit( true ) ) )
3682+ . or( not_distinct_from( col( "c1" ) . eq( lit( "completed" ) ) , lit( true ) ) )
3683+ . or( not_distinct_from( col( "c1" ) . eq( lit( "paused" ) ) , lit( true ) ) )
3684+ . or( not_distinct_from( col( "c2" ) , lit( true ) ) )
3685+ . or( not_distinct_from(
3686+ col( "c1" ) . eq( lit( "invoked" ) ) . and( col( "c3" ) . gt( lit( 0 ) ) ) ,
3687+ lit( true )
3688+ ) )
3689+ ) ;
3690+ }
3691+
34683692 #[ test]
34693693 fn simplify_expr_case_when_then_else ( ) {
34703694 // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
0 commit comments