Skip to content

Commit a165b7f

Browse files
authored
Avoid copies in CountWildcardRule via TreeNode API (#10066)
* Avoid copies in `CountWildcardRule` via TreeNode API
1 parent 4e9f2d5 commit a165b7f

File tree

2 files changed

+66
-179
lines changed

2 files changed

+66
-179
lines changed

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

Lines changed: 64 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,17 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
2018
use crate::analyzer::AnalyzerRule;
2119

20+
use crate::utils::NamePreserver;
2221
use datafusion_common::config::ConfigOptions;
23-
use datafusion_common::tree_node::{
24-
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
25-
};
22+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
2623
use datafusion_common::Result;
27-
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery};
28-
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
29-
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
30-
use datafusion_expr::Expr::ScalarSubquery;
31-
use datafusion_expr::{
32-
aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
33-
LogicalPlanBuilder, Projection, Sort, Subquery,
24+
use datafusion_expr::expr::{
25+
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
3426
};
27+
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
28+
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
3529

3630
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
3731
///
@@ -47,181 +41,62 @@ impl CountWildcardRule {
4741

4842
impl AnalyzerRule for CountWildcardRule {
4943
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
50-
plan.transform_down(&analyze_internal).data()
44+
plan.transform_down_with_subqueries(&analyze_internal)
45+
.data()
5146
}
5247

5348
fn name(&self) -> &str {
5449
"count_wildcard_rule"
5550
}
5651
}
5752

58-
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
59-
let mut rewriter = CountWildcardRewriter {};
60-
match plan {
61-
LogicalPlan::Window(window) => {
62-
let window_expr = window
63-
.window_expr
64-
.iter()
65-
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
66-
.collect::<Result<Vec<_>>>()?;
67-
68-
Ok(Transformed::yes(
69-
LogicalPlanBuilder::from((*window.input).clone())
70-
.window(window_expr)?
71-
.build()?,
72-
))
73-
}
74-
LogicalPlan::Aggregate(agg) => {
75-
let aggr_expr = agg
76-
.aggr_expr
77-
.iter()
78-
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
79-
.collect::<Result<Vec<_>>>()?;
80-
81-
Ok(Transformed::yes(LogicalPlan::Aggregate(
82-
Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?,
83-
)))
84-
}
85-
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
86-
let sort_expr = expr
87-
.iter()
88-
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
89-
.collect::<Result<Vec<_>>>()?;
90-
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
91-
expr: sort_expr,
92-
input,
93-
fetch,
94-
})))
95-
}
96-
LogicalPlan::Projection(projection) => {
97-
let projection_expr = projection
98-
.expr
99-
.iter()
100-
.map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter))
101-
.collect::<Result<Vec<_>>>()?;
102-
Ok(Transformed::yes(LogicalPlan::Projection(
103-
Projection::try_new(projection_expr, projection.input)?,
104-
)))
105-
}
106-
LogicalPlan::Filter(Filter {
107-
predicate, input, ..
108-
}) => {
109-
let predicate = rewrite_preserving_name(predicate, &mut rewriter)?;
110-
Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
111-
predicate, input,
112-
)?)))
113-
}
114-
115-
_ => Ok(Transformed::no(plan)),
116-
}
53+
fn is_wildcard(expr: &Expr) -> bool {
54+
matches!(expr, Expr::Wildcard { qualifier: None })
11755
}
11856

119-
struct CountWildcardRewriter {}
120-
121-
impl TreeNodeRewriter for CountWildcardRewriter {
122-
type Node = Expr;
123-
124-
fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> {
125-
Ok(match old_expr.clone() {
126-
Expr::WindowFunction(expr::WindowFunction {
127-
fun:
128-
expr::WindowFunctionDefinition::AggregateFunction(
129-
aggregate_function::AggregateFunction::Count,
130-
),
131-
args,
132-
partition_by,
133-
order_by,
134-
window_frame,
135-
null_treatment,
136-
}) if args.len() == 1 => match args[0] {
137-
Expr::Wildcard { qualifier: None } => {
138-
Transformed::yes(Expr::WindowFunction(expr::WindowFunction {
139-
fun: expr::WindowFunctionDefinition::AggregateFunction(
140-
aggregate_function::AggregateFunction::Count,
141-
),
142-
args: vec![lit(COUNT_STAR_EXPANSION)],
143-
partition_by,
144-
order_by,
145-
window_frame,
146-
null_treatment,
147-
}))
148-
}
57+
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
58+
matches!(
59+
&aggregate_function.func_def,
60+
AggregateFunctionDefinition::BuiltIn(
61+
datafusion_expr::aggregate_function::AggregateFunction::Count,
62+
)
63+
) && aggregate_function.args.len() == 1
64+
&& is_wildcard(&aggregate_function.args[0])
65+
}
14966

150-
_ => Transformed::no(old_expr),
151-
},
152-
Expr::AggregateFunction(AggregateFunction {
153-
func_def:
154-
AggregateFunctionDefinition::BuiltIn(
155-
aggregate_function::AggregateFunction::Count,
156-
),
157-
args,
158-
distinct,
159-
filter,
160-
order_by,
161-
null_treatment,
162-
}) if args.len() == 1 => match args[0] {
163-
Expr::Wildcard { qualifier: None } => {
164-
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
165-
aggregate_function::AggregateFunction::Count,
166-
vec![lit(COUNT_STAR_EXPANSION)],
167-
distinct,
168-
filter,
169-
order_by,
170-
null_treatment,
171-
)))
172-
}
173-
_ => Transformed::no(old_expr),
174-
},
67+
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
68+
matches!(
69+
&window_function.fun,
70+
WindowFunctionDefinition::AggregateFunction(
71+
datafusion_expr::aggregate_function::AggregateFunction::Count,
72+
)
73+
) && window_function.args.len() == 1
74+
&& is_wildcard(&window_function.args[0])
75+
}
17576

176-
ScalarSubquery(Subquery {
177-
subquery,
178-
outer_ref_columns,
179-
}) => subquery
180-
.as_ref()
181-
.clone()
182-
.transform_down(&analyze_internal)?
183-
.update_data(|new_plan| {
184-
ScalarSubquery(Subquery {
185-
subquery: Arc::new(new_plan),
186-
outer_ref_columns,
187-
})
188-
}),
189-
Expr::InSubquery(InSubquery {
190-
expr,
191-
subquery,
192-
negated,
193-
}) => subquery
194-
.subquery
195-
.as_ref()
196-
.clone()
197-
.transform_down(&analyze_internal)?
198-
.update_data(|new_plan| {
199-
Expr::InSubquery(InSubquery::new(
200-
expr,
201-
Subquery {
202-
subquery: Arc::new(new_plan),
203-
outer_ref_columns: subquery.outer_ref_columns,
204-
},
205-
negated,
206-
))
207-
}),
208-
Expr::Exists(expr::Exists { subquery, negated }) => subquery
209-
.subquery
210-
.as_ref()
211-
.clone()
212-
.transform_down(&analyze_internal)?
213-
.update_data(|new_plan| {
214-
Expr::Exists(expr::Exists {
215-
subquery: Subquery {
216-
subquery: Arc::new(new_plan),
217-
outer_ref_columns: subquery.outer_ref_columns,
218-
},
219-
negated,
220-
})
221-
}),
222-
_ => Transformed::no(old_expr),
223-
})
224-
}
77+
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
78+
let name_preserver = NamePreserver::new(&plan);
79+
plan.map_expressions(|expr| {
80+
let original_name = name_preserver.save(&expr)?;
81+
let transformed_expr = expr.transform_up(&|expr| match expr {
82+
Expr::WindowFunction(mut window_function)
83+
if is_count_star_window_aggregate(&window_function) =>
84+
{
85+
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
86+
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
87+
}
88+
Expr::AggregateFunction(mut aggregate_function)
89+
if is_count_star_aggregate(&aggregate_function) =>
90+
{
91+
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
92+
Ok(Transformed::yes(Expr::AggregateFunction(
93+
aggregate_function,
94+
)))
95+
}
96+
_ => Ok(Transformed::no(expr)),
97+
})?;
98+
transformed_expr.map_data(|data| original_name.restore(data))
99+
})
225100
}
226101

227102
#[cfg(test)]
@@ -233,9 +108,10 @@ mod tests {
233108
use datafusion_expr::expr::Sort;
234109
use datafusion_expr::{
235110
col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder,
236-
max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr,
111+
max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr,
237112
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
238113
};
114+
use std::sync::Arc;
239115

240116
fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
241117
assert_analyzed_plan_eq_display_indent(
@@ -381,6 +257,17 @@ mod tests {
381257
assert_plan_eq(&plan, expected)
382258
}
383259

260+
#[test]
261+
fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
262+
let table_scan = test_table_scan()?;
263+
let err = LogicalPlanBuilder::from(table_scan)
264+
.aggregate(Vec::<Expr>::new(), vec![sum(wildcard())])
265+
.unwrap_err()
266+
.to_string();
267+
assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}");
268+
Ok(())
269+
}
270+
384271
#[test]
385272
fn test_count_wildcard_on_nesting() -> Result<()> {
386273
let table_scan = test_table_scan()?;

datafusion/optimizer/src/analyzer/function_rewrite.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ impl ApplyFunctionRewrites {
6464
let original_name = name_preserver.save(&expr)?;
6565

6666
// recursively transform the expression, applying the rewrites at each step
67-
let result = expr.transform_up(&|expr| {
67+
let transformed_expr = expr.transform_up(&|expr| {
6868
let mut result = Transformed::no(expr);
6969
for rewriter in self.function_rewrites.iter() {
7070
result = result.transform_data(|expr| {
@@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
7474
Ok(result)
7575
})?;
7676

77-
result.map_data(|expr| original_name.restore(expr))
77+
transformed_expr.map_data(|expr| original_name.restore(expr))
7878
})
7979
}
8080
}

0 commit comments

Comments
 (0)