Skip to content

Commit 964dee0

Browse files
committed
refactor: encapsulate most of the logic within ExprSet, and delineate the expr_identifier from the alias symbol
1 parent 87eb784 commit 964dee0

File tree

1 file changed

+57
-45
lines changed

1 file changed

+57
-45
lines changed

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Eliminate common sub-expression.
1919
20+
use std::collections::hash_map::Entry;
2021
use std::collections::{BTreeSet, HashMap};
2122
use std::sync::Arc;
2223

@@ -35,16 +36,42 @@ use datafusion_expr::expr::Alias;
3536
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
3637
use datafusion_expr::{col, Expr, ExprSchemable};
3738

38-
/// A map from expression's identifier to tuple including
39-
///
40-
/// key == Identifier created with only the current node (& subtree)
41-
///
42-
/// values:
43-
/// - the expression itself (cloned)
44-
/// - counter
45-
/// - DataType of this expression.
46-
/// - symbol used as the identifier in the alias
47-
type ExprSet = HashMap<Identifier, (Expr, usize, DataType, Identifier)>;
39+
/// Set of expressions generated by the [`ExprIdentifierVisitor`]
40+
/// and consumed by the [`CommonSubexprRewriter`].
41+
#[derive(Default)]
42+
struct ExprSet {
43+
/// A map from expression's identifier (stringified expr) to tuple including:
44+
/// - the expression itself (cloned)
45+
/// - counter
46+
/// - DataType of this expression.
47+
/// - symbol used as the identifier in the alias.
48+
map: HashMap<Identifier, (Expr, usize, DataType, Identifier)>,
49+
}
50+
51+
impl ExprSet {
52+
fn expr_identifier(expr: &Expr) -> Identifier {
53+
format!("{expr}")
54+
}
55+
56+
fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> {
57+
self.map.get(key)
58+
}
59+
60+
fn entry(
61+
&mut self,
62+
key: Identifier,
63+
) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> {
64+
self.map.entry(key)
65+
}
66+
}
67+
68+
impl From<Vec<(Identifier, (Expr, usize, DataType, Identifier))>> for ExprSet {
69+
fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self {
70+
let mut expr_set = Self::default();
71+
let _ = entries.into_iter().map(|(k, v)| expr_set.map.insert(k, v));
72+
expr_set
73+
}
74+
}
4875

4976
/// Identifier for each subexpression.
5077
///
@@ -134,7 +161,7 @@ impl CommonSubexprEliminate {
134161
config: &dyn OptimizerConfig,
135162
) -> Result<LogicalPlan> {
136163
let mut window_exprs = vec![];
137-
let mut expr_set = ExprSet::new();
164+
let mut expr_set = ExprSet::default();
138165

139166
// Get all window expressions inside the consecutive window operators.
140167
// Consecutive window expressions may refer to same complex expression.
@@ -206,7 +233,7 @@ impl CommonSubexprEliminate {
206233
input,
207234
..
208235
} = aggregate;
209-
let mut expr_set = ExprSet::new();
236+
let mut expr_set = ExprSet::default();
210237

211238
// build expr_set, with groupby and aggr
212239
let input_schema = Arc::clone(input.schema());
@@ -226,7 +253,7 @@ impl CommonSubexprEliminate {
226253
let new_group_expr = pop_expr(&mut new_expr)?;
227254

228255
// create potential projection on top
229-
let mut expr_set = ExprSet::new();
256+
let mut expr_set = ExprSet::default();
230257
let new_input_schema = Arc::clone(new_input.schema());
231258
populate_expr_set(
232259
&new_aggr_expr,
@@ -277,9 +304,7 @@ impl CommonSubexprEliminate {
277304
agg_exprs.push(expr.alias(&name));
278305
proj_exprs.push(Expr::Column(Column::from_name(name)));
279306
} else {
280-
let id = ExprIdentifierVisitor::<'static>::expr_identifier(
281-
&expr_rewritten,
282-
);
307+
let id = ExprSet::expr_identifier(&expr_rewritten);
283308
let out_name =
284309
expr_rewritten.to_field(&new_input_schema)?.qualified_name();
285310
agg_exprs.push(expr_rewritten.alias(&id));
@@ -313,7 +338,7 @@ impl CommonSubexprEliminate {
313338
let inputs = plan.inputs();
314339
let input = inputs[0];
315340
let input_schema = Arc::clone(input.schema());
316-
let mut expr_set = ExprSet::new();
341+
let mut expr_set = ExprSet::default();
317342

318343
// Visit expr list and build expr identifier to occuring count map (`expr_set`).
319344
populate_expr_set(&expr, input_schema, &mut expr_set, ExprMask::Normal)?;
@@ -572,10 +597,6 @@ enum VisitRecord {
572597
}
573598

574599
impl ExprIdentifierVisitor<'_> {
575-
fn expr_identifier(expr: &Expr) -> Identifier {
576-
format!("{expr}")
577-
}
578-
579600
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
580601
/// before it.
581602
fn pop_enter_mark(&mut self) -> (usize, Identifier) {
@@ -619,21 +640,22 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
619640

620641
// skip exprs should not be recognize.
621642
if self.expr_mask.ignores(expr) {
622-
let curr_expr_identifier = Self::expr_identifier(expr);
643+
let curr_expr_identifier = ExprSet::expr_identifier(expr);
623644
self.visit_stack
624645
.push(VisitRecord::ExprItem(curr_expr_identifier));
625646
return Ok(TreeNodeRecursion::Continue);
626647
}
627-
let curr_expr_identifier = Self::expr_identifier(expr);
628-
let desc = format!("{curr_expr_identifier}{sub_expr_identifier}");
648+
let curr_expr_identifier = ExprSet::expr_identifier(expr);
649+
let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}");
629650

630-
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
651+
self.visit_stack
652+
.push(VisitRecord::ExprItem(alias_symbol.clone()));
631653

632654
let data_type = expr.get_type(&self.input_schema)?;
633655

634656
self.expr_set
635657
.entry(curr_expr_identifier)
636-
.or_insert_with(|| (expr.clone(), 0, data_type, desc))
658+
.or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol))
637659
.1 += 1;
638660
Ok(TreeNodeRecursion::Continue)
639661
}
@@ -666,12 +688,6 @@ struct CommonSubexprRewriter<'a> {
666688
affected_id: &'a mut BTreeSet<Identifier>,
667689
}
668690

669-
impl CommonSubexprRewriter<'_> {
670-
fn expr_identifier(expr: &Expr) -> Identifier {
671-
format!("{expr}")
672-
}
673-
}
674-
675691
impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
676692
type Node = Expr;
677693

@@ -683,7 +699,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
683699
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
684700
}
685701

686-
let curr_id = &Self::expr_identifier(&expr);
702+
let curr_id = &ExprSet::expr_identifier(&expr);
687703

688704
// lookup previously visited expression
689705
match self.expr_set.get(curr_id) {
@@ -997,7 +1013,7 @@ mod test {
9971013
let table_scan = test_table_scan().unwrap();
9981014
let affected_id: BTreeSet<Identifier> =
9991015
["c+a".to_string(), "b+a".to_string()].into_iter().collect();
1000-
let expr_set_1 = [
1016+
let expr_set_1 = vec![
10011017
(
10021018
"c+a".to_string(),
10031019
(col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()),
@@ -1007,9 +1023,8 @@ mod test {
10071023
(col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()),
10081024
),
10091025
]
1010-
.into_iter()
1011-
.collect();
1012-
let expr_set_2 = [
1026+
.into();
1027+
let expr_set_2 = vec![
10131028
(
10141029
"c+a".to_string(),
10151030
(col("c+a"), 1, DataType::UInt32, "c+a".to_string()),
@@ -1019,8 +1034,7 @@ mod test {
10191034
(col("b+a"), 1, DataType::UInt32, "b+a".to_string()),
10201035
),
10211036
]
1022-
.into_iter()
1023-
.collect();
1037+
.into();
10241038
let project =
10251039
build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1)
10261040
.unwrap();
@@ -1046,7 +1060,7 @@ mod test {
10461060
["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()]
10471061
.into_iter()
10481062
.collect();
1049-
let expr_set_1 = [
1063+
let expr_set_1 = vec![
10501064
(
10511065
"test1.c+test1.a".to_string(),
10521066
(
@@ -1066,9 +1080,8 @@ mod test {
10661080
),
10671081
),
10681082
]
1069-
.into_iter()
1070-
.collect();
1071-
let expr_set_2 = [
1083+
.into();
1084+
let expr_set_2 = vec![
10721085
(
10731086
"test1.c+test1.a".to_string(),
10741087
(
@@ -1088,8 +1101,7 @@ mod test {
10881101
),
10891102
),
10901103
]
1091-
.into_iter()
1092-
.collect();
1104+
.into();
10931105
let project =
10941106
build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1)
10951107
.unwrap();

0 commit comments

Comments
 (0)