Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions datafusion/core/src/physical_plan/sorts/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::common::Result;
use arrow::compute::interleave;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_execution::memory_pool::MemoryReservation;

#[derive(Debug, Copy, Clone, Default)]
struct BatchCursor {
Expand All @@ -37,6 +38,9 @@ pub struct BatchBuilder {
/// Maintain a list of [`RecordBatch`] and their corresponding stream
batches: Vec<(usize, RecordBatch)>,

/// Accounts for memory used by buffered batches
reservation: MemoryReservation,

/// The current [`BatchCursor`] for each stream
cursors: Vec<BatchCursor>,

Expand All @@ -47,23 +51,31 @@ pub struct BatchBuilder {

impl BatchBuilder {
/// Create a new [`BatchBuilder`] with the provided `stream_count` and `batch_size`
pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) -> Self {
pub fn new(
schema: SchemaRef,
stream_count: usize,
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
Self {
schema,
batches: Vec::with_capacity(stream_count * 2),
cursors: vec![BatchCursor::default(); stream_count],
indices: Vec::with_capacity(batch_size),
reservation,
}
}

/// Append a new batch in `stream_idx`
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) {
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
self.reservation.try_grow(batch.get_array_memory_size())?;
let batch_idx = self.batches.len();
self.batches.push((stream_idx, batch));
self.cursors[stream_idx] = BatchCursor {
batch_idx,
row_idx: 0,
}
};
Ok(())
}

/// Append the next row from `stream_idx`
Expand Down Expand Up @@ -119,14 +131,17 @@ impl BatchBuilder {
// We can therefore drop all but the last batch for each stream
let mut batch_idx = 0;
let mut retained = 0;
self.batches.retain(|(stream_idx, _)| {
self.batches.retain(|(stream_idx, batch)| {
let stream_cursor = &mut self.cursors[*stream_idx];
let retain = stream_cursor.batch_idx == batch_idx;
batch_idx += 1;

if retain {
stream_cursor.batch_idx = retained;
retained += 1;
match retain {
true => {
stream_cursor.batch_idx = retained;
retained += 1;
}
false => self.reservation.shrink(batch.get_array_memory_size()),
}
retain
});
Expand Down
9 changes: 7 additions & 2 deletions datafusion/core/src/physical_plan/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::datatypes::ArrowNativeTypeOp;
use arrow::row::{Row, Rows};
use arrow_array::types::ByteArrayType;
use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray};
use datafusion_execution::memory_pool::MemoryReservation;
use std::cmp::Ordering;

/// A [`Cursor`] for [`Rows`]
Expand All @@ -29,6 +30,9 @@ pub struct RowCursor {
num_rows: usize,

rows: Rows,

#[allow(dead_code)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would help to note here in comments why the code needs to keep around a field that is never read (dead_code). I think it is to keep the reservation around long enough?

reservation: MemoryReservation,
}

impl std::fmt::Debug for RowCursor {
Expand All @@ -41,12 +45,13 @@ impl std::fmt::Debug for RowCursor {
}

impl RowCursor {
/// Create a new SortKeyCursor
pub fn new(rows: Rows) -> Self {
/// Create a new SortKeyCursor from `rows` and the associated `reservation`
pub fn new(rows: Rows, reservation: MemoryReservation) -> Self {
Self {
cur_row: 0,
num_rows: rows.num_rows(),
rows,
reservation,
}
}

Expand Down
30 changes: 20 additions & 10 deletions datafusion/core/src/physical_plan/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::physical_plan::{
use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::*;
use datafusion_execution::memory_pool::MemoryReservation;
use futures::Stream;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
Expand All @@ -39,13 +40,14 @@ macro_rules! primitive_merge_helper {
}

macro_rules! merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $reservation:ident) => {{
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
$reservation,
)));
}};
}
Expand All @@ -57,27 +59,35 @@ pub(crate) fn streaming_merge(
expressions: &[PhysicalSortExpr],
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
) -> Result<SendableRecordBatchStream> {
// Special case single column comparisons with optimized cursor implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size)
data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, reservation),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, reservation)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, reservation)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, reservation)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, reservation)
_ => {}
}
}

let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?;
let streams = RowCursorStream::try_new(
schema.as_ref(),
expressions,
streams,
reservation.split_empty(),
)?;

Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
schema,
metrics,
batch_size,
reservation,
)))
}

Expand Down Expand Up @@ -148,11 +158,12 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
schema: SchemaRef,
metrics: BaselineMetrics,
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
let stream_count = streams.partitions();

Self {
in_progress: BatchBuilder::new(schema, stream_count, batch_size),
in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation),
streams,
metrics,
aborted: false,
Expand Down Expand Up @@ -181,8 +192,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
Some(Err(e)) => Poll::Ready(Err(e)),
Some(Ok((cursor, batch))) => {
self.cursors[idx] = Some(cursor);
self.in_progress.push_batch(idx, batch);
Poll::Ready(Ok(()))
Poll::Ready(self.in_progress.push_batch(idx, batch))
}
}
}
Expand Down
64 changes: 41 additions & 23 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ use tempfile::NamedTempFile;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::task;

/// How much memory to reserve for performing in-memory sorts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// How much memory to reserve for performing in-memory sorts
/// How much memory to reserve for performing in-memory sorts prior to spill

const EXTERNAL_SORTER_MERGE_RESERVATION: usize = 10 * 1024 * 1024;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a massive fan of this, but this somewhat patches around the issue that once we initiate a merge we can't then spill

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this approach is that even 10MB may not be enough to correctly merge the batches prior to spilling. So some queries that today would succeed (though exceed their memory limits) might fail.

It seems to me better approaches (as follow on PRs) would be:

  1. Make this a config parameter so users can avoid the error by reserving more memory up front if needed
  2. teach SortExec how to write more (smaller) spill files if it doesn't have enough memory to merge the in memory batches.

However, given the behavior on master today is to simply ignore the reservation and exceed the memory limit this behavior seems better than before.

I suggest we merge this PR as is and file a follow on ticket for the improved behavior


struct ExternalSorterMetrics {
/// metrics
baseline: BaselineMetrics,
Expand Down Expand Up @@ -94,8 +97,10 @@ struct ExternalSorter {
expr: Arc<[PhysicalSortExpr]>,
metrics: ExternalSorterMetrics,
fetch: Option<usize>,
/// Reservation for in_mem_batches
reservation: MemoryReservation,
partition_id: usize,
/// Reservation for in memory sorting of batches
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Reservation for in memory sorting of batches
/// Reservation for in memory sorting of batches, prior to spilling.
/// Without this reservation, when the memory budget is exhausted
/// it might not be possible to merge the in memory batches as part
/// of spilling.

merge_reservation: MemoryReservation,
runtime: Arc<RuntimeEnv>,
batch_size: usize,
}
Expand All @@ -115,6 +120,12 @@ impl ExternalSorter {
.with_can_spill(true)
.register(&runtime.memory_pool);

let mut merge_reservation =
MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]"))
.register(&runtime.memory_pool);

merge_reservation.resize(EXTERNAL_SORTER_MERGE_RESERVATION);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it as a positive sign that this was required to make the spill tests pass, without this the merge would exceed the memory limit and fail


Self {
schema,
in_mem_batches: vec![],
Expand All @@ -124,7 +135,7 @@ impl ExternalSorter {
metrics,
fetch,
reservation,
partition_id,
merge_reservation,
runtime,
batch_size,
}
Expand Down Expand Up @@ -189,12 +200,10 @@ impl ExternalSorter {
&self.expr,
self.metrics.baseline.clone(),
self.batch_size,
self.reservation.split_empty(),
)
} else if !self.in_mem_batches.is_empty() {
let result = self.in_mem_sort_stream(self.metrics.baseline.clone());
// Report to the memory manager we are no longer using memory
self.reservation.free();
result
self.in_mem_sort_stream(self.metrics.baseline.clone())
} else {
Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
}
Expand Down Expand Up @@ -238,6 +247,9 @@ impl ExternalSorter {
return Ok(());
}

// Release the memory reserved for merge
self.merge_reservation.free();

self.in_mem_batches = self
.in_mem_sort_stream(self.metrics.baseline.intermediate())?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked that in_mem_sort_stream correctly respects self.reservation 👍

.try_collect()
Expand All @@ -249,7 +261,11 @@ impl ExternalSorter {
.map(|x| x.get_array_memory_size())
.sum();

self.reservation.resize(size);
// Reserve headroom for next sort
self.merge_reservation
.resize(EXTERNAL_SORTER_MERGE_RESERVATION);

self.reservation.try_resize(size)?;
self.in_mem_batches_sorted = true;
Ok(())
}
Expand All @@ -262,9 +278,8 @@ impl ExternalSorter {
assert_ne!(self.in_mem_batches.len(), 0);
if self.in_mem_batches.len() == 1 {
let batch = self.in_mem_batches.remove(0);
let stream = self.sort_batch_stream(batch, metrics)?;
self.in_mem_batches.clear();
return Ok(stream);
let reservation = self.reservation.take();
return self.sort_batch_stream(batch, metrics, reservation);
}

// If less than 1MB of in-memory data, concatenate and sort in place
Expand All @@ -274,14 +289,19 @@ impl ExternalSorter {
// Concatenate memory batches together and sort
let batch = concat_batches(&self.schema, &self.in_mem_batches)?;
self.in_mem_batches.clear();
return self.sort_batch_stream(batch, metrics);
self.reservation.try_resize(batch.get_array_memory_size())?;
let reservation = self.reservation.take();
return self.sort_batch_stream(batch, metrics, reservation);
}

let streams = std::mem::take(&mut self.in_mem_batches)
.into_iter()
.map(|batch| {
let metrics = self.metrics.baseline.intermediate();
Ok(spawn_buffered(self.sort_batch_stream(batch, metrics)?, 1))
let reservation =
self.reservation.split(batch.get_array_memory_size())?;
let input = self.sort_batch_stream(batch, metrics, reservation)?;
Ok(spawn_buffered(input, 1))
})
.collect::<Result<_>>()?;

Expand All @@ -293,30 +313,25 @@ impl ExternalSorter {
&self.expr,
metrics,
self.batch_size,
self.merge_reservation.split_empty(),
)
}

fn sort_batch_stream(
&self,
batch: RecordBatch,
metrics: BaselineMetrics,
reservation: MemoryReservation,
) -> Result<SendableRecordBatchStream> {
let schema = batch.schema();

let mut reservation =
MemoryConsumer::new(format!("sort_batch_stream{}", self.partition_id))
.register(&self.runtime.memory_pool);

// TODO: This should probably be try_grow (#5885)
reservation.resize(batch.get_array_memory_size());

let fetch = self.fetch;
let expressions = self.expr.clone();
let stream = futures::stream::once(futures::future::lazy(move |_| {
let sorted = sort_batch(&batch, &expressions, fetch)?;
metrics.record_output(sorted.num_rows());
drop(batch);
reservation.free();
drop(reservation);
Ok(sorted)
}));
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
Expand Down Expand Up @@ -723,7 +738,8 @@ mod tests {
#[tokio::test]
async fn test_sort_spill() -> Result<()> {
// trigger spill there will be 4 batches with 5.5KB for each
let config = RuntimeConfig::new().with_memory_limit(12288, 1.0);
let config = RuntimeConfig::new()
.with_memory_limit(EXTERNAL_SORTER_MERGE_RESERVATION + 12288, 1.0);
let runtime = Arc::new(RuntimeEnv::new(config)?);
let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime);

Expand Down Expand Up @@ -805,8 +821,10 @@ mod tests {
];

for (fetch, expect_spillage) in test_options {
let config = RuntimeConfig::new()
.with_memory_limit(avg_batch_size * (partitions - 1), 1.0);
let config = RuntimeConfig::new().with_memory_limit(
EXTERNAL_SORTER_MERGE_RESERVATION + avg_batch_size * (partitions - 1),
1.0,
);
let runtime = Arc::new(RuntimeEnv::new(config)?);
let session_ctx =
SessionContext::with_config_rt(SessionConfig::new(), runtime);
Expand Down
Loading