1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: sync:: Arc ;
19-
2018use crate :: analyzer:: AnalyzerRule ;
2119
20+ use crate :: utils:: NamePreserver ;
2221use datafusion_common:: config:: ConfigOptions ;
23- use datafusion_common:: tree_node:: {
24- Transformed , TransformedResult , TreeNode , TreeNodeRewriter ,
25- } ;
22+ use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
2623use datafusion_common:: Result ;
27- use datafusion_expr:: expr:: { AggregateFunction , AggregateFunctionDefinition , InSubquery } ;
28- use datafusion_expr:: expr_rewriter:: rewrite_preserving_name;
29- use datafusion_expr:: utils:: COUNT_STAR_EXPANSION ;
30- use datafusion_expr:: Expr :: ScalarSubquery ;
31- use datafusion_expr:: {
32- aggregate_function, expr, lit, Aggregate , Expr , Filter , LogicalPlan ,
33- LogicalPlanBuilder , Projection , Sort , Subquery ,
24+ use datafusion_expr:: expr:: {
25+ AggregateFunction , AggregateFunctionDefinition , WindowFunction ,
3426} ;
27+ use datafusion_expr:: utils:: COUNT_STAR_EXPANSION ;
28+ use datafusion_expr:: { lit, Expr , LogicalPlan , WindowFunctionDefinition } ;
3529
3630/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
3731///
@@ -47,181 +41,62 @@ impl CountWildcardRule {
4741
4842impl AnalyzerRule for CountWildcardRule {
4943 fn analyze ( & self , plan : LogicalPlan , _: & ConfigOptions ) -> Result < LogicalPlan > {
50- plan. transform_down ( & analyze_internal) . data ( )
44+ plan. transform_down_with_subqueries ( & analyze_internal)
45+ . data ( )
5146 }
5247
5348 fn name ( & self ) -> & str {
5449 "count_wildcard_rule"
5550 }
5651}
5752
58- fn analyze_internal ( plan : LogicalPlan ) -> Result < Transformed < LogicalPlan > > {
59- let mut rewriter = CountWildcardRewriter { } ;
60- match plan {
61- LogicalPlan :: Window ( window) => {
62- let window_expr = window
63- . window_expr
64- . iter ( )
65- . map ( |expr| rewrite_preserving_name ( expr. clone ( ) , & mut rewriter) )
66- . collect :: < Result < Vec < _ > > > ( ) ?;
67-
68- Ok ( Transformed :: yes (
69- LogicalPlanBuilder :: from ( ( * window. input ) . clone ( ) )
70- . window ( window_expr) ?
71- . build ( ) ?,
72- ) )
73- }
74- LogicalPlan :: Aggregate ( agg) => {
75- let aggr_expr = agg
76- . aggr_expr
77- . iter ( )
78- . map ( |expr| rewrite_preserving_name ( expr. clone ( ) , & mut rewriter) )
79- . collect :: < Result < Vec < _ > > > ( ) ?;
80-
81- Ok ( Transformed :: yes ( LogicalPlan :: Aggregate (
82- Aggregate :: try_new ( agg. input . clone ( ) , agg. group_expr , aggr_expr) ?,
83- ) ) )
84- }
85- LogicalPlan :: Sort ( Sort { expr, input, fetch } ) => {
86- let sort_expr = expr
87- . iter ( )
88- . map ( |expr| rewrite_preserving_name ( expr. clone ( ) , & mut rewriter) )
89- . collect :: < Result < Vec < _ > > > ( ) ?;
90- Ok ( Transformed :: yes ( LogicalPlan :: Sort ( Sort {
91- expr : sort_expr,
92- input,
93- fetch,
94- } ) ) )
95- }
96- LogicalPlan :: Projection ( projection) => {
97- let projection_expr = projection
98- . expr
99- . iter ( )
100- . map ( |expr| rewrite_preserving_name ( expr. clone ( ) , & mut rewriter) )
101- . collect :: < Result < Vec < _ > > > ( ) ?;
102- Ok ( Transformed :: yes ( LogicalPlan :: Projection (
103- Projection :: try_new ( projection_expr, projection. input ) ?,
104- ) ) )
105- }
106- LogicalPlan :: Filter ( Filter {
107- predicate, input, ..
108- } ) => {
109- let predicate = rewrite_preserving_name ( predicate, & mut rewriter) ?;
110- Ok ( Transformed :: yes ( LogicalPlan :: Filter ( Filter :: try_new (
111- predicate, input,
112- ) ?) ) )
113- }
114-
115- _ => Ok ( Transformed :: no ( plan) ) ,
116- }
53+ fn is_wildcard ( expr : & Expr ) -> bool {
54+ matches ! ( expr, Expr :: Wildcard { qualifier: None } )
11755}
11856
119- struct CountWildcardRewriter { }
120-
121- impl TreeNodeRewriter for CountWildcardRewriter {
122- type Node = Expr ;
123-
124- fn f_up ( & mut self , old_expr : Expr ) -> Result < Transformed < Expr > > {
125- Ok ( match old_expr. clone ( ) {
126- Expr :: WindowFunction ( expr:: WindowFunction {
127- fun :
128- expr:: WindowFunctionDefinition :: AggregateFunction (
129- aggregate_function:: AggregateFunction :: Count ,
130- ) ,
131- args,
132- partition_by,
133- order_by,
134- window_frame,
135- null_treatment,
136- } ) if args. len ( ) == 1 => match args[ 0 ] {
137- Expr :: Wildcard { qualifier : None } => {
138- Transformed :: yes ( Expr :: WindowFunction ( expr:: WindowFunction {
139- fun : expr:: WindowFunctionDefinition :: AggregateFunction (
140- aggregate_function:: AggregateFunction :: Count ,
141- ) ,
142- args : vec ! [ lit( COUNT_STAR_EXPANSION ) ] ,
143- partition_by,
144- order_by,
145- window_frame,
146- null_treatment,
147- } ) )
148- }
57+ fn is_count_star_aggregate ( aggregate_function : & AggregateFunction ) -> bool {
58+ matches ! (
59+ & aggregate_function. func_def,
60+ AggregateFunctionDefinition :: BuiltIn (
61+ datafusion_expr:: aggregate_function:: AggregateFunction :: Count ,
62+ )
63+ ) && aggregate_function. args . len ( ) == 1
64+ && is_wildcard ( & aggregate_function. args [ 0 ] )
65+ }
14966
150- _ => Transformed :: no ( old_expr) ,
151- } ,
152- Expr :: AggregateFunction ( AggregateFunction {
153- func_def :
154- AggregateFunctionDefinition :: BuiltIn (
155- aggregate_function:: AggregateFunction :: Count ,
156- ) ,
157- args,
158- distinct,
159- filter,
160- order_by,
161- null_treatment,
162- } ) if args. len ( ) == 1 => match args[ 0 ] {
163- Expr :: Wildcard { qualifier : None } => {
164- Transformed :: yes ( Expr :: AggregateFunction ( AggregateFunction :: new (
165- aggregate_function:: AggregateFunction :: Count ,
166- vec ! [ lit( COUNT_STAR_EXPANSION ) ] ,
167- distinct,
168- filter,
169- order_by,
170- null_treatment,
171- ) ) )
172- }
173- _ => Transformed :: no ( old_expr) ,
174- } ,
67+ fn is_count_star_window_aggregate ( window_function : & WindowFunction ) -> bool {
68+ matches ! (
69+ & window_function. fun,
70+ WindowFunctionDefinition :: AggregateFunction (
71+ datafusion_expr:: aggregate_function:: AggregateFunction :: Count ,
72+ )
73+ ) && window_function. args . len ( ) == 1
74+ && is_wildcard ( & window_function. args [ 0 ] )
75+ }
17576
176- ScalarSubquery ( Subquery {
177- subquery,
178- outer_ref_columns,
179- } ) => subquery
180- . as_ref ( )
181- . clone ( )
182- . transform_down ( & analyze_internal) ?
183- . update_data ( |new_plan| {
184- ScalarSubquery ( Subquery {
185- subquery : Arc :: new ( new_plan) ,
186- outer_ref_columns,
187- } )
188- } ) ,
189- Expr :: InSubquery ( InSubquery {
190- expr,
191- subquery,
192- negated,
193- } ) => subquery
194- . subquery
195- . as_ref ( )
196- . clone ( )
197- . transform_down ( & analyze_internal) ?
198- . update_data ( |new_plan| {
199- Expr :: InSubquery ( InSubquery :: new (
200- expr,
201- Subquery {
202- subquery : Arc :: new ( new_plan) ,
203- outer_ref_columns : subquery. outer_ref_columns ,
204- } ,
205- negated,
206- ) )
207- } ) ,
208- Expr :: Exists ( expr:: Exists { subquery, negated } ) => subquery
209- . subquery
210- . as_ref ( )
211- . clone ( )
212- . transform_down ( & analyze_internal) ?
213- . update_data ( |new_plan| {
214- Expr :: Exists ( expr:: Exists {
215- subquery : Subquery {
216- subquery : Arc :: new ( new_plan) ,
217- outer_ref_columns : subquery. outer_ref_columns ,
218- } ,
219- negated,
220- } )
221- } ) ,
222- _ => Transformed :: no ( old_expr) ,
223- } )
224- }
77+ fn analyze_internal ( plan : LogicalPlan ) -> Result < Transformed < LogicalPlan > > {
78+ let name_preserver = NamePreserver :: new ( & plan) ;
79+ plan. map_expressions ( |expr| {
80+ let original_name = name_preserver. save ( & expr) ?;
81+ let transformed_expr = expr. transform_up ( & |expr| match expr {
82+ Expr :: WindowFunction ( mut window_function)
83+ if is_count_star_window_aggregate ( & window_function) =>
84+ {
85+ window_function. args = vec ! [ lit( COUNT_STAR_EXPANSION ) ] ;
86+ Ok ( Transformed :: yes ( Expr :: WindowFunction ( window_function) ) )
87+ }
88+ Expr :: AggregateFunction ( mut aggregate_function)
89+ if is_count_star_aggregate ( & aggregate_function) =>
90+ {
91+ aggregate_function. args = vec ! [ lit( COUNT_STAR_EXPANSION ) ] ;
92+ Ok ( Transformed :: yes ( Expr :: AggregateFunction (
93+ aggregate_function,
94+ ) ) )
95+ }
96+ _ => Ok ( Transformed :: no ( expr) ) ,
97+ } ) ?;
98+ transformed_expr. map_data ( |data| original_name. restore ( data) )
99+ } )
225100}
226101
227102#[ cfg( test) ]
@@ -233,9 +108,10 @@ mod tests {
233108 use datafusion_expr:: expr:: Sort ;
234109 use datafusion_expr:: {
235110 col, count, exists, expr, in_subquery, lit, logical_plan:: LogicalPlanBuilder ,
236- max, out_ref_col, scalar_subquery, wildcard, AggregateFunction , Expr ,
111+ max, out_ref_col, scalar_subquery, sum , wildcard, AggregateFunction , Expr ,
237112 WindowFrame , WindowFrameBound , WindowFrameUnits , WindowFunctionDefinition ,
238113 } ;
114+ use std:: sync:: Arc ;
239115
240116 fn assert_plan_eq ( plan : & LogicalPlan , expected : & str ) -> Result < ( ) > {
241117 assert_analyzed_plan_eq_display_indent (
@@ -381,6 +257,17 @@ mod tests {
381257 assert_plan_eq ( & plan, expected)
382258 }
383259
260+ #[ test]
261+ fn test_count_wildcard_on_non_count_aggregate ( ) -> Result < ( ) > {
262+ let table_scan = test_table_scan ( ) ?;
263+ let err = LogicalPlanBuilder :: from ( table_scan)
264+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ sum( wildcard( ) ) ] )
265+ . unwrap_err ( )
266+ . to_string ( ) ;
267+ assert ! ( err. contains( "Error during planning: No function matches the given name and argument types 'SUM(Null)'." ) , "{err}" ) ;
268+ Ok ( ( ) )
269+ }
270+
384271 #[ test]
385272 fn test_count_wildcard_on_nesting ( ) -> Result < ( ) > {
386273 let table_scan = test_table_scan ( ) ?;
0 commit comments