Skip to content

Commit 7f2f7ac

Browse files
committed
Only load pin field once.
1 parent 90181a4 commit 7f2f7ac

13 files changed

+210
-290
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ struct SelfArgVisitor<'tcx> {
133133
}
134134

135135
impl<'tcx> SelfArgVisitor<'tcx> {
136-
fn new(tcx: TyCtxt<'tcx>, elem: ProjectionElem<Local, Ty<'tcx>>) -> Self {
137-
Self { tcx, new_base: Place { local: SELF_ARG, projection: tcx.mk_place_elems(&[elem]) } }
136+
fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
137+
Self { tcx, new_base }
138138
}
139139
}
140140

@@ -147,16 +147,14 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
147147
assert_ne!(*local, SELF_ARG);
148148
}
149149

150-
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
150+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
151151
if place.local == SELF_ARG {
152152
replace_base(place, self.new_base, self.tcx);
153-
} else {
154-
self.visit_local(&mut place.local, context, location);
153+
}
155154

156-
for elem in place.projection.iter() {
157-
if let PlaceElem::Index(local) = elem {
158-
assert_ne!(local, SELF_ARG);
159-
}
155+
for elem in place.projection.iter() {
156+
if let PlaceElem::Index(local) = elem {
157+
assert_ne!(local, SELF_ARG);
160158
}
161159
}
162160
}
@@ -516,32 +514,51 @@ fn make_aggregate_adt<'tcx>(
516514

517515
#[tracing::instrument(level = "trace", skip(tcx, body))]
518516
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
519-
let coroutine_ty = body.local_decls.raw[1].ty;
517+
let coroutine_ty = body.local_decls[SELF_ARG].ty;
520518

521519
let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
522520

523521
// Replace the by value coroutine argument
524-
body.local_decls.raw[1].ty = ref_coroutine_ty;
522+
body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
525523

526524
// Add a deref to accesses of the coroutine state
527-
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
525+
SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
528526
}
529527

530528
#[tracing::instrument(level = "trace", skip(tcx, body))]
531529
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
532-
let ref_coroutine_ty = body.local_decls.raw[1].ty;
530+
let coroutine_ty = body.local_decls[SELF_ARG].ty;
531+
532+
let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
533533

534534
let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
535535
let pin_adt_ref = tcx.adt_def(pin_did);
536536
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
537537
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
538538

539539
// Replace the by ref coroutine argument
540-
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;
540+
body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
541+
542+
let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
541543

542544
// Add the Pin field access to accesses of the coroutine state
543-
SelfArgVisitor::new(tcx, ProjectionElem::Field(FieldIdx::ZERO, ref_coroutine_ty))
544-
.visit_body(body);
545+
SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
546+
547+
let source_info = SourceInfo::outermost(body.span);
548+
body.basic_blocks_mut()[START_BLOCK].statements.insert(
549+
0,
550+
Statement::new(
551+
source_info,
552+
StatementKind::Assign(Box::new((
553+
unpinned_local.into(),
554+
Rvalue::CopyForDeref(tcx.mk_place_field(
555+
SELF_ARG.into(),
556+
FieldIdx::ZERO,
557+
ref_coroutine_ty,
558+
)),
559+
))),
560+
),
561+
);
545562
}
546563

547564
/// Transforms the `body` of the coroutine applying the following transforms:
@@ -1293,8 +1310,6 @@ fn create_coroutine_resume_function<'tcx>(
12931310
let default_block = insert_term_block(body, TerminatorKind::Unreachable);
12941311
insert_switch(body, cases, &transform, default_block);
12951312

1296-
make_coroutine_state_argument_indirect(tcx, body);
1297-
12981313
match transform.coroutine_kind {
12991314
CoroutineKind::Coroutine(_)
13001315
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
@@ -1303,7 +1318,9 @@ fn create_coroutine_resume_function<'tcx>(
13031318
}
13041319
// Iterator::next doesn't accept a pinned argument,
13051320
// unlike for all other coroutine kinds.
1306-
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
1321+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1322+
make_coroutine_state_argument_indirect(tcx, body);
1323+
}
13071324
}
13081325
}
13091326

compiler/rustc_mir_transform/src/coroutine/drop.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,12 +685,13 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
685685
let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
686686
body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
687687

688-
make_coroutine_state_argument_indirect(tcx, &mut body);
689-
690688
match transform.coroutine_kind {
691689
// Iterator::next doesn't accept a pinned argument,
692690
// unlike for all other coroutine kinds.
693-
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
691+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
692+
make_coroutine_state_argument_indirect(tcx, &mut body);
693+
}
694+
694695
_ => {
695696
make_coroutine_state_argument_pinned(tcx, &mut body);
696697
}

tests/mir-opt/async_drop_live_dead.a-{closure#0}.coroutine_drop_async.0.panic-abort.mir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
44
debug _task_context => _2;
5-
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
5+
debug x => ((*_20).0: T);
66
let mut _0: std::task::Poll<()>;
77
let _3: T;
88
let mut _4: impl std::future::Future<Output = ()>;
@@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
2121
let mut _17: isize;
2222
let mut _18: ();
2323
let mut _19: u32;
24+
let mut _20: &mut {async fn body of a<T>()};
2425
scope 1 {
25-
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
26+
debug x => (((*_20) as variant#4).0: T);
2627
}
2728

2829
bb0: {
29-
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
30+
_20 = deref_copy (_1.0: &mut {async fn body of a<T>()});
31+
_19 = discriminant((*_20));
3032
switchInt(move _19) -> [0: bb8, 3: bb11, 4: bb12, otherwise: bb13];
3133
}
3234

@@ -39,13 +41,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
3941

4042
bb2: {
4143
_0 = Poll::<()>::Pending;
42-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
44+
discriminant((*_20)) = 4;
4345
return;
4446
}
4547

4648
bb3: {
4749
StorageLive(_16);
48-
_15 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
50+
_15 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
4951
_16 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _15) -> [return: bb6, unwind unreachable];
5052
}
5153

@@ -77,7 +79,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
7779
}
7880

7981
bb10: {
80-
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb9, unwind unreachable];
82+
drop(((*_20).0: T)) -> [return: bb9, unwind unreachable];
8183
}
8284

8385
bb11: {

tests/mir-opt/async_drop_live_dead.a-{closure#0}.coroutine_drop_async.0.panic-unwind.mir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
44
debug _task_context => _2;
5-
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
5+
debug x => ((*_20).0: T);
66
let mut _0: std::task::Poll<()>;
77
let _3: T;
88
let mut _4: impl std::future::Future<Output = ()>;
@@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
2121
let mut _17: isize;
2222
let mut _18: ();
2323
let mut _19: u32;
24+
let mut _20: &mut {async fn body of a<T>()};
2425
scope 1 {
25-
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
26+
debug x => (((*_20) as variant#4).0: T);
2627
}
2728

2829
bb0: {
29-
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
30+
_20 = deref_copy (_1.0: &mut {async fn body of a<T>()});
31+
_19 = discriminant((*_20));
3032
switchInt(move _19) -> [0: bb8, 2: bb15, 3: bb13, 4: bb14, otherwise: bb16];
3133
}
3234

@@ -39,13 +41,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
3941

4042
bb2: {
4143
_0 = Poll::<()>::Pending;
42-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
44+
discriminant((*_20)) = 4;
4345
return;
4446
}
4547

4648
bb3: {
4749
StorageLive(_16);
48-
_15 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
50+
_15 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
4951
_16 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _15) -> [return: bb6, unwind: bb12];
5052
}
5153

@@ -81,11 +83,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
8183
}
8284

8385
bb11: {
84-
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb10, unwind: bb9];
86+
drop(((*_20).0: T)) -> [return: bb10, unwind: bb9];
8587
}
8688

8789
bb12 (cleanup): {
88-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 2;
90+
discriminant((*_20)) = 2;
8991
resume;
9092
}
9193

tests/mir-opt/building/async_await.a-{closure#0}.coroutine_resume.0.mir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
1414
let mut _0: std::task::Poll<()>;
1515
let mut _3: ();
1616
let mut _4: u32;
17+
let mut _5: &mut {async fn body of a()};
1718

1819
bb0: {
19-
_4 = discriminant((*(_1.0: &mut {async fn body of a()})));
20+
_5 = deref_copy (_1.0: &mut {async fn body of a()});
21+
_4 = discriminant((*_5));
2022
switchInt(move _4) -> [0: bb1, 1: bb9, otherwise: bb10];
2123
}
2224

@@ -27,7 +29,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
2729

2830
bb2: {
2931
_0 = Poll::<()>::Ready(move _3);
30-
discriminant((*(_1.0: &mut {async fn body of a()}))) = 1;
32+
discriminant((*_5)) = 1;
3133
return;
3234
}
3335

tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,25 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
8686
let mut _36: ();
8787
let mut _37: ();
8888
let mut _38: u32;
89+
let mut _39: &mut {async fn body of b()};
8990
scope 1 {
90-
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
91+
debug __awaitee => (((*_39) as variant#3).0: {async fn body of a()});
9192
let _17: ();
9293
scope 2 {
9394
debug result => _17;
9495
}
9596
}
9697
scope 3 {
97-
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
98+
debug __awaitee => (((*_39) as variant#4).0: {async fn body of a()});
9899
let _33: ();
99100
scope 4 {
100101
debug result => _33;
101102
}
102103
}
103104

104105
bb0: {
105-
_38 = discriminant((*(_1.0: &mut {async fn body of b()})));
106+
_39 = deref_copy (_1.0: &mut {async fn body of b()});
107+
_38 = discriminant((*_39));
106108
switchInt(move _38) -> [0: bb1, 1: bb39, 3: bb37, 4: bb38, otherwise: bb40];
107109
}
108110

@@ -120,7 +122,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
120122
bb3: {
121123
StorageDead(_5);
122124
nop;
123-
(((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()}) = move _4;
125+
(((*_39) as variant#3).0: {async fn body of a()}) = move _4;
124126
goto -> bb4;
125127
}
126128

@@ -130,7 +132,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
130132
StorageLive(_10);
131133
StorageLive(_11);
132134
StorageLive(_12);
133-
_12 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
135+
_12 = &mut (((*_39) as variant#3).0: {async fn body of a()});
134136
_11 = &mut (*_12);
135137
_10 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _11) -> [return: bb5, unwind unreachable];
136138
}
@@ -176,7 +178,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
176178
StorageDead(_4);
177179
StorageDead(_19);
178180
StorageDead(_20);
179-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 3;
181+
discriminant((*_39)) = 3;
180182
return;
181183
}
182184

@@ -189,7 +191,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
189191
StorageDead(_12);
190192
StorageDead(_9);
191193
StorageDead(_8);
192-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
194+
drop((((*_39) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
193195
}
194196

195197
bb11: {
@@ -216,7 +218,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
216218
bb14: {
217219
StorageDead(_22);
218220
nop;
219-
(((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()}) = move _21;
221+
(((*_39) as variant#4).0: {async fn body of a()}) = move _21;
220222
goto -> bb15;
221223
}
222224

@@ -226,7 +228,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
226228
StorageLive(_26);
227229
StorageLive(_27);
228230
StorageLive(_28);
229-
_28 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
231+
_28 = &mut (((*_39) as variant#4).0: {async fn body of a()});
230232
_27 = &mut (*_28);
231233
_26 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _27) -> [return: bb16, unwind unreachable];
232234
}
@@ -267,7 +269,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
267269
StorageDead(_21);
268270
StorageDead(_35);
269271
StorageDead(_36);
270-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 4;
272+
discriminant((*_39)) = 4;
271273
return;
272274
}
273275

@@ -280,7 +282,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
280282
StorageDead(_28);
281283
StorageDead(_25);
282284
StorageDead(_24);
283-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()})) -> [return: bb22, unwind unreachable];
285+
drop((((*_39) as variant#4).0: {async fn body of a()})) -> [return: bb22, unwind unreachable];
284286
}
285287

286288
bb21: {
@@ -299,14 +301,14 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
299301

300302
bb23: {
301303
_0 = Poll::<()>::Ready(move _37);
302-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 1;
304+
discriminant((*_39)) = 1;
303305
return;
304306
}
305307

306308
bb24: {
307309
StorageDead(_36);
308310
StorageDead(_35);
309-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()})) -> [return: bb25, unwind unreachable];
311+
drop((((*_39) as variant#4).0: {async fn body of a()})) -> [return: bb25, unwind unreachable];
310312
}
311313

312314
bb25: {
@@ -318,7 +320,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
318320
bb26: {
319321
StorageDead(_20);
320322
StorageDead(_19);
321-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()})) -> [return: bb27, unwind unreachable];
323+
drop((((*_39) as variant#3).0: {async fn body of a()})) -> [return: bb27, unwind unreachable];
322324
}
323325

324326
bb27: {

0 commit comments

Comments
 (0)