Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,10 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely the better plan 👍 @DhamoPS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb I have checked the #2858. Even though, it would help in handling of disjunctive predicates, it does not solve the problem of #78. I understand that we need to write these rules in optimizer.rs, so that it would be applicable for DATAFRAME API plans as well.
I would convert my fix into optimizer rule as suggested. CrossJoins must be converted to InnerJoins if there is one or more common predicates between those tables.

" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
68 changes: 68 additions & 0 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,74 @@ async fn tpch_q17_correlated() -> Result<()> {
Ok(())
}

// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
#[tokio::test]
async fn tpch_q19_pull_predicates_to_innerjoin_simplified() -> Result<()> {
let ctx = SessionContext::new();

register_tpch_csv(&ctx, "part").await?;
register_tpch_csv(&ctx, "lineitem").await?;

let partsupp = r#"63700,7311,100,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff"#;
register_tpch_csv_data(&ctx, "partsupp", partsupp).await?;

let sql = r#"
select
p_partkey,
sum(l_extendedprice),
avg(l_discount),
count(distinct ps_suppkey)
from
lineitem,
part,
partsupp
where
(
p_partkey = l_partkey
and p_brand = 'Brand#12'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question about disjunctions can be rephrased as "what if this and was an or? would we still try to lift it out into a join predicate?"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that scenario, we don't lift the predicate up. I have tested the behavior by running the modified tests.

and p_partkey = ps_partkey
)
or
(
ps_partkey = p_partkey
and p_brand = 'Brand#23'
and p_partkey = l_partkey
)

group by p_partkey
;"#;

// assert plan
let plan = ctx.create_logical_plan(sql).unwrap();
debug!("input:\n{}", plan.display_indent());

let plan = ctx.optimize(&plan).unwrap();
let actual = format!("{}", plan.display_indent());
let expected = r#"Projection: #part.p_partkey, #SUM(lineitem.l_extendedprice), #AVG(lineitem.l_discount), #COUNT(DISTINCT partsupp.ps_suppkey)
Aggregate: groupBy=[[#part.p_partkey]], aggr=[[SUM(#lineitem.l_extendedprice), AVG(#lineitem.l_discount), COUNT(DISTINCT #partsupp.ps_suppkey)]]
Inner Join: #part.p_partkey = #partsupp.ps_partkey
Inner Join: #lineitem.l_partkey = #part.p_partkey
TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount]
Filter: #part.p_brand = Utf8("Brand#12") OR #part.p_brand = Utf8("Brand#23")
TableScan: part projection=[p_partkey, p_brand], partial_filters=[#part.p_brand = Utf8("Brand#12") OR #part.p_brand = Utf8("Brand#23")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
.to_string();
assert_eq!(actual, expected);

// assert data
let results = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------+-------------------------------+--------------------------+-------------------------------------+",
"| p_partkey | SUM(lineitem.l_extendedprice) | AVG(lineitem.l_discount) | COUNT(DISTINCT partsupp.ps_suppkey) |",
"+-----------+-------------------------------+--------------------------+-------------------------------------+",
"| 63700 | 13309.6 | 0.1 | 1 |",
"+-----------+-------------------------------+--------------------------+-------------------------------------+",
];
assert_batches_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn tpch_q20_correlated() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
42 changes: 41 additions & 1 deletion datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use datafusion_expr::utils::{
COUNT_STAR_EXPANSION,
};
use datafusion_expr::{
and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
and, or, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
WindowFrame, WindowFrameUnits,
};
use datafusion_expr::{
Expand Down Expand Up @@ -2453,6 +2453,17 @@ fn remove_join_expressions(
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
}
},
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can pull out ors the same way we would part of an and - altering a disjunction will affect the outcome. i.e. if it was l.partkey=s.partkey or l.price < 5.00 and we extract the first part into a join clause, we will now only see ones with prices < 5, not ones that matched or were less than 5. A contrived example I'm sure, but I think I'd have to see a counter-example working before I am confident we didn't alter the results.

I think a better place to solve this might be with an optimizer rule, instead of in the planner - in an optimizer rule we could always just return Err() if the situation isn't perfectly to our liking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I agree that OR predicates need to be handled differently.
OR predicates are handled in below code.

        ` Operator::Or => {
            let mut left_join_keys = vec![];
            let mut right_join_keys = vec![];

            extract_possible_join_keys(left, &mut left_join_keys)?;
            extract_possible_join_keys(right, &mut right_join_keys)?;

            intersect( accum, &left_join_keys, &right_join_keys)
        }`            

In case of OR predicates, we pull the predicates if both left and right child are having same JOIN predicates.
insect() function would ensure that we don't pull the all OR predicates.

In the example given above, l.partkey=s.partkey or l.price < 5.00, the left child of OR predicate would be l.partkey=s.partkey and right child of OR predicate would be ` l.price < 5.00'. There is no common predicates on left and right child of OR expr and so these predicates are not pulled to Join Predicates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not pull the predicates from disjunctions unless both left and right child of OR expr are having common conjunctions, since it will affect the query results. The idea for this fix is, "If there are common predicates in left and right child of OR expression then we should move those predicates to Join predicates. In case of Q19, it has common join predicate in all 3 OR expressions and so that Join predicate is pulled up to JOIN node.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think this kind of rewrite should be done in an SQL optimizer pass so that it will apply to any query plan (e.g. that came from the DataFrame API) rather than only SQL

let l = remove_join_expressions(left, join_columns)?;
let r = remove_join_expressions(right, join_columns)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(or(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
}
}
_ => Ok(Some(expr.clone())),
},
Expand Down Expand Up @@ -2526,6 +2537,25 @@ fn extract_join_keys(
}
}

fn intersect(
accum: &mut Vec<(Column, Column)>,
vec1: & Vec<(Column, Column)>,
vec2: & Vec<(Column, Column)>,
) -> Result<()> {

for x1 in vec1.iter() {
for x2 in vec2.iter() {
if x1.0 == x2.0 && x1.1 == x2.1
|| x1.1 == x2.0 && x1.0 == x2.1
{
accum.push((x1.0.clone(), x1.1.clone()));
}
}
}

Ok(())
}

/// Extract join keys from a WHERE clause
fn extract_possible_join_keys(
expr: &Expr,
Expand All @@ -2543,6 +2573,16 @@ fn extract_possible_join_keys(
Operator::And => {
extract_possible_join_keys(left, accum)?;
extract_possible_join_keys(right, accum)
},
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
let mut left_join_keys = vec![];
let mut right_join_keys = vec![];

extract_possible_join_keys(left, &mut left_join_keys)?;
extract_possible_join_keys(right, &mut right_join_keys)?;

intersect( accum, &left_join_keys, &right_join_keys)
}
_ => Ok(()),
},
Expand Down