1717use std:: collections:: { HashMap , HashSet } ;
1818use std:: sync:: Arc ;
1919
20- use crate :: optimizer:: ApplyOrder ;
21- use crate :: { OptimizerConfig , OptimizerRule } ;
20+ use itertools:: Itertools ;
2221
2322use datafusion_common:: tree_node:: {
2423 Transformed , TransformedResult , TreeNode , TreeNodeRecursion ,
@@ -29,6 +28,7 @@ use datafusion_common::{
2928} ;
3029use datafusion_expr:: expr:: Alias ;
3130use datafusion_expr:: expr_rewriter:: replace_col;
31+ use datafusion_expr:: logical_plan:: tree_node:: unwrap_arc;
3232use datafusion_expr:: logical_plan:: {
3333 CrossJoin , Join , JoinType , LogicalPlan , TableScan , Union ,
3434} ;
@@ -38,7 +38,8 @@ use datafusion_expr::{
3838 ScalarFunctionDefinition , TableProviderFilterPushDown ,
3939} ;
4040
41- use itertools:: Itertools ;
41+ use crate :: optimizer:: ApplyOrder ;
42+ use crate :: { OptimizerConfig , OptimizerRule } ;
4243
4344/// Optimizer rule for pushing (moving) filter expressions down in a plan so
4445/// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
407408 right : & LogicalPlan ,
408409 on_filter : Vec < Expr > ,
409410 is_inner_join : bool ,
410- ) -> Result < LogicalPlan > {
411+ ) -> Result < Transformed < LogicalPlan > > {
411412 let on_filter_empty = on_filter. is_empty ( ) ;
412413 // Get pushable predicates from current optimizer state
413414 let ( left_preserved, right_preserved) = lr_is_preserved ( join_plan) ?;
@@ -505,41 +506,43 @@ fn push_down_all_join(
505506 // wrap the join on the filter whose predicates must be kept
506507 match conjunction ( keep_predicates) {
507508 Some ( predicate) => {
508- Filter :: try_new ( predicate, Arc :: new ( plan) ) . map ( LogicalPlan :: Filter )
509+ let new_filter_plan = Filter :: try_new ( predicate, Arc :: new ( plan) ) ?;
510+ Ok ( Transformed :: yes ( LogicalPlan :: Filter ( new_filter_plan) ) )
509511 }
510- None => Ok ( plan) ,
512+ None => Ok ( Transformed :: no ( plan) ) ,
511513 }
512514}
513515
514516fn push_down_join (
515517 plan : & LogicalPlan ,
516518 join : & Join ,
517519 parent_predicate : Option < & Expr > ,
518- ) -> Result < Option < LogicalPlan > > {
519- let predicates = match parent_predicate {
520- Some ( parent_predicate) => split_conjunction_owned ( parent_predicate. clone ( ) ) ,
521- None => vec ! [ ] ,
522- } ;
520+ ) -> Result < Transformed < LogicalPlan > > {
521+ // Split the parent predicate into individual conjunctive parts.
522+ let predicates = parent_predicate
523+ . map_or_else ( Vec :: new, |pred| split_conjunction_owned ( pred. clone ( ) ) ) ;
523524
524- // Convert JOIN ON predicate to Predicates
525+ // Extract conjunctions from the JOIN's ON filter, if present.
525526 let on_filters = join
526527 . filter
527528 . as_ref ( )
528- . map ( |e| split_conjunction_owned ( e. clone ( ) ) )
529- . unwrap_or_default ( ) ;
529+ . map_or_else ( Vec :: new, |filter| split_conjunction_owned ( filter. clone ( ) ) ) ;
530530
531531 let mut is_inner_join = false ;
532532 let infer_predicates = if join. join_type == JoinType :: Inner {
533533 is_inner_join = true ;
534+
534535 // Only allow both side key is column.
535536 let join_col_keys = join
536537 . on
537538 . iter ( )
538- . flat_map ( |( l, r) | match ( l. try_into_col ( ) , r. try_into_col ( ) ) {
539- ( Ok ( l_col) , Ok ( r_col) ) => Some ( ( l_col, r_col) ) ,
540- _ => None ,
539+ . filter_map ( |( l, r) | {
540+ let left_col = l. try_into_col ( ) . ok ( ) ?;
541+ let right_col = r. try_into_col ( ) . ok ( ) ?;
542+ Some ( ( left_col, right_col) )
541543 } )
542544 . collect :: < Vec < _ > > ( ) ;
545+
543546 // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
544547 // For inner joins, duplicate filters for joined columns so filters can be pushed down
545548 // to both sides. Take the following query as an example:
@@ -559,6 +562,7 @@ fn push_down_join(
559562 . chain ( on_filters. iter ( ) )
560563 . filter_map ( |predicate| {
561564 let mut join_cols_to_replace = HashMap :: new ( ) ;
565+
562566 let columns = match predicate. to_columns ( ) {
563567 Ok ( columns) => columns,
564568 Err ( e) => return Some ( Err ( e) ) ,
@@ -596,20 +600,32 @@ fn push_down_join(
596600 } ;
597601
598602 if on_filters. is_empty ( ) && predicates. is_empty ( ) && infer_predicates. is_empty ( ) {
599- return Ok ( None ) ;
603+ return Ok ( Transformed :: no ( plan . clone ( ) ) ) ;
600604 }
601- Ok ( Some ( push_down_all_join (
605+
606+ match push_down_all_join (
602607 predicates,
603608 infer_predicates,
604609 plan,
605610 & join. left ,
606611 & join. right ,
607612 on_filters,
608613 is_inner_join,
609- ) ?) )
614+ ) {
615+ Ok ( plan) => Ok ( Transformed :: yes ( plan. data ) ) ,
616+ Err ( e) => Err ( e) ,
617+ }
610618}
611619
612620impl OptimizerRule for PushDownFilter {
621+ fn try_optimize (
622+ & self ,
623+ _plan : & LogicalPlan ,
624+ _config : & dyn OptimizerConfig ,
625+ ) -> Result < Option < LogicalPlan > > {
626+ internal_err ! ( "Should have called PushDownFilter::rewrite" )
627+ }
628+
613629 fn name ( & self ) -> & str {
614630 "push_down_filter"
615631 }
@@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
618634 Some ( ApplyOrder :: TopDown )
619635 }
620636
621- fn try_optimize (
637+ fn supports_rewrite ( & self ) -> bool {
638+ true
639+ }
640+
641+ fn rewrite (
622642 & self ,
623- plan : & LogicalPlan ,
643+ plan : LogicalPlan ,
624644 _config : & dyn OptimizerConfig ,
625- ) -> Result < Option < LogicalPlan > > {
645+ ) -> Result < Transformed < LogicalPlan > > {
626646 let filter = match plan {
627- LogicalPlan :: Filter ( filter) => filter,
628- // we also need to pushdown filter in Join.
629- LogicalPlan :: Join ( join) => return push_down_join ( plan, join, None ) ,
630- _ => return Ok ( None ) ,
647+ LogicalPlan :: Filter ( ref filter) => filter,
648+ LogicalPlan :: Join ( ref join) => return push_down_join ( & plan, join, None ) ,
649+ _ => return Ok ( Transformed :: no ( plan) ) ,
631650 } ;
632651
633652 let child_plan = filter. input . as_ref ( ) ;
634653 let new_plan = match child_plan {
635- LogicalPlan :: Filter ( child_filter) => {
654+ LogicalPlan :: Filter ( ref child_filter) => {
636655 let parents_predicates = split_conjunction ( & filter. predicate ) ;
637656 let set: HashSet < & & Expr > = parents_predicates. iter ( ) . collect ( ) ;
638657
@@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
652671 new_predicate,
653672 child_filter. input . clone ( ) ,
654673 ) ?) ;
655- self . try_optimize ( & new_filter, _config) ?
656- . unwrap_or ( new_filter)
674+ self . rewrite ( new_filter, _config) ?. data
657675 }
658676 LogicalPlan :: Repartition ( _)
659677 | LogicalPlan :: Distinct ( _)
660678 | LogicalPlan :: Sort ( _) => {
661- // commutable
662679 let new_filter = plan. with_new_exprs (
663680 plan. expressions ( ) ,
664681 vec ! [ child_plan. inputs( ) [ 0 ] . clone( ) ] ,
665682 ) ?;
666683 child_plan. with_new_exprs ( child_plan. expressions ( ) , vec ! [ new_filter] ) ?
667684 }
668- LogicalPlan :: SubqueryAlias ( subquery_alias) => {
685+ LogicalPlan :: SubqueryAlias ( ref subquery_alias) => {
669686 let mut replace_map = HashMap :: new ( ) ;
670687 for ( i, ( qualifier, field) ) in
671688 subquery_alias. input . schema ( ) . iter ( ) . enumerate ( )
@@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
685702 ) ?) ;
686703 child_plan. with_new_exprs ( child_plan. expressions ( ) , vec ! [ new_filter] ) ?
687704 }
688- LogicalPlan :: Projection ( projection) => {
705+ LogicalPlan :: Projection ( ref projection) => {
689706 // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
690707 // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
691708 // collect projection.
@@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
742759 }
743760 }
744761 }
745- None => return Ok ( None ) ,
762+ None => return Ok ( Transformed :: no ( plan ) ) ,
746763 }
747764 }
748- LogicalPlan :: Union ( union) => {
765+ LogicalPlan :: Union ( ref union) => {
749766 let mut inputs = Vec :: with_capacity ( union. inputs . len ( ) ) ;
750767 for input in & union. inputs {
751768 let mut replace_map = HashMap :: new ( ) ;
@@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
770787 schema : plan. schema ( ) . clone ( ) ,
771788 } )
772789 }
773- LogicalPlan :: Aggregate ( agg) => {
790+ LogicalPlan :: Aggregate ( ref agg) => {
774791 // We can push down Predicate which in groupby_expr.
775792 let group_expr_columns = agg
776793 . group_expr
@@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
821838 None => new_agg,
822839 }
823840 }
824- LogicalPlan :: Join ( join) => {
825- match push_down_join ( & filter. input , join, Some ( & filter. predicate ) ) ? {
826- Some ( optimized_plan) => optimized_plan,
827- None => return Ok ( None ) ,
828- }
841+ LogicalPlan :: Join ( ref join) => {
842+ push_down_join (
843+ & unwrap_arc ( filter. clone ( ) . input ) ,
844+ join,
845+ Some ( & filter. predicate ) ,
846+ ) ?
847+ . data
829848 }
830- LogicalPlan :: CrossJoin ( cross_join) => {
849+ LogicalPlan :: CrossJoin ( ref cross_join) => {
831850 let predicates = split_conjunction_owned ( filter. predicate . clone ( ) ) ;
832851 let join = convert_cross_join_to_inner_join ( cross_join. clone ( ) ) ?;
833852 let join_plan = LogicalPlan :: Join ( join) ;
@@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
843862 vec ! [ ] ,
844863 true ,
845864 ) ?;
846- convert_to_cross_join_if_beneficial ( plan) ?
865+ convert_to_cross_join_if_beneficial ( plan. data ) ?
847866 }
848- LogicalPlan :: TableScan ( scan) => {
867+ LogicalPlan :: TableScan ( ref scan) => {
849868 let filter_predicates = split_conjunction ( & filter. predicate ) ;
850869 let results = scan
851870 . source
@@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
892911 None => new_scan,
893912 }
894913 }
895- LogicalPlan :: Extension ( extension_plan) => {
914+ LogicalPlan :: Extension ( ref extension_plan) => {
896915 let prevent_cols =
897916 extension_plan. node . prevent_predicate_push_down_columns ( ) ;
898917
@@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
935954 None => new_extension,
936955 }
937956 }
938- _ => return Ok ( None ) ,
957+ _ => return Ok ( Transformed :: no ( plan ) ) ,
939958 } ;
940- Ok ( Some ( new_plan) )
959+
960+ Ok ( Transformed :: yes ( new_plan) )
941961 }
942962}
943963
@@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
10241044
10251045#[ cfg( test) ]
10261046mod tests {
1027- use super :: * ;
10281047 use std:: any:: Any ;
10291048 use std:: fmt:: { Debug , Formatter } ;
10301049
1031- use crate :: optimizer:: Optimizer ;
1032- use crate :: rewrite_disjunctive_predicate:: RewriteDisjunctivePredicate ;
1033- use crate :: test:: * ;
1034- use crate :: OptimizerContext ;
1035-
10361050 use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
1051+ use async_trait:: async_trait;
1052+
10371053 use datafusion_common:: ScalarValue ;
10381054 use datafusion_expr:: expr:: ScalarFunction ;
10391055 use datafusion_expr:: logical_plan:: table_scan;
@@ -1043,7 +1059,13 @@ mod tests {
10431059 Volatility ,
10441060 } ;
10451061
1046- use async_trait:: async_trait;
1062+ use crate :: optimizer:: Optimizer ;
1063+ use crate :: rewrite_disjunctive_predicate:: RewriteDisjunctivePredicate ;
1064+ use crate :: test:: * ;
1065+ use crate :: OptimizerContext ;
1066+
1067+ use super :: * ;
1068+
10471069 fn observe ( _plan : & LogicalPlan , _rule : & dyn OptimizerRule ) { }
10481070
10491071 fn assert_optimized_plan_eq ( plan : LogicalPlan , expected : & str ) -> Result < ( ) > {
@@ -2298,9 +2320,9 @@ mod tests {
22982320 table_scan_with_pushdown_provider ( TableProviderFilterPushDown :: Inexact ) ?;
22992321
23002322 let optimized_plan = PushDownFilter :: new ( )
2301- . try_optimize ( & plan, & OptimizerContext :: new ( ) )
2323+ . rewrite ( plan, & OptimizerContext :: new ( ) )
23022324 . expect ( "failed to optimize plan" )
2303- . unwrap ( ) ;
2325+ . data ;
23042326
23052327 let expected = "\
23062328 Filter: a = Int64(1)\
@@ -2667,8 +2689,9 @@ Projection: a, b
26672689 // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
26682690 // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
26692691 let optimized_plan = PushDownFilter :: new ( )
2670- . try_optimize ( & plan, & OptimizerContext :: new ( ) ) ?
2671- . expect ( "failed to optimize plan" ) ;
2692+ . rewrite ( plan, & OptimizerContext :: new ( ) )
2693+ . expect ( "failed to optimize plan" )
2694+ . data ;
26722695 assert_optimized_plan_eq ( optimized_plan, expected)
26732696 }
26742697
0 commit comments