Skip to content

Commit 0eb2adb

Browse files
Add async bound modifier to enable async Fn bounds
1 parent cdaa12e commit 0eb2adb

File tree

12 files changed

+199
-59
lines changed

12 files changed

+199
-59
lines changed

compiler/rustc_ast/src/ast.rs

+26-3
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,16 @@ pub use crate::node_id::{NodeId, CRATE_NODE_ID, DUMMY_NODE_ID};
291291
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
292292
pub struct TraitBoundModifiers {
293293
pub constness: BoundConstness,
294+
pub asyncness: BoundAsyncness,
294295
pub polarity: BoundPolarity,
295296
}
296297

297298
impl TraitBoundModifiers {
298-
pub const NONE: Self =
299-
Self { constness: BoundConstness::Never, polarity: BoundPolarity::Positive };
299+
pub const NONE: Self = Self {
300+
constness: BoundConstness::Never,
301+
asyncness: BoundAsyncness::Normal,
302+
polarity: BoundPolarity::Positive,
303+
};
300304
}
301305

302306
/// The AST represents all type param bounds as types.
@@ -2562,6 +2566,25 @@ impl BoundConstness {
25622566
}
25632567
}
25642568

2569+
/// The asyncness of a trait bound.
2570+
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
2571+
#[derive(HashStable_Generic)]
2572+
pub enum BoundAsyncness {
2573+
/// `Type: Trait`
2574+
Normal,
2575+
/// `Type: async Trait`
2576+
Async(Span),
2577+
}
2578+
2579+
impl BoundAsyncness {
2580+
pub fn as_str(self) -> &'static str {
2581+
match self {
2582+
Self::Normal => "",
2583+
Self::Async(_) => "async",
2584+
}
2585+
}
2586+
}
2587+
25652588
#[derive(Clone, Encodable, Decodable, Debug)]
25662589
pub enum FnRetTy {
25672590
/// Returns type is not specified.
@@ -3300,7 +3323,7 @@ mod size_asserts {
33003323
static_assert_size!(ForeignItem, 96);
33013324
static_assert_size!(ForeignItemKind, 24);
33023325
static_assert_size!(GenericArg, 24);
3303-
static_assert_size!(GenericBound, 72);
3326+
static_assert_size!(GenericBound, 88);
33043327
static_assert_size!(Generics, 40);
33053328
static_assert_size!(Impl, 136);
33063329
static_assert_size!(Item, 136);

compiler/rustc_ast_lowering/src/expr.rs

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
100100
ParenthesizedGenericArgs::Err,
101101
&ImplTraitContext::Disallowed(ImplTraitPosition::Path),
102102
None,
103+
// Method calls can't have bound modifiers
104+
None,
103105
));
104106
let receiver = self.lower_expr(receiver);
105107
let args =

compiler/rustc_ast_lowering/src/item.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
343343
let itctx = ImplTraitContext::Universal;
344344
let (generics, (trait_ref, lowered_ty)) =
345345
self.lower_generics(ast_generics, *constness, id, &itctx, |this| {
346-
let constness = match *constness {
347-
Const::Yes(span) => BoundConstness::Maybe(span),
348-
Const::No => BoundConstness::Never,
346+
let modifiers = TraitBoundModifiers {
347+
constness: match *constness {
348+
Const::Yes(span) => BoundConstness::Maybe(span),
349+
Const::No => BoundConstness::Never,
350+
},
351+
asyncness: BoundAsyncness::Normal,
352+
// we don't use this in bound lowering
353+
polarity: BoundPolarity::Positive,
349354
};
350355

351356
let trait_ref = trait_ref.as_ref().map(|trait_ref| {
352357
this.lower_trait_ref(
353-
constness,
358+
modifiers,
354359
trait_ref,
355360
&ImplTraitContext::Disallowed(ImplTraitPosition::Trait),
356361
)

compiler/rustc_ast_lowering/src/lib.rs

+12-7
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ struct LoweringContext<'a, 'hir> {
131131
allow_gen_future: Lrc<[Symbol]>,
132132
allow_async_iterator: Lrc<[Symbol]>,
133133
allow_for_await: Lrc<[Symbol]>,
134+
allow_async_fn_traits: Lrc<[Symbol]>,
134135

135136
/// Mapping from generics `def_id`s to TAIT generics `def_id`s.
136137
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
@@ -176,6 +177,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
176177
[sym::gen_future].into()
177178
},
178179
allow_for_await: [sym::async_iterator].into(),
180+
allow_async_fn_traits: [sym::async_fn_traits].into(),
179181
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
180182
// interact with `gen`/`async gen` blocks
181183
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),
@@ -1311,7 +1313,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
13111313
span: t.span,
13121314
},
13131315
itctx,
1314-
ast::BoundConstness::Never,
1316+
TraitBoundModifiers::NONE,
13151317
);
13161318
let bounds = this.arena.alloc_from_iter([bound]);
13171319
let lifetime_bound = this.elided_dyn_bound(t.span);
@@ -1426,7 +1428,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
14261428
itctx,
14271429
// Still, don't pass along the constness here; we don't want to
14281430
// synthesize any host effect args, it'd only cause problems.
1429-
ast::BoundConstness::Never,
1431+
TraitBoundModifiers {
1432+
constness: BoundConstness::Never,
1433+
..*modifiers
1434+
},
14301435
))
14311436
}
14321437
BoundPolarity::Maybe(_) => None,
@@ -2019,7 +2024,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
20192024
) -> hir::GenericBound<'hir> {
20202025
match tpb {
20212026
GenericBound::Trait(p, modifiers) => hir::GenericBound::Trait(
2022-
self.lower_poly_trait_ref(p, itctx, modifiers.constness.into()),
2027+
self.lower_poly_trait_ref(p, itctx, *modifiers),
20232028
self.lower_trait_bound_modifiers(*modifiers),
20242029
),
20252030
GenericBound::Outlives(lifetime) => {
@@ -2192,7 +2197,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
21922197

21932198
fn lower_trait_ref(
21942199
&mut self,
2195-
constness: ast::BoundConstness,
2200+
modifiers: ast::TraitBoundModifiers,
21962201
p: &TraitRef,
21972202
itctx: &ImplTraitContext,
21982203
) -> hir::TraitRef<'hir> {
@@ -2202,7 +2207,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
22022207
&p.path,
22032208
ParamMode::Explicit,
22042209
itctx,
2205-
Some(constness),
2210+
Some(modifiers),
22062211
) {
22072212
hir::QPath::Resolved(None, path) => path,
22082213
qpath => panic!("lower_trait_ref: unexpected QPath `{qpath:?}`"),
@@ -2215,11 +2220,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
22152220
&mut self,
22162221
p: &PolyTraitRef,
22172222
itctx: &ImplTraitContext,
2218-
constness: ast::BoundConstness,
2223+
modifiers: ast::TraitBoundModifiers,
22192224
) -> hir::PolyTraitRef<'hir> {
22202225
let bound_generic_params =
22212226
self.lower_lifetime_binder(p.trait_ref.ref_id, &p.bound_generic_params);
2222-
let trait_ref = self.lower_trait_ref(constness, &p.trait_ref, itctx);
2227+
let trait_ref = self.lower_trait_ref(modifiers, &p.trait_ref, itctx);
22232228
hir::PolyTraitRef { bound_generic_params, trait_ref, span: self.lower_span(p.span) }
22242229
}
22252230

compiler/rustc_ast_lowering/src/path.rs

+78-9
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ use super::{GenericArgsCtor, LifetimeRes, ParenthesizedGenericArgs};
66
use super::{ImplTraitContext, LoweringContext, ParamMode};
77

88
use rustc_ast::{self as ast, *};
9+
use rustc_data_structures::sync::Lrc;
910
use rustc_hir as hir;
1011
use rustc_hir::def::{DefKind, PartialRes, Res};
12+
use rustc_hir::def_id::DefId;
1113
use rustc_hir::GenericArg;
1214
use rustc_middle::span_bug;
1315
use rustc_span::symbol::{kw, sym, Ident};
14-
use rustc_span::{BytePos, Span, DUMMY_SP};
16+
use rustc_span::{BytePos, DesugaringKind, Span, Symbol, DUMMY_SP};
1517

1618
use smallvec::{smallvec, SmallVec};
1719

@@ -24,8 +26,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
2426
p: &Path,
2527
param_mode: ParamMode,
2628
itctx: &ImplTraitContext,
27-
// constness of the impl/bound if this is a trait path
28-
constness: Option<ast::BoundConstness>,
29+
// modifiers of the impl/bound if this is a trait path
30+
modifiers: Option<ast::TraitBoundModifiers>,
2931
) -> hir::QPath<'hir> {
3032
let qself_position = qself.as_ref().map(|q| q.position);
3133
let qself = qself.as_ref().map(|q| self.lower_ty(&q.ty, itctx));
@@ -35,10 +37,27 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
3537
let base_res = partial_res.base_res();
3638
let unresolved_segments = partial_res.unresolved_segments();
3739

40+
let mut res = self.lower_res(base_res);
41+
42+
// When we have an `async` kw on a bound, map the trait it resolves to.
43+
let mut bound_modifier_allowed_features = None;
44+
if let Some(TraitBoundModifiers { asyncness: BoundAsyncness::Async(_), .. }) = modifiers {
45+
if let Res::Def(DefKind::Trait, def_id) = res {
46+
if let Some((async_def_id, features)) = self.map_trait_to_async_trait(def_id) {
47+
res = Res::Def(DefKind::Trait, async_def_id);
48+
bound_modifier_allowed_features = Some(features);
49+
} else {
50+
panic!();
51+
}
52+
} else {
53+
panic!();
54+
}
55+
}
56+
3857
let path_span_lo = p.span.shrink_to_lo();
3958
let proj_start = p.segments.len() - unresolved_segments;
4059
let path = self.arena.alloc(hir::Path {
41-
res: self.lower_res(base_res),
60+
res,
4261
segments: self.arena.alloc_from_iter(p.segments[..proj_start].iter().enumerate().map(
4362
|(i, segment)| {
4463
let param_mode = match (qself_position, param_mode) {
@@ -77,7 +96,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
7796
parenthesized_generic_args,
7897
itctx,
7998
// if this is the last segment, add constness to the trait path
80-
if i == proj_start - 1 { constness } else { None },
99+
if i == proj_start - 1 { modifiers.map(|m| m.constness) } else { None },
100+
bound_modifier_allowed_features.clone(),
81101
)
82102
},
83103
)),
@@ -88,6 +108,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
88108
),
89109
});
90110

111+
if let Some(bound_modifier_allowed_features) = bound_modifier_allowed_features {
112+
path.span = self.mark_span_with_reason(
113+
DesugaringKind::BoundModifier,
114+
path.span,
115+
Some(bound_modifier_allowed_features),
116+
);
117+
}
118+
91119
// Simple case, either no projections, or only fully-qualified.
92120
// E.g., `std::mem::size_of` or `<I as Iterator>::Item`.
93121
if unresolved_segments == 0 {
@@ -125,6 +153,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
125153
ParenthesizedGenericArgs::Err,
126154
itctx,
127155
None,
156+
None,
128157
));
129158
let qpath = hir::QPath::TypeRelative(ty, hir_segment);
130159

@@ -166,6 +195,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
166195
ParenthesizedGenericArgs::Err,
167196
&ImplTraitContext::Disallowed(ImplTraitPosition::Path),
168197
None,
198+
None,
169199
)
170200
})),
171201
span: self.lower_span(p.span),
@@ -180,6 +210,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
180210
parenthesized_generic_args: ParenthesizedGenericArgs,
181211
itctx: &ImplTraitContext,
182212
constness: Option<ast::BoundConstness>,
213+
// Additional features ungated with a bound modifier like `async`.
214+
// This is passed down to the implicit associated type binding in
215+
// parenthesized bounds.
216+
bound_modifier_allowed_features: Option<Lrc<[Symbol]>>,
183217
) -> hir::PathSegment<'hir> {
184218
debug!("path_span: {:?}, lower_path_segment(segment: {:?})", path_span, segment);
185219
let (mut generic_args, infer_args) = if let Some(generic_args) = segment.args.as_deref() {
@@ -188,9 +222,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
188222
self.lower_angle_bracketed_parameter_data(data, param_mode, itctx)
189223
}
190224
GenericArgs::Parenthesized(data) => match parenthesized_generic_args {
191-
ParenthesizedGenericArgs::ParenSugar => {
192-
self.lower_parenthesized_parameter_data(data, itctx)
193-
}
225+
ParenthesizedGenericArgs::ParenSugar => self
226+
.lower_parenthesized_parameter_data(
227+
data,
228+
itctx,
229+
bound_modifier_allowed_features,
230+
),
194231
ParenthesizedGenericArgs::Err => {
195232
// Suggest replacing parentheses with angle brackets `Trait(params...)` to `Trait<params...>`
196233
let sub = if !data.inputs.is_empty() {
@@ -357,6 +394,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
357394
&mut self,
358395
data: &ParenthesizedArgs,
359396
itctx: &ImplTraitContext,
397+
bound_modifier_allowed_features: Option<Lrc<[Symbol]>>,
360398
) -> (GenericArgsCtor<'hir>, bool) {
361399
// Switch to `PassThrough` mode for anonymous lifetimes; this
362400
// means that we permit things like `&Ref<T>`, where `Ref` has
@@ -392,7 +430,19 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
392430
FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
393431
};
394432
let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))];
395-
let binding = self.assoc_ty_binding(sym::Output, output_ty.span, output_ty);
433+
434+
// If we have a bound like `async Fn() -> T`, make sure that we mark the
435+
// `Output = T` associated type bound with the right feature gates.
436+
let mut output_span = output_ty.span;
437+
if let Some(bound_modifier_allowed_features) = bound_modifier_allowed_features {
438+
output_span = self.mark_span_with_reason(
439+
DesugaringKind::BoundModifier,
440+
output_span,
441+
Some(bound_modifier_allowed_features),
442+
);
443+
}
444+
let binding = self.assoc_ty_binding(sym::Output, output_span, output_ty);
445+
396446
(
397447
GenericArgsCtor {
398448
args,
@@ -429,4 +479,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
429479
kind,
430480
}
431481
}
482+
483+
/// When a bound is annotated with `async`, it signals to lowering that the trait
484+
/// that the bound refers to should be mapped to the "async" flavor of the trait.
485+
///
486+
/// This only needs to be done until we unify `AsyncFn` and `Fn` traits into one
487+
/// that is generic over `async`ness, if that's ever possible, or modify the
488+
/// lowering of `async Fn()` bounds to desugar to another trait like `LendingFn`.
489+
fn map_trait_to_async_trait(&self, def_id: DefId) -> Option<(DefId, Lrc<[Symbol]>)> {
490+
let lang_items = self.tcx.lang_items();
491+
if Some(def_id) == lang_items.fn_trait() {
492+
Some((lang_items.async_fn_trait()?, self.allow_async_fn_traits.clone()))
493+
} else if Some(def_id) == lang_items.fn_mut_trait() {
494+
Some((lang_items.async_fn_mut_trait()?, self.allow_async_fn_traits.clone()))
495+
} else if Some(def_id) == lang_items.fn_once_trait() {
496+
Some((lang_items.async_fn_once_trait()?, self.allow_async_fn_traits.clone()))
497+
} else {
498+
None
499+
}
500+
}
432501
}

compiler/rustc_ast_pretty/src/pprust/state.rs

+16-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod item;
88
use crate::pp::Breaks::{Consistent, Inconsistent};
99
use crate::pp::{self, Breaks};
1010
use crate::pprust::state::expr::FixupContext;
11+
use ast::TraitBoundModifiers;
1112
use rustc_ast::attr::AttrIdGenerator;
1213
use rustc_ast::ptr::P;
1314
use rustc_ast::token::{self, BinOpToken, CommentKind, Delimiter, Nonterminal, Token, TokenKind};
@@ -1590,18 +1591,28 @@ impl<'a> State<'a> {
15901591
}
15911592

15921593
match bound {
1593-
GenericBound::Trait(tref, modifier) => {
1594-
match modifier.constness {
1594+
GenericBound::Trait(
1595+
tref,
1596+
TraitBoundModifiers { constness, asyncness, polarity },
1597+
) => {
1598+
match constness {
15951599
ast::BoundConstness::Never => {}
15961600
ast::BoundConstness::Always(_) | ast::BoundConstness::Maybe(_) => {
1597-
self.word_space(modifier.constness.as_str());
1601+
self.word_space(constness.as_str());
15981602
}
15991603
}
16001604

1601-
match modifier.polarity {
1605+
match asyncness {
1606+
ast::BoundAsyncness::Normal => {}
1607+
ast::BoundAsyncness::Async(_) => {
1608+
self.word_space(asyncness.as_str());
1609+
}
1610+
}
1611+
1612+
match polarity {
16021613
ast::BoundPolarity::Positive => {}
16031614
ast::BoundPolarity::Negative(_) | ast::BoundPolarity::Maybe(_) => {
1604-
self.word(modifier.polarity.as_str());
1615+
self.word(polarity.as_str());
16051616
}
16061617
}
16071618

compiler/rustc_expand/src/build.rs

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ impl<'a> ExtCtxt<'a> {
141141
} else {
142142
ast::BoundConstness::Never
143143
},
144+
asyncness: ast::BoundAsyncness::Normal,
144145
},
145146
)
146147
}

0 commit comments

Comments
 (0)