From d0aac59f0a020e865b27601353b2e51e38af19db Mon Sep 17 00:00:00 2001 From: xudong963 Date: Wed, 30 Apr 2025 23:32:23 +0800 Subject: [PATCH] Support inferring new predicates to push down --- datafusion/expr/src/expr_rewriter/mod.rs | 18 ++++- datafusion/optimizer/src/push_down_filter.rs | 75 ++++++++++++++++++- datafusion/optimizer/src/utils.rs | 6 +- .../test_files/push_down_filter.slt | 32 ++++++++ 4 files changed, 124 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 90dcbce46b01..0e76ce205f8d 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -131,13 +131,25 @@ pub fn normalize_sorts( } /// Recursively replace all [`Column`] expressions in a given expression tree with -/// `Column` expressions provided by the hash map argument. -pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { +/// the expressions provided by the hash map argument. +/// +/// # Arguments +/// * `expr` - The expression to transform +/// * `replace_map` - A mapping from Column to replacement expression +/// * `to_expr` - A function that converts the replacement value to an Expr +pub fn replace_col( + expr: Expr, + replace_map: &HashMap<&Column, V>, + to_expr: F, +) -> Result +where + F: Fn(&V) -> Expr, +{ expr.transform(|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { - Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + Some(replacement) => Transformed::yes(to_expr(replacement)), None => Transformed::no(expr), } } else { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c9617514e453..6618191a1d09 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -631,7 +631,10 @@ impl InferredPredicates { Ok(true) ) { - self.predicates.push(replace_col(predicate, replace_map)?); + self.predicates + .push(replace_col(predicate, replace_map, |col| { + Expr::Column((*col).clone()) + })?); } Ok(()) @@ -784,13 +787,14 @@ impl OptimizerRule for PushDownFilter { // remove duplicated filters let child_predicates = split_conjunction_owned(child_filter.predicate); - let new_predicates = parents_predicates + let mut new_predicates = parents_predicates .into_iter() .chain(child_predicates) // use IndexSet to remove dupes while preserving predicate order .collect::>() .into_iter() .collect::>(); + new_predicates = infer_predicates_from_equalities(new_predicates)?; let Some(new_predicate) = conjunction(new_predicates) else { return plan_err!("at least one expression exists"); @@ -1382,6 +1386,73 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { is_contain } +/// Infers new predicates by substituting equalities. +/// For example, with predicates `t2.b = 3` and `t1.b > t2.b`, +/// we can infer `t1.b > 3`. +fn infer_predicates_from_equalities(predicates: Vec) -> Result> { + // Map from column names to their literal values (from equality predicates) + let mut equality_map: HashMap = + HashMap::with_capacity(predicates.len()); + let mut final_predicates = Vec::with_capacity(predicates.len()); + // First pass: collect column=literal equalities + for predicate in predicates.iter() { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = predicate + { + if let Expr::Column(col) = left.as_ref() { + // Only add to map if right side is a literal + if matches!(right.as_ref(), Expr::Literal(_)) { + equality_map.insert(col.clone(), *right.clone()); + final_predicates.push(predicate.clone()); + } + } else if let Expr::Column(col) = right.as_ref() { + // Only add to map if left side is a literal + if matches!(left.as_ref(), Expr::Literal(_)) { + equality_map.insert(col.clone(), *right.clone()); + final_predicates.push(predicate.clone()); + } + } + } + } + + // If no equality mappings found, nothing to infer + if equality_map.is_empty() { + return Ok(predicates); + } + + // Second pass: apply substitutions to create new predicates + for predicate in predicates { + // Skip equality predicates we already used for mapping + if final_predicates.contains(&predicate) { + continue; + } + + // Try to replace columns with their literal values + let mut columns_in_expr = HashSet::new(); + expr_to_columns(&predicate, &mut columns_in_expr)?; + + // Create a combined replacement map for all columns in this predicate + let replace_map: HashMap<_, _> = columns_in_expr + .iter() + .filter_map(|col| equality_map.get(col).map(|lit| (col, lit))) + .collect(); + + if replace_map.is_empty() { + final_predicates.push(predicate); + continue; + } + // Apply all substitutions at once to get the fully substituted predicate + let new_pred = replace_col(predicate, &replace_map, |e| (*e).clone())?; + + final_predicates.push(new_pred); + } + + Ok(final_predicates) +} + #[cfg(test)] mod tests { use std::any::Any; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 41c40ec06d65..ebd3c5cd79f2 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -59,7 +59,7 @@ pub(crate) fn replace_qualified_name( let replace_map: HashMap<&Column, &Column> = cols.iter().zip(alias_cols.iter()).collect(); - replace_col(expr, &replace_map) + replace_col(expr, &replace_map, |col| Expr::Column((*col).clone())) } /// Log the plan in debug/tracing mode after some part of the optimizer runs @@ -136,7 +136,9 @@ fn evaluate_expr_with_null_column<'a>( .map(|column| (column, &null_column)) .collect::>(); - let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?; + let replaced_predicate = replace_col(predicate, &join_cols_to_replace, |col| { + Expr::Column((*col).clone()) + })?; let coerced_predicate = coerce(replaced_predicate, &input_schema)?; create_physical_expr(&coerced_predicate, &input_schema, &execution_props)? .evaluate(&input_batch) diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 67965146e76b..06fc4b940efe 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -259,3 +259,35 @@ logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8 statement ok drop table t; + +statement ok +create table t1(a int, b int) as values(1, 2), (2, 3), (3 ,4); + +statement ok +create table t2(a int, b int) as values (1, 2), (2, 4), (4, 5); + +query TT +explain select + * +from + t1 + join t2 on t1.a = t2.a + and t1.b between t2.b + and t2.b + 2 +where + t2.b = 3 +---- +logical_plan +01)Inner Join: t1.a = t2.a +02)--Projection: t1.a, t1.b +03)----Filter: __common_expr_4 >= Int64(3) AND __common_expr_4 <= Int64(5) +04)------Projection: CAST(t1.b AS Int64) AS __common_expr_4, t1.a, t1.b +05)--------TableScan: t1 projection=[a, b] +06)--Filter: t2.b = Int32(3) +07)----TableScan: t2 projection=[a, b] + +statement ok +drop table t1; + +statement ok +drop table t2;