Skip to content

Commit aeb6c84

Browse files
committed
rfc: optional skipping partial aggregation
1 parent a721be1 commit aeb6c84

File tree

9 files changed

+671
-4
lines changed

9 files changed

+671
-4
lines changed

datafusion/common/src/config.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,15 @@ config_namespace! {
324324

325325
/// Should DataFusion keep the columns used for partition_by in the output RecordBatches
326326
pub keep_partition_by_columns: bool, default = false
327+
328+
/// Aggregation ratio (number of distinct groups / number of input rows)
329+
/// threshold for skipping partial aggregation. If the value is greater
330+
/// then partial aggregation will skip aggregation for further input
331+
pub skip_partial_aggregation_probe_ratio_threshold: f64, default = 0.8
332+
333+
/// Number of input rows partial aggregation partition should process, before
334+
/// aggregation ratio check and trying to switch to skipping aggregation mode
335+
pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000
327336
}
328337
}
329338

datafusion/expr/src/groups_accumulator.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Vectorized [`GroupsAccumulator`]
1919
2020
use arrow_array::{ArrayRef, BooleanArray};
21-
use datafusion_common::Result;
21+
use datafusion_common::{not_impl_err, Result};
2222

2323
/// Describes how many rows should be emitted during grouping.
2424
#[derive(Debug, Clone, Copy)]
@@ -158,6 +158,24 @@ pub trait GroupsAccumulator: Send {
158158
total_num_groups: usize,
159159
) -> Result<()>;
160160

161+
/// Converts input batch to intermediate aggregate state,
162+
/// without grouping (each input row considered as a separate
163+
/// group).
164+
fn convert_to_state(
165+
&self,
166+
_values: &[ArrayRef],
167+
_opt_filter: Option<&BooleanArray>,
168+
) -> Result<Vec<ArrayRef>> {
169+
not_impl_err!("Input batch conversion to state not implemented")
170+
}
171+
172+
/// Returns `true` is groups accumulator supports input batch
173+
/// to intermediate aggregate state conversion (`convert_to_state`
174+
/// method is implemented).
175+
fn convert_to_state_supported(&self) -> bool {
176+
false
177+
}
178+
161179
/// Amount of memory used to store the state of this accumulator,
162180
/// in bytes. This function is called once per batch, so it should
163181
/// be `O(n)` to compute, not `O(num_groups)`

datafusion/functions-aggregate/src/count.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use arrow::{
3333
};
3434

3535
use arrow::{
36-
array::{Array, BooleanArray, Int64Array, PrimitiveArray},
36+
array::{Array, BooleanArray, Int64Array, Int64Builder, PrimitiveArray},
3737
buffer::BooleanBuffer,
3838
};
3939
use datafusion_common::{
@@ -433,6 +433,49 @@ impl GroupsAccumulator for CountGroupsAccumulator {
433433
Ok(vec![Arc::new(counts) as ArrayRef])
434434
}
435435

436+
fn convert_to_state(
437+
&self,
438+
values: &[ArrayRef],
439+
opt_filter: Option<&BooleanArray>,
440+
) -> Result<Vec<ArrayRef>> {
441+
let values = &values[0];
442+
443+
let state_array = match (values.logical_nulls(), opt_filter) {
444+
(Some(nulls), None) => {
445+
let mut builder = Int64Builder::with_capacity(values.len());
446+
nulls
447+
.into_iter()
448+
.for_each(|is_valid| builder.append_value(is_valid as i64));
449+
builder.finish()
450+
}
451+
(Some(nulls), Some(filter)) => {
452+
let mut builder = Int64Builder::with_capacity(values.len());
453+
nulls.into_iter().zip(filter.iter()).for_each(
454+
|(is_valid, filter_value)| {
455+
builder.append_value(
456+
(is_valid && filter_value.is_some_and(|val| val)) as i64,
457+
)
458+
},
459+
);
460+
builder.finish()
461+
}
462+
(None, Some(filter)) => {
463+
let mut builder = Int64Builder::with_capacity(values.len());
464+
filter.into_iter().for_each(|filter_value| {
465+
builder.append_value(filter_value.is_some_and(|val| val) as i64)
466+
});
467+
builder.finish()
468+
}
469+
(None, None) => Int64Array::from_value(1, values.len()),
470+
};
471+
472+
Ok(vec![Arc::new(state_array)])
473+
}
474+
475+
fn convert_to_state_supported(&self) -> bool {
476+
true
477+
}
478+
436479
fn size(&self) -> usize {
437480
self.counts.capacity() * std::mem::size_of::<usize>()
438481
}

datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
20+
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder};
2121
use arrow::datatypes::ArrowPrimitiveType;
2222
use arrow::datatypes::DataType;
2323
use datafusion_common::Result;
@@ -134,6 +134,46 @@ where
134134
self.update_batch(values, group_indices, opt_filter, total_num_groups)
135135
}
136136

137+
fn convert_to_state(
138+
&self,
139+
values: &[ArrayRef],
140+
opt_filter: Option<&BooleanArray>,
141+
) -> Result<Vec<ArrayRef>> {
142+
let values = values[0].as_primitive::<T>();
143+
let mut state = PrimitiveBuilder::<T>::with_capacity(values.len())
144+
.with_data_type(self.data_type.clone());
145+
146+
match opt_filter {
147+
Some(filter) => {
148+
values
149+
.iter()
150+
.zip(filter.iter())
151+
.for_each(|(val, filter_val)| match (val, filter_val) {
152+
(Some(val), Some(true)) => {
153+
let mut state_val = self.starting_value;
154+
(self.prim_fn)(&mut state_val, val);
155+
state.append_value(state_val);
156+
}
157+
(_, _) => state.append_null(),
158+
})
159+
}
160+
None => values.iter().for_each(|val| match val {
161+
Some(val) => {
162+
let mut state_val = self.starting_value;
163+
(self.prim_fn)(&mut state_val, val);
164+
state.append_value(state_val);
165+
}
166+
None => state.append_null(),
167+
}),
168+
};
169+
170+
Ok(vec![Arc::new(state.finish())])
171+
}
172+
173+
fn convert_to_state_supported(&self) -> bool {
174+
true
175+
}
176+
137177
fn size(&self) -> usize {
138178
self.values.capacity() * std::mem::size_of::<T::Native>() + self.null_state.size()
139179
}

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2398,4 +2398,189 @@ mod tests {
23982398

23992399
Ok(())
24002400
}
2401+
2402+
#[tokio::test]
2403+
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
2404+
let schema = Arc::new(Schema::new(vec![
2405+
Field::new("key", DataType::Int32, true),
2406+
Field::new("val", DataType::Int32, true),
2407+
]));
2408+
let df_schema = DFSchema::try_from(schema.clone())?;
2409+
2410+
let group_by =
2411+
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2412+
2413+
let aggr_expr: Vec<Arc<dyn AggregateExpr>> =
2414+
vec![create_aggregate_expr_with_dfschema(
2415+
&count_udaf(),
2416+
&[col("val", &schema)?],
2417+
&[datafusion_expr::col("val")],
2418+
&[],
2419+
&[],
2420+
&df_schema,
2421+
"COUNT(val)",
2422+
false,
2423+
false,
2424+
false,
2425+
)?];
2426+
2427+
let input_data = vec![
2428+
RecordBatch::try_new(
2429+
Arc::clone(&schema),
2430+
vec![
2431+
Arc::new(Int32Array::from(vec![1, 2, 3])),
2432+
Arc::new(Int32Array::from(vec![0, 0, 0])),
2433+
],
2434+
)
2435+
.unwrap(),
2436+
RecordBatch::try_new(
2437+
Arc::clone(&schema),
2438+
vec![
2439+
Arc::new(Int32Array::from(vec![2, 3, 4])),
2440+
Arc::new(Int32Array::from(vec![0, 0, 0])),
2441+
],
2442+
)
2443+
.unwrap(),
2444+
];
2445+
2446+
let input = Arc::new(MemoryExec::try_new(
2447+
&[input_data],
2448+
Arc::clone(&schema),
2449+
None,
2450+
)?);
2451+
let aggregate_exec = Arc::new(AggregateExec::try_new(
2452+
AggregateMode::Partial,
2453+
group_by,
2454+
aggr_expr,
2455+
vec![None],
2456+
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2457+
schema,
2458+
)?);
2459+
2460+
let mut session_config = SessionConfig::default();
2461+
session_config = session_config.set(
2462+
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2463+
ScalarValue::Int64(Some(2)),
2464+
);
2465+
session_config = session_config.set(
2466+
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2467+
ScalarValue::Float64(Some(0.1)),
2468+
);
2469+
2470+
let ctx = TaskContext::default().with_session_config(session_config);
2471+
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2472+
2473+
let expected = [
2474+
"+-----+-------------------+",
2475+
"| key | COUNT(val)[count] |",
2476+
"+-----+-------------------+",
2477+
"| 1 | 1 |",
2478+
"| 2 | 1 |",
2479+
"| 3 | 1 |",
2480+
"| 2 | 1 |",
2481+
"| 3 | 1 |",
2482+
"| 4 | 1 |",
2483+
"+-----+-------------------+",
2484+
];
2485+
assert_batches_eq!(expected, &output);
2486+
2487+
Ok(())
2488+
}
2489+
2490+
#[tokio::test]
2491+
async fn test_skip_aggregation_after_threshold() -> Result<()> {
2492+
let schema = Arc::new(Schema::new(vec![
2493+
Field::new("key", DataType::Int32, true),
2494+
Field::new("val", DataType::Int32, true),
2495+
]));
2496+
let df_schema = DFSchema::try_from(schema.clone())?;
2497+
2498+
let group_by =
2499+
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2500+
2501+
let aggr_expr: Vec<Arc<dyn AggregateExpr>> =
2502+
vec![create_aggregate_expr_with_dfschema(
2503+
&count_udaf(),
2504+
&[col("val", &schema)?],
2505+
&[datafusion_expr::col("val")],
2506+
&[],
2507+
&[],
2508+
&df_schema,
2509+
"COUNT(val)",
2510+
false,
2511+
false,
2512+
false,
2513+
)?];
2514+
2515+
let input_data = vec![
2516+
RecordBatch::try_new(
2517+
Arc::clone(&schema),
2518+
vec![
2519+
Arc::new(Int32Array::from(vec![1, 2, 3])),
2520+
Arc::new(Int32Array::from(vec![0, 0, 0])),
2521+
],
2522+
)
2523+
.unwrap(),
2524+
RecordBatch::try_new(
2525+
Arc::clone(&schema),
2526+
vec![
2527+
Arc::new(Int32Array::from(vec![2, 3, 4])),
2528+
Arc::new(Int32Array::from(vec![0, 0, 0])),
2529+
],
2530+
)
2531+
.unwrap(),
2532+
RecordBatch::try_new(
2533+
Arc::clone(&schema),
2534+
vec![
2535+
Arc::new(Int32Array::from(vec![2, 3, 4])),
2536+
Arc::new(Int32Array::from(vec![0, 0, 0])),
2537+
],
2538+
)
2539+
.unwrap(),
2540+
];
2541+
2542+
let input = Arc::new(MemoryExec::try_new(
2543+
&[input_data],
2544+
Arc::clone(&schema),
2545+
None,
2546+
)?);
2547+
let aggregate_exec = Arc::new(AggregateExec::try_new(
2548+
AggregateMode::Partial,
2549+
group_by,
2550+
aggr_expr,
2551+
vec![None],
2552+
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2553+
schema,
2554+
)?);
2555+
2556+
let mut session_config = SessionConfig::default();
2557+
session_config = session_config.set(
2558+
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2559+
ScalarValue::Int64(Some(5)),
2560+
);
2561+
session_config = session_config.set(
2562+
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2563+
ScalarValue::Float64(Some(0.1)),
2564+
);
2565+
2566+
let ctx = TaskContext::default().with_session_config(session_config);
2567+
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2568+
2569+
let expected = [
2570+
"+-----+-------------------+",
2571+
"| key | COUNT(val)[count] |",
2572+
"+-----+-------------------+",
2573+
"| 1 | 1 |",
2574+
"| 2 | 2 |",
2575+
"| 3 | 2 |",
2576+
"| 4 | 1 |",
2577+
"| 2 | 1 |",
2578+
"| 3 | 1 |",
2579+
"| 4 | 1 |",
2580+
"+-----+-------------------+",
2581+
];
2582+
assert_batches_eq!(expected, &output);
2583+
2584+
Ok(())
2585+
}
24012586
}

0 commit comments

Comments
 (0)