@@ -49,8 +49,10 @@ use std::sync::Arc;
4949mod group_values;
5050mod no_grouping;
5151mod order;
52+ mod priority_queue;
5253mod row_hash;
5354
55+ use crate :: physical_plan:: aggregates:: priority_queue:: GroupedPriorityQueueAggregateStream ;
5456pub use datafusion_expr:: AggregateFunction ;
5557use datafusion_physical_expr:: aggregate:: is_order_sensitive;
5658pub use datafusion_physical_expr:: expressions:: create_aggregate_expr;
@@ -228,14 +230,16 @@ impl PartialEq for PhysicalGroupBy {
228230
229231enum StreamType {
230232 AggregateStream ( AggregateStream ) ,
231- GroupedHashAggregateStream ( GroupedHashAggregateStream ) ,
233+ GroupedHash ( GroupedHashAggregateStream ) ,
234+ GroupedPriorityQueue ( GroupedPriorityQueueAggregateStream ) ,
232235}
233236
234237impl From < StreamType > for SendableRecordBatchStream {
235238 fn from ( stream : StreamType ) -> Self {
236239 match stream {
237240 StreamType :: AggregateStream ( stream) => Box :: pin ( stream) ,
238- StreamType :: GroupedHashAggregateStream ( stream) => Box :: pin ( stream) ,
241+ StreamType :: GroupedHash ( stream) => Box :: pin ( stream) ,
242+ StreamType :: GroupedPriorityQueue ( stream) => Box :: pin ( stream) ,
239243 }
240244 }
241245}
@@ -265,6 +269,8 @@ pub struct AggregateExec {
265269 pub ( crate ) filter_expr : Vec < Option < Arc < dyn PhysicalExpr > > > ,
266270 /// (ORDER BY clause) expression for each aggregate expression
267271 pub ( crate ) order_by_expr : Vec < Option < LexOrdering > > ,
272+ /// Set if the output of this aggregation is truncated by a upstream sort/limit clause
273+ pub ( crate ) limit : Option < usize > ,
268274 /// Input plan, could be a partial aggregate or the input to the aggregate
269275 pub ( crate ) input : Arc < dyn ExecutionPlan > ,
270276 /// Schema after the aggregate is applied
@@ -670,6 +676,7 @@ impl AggregateExec {
670676 metrics : ExecutionPlanMetricsSet :: new ( ) ,
671677 aggregation_ordering,
672678 required_input_ordering,
679+ limit : None ,
673680 } )
674681 }
675682
@@ -718,15 +725,29 @@ impl AggregateExec {
718725 partition : usize ,
719726 context : Arc < TaskContext > ,
720727 ) -> Result < StreamType > {
728+ // no group by at all
721729 if self . group_by . expr . is_empty ( ) {
722- Ok ( StreamType :: AggregateStream ( AggregateStream :: new (
730+ return Ok ( StreamType :: AggregateStream ( AggregateStream :: new (
723731 self , context, partition,
724- ) ?) )
725- } else {
726- Ok ( StreamType :: GroupedHashAggregateStream (
727- GroupedHashAggregateStream :: new ( self , context, partition) ?,
728- ) )
732+ ) ?) ) ;
733+ }
734+
735+ // grouping by an expression that has a sort/limit upstream
736+ let is_minmax =
737+ GroupedPriorityQueueAggregateStream :: get_minmax_desc ( self ) . is_some ( ) ;
738+ if self . limit . is_some ( ) && is_minmax {
739+ println ! ( "Using limited priority queue aggregation" ) ;
740+ return Ok ( StreamType :: GroupedPriorityQueue (
741+ GroupedPriorityQueueAggregateStream :: new (
742+ self , context, partition, self . limit ,
743+ ) ?,
744+ ) ) ;
729745 }
746+
747+ // grouping by something else and we need to just materialize all results
748+ Ok ( StreamType :: GroupedHash ( GroupedHashAggregateStream :: new (
749+ self , context, partition,
750+ ) ?) )
730751 }
731752}
732753
@@ -1149,7 +1170,7 @@ fn evaluate(
11491170}
11501171
11511172/// Evaluates expressions against a record batch.
1152- fn evaluate_many (
1173+ pub fn evaluate_many (
11531174 expr : & [ Vec < Arc < dyn PhysicalExpr > > ] ,
11541175 batch : & RecordBatch ,
11551176) -> Result < Vec < Vec < ArrayRef > > > {
@@ -1172,7 +1193,17 @@ fn evaluate_optional(
11721193 . collect :: < Result < Vec < _ > > > ( )
11731194}
11741195
1175- fn evaluate_group_by (
1196+ /// Evaluate a group by expression against a `RecordBatch`
1197+ ///
1198+ /// Arguments:
1199+ /// `group_by`: the expression to evaluate
1200+ /// `batch`: the `RecordBatch` to evaluate against
1201+ ///
1202+ /// Returns: A Vec of Vecs of Array of results
1203+ /// The outer Vect appears to be for grouping sets
1204+ /// The inner Vect contains the results per expression
1205+ /// The inner-inner Array contains the results per row
1206+ pub fn evaluate_group_by (
11761207 group_by : & PhysicalGroupBy ,
11771208 batch : & RecordBatch ,
11781209) -> Result < Vec < Vec < ArrayRef > > > {
@@ -1841,10 +1872,10 @@ mod tests {
18411872 assert ! ( matches!( stream, StreamType :: AggregateStream ( _) ) ) ;
18421873 }
18431874 1 => {
1844- assert ! ( matches!( stream, StreamType :: GroupedHashAggregateStream ( _) ) ) ;
1875+ assert ! ( matches!( stream, StreamType :: GroupedHash ( _) ) ) ;
18451876 }
18461877 2 => {
1847- assert ! ( matches!( stream, StreamType :: GroupedHashAggregateStream ( _) ) ) ;
1878+ assert ! ( matches!( stream, StreamType :: GroupedHash ( _) ) ) ;
18481879 }
18491880 _ => panic ! ( "Unknown version: {version}" ) ,
18501881 }
0 commit comments