@@ -381,7 +381,7 @@ mod tests {
381381 use std:: fmt:: Formatter ;
382382 use std:: pin:: Pin ;
383383 use std:: sync:: Mutex ;
384- use std:: task:: { Context , Poll } ;
384+ use std:: task:: { ready , Context , Poll , Waker } ;
385385 use std:: time:: Duration ;
386386
387387 use super :: * ;
@@ -1285,13 +1285,45 @@ mod tests {
12851285 "# ) ;
12861286 }
12871287
1288+ #[ derive( Debug ) ]
1289+ struct Congestion {
1290+ congestion_cleared : Mutex < Option < Vec < Waker > > > ,
1291+ }
1292+
1293+ impl Congestion {
1294+ fn new ( ) -> Self {
1295+ Congestion {
1296+ congestion_cleared : Mutex :: new ( Some ( vec ! [ ] ) ) ,
1297+ }
1298+ }
1299+
1300+ fn clear_congestion ( & self ) {
1301+ let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1302+ if let Some ( wakers) = & mut * cleared {
1303+ wakers. iter ( ) . for_each ( |w| w. wake_by_ref ( ) ) ;
1304+ * cleared = None ;
1305+ }
1306+ }
1307+
1308+ fn check_congested ( & self , cx : & mut Context < ' _ > ) -> Poll < ( ) > {
1309+ let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1310+ match & mut * cleared {
1311+ None => Poll :: Ready ( ( ) ) ,
1312+ Some ( wakers) => {
1313+ wakers. push ( cx. waker ( ) . clone ( ) ) ;
1314+ Poll :: Pending
1315+ }
1316+ }
1317+ }
1318+ }
1319+
12881320 /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
12891321 /// partition is exhausted from the start, and if it is polled more than one, it panics.
12901322 #[ derive( Debug , Clone ) ]
12911323 struct CongestedExec {
12921324 schema : Schema ,
12931325 cache : PlanProperties ,
1294- congestion_cleared : Arc < Mutex < bool > > ,
1326+ congestion : Arc < Congestion > ,
12951327 }
12961328
12971329 impl CongestedExec {
@@ -1346,7 +1378,7 @@ mod tests {
13461378 Ok ( Box :: pin ( CongestedStream {
13471379 schema : Arc :: new ( self . schema . clone ( ) ) ,
13481380 none_polled_once : false ,
1349- congestion_cleared : Arc :: clone ( & self . congestion_cleared ) ,
1381+ congestion : Arc :: clone ( & self . congestion ) ,
13501382 partition,
13511383 } ) )
13521384 }
@@ -1373,7 +1405,7 @@ mod tests {
13731405 pub struct CongestedStream {
13741406 schema : SchemaRef ,
13751407 none_polled_once : bool ,
1376- congestion_cleared : Arc < Mutex < bool > > ,
1408+ congestion : Arc < Congestion > ,
13771409 partition : usize ,
13781410 }
13791411
@@ -1393,16 +1425,11 @@ mod tests {
13931425 }
13941426 }
13951427 1 => {
1396- let cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1397- if * cleared {
1398- Poll :: Ready ( None )
1399- } else {
1400- Poll :: Pending
1401- }
1428+ ready ! ( self . congestion. check_congested( _cx) ) ;
1429+ Poll :: Ready ( None )
14021430 }
14031431 2 => {
1404- let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1405- * cleared = true ;
1432+ self . congestion . clear_congestion ( ) ;
14061433 Poll :: Ready ( None )
14071434 }
14081435 _ => unreachable ! ( ) ,
@@ -1423,7 +1450,7 @@ mod tests {
14231450 let source = CongestedExec {
14241451 schema : schema. clone ( ) ,
14251452 cache : CongestedExec :: compute_properties ( Arc :: new ( schema. clone ( ) ) ) ,
1426- congestion_cleared : Arc :: new ( Mutex :: new ( false ) ) ,
1453+ congestion : Arc :: new ( Congestion :: new ( ) ) ,
14271454 } ;
14281455 let spm = SortPreservingMergeExec :: new (
14291456 [ PhysicalSortExpr :: new_default ( Arc :: new ( Column :: new (
0 commit comments