Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ use crate::aggregates::{
PhysicalGroupBy,
};
use crate::common::IPCWriter;
use crate::metrics::{BaselineMetrics, RecordOutput};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
use crate::sorts::streaming_merge;
use crate::spill::read_spill_as_stream;
use crate::stream::RecordBatchStreamAdapter;
use crate::{aggregates, ExecutionPlan, PhysicalExpr};
use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
use crate::{RecordBatchStream, SendableRecordBatchStream};

use arrow::array::*;
Expand Down Expand Up @@ -117,17 +117,30 @@ struct SkipAggregationProbe {
/// Flag indicating that further updates of `SkipAggregationProbe`
/// state won't make any effect
is_locked: bool,

/// Number of rows where state was output without aggregation.
///
/// * If 0, all input rows were aggregated (should_skip was always false)
///
/// * if greater than zero, the number of rows which were output directly
/// without aggregation
skipped_aggregation_rows: metrics::Count,
}

impl SkipAggregationProbe {
fn new(probe_rows_threshold: usize, probe_ratio_threshold: f64) -> Self {
fn new(
probe_rows_threshold: usize,
probe_ratio_threshold: f64,
skipped_aggregation_rows: metrics::Count,
) -> Self {
Self {
input_rows: 0,
num_groups: 0,
probe_rows_threshold,
probe_ratio_threshold,
should_skip: false,
is_locked: false,
skipped_aggregation_rows,
}
}

Expand Down Expand Up @@ -160,6 +173,11 @@ impl SkipAggregationProbe {
self.should_skip = false;
self.is_locked = true;
}

/// Record the number of rows that were output directly without aggregation
fn record_skipped(&mut self, batch: &RecordBatch) {
self.skipped_aggregation_rows.add(batch.num_rows());
}
}

/// HashTable based Grouping Aggregator
Expand Down Expand Up @@ -473,17 +491,17 @@ impl GroupedHashAggregateStream {
.all(|acc| acc.supports_convert_to_state())
&& agg_group_by.is_single()
{
let options = &context.session_config().options().execution;
let probe_rows_threshold =
options.skip_partial_aggregation_probe_rows_threshold;
let probe_ratio_threshold =
options.skip_partial_aggregation_probe_ratio_threshold;
let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that it should be better not to add this counter into all group bys (as another baseline metric) by default 👍

.counter("skipped_aggregation_rows", partition);
Some(SkipAggregationProbe::new(
context
.session_config()
.options()
.execution
.skip_partial_aggregation_probe_rows_threshold,
context
.session_config()
.options()
.execution
.skip_partial_aggregation_probe_ratio_threshold,
probe_rows_threshold,
probe_ratio_threshold,
skipped_aggregation_rows,
))
} else {
None
Expand Down Expand Up @@ -611,6 +629,9 @@ impl Stream for GroupedHashAggregateStream {
match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let _timer = elapsed_compute.timer();
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
probe.record_skipped(&batch);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual call that records the metrics . The rest of the PR is comments and plumbing

}
let states = self.transform_to_states(batch)?;
return Poll::Ready(Some(Ok(
states.record_output(&self.baseline_metrics)
Expand Down