Skip to content

Commit f2a8b07

Browse files
Fix Possible Congestion Scenario in SortPreservingMergeExec (#12302)
* Ready Update merge_fuzz.rs Update merge.rs Update merge.rs Update merge.rs Update merge_fuzz.rs Update merge.rs Update merge_fuzz.rs Update merge.rs Counter is not enough Termination logic with counter * Add comments * Increase threshold * Debug * Remove merge.rs changes * Use waker * Update merge.rs * Simplify the test * Update merge_fuzz.rs * Update merge.rs * Addresses the latest review * fix clippy * Use VecDeque for rotation
1 parent aed84c2 commit f2a8b07

File tree

3 files changed

+213
-14
lines changed

3 files changed

+213
-14
lines changed

datafusion/core/tests/fuzz_cases/merge_fuzz.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
//! Fuzz Test for various corner cases merging streams of RecordBatches
19+
1920
use std::sync::Arc;
2021

2122
use arrow::{

datafusion/physical-plan/src/sorts/merge.rs

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,23 @@
1818
//! Merge that deals with an arbitrary size of streaming inputs.
1919
//! This is an order-preserving merge.
2020
21+
use std::collections::VecDeque;
22+
use std::pin::Pin;
23+
use std::sync::Arc;
24+
use std::task::{ready, Context, Poll};
25+
2126
use crate::metrics::BaselineMetrics;
2227
use crate::sorts::builder::BatchBuilder;
2328
use crate::sorts::cursor::{Cursor, CursorValues};
2429
use crate::sorts::stream::PartitionedStream;
2530
use crate::RecordBatchStream;
31+
2632
use arrow::datatypes::SchemaRef;
2733
use arrow::record_batch::RecordBatch;
2834
use datafusion_common::Result;
2935
use datafusion_execution::memory_pool::MemoryReservation;
36+
3037
use futures::Stream;
31-
use std::pin::Pin;
32-
use std::sync::Arc;
33-
use std::task::{ready, Context, Poll};
3438

3539
/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
3640
type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>;
@@ -86,7 +90,7 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
8690
/// been updated
8791
loser_tree_adjusted: bool,
8892

89-
/// target batch size
93+
/// Target batch size
9094
batch_size: usize,
9195

9296
/// Cursors for each input partition. `None` means the input is exhausted
@@ -97,6 +101,12 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
97101

98102
/// number of rows produced
99103
produced: usize,
104+
105+
/// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`,
106+
/// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the
107+
/// vector to ensure the next iteration starts with a different partition, preventing the same partition
108+
/// from being continuously polled.
109+
uninitiated_partitions: VecDeque<usize>,
100110
}
101111

102112
impl<C: CursorValues> SortPreservingMergeStream<C> {
@@ -121,6 +131,7 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
121131
batch_size,
122132
fetch,
123133
produced: 0,
134+
uninitiated_partitions: (0..stream_count).collect(),
124135
}
125136
}
126137

@@ -154,14 +165,36 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
154165
if self.aborted {
155166
return Poll::Ready(None);
156167
}
157-
// try to initialize the loser tree
168+
// Once all partitions have set their corresponding cursors for the loser tree,
169+
// we skip the following block. Until then, this function may be called multiple
170+
// times and can return Poll::Pending if any partition returns Poll::Pending.
158171
if self.loser_tree.is_empty() {
159-
// Ensure all non-exhausted streams have a cursor from which
160-
// rows can be pulled
161-
for i in 0..self.streams.partitions() {
162-
if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
163-
self.aborted = true;
164-
return Poll::Ready(Some(Err(e)));
172+
let remaining_partitions = self.uninitiated_partitions.clone();
173+
for i in remaining_partitions {
174+
match self.maybe_poll_stream(cx, i) {
175+
Poll::Ready(Err(e)) => {
176+
self.aborted = true;
177+
return Poll::Ready(Some(Err(e)));
178+
}
179+
Poll::Pending => {
180+
// If a partition returns Poll::Pending, to avoid continuously polling it
181+
// and potentially increasing upstream buffer sizes, we move it to the
182+
// back of the polling queue.
183+
if let Some(front) = self.uninitiated_partitions.pop_front() {
184+
// This pop_front can never return `None`.
185+
self.uninitiated_partitions.push_back(front);
186+
}
187+
// This function could remain in a pending state, so we manually wake it here.
188+
// However, this approach can be investigated further to find a more natural way
189+
// to avoid disrupting the runtime scheduler.
190+
cx.waker().wake_by_ref();
191+
return Poll::Pending;
192+
}
193+
_ => {
194+
// If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
195+
// we remove this partition from the queue so it is not polled again.
196+
self.uninitiated_partitions.retain(|idx| *idx != i);
197+
}
165198
}
166199
}
167200
self.init_loser_tree();

datafusion/physical-plan/src/sorts/sort_preserving_merge.rs

Lines changed: 168 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ impl ExecutionPlan for SortPreservingMergeExec {
300300

301301
#[cfg(test)]
302302
mod tests {
303+
use std::fmt::Formatter;
304+
use std::pin::Pin;
305+
use std::sync::Mutex;
306+
use std::task::{Context, Poll};
307+
use std::time::Duration;
303308

304309
use super::*;
305310
use crate::coalesce_partitions::CoalescePartitionsExec;
@@ -310,16 +315,23 @@ mod tests {
310315
use crate::stream::RecordBatchReceiverStream;
311316
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
312317
use crate::test::{self, assert_is_pending, make_partition};
313-
use crate::{collect, common};
318+
use crate::{collect, common, ExecutionMode};
314319

315320
use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray};
316321
use arrow::compute::SortOptions;
317322
use arrow::datatypes::{DataType, Field, Schema};
318323
use arrow::record_batch::RecordBatch;
319-
use datafusion_common::{assert_batches_eq, assert_contains};
324+
use arrow_schema::SchemaRef;
325+
use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError};
326+
use datafusion_common_runtime::SpawnedTask;
320327
use datafusion_execution::config::SessionConfig;
328+
use datafusion_execution::RecordBatchStream;
329+
use datafusion_physical_expr::expressions::Column;
330+
use datafusion_physical_expr::EquivalenceProperties;
331+
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
321332

322-
use futures::{FutureExt, StreamExt};
333+
use futures::{FutureExt, Stream, StreamExt};
334+
use tokio::time::timeout;
323335

324336
#[tokio::test]
325337
async fn test_merge_interleave() {
@@ -1141,4 +1153,157 @@ mod tests {
11411153
collected.as_slice()
11421154
);
11431155
}
1156+
1157+
/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1158+
/// partition is exhausted from the start, and if it is polled more than one, it panics.
1159+
#[derive(Debug, Clone)]
1160+
struct CongestedExec {
1161+
schema: Schema,
1162+
cache: PlanProperties,
1163+
congestion_cleared: Arc<Mutex<bool>>,
1164+
}
1165+
1166+
impl CongestedExec {
1167+
fn compute_properties(schema: SchemaRef) -> PlanProperties {
1168+
let columns = schema
1169+
.fields
1170+
.iter()
1171+
.enumerate()
1172+
.map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1173+
.collect::<Vec<_>>();
1174+
let mut eq_properties = EquivalenceProperties::new(schema);
1175+
eq_properties.add_new_orderings(vec![columns
1176+
.iter()
1177+
.map(|expr| {
1178+
PhysicalSortExpr::new(Arc::clone(expr), SortOptions::default())
1179+
})
1180+
.collect::<Vec<_>>()]);
1181+
let mode = ExecutionMode::Unbounded;
1182+
PlanProperties::new(eq_properties, Partitioning::Hash(columns, 3), mode)
1183+
}
1184+
}
1185+
1186+
impl ExecutionPlan for CongestedExec {
1187+
fn name(&self) -> &'static str {
1188+
Self::static_name()
1189+
}
1190+
fn as_any(&self) -> &dyn Any {
1191+
self
1192+
}
1193+
fn properties(&self) -> &PlanProperties {
1194+
&self.cache
1195+
}
1196+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1197+
vec![]
1198+
}
1199+
fn with_new_children(
1200+
self: Arc<Self>,
1201+
_: Vec<Arc<dyn ExecutionPlan>>,
1202+
) -> Result<Arc<dyn ExecutionPlan>> {
1203+
Ok(self)
1204+
}
1205+
fn execute(
1206+
&self,
1207+
partition: usize,
1208+
_context: Arc<TaskContext>,
1209+
) -> Result<SendableRecordBatchStream> {
1210+
Ok(Box::pin(CongestedStream {
1211+
schema: Arc::new(self.schema.clone()),
1212+
none_polled_once: false,
1213+
congestion_cleared: Arc::clone(&self.congestion_cleared),
1214+
partition,
1215+
}))
1216+
}
1217+
}
1218+
1219+
impl DisplayAs for CongestedExec {
1220+
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1221+
match t {
1222+
DisplayFormatType::Default | DisplayFormatType::Verbose => {
1223+
write!(f, "CongestedExec",).unwrap()
1224+
}
1225+
}
1226+
Ok(())
1227+
}
1228+
}
1229+
1230+
/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1231+
/// partition is exhausted from the start, and if it is polled more than once, it panics.
1232+
#[derive(Debug)]
1233+
pub struct CongestedStream {
1234+
schema: SchemaRef,
1235+
none_polled_once: bool,
1236+
congestion_cleared: Arc<Mutex<bool>>,
1237+
partition: usize,
1238+
}
1239+
1240+
impl Stream for CongestedStream {
1241+
type Item = Result<RecordBatch>;
1242+
fn poll_next(
1243+
mut self: Pin<&mut Self>,
1244+
_cx: &mut Context<'_>,
1245+
) -> Poll<Option<Self::Item>> {
1246+
match self.partition {
1247+
0 => {
1248+
if self.none_polled_once {
1249+
panic!("Exhausted stream is polled more than one")
1250+
} else {
1251+
self.none_polled_once = true;
1252+
Poll::Ready(None)
1253+
}
1254+
}
1255+
1 => {
1256+
let cleared = self.congestion_cleared.lock().unwrap();
1257+
if *cleared {
1258+
Poll::Ready(None)
1259+
} else {
1260+
Poll::Pending
1261+
}
1262+
}
1263+
2 => {
1264+
let mut cleared = self.congestion_cleared.lock().unwrap();
1265+
*cleared = true;
1266+
Poll::Ready(None)
1267+
}
1268+
_ => unreachable!(),
1269+
}
1270+
}
1271+
}
1272+
1273+
impl RecordBatchStream for CongestedStream {
1274+
fn schema(&self) -> SchemaRef {
1275+
Arc::clone(&self.schema)
1276+
}
1277+
}
1278+
1279+
#[tokio::test]
1280+
async fn test_spm_congestion() -> Result<()> {
1281+
let task_ctx = Arc::new(TaskContext::default());
1282+
let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1283+
let source = CongestedExec {
1284+
schema: schema.clone(),
1285+
cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
1286+
congestion_cleared: Arc::new(Mutex::new(false)),
1287+
};
1288+
let spm = SortPreservingMergeExec::new(
1289+
vec![PhysicalSortExpr::new(
1290+
Arc::new(Column::new("c1", 0)),
1291+
SortOptions::default(),
1292+
)],
1293+
Arc::new(source),
1294+
);
1295+
let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1296+
1297+
let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1298+
match result {
1299+
Ok(Ok(Ok(_batches))) => Ok(()),
1300+
Ok(Ok(Err(e))) => Err(e),
1301+
Ok(Err(_)) => Err(DataFusionError::Execution(
1302+
"SortPreservingMerge task panicked or was cancelled".to_string(),
1303+
)),
1304+
Err(_) => Err(DataFusionError::Execution(
1305+
"SortPreservingMerge caused a deadlock".to_string(),
1306+
)),
1307+
}
1308+
}
11441309
}

0 commit comments

Comments
 (0)