Skip to content

Commit 4a857f0

Browse files
authored
bug: remove busy-wait while sort is ongoing (#16322)
* bug: remove busy-wait while sort is ongoing (#16321) * bug: make CongestedStream a correct Stream implementation * Make test_spm_congestion independent of the exact poll order
1 parent e1716f9 commit 4a857f0

File tree

2 files changed

+88
-43
lines changed

2 files changed

+88
-43
lines changed

datafusion/physical-plan/src/sorts/merge.rs

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
//! Merge that deals with an arbitrary size of streaming inputs.
1919
//! This is an order-preserving merge.
2020
21-
use std::collections::VecDeque;
2221
use std::pin::Pin;
2322
use std::sync::Arc;
2423
use std::task::{ready, Context, Poll};
@@ -143,11 +142,8 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
143142
/// number of rows produced
144143
produced: usize,
145144

146-
/// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`,
147-
/// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the
148-
/// vector to ensure the next iteration starts with a different partition, preventing the same partition
149-
/// from being continuously polled.
150-
uninitiated_partitions: VecDeque<usize>,
145+
/// This vector contains the indices of the partitions that have not started emitting yet.
146+
uninitiated_partitions: Vec<usize>,
151147
}
152148

153149
impl<C: CursorValues> SortPreservingMergeStream<C> {
@@ -216,36 +212,50 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
216212
// Once all partitions have set their corresponding cursors for the loser tree,
217213
// we skip the following block. Until then, this function may be called multiple
218214
// times and can return Poll::Pending if any partition returns Poll::Pending.
215+
219216
if self.loser_tree.is_empty() {
220-
while let Some(&partition_idx) = self.uninitiated_partitions.front() {
217+
// Manual indexing since we're iterating over the vector and shrinking it in the loop
218+
let mut idx = 0;
219+
while idx < self.uninitiated_partitions.len() {
220+
let partition_idx = self.uninitiated_partitions[idx];
221221
match self.maybe_poll_stream(cx, partition_idx) {
222222
Poll::Ready(Err(e)) => {
223223
self.aborted = true;
224224
return Poll::Ready(Some(Err(e)));
225225
}
226226
Poll::Pending => {
227-
// If a partition returns Poll::Pending, to avoid continuously polling it
228-
// and potentially increasing upstream buffer sizes, we move it to the
229-
// back of the polling queue.
230-
self.uninitiated_partitions.rotate_left(1);
231-
232-
// This function could remain in a pending state, so we manually wake it here.
233-
// However, this approach can be investigated further to find a more natural way
234-
// to avoid disrupting the runtime scheduler.
235-
cx.waker().wake_by_ref();
236-
return Poll::Pending;
227+
// The polled stream is pending which means we're already set up to
228+
// be woken when necessary
229+
// Try the next stream
230+
idx += 1;
237231
}
238232
_ => {
239-
// If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
240-
// we remove this partition from the queue so it is not polled again.
241-
self.uninitiated_partitions.pop_front();
233+
// The polled stream is ready
234+
// Remove it from uninitiated_partitions
235+
// Don't bump idx here, since a new element will have taken its
236+
// place which we'll try in the next loop iteration
237+
// swap_remove will change the partition poll order, but that shouldn't
238+
// make a difference since we're waiting for all streams to be ready.
239+
self.uninitiated_partitions.swap_remove(idx);
242240
}
243241
}
244242
}
245243

246-
// Claim the memory for the uninitiated partitions
247-
self.uninitiated_partitions.shrink_to_fit();
248-
self.init_loser_tree();
244+
if self.uninitiated_partitions.is_empty() {
245+
// If there are no more uninitiated partitions, set up the loser tree and continue
246+
// to the next phase.
247+
248+
// Claim the memory for the uninitiated partitions
249+
self.uninitiated_partitions.shrink_to_fit();
250+
self.init_loser_tree();
251+
} else {
252+
// There are still uninitiated partitions so return pending.
253+
// We only get here if we've polled all uninitiated streams and at least one of them
254+
// returned pending itself. That means we will be woken as soon as one of the
255+
// streams would like to be polled again.
256+
// There is no need to reschedule ourselves eagerly.
257+
return Poll::Pending;
258+
}
249259
}
250260

251261
// NB timer records time taken on drop, so there are no

datafusion/physical-plan/src/sorts/sort_preserving_merge.rs

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,11 @@ impl ExecutionPlan for SortPreservingMergeExec {
378378

379379
#[cfg(test)]
380380
mod tests {
381+
use std::collections::HashSet;
381382
use std::fmt::Formatter;
382383
use std::pin::Pin;
383384
use std::sync::Mutex;
384-
use std::task::{Context, Poll};
385+
use std::task::{ready, Context, Poll, Waker};
385386
use std::time::Duration;
386387

387388
use super::*;
@@ -1285,13 +1286,50 @@ mod tests {
12851286
"#);
12861287
}
12871288

1289+
#[derive(Debug)]
1290+
struct CongestionState {
1291+
wakers: Vec<Waker>,
1292+
unpolled_partitions: HashSet<usize>,
1293+
}
1294+
1295+
#[derive(Debug)]
1296+
struct Congestion {
1297+
congestion_state: Mutex<CongestionState>,
1298+
}
1299+
1300+
impl Congestion {
1301+
fn new(partition_count: usize) -> Self {
1302+
Congestion {
1303+
congestion_state: Mutex::new(CongestionState {
1304+
wakers: vec![],
1305+
unpolled_partitions: (0usize..partition_count).collect(),
1306+
}),
1307+
}
1308+
}
1309+
1310+
fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1311+
let mut state = self.congestion_state.lock().unwrap();
1312+
1313+
state.unpolled_partitions.remove(&partition);
1314+
1315+
if state.unpolled_partitions.is_empty() {
1316+
state.wakers.iter().for_each(|w| w.wake_by_ref());
1317+
state.wakers.clear();
1318+
Poll::Ready(())
1319+
} else {
1320+
state.wakers.push(cx.waker().clone());
1321+
Poll::Pending
1322+
}
1323+
}
1324+
}
1325+
12881326
/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
12891327
/// partition is exhausted from the start, and if it is polled more than one, it panics.
12901328
#[derive(Debug, Clone)]
12911329
struct CongestedExec {
12921330
schema: Schema,
12931331
cache: PlanProperties,
1294-
congestion_cleared: Arc<Mutex<bool>>,
1332+
congestion: Arc<Congestion>,
12951333
}
12961334

12971335
impl CongestedExec {
@@ -1346,7 +1384,7 @@ mod tests {
13461384
Ok(Box::pin(CongestedStream {
13471385
schema: Arc::new(self.schema.clone()),
13481386
none_polled_once: false,
1349-
congestion_cleared: Arc::clone(&self.congestion_cleared),
1387+
congestion: Arc::clone(&self.congestion),
13501388
partition,
13511389
}))
13521390
}
@@ -1373,39 +1411,30 @@ mod tests {
13731411
pub struct CongestedStream {
13741412
schema: SchemaRef,
13751413
none_polled_once: bool,
1376-
congestion_cleared: Arc<Mutex<bool>>,
1414+
congestion: Arc<Congestion>,
13771415
partition: usize,
13781416
}
13791417

13801418
impl Stream for CongestedStream {
13811419
type Item = Result<RecordBatch>;
13821420
fn poll_next(
13831421
mut self: Pin<&mut Self>,
1384-
_cx: &mut Context<'_>,
1422+
cx: &mut Context<'_>,
13851423
) -> Poll<Option<Self::Item>> {
13861424
match self.partition {
13871425
0 => {
1426+
let _ = self.congestion.check_congested(self.partition, cx);
13881427
if self.none_polled_once {
1389-
panic!("Exhausted stream is polled more than one")
1428+
panic!("Exhausted stream is polled more than once")
13901429
} else {
13911430
self.none_polled_once = true;
13921431
Poll::Ready(None)
13931432
}
13941433
}
1395-
1 => {
1396-
let cleared = self.congestion_cleared.lock().unwrap();
1397-
if *cleared {
1398-
Poll::Ready(None)
1399-
} else {
1400-
Poll::Pending
1401-
}
1402-
}
1403-
2 => {
1404-
let mut cleared = self.congestion_cleared.lock().unwrap();
1405-
*cleared = true;
1434+
_ => {
1435+
ready!(self.congestion.check_congested(self.partition, cx));
14061436
Poll::Ready(None)
14071437
}
1408-
_ => unreachable!(),
14091438
}
14101439
}
14111440
}
@@ -1420,10 +1449,16 @@ mod tests {
14201449
async fn test_spm_congestion() -> Result<()> {
14211450
let task_ctx = Arc::new(TaskContext::default());
14221451
let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1452+
let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1453+
let &partition_count = match properties.output_partitioning() {
1454+
Partitioning::RoundRobinBatch(partitions) => partitions,
1455+
Partitioning::Hash(_, partitions) => partitions,
1456+
Partitioning::UnknownPartitioning(partitions) => partitions,
1457+
};
14231458
let source = CongestedExec {
14241459
schema: schema.clone(),
1425-
cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
1426-
congestion_cleared: Arc::new(Mutex::new(false)),
1460+
cache: properties,
1461+
congestion: Arc::new(Congestion::new(partition_count)),
14271462
};
14281463
let spm = SortPreservingMergeExec::new(
14291464
[PhysicalSortExpr::new_default(Arc::new(Column::new(

0 commit comments

Comments
 (0)