1717
1818//! Eliminate common sub-expression.
1919
20+ use std:: collections:: hash_map:: Entry ;
2021use std:: collections:: { BTreeSet , HashMap } ;
2122use std:: sync:: Arc ;
2223
@@ -35,16 +36,42 @@ use datafusion_expr::expr::Alias;
3536use datafusion_expr:: logical_plan:: { Aggregate , LogicalPlan , Projection , Window } ;
3637use datafusion_expr:: { col, Expr , ExprSchemable } ;
3738
38- /// A map from expression's identifier to tuple including
39- ///
40- /// key == Identifier created with only the current node (& subtree)
41- ///
42- /// values:
43- /// - the expression itself (cloned)
44- /// - counter
45- /// - DataType of this expression.
46- /// - symbol used as the identifier in the alias
47- type ExprSet = HashMap < Identifier , ( Expr , usize , DataType , Identifier ) > ;
39+ /// Set of expressions generated by the [`ExprIdentifierVisitor`]
40+ /// and consumed by the [`CommonSubexprRewriter`].
41+ #[ derive( Default ) ]
42+ struct ExprSet {
43+ /// A map from expression's identifier (stringified expr) to tuple including:
44+ /// - the expression itself (cloned)
45+ /// - counter
46+ /// - DataType of this expression.
47+ /// - symbol used as the identifier in the alias.
48+ map : HashMap < Identifier , ( Expr , usize , DataType , Identifier ) > ,
49+ }
50+
51+ impl ExprSet {
52+ fn expr_identifier ( expr : & Expr ) -> Identifier {
53+ format ! ( "{expr}" )
54+ }
55+
56+ fn get ( & self , key : & Identifier ) -> Option < & ( Expr , usize , DataType , Identifier ) > {
57+ self . map . get ( key)
58+ }
59+
60+ fn entry (
61+ & mut self ,
62+ key : Identifier ,
63+ ) -> Entry < ' _ , Identifier , ( Expr , usize , DataType , Identifier ) > {
64+ self . map . entry ( key)
65+ }
66+ }
67+
68+ impl From < Vec < ( Identifier , ( Expr , usize , DataType , Identifier ) ) > > for ExprSet {
69+ fn from ( entries : Vec < ( Identifier , ( Expr , usize , DataType , Identifier ) ) > ) -> Self {
70+ let mut expr_set = Self :: default ( ) ;
71+ let _ = entries. into_iter ( ) . map ( |( k, v) | expr_set. map . insert ( k, v) ) ;
72+ expr_set
73+ }
74+ }
4875
4976/// Identifier for each subexpression.
5077///
@@ -134,7 +161,7 @@ impl CommonSubexprEliminate {
134161 config : & dyn OptimizerConfig ,
135162 ) -> Result < LogicalPlan > {
136163 let mut window_exprs = vec ! [ ] ;
137- let mut expr_set = ExprSet :: new ( ) ;
164+ let mut expr_set = ExprSet :: default ( ) ;
138165
139166 // Get all window expressions inside the consecutive window operators.
140167 // Consecutive window expressions may refer to same complex expression.
@@ -206,7 +233,7 @@ impl CommonSubexprEliminate {
206233 input,
207234 ..
208235 } = aggregate;
209- let mut expr_set = ExprSet :: new ( ) ;
236+ let mut expr_set = ExprSet :: default ( ) ;
210237
211238 // build expr_set, with groupby and aggr
212239 let input_schema = Arc :: clone ( input. schema ( ) ) ;
@@ -226,7 +253,7 @@ impl CommonSubexprEliminate {
226253 let new_group_expr = pop_expr ( & mut new_expr) ?;
227254
228255 // create potential projection on top
229- let mut expr_set = ExprSet :: new ( ) ;
256+ let mut expr_set = ExprSet :: default ( ) ;
230257 let new_input_schema = Arc :: clone ( new_input. schema ( ) ) ;
231258 populate_expr_set (
232259 & new_aggr_expr,
@@ -277,9 +304,7 @@ impl CommonSubexprEliminate {
277304 agg_exprs. push ( expr. alias ( & name) ) ;
278305 proj_exprs. push ( Expr :: Column ( Column :: from_name ( name) ) ) ;
279306 } else {
280- let id = ExprIdentifierVisitor :: < ' static > :: expr_identifier (
281- & expr_rewritten,
282- ) ;
307+ let id = ExprSet :: expr_identifier ( & expr_rewritten) ;
283308 let out_name =
284309 expr_rewritten. to_field ( & new_input_schema) ?. qualified_name ( ) ;
285310 agg_exprs. push ( expr_rewritten. alias ( & id) ) ;
@@ -313,7 +338,7 @@ impl CommonSubexprEliminate {
313338 let inputs = plan. inputs ( ) ;
314339 let input = inputs[ 0 ] ;
315340 let input_schema = Arc :: clone ( input. schema ( ) ) ;
316- let mut expr_set = ExprSet :: new ( ) ;
341+ let mut expr_set = ExprSet :: default ( ) ;
317342
318343 // Visit expr list and build expr identifier to occuring count map (`expr_set`).
319344 populate_expr_set ( & expr, input_schema, & mut expr_set, ExprMask :: Normal ) ?;
@@ -572,10 +597,6 @@ enum VisitRecord {
572597}
573598
574599impl ExprIdentifierVisitor < ' _ > {
575- fn expr_identifier ( expr : & Expr ) -> Identifier {
576- format ! ( "{expr}" )
577- }
578-
579600 /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
580601 /// before it.
581602 fn pop_enter_mark ( & mut self ) -> ( usize , Identifier ) {
@@ -619,21 +640,22 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
619640
620641 // skip exprs should not be recognize.
621642 if self . expr_mask . ignores ( expr) {
622- let curr_expr_identifier = Self :: expr_identifier ( expr) ;
643+ let curr_expr_identifier = ExprSet :: expr_identifier ( expr) ;
623644 self . visit_stack
624645 . push ( VisitRecord :: ExprItem ( curr_expr_identifier) ) ;
625646 return Ok ( TreeNodeRecursion :: Continue ) ;
626647 }
627- let curr_expr_identifier = Self :: expr_identifier ( expr) ;
628- let desc = format ! ( "{curr_expr_identifier}{sub_expr_identifier}" ) ;
648+ let curr_expr_identifier = ExprSet :: expr_identifier ( expr) ;
649+ let alias_symbol = format ! ( "{curr_expr_identifier}{sub_expr_identifier}" ) ;
629650
630- self . visit_stack . push ( VisitRecord :: ExprItem ( desc. clone ( ) ) ) ;
651+ self . visit_stack
652+ . push ( VisitRecord :: ExprItem ( alias_symbol. clone ( ) ) ) ;
631653
632654 let data_type = expr. get_type ( & self . input_schema ) ?;
633655
634656 self . expr_set
635657 . entry ( curr_expr_identifier)
636- . or_insert_with ( || ( expr. clone ( ) , 0 , data_type, desc ) )
658+ . or_insert_with ( || ( expr. clone ( ) , 0 , data_type, alias_symbol ) )
637659 . 1 += 1 ;
638660 Ok ( TreeNodeRecursion :: Continue )
639661 }
@@ -666,12 +688,6 @@ struct CommonSubexprRewriter<'a> {
666688 affected_id : & ' a mut BTreeSet < Identifier > ,
667689}
668690
669- impl CommonSubexprRewriter < ' _ > {
670- fn expr_identifier ( expr : & Expr ) -> Identifier {
671- format ! ( "{expr}" )
672- }
673- }
674-
675691impl TreeNodeRewriter for CommonSubexprRewriter < ' _ > {
676692 type Node = Expr ;
677693
@@ -683,7 +699,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
683699 return Ok ( Transformed :: new ( expr, false , TreeNodeRecursion :: Jump ) ) ;
684700 }
685701
686- let curr_id = & Self :: expr_identifier ( & expr) ;
702+ let curr_id = & ExprSet :: expr_identifier ( & expr) ;
687703
688704 // lookup previously visited expression
689705 match self . expr_set . get ( curr_id) {
@@ -997,7 +1013,7 @@ mod test {
9971013 let table_scan = test_table_scan ( ) . unwrap ( ) ;
9981014 let affected_id: BTreeSet < Identifier > =
9991015 [ "c+a" . to_string ( ) , "b+a" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
1000- let expr_set_1 = [
1016+ let expr_set_1 = vec ! [
10011017 (
10021018 "c+a" . to_string( ) ,
10031019 ( col( "c" ) + col( "a" ) , 1 , DataType :: UInt32 , "c+a" . to_string( ) ) ,
@@ -1007,9 +1023,8 @@ mod test {
10071023 ( col( "b" ) + col( "a" ) , 1 , DataType :: UInt32 , "b+a" . to_string( ) ) ,
10081024 ) ,
10091025 ]
1010- . into_iter ( )
1011- . collect ( ) ;
1012- let expr_set_2 = [
1026+ . into ( ) ;
1027+ let expr_set_2 = vec ! [
10131028 (
10141029 "c+a" . to_string( ) ,
10151030 ( col( "c+a" ) , 1 , DataType :: UInt32 , "c+a" . to_string( ) ) ,
@@ -1019,8 +1034,7 @@ mod test {
10191034 ( col( "b+a" ) , 1 , DataType :: UInt32 , "b+a" . to_string( ) ) ,
10201035 ) ,
10211036 ]
1022- . into_iter ( )
1023- . collect ( ) ;
1037+ . into ( ) ;
10241038 let project =
10251039 build_common_expr_project_plan ( table_scan, affected_id. clone ( ) , & expr_set_1)
10261040 . unwrap ( ) ;
@@ -1046,7 +1060,7 @@ mod test {
10461060 [ "test1.c+test1.a" . to_string ( ) , "test1.b+test1.a" . to_string ( ) ]
10471061 . into_iter ( )
10481062 . collect ( ) ;
1049- let expr_set_1 = [
1063+ let expr_set_1 = vec ! [
10501064 (
10511065 "test1.c+test1.a" . to_string( ) ,
10521066 (
@@ -1066,9 +1080,8 @@ mod test {
10661080 ) ,
10671081 ) ,
10681082 ]
1069- . into_iter ( )
1070- . collect ( ) ;
1071- let expr_set_2 = [
1083+ . into ( ) ;
1084+ let expr_set_2 = vec ! [
10721085 (
10731086 "test1.c+test1.a" . to_string( ) ,
10741087 (
@@ -1088,8 +1101,7 @@ mod test {
10881101 ) ,
10891102 ) ,
10901103 ]
1091- . into_iter ( )
1092- . collect ( ) ;
1104+ . into ( ) ;
10931105 let project =
10941106 build_common_expr_project_plan ( join, affected_id. clone ( ) , & expr_set_1)
10951107 . unwrap ( ) ;
0 commit comments