Skip to content

Commit d89e798

Browse files
committed
Prereq6 for async drop - templated coroutine processing and layout
1 parent f9e319c commit d89e798

File tree

12 files changed

+270
-39
lines changed

12 files changed

+270
-39
lines changed

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,7 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
721721
_ => unreachable!(),
722722
};
723723

724-
let coroutine_layout =
725-
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
724+
let coroutine_layout = cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.args).unwrap();
726725

727726
let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
728727
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,8 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
174174
DIFlags::FlagZero,
175175
),
176176
|cx, coroutine_type_di_node| {
177-
let coroutine_layout = cx
178-
.tcx
179-
.coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
180-
.unwrap();
177+
let coroutine_layout =
178+
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args).unwrap();
181179

182180
let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
183181
coroutine_type_and_layout.variants

compiler/rustc_middle/src/arena.rs

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ macro_rules! arena_types {
99
($macro:path) => (
1010
$macro!([
1111
[] layout: rustc_abi::LayoutData<rustc_abi::FieldIdx, rustc_abi::VariantIdx>,
12+
[] proxy_coroutine_layout: rustc_middle::mir::CoroutineLayout<'tcx>,
1213
[] fn_abi: rustc_target::callconv::FnAbi<'tcx, rustc_middle::ty::Ty<'tcx>>,
1314
// AdtDef are interned and compared by address
1415
[decode] adt_def: rustc_middle::ty::AdtDefData,

compiler/rustc_middle/src/query/mod.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,13 @@ rustc_queries! {
549549
desc { |tcx| "elaborating drops for `{}`", tcx.def_path_str(key) }
550550
}
551551

552+
query templated_mir_drops_elaborated_and_const_checked(ty: Ty<'tcx>)
553+
-> &'tcx Steal<mir::Body<'tcx>>
554+
{
555+
no_hash
556+
desc { |tcx| "elaborating drops for templated mir `{}`", ty }
557+
}
558+
552559
query mir_for_ctfe(
553560
key: DefId
554561
) -> &'tcx mir::Body<'tcx> {
@@ -614,6 +621,11 @@ rustc_queries! {
614621
feedable
615622
}
616623

624+
/// MIR for templated coroutine after our optimization passes have run.
625+
query templated_optimized_mir(ty: Ty<'tcx>) -> &'tcx mir::Body<'tcx> {
626+
desc { |tcx| "optimizing templated MIR for `{}`", ty }
627+
}
628+
617629
/// Scans through a function's MIR after MIR optimizations, to prepare the
618630
/// information needed by codegen when `-Cinstrument-coverage` is active.
619631
///
@@ -1310,7 +1322,11 @@ rustc_queries! {
13101322
/// Generates a MIR body for the shim.
13111323
query mir_shims(key: ty::InstanceKind<'tcx>) -> &'tcx mir::Body<'tcx> {
13121324
arena_cache
1313-
desc { |tcx| "generating MIR shim for `{}`", tcx.def_path_str(key.def_id()) }
1325+
desc {
1326+
|tcx| "generating MIR shim for `{}`, instance={:?}",
1327+
tcx.def_path_str(key.def_id()),
1328+
key
1329+
}
13141330
}
13151331

13161332
/// The `symbol_name` query provides the symbol name for calling a

compiler/rustc_middle/src/ty/layout.rs

+50-15
Original file line numberDiff line numberDiff line change
@@ -922,23 +922,58 @@ where
922922
i,
923923
),
924924

925-
ty::Coroutine(def_id, args) => match this.variants {
926-
Variants::Empty => unreachable!(),
927-
Variants::Single { index } => TyMaybeWithLayout::Ty(
928-
args.as_coroutine()
929-
.state_tys(def_id, tcx)
930-
.nth(index.as_usize())
931-
.unwrap()
932-
.nth(i)
933-
.unwrap(),
934-
),
935-
Variants::Multiple { tag, tag_field, .. } => {
936-
if i == tag_field {
937-
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
925+
ty::Coroutine(def_id, args) => {
926+
// layout of `async_drop_in_place<T>::{closure}` in case,
927+
// when T is a coroutine, contains this internal coroutine's ref
928+
if tcx.is_templated_coroutine(def_id) {
929+
fn find_impl_coroutine<'tcx>(
930+
tcx: TyCtxt<'tcx>,
931+
mut cor_ty: Ty<'tcx>,
932+
) -> Ty<'tcx> {
933+
let mut ty = cor_ty;
934+
loop {
935+
if let ty::Coroutine(def_id, args) = ty.kind() {
936+
cor_ty = ty;
937+
if tcx.is_templated_coroutine(*def_id) {
938+
ty = args.first().unwrap().expect_ty();
939+
continue;
940+
} else {
941+
return cor_ty;
942+
}
943+
} else {
944+
return cor_ty;
945+
}
946+
}
947+
}
948+
let arg_cor_ty = args.first().unwrap().expect_ty();
949+
if arg_cor_ty.is_coroutine() {
950+
assert!(i == 0);
951+
let impl_cor_ty = find_impl_coroutine(tcx, arg_cor_ty);
952+
return TyMaybeWithLayout::Ty(Ty::new_mut_ref(
953+
tcx,
954+
tcx.lifetimes.re_static,
955+
impl_cor_ty,
956+
));
938957
}
939-
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
940958
}
941-
},
959+
match this.variants {
960+
Variants::Empty => unreachable!(),
961+
Variants::Single { index } => TyMaybeWithLayout::Ty(
962+
args.as_coroutine()
963+
.state_tys(def_id, tcx)
964+
.nth(index.as_usize())
965+
.unwrap()
966+
.nth(i)
967+
.unwrap(),
968+
),
969+
Variants::Multiple { tag, tag_field, .. } => {
970+
if i == tag_field {
971+
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
972+
}
973+
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
974+
}
975+
}
976+
}
942977

943978
ty::Tuple(tys) => TyMaybeWithLayout::Ty(tys[i]),
944979

compiler/rustc_middle/src/ty/mod.rs

+66-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use rustc_hir::LangItem;
3737
use rustc_hir::def::{CtorKind, CtorOf, DefKind, DocLinkResMap, LifetimeRes, Res};
3838
use rustc_hir::def_id::{CrateNum, DefId, DefIdMap, LocalDefId, LocalDefIdMap};
3939
use rustc_index::IndexVec;
40+
use rustc_index::bit_set::BitMatrix;
4041
use rustc_macros::{
4142
Decodable, Encodable, HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable,
4243
extension,
@@ -100,7 +101,7 @@ pub use self::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeV
100101
use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason};
101102
use crate::metadata::ModChild;
102103
use crate::middle::privacy::EffectiveVisibilities;
103-
use crate::mir::{Body, CoroutineLayout};
104+
use crate::mir::{Body, CoroutineLayout, CoroutineSavedLocal, CoroutineSavedTy, SourceInfo};
104105
use crate::query::{IntoQueryParam, Providers};
105106
use crate::ty;
106107
pub use crate::ty::diagnostics::*;
@@ -1742,7 +1743,7 @@ impl<'tcx> TyCtxt<'tcx> {
17421743
| ty::InstanceKind::FnPtrAddrShim(..)
17431744
| ty::InstanceKind::AsyncDropGlueCtorShim(..) => self.mir_shims(instance),
17441745
// async drop glue should be processed specifically, as a templated coroutine
1745-
ty::InstanceKind::AsyncDropGlue(_, _ty) => todo!(),
1746+
ty::InstanceKind::AsyncDropGlue(_, ty) => self.templated_optimized_mir(ty),
17461747
}
17471748
}
17481749

@@ -1853,16 +1854,17 @@ impl<'tcx> TyCtxt<'tcx> {
18531854
self.def_kind(trait_def_id) == DefKind::TraitAlias
18541855
}
18551856

1856-
/// Returns layout of a coroutine. Layout might be unavailable if the
1857+
/// Returns layout of a non-templated coroutine. Layout might be unavailable if the
18571858
/// coroutine is tainted by errors.
18581859
///
18591860
/// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
18601861
/// e.g. `args.as_coroutine().kind_ty()`.
1861-
pub fn coroutine_layout(
1862+
pub fn ordinary_coroutine_layout(
18621863
self,
18631864
def_id: DefId,
18641865
coroutine_kind_ty: Ty<'tcx>,
18651866
) -> Option<&'tcx CoroutineLayout<'tcx>> {
1867+
debug_assert_ne!(Some(def_id), self.lang_items().async_drop_in_place_poll_fn());
18661868
let mir = self.optimized_mir(def_id);
18671869
// Regular coroutine
18681870
if coroutine_kind_ty.is_unit() {
@@ -1892,6 +1894,66 @@ impl<'tcx> TyCtxt<'tcx> {
18921894
}
18931895
}
18941896

1897+
/// Returns layout of a templated coroutine. Layout might be unavailable if the
1898+
/// coroutine is tainted by errors. Atm, the only templated coroutine is
1899+
/// `async_drop_in_place<T>::{closure}` returned from `async fn async_drop_in_place<T>(..)`.
1900+
pub fn templated_coroutine_layout(self, ty: Ty<'tcx>) -> Option<&'tcx CoroutineLayout<'tcx>> {
1901+
self.templated_optimized_mir(ty).coroutine_layout_raw()
1902+
}
1903+
1904+
/// Returns layout of a templated (or not) coroutine. Layout might be unavailable if the
1905+
/// coroutine is tainted by errors.
1906+
pub fn coroutine_layout(
1907+
self,
1908+
def_id: DefId,
1909+
args: GenericArgsRef<'tcx>,
1910+
) -> Option<&'tcx CoroutineLayout<'tcx>> {
1911+
if Some(def_id) == self.lang_items().async_drop_in_place_poll_fn() {
1912+
fn find_impl_coroutine<'tcx>(tcx: TyCtxt<'tcx>, mut cor_ty: Ty<'tcx>) -> Ty<'tcx> {
1913+
let mut ty = cor_ty;
1914+
loop {
1915+
if let ty::Coroutine(def_id, args) = ty.kind() {
1916+
cor_ty = ty;
1917+
if tcx.is_templated_coroutine(*def_id) {
1918+
ty = args.first().unwrap().expect_ty();
1919+
continue;
1920+
} else {
1921+
return cor_ty;
1922+
}
1923+
} else {
1924+
return cor_ty;
1925+
}
1926+
}
1927+
}
1928+
// layout of `async_drop_in_place<T>::{closure}` in case,
1929+
// when T is a coroutine, contains this internal coroutine's ref
1930+
let arg_cor_ty = args.first().unwrap().expect_ty();
1931+
if arg_cor_ty.is_coroutine() {
1932+
let impl_cor_ty = find_impl_coroutine(self, arg_cor_ty);
1933+
let impl_ref = Ty::new_mut_ref(self, self.lifetimes.re_static, impl_cor_ty);
1934+
let span = self.def_span(def_id);
1935+
let source_info = SourceInfo::outermost(span);
1936+
let proxy_layout = CoroutineLayout {
1937+
field_tys: [CoroutineSavedTy {
1938+
ty: impl_ref,
1939+
source_info,
1940+
ignore_for_traits: true,
1941+
}]
1942+
.into(),
1943+
field_names: [None].into(),
1944+
variant_fields: [IndexVec::from([CoroutineSavedLocal::ZERO])].into(),
1945+
variant_source_info: [source_info].into(),
1946+
storage_conflicts: BitMatrix::new(1, 1),
1947+
};
1948+
return Some(self.arena.alloc(proxy_layout));
1949+
} else {
1950+
self.templated_coroutine_layout(Ty::new_coroutine(self, def_id, args))
1951+
}
1952+
} else {
1953+
self.ordinary_coroutine_layout(def_id, args.as_coroutine().kind_ty())
1954+
}
1955+
}
1956+
18951957
/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
18961958
/// If it implements no trait, returns `None`.
18971959
pub fn trait_id_of_impl(self, def_id: DefId) -> Option<DefId> {

compiler/rustc_middle/src/ty/sty.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
7878
#[inline]
7979
fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
8080
// FIXME requires optimized MIR
81-
FIRST_VARIANT
82-
..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
81+
FIRST_VARIANT..tcx.coroutine_layout(def_id, self.args).unwrap().variant_fields.next_index()
8382
}
8483

8584
/// The discriminant for the given variant. Panics if the `variant_index` is
@@ -139,10 +138,14 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
139138
def_id: DefId,
140139
tcx: TyCtxt<'tcx>,
141140
) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
142-
let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
141+
let layout = tcx.coroutine_layout(def_id, self.args).unwrap();
143142
layout.variant_fields.iter().map(move |variant| {
144143
variant.iter().map(move |field| {
145-
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
144+
if tcx.is_templated_coroutine(def_id) {
145+
layout.field_tys[*field].ty
146+
} else {
147+
ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
148+
}
146149
})
147150
})
148151
}

compiler/rustc_mir_dataflow/src/value_analysis.rs

+3
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ impl<'tcx> Map<'tcx> {
407407
if exclude.contains(local) {
408408
continue;
409409
}
410+
if decl.ty.is_templated_coroutine(tcx) {
411+
continue;
412+
}
410413

411414
// Create a place for the local.
412415
debug_assert!(self.locals[local].is_none());

compiler/rustc_mir_transform/src/known_panics_lint.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,11 @@ impl CanConstProp {
888888
};
889889
for (local, val) in cpv.can_const_prop.iter_enumerated_mut() {
890890
let ty = body.local_decls[local].ty;
891-
if ty.is_union() {
891+
if ty.is_templated_coroutine(tcx) {
892+
// No const propagation for templated coroutine (AsyncDropGlue)
893+
*val = ConstPropMode::NoPropagation;
894+
continue;
895+
} else if ty.is_union() {
892896
// Unions are incompatible with the current implementation of
893897
// const prop because Rust has no concept of an active
894898
// variant of a union

compiler/rustc_mir_transform/src/lib.rs

+46-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use rustc_middle::mir::{
3030
MirPhase, Operand, Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, START_BLOCK,
3131
SourceInfo, Statement, StatementKind, TerminatorKind,
3232
};
33-
use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
33+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
3434
use rustc_middle::util::Providers;
3535
use rustc_middle::{bug, query, span_bug};
3636
use rustc_mir_build::builder::build_mir;
@@ -214,9 +214,11 @@ pub fn provide(providers: &mut Providers) {
214214
mir_const_qualif,
215215
mir_promoted,
216216
mir_drops_elaborated_and_const_checked,
217+
templated_mir_drops_elaborated_and_const_checked,
217218
mir_for_ctfe,
218219
mir_coroutine_witnesses: coroutine::mir_coroutine_witnesses,
219220
optimized_mir,
221+
templated_optimized_mir,
220222
is_mir_available,
221223
is_ctfe_mir_available: is_mir_available,
222224
mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable,
@@ -517,6 +519,21 @@ fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &
517519
tcx.alloc_steal_mir(body)
518520
}
519521

522+
/// mir_drops_elaborated_and_const_checked simplified analog for templated coroutine
523+
fn templated_mir_drops_elaborated_and_const_checked<'tcx>(
524+
tcx: TyCtxt<'tcx>,
525+
ty: Ty<'tcx>,
526+
) -> &'tcx Steal<Body<'tcx>> {
527+
let ty::Coroutine(def_id, _) = ty.kind() else {
528+
bug!();
529+
};
530+
assert!(ty.is_templated_coroutine(tcx));
531+
532+
let instance = ty::InstanceKind::AsyncDropGlue(*def_id, ty);
533+
let body = tcx.mir_shims(instance).clone();
534+
tcx.alloc_steal_mir(body)
535+
}
536+
520537
// Made public so that `mir_drops_elaborated_and_const_checked` can be overridden
521538
// by custom rustc drivers, running all the steps by themselves. See #114628.
522539
pub fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
@@ -720,6 +737,11 @@ fn optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> &Body<'_> {
720737
tcx.arena.alloc(inner_optimized_mir(tcx, did))
721738
}
722739

740+
/// Optimize the templated MIR and prepare it for codegen.
741+
fn templated_optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> &'tcx Body<'tcx> {
742+
tcx.arena.alloc(inner_templated_optimized_mir(tcx, ty))
743+
}
744+
723745
fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
724746
if tcx.is_constructor(did.to_def_id()) {
725747
// There's no reason to run all of the MIR passes on constructors when
@@ -764,6 +786,29 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
764786
body
765787
}
766788

789+
fn inner_templated_optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Body<'tcx> {
790+
debug!("about to call templated_mir_drops_elaborated...");
791+
let body = tcx.templated_mir_drops_elaborated_and_const_checked(ty).steal();
792+
let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst);
793+
794+
if body.tainted_by_errors.is_some() {
795+
return body;
796+
}
797+
798+
// If `mir_drops_elaborated_and_const_checked` found that the current body has unsatisfiable
799+
// predicates, it will shrink the MIR to a single `unreachable` terminator.
800+
// More generally, if MIR is a lone `unreachable`, there is nothing to optimize.
801+
if let TerminatorKind::Unreachable = body.basic_blocks[START_BLOCK].terminator().kind
802+
&& body.basic_blocks[START_BLOCK].statements.is_empty()
803+
{
804+
return body;
805+
}
806+
807+
run_optimization_passes(tcx, &mut body);
808+
809+
body
810+
}
811+
767812
/// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for
768813
/// constant evaluation once all generic parameters become known.
769814
fn promoted_mir(tcx: TyCtxt<'_>, def: LocalDefId) -> &IndexVec<Promoted, Body<'_>> {

0 commit comments

Comments
 (0)