Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct TestOpener {
batch_size: Option<usize>,
schema: Option<SchemaRef>,
projection: Option<Vec<usize>>,
predicate: Option<Arc<dyn PhysicalExpr>>,
}

impl FileOpener for TestOpener {
Expand All @@ -61,6 +62,12 @@ impl FileOpener for TestOpener {
_file_meta: FileMeta,
_file: PartitionedFile,
) -> Result<FileOpenFuture> {
if let Some(predicate) = &self.predicate {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take this out, just leaving here for now for visualization.

I wonder what are good ways to regression test this 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best I can think of would be to make a custom ExecutionPlan that delays partition 0 for 1s or something and check that we still read the expected number of rows on the probe side

Copy link
Contributor Author

@rkrishn7 rkrishn7 Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @adriangb, I modified the existing partitioned test to inspect metrics from the probe side scan and verify rows are correctly being filtered out there. I think we can get away without the delay since I applied the patch to main and the test became flaky, as expected. Does that sound good to you or did you have something else in mind?

println!(
"Predicate when calling open: {}",
fmt_sql(predicate.as_ref())
);
}
let mut batches = self.batches.clone();
if let Some(batch_size) = self.batch_size {
let batch = concat_batches(&batches[0].schema(), &batches)?;
Expand Down Expand Up @@ -133,6 +140,7 @@ impl FileSource for TestSource {
batch_size: self.batch_size,
schema: self.schema.clone(),
projection: self.projection.clone(),
predicate: self.predicate.clone(),
})
}

Expand Down
129 changes: 105 additions & 24 deletions datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
// TODO: include the link to the Dynamic Filter blog post.

use std::fmt;
use std::future::Future;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::joins::PartitionMode;
Expand All @@ -30,6 +32,7 @@ use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr};
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use futures::task::AtomicWaker;
use itertools::Itertools;
use parking_lot::Mutex;

Expand Down Expand Up @@ -119,7 +122,9 @@ struct SharedBoundsState {
/// Each element represents the column bounds computed by one partition.
bounds: Vec<PartitionBounds>,
/// Number of partitions that have reported completion.
completed_partitions: usize,
completed_partitions: Arc<AtomicUsize>,
/// Cached wakers to wake when all partitions are complete
wakers: Vec<Arc<AtomicWaker>>,
}

impl SharedBoundsAccumulator {
Expand Down Expand Up @@ -170,7 +175,8 @@ impl SharedBoundsAccumulator {
Self {
inner: Mutex::new(SharedBoundsState {
bounds: Vec::with_capacity(expected_calls),
completed_partitions: 0,
completed_partitions: Arc::new(AtomicUsize::new(0)),
wakers: Vec::new(),
}),
total_partitions: expected_calls,
dynamic_filter,
Expand Down Expand Up @@ -253,39 +259,69 @@ impl SharedBoundsAccumulator {
/// bounds from the current partition, increments the completion counter, and when all
/// partitions have reported, creates an OR'd filter from individual partition bounds.
///
/// It returns a [`BoundsWaiter`] future that can be awaited to ensure the filter has been
/// updated before proceeding. This is important to delay probe-side scans until the filter
/// is ready.
///
/// # Arguments
/// * `partition` - The partition identifier reporting its bounds
/// * `partition_bounds` - The bounds computed by this partition (if any)
///
/// # Returns
/// * `Result<()>` - Ok if successful, Err if filter update failed
/// * `Result<Option<BoundsWaiter>>` - Ok if successful, Err if filter update failed
pub(crate) fn report_partition_bounds(
&self,
partition: usize,
partition_bounds: Option<Vec<ColumnBounds>>,
) -> Result<()> {
let mut inner = self.inner.lock();
) -> Result<Option<BoundsWaiter>> {
// Scope for lock to avoid holding it across await points
let maybe_waiter = {
let mut inner = self.inner.lock();

// Store bounds in the accumulator - this runs once per partition
if let Some(bounds) = partition_bounds {
// Only push actual bounds if they exist
inner.bounds.push(PartitionBounds::new(partition, bounds));
}
// Store bounds in the accumulator - this runs once per partition
if let Some(bounds) = partition_bounds {
// Only push actual bounds if they exist
inner.bounds.push(PartitionBounds::new(partition, bounds));
}

// Increment the completion counter
// Even empty partitions must report to ensure proper termination
inner.completed_partitions += 1;
let completed = inner.completed_partitions;
let total_partitions = self.total_partitions;

// Critical synchronization point: Only update the filter when ALL partitions are complete
// Troubleshooting: If you see "completed > total_partitions", check partition
// count calculation in new_from_partition_mode() - it may not match actual execution calls
if completed == total_partitions && !inner.bounds.is_empty() {
let filter_expr = self.create_filter_from_partition_bounds(&inner.bounds)?;
self.dynamic_filter.update(filter_expr)?;
}
// Increment the completion counter
// Even empty partitions must report to ensure proper termination
inner
.completed_partitions
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let completed = inner
.completed_partitions
.load(std::sync::atomic::Ordering::SeqCst);
let total_partitions = self.total_partitions;

// Critical synchronization point: Only update the filter when ALL partitions are complete
// Troubleshooting: If you see "completed > total_partitions", check partition
// count calculation in new_from_partition_mode() - it may not match actual execution calls
if completed == total_partitions {
if !inner.bounds.is_empty() {
let filter_expr =
self.create_filter_from_partition_bounds(&inner.bounds)?;
self.dynamic_filter.update(filter_expr)?;
}

// Notify any waiters that the filter is ready
for waker in inner.wakers.drain(..) {
waker.wake();
}

None
} else {
let waker = Arc::new(AtomicWaker::new());
inner.wakers.push(Arc::clone(&waker));
Some(BoundsWaiter::new(
total_partitions,
Arc::clone(&inner.completed_partitions),
waker,
))
}
};

Ok(())
Ok(maybe_waiter)
}
}

Expand All @@ -294,3 +330,48 @@ impl fmt::Debug for SharedBoundsAccumulator {
write!(f, "SharedBoundsAccumulator")
}
}

/// Utility future to wait until all partitions have reported completion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to a Barrier -- https://docs.rs/tokio/latest/tokio/sync/struct.Barrier.html

Maybe as a follow on PR we could simplify the code a bit by potentially using that pre-defied structure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - @rkrishn7 could you refactor this to use Barrier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, I was reaching for Barrier at first but I didn't see many tokio sync primitives being used in the crate so I decided against it. Agreed that using Barrier is preferable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you rather push to this PR or do it as a followup? Either is good to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll handle in this PR! Should get it done shortly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb @adriangb Done ✅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is looking pretty sweet now that we have the barrier in there :bowtie:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is some very nice code @rkrishn7 !

/// and the dynamic filter has been updated.
#[derive(Clone)]
pub(crate) struct BoundsWaiter {
waker: Arc<AtomicWaker>,
total: usize,
completed: Arc<AtomicUsize>,
}

impl BoundsWaiter {
pub fn new(
total: usize,
completed: Arc<AtomicUsize>,
waker: Arc<AtomicWaker>,
) -> Self {
Self {
waker,
total,
completed,
}
}
}

impl Future for BoundsWaiter {
type Output = ();

fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
// Quick check to avoid registration if already complete
if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total {
return std::task::Poll::Ready(());
}

self.waker.register(cx.waker());

if self.completed.load(std::sync::atomic::Ordering::Relaxed) >= self.total {
std::task::Poll::Ready(())
} else {
std::task::Poll::Pending
}
}
}
31 changes: 26 additions & 5 deletions datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::sync::Arc;
use std::task::Poll;

use crate::joins::hash_join::exec::JoinLeftData;
use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator;
use crate::joins::hash_join::shared_bounds::{BoundsWaiter, SharedBoundsAccumulator};
use crate::joins::utils::{
equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut,
};
Expand All @@ -50,7 +50,7 @@ use datafusion_common::{
use datafusion_physical_expr::PhysicalExprRef;

use ahash::RandomState;
use futures::{ready, Stream, StreamExt};
use futures::{ready, FutureExt, Stream, StreamExt};

/// Represents build-side of hash join.
pub(super) enum BuildSide {
Expand Down Expand Up @@ -120,6 +120,8 @@ impl BuildSide {
pub(super) enum HashJoinStreamState {
/// Initial state for HashJoinStream indicating that build-side data not collected yet
WaitBuildSide,
/// Waiting for bounds to be reported by all partitions
WaitPartitionBoundsReport,
/// Indicates that build-side has been collected, and stream is ready for fetching probe-side
FetchProbeBatch,
/// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed
Expand Down Expand Up @@ -205,6 +207,9 @@ pub(super) struct HashJoinStream {
right_side_ordered: bool,
/// Shared bounds accumulator for coordinating dynamic filter updates (optional)
bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
/// Optional future to signal when bounds have been reported by all partitions
/// and the dynamic filter has been updated
bounds_waiter: Option<BoundsWaiter>,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -325,6 +330,7 @@ impl HashJoinStream {
hashes_buffer,
right_side_ordered,
bounds_accumulator,
bounds_waiter: None,
}
}

Expand All @@ -339,6 +345,9 @@ impl HashJoinStream {
HashJoinStreamState::WaitBuildSide => {
handle_state!(ready!(self.collect_build_side(cx)))
}
HashJoinStreamState::WaitPartitionBoundsReport => {
handle_state!(ready!(self.wait_for_partition_bounds_report(cx)))
}
HashJoinStreamState::FetchProbeBatch => {
handle_state!(ready!(self.fetch_probe_batch(cx)))
}
Expand All @@ -355,6 +364,17 @@ impl HashJoinStream {
}
}

fn wait_for_partition_bounds_report(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
if let Some(ref mut fut) = self.bounds_waiter {
ready!(fut.poll_unpin(cx));
}
self.state = HashJoinStreamState::FetchProbeBatch;
Poll::Ready(Ok(StatefulStreamResult::Continue))
}

/// Collects build-side data by polling `OnceFut` future from initialized build-side
///
/// Updates build-side to `Ready`, and state to `FetchProbeSide`
Expand All @@ -376,13 +396,14 @@ impl HashJoinStream {
// Dynamic filter coordination between partitions:
// Report bounds to the accumulator which will handle synchronization and filter updates
if let Some(ref bounds_accumulator) = self.bounds_accumulator {
bounds_accumulator
self.bounds_waiter = bounds_accumulator
.report_partition_bounds(self.partition, left_data.bounds.clone())?;
self.state = HashJoinStreamState::WaitPartitionBoundsReport;
} else {
self.state = HashJoinStreamState::FetchProbeBatch;
}

self.state = HashJoinStreamState::FetchProbeBatch;
self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });

Poll::Ready(Ok(StatefulStreamResult::Continue))
}

Expand Down