@@ -8,6 +8,7 @@ use crate::deriving::generic::ty::*;
88use crate :: deriving:: generic:: * ;
99use crate :: deriving:: { path_local, path_std} ;
1010
11+ /// Expands a `#[derive(PartialEq)]` attribute into an implementation for the target item.
1112pub ( crate ) fn expand_deriving_partial_eq (
1213 cx : & ExtCtxt < ' _ > ,
1314 span : Span ,
@@ -16,62 +17,6 @@ pub(crate) fn expand_deriving_partial_eq(
1617 push : & mut dyn FnMut ( Annotatable ) ,
1718 is_const : bool ,
1819) {
19- fn cs_eq ( cx : & ExtCtxt < ' _ > , span : Span , substr : & Substructure < ' _ > ) -> BlockOrExpr {
20- let base = true ;
21- let expr = cs_fold (
22- true , // use foldl
23- cx,
24- span,
25- substr,
26- |cx, fold| match fold {
27- CsFold :: Single ( field) => {
28- let [ other_expr] = & field. other_selflike_exprs [ ..] else {
29- cx. dcx ( )
30- . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
31- } ;
32-
33- // We received arguments of type `&T`. Convert them to type `T` by stripping
34- // any leading `&`. This isn't necessary for type checking, but
35- // it results in better error messages if something goes wrong.
36- //
37- // Note: for arguments that look like `&{ x }`, which occur with packed
38- // structs, this would cause expressions like `{ self.x } == { other.x }`,
39- // which isn't valid Rust syntax. This wouldn't break compilation because these
40- // AST nodes are constructed within the compiler. But it would mean that code
41- // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42- // syntax, which would be suboptimal. So we wrap these in parens, giving
43- // `({ self.x }) == ({ other.x })`, which is valid syntax.
44- let convert = |expr : & P < Expr > | {
45- if let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) =
46- & expr. kind
47- {
48- if let ExprKind :: Block ( ..) = & inner. kind {
49- // `&{ x }` form: remove the `&`, add parens.
50- cx. expr_paren ( field. span , inner. clone ( ) )
51- } else {
52- // `&x` form: remove the `&`.
53- inner. clone ( )
54- }
55- } else {
56- expr. clone ( )
57- }
58- } ;
59- cx. expr_binary (
60- field. span ,
61- BinOpKind :: Eq ,
62- convert ( & field. self_expr ) ,
63- convert ( other_expr) ,
64- )
65- }
66- CsFold :: Combine ( span, expr1, expr2) => {
67- cx. expr_binary ( span, BinOpKind :: And , expr1, expr2)
68- }
69- CsFold :: Fieldless => cx. expr_bool ( span, base) ,
70- } ,
71- ) ;
72- BlockOrExpr :: new_expr ( expr)
73- }
74-
7520 let structural_trait_def = TraitDef {
7621 span,
7722 path : path_std ! ( marker:: StructuralPartialEq ) ,
@@ -97,7 +42,9 @@ pub(crate) fn expand_deriving_partial_eq(
9742 ret_ty: Path ( path_local!( bool ) ) ,
9843 attributes: thin_vec![ cx. attr_word( sym:: inline, span) ] ,
9944 fieldless_variants_strategy: FieldlessVariantsStrategy :: Unify ,
100- combine_substructure: combine_substructure( Box :: new( |a, b, c| cs_eq( a, b, c) ) ) ,
45+ combine_substructure: combine_substructure( Box :: new( |a, b, c| {
46+ BlockOrExpr :: new_expr( get_substructure_equality_expr( a, b, c) )
47+ } ) ) ,
10148 } ] ;
10249
10350 let trait_def = TraitDef {
@@ -113,3 +60,138 @@ pub(crate) fn expand_deriving_partial_eq(
11360 } ;
11461 trait_def. expand ( cx, mitem, item, push)
11562}
63+
64+ /// Generates the equality expression for a struct or enum variant when deriving `PartialEq`.
65+ ///
66+ /// This function constructs an expression that compares all fields of a struct or enum variant for equality.
67+ /// It groups scalar and compound types separately, combining their comparisons efficiently:
68+ /// - If there are no fields, returns `true` (fieldless types are always equal to themselves).
69+ /// - Scalar fields are compared first for efficiency, then compound fields.
70+ /// - If only one group is non-empty, returns its comparison directly.
71+ /// - Otherwise, returns a conjunction (logical AND) of both groups' comparisons.
72+ ///
73+ /// For enums with discriminants, compares the discriminant first, then the rest of the fields.
74+ ///
75+ /// # Panics
76+ ///
77+ /// If called on static or all-fieldless enums/structs, which should not occur during derive expansion.
78+ fn get_substructure_equality_expr (
79+ cx : & ExtCtxt < ' _ > ,
80+ span : Span ,
81+ substructure : & Substructure < ' _ > ,
82+ ) -> P < Expr > {
83+ /// Combines the accumulated comparison expression with the next field's comparison using logical AND.
84+ ///
85+ /// If this is the first field, initializes the accumulator. Otherwise, chains with logical AND.
86+ fn combine ( cx : & ExtCtxt < ' _ > , span : Span , acc : & mut Option < P < Expr > > , elem : P < Expr > ) {
87+ let Some ( lhs) = acc. take ( ) else {
88+ * acc = Some ( elem) ;
89+ return ;
90+ } ;
91+ * acc = Some ( cx. expr_binary ( span, BinOpKind :: And , lhs, elem) ) ;
92+ }
93+
94+ use SubstructureFields :: * ;
95+ match substructure. fields {
96+ EnumMatching ( .., fields) | Struct ( .., fields) => {
97+ if fields. is_empty ( ) {
98+ // Fieldless structs or enum variants are always equal to themselves.
99+ return cx. expr_bool ( span, true ) ;
100+ }
101+
102+ let mut scalar_ty_cmp = None ;
103+ let mut compound_ty_cmp = None ;
104+ // Compare scalar and compound types separately for efficiency.
105+ for field in fields {
106+ let is_scalar = field. is_scalar ;
107+ let field_span = field. span ;
108+ let rhs = get_field_equality_expr ( cx, field) ;
109+
110+ if is_scalar {
111+ // Combine scalar field comparisons first (cheaper to evaluate).
112+ combine ( cx, field_span, & mut scalar_ty_cmp, rhs) ;
113+ continue ;
114+ }
115+ // Combine compound (non-scalar) field comparisons.
116+ combine ( cx, field_span, & mut compound_ty_cmp, rhs) ;
117+ }
118+
119+ // If only one group (scalar or compound) has fields, return its comparison directly.
120+ if scalar_ty_cmp. is_some ( ) ^ compound_ty_cmp. is_some ( ) {
121+ return scalar_ty_cmp. or ( compound_ty_cmp) . unwrap ( ) ;
122+ }
123+
124+ // If both groups are non-empty, require all fields to be equal.
125+ // Scalar fields are compared first for performance.
126+ return cx. expr_binary (
127+ span,
128+ BinOpKind :: And ,
129+ scalar_ty_cmp. unwrap ( ) ,
130+ compound_ty_cmp. unwrap ( ) ,
131+ ) ;
132+ }
133+ EnumDiscr ( disc, match_expr) => {
134+ let lhs = get_field_equality_expr ( cx, disc) ;
135+ let Some ( match_expr) = match_expr else {
136+ return lhs;
137+ } ;
138+ // Compare the discriminant first (cheaper), then the rest of the fields.
139+ return cx. expr_binary ( disc. span , BinOpKind :: And , lhs, match_expr. clone ( ) ) ;
140+ }
141+ StaticEnum ( ..) => cx. dcx ( ) . span_bug (
142+ span,
143+ "unexpected static enum encountered during `derive(PartialEq)` expansion" ,
144+ ) ,
145+ StaticStruct ( ..) => cx. dcx ( ) . span_bug (
146+ span,
147+ "unexpected static struct encountered during `derive(PartialEq)` expansion" ,
148+ ) ,
149+ AllFieldlessEnum ( ..) => cx. dcx ( ) . span_bug (
150+ span,
151+ "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion" ,
152+ ) ,
153+ }
154+ }
155+
156+ /// Generates an equality comparison expression for a single struct or enum field.
157+ ///
158+ /// This function produces an AST expression that compares the `self` and `other` values for a field using `==`.
159+ /// It removes any leading references from both sides for readability.
160+ /// If the field is a block expression, it is wrapped in parentheses to ensure valid syntax.
161+ ///
162+ /// # Panics
163+ ///
164+ /// Panics if there are not exactly two arguments to compare (should be `self` and `other`).
165+ fn get_field_equality_expr ( cx : & ExtCtxt < ' _ > , field : & FieldInfo ) -> P < Expr > {
166+ let [ rhs] = & field. other_selflike_exprs [ ..] else {
167+ cx. dcx ( ) . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
168+ } ;
169+
170+ cx. expr_binary (
171+ field. span ,
172+ BinOpKind :: Eq ,
173+ wrap_block_expr ( cx, peel_refs ( & field. self_expr ) ) ,
174+ wrap_block_expr ( cx, peel_refs ( rhs) ) ,
175+ )
176+ }
177+
178+ /// Removes all leading immutable references from an expression.
179+ ///
180+ /// This is used to strip away any number of leading `&` from an expression (e.g., `&&&T` becomes `T`).
181+ /// Only removes immutable references; mutable references are preserved.
182+ fn peel_refs ( mut expr : & P < Expr > ) -> P < Expr > {
183+ while let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) = & expr. kind {
184+ expr = & inner;
185+ }
186+ expr. clone ( )
187+ }
188+
189+ /// Wraps a block expression in parentheses to ensure valid AST in macro expansion output.
190+ ///
191+ /// If the given expression is a block, it is wrapped in parentheses; otherwise, it is returned unchanged.
192+ fn wrap_block_expr ( cx : & ExtCtxt < ' _ > , expr : P < Expr > ) -> P < Expr > {
193+ if matches ! ( & expr. kind, ExprKind :: Block ( ..) ) {
194+ return cx. expr_paren ( expr. span , expr) ;
195+ }
196+ expr
197+ }
0 commit comments