@@ -378,10 +378,11 @@ impl ExecutionPlan for SortPreservingMergeExec {
378378
379379#[ cfg( test) ]
380380mod tests {
381+ use std:: collections:: HashSet ;
381382 use std:: fmt:: Formatter ;
382383 use std:: pin:: Pin ;
383384 use std:: sync:: Mutex ;
384- use std:: task:: { Context , Poll } ;
385+ use std:: task:: { ready , Context , Poll , Waker } ;
385386 use std:: time:: Duration ;
386387
387388 use super :: * ;
@@ -1285,13 +1286,50 @@ mod tests {
12851286 "# ) ;
12861287 }
12871288
1289+ #[ derive( Debug ) ]
1290+ struct CongestionState {
1291+ wakers : Vec < Waker > ,
1292+ unpolled_partitions : HashSet < usize > ,
1293+ }
1294+
1295+ #[ derive( Debug ) ]
1296+ struct Congestion {
1297+ congestion_state : Mutex < CongestionState > ,
1298+ }
1299+
1300+ impl Congestion {
1301+ fn new ( partition_count : usize ) -> Self {
1302+ Congestion {
1303+ congestion_state : Mutex :: new ( CongestionState {
1304+ wakers : vec ! [ ] ,
1305+ unpolled_partitions : ( 0usize ..partition_count) . collect ( ) ,
1306+ } ) ,
1307+ }
1308+ }
1309+
1310+ fn check_congested ( & self , partition : usize , cx : & mut Context < ' _ > ) -> Poll < ( ) > {
1311+ let mut state = self . congestion_state . lock ( ) . unwrap ( ) ;
1312+
1313+ state. unpolled_partitions . remove ( & partition) ;
1314+
1315+ if state. unpolled_partitions . is_empty ( ) {
1316+ state. wakers . iter ( ) . for_each ( |w| w. wake_by_ref ( ) ) ;
1317+ state. wakers . clear ( ) ;
1318+ Poll :: Ready ( ( ) )
1319+ } else {
1320+ state. wakers . push ( cx. waker ( ) . clone ( ) ) ;
1321+ Poll :: Pending
1322+ }
1323+ }
1324+ }
1325+
12881326 /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
12891327 /// partition is exhausted from the start, and if it is polled more than one, it panics.
12901328 #[ derive( Debug , Clone ) ]
12911329 struct CongestedExec {
12921330 schema : Schema ,
12931331 cache : PlanProperties ,
1294- congestion_cleared : Arc < Mutex < bool > > ,
1332+ congestion : Arc < Congestion > ,
12951333 }
12961334
12971335 impl CongestedExec {
@@ -1346,7 +1384,7 @@ mod tests {
13461384 Ok ( Box :: pin ( CongestedStream {
13471385 schema : Arc :: new ( self . schema . clone ( ) ) ,
13481386 none_polled_once : false ,
1349- congestion_cleared : Arc :: clone ( & self . congestion_cleared ) ,
1387+ congestion : Arc :: clone ( & self . congestion ) ,
13501388 partition,
13511389 } ) )
13521390 }
@@ -1373,39 +1411,30 @@ mod tests {
13731411 pub struct CongestedStream {
13741412 schema : SchemaRef ,
13751413 none_polled_once : bool ,
1376- congestion_cleared : Arc < Mutex < bool > > ,
1414+ congestion : Arc < Congestion > ,
13771415 partition : usize ,
13781416 }
13791417
13801418 impl Stream for CongestedStream {
13811419 type Item = Result < RecordBatch > ;
13821420 fn poll_next (
13831421 mut self : Pin < & mut Self > ,
1384- _cx : & mut Context < ' _ > ,
1422+ cx : & mut Context < ' _ > ,
13851423 ) -> Poll < Option < Self :: Item > > {
13861424 match self . partition {
13871425 0 => {
1426+ let _ = self . congestion . check_congested ( self . partition , cx) ;
13881427 if self . none_polled_once {
1389- panic ! ( "Exhausted stream is polled more than one " )
1428+ panic ! ( "Exhausted stream is polled more than once " )
13901429 } else {
13911430 self . none_polled_once = true ;
13921431 Poll :: Ready ( None )
13931432 }
13941433 }
1395- 1 => {
1396- let cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1397- if * cleared {
1398- Poll :: Ready ( None )
1399- } else {
1400- Poll :: Pending
1401- }
1402- }
1403- 2 => {
1404- let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1405- * cleared = true ;
1434+ _ => {
1435+ ready ! ( self . congestion. check_congested( self . partition, cx) ) ;
14061436 Poll :: Ready ( None )
14071437 }
1408- _ => unreachable ! ( ) ,
14091438 }
14101439 }
14111440 }
@@ -1420,10 +1449,16 @@ mod tests {
14201449 async fn test_spm_congestion ( ) -> Result < ( ) > {
14211450 let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
14221451 let schema = Schema :: new ( vec ! [ Field :: new( "c1" , DataType :: UInt64 , false ) ] ) ;
1452+ let properties = CongestedExec :: compute_properties ( Arc :: new ( schema. clone ( ) ) ) ;
1453+ let & partition_count = match properties. output_partitioning ( ) {
1454+ Partitioning :: RoundRobinBatch ( partitions) => partitions,
1455+ Partitioning :: Hash ( _, partitions) => partitions,
1456+ Partitioning :: UnknownPartitioning ( partitions) => partitions,
1457+ } ;
14231458 let source = CongestedExec {
14241459 schema : schema. clone ( ) ,
1425- cache : CongestedExec :: compute_properties ( Arc :: new ( schema . clone ( ) ) ) ,
1426- congestion_cleared : Arc :: new ( Mutex :: new ( false ) ) ,
1460+ cache : properties ,
1461+ congestion : Arc :: new ( Congestion :: new ( partition_count ) ) ,
14271462 } ;
14281463 let spm = SortPreservingMergeExec :: new (
14291464 [ PhysicalSortExpr :: new_default ( Arc :: new ( Column :: new (
0 commit comments