diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index e98fd14cfbb0..b540b411a852 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -68,7 +68,7 @@ pub type PhysicalExprRef = Arc; /// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html -pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { +pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 383769173d7c..65b9a54f9ae6 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -329,7 +329,6 @@ impl JoinLeftData { /// Note this structure includes a [`OnceAsync`] that is used to coordinate the /// loading of the left side with the processing in each output stream. /// Therefore it can not be [`Clone`] -#[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed pub left: Arc, @@ -350,7 +349,7 @@ pub struct HashJoinExec { /// /// Each output stream waits on the `OnceAsync` to signal the completion of /// the hash table creation. - left_fut: OnceAsync, + left_fut: Arc>, /// Shared the `RandomState` for the hashing algorithm random_state: RandomState, /// Partitioning mode to use @@ -366,7 +365,29 @@ pub struct HashJoinExec { /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, /// Dynamic filter for pushing down to the probe side - dynamic_filter: Arc, + dynamic_filter: Option>, +} + +impl fmt::Debug for HashJoinExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HashJoinExec") + .field("left", &self.left) + .field("right", &self.right) + .field("on", &self.on) + .field("filter", &self.filter) + .field("join_type", &self.join_type) + .field("join_schema", &self.join_schema) + .field("left_fut", &self.left_fut) + .field("random_state", &self.random_state) + .field("mode", &self.mode) + .field("metrics", &self.metrics) + .field("projection", &self.projection) + .field("column_indices", &self.column_indices) + .field("null_equality", &self.null_equality) + .field("cache", &self.cache) + // Explicitly exclude dynamic_filter to avoid runtime state differences in tests + .finish() + } } impl HashJoinExec { @@ -413,8 +434,6 @@ impl HashJoinExec { projection.as_ref(), )?; - let dynamic_filter = Self::create_dynamic_filter(&on); - Ok(HashJoinExec { left, right, @@ -430,12 +449,13 @@ impl HashJoinExec { column_indices, null_equality, cache, - dynamic_filter, + dynamic_filter: None, }) } fn create_dynamic_filter(on: &JoinOn) -> Arc { - // Extract the right-side keys from the `on` clauses + // Extract the right-side keys (probe side keys) from the `on` clauses + // Dynamic filter will be created from build side values (left side) and applied to probe side (right side) let right_keys: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); // Initialize with a placeholder expression (true) that will be updated when the hash table is built Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) @@ -686,11 +706,14 @@ impl DisplayAs for HashJoinExec { .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); - let dynamic_filter_display = match self.dynamic_filter.current() { - Ok(current) if current != lit(true) => { - format!(", filter=[{current}]") - } - _ => "".to_string(), + let dynamic_filter_display = match self.dynamic_filter.as_ref() { + Some(dynamic_filter) => match dynamic_filter.current() { + Ok(current) if current != lit(true) => { + format!(", filter=[{current}]") + } + _ => "".to_string(), + }, + None => "".to_string(), }; write!( f, @@ -794,7 +817,7 @@ impl ExecutionPlan for HashJoinExec { self: Arc, children: Vec>, ) -> Result> { - let mut new_join = HashJoinExec::try_new( + let new_join = HashJoinExec::try_new( Arc::clone(&children[0]), Arc::clone(&children[1]), self.on.clone(), @@ -804,8 +827,6 @@ impl ExecutionPlan for HashJoinExec { self.mode, self.null_equality, )?; - // Preserve the dynamic filter if it exists - new_join.dynamic_filter = Arc::clone(&self.dynamic_filter); Ok(Arc::new(new_join)) } @@ -818,7 +839,7 @@ impl ExecutionPlan for HashJoinExec { filter: self.filter.clone(), join_type: self.join_type, join_schema: Arc::clone(&self.join_schema), - left_fut: OnceAsync::default(), + left_fut: Arc::new(OnceAsync::default()), random_state: self.random_state.clone(), mode: self.mode, metrics: ExecutionPlanMetricsSet::new(), @@ -826,7 +847,7 @@ impl ExecutionPlan for HashJoinExec { column_indices: self.column_indices.clone(), null_equality: self.null_equality, cache: self.cache.clone(), - dynamic_filter: Self::create_dynamic_filter(&self.on), + dynamic_filter: None, })) } @@ -886,7 +907,8 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown - .then_some(Arc::clone(&self.dynamic_filter)), + .then_some(self.dynamic_filter.clone()) + .flatten(), on_right.clone(), )) })?, @@ -906,7 +928,8 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown - .then_some(Arc::clone(&self.dynamic_filter)), + .then_some(self.dynamic_filter.clone()) + .flatten(), on_right.clone(), )) } @@ -1050,8 +1073,7 @@ impl ExecutionPlan for HashJoinExec { && config.optimizer.enable_dynamic_filter_pushdown { // Add actual dynamic filter to right side (probe side) - let dynamic_filter = - Arc::clone(&self.dynamic_filter) as Arc; + let dynamic_filter = Self::create_dynamic_filter(&self.on); right_child = right_child.with_self_filter(dynamic_filter); } @@ -1078,7 +1100,40 @@ impl ExecutionPlan for HashJoinExec { child_pushdown_result, )); } - Ok(FilterPushdownPropagation::if_any(child_pushdown_result)) + + let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); + assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children + let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child + // We expect 0 or 1 self filters + if let Some(filter) = right_child_self_filters.first() { + // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said + // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating + let predicate = Arc::clone(&filter.predicate); + if let Ok(dynamic_filter) = + Arc::downcast::(predicate) + { + // We successfully pushed down our self filter - we need to make a new node with the dynamic filter + let new_node = Arc::new(HashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + left_fut: Arc::clone(&self.left_fut), + random_state: self.random_state.clone(), + mode: self.mode, + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + dynamic_filter: Some(dynamic_filter), + }); + result = result.with_updated_node(new_node as Arc); + } + } + Ok(result) } } diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index c5009c248c31..c999aa71fe5b 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -286,5 +286,37 @@ explain select a from t where CAST(a AS string) = '0123'; physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8View) = 0123 +# Test dynamic filter pushdown with swapped join inputs (issue #17196) +# Create tables with different sizes to force join input swapping +statement ok +copy (select i as k from generate_series(1, 100) t(i)) to 'test_files/scratch/push_down_filter/small_table.parquet'; + +statement ok +copy (select i as k, i as v from generate_series(1, 1000) t(i)) to 'test_files/scratch/push_down_filter/large_table.parquet'; + +statement ok +create external table small_table stored as parquet location 'test_files/scratch/push_down_filter/small_table.parquet'; + +statement ok +create external table large_table stored as parquet location 'test_files/scratch/push_down_filter/large_table.parquet'; + +# Test that dynamic filter is applied to the correct table after join input swapping +# The small_table should be the build side, large_table should be the probe side with dynamic filter +query TT +explain select * from small_table join large_table on small_table.k = large_table.k where large_table.v >= 50; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(k@0, k@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/small_table.parquet]]}, projection=[k], file_type=parquet +04)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/large_table.parquet]]}, projection=[k, v], file_type=parquet, predicate=v@1 >= 50 AND DynamicFilterPhysicalExpr [ true ], pruning_predicate=v_null_count@1 != row_count@2 AND v_max@0 >= 50, required_guarantees=[] + +statement ok +drop table small_table; + +statement ok +drop table large_table; + statement ok drop table t;