Skip to content

Commit 1a68b6f

Browse files
relocate upvars to Unresumed state and make coroutine prefix trivial
Co-authored-by: Dario Nieuwenhuis <[email protected]>
1 parent 01e2fff commit 1a68b6f

File tree

55 files changed

+995
-430
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+995
-430
lines changed

Diff for: compiler/rustc_borrowck/src/lib.rs

+20-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use std::ops::Deref;
2323
use consumers::{BodyWithBorrowckFacts, ConsumerOptions};
2424
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
2525
use rustc_data_structures::graph::dominators::Dominators;
26+
use rustc_data_structures::unord::UnordMap;
2627
use rustc_errors::Diag;
2728
use rustc_hir as hir;
2829
use rustc_hir::def_id::LocalDefId;
@@ -287,6 +288,7 @@ fn do_mir_borrowck<'tcx>(
287288
regioncx: &regioncx,
288289
used_mut: Default::default(),
289290
used_mut_upvars: SmallVec::new(),
291+
local_from_upvars: UnordMap::default(),
290292
borrow_set: &borrow_set,
291293
upvars: &[],
292294
local_names: IndexVec::from_elem(None, &promoted_body.local_decls),
@@ -313,6 +315,12 @@ fn do_mir_borrowck<'tcx>(
313315
}
314316
}
315317

318+
let mut local_from_upvars = UnordMap::default();
319+
for (field, &local) in body.local_upvar_map.iter_enumerated() {
320+
let Some(local) = local else { continue };
321+
local_from_upvars.insert(local, field);
322+
}
323+
debug!(?local_from_upvars, "dxf");
316324
let mut mbcx = MirBorrowckCtxt {
317325
infcx: &infcx,
318326
param_env,
@@ -328,6 +336,7 @@ fn do_mir_borrowck<'tcx>(
328336
regioncx: &regioncx,
329337
used_mut: Default::default(),
330338
used_mut_upvars: SmallVec::new(),
339+
local_from_upvars,
331340
borrow_set: &borrow_set,
332341
upvars: tcx.closure_captures(def),
333342
local_names,
@@ -563,6 +572,9 @@ struct MirBorrowckCtxt<'a, 'infcx, 'tcx> {
563572
/// If the function we're checking is a closure, then we'll need to report back the list of
564573
/// mutable upvars that have been used. This field keeps track of them.
565574
used_mut_upvars: SmallVec<[FieldIdx; 8]>,
575+
/// Since upvars are moved to real locals, we need to map mutations to the locals back to
576+
/// the upvars, so that used_mut_upvars is up-to-date.
577+
local_from_upvars: UnordMap<Local, FieldIdx>,
566578
/// Region inference context. This contains the results from region inference and lets us e.g.
567579
/// find out which CFG points are contained in each borrow region.
568580
regioncx: &'a RegionInferenceContext<'tcx>,
@@ -2218,10 +2230,12 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {
22182230
// If the local may have been initialized, and it is now currently being
22192231
// mutated, then it is justified to be annotated with the `mut`
22202232
// keyword, since the mutation may be a possible reassignment.
2221-
if is_local_mutation_allowed != LocalMutationIsAllowed::Yes
2222-
&& self.is_local_ever_initialized(local, state).is_some()
2223-
{
2224-
self.used_mut.insert(local);
2233+
if !matches!(is_local_mutation_allowed, LocalMutationIsAllowed::Yes) {
2234+
if self.is_local_ever_initialized(local, state).is_some() {
2235+
self.used_mut.insert(local);
2236+
} else if let Some(&field) = self.local_from_upvars.get(&local) {
2237+
self.used_mut_upvars.push(field);
2238+
}
22252239
}
22262240
}
22272241
RootPlace {
@@ -2239,6 +2253,8 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {
22392253
projection: place_projection,
22402254
}) {
22412255
self.used_mut_upvars.push(field);
2256+
} else if let Some(&field) = self.local_from_upvars.get(&place_local) {
2257+
self.used_mut_upvars.push(field);
22422258
}
22432259
}
22442260
}

Diff for: compiler/rustc_borrowck/src/type_check/mod.rs

+12-12
Original file line numberDiff line numberDiff line change
@@ -812,15 +812,15 @@ impl<'a, 'b, 'tcx> TypeVerifier<'a, 'b, 'tcx> {
812812
}),
813813
};
814814
}
815-
ty::Coroutine(_, args) => {
815+
ty::Coroutine(_def_id, args) => {
816816
// Only prefix fields (upvars and current state) are
817817
// accessible without a variant index.
818-
return match args.as_coroutine().prefix_tys().get(field.index()) {
819-
Some(ty) => Ok(*ty),
820-
None => Err(FieldAccessError::OutOfRange {
821-
field_count: args.as_coroutine().prefix_tys().len(),
822-
}),
823-
};
818+
let upvar_tys = args.as_coroutine().upvar_tys();
819+
if let Some(ty) = upvar_tys.get(field.index()) {
820+
return Ok(*ty);
821+
} else {
822+
return Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() });
823+
}
824824
}
825825
ty::Tuple(tys) => {
826826
return match tys.get(field.index()) {
@@ -1838,11 +1838,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
18381838
// It doesn't make sense to look at a field beyond the prefix;
18391839
// these require a variant index, and are not initialized in
18401840
// aggregate rvalues.
1841-
match args.as_coroutine().prefix_tys().get(field_index.as_usize()) {
1842-
Some(ty) => Ok(*ty),
1843-
None => Err(FieldAccessError::OutOfRange {
1844-
field_count: args.as_coroutine().prefix_tys().len(),
1845-
}),
1841+
let upvar_tys = &args.as_coroutine().upvar_tys();
1842+
if let Some(ty) = upvar_tys.get(field_index.as_usize()) {
1843+
Ok(*ty)
1844+
} else {
1845+
Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() })
18461846
}
18471847
}
18481848
AggregateKind::CoroutineClosure(_, args) => {

Diff for: compiler/rustc_codegen_cranelift/src/base.rs

+3
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,9 @@ fn codegen_stmt<'tcx>(
891891
let variant_dest = lval.downcast_variant(fx, variant_index);
892892
(variant_index, variant_dest, active_field_index)
893893
}
894+
mir::AggregateKind::Coroutine(_def_id, _args) => {
895+
(FIRST_VARIANT, lval.downcast_variant(fx, FIRST_VARIANT), None)
896+
}
894897
_ => (FIRST_VARIANT, lval, None),
895898
};
896899
if active_field_index.is_some() {

Diff for: compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ use rustc_hir::def_id::{DefId, LOCAL_CRATE};
1313
use rustc_middle::bug;
1414
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1515
use rustc_middle::ty::{
16-
self, AdtKind, CoroutineArgsExt, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt,
17-
Visibility,
16+
self, AdtKind, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt, Visibility,
1817
};
1918
use rustc_session::config::{self, DebugInfo, Lto};
2019
use rustc_span::symbol::Symbol;
@@ -1124,7 +1123,7 @@ fn build_upvar_field_di_nodes<'ll, 'tcx>(
11241123
closure_or_coroutine_di_node: &'ll DIType,
11251124
) -> SmallVec<&'ll DIType> {
11261125
let (&def_id, up_var_tys) = match closure_or_coroutine_ty.kind() {
1127-
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().prefix_tys()),
1126+
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().upvar_tys()),
11281127
ty::Closure(def_id, args) => (def_id, args.as_closure().upvar_tys()),
11291128
ty::CoroutineClosure(def_id, args) => (def_id, args.as_coroutine_closure().upvar_tys()),
11301129
_ => {

Diff for: compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs

-2
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
672672
let coroutine_layout =
673673
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
674674

675-
let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
676675
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
677676
let variant_count = (variant_range.start.as_u32()..variant_range.end.as_u32()).len();
678677

@@ -707,7 +706,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
707706
coroutine_type_and_layout,
708707
coroutine_type_di_node,
709708
coroutine_layout,
710-
common_upvar_names,
711709
);
712710

713711
let span = coroutine_layout.variant_source_info[variant_index].span;

Diff for: compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs

+2-31
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@ use std::borrow::Cow;
33
use rustc_codegen_ssa::debuginfo::type_names::{compute_debuginfo_type_name, cpp_like_debuginfo};
44
use rustc_codegen_ssa::debuginfo::{tag_base_type, wants_c_like_enum_debuginfo};
55
use rustc_hir::def::CtorKind;
6-
use rustc_index::IndexSlice;
76
use rustc_middle::bug;
87
use rustc_middle::mir::CoroutineLayout;
98
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
109
use rustc_middle::ty::{self, AdtDef, CoroutineArgs, CoroutineArgsExt, Ty, VariantDef};
11-
use rustc_span::Symbol;
1210
use rustc_target::abi::{FieldIdx, TagEncoding, VariantIdx, Variants};
1311

1412
use super::type_map::{DINodeCreationResult, UniqueTypeId};
@@ -263,7 +261,6 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
263261
coroutine_type_and_layout: TyAndLayout<'tcx>,
264262
coroutine_type_di_node: &'ll DIType,
265263
coroutine_layout: &CoroutineLayout<'tcx>,
266-
common_upvar_names: &IndexSlice<FieldIdx, Symbol>,
267264
) -> &'ll DIType {
268265
let variant_name = CoroutineArgs::variant_name(variant_index);
269266
let unique_type_id = UniqueTypeId::for_enum_variant_struct_type(
@@ -274,11 +271,6 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
274271

275272
let variant_layout = coroutine_type_and_layout.for_variant(cx, variant_index);
276273

277-
let coroutine_args = match coroutine_type_and_layout.ty.kind() {
278-
ty::Coroutine(_, args) => args.as_coroutine(),
279-
_ => unreachable!(),
280-
};
281-
282274
type_map::build_type_with_children(
283275
cx,
284276
type_map::stub(
@@ -292,7 +284,7 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
292284
),
293285
|cx, variant_struct_type_di_node| {
294286
// Fields that just belong to this variant/state
295-
let state_specific_fields: SmallVec<_> = (0..variant_layout.fields.count())
287+
(0..variant_layout.fields.count())
296288
.map(|field_index| {
297289
let coroutine_saved_local = coroutine_layout.variant_fields[variant_index]
298290
[FieldIdx::from_usize(field_index)];
@@ -314,28 +306,7 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
314306
type_di_node(cx, field_type),
315307
)
316308
})
317-
.collect();
318-
319-
// Fields that are common to all states
320-
let common_fields: SmallVec<_> = coroutine_args
321-
.prefix_tys()
322-
.iter()
323-
.zip(common_upvar_names)
324-
.enumerate()
325-
.map(|(index, (upvar_ty, upvar_name))| {
326-
build_field_di_node(
327-
cx,
328-
variant_struct_type_di_node,
329-
upvar_name.as_str(),
330-
cx.size_and_align_of(upvar_ty),
331-
coroutine_type_and_layout.fields.offset(index),
332-
DIFlags::FlagZero,
333-
type_di_node(cx, upvar_ty),
334-
)
335-
})
336-
.collect();
337-
338-
state_specific_fields.into_iter().chain(common_fields).collect()
309+
.collect()
339310
},
340311
|cx| build_generic_type_param_di_nodes(cx, coroutine_type_and_layout.ty),
341312
)

Diff for: compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs

-4
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
160160
)
161161
};
162162

163-
let common_upvar_names =
164-
cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
165-
166163
// Build variant struct types
167164
let variant_struct_type_di_nodes: SmallVec<_> = variants
168165
.indices()
@@ -190,7 +187,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
190187
coroutine_type_and_layout,
191188
coroutine_type_di_node,
192189
coroutine_layout,
193-
common_upvar_names,
194190
),
195191
source_info,
196192
}

Diff for: compiler/rustc_codegen_ssa/src/mir/operand.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
1010
use rustc_middle::mir::{self, ConstValue};
1111
use rustc_middle::ty::Ty;
1212
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
13-
use tracing::debug;
13+
use tracing::{debug, instrument};
1414

1515
use super::place::{PlaceRef, PlaceValue};
1616
use super::{FunctionCx, LocalRef};
@@ -551,13 +551,12 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {
551551
}
552552

553553
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
554+
#[instrument(level = "debug", skip(self, bx), ret)]
554555
fn maybe_codegen_consume_direct(
555556
&mut self,
556557
bx: &mut Bx,
557558
place_ref: mir::PlaceRef<'tcx>,
558559
) -> Option<OperandRef<'tcx, Bx::Value>> {
559-
debug!("maybe_codegen_consume_direct(place_ref={:?})", place_ref);
560-
561560
match self.locals[place_ref.local] {
562561
LocalRef::Operand(mut o) => {
563562
// Moves out of scalar and scalar pair fields are trivial.
@@ -600,13 +599,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
600599
}
601600
}
602601

602+
#[instrument(level = "debug", skip(self, bx), ret)]
603603
pub fn codegen_consume(
604604
&mut self,
605605
bx: &mut Bx,
606606
place_ref: mir::PlaceRef<'tcx>,
607607
) -> OperandRef<'tcx, Bx::Value> {
608-
debug!("codegen_consume(place_ref={:?})", place_ref);
609-
610608
let ty = self.monomorphized_place_ty(place_ref);
611609
let layout = bx.cx().layout_of(ty);
612610

@@ -625,13 +623,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
625623
bx.load_operand(place)
626624
}
627625

626+
#[instrument(level = "debug", skip(self, bx), ret)]
628627
pub fn codegen_operand(
629628
&mut self,
630629
bx: &mut Bx,
631630
operand: &mir::Operand<'tcx>,
632631
) -> OperandRef<'tcx, Bx::Value> {
633-
debug!("codegen_operand(operand={:?})", operand);
634-
635632
match *operand {
636633
mir::Operand::Copy(ref place) | mir::Operand::Move(ref place) => {
637634
self.codegen_consume(bx, place.as_ref())

Diff for: compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+3
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
130130
let variant_dest = dest.project_downcast(bx, variant_index);
131131
(variant_index, variant_dest, active_field_index)
132132
}
133+
mir::AggregateKind::Coroutine(_, _) => {
134+
(FIRST_VARIANT, dest.project_downcast(bx, FIRST_VARIANT), None)
135+
}
133136
_ => (FIRST_VARIANT, dest, None),
134137
};
135138
if active_field_index.is_some() {

Diff for: compiler/rustc_const_eval/src/interpret/step.rs

+3
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
294294
let variant_dest = self.project_downcast(dest, variant_index)?;
295295
(variant_index, variant_dest, active_field_index)
296296
}
297+
mir::AggregateKind::Coroutine(_def_id, _args) => {
298+
(FIRST_VARIANT, self.project_downcast(dest, FIRST_VARIANT)?, None)
299+
}
297300
mir::AggregateKind::RawPtr(..) => {
298301
// Pointers don't have "fields" in the normal sense, so the
299302
// projection-based code below would either fail in projection

Diff for: compiler/rustc_middle/src/mir/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,9 @@ pub struct Body<'tcx> {
368368
/// If `-Cinstrument-coverage` is not active, or if an individual function
369369
/// is not eligible for coverage, then this should always be `None`.
370370
pub function_coverage_info: Option<Box<coverage::FunctionCoverageInfo>>,
371+
372+
/// Coroutine local-upvar map
373+
pub local_upvar_map: IndexVec<FieldIdx, Option<Local>>,
371374
}
372375

373376
impl<'tcx> Body<'tcx> {
@@ -411,6 +414,7 @@ impl<'tcx> Body<'tcx> {
411414
tainted_by_errors,
412415
coverage_info_hi: None,
413416
function_coverage_info: None,
417+
local_upvar_map: IndexVec::new(),
414418
};
415419
body.is_polymorphic = body.has_non_region_param();
416420
body
@@ -442,6 +446,7 @@ impl<'tcx> Body<'tcx> {
442446
tainted_by_errors: None,
443447
coverage_info_hi: None,
444448
function_coverage_info: None,
449+
local_upvar_map: IndexVec::new(),
445450
};
446451
body.is_polymorphic = body.has_non_region_param();
447452
body

Diff for: compiler/rustc_middle/src/mir/patch.rs

+4
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,14 @@ impl<'tcx> MirPatch<'tcx> {
155155
ty: Ty<'tcx>,
156156
span: Span,
157157
local_info: LocalInfo<'tcx>,
158+
immutable: bool,
158159
) -> Local {
159160
let index = self.next_local;
160161
self.next_local += 1;
161162
let mut new_decl = LocalDecl::new(ty, span);
163+
if immutable {
164+
new_decl = new_decl.immutable();
165+
}
162166
**new_decl.local_info.as_mut().assert_crate_local() = local_info;
163167
self.new_locals.push(new_decl);
164168
Local::new(index)

Diff for: compiler/rustc_middle/src/ty/layout.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -891,9 +891,10 @@ where
891891
),
892892
Variants::Multiple { tag, tag_field, .. } => {
893893
if i == tag_field {
894-
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
894+
TyMaybeWithLayout::TyAndLayout(tag_layout(tag))
895+
} else {
896+
TyMaybeWithLayout::Ty(args.as_coroutine().upvar_tys()[i])
895897
}
896-
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
897898
}
898899
},
899900

Diff for: compiler/rustc_middle/src/ty/sty.rs

-7
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,6 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
148148
})
149149
})
150150
}
151-
152-
/// This is the types of the fields of a coroutine which are not stored in a
153-
/// variant.
154-
#[inline]
155-
fn prefix_tys(self) -> &'tcx List<Ty<'tcx>> {
156-
self.upvar_tys()
157-
}
158151
}
159152

160153
#[derive(Debug, Copy, Clone, HashStable, TypeFoldable, TypeVisitable)]

Diff for: compiler/rustc_mir_build/src/build/custom/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub(super) fn build_custom_mir<'tcx>(
6262
pass_count: 0,
6363
coverage_info_hi: None,
6464
function_coverage_info: None,
65+
local_upvar_map: IndexVec::new(),
6566
};
6667

6768
body.local_decls.push(LocalDecl::new(return_ty, return_ty_span));

0 commit comments

Comments
 (0)