Skip to content

Commit 2c5b029

Browse files
committed
Auto merge of rust-lang#117703 - compiler-errors:recursive-async, r=<try>
Support async recursive calls (as long as they have indirection) Before rust-lang#101692, we stored coroutine witness types directly inside of the coroutine. That means that a coroutine could not contain itself (as a witness field) without creating a cycle in the type representation of the coroutine, which we detected with the `OpaqueTypeExpander`, which is used to detect cycles when expanding opaque types after that are inferred to contain themselves. After `-Zdrop-tracking-mir` was stabilized, we no longer store these generator witness fields directly, but instead behind a def-id based query. That means there is no technical obstacle in the compiler preventing coroutines from containing themselves per se, other than the fact that for a coroutine to have a non-infinite layout, it must contain itself wrapped in a layer of allocation indirection (like a `Box`). This means that it should be valid for this code to work: ``` async fn async_fibonacci(i: u32) -> u32 { if i == 0 || i == 1 { i } else { Box::pin(async_fibonacci(i - 1)).await + Box::pin(async_fibonacci(i - 2)).await } } ``` Whereas previously, you'd need to coerce the future to `Pin<Box<dyn Future<Output = ...>>` before `await`ing it, to prevent the async's desugared coroutine from containing itself across as await point. This PR does two things: 1. Remove the behavior from `OpaqueTypeExpander` where it intentionally fetches and walks through the coroutine's witness fields. This was kept around after `-Zdrop-tracking-mir` was stabilized so we would not be introducing new stable behavior, and to preserve the much better diagnostics of async recursion compared to a layout error. 2. Reworks the way we report layout errors having to do with coroutines, to make up for the diagnostic regressions introduced by (1.). We actually do even better now, pointing out the call sites of the recursion!
2 parents 0828c15 + 19eb35e commit 2c5b029

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+310
-189
lines changed

compiler/rustc_ast_lowering/src/expr.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
792792
// debuggers and debugger extensions expect it to be called `__awaitee`. They use
793793
// this name to identify what is being awaited by a suspended async functions.
794794
let awaitee_ident = Ident::with_dummy_span(sym::__awaitee);
795-
let (awaitee_pat, awaitee_pat_hid) =
796-
self.pat_ident_binding_mode(span, awaitee_ident, hir::BindingAnnotation::MUT);
795+
let (awaitee_pat, awaitee_pat_hid) = self.pat_ident_binding_mode(
796+
gen_future_span,
797+
awaitee_ident,
798+
hir::BindingAnnotation::MUT,
799+
);
797800

798801
let task_context_ident = Ident::with_dummy_span(sym::_task_context);
799802

compiler/rustc_hir/src/hir.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,15 @@ pub enum CoroutineKind {
15291529
Coroutine,
15301530
}
15311531

1532+
impl CoroutineKind {
1533+
pub fn is_fn_like(self) -> bool {
1534+
matches!(
1535+
self,
1536+
CoroutineKind::Async(CoroutineSource::Fn) | CoroutineKind::Gen(CoroutineSource::Fn)
1537+
)
1538+
}
1539+
}
1540+
15321541
impl fmt::Display for CoroutineKind {
15331542
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15341543
match self {

compiler/rustc_hir_analysis/src/check/check.rs

+29-23
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use rustc_middle::middle::stability::EvalResult;
1818
use rustc_middle::traits::{DefiningAnchor, ObligationCauseCode};
1919
use rustc_middle::ty::fold::BottomUpFolder;
2020
use rustc_middle::ty::layout::{LayoutError, MAX_SIMD_LANES};
21-
use rustc_middle::ty::util::{Discr, IntTypeExt};
21+
use rustc_middle::ty::util::{Discr, InspectCoroutineFields, IntTypeExt};
2222
use rustc_middle::ty::GenericArgKind;
2323
use rustc_middle::ty::{
2424
self, AdtDef, ParamEnv, RegionKind, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable,
@@ -216,13 +216,12 @@ fn check_opaque(tcx: TyCtxt<'_>, id: hir::ItemId) {
216216
return;
217217
}
218218

219-
let args = GenericArgs::identity_for_item(tcx, item.owner_id);
220219
let span = tcx.def_span(item.owner_id.def_id);
221220

222221
if tcx.type_of(item.owner_id.def_id).instantiate_identity().references_error() {
223222
return;
224223
}
225-
if check_opaque_for_cycles(tcx, item.owner_id.def_id, args, span, &origin).is_err() {
224+
if check_opaque_for_cycles(tcx, item.owner_id.def_id, span).is_err() {
226225
return;
227226
}
228227

@@ -233,19 +232,36 @@ fn check_opaque(tcx: TyCtxt<'_>, id: hir::ItemId) {
233232
pub(super) fn check_opaque_for_cycles<'tcx>(
234233
tcx: TyCtxt<'tcx>,
235234
def_id: LocalDefId,
236-
args: GenericArgsRef<'tcx>,
237235
span: Span,
238-
origin: &hir::OpaqueTyOrigin,
239236
) -> Result<(), ErrorGuaranteed> {
240-
if tcx.try_expand_impl_trait_type(def_id.to_def_id(), args).is_err() {
241-
let reported = match origin {
242-
hir::OpaqueTyOrigin::AsyncFn(..) => async_opaque_type_cycle_error(tcx, span),
243-
_ => opaque_type_cycle_error(tcx, def_id, span),
244-
};
245-
Err(reported)
246-
} else {
247-
Ok(())
237+
let args = GenericArgs::identity_for_item(tcx, def_id);
238+
239+
// First, try to look at any opaque expansion cycles, considering coroutine fields
240+
// (even though these aren't necessarily true errors).
241+
if tcx
242+
.try_expand_impl_trait_type(def_id.to_def_id(), args, InspectCoroutineFields::Yes)
243+
.is_err()
244+
{
245+
// Look for true opaque expansion cycles, but ignore coroutines.
246+
// This will give us any true errors. Coroutines are only problematic
247+
// if they cause layout computation errors.
248+
if tcx
249+
.try_expand_impl_trait_type(def_id.to_def_id(), args, InspectCoroutineFields::No)
250+
.is_err()
251+
{
252+
let reported = opaque_type_cycle_error(tcx, def_id, span);
253+
return Err(reported);
254+
}
255+
256+
// And also look for cycle errors in the layout of coroutines.
257+
if let Err(&LayoutError::Cycle(guar)) =
258+
tcx.layout_of(tcx.param_env(def_id).and(Ty::new_opaque(tcx, def_id.to_def_id(), args)))
259+
{
260+
return Err(guar);
261+
}
248262
}
263+
264+
Ok(())
249265
}
250266

251267
/// Check that the concrete type behind `impl Trait` actually implements `Trait`.
@@ -1324,16 +1340,6 @@ pub(super) fn check_mod_item_types(tcx: TyCtxt<'_>, module_def_id: LocalModDefId
13241340
}
13251341
}
13261342

1327-
fn async_opaque_type_cycle_error(tcx: TyCtxt<'_>, span: Span) -> ErrorGuaranteed {
1328-
struct_span_err!(tcx.sess, span, E0733, "recursion in an `async fn` requires boxing")
1329-
.span_label(span, "recursive `async fn`")
1330-
.note("a recursive `async fn` must be rewritten to return a boxed `dyn Future`")
1331-
.note(
1332-
"consider using the `async_recursion` crate: https://crates.io/crates/async_recursion",
1333-
)
1334-
.emit()
1335-
}
1336-
13371343
/// Emit an error for recursive opaque types.
13381344
///
13391345
/// If this is a return `impl Trait`, find the item's return expressions and point at them. For

compiler/rustc_middle/src/query/keys.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub trait Key: Sized {
4040
None
4141
}
4242

43-
fn ty_adt_id(&self) -> Option<DefId> {
43+
fn ty_def_id(&self) -> Option<DefId> {
4444
None
4545
}
4646
}
@@ -406,9 +406,10 @@ impl<'tcx> Key for Ty<'tcx> {
406406
DUMMY_SP
407407
}
408408

409-
fn ty_adt_id(&self) -> Option<DefId> {
410-
match self.kind() {
409+
fn ty_def_id(&self) -> Option<DefId> {
410+
match *self.kind() {
411411
ty::Adt(adt, _) => Some(adt.did()),
412+
ty::Coroutine(def_id, ..) => Some(def_id),
412413
_ => None,
413414
}
414415
}
@@ -452,6 +453,10 @@ impl<'tcx, T: Key> Key for ty::ParamEnvAnd<'tcx, T> {
452453
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
453454
self.value.default_span(tcx)
454455
}
456+
457+
fn ty_def_id(&self) -> Option<DefId> {
458+
self.value.ty_def_id()
459+
}
455460
}
456461

457462
impl Key for Symbol {
@@ -550,7 +555,7 @@ impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
550555
DUMMY_SP
551556
}
552557

553-
fn ty_adt_id(&self) -> Option<DefId> {
558+
fn ty_def_id(&self) -> Option<DefId> {
554559
match self.1.value.kind() {
555560
ty::Adt(adt, _) => Some(adt.did()),
556561
_ => None,

compiler/rustc_middle/src/query/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,8 @@ rustc_queries! {
13901390
) -> Result<ty::layout::TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
13911391
depth_limit
13921392
desc { "computing layout of `{}`", key.value }
1393+
// we emit our own error during query cycle handling
1394+
cycle_delay_bug
13931395
}
13941396

13951397
/// Compute a `FnAbi` suitable for indirect calls, i.e. to `fn` pointers.

compiler/rustc_middle/src/query/plumbing.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub struct DynamicQuery<'tcx, C: QueryCache> {
5353
fn(tcx: TyCtxt<'tcx>, key: &C::Key, index: SerializedDepNodeIndex) -> bool,
5454
pub hash_result: HashResult<C::Value>,
5555
pub value_from_cycle_error:
56-
fn(tcx: TyCtxt<'tcx>, cycle: &[QueryInfo], guar: ErrorGuaranteed) -> C::Value,
56+
fn(tcx: TyCtxt<'tcx>, cycle_error: &CycleError, guar: ErrorGuaranteed) -> C::Value,
5757
pub format_value: fn(&C::Value) -> String,
5858
}
5959

compiler/rustc_middle/src/ty/layout.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ pub enum LayoutError<'tcx> {
215215
SizeOverflow(Ty<'tcx>),
216216
NormalizationFailure(Ty<'tcx>, NormalizationError<'tcx>),
217217
ReferencesError(ErrorGuaranteed),
218-
Cycle,
218+
Cycle(ErrorGuaranteed),
219219
}
220220

221221
impl<'tcx> LayoutError<'tcx> {
@@ -226,7 +226,7 @@ impl<'tcx> LayoutError<'tcx> {
226226
Unknown(_) => middle_unknown_layout,
227227
SizeOverflow(_) => middle_values_too_big,
228228
NormalizationFailure(_, _) => middle_cannot_be_normalized,
229-
Cycle => middle_cycle,
229+
Cycle(_) => middle_cycle,
230230
ReferencesError(_) => middle_layout_references_error,
231231
}
232232
}
@@ -240,7 +240,7 @@ impl<'tcx> LayoutError<'tcx> {
240240
NormalizationFailure(ty, e) => {
241241
E::NormalizationFailure { ty, failure_ty: e.get_type_for_failure() }
242242
}
243-
Cycle => E::Cycle,
243+
Cycle(_) => E::Cycle,
244244
ReferencesError(_) => E::ReferencesError,
245245
}
246246
}
@@ -261,7 +261,7 @@ impl<'tcx> fmt::Display for LayoutError<'tcx> {
261261
t,
262262
e.get_type_for_failure()
263263
),
264-
LayoutError::Cycle => write!(f, "a cycle occurred during layout computation"),
264+
LayoutError::Cycle(_) => write!(f, "a cycle occurred during layout computation"),
265265
LayoutError::ReferencesError(_) => write!(f, "the type has an unknown layout"),
266266
}
267267
}
@@ -333,7 +333,7 @@ impl<'tcx> SizeSkeleton<'tcx> {
333333
Err(err @ LayoutError::Unknown(_)) => err,
334334
// We can't extract SizeSkeleton info from other layout errors
335335
Err(
336-
e @ LayoutError::Cycle
336+
e @ LayoutError::Cycle(_)
337337
| e @ LayoutError::SizeOverflow(_)
338338
| e @ LayoutError::NormalizationFailure(..)
339339
| e @ LayoutError::ReferencesError(_),

compiler/rustc_middle/src/ty/util.rs

+26-9
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ impl<'tcx> TyCtxt<'tcx> {
711711
check_recursion: false,
712712
expand_coroutines: false,
713713
tcx: self,
714+
inspect_coroutine_fields: InspectCoroutineFields::No,
714715
};
715716
val.fold_with(&mut visitor)
716717
}
@@ -721,6 +722,7 @@ impl<'tcx> TyCtxt<'tcx> {
721722
self,
722723
def_id: DefId,
723724
args: GenericArgsRef<'tcx>,
725+
inspect_coroutine_fields: InspectCoroutineFields,
724726
) -> Result<Ty<'tcx>, Ty<'tcx>> {
725727
let mut visitor = OpaqueTypeExpander {
726728
seen_opaque_tys: FxHashSet::default(),
@@ -731,6 +733,7 @@ impl<'tcx> TyCtxt<'tcx> {
731733
check_recursion: true,
732734
expand_coroutines: true,
733735
tcx: self,
736+
inspect_coroutine_fields,
734737
};
735738

736739
let expanded_type = visitor.expand_opaque_ty(def_id, args).unwrap();
@@ -747,9 +750,13 @@ impl<'tcx> TyCtxt<'tcx> {
747750
match def_kind {
748751
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method",
749752
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
750-
rustc_hir::CoroutineKind::Async(..) => "async closure",
751-
rustc_hir::CoroutineKind::Coroutine => "coroutine",
752-
rustc_hir::CoroutineKind::Gen(..) => "gen closure",
753+
hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "async fn",
754+
hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "async block",
755+
hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "async closure",
756+
hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "gen fn",
757+
hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "gen block",
758+
hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "gen closure",
759+
hir::CoroutineKind::Coroutine => "coroutine",
753760
},
754761
_ => def_kind.descr(def_id),
755762
}
@@ -765,9 +772,9 @@ impl<'tcx> TyCtxt<'tcx> {
765772
match def_kind {
766773
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "a",
767774
DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
768-
rustc_hir::CoroutineKind::Async(..) => "an",
769-
rustc_hir::CoroutineKind::Coroutine => "a",
770-
rustc_hir::CoroutineKind::Gen(..) => "a",
775+
hir::CoroutineKind::Async(..) => "an",
776+
hir::CoroutineKind::Coroutine => "a",
777+
hir::CoroutineKind::Gen(..) => "a",
771778
},
772779
_ => def_kind.article(),
773780
}
@@ -812,6 +819,13 @@ struct OpaqueTypeExpander<'tcx> {
812819
/// recursion, and 'false' otherwise to avoid unnecessary work.
813820
check_recursion: bool,
814821
tcx: TyCtxt<'tcx>,
822+
inspect_coroutine_fields: InspectCoroutineFields,
823+
}
824+
825+
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
826+
pub enum InspectCoroutineFields {
827+
No,
828+
Yes,
815829
}
816830

817831
impl<'tcx> OpaqueTypeExpander<'tcx> {
@@ -853,9 +867,11 @@ impl<'tcx> OpaqueTypeExpander<'tcx> {
853867
let expanded_ty = match self.expanded_cache.get(&(def_id, args)) {
854868
Some(expanded_ty) => *expanded_ty,
855869
None => {
856-
for bty in self.tcx.coroutine_hidden_types(def_id) {
857-
let hidden_ty = bty.instantiate(self.tcx, args);
858-
self.fold_ty(hidden_ty);
870+
if matches!(self.inspect_coroutine_fields, InspectCoroutineFields::Yes) {
871+
for bty in self.tcx.coroutine_hidden_types(def_id) {
872+
let hidden_ty = bty.instantiate(self.tcx, args);
873+
self.fold_ty(hidden_ty);
874+
}
859875
}
860876
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
861877
self.expanded_cache.insert((def_id, args), expanded_ty);
@@ -1433,6 +1449,7 @@ pub fn reveal_opaque_types_in_bounds<'tcx>(
14331449
check_recursion: false,
14341450
expand_coroutines: false,
14351451
tcx,
1452+
inspect_coroutine_fields: InspectCoroutineFields::No,
14361453
};
14371454
val.fold_with(&mut visitor)
14381455
}

0 commit comments

Comments
 (0)