diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 344a738fe84b..5b5bca75ddb0 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,6 +17,8 @@ //! Expression simplification API +use std::borrow::Cow; +use std::collections::HashSet; use std::ops::Not; use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; @@ -38,10 +40,11 @@ use datafusion_common::{ use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, - ScalarFunctionDefinition, Volatility, + Operator, ScalarFunctionDefinition, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -189,14 +192,15 @@ impl ExprSimplifier { .data()? .rewrite(&mut inlist_simplifier) .data()? - .rewrite(&mut shorten_in_list_simplifier) - .data()? .rewrite(&mut guarantee_rewriter) .data()? // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator) .data()? .rewrite(&mut simplifier) + .data()? + // shorten inlist should be started after other inlist rules are applied + .rewrite(&mut shorten_in_list_simplifier) .data() } @@ -1405,12 +1409,116 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Transformed::yes(lit(false)) } + // expr IN () --> false + // expr NOT IN () --> true + Expr::InList(InList { + expr, + list, + negated, + }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + Transformed::yes(lit(negated)) + } + + // null in (x, y, z) --> null + // null not in (x, y, z) --> null + Expr::InList(InList { + expr, + list: _, + negated: _, + }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()), + + // expr IN ((subquery)) -> expr IN (subquery), see ##5529 + Expr::InList(InList { + expr, + mut list, + negated, + }) if list.len() == 1 + && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) => + { + let Expr::ScalarSubquery(subquery) = list.remove(0) else { + unreachable!() + }; + + Transformed::yes(Expr::InSubquery(InSubquery::new( + expr, subquery, negated, + ))) + } + + // Combine multiple OR expressions into a single IN list expression if possible + // + // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { + let left = as_inlist(left.as_ref()); + let right = as_inlist(right.as_ref()); + + let lhs = left.unwrap(); + let rhs = right.unwrap(); + let lhs = lhs.into_owned(); + let rhs = rhs.into_owned(); + let mut seen: HashSet = HashSet::new(); + let list = lhs + .list + .into_iter() + .chain(rhs.list) + .filter(|e| seen.insert(e.to_owned())) + .collect::>(); + + let merged_inlist = InList { + expr: lhs.expr, + list, + negated: false, + }; + + return Ok(Transformed::yes(Expr::InList(merged_inlist))); + } + // no additional rewrites possible expr => Transformed::no(expr), }) } } +fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { + let left = as_inlist(left); + let right = as_inlist(right); + if let (Some(lhs), Some(rhs)) = (left, right) { + matches!(lhs.expr.as_ref(), Expr::Column(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_)) + && lhs.expr == rhs.expr + && !lhs.negated + && !rhs.negated + } else { + false + } +} + +/// Try to convert an expression to an in-list expression +fn as_inlist(expr: &Expr) -> Option> { + match expr { + Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), + Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { + expr: left.clone(), + list: vec![*right.clone()], + negated: false, + })), + (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { + expr: right.clone(), + list: vec![*left.clone()], + negated: false, + })), + _ => None, + } + } + _ => None, + } +} + #[cfg(test)] mod tests { use std::{ diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index fa1d7cfc1239..5d1cf27827a9 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -17,15 +17,13 @@ //! This module implements a rule that simplifies the values for `InList`s -use super::utils::{is_null, lit_bool_null}; use super::THRESHOLD_INLINE_INLIST; -use std::borrow::Cow; use std::collections::HashSet; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery}; +use datafusion_common::Result; +use datafusion_expr::expr::InList; use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; pub(super) struct ShortenInListSimplifier {} @@ -112,66 +110,6 @@ impl TreeNodeRewriter for InListSimplifier { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { - if let Expr::InList(InList { - expr, - mut list, - negated, - }) = expr.clone() - { - // expr IN () --> false - // expr NOT IN () --> true - if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) { - return Ok(Transformed::yes(lit(negated))); - // null in (x, y, z) --> null - // null not in (x, y, z) --> null - } else if is_null(&expr) { - return Ok(Transformed::yes(lit_bool_null())); - // expr IN ((subquery)) -> expr IN (subquery), see ##5529 - } else if list.len() == 1 - && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) - { - let Expr::ScalarSubquery(subquery) = list.remove(0) else { - unreachable!() - }; - return Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - expr, subquery, negated, - )))); - } - } - // Combine multiple OR expressions into a single IN list expression if possible - // - // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { - if *op == Operator::Or { - let left = as_inlist(left); - let right = as_inlist(right); - if let (Some(lhs), Some(rhs)) = (left, right) { - if lhs.expr.try_into_col().is_ok() - && rhs.expr.try_into_col().is_ok() - && lhs.expr == rhs.expr - && !lhs.negated - && !rhs.negated - { - let lhs = lhs.into_owned(); - let rhs = rhs.into_owned(); - let mut seen: HashSet = HashSet::new(); - let list = lhs - .list - .into_iter() - .chain(rhs.list) - .filter(|e| seen.insert(e.to_owned())) - .collect::>(); - - let merged_inlist = InList { - expr: lhs.expr, - list, - negated: false, - }; - return Ok(Transformed::yes(Expr::InList(merged_inlist))); - } - } - } - } // Simplify expressions that is guaranteed to be true or false to a literal boolean expression // // Rules: @@ -230,29 +168,6 @@ impl TreeNodeRewriter for InListSimplifier { } } -/// Try to convert an expression to an in-list expression -fn as_inlist(expr: &Expr) -> Option> { - match expr { - Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { - expr: left.clone(), - list: vec![*right.clone()], - negated: false, - })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { - expr: right.clone(), - list: vec![*left.clone()], - negated: false, - })), - _ => None, - } - } - _ => None, - } -} - /// Return the union of two inlist expressions /// maintaining the order of the elements in the two lists fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result {