diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 93df0dcfd500..27b69894aaf2 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -96,6 +96,7 @@ impl OptimizerRule for EliminateCrossJoin { filter.input.as_ref(), LogicalPlan::Join(Join { join_type: JoinType::Inner, + null_equals_null: false, .. }) | LogicalPlan::CrossJoin(_) ); @@ -124,6 +125,7 @@ impl OptimizerRule for EliminateCrossJoin { plan, LogicalPlan::Join(Join { join_type: JoinType::Inner, + null_equals_null: false, .. }) ) { @@ -268,6 +270,7 @@ fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { match child { LogicalPlan::Join(Join { join_type: JoinType::Inner, + null_equals_null: false, .. }) | LogicalPlan::CrossJoin(_) => { diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 0dae777ab5bd..ead02abada52 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_expr::utils::split_conjunction_owned; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; // equijoin predicate @@ -67,57 +67,135 @@ impl OptimizerRule for ExtractEquijoinPredicate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Join(Join { - left, - right, - mut on, - filter: Some(expr), - join_type, - join_constraint, - schema, - null_equals_null, - }) => { - let left_schema = left.schema(); - let right_schema = right.schema(); - let (equijoin_predicates, non_equijoin_expr) = - split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?; - - if !equijoin_predicates.is_empty() { - on.extend(equijoin_predicates); - Ok(Transformed::yes(LogicalPlan::Join(Join { - left, - right, - on, - filter: non_equijoin_expr, - join_type, - join_constraint, - schema, - null_equals_null, - }))) - } else { - Ok(Transformed::no(LogicalPlan::Join(Join { - left, - right, - on, - filter: non_equijoin_expr, - join_type, - join_constraint, - schema, - null_equals_null, - }))) - } - } + LogicalPlan::Join(join) => extract_equijoin_predicate(join), _ => Ok(Transformed::no(plan)), } } } +fn extract_equijoin_predicate(join: Join) -> Result> { + fn update_join_predicate( + join: Join, + extra_on: Vec, + filter: Option, + null_equals_null: bool, + ) -> Transformed { + if extra_on.is_empty() { + Transformed::no(LogicalPlan::Join(join)) + } else { + let mut on = join.on; + on.extend(extra_on); + Transformed::yes(LogicalPlan::Join(Join { + left: join.left, + right: join.right, + on, + filter, + join_type: join.join_type, + join_constraint: join.join_constraint, + schema: join.schema, + null_equals_null, + })) + } + } + if join.filter.is_none() { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + let expr = join.filter.as_ref().unwrap(); + + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); + + if join.on.is_empty() { + let (eq_predicates, null_equals_null, non_eq_expr) = + split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?; + Ok(update_join_predicate( + join, + eq_predicates, + non_eq_expr, + null_equals_null, + )) + } else if join.null_equals_null { + let (eq_predicates, non_eq_expr) = + split_eq_and_noneq_join_predicate_nulls_eq(expr, left_schema, right_schema)?; + + Ok(update_join_predicate( + join, + eq_predicates, + non_eq_expr, + true, + )) + } else { + let (eq_predicates, non_eq_expr) = + split_eq_and_noneq_join_predicate_nulls_not_eq( + expr, + left_schema, + right_schema, + )?; + + Ok(update_join_predicate( + join, + eq_predicates, + non_eq_expr, + false, + )) + } +} + fn split_eq_and_noneq_join_predicate( - filter: Expr, + filter: &Expr, + left_schema: &DFSchema, + right_schema: &DFSchema, +) -> Result<(Vec, bool, Option)> { + let (eq, noneq) = split_eq_and_noneq_join_predicate_nulls_not_eq( + filter, + left_schema, + right_schema, + )?; + if !eq.is_empty() { + Ok((eq, false, noneq)) + } else { + let (eq, noneq) = split_eq_and_noneq_join_predicate_nulls_eq( + filter, + left_schema, + right_schema, + )?; + Ok((eq, true, noneq)) + } +} + +fn split_eq_and_noneq_join_predicate_nulls_eq( + filter: &Expr, + left_schema: &DFSchema, + right_schema: &DFSchema, +) -> Result<(Vec, Option)> { + split_eq_and_noneq_join_predicate_impl( + filter, + left_schema, + right_schema, + Operator::IsNotDistinctFrom, + ) +} + +fn split_eq_and_noneq_join_predicate_nulls_not_eq( + filter: &Expr, + left_schema: &DFSchema, + right_schema: &DFSchema, +) -> Result<(Vec, Option)> { + split_eq_and_noneq_join_predicate_impl( + filter, + left_schema, + right_schema, + Operator::Eq, + ) +} + +fn split_eq_and_noneq_join_predicate_impl( + filter: &Expr, left_schema: &DFSchema, right_schema: &DFSchema, + eq_op: Operator, ) -> Result<(Vec, Option)> { - let exprs = split_conjunction_owned(filter); + let exprs = split_conjunction(filter); let mut accum_join_keys: Vec<(Expr, Expr)> = vec![]; let mut accum_filters: Vec = vec![]; @@ -125,9 +203,9 @@ fn split_eq_and_noneq_join_predicate( match expr { Expr::BinaryExpr(BinaryExpr { ref left, - op: Operator::Eq, + op, ref right, - }) => { + }) if *op == eq_op => { let join_key_pair = find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?; @@ -138,13 +216,13 @@ fn split_eq_and_noneq_join_predicate( if can_hash(&left_expr_type) && can_hash(&right_expr_type) { accum_join_keys.push((left_expr, right_expr)); } else { - accum_filters.push(expr); + accum_filters.push(expr.clone()); } } else { - accum_filters.push(expr); + accum_filters.push(expr.clone()); } } - _ => accum_filters.push(expr), + _ => accum_filters.push(expr.clone()), } } diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 21fea4ad1025..1f3305f83df9 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -766,6 +766,50 @@ set datafusion.execution.target_partitions = 4; statement ok set datafusion.optimizer.repartition_joins = false; +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +# tables for join nulls equals null +statement ok +CREATE TABLE IF NOT EXISTS t1(t1_id INT NULL, t1_int INT) AS VALUES +(11, 1), +(22, 2), +(NULL, 3); + +statement ok +CREATE TABLE IF NOT EXISTS t2(t2_id INT NULL, t2_int INT) AS VALUES +(11, 3), +(33, 1), +(NULL, 5), +(NULL, 6); + + +# IS NOT DISTRINCT can be transformed into equijoin +query TT +EXPLAIN SELECT t1_id, t1_int, t2_int FROM t1 JOIN t2 ON t1_id IS NOT DISTINCT from t2_id +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_int, t2.t2_int +02)--Inner Join: t1.t1_id = t2.t2_id +03)----TableScan: t1 projection=[t1_id, t1_int] +04)----TableScan: t2 projection=[t2_id, t2_int] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0, t1_int@1, t2_int@3] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +query III rowsort +SELECT t1_id, t1_int, t2_int FROM t1 JOIN t2 ON t1_id IS NOT DISTINCT from t2_id +---- +11 1 3 +NULL 3 5 +NULL 3 6 + + statement ok DROP TABLE t1;