Skip to content

Commit b89d532

Browse files
committed
feat: trait_variant::make supports rewriting of the original trait.
1 parent f1e171e commit b89d532

File tree

3 files changed

+89
-55
lines changed

3 files changed

+89
-55
lines changed

trait-variant/examples/variant.rs

+14
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,18 @@ where
4343
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
4444
}
4545

46+
#[trait_variant::make(Send + Sync)]
47+
pub trait GenericTraitWithBounds<'x, S: Sync, Y, const X: usize>
48+
where
49+
Y: Sync,
50+
{
51+
const CONST: usize = 3;
52+
type F;
53+
type A<const ANOTHER_CONST: u8>;
54+
type B<T: Display>: FromIterator<T>;
55+
56+
async fn take(&self, s: S);
57+
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
58+
}
59+
4660
fn main() {}

trait-variant/src/lib.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ mod variant;
1414
/// fn` and/or `-> impl Trait` return types.
1515
///
1616
/// ```
17-
/// #[trait_variant::make(IntFactory: Send)]
18-
/// trait LocalIntFactory {
17+
/// #[trait_variant::make(Send)]
18+
/// trait IntFactory {
1919
/// async fn make(&self) -> i32;
2020
/// fn stream(&self) -> impl Iterator<Item = i32>;
2121
/// fn call(&self) -> u32;
2222
/// }
2323
/// ```
2424
///
25-
/// The above example causes a second trait called `IntFactory` to be created:
25+
/// The above example causes the trait to be rewritten as:
2626
///
2727
/// ```
2828
/// # use core::future::Future;
@@ -35,6 +35,19 @@ mod variant;
3535
///
3636
/// Note that ordinary methods such as `call` are not affected.
3737
///
38+
/// If you want to preserve an original trait untouched, `make` can be used to create a new trait with bounds on `async
39+
/// fn` and/or `-> impl Trait` return types.
40+
///
41+
/// ```
42+
/// #[trait_variant::make(IntFactory: Send)]
43+
/// trait LocalIntFactory {
44+
/// async fn make(&self) -> i32;
45+
/// fn stream(&self) -> impl Iterator<Item = i32>;
46+
/// fn call(&self) -> u32;
47+
/// }
48+
/// ```
49+
///
50+
/// The example causes a second trait called `IntFactory` to be created.
3851
/// Implementers of the trait can choose to implement the variant instead of the
3952
/// original trait. The macro creates a blanket impl which ensures that any type
4053
/// which implements the variant also implements the original trait.

trait-variant/src/variant.rs

+59-52
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,33 @@ impl Parse for Attrs {
3232
}
3333
}
3434

35-
struct MakeVariant {
36-
name: Ident,
37-
#[allow(unused)]
38-
colon: Token![:],
39-
bounds: Punctuated<TraitBound, Plus>,
35+
enum MakeVariant {
36+
// Creates a variant of a trait under a new name with additional bounds while preserving the original trait.
37+
Create {
38+
name: Ident,
39+
_colon: Token![:],
40+
bounds: Punctuated<TraitBound, Plus>,
41+
},
42+
// Rewrites the original trait into a new trait with additional bounds.
43+
Rewrite {
44+
bounds: Punctuated<TraitBound, Plus>,
45+
},
4046
}
4147

4248
impl Parse for MakeVariant {
4349
fn parse(input: ParseStream) -> Result<Self> {
44-
Ok(Self {
45-
name: input.parse()?,
46-
colon: input.parse()?,
47-
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
48-
})
50+
let variant = if input.peek(Ident) && input.peek2(Token![:]) {
51+
MakeVariant::Create {
52+
name: input.parse()?,
53+
_colon: input.parse()?,
54+
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
55+
}
56+
} else {
57+
MakeVariant::Rewrite {
58+
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
59+
}
60+
};
61+
Ok(variant)
4962
}
5063
}
5164

@@ -56,43 +69,51 @@ pub fn make(
5669
let attrs = parse_macro_input!(attr as Attrs);
5770
let item = parse_macro_input!(item as ItemTrait);
5871

59-
let maybe_allow_async_lint = if attrs
60-
.variant
61-
.bounds
62-
.iter()
63-
.any(|b| b.path.segments.last().unwrap().ident == "Send")
64-
{
65-
quote! { #[allow(async_fn_in_trait)] }
66-
} else {
67-
quote! {}
68-
};
72+
match attrs.variant {
73+
MakeVariant::Create { name, bounds, .. } => {
74+
let maybe_allow_async_lint = if bounds
75+
.iter()
76+
.any(|b| b.path.segments.last().unwrap().ident == "Send")
77+
{
78+
quote! { #[allow(async_fn_in_trait)] }
79+
} else {
80+
quote! {}
81+
};
6982

70-
let variant = mk_variant(&attrs, &item);
71-
let blanket_impl = mk_blanket_impl(&attrs, &item);
83+
let variant = mk_variant(&name, bounds, &item);
84+
let blanket_impl = mk_blanket_impl(&name, &item);
7285

73-
quote! {
74-
#maybe_allow_async_lint
75-
#item
86+
quote! {
87+
#maybe_allow_async_lint
88+
#item
7689

77-
#variant
90+
#variant
7891

79-
#blanket_impl
92+
#blanket_impl
93+
}
94+
.into()
95+
}
96+
MakeVariant::Rewrite { bounds, .. } => {
97+
let variant = mk_variant(&item.ident, bounds, &item);
98+
quote! {
99+
#variant
100+
}
101+
.into()
102+
}
80103
}
81-
.into()
82104
}
83105

84-
fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
85-
let MakeVariant {
86-
ref name,
87-
colon: _,
88-
ref bounds,
89-
} = attrs.variant;
90-
let bounds: Vec<_> = bounds
106+
fn mk_variant(
107+
variant: &Ident,
108+
with_bounds: Punctuated<TraitBound, Plus>,
109+
tr: &ItemTrait,
110+
) -> TokenStream {
111+
let bounds: Vec<_> = with_bounds
91112
.into_iter()
92113
.map(|b| TypeParamBound::Trait(b.clone()))
93114
.collect();
94115
let variant = ItemTrait {
95-
ident: name.clone(),
116+
ident: variant.clone(),
96117
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
97118
items: tr
98119
.items
@@ -104,21 +125,8 @@ fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
104125
quote! { #variant }
105126
}
106127

128+
// Transforms a one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds.
107129
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
108-
// #[make_variant(SendIntFactory: Send)]
109-
// trait IntFactory {
110-
// async fn make(&self, x: u32, y: &str) -> i32;
111-
// fn stream(&self) -> impl Iterator<Item = i32>;
112-
// fn call(&self) -> u32;
113-
// }
114-
//
115-
// becomes:
116-
//
117-
// trait SendIntFactory: Send {
118-
// fn make(&self, x: u32, y: &str) -> impl ::core::future::Future<Output = i32> + Send;
119-
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
120-
// fn call(&self) -> u32;
121-
// }
122130
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
123131
return item.clone();
124132
};
@@ -160,9 +168,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
160168
})
161169
}
162170

163-
fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
171+
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
164172
let orig = &tr.ident;
165-
let variant = &attrs.variant.name;
166173
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
167174
let items = tr
168175
.items

0 commit comments

Comments
 (0)