@@ -191,17 +191,15 @@ impl EnforceDistribution {
191191impl PhysicalOptimizerRule for EnforceDistribution {
192192 fn optimize (
193193 & self ,
194- plan : Arc < dyn ExecutionPlan > ,
194+ mut plan : Arc < dyn ExecutionPlan > ,
195195 config : & ConfigOptions ,
196196 ) -> Result < Arc < dyn ExecutionPlan > > {
197197 let top_down_join_key_reordering = config. optimizer . top_down_join_key_reordering ;
198198
199199 let adjusted = if top_down_join_key_reordering {
200200 // Run a top-down process to adjust input key ordering recursively
201- let plan_requirements = PlanWithKeyRequirements :: new ( plan) ;
202- let adjusted =
203- plan_requirements. transform_down_old ( & adjust_input_keys_ordering) ?;
204- adjusted. plan
201+ plan. transform_down_with_payload ( & mut adjust_input_keys_ordering, None ) ?;
202+ plan
205203 } else {
206204 // Run a bottom-up process
207205 plan. transform_up_old ( & |plan| {
@@ -269,12 +267,15 @@ impl PhysicalOptimizerRule for EnforceDistribution {
269267/// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements
270268/// 5) For other types of operators, by default, pushdown the parent requirements to children.
271269///
270+ type RequiredKeyOrdering = Option < Vec < Arc < dyn PhysicalExpr > > > ;
271+
272272fn adjust_input_keys_ordering (
273- requirements : PlanWithKeyRequirements ,
274- ) -> Result < Transformed < PlanWithKeyRequirements > > {
275- let parent_required = requirements. required_key_ordering . clone ( ) ;
276- let plan_any = requirements. plan . as_any ( ) ;
277- let transformed = if let Some ( HashJoinExec {
273+ plan : & mut Arc < dyn ExecutionPlan > ,
274+ required_key_ordering : RequiredKeyOrdering ,
275+ ) -> Result < ( TreeNodeRecursion , Vec < RequiredKeyOrdering > ) > {
276+ let parent_required = required_key_ordering. unwrap_or_default ( ) . clone ( ) ;
277+ let plan_any = plan. as_any ( ) ;
278+ if let Some ( HashJoinExec {
278279 left,
279280 right,
280281 on,
@@ -299,13 +300,15 @@ fn adjust_input_keys_ordering(
299300 * null_equals_null,
300301 ) ?) as Arc < dyn ExecutionPlan > )
301302 } ;
302- Some ( reorder_partitioned_join_keys (
303- requirements . plan . clone ( ) ,
303+ let ( new_plan , request_key_ordering ) = reorder_partitioned_join_keys (
304+ plan. clone ( ) ,
304305 & parent_required,
305306 on,
306307 vec ! [ ] ,
307308 & join_constructor,
308- ) ?)
309+ ) ?;
310+ * plan = new_plan;
311+ Ok ( ( TreeNodeRecursion :: Continue , request_key_ordering) )
309312 }
310313 PartitionMode :: CollectLeft => {
311314 let new_right_request = match join_type {
@@ -323,30 +326,28 @@ fn adjust_input_keys_ordering(
323326 } ;
324327
325328 // Push down requirements to the right side
326- Some ( PlanWithKeyRequirements {
327- plan : requirements. plan . clone ( ) ,
328- required_key_ordering : vec ! [ ] ,
329- request_key_ordering : vec ! [ None , new_right_request] ,
330- } )
329+ Ok ( ( TreeNodeRecursion :: Continue , vec ! [ None , new_right_request] ) )
331330 }
332331 PartitionMode :: Auto => {
333332 // Can not satisfy, clear the current requirements and generate new empty requirements
334- Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) )
333+ Ok ( (
334+ TreeNodeRecursion :: Continue ,
335+ vec ! [ None ; plan. children( ) . len( ) ] ,
336+ ) )
335337 }
336338 }
337339 } else if let Some ( CrossJoinExec { left, .. } ) =
338340 plan_any. downcast_ref :: < CrossJoinExec > ( )
339341 {
340342 let left_columns_len = left. schema ( ) . fields ( ) . len ( ) ;
341343 // Push down requirements to the right side
342- Some ( PlanWithKeyRequirements {
343- plan : requirements. plan . clone ( ) ,
344- required_key_ordering : vec ! [ ] ,
345- request_key_ordering : vec ! [
344+ Ok ( (
345+ TreeNodeRecursion :: Continue ,
346+ vec ! [
346347 None ,
347348 shift_right_required( & parent_required, left_columns_len) ,
348349 ] ,
349- } )
350+ ) )
350351 } else if let Some ( SortMergeJoinExec {
351352 left,
352353 right,
@@ -368,26 +369,38 @@ fn adjust_input_keys_ordering(
368369 * null_equals_null,
369370 ) ?) as Arc < dyn ExecutionPlan > )
370371 } ;
371- Some ( reorder_partitioned_join_keys (
372- requirements . plan . clone ( ) ,
372+ let ( new_plan , request_key_ordering ) = reorder_partitioned_join_keys (
373+ plan. clone ( ) ,
373374 & parent_required,
374375 on,
375376 sort_options. clone ( ) ,
376377 & join_constructor,
377- ) ?)
378+ ) ?;
379+ * plan = new_plan;
380+ Ok ( ( TreeNodeRecursion :: Continue , request_key_ordering) )
378381 } else if let Some ( aggregate_exec) = plan_any. downcast_ref :: < AggregateExec > ( ) {
379382 if !parent_required. is_empty ( ) {
380383 match aggregate_exec. mode ( ) {
381- AggregateMode :: FinalPartitioned => Some ( reorder_aggregate_keys (
382- requirements. plan . clone ( ) ,
383- & parent_required,
384- aggregate_exec,
385- ) ?) ,
386- _ => Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) ) ,
384+ AggregateMode :: FinalPartitioned => {
385+ let ( new_plan, request_key_ordering) = reorder_aggregate_keys (
386+ plan. clone ( ) ,
387+ & parent_required,
388+ aggregate_exec,
389+ ) ?;
390+ * plan = new_plan;
391+ Ok ( ( TreeNodeRecursion :: Continue , request_key_ordering) )
392+ }
393+ _ => Ok ( (
394+ TreeNodeRecursion :: Continue ,
395+ vec ! [ None ; plan. children( ) . len( ) ] ,
396+ ) ) ,
387397 }
388398 } else {
389399 // Keep everything unchanged
390- None
400+ Ok ( (
401+ TreeNodeRecursion :: Continue ,
402+ vec ! [ None ; plan. children( ) . len( ) ] ,
403+ ) )
391404 }
392405 } else if let Some ( proj) = plan_any. downcast_ref :: < ProjectionExec > ( ) {
393406 let expr = proj. expr ( ) ;
@@ -396,34 +409,33 @@ fn adjust_input_keys_ordering(
396409 // Construct a mapping from new name to the the orginal Column
397410 let new_required = map_columns_before_projection ( & parent_required, expr) ;
398411 if new_required. len ( ) == parent_required. len ( ) {
399- Some ( PlanWithKeyRequirements {
400- plan : requirements. plan . clone ( ) ,
401- required_key_ordering : vec ! [ ] ,
402- request_key_ordering : vec ! [ Some ( new_required. clone( ) ) ] ,
403- } )
412+ Ok ( (
413+ TreeNodeRecursion :: Continue ,
414+ vec ! [ Some ( new_required. clone( ) ) ] ,
415+ ) )
404416 } else {
405417 // Can not satisfy, clear the current requirements and generate new empty requirements
406- Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) )
418+ Ok ( (
419+ TreeNodeRecursion :: Continue ,
420+ vec ! [ None ; plan. children( ) . len( ) ] ,
421+ ) )
407422 }
408423 } else if plan_any. downcast_ref :: < RepartitionExec > ( ) . is_some ( )
409424 || plan_any. downcast_ref :: < CoalescePartitionsExec > ( ) . is_some ( )
410425 || plan_any. downcast_ref :: < WindowAggExec > ( ) . is_some ( )
411426 {
412- Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) )
427+ Ok ( (
428+ TreeNodeRecursion :: Continue ,
429+ vec ! [ None ; plan. children( ) . len( ) ] ,
430+ ) )
413431 } else {
414432 // By default, push down the parent requirements to children
415- let children_len = requirements. plan . children ( ) . len ( ) ;
416- Some ( PlanWithKeyRequirements {
417- plan : requirements. plan . clone ( ) ,
418- required_key_ordering : vec ! [ ] ,
419- request_key_ordering : vec ! [ Some ( parent_required. clone( ) ) ; children_len] ,
420- } )
421- } ;
422- Ok ( if let Some ( transformed) = transformed {
423- Transformed :: Yes ( transformed)
424- } else {
425- Transformed :: No ( requirements)
426- } )
433+ let children_len = plan. children ( ) . len ( ) ;
434+ Ok ( (
435+ TreeNodeRecursion :: Continue ,
436+ vec ! [ Some ( parent_required. clone( ) ) ; children_len] ,
437+ ) )
438+ }
427439}
428440
429441fn reorder_partitioned_join_keys < F > (
@@ -432,7 +444,7 @@ fn reorder_partitioned_join_keys<F>(
432444 on : & [ ( Column , Column ) ] ,
433445 sort_options : Vec < SortOptions > ,
434446 join_constructor : & F ,
435- ) -> Result < PlanWithKeyRequirements >
447+ ) -> Result < ( Arc < dyn ExecutionPlan > , Vec < RequiredKeyOrdering > ) >
436448where
437449 F : Fn ( ( Vec < ( Column , Column ) > , Vec < SortOptions > ) ) -> Result < Arc < dyn ExecutionPlan > > ,
438450{
@@ -455,35 +467,29 @@ where
455467 new_sort_options. push ( sort_options[ new_positions[ idx] ] )
456468 }
457469
458- Ok ( PlanWithKeyRequirements {
459- plan : join_constructor ( ( new_join_on, new_sort_options) ) ?,
460- required_key_ordering : vec ! [ ] ,
461- request_key_ordering : vec ! [ Some ( left_keys) , Some ( right_keys) ] ,
462- } )
470+ Ok ( (
471+ join_constructor ( ( new_join_on, new_sort_options) ) ?,
472+ vec ! [ Some ( left_keys) , Some ( right_keys) ] ,
473+ ) )
463474 } else {
464- Ok ( PlanWithKeyRequirements {
465- plan : join_plan,
466- required_key_ordering : vec ! [ ] ,
467- request_key_ordering : vec ! [ Some ( left_keys) , Some ( right_keys) ] ,
468- } )
475+ Ok ( ( join_plan, vec ! [ Some ( left_keys) , Some ( right_keys) ] ) )
469476 }
470477 } else {
471- Ok ( PlanWithKeyRequirements {
472- plan : join_plan,
473- required_key_ordering : vec ! [ ] ,
474- request_key_ordering : vec ! [
478+ Ok ( (
479+ join_plan,
480+ vec ! [
475481 Some ( join_key_pairs. left_keys) ,
476482 Some ( join_key_pairs. right_keys) ,
477483 ] ,
478- } )
484+ ) )
479485 }
480486}
481487
482488fn reorder_aggregate_keys (
483489 agg_plan : Arc < dyn ExecutionPlan > ,
484490 parent_required : & [ Arc < dyn PhysicalExpr > ] ,
485491 agg_exec : & AggregateExec ,
486- ) -> Result < PlanWithKeyRequirements > {
492+ ) -> Result < ( Arc < dyn ExecutionPlan > , Vec < RequiredKeyOrdering > ) > {
487493 let output_columns = agg_exec
488494 . group_by ( )
489495 . expr ( )
@@ -501,11 +507,15 @@ fn reorder_aggregate_keys(
501507 || !agg_exec. group_by ( ) . null_expr ( ) . is_empty ( )
502508 || physical_exprs_equal ( & output_exprs, parent_required)
503509 {
504- Ok ( PlanWithKeyRequirements :: new ( agg_plan) )
510+ let request_key_ordering = vec ! [ None ; agg_plan. children( ) . len( ) ] ;
511+ Ok ( ( agg_plan, request_key_ordering) )
505512 } else {
506513 let new_positions = expected_expr_positions ( & output_exprs, parent_required) ;
507514 match new_positions {
508- None => Ok ( PlanWithKeyRequirements :: new ( agg_plan) ) ,
515+ None => {
516+ let request_key_ordering = vec ! [ None ; agg_plan. children( ) . len( ) ] ;
517+ Ok ( ( agg_plan, request_key_ordering) )
518+ }
509519 Some ( positions) => {
510520 let new_partial_agg = if let Some ( agg_exec) =
511521 agg_exec. input ( ) . as_any ( ) . downcast_ref :: < AggregateExec > ( )
@@ -577,11 +587,13 @@ fn reorder_aggregate_keys(
577587 . push ( ( Arc :: new ( Column :: new ( name, idx) ) as _ , name. clone ( ) ) )
578588 }
579589 // TODO merge adjacent Projections if there are
580- Ok ( PlanWithKeyRequirements :: new ( Arc :: new (
581- ProjectionExec :: try_new ( proj_exprs, new_final_agg) ?,
582- ) ) )
590+ let new_plan =
591+ Arc :: new ( ProjectionExec :: try_new ( proj_exprs, new_final_agg) ?) ;
592+ let request_key_ordering = vec ! [ None ; new_plan. children( ) . len( ) ] ;
593+ Ok ( ( new_plan, request_key_ordering) )
583594 } else {
584- Ok ( PlanWithKeyRequirements :: new ( agg_plan) )
595+ let request_key_ordering = vec ! [ None ; agg_plan. children( ) . len( ) ] ;
596+ Ok ( ( agg_plan, request_key_ordering) )
585597 }
586598 }
587599 }
@@ -1539,93 +1551,6 @@ struct JoinKeyPairs {
15391551 right_keys : Vec < Arc < dyn PhysicalExpr > > ,
15401552}
15411553
1542- #[ derive( Debug , Clone ) ]
1543- struct PlanWithKeyRequirements {
1544- plan : Arc < dyn ExecutionPlan > ,
1545- /// Parent required key ordering
1546- required_key_ordering : Vec < Arc < dyn PhysicalExpr > > ,
1547- /// The request key ordering to children
1548- request_key_ordering : Vec < Option < Vec < Arc < dyn PhysicalExpr > > > > ,
1549- }
1550-
1551- impl PlanWithKeyRequirements {
1552- fn new ( plan : Arc < dyn ExecutionPlan > ) -> Self {
1553- let children_len = plan. children ( ) . len ( ) ;
1554- PlanWithKeyRequirements {
1555- plan,
1556- required_key_ordering : vec ! [ ] ,
1557- request_key_ordering : vec ! [ None ; children_len] ,
1558- }
1559- }
1560-
1561- fn children ( & self ) -> Vec < PlanWithKeyRequirements > {
1562- let plan_children = self . plan . children ( ) ;
1563- assert_eq ! ( plan_children. len( ) , self . request_key_ordering. len( ) ) ;
1564- plan_children
1565- . into_iter ( )
1566- . zip ( self . request_key_ordering . clone ( ) )
1567- . map ( |( child, required) | {
1568- let from_parent = required. unwrap_or_default ( ) ;
1569- let length = child. children ( ) . len ( ) ;
1570- PlanWithKeyRequirements {
1571- plan : child,
1572- required_key_ordering : from_parent,
1573- request_key_ordering : vec ! [ None ; length] ,
1574- }
1575- } )
1576- . collect ( )
1577- }
1578- }
1579-
1580- impl TreeNode for PlanWithKeyRequirements {
1581- fn apply_children < F > ( & self , f : & mut F ) -> Result < TreeNodeRecursion >
1582- where
1583- F : FnMut ( & Self ) -> Result < TreeNodeRecursion > ,
1584- {
1585- self . children ( ) . iter ( ) . for_each_till_continue ( f)
1586- }
1587-
1588- fn map_children < F > ( self , transform : F ) -> Result < Self >
1589- where
1590- F : FnMut ( Self ) -> Result < Self > ,
1591- {
1592- let children = self . children ( ) ;
1593- if !children. is_empty ( ) {
1594- let new_children: Result < Vec < _ > > =
1595- children. into_iter ( ) . map ( transform) . collect ( ) ;
1596-
1597- let children_plans = new_children?
1598- . into_iter ( )
1599- . map ( |child| child. plan )
1600- . collect :: < Vec < _ > > ( ) ;
1601- let new_plan = with_new_children_if_necessary ( self . plan , children_plans) ?;
1602- Ok ( PlanWithKeyRequirements {
1603- plan : new_plan. into ( ) ,
1604- required_key_ordering : self . required_key_ordering ,
1605- request_key_ordering : self . request_key_ordering ,
1606- } )
1607- } else {
1608- Ok ( self )
1609- }
1610- }
1611-
1612- fn transform_children < F > ( & mut self , f : & mut F ) -> Result < TreeNodeRecursion >
1613- where
1614- F : FnMut ( & mut Self ) -> Result < TreeNodeRecursion > ,
1615- {
1616- let mut children = self . children ( ) ;
1617- if !children. is_empty ( ) {
1618- let tnr = children. iter_mut ( ) . for_each_till_continue ( f) ?;
1619- let children_plans = children. into_iter ( ) . map ( |c| c. plan ) . collect ( ) ;
1620- self . plan =
1621- with_new_children_if_necessary ( self . plan . clone ( ) , children_plans) ?. into ( ) ;
1622- Ok ( tnr)
1623- } else {
1624- Ok ( TreeNodeRecursion :: Continue )
1625- }
1626- }
1627- }
1628-
16291554/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
16301555#[ cfg( feature = "parquet" ) ]
16311556#[ cfg( test) ]
0 commit comments