Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,10 +927,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::LangItem::PinNewUnchecked,
arena_vec![self; ref_mut_awaitee],
);
let get_context = self.expr_call_lang_item_fn_mut(
let get_context = self.expr(
gen_future_span,
hir::LangItem::GetContext,
arena_vec![self; task_context],
hir::ExprKind::UnsafeBinderCast(
UnsafeBinderCastKind::Unwrap,
self.arena.alloc(task_context),
None,
),
);
let call = match await_kind {
FutureKind::Future => self.expr_call_lang_item_fn(
Expand Down
10 changes: 6 additions & 4 deletions compiler/rustc_borrowck/src/places_conflict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ fn place_projection_conflict<'tcx>(
debug!("place_element_conflict: DISJOINT-OR-EQ-OPAQUE");
Overlap::EqualOrDisjoint
}
(ProjectionElem::UnwrapUnsafeBinder(_), ProjectionElem::UnwrapUnsafeBinder(_)) => {
// casts to other types may always conflict irrespective of the type being cast to.
debug!("place_element_conflict: DISJOINT-OR-EQ-OPAQUE");
Overlap::EqualOrDisjoint
}
(ProjectionElem::Field(f1, _), ProjectionElem::Field(f2, _)) => {
if f1 == f2 {
// same field (e.g., `a.y` vs. `a.y`) - recur.
Expand Down Expand Up @@ -512,6 +517,7 @@ fn place_projection_conflict<'tcx>(
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::Subtype(_)
| ProjectionElem::OpaqueCast { .. }
| ProjectionElem::UnwrapUnsafeBinder { .. }
| ProjectionElem::Subslice { .. }
| ProjectionElem::Downcast(..),
_,
Expand All @@ -520,9 +526,5 @@ fn place_projection_conflict<'tcx>(
pi1_elem,
pi2_elem
),

(ProjectionElem::UnwrapUnsafeBinder(_), _) => {
todo!()
}
}
}
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ pub(crate) fn spanned_type_di_node<'ll, 'tcx>(
AdtKind::Enum => enums::build_enum_type_di_node(cx, unique_type_id, span),
},
ty::Tuple(_) => build_tuple_type_di_node(cx, unique_type_id),
ty::UnsafeBinder(binder) => {
return type_di_node(cx, cx.tcx.instantiate_bound_regions_with_erased(*binder));
}
_ => bug!("debuginfo: unexpected type in type_di_node(): {:?}", t),
};

Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,7 @@ language_item_table! {

// FIXME(swatinem): the following lang items are used for async lowering and
// should become obsolete eventually.
ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None;
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;
ResumeTy, sym::ResumeTy, resume_ty, Target::TyAlias, GenericRequirement::None;

Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
Expand Down
23 changes: 23 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,29 @@ impl<'tcx> Ty<'tcx> {
Ty::new_generic_adt(tcx, def_id, ty)
}

/// Creates a `unsafe<'a, 'b> &'a mut Context<'b>` [`Ty`].
pub fn new_resume_ty(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let context_did = tcx.require_lang_item(LangItem::Context, DUMMY_SP);
let context_adt_ref = tcx.adt_def(context_did);

let lt = |n| {
ty::Region::new_bound(
tcx,
ty::INNERMOST,
ty::BoundRegion { var: ty::BoundVar::from_u32(n), kind: BoundRegionKind::Anon },
)
};

let context_args = tcx.mk_args(&[lt(1).into()]);
let context_ty = Ty::new_adt(tcx, context_adt_ref, context_args);
let context_mut_ref = Ty::new_mut_ref(tcx, lt(0), context_ty);
let bound_vars = tcx.mk_bound_variable_kinds(&[
BoundVariableKind::Region(BoundRegionKind::Anon),
BoundVariableKind::Region(BoundRegionKind::Anon),
]);
Ty::new_unsafe_binder(tcx, ty::Binder::bind_with_vars(context_mut_ref, bound_vars))
}

/// Creates a `&mut Context<'_>` [`Ty`] with erased lifetimes.
pub fn new_task_context(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let context_did = tcx.require_lang_item(LangItem::Context, DUMMY_SP);
Expand Down
130 changes: 20 additions & 110 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ struct TransformVisitor<'tcx> {
old_yield_ty: Ty<'tcx>,

old_ret_ty: Ty<'tcx>,

/// The rvalue that should be assigned to yield `resume_arg` place.
resume_rvalue: Rvalue<'tcx>,
}

impl<'tcx> TransformVisitor<'tcx> {
Expand Down Expand Up @@ -533,100 +536,6 @@ fn replace_local<'tcx>(
new_local
}

/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
/// See <https://github.com/rust-lang/rust/issues/105501>.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
let context_mut_ref = Ty::new_task_context(tcx);

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);

for bb in body.basic_blocks.indices() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
{
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
}
}
context_mut_ref
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
bug!();
};
let [arg] = *Box::try_from(args).unwrap();
let local = arg.node.place().unwrap().local;

let arg = Rvalue::Use(arg.node);
let assign =
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
}

#[cfg_attr(not(debug_assertions), allow(unused))]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
}

/// Transforms the `body` of the coroutine applying the following transform:
///
/// - Remove the `resume` argument.
Expand Down Expand Up @@ -1342,12 +1251,11 @@ fn create_cases<'tcx>(

if operation == Operation::Resume {
// Move the resume argument to the destination place of the `Yield` terminator
let resume_arg = CTX_ARG;
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
point.resume_arg,
Rvalue::Use(Operand::Move(resume_arg.into())),
transform.resume_rvalue.clone(),
))),
));
}
Expand Down Expand Up @@ -1504,12 +1412,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
) && has_expandable_async_drops(tcx, body, coroutine_ty);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if matches!(
coroutine_kind,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
) {
let context_mut_ref = transform_async_context(tcx, body);
expand_async_drops(tcx, body, context_mut_ref, coroutine_kind, coroutine_ty);
if has_async_drops {
expand_async_drops(tcx, body, coroutine_kind, coroutine_ty);

if let Some(dumper) = MirDumper::new(tcx, "coroutine_async_drop_expand", body) {
dumper.dump_mir(body);
Expand All @@ -1522,22 +1426,27 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// This is needed because the resume argument `_2` might be live across a `yield`, in which
// case there is no `Assign` to it that the transform can turn into a store to the coroutine
// state. After the yield the slot in the coroutine state would then be uninitialized.
let resume_local = CTX_ARG;
let resume_ty = body.local_decls[resume_local].ty;
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
let resume_ty = body.local_decls[CTX_ARG].ty;
let old_resume_local = replace_local(CTX_ARG, resume_ty, body, tcx);

// When first entering the coroutine, move the resume argument into its old local
// (which is now a generator interior).
let source_info = SourceInfo::outermost(body.span);
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
let stmts = &mut body.basic_blocks.as_mut()[START_BLOCK].statements;
let resume_rvalue = if matches!(
coroutine_kind,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
) {
body.local_decls[CTX_ARG].ty = Ty::new_task_context(tcx);
Rvalue::WrapUnsafeBinder(Operand::Move(CTX_ARG.into()), resume_ty)
} else {
Rvalue::Use(Operand::Move(CTX_ARG.into()))
};
stmts.insert(
0,
Statement::new(
source_info,
StatementKind::Assign(Box::new((
old_resume_local.into(),
Rvalue::Use(Operand::Move(resume_local.into())),
))),
StatementKind::Assign(Box::new((old_resume_local.into(), resume_rvalue.clone()))),
),
);

Expand Down Expand Up @@ -1580,6 +1489,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
discr_ty,
old_ret_ty,
old_yield_ty,
resume_rvalue,
};
transform.visit_body(body);

Expand Down
16 changes: 10 additions & 6 deletions compiler/rustc_mir_transform/src/coroutine/drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,11 @@ pub(super) fn has_expandable_async_drops<'tcx>(
pub(super) fn expand_async_drops<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
context_mut_ref: Ty<'tcx>,
coroutine_kind: hir::CoroutineKind,
coroutine_ty: Ty<'tcx>,
) {
let resume_ty = Ty::new_resume_ty(tcx);
let context_mut_ref = Ty::new_task_context(tcx);
let dropline = gather_dropline_blocks(body);
// Clean drop and async_fut fields if potentially async drop is not expanded (stays sync)
let remove_asyncness = |block: &mut BasicBlockData<'tcx>| {
Expand Down Expand Up @@ -323,8 +324,8 @@ pub(super) fn expand_async_drops<'tcx>(

// First state-loop yield for mainline
let context_ref_place =
Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)));
let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG)));
Place::from(body.local_decls.push(LocalDecl::new(resume_ty, source_info.span)));
let arg = Rvalue::Use(Operand::Move(CTX_ARG.into()));
body[bb].statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((context_ref_place, arg))),
Expand Down Expand Up @@ -358,8 +359,11 @@ pub(super) fn expand_async_drops<'tcx>(
let mut dropline_context_ref: Option<Place<'_>> = None;
let mut dropline_call_bb: Option<BasicBlock> = None;
if !is_dropline_bb {
let context_ref_place2: Place<'_> = Place::from(
body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)),
let context_ref_local2 =
body.local_decls.push(LocalDecl::new(resume_ty, source_info.span));
let context_ref_place2 = tcx.mk_place_elem(
context_ref_local2.into(),
PlaceElem::UnwrapUnsafeBinder(context_mut_ref),
);
let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); // `kind` replaced later to yield
let (pin_bb2, fut_pin_place2) =
Expand All @@ -385,7 +389,7 @@ pub(super) fn expand_async_drops<'tcx>(
);
dropline_transition_bb = Some(pin_bb2);
dropline_yield_bb = Some(drop_yield_block);
dropline_context_ref = Some(context_ref_place2);
dropline_context_ref = Some(context_ref_local2.into());
dropline_call_bb = Some(drop_call_bb);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ pub(super) fn build_async_drop_shim<'tcx>(
let needs_async_drop = drop_ty.needs_async_drop(tcx, typing_env);
let needs_sync_drop = !needs_async_drop && drop_ty.needs_drop(tcx, typing_env);

let resume_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, DUMMY_SP));
let resume_ty = Ty::new_adt(tcx, resume_adt, ty::List::empty());
let resume_ty = Ty::new_resume_ty(tcx);

let fn_sig = ty::Binder::dummy(tcx.mk_fn_sig(
[ty, resume_ty],
Expand Down
24 changes: 7 additions & 17 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,11 @@ fn fn_sig_for_fn_abi<'tcx>(
// with `&mut Context<'_>` which is used in codegen.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
let expected_adt =
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, DUMMY_SP));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
};
let resume_ty = Ty::new_resume_ty(tcx);
assert_eq!(resume_ty, sig.resume_ty);
}
let context_mut_ref = Ty::new_task_context(tcx);

let context_mut_ref = Ty::new_task_context(tcx);
(Some(context_mut_ref), ret_ty)
}
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
Expand All @@ -191,20 +186,15 @@ fn fn_sig_for_fn_abi<'tcx>(
// Yield type is already `Poll<Option<yield_ty>>`
let ret_ty = sig.yield_ty;

// We have to replace the `ResumeTy` that is used for type and borrow checking
// We have to replace the `{:?}` that is used for type and borrow checking
// with `&mut Context<'_>` which is used in codegen.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
let expected_adt =
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, DUMMY_SP));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
};
let resume_ty = Ty::new_resume_ty(tcx);
assert_eq!(resume_ty, sig.resume_ty);
}
let context_mut_ref = Ty::new_task_context(tcx);

let context_mut_ref = Ty::new_task_context(tcx);
(Some(context_mut_ref), ret_ty)
}
hir::CoroutineKind::Coroutine(_) => {
Expand Down
Loading
Loading