5858//! borrowing from the outer closure, and we simply peel off a `deref` projection 
5959//! from them. This second body is stored alongside the first body, and optimized 
6060//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`, 
61- //! we use this "by move" body instead. 
62- 
63- use  itertools:: Itertools ; 
61+ //! we use this "by-move" body instead. 
62+ //! 
63+ //! ## How does this work? 
64+ //! 
65+ //! This pass essentially remaps the body of the (child) closure of the coroutine-closure 
66+ //! to take the set of upvars of the parent closure by value. This at least requires 
67+ //! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure 
68+ //! captures something by value; however, it may also require renumbering field indices 
69+ //! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine 
70+ //! to split one field capture into two. 
6471
65- use  rustc_data_structures:: unord:: UnordSet ; 
72+ use  rustc_data_structures:: unord:: UnordMap ; 
6673use  rustc_hir as  hir; 
74+ use  rustc_middle:: hir:: place:: { PlaceBase ,  Projection ,  ProjectionKind } ; 
6775use  rustc_middle:: mir:: visit:: MutVisitor ; 
6876use  rustc_middle:: mir:: { self ,  dump_mir,  MirPass } ; 
6977use  rustc_middle:: ty:: { self ,  InstanceDef ,  Ty ,  TyCtxt ,  TypeVisitableExt } ; 
70- use  rustc_target:: abi:: FieldIdx ; 
78+ use  rustc_target:: abi:: { FieldIdx ,   VariantIdx } ; 
7179
7280pub  struct  ByMoveBody ; 
7381
@@ -116,32 +124,116 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
116124            . tuple_fields ( ) 
117125            . len ( ) ; 
118126
119-         let  mut  by_ref_fields = UnordSet :: default ( ) ; 
120-         for  ( idx,  ( coroutine_capture,  parent_capture) )  in  tcx
127+         let  mut  field_remapping = UnordMap :: default ( ) ; 
128+ 
129+         // One parent capture may correspond to several child captures if we end up 
130+         // refining the set of captures via edition-2021 precise captures. We want to 
131+         // match up any number of child captures with one parent capture, so we keep 
132+         // peeking off this `Peekable` until the child doesn't match anymore. 
133+         let  mut  parent_captures =
134+             tcx. closure_captures ( parent_def_id) . iter ( ) . copied ( ) . enumerate ( ) . peekable ( ) ; 
135+         // Make sure we use every field at least once, b/c why are we capturing something 
136+         // if it's not used in the inner coroutine. 
137+         let  mut  field_used_at_least_once = false ; 
138+ 
139+         for  ( child_field_idx,  child_capture)  in  tcx
121140            . closure_captures ( coroutine_def_id) 
122141            . iter ( ) 
142+             . copied ( ) 
123143            // By construction we capture all the args first. 
124144            . skip ( num_args) 
125-             . zip_eq ( tcx. closure_captures ( parent_def_id) ) 
126145            . enumerate ( ) 
127146        { 
128-             // This upvar is captured by-move from the parent closure, but by-ref 
129-             // from the inner async block. That means that it's being borrowed from 
130-             // the outer closure body -- we need to change the coroutine to take the 
131-             // upvar by value. 
132-             if  coroutine_capture. is_by_ref ( )  && !parent_capture. is_by_ref ( )  { 
133-                 assert_ne ! ( 
134-                     coroutine_kind, 
135-                     ty:: ClosureKind :: FnOnce , 
136-                     "`FnOnce` coroutine-closures return coroutines that capture from \  
137- 
147+             loop  { 
148+                 let  Some ( & ( parent_field_idx,  parent_capture) )  = parent_captures. peek ( )  else  { 
149+                     bug ! ( "we ran out of parent captures!" ) 
150+                 } ; 
151+ 
152+                 let  PlaceBase :: Upvar ( parent_base)  = parent_capture. place . base  else  { 
153+                     bug ! ( "expected capture to be an upvar" ) ; 
154+                 } ; 
155+                 let  PlaceBase :: Upvar ( child_base)  = child_capture. place . base  else  { 
156+                     bug ! ( "expected capture to be an upvar" ) ; 
157+                 } ; 
158+ 
159+                 assert ! ( 
160+                     child_capture. place. projections. len( )  >= parent_capture. place. projections. len( ) 
138161                ) ; 
139-                 by_ref_fields. insert ( FieldIdx :: from_usize ( num_args + idx) ) ; 
162+                 // A parent matches a child they share the same prefix of projections. 
163+                 // The child may have more, if it is capturing sub-fields out of 
164+                 // something that is captured by-move in the parent closure. 
165+                 if  parent_base. var_path . hir_id  != child_base. var_path . hir_id 
166+                     || !std:: iter:: zip ( 
167+                         & child_capture. place . projections , 
168+                         & parent_capture. place . projections , 
169+                     ) 
170+                     . all ( |( child,  parent) | child. kind  == parent. kind ) 
171+                 { 
172+                     // Make sure the field was used at least once. 
173+                     assert ! ( 
174+                         field_used_at_least_once, 
175+                         "we captured {parent_capture:#?} but it was not used in the child coroutine?" 
176+                     ) ; 
177+                     field_used_at_least_once = false ; 
178+                     // Skip this field. 
179+                     let  _ = parent_captures. next ( ) . unwrap ( ) ; 
180+                     continue ; 
181+                 } 
182+ 
183+                 // Store this set of additional projections (fields and derefs). 
184+                 // We need to re-apply them later. 
185+                 let  child_precise_captures =
186+                     & child_capture. place . projections [ parent_capture. place . projections . len ( ) ..] ; 
187+ 
188+                 // If the parent captures by-move, and the child captures by-ref, then we 
189+                 // need to peel an additional `deref` off of the body of the child. 
190+                 let  needs_deref = child_capture. is_by_ref ( )  && !parent_capture. is_by_ref ( ) ; 
191+                 if  needs_deref { 
192+                     assert_ne ! ( 
193+                         coroutine_kind, 
194+                         ty:: ClosureKind :: FnOnce , 
195+                         "`FnOnce` coroutine-closures return coroutines that capture from \  
196+ 
197+                     ) ; 
198+                 } 
199+ 
200+                 // Finally, store the type of the parent's captured place. We need 
201+                 // this when building the field projection in the MIR body later on. 
202+                 let  mut  parent_capture_ty = parent_capture. place . ty ( ) ; 
203+                 parent_capture_ty = match  parent_capture. info . capture_kind  { 
204+                     ty:: UpvarCapture :: ByValue  => parent_capture_ty, 
205+                     ty:: UpvarCapture :: ByRef ( kind)  => Ty :: new_ref ( 
206+                         tcx, 
207+                         tcx. lifetimes . re_erased , 
208+                         parent_capture_ty, 
209+                         kind. to_mutbl_lossy ( ) , 
210+                     ) , 
211+                 } ; 
212+ 
213+                 field_remapping. insert ( 
214+                     FieldIdx :: from_usize ( child_field_idx + num_args) , 
215+                     ( 
216+                         FieldIdx :: from_usize ( parent_field_idx + num_args) , 
217+                         parent_capture_ty, 
218+                         needs_deref, 
219+                         child_precise_captures, 
220+                     ) , 
221+                 ) ; 
222+ 
223+                 field_used_at_least_once = true ; 
224+                 break ; 
140225            } 
226+         } 
227+ 
228+         // Pop the last parent capture 
229+         if  field_used_at_least_once { 
230+             let  _ = parent_captures. next ( ) . unwrap ( ) ; 
231+         } 
232+         assert_eq ! ( parent_captures. next( ) ,  None ,  "leftover parent captures?" ) ; 
141233
142-              // Make sure we're actually talking about the same capture. 
143-             // FIXME(async_closures): We could look at the `hir::Upvar` instead? 
144-             assert_eq ! ( coroutine_capture . place . ty ( ) ,  parent_capture . place . ty ( ) ) ; 
234+         if  coroutine_kind == ty :: ClosureKind :: FnOnce   { 
235+             assert_eq ! ( field_remapping . len ( ) ,  tcx . closure_captures ( parent_def_id ) . len ( ) ) ; 
236+             return ; 
145237        } 
146238
147239        let  by_move_coroutine_ty = tcx
@@ -157,7 +249,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
157249            ) ; 
158250
159251        let  mut  by_move_body = body. clone ( ) ; 
160-         MakeByMoveBody  {  tcx,  by_ref_fields ,  by_move_coroutine_ty } . visit_body ( & mut  by_move_body) ; 
252+         MakeByMoveBody  {  tcx,  field_remapping ,  by_move_coroutine_ty } . visit_body ( & mut  by_move_body) ; 
161253        dump_mir ( tcx,  false ,  "coroutine_by_move" ,  & 0 ,  & by_move_body,  |_,  _| Ok ( ( ) ) ) ; 
162254        by_move_body. source  = mir:: MirSource :: from_instance ( InstanceDef :: CoroutineKindShim  { 
163255            coroutine_def_id :  coroutine_def_id. to_def_id ( ) , 
@@ -168,7 +260,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
168260
169261struct  MakeByMoveBody < ' tcx >  { 
170262    tcx :  TyCtxt < ' tcx > , 
171-     by_ref_fields :   UnordSet < FieldIdx > , 
263+     field_remapping :   UnordMap < FieldIdx ,   ( FieldIdx ,   Ty < ' tcx > ,   bool ,   & ' tcx   [ Projection < ' tcx > ] ) > , 
172264    by_move_coroutine_ty :  Ty < ' tcx > , 
173265} 
174266
@@ -183,24 +275,59 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
183275        context :  mir:: visit:: PlaceContext , 
184276        location :  mir:: Location , 
185277    )  { 
278+         // Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a 
279+         // field projection. If this is in `field_remapping`, then it must not be an 
280+         // arg from calling the closure, but instead an upvar. 
186281        if  place. local  == ty:: CAPTURE_STRUCT_LOCAL 
187-             && let  Some ( ( & mir:: ProjectionElem :: Field ( idx,  ty ) ,  projection) )  =
282+             && let  Some ( ( & mir:: ProjectionElem :: Field ( idx,  _ ) ,  projection) )  =
188283                place. projection . split_first ( ) 
189-             && self . by_ref_fields . contains ( & idx) 
284+             && let  Some ( & ( remapped_idx,  remapped_ty,  needs_deref,  additional_projections) )  =
285+                 self . field_remapping . get ( & idx) 
190286        { 
191-             let  ( begin,  end)  = projection. split_first ( ) . unwrap ( ) ; 
192-             // FIXME(async_closures): I'm actually a bit surprised to see that we always 
193-             // initially deref the by-ref upvars. If this is not actually true, then we 
194-             // will at least get an ICE that explains why this isn't true :^) 
195-             assert_eq ! ( * begin,  mir:: ProjectionElem :: Deref ) ; 
196-             // Peel one ref off of the ty. 
197-             let  peeled_ty = ty. builtin_deref ( true ) . unwrap ( ) . ty ; 
287+             // As noted before, if the parent closure captures a field by value, and 
288+             // the child captures a field by ref, then for the by-move body we're 
289+             // generating, we also are taking that field by value. Peel off a deref, 
290+             // since a layer of reffing has now become redundant. 
291+             let  final_deref = if  needs_deref { 
292+                 let  Some ( ( mir:: ProjectionElem :: Deref ,  projection) )  = projection. split_first ( ) 
293+                 else  { 
294+                     bug ! ( 
295+                         "There should be at least a single deref for an upvar local initialization, found {projection:#?}" 
296+                     ) ; 
297+                 } ; 
298+                 // There may be more derefs, since we may also implicitly reborrow 
299+                 // a captured mut pointer. 
300+                 projection
301+             }  else  { 
302+                 projection
303+             } ; 
304+ 
305+             // The only thing that should be left is a deref, if the parent captured 
306+             // an upvar by-ref. 
307+             std:: assert_matches:: assert_matches!( final_deref,  [ ]  | [ mir:: ProjectionElem :: Deref ] ) ; 
308+ 
309+             // For all of the additional projections that come out of precise capturing, 
310+             // re-apply these projections. 
311+             let  additional_projections =
312+                 additional_projections. iter ( ) . map ( |elem| match  elem. kind  { 
313+                     ProjectionKind :: Deref  => mir:: ProjectionElem :: Deref , 
314+                     ProjectionKind :: Field ( idx,  VariantIdx :: ZERO )  => { 
315+                         mir:: ProjectionElem :: Field ( idx,  elem. ty ) 
316+                     } 
317+                     _ => unreachable ! ( "precise captures only through fields and derefs" ) , 
318+                 } ) ; 
319+ 
320+             // We start out with an adjusted field index (and ty), representing the 
321+             // upvar that we get from our parent closure. We apply any of the additional 
322+             // projections to make sure that to the rest of the body of the closure, the 
323+             // place looks the same, and then apply that final deref if necessary. 
198324            * place = mir:: Place  { 
199325                local :  place. local , 
200326                projection :  self . tcx . mk_place_elems_from_iter ( 
201-                     [ mir:: ProjectionElem :: Field ( idx ,  peeled_ty ) ] 
327+                     [ mir:: ProjectionElem :: Field ( remapped_idx ,  remapped_ty ) ] 
202328                        . into_iter ( ) 
203-                         . chain ( end. iter ( ) . copied ( ) ) , 
329+                         . chain ( additional_projections) 
330+                         . chain ( final_deref. iter ( ) . copied ( ) ) , 
204331                ) , 
205332            } ; 
206333        } 
0 commit comments