@@ -300,6 +300,11 @@ impl ExecutionPlan for SortPreservingMergeExec {
300300
301301#[ cfg( test) ]
302302mod tests {
303+ use std:: fmt:: Formatter ;
304+ use std:: pin:: Pin ;
305+ use std:: sync:: Mutex ;
306+ use std:: task:: { Context , Poll } ;
307+ use std:: time:: Duration ;
303308
304309 use super :: * ;
305310 use crate :: coalesce_partitions:: CoalescePartitionsExec ;
@@ -310,16 +315,23 @@ mod tests {
310315 use crate :: stream:: RecordBatchReceiverStream ;
311316 use crate :: test:: exec:: { assert_strong_count_converges_to_zero, BlockingExec } ;
312317 use crate :: test:: { self , assert_is_pending, make_partition} ;
313- use crate :: { collect, common} ;
318+ use crate :: { collect, common, ExecutionMode } ;
314319
315320 use arrow:: array:: { ArrayRef , Int32Array , StringArray , TimestampNanosecondArray } ;
316321 use arrow:: compute:: SortOptions ;
317322 use arrow:: datatypes:: { DataType , Field , Schema } ;
318323 use arrow:: record_batch:: RecordBatch ;
319- use datafusion_common:: { assert_batches_eq, assert_contains} ;
324+ use arrow_schema:: SchemaRef ;
325+ use datafusion_common:: { assert_batches_eq, assert_contains, DataFusionError } ;
326+ use datafusion_common_runtime:: SpawnedTask ;
320327 use datafusion_execution:: config:: SessionConfig ;
328+ use datafusion_execution:: RecordBatchStream ;
329+ use datafusion_physical_expr:: expressions:: Column ;
330+ use datafusion_physical_expr:: EquivalenceProperties ;
331+ use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
321332
322- use futures:: { FutureExt , StreamExt } ;
333+ use futures:: { FutureExt , Stream , StreamExt } ;
334+ use tokio:: time:: timeout;
323335
324336 #[ tokio:: test]
325337 async fn test_merge_interleave ( ) {
@@ -1141,4 +1153,157 @@ mod tests {
11411153 collected. as_slice( )
11421154 ) ;
11431155 }
1156+
1157+ /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1158+ /// partition is exhausted from the start, and if it is polled more than one, it panics.
1159+ #[ derive( Debug , Clone ) ]
1160+ struct CongestedExec {
1161+ schema : Schema ,
1162+ cache : PlanProperties ,
1163+ congestion_cleared : Arc < Mutex < bool > > ,
1164+ }
1165+
1166+ impl CongestedExec {
1167+ fn compute_properties ( schema : SchemaRef ) -> PlanProperties {
1168+ let columns = schema
1169+ . fields
1170+ . iter ( )
1171+ . enumerate ( )
1172+ . map ( |( i, f) | Arc :: new ( Column :: new ( f. name ( ) , i) ) as Arc < dyn PhysicalExpr > )
1173+ . collect :: < Vec < _ > > ( ) ;
1174+ let mut eq_properties = EquivalenceProperties :: new ( schema) ;
1175+ eq_properties. add_new_orderings ( vec ! [ columns
1176+ . iter( )
1177+ . map( |expr| {
1178+ PhysicalSortExpr :: new( Arc :: clone( expr) , SortOptions :: default ( ) )
1179+ } )
1180+ . collect:: <Vec <_>>( ) ] ) ;
1181+ let mode = ExecutionMode :: Unbounded ;
1182+ PlanProperties :: new ( eq_properties, Partitioning :: Hash ( columns, 3 ) , mode)
1183+ }
1184+ }
1185+
1186+ impl ExecutionPlan for CongestedExec {
1187+ fn name ( & self ) -> & ' static str {
1188+ Self :: static_name ( )
1189+ }
1190+ fn as_any ( & self ) -> & dyn Any {
1191+ self
1192+ }
1193+ fn properties ( & self ) -> & PlanProperties {
1194+ & self . cache
1195+ }
1196+ fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
1197+ vec ! [ ]
1198+ }
1199+ fn with_new_children (
1200+ self : Arc < Self > ,
1201+ _: Vec < Arc < dyn ExecutionPlan > > ,
1202+ ) -> Result < Arc < dyn ExecutionPlan > > {
1203+ Ok ( self )
1204+ }
1205+ fn execute (
1206+ & self ,
1207+ partition : usize ,
1208+ _context : Arc < TaskContext > ,
1209+ ) -> Result < SendableRecordBatchStream > {
1210+ Ok ( Box :: pin ( CongestedStream {
1211+ schema : Arc :: new ( self . schema . clone ( ) ) ,
1212+ none_polled_once : false ,
1213+ congestion_cleared : Arc :: clone ( & self . congestion_cleared ) ,
1214+ partition,
1215+ } ) )
1216+ }
1217+ }
1218+
1219+ impl DisplayAs for CongestedExec {
1220+ fn fmt_as ( & self , t : DisplayFormatType , f : & mut Formatter ) -> std:: fmt:: Result {
1221+ match t {
1222+ DisplayFormatType :: Default | DisplayFormatType :: Verbose => {
1223+ write ! ( f, "CongestedExec" , ) . unwrap ( )
1224+ }
1225+ }
1226+ Ok ( ( ) )
1227+ }
1228+ }
1229+
1230+ /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1231+ /// partition is exhausted from the start, and if it is polled more than once, it panics.
1232+ #[ derive( Debug ) ]
1233+ pub struct CongestedStream {
1234+ schema : SchemaRef ,
1235+ none_polled_once : bool ,
1236+ congestion_cleared : Arc < Mutex < bool > > ,
1237+ partition : usize ,
1238+ }
1239+
1240+ impl Stream for CongestedStream {
1241+ type Item = Result < RecordBatch > ;
1242+ fn poll_next (
1243+ mut self : Pin < & mut Self > ,
1244+ _cx : & mut Context < ' _ > ,
1245+ ) -> Poll < Option < Self :: Item > > {
1246+ match self . partition {
1247+ 0 => {
1248+ if self . none_polled_once {
1249+ panic ! ( "Exhausted stream is polled more than one" )
1250+ } else {
1251+ self . none_polled_once = true ;
1252+ Poll :: Ready ( None )
1253+ }
1254+ }
1255+ 1 => {
1256+ let cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1257+ if * cleared {
1258+ Poll :: Ready ( None )
1259+ } else {
1260+ Poll :: Pending
1261+ }
1262+ }
1263+ 2 => {
1264+ let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1265+ * cleared = true ;
1266+ Poll :: Ready ( None )
1267+ }
1268+ _ => unreachable ! ( ) ,
1269+ }
1270+ }
1271+ }
1272+
1273+ impl RecordBatchStream for CongestedStream {
1274+ fn schema ( & self ) -> SchemaRef {
1275+ Arc :: clone ( & self . schema )
1276+ }
1277+ }
1278+
1279+ #[ tokio:: test]
1280+ async fn test_spm_congestion ( ) -> Result < ( ) > {
1281+ let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
1282+ let schema = Schema :: new ( vec ! [ Field :: new( "c1" , DataType :: UInt64 , false ) ] ) ;
1283+ let source = CongestedExec {
1284+ schema : schema. clone ( ) ,
1285+ cache : CongestedExec :: compute_properties ( Arc :: new ( schema. clone ( ) ) ) ,
1286+ congestion_cleared : Arc :: new ( Mutex :: new ( false ) ) ,
1287+ } ;
1288+ let spm = SortPreservingMergeExec :: new (
1289+ vec ! [ PhysicalSortExpr :: new(
1290+ Arc :: new( Column :: new( "c1" , 0 ) ) ,
1291+ SortOptions :: default ( ) ,
1292+ ) ] ,
1293+ Arc :: new ( source) ,
1294+ ) ;
1295+ let spm_task = SpawnedTask :: spawn ( collect ( Arc :: new ( spm) , task_ctx) ) ;
1296+
1297+ let result = timeout ( Duration :: from_secs ( 3 ) , spm_task. join ( ) ) . await ;
1298+ match result {
1299+ Ok ( Ok ( Ok ( _batches) ) ) => Ok ( ( ) ) ,
1300+ Ok ( Ok ( Err ( e) ) ) => Err ( e) ,
1301+ Ok ( Err ( _) ) => Err ( DataFusionError :: Execution (
1302+ "SortPreservingMerge task panicked or was cancelled" . to_string ( ) ,
1303+ ) ) ,
1304+ Err ( _) => Err ( DataFusionError :: Execution (
1305+ "SortPreservingMerge caused a deadlock" . to_string ( ) ,
1306+ ) ) ,
1307+ }
1308+ }
11441309}
0 commit comments