Skip to content

Commit 1a90208

Browse files
committed
scalar or vector: derive enums and improve enum docs
1 parent 4dd8c60 commit 1a90208

File tree

7 files changed

+377
-28
lines changed

7 files changed

+377
-28
lines changed
Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
use proc_macro2::TokenStream;
22
use quote::{ToTokens, quote};
33
use syn::punctuated::Punctuated;
4-
use syn::{Fields, FieldsNamed, FieldsUnnamed, GenericParam, Token};
4+
use syn::{
5+
Data, DataStruct, DataUnion, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, GenericParam,
6+
Token,
7+
};
58

69
pub fn derive(item: TokenStream) -> syn::Result<TokenStream> {
710
// Whenever we'll properly resolve the crate symbol, replace this.
811
let spirv_std = quote!(spirv_std);
912

1013
// Defer all validation to our codegen backend. Rather than erroring here, emit garbage.
11-
let item = syn::parse2::<syn::ItemStruct>(item)?;
12-
let struct_ident = &item.ident;
14+
let item = syn::parse2::<DeriveInput>(item)?;
15+
let content = match &item.data {
16+
Data::Enum(_) => derive_enum(&spirv_std, &item),
17+
Data::Struct(data) => derive_struct(&spirv_std, data),
18+
Data::Union(DataUnion { union_token, .. }) => {
19+
Err(syn::Error::new_spanned(union_token, "Union not supported"))
20+
}
21+
}?;
22+
23+
let ident = &item.ident;
1324
let gens = &item.generics.params;
1425
let gen_refs = &item
1526
.generics
@@ -23,33 +34,57 @@ pub fn derive(item: TokenStream) -> syn::Result<TokenStream> {
2334
.collect::<Punctuated<_, Token![,]>>();
2435
let where_clause = &item.generics.where_clause;
2536

26-
let content =
27-
match item.fields {
28-
Fields::Named(FieldsNamed { named, .. }) => {
29-
let content = named.iter().map(|f| {
30-
let ident = &f.ident;
31-
quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f))
32-
}).collect::<Punctuated<_, Token![,]>>();
33-
quote!(Self { #content })
34-
}
35-
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
36-
let content = (0..unnamed.len())
37-
.map(|i| {
38-
let i = syn::Index::from(i);
39-
quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f))
40-
})
41-
.collect::<Punctuated<_, Token![,]>>();
42-
quote!(Self(#content))
43-
}
44-
Fields::Unit => quote!(Self),
45-
};
46-
4737
Ok(quote! {
48-
impl<#gens> #spirv_std::ScalarOrVectorComposite for #struct_ident<#gen_refs> #where_clause {
38+
impl<#gens> #spirv_std::ScalarOrVectorComposite for #ident<#gen_refs> #where_clause {
4939
#[inline]
5040
fn transform<F: #spirv_std::ScalarOrVectorTransform>(self, f: &mut F) -> Self {
5141
#content
5242
}
5343
}
5444
})
5545
}
46+
47+
pub fn derive_struct(spirv_std: &TokenStream, data: &DataStruct) -> syn::Result<TokenStream> {
48+
Ok(match &data.fields {
49+
Fields::Named(FieldsNamed { named, .. }) => {
50+
let content = named
51+
.iter()
52+
.map(|f| {
53+
let ident = &f.ident;
54+
quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f))
55+
})
56+
.collect::<Punctuated<_, Token![,]>>();
57+
quote!(Self { #content })
58+
}
59+
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
60+
let content = (0..unnamed.len())
61+
.map(|i| {
62+
let i = syn::Index::from(i);
63+
quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f))
64+
})
65+
.collect::<Punctuated<_, Token![,]>>();
66+
quote!(Self(#content))
67+
}
68+
Fields::Unit => quote!(Self),
69+
})
70+
}
71+
72+
pub fn derive_enum(spirv_std: &TokenStream, item: &DeriveInput) -> syn::Result<TokenStream> {
73+
let mut attributes = item.attrs.iter().filter(|a| a.path().is_ident("repr"));
74+
let repr = match (attributes.next(), attributes.next()) {
75+
(None, _) => Err(syn::Error::new_spanned(
76+
item,
77+
"Missing #[repr(...)] attribute",
78+
)),
79+
(Some(repr), None) => Ok(repr),
80+
(Some(_), Some(_)) => Err(syn::Error::new_spanned(
81+
item,
82+
"Multiple #[repr(...)] attributes found",
83+
)),
84+
}?;
85+
let prim = &repr.meta.require_list()?.tokens;
86+
Ok(quote! {
87+
#spirv_std::assert_is_integer::<#prim>();
88+
<Self as From<#prim>>::from(#spirv_std::ScalarOrVectorComposite::transform(<Self as Into<#prim>>::into(self), f))
89+
})
90+
}

crates/spirv-std/src/scalar.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,7 @@ impl_scalar! {
117117
impl Float for f64;
118118
impl Scalar for bool;
119119
}
120+
121+
/// used by `ScalarOrVector` derive when working with enums
122+
#[inline]
123+
pub fn assert_is_integer<T: Integer>() {}

crates/spirv-std/src/scalar_or_vector.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,23 @@ pub unsafe trait ScalarOrVector: ScalarOrVectorComposite + Default {
3232
/// data to other threads.
3333
///
3434
/// To derive `#[derive(VectorOrScalarComposite)]` on a struct, all members must also implement
35-
/// `VectorOrScalarComposite`. To derive it on an enum, the enum must have `#[repr(N)]` where `N` is an [`Integer`].
36-
/// Additionally, you must derive `num_enum::FromPrimitive` and `num_enum::ToPrimitive`, which requires the enum to be
37-
/// either exhaustive, implement [`Default`] or a variant of the enum to have the `#[num_enum(default)]` attribute.
35+
/// `VectorOrScalarComposite`.
36+
///
37+
/// To derive it on an enum, the enum must implement `From<N>` and `Into<N>` where `N` is defined by the `#[repr(N)]`
38+
/// attribute on the enum and is an [`Integer`], like `u32`.
39+
/// Note that some [safe subgroup operations] may return an "undefined result", so your `From<N>` must gracefully handle
40+
/// arbitrary bit patterns being passed to it. While panicking is legal, it is discouraged as it may result in
41+
/// unexpected control flow.
42+
/// To implement these conversion traits, we recommend [`FromPrimitive`] and [`IntoPrimitive`] from the [`num_enum`]
43+
/// crate. [`FromPrimitive`] requires that either the enum is exhaustive, or you provide it with a variant to default
44+
/// to, by either implementing [`Default`] or marking a variant with `#[num_enum(default)]`. Note to disable default
45+
/// features on the [`num_enum`] crate, or it won't compile on SPIR-V.
3846
///
3947
/// [`Integer`]: crate::Integer
48+
/// [subgroup operations]: crate::arch::subgroup_shuffle
49+
/// [`FromPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.FromPrimitive.html
50+
/// [`IntoPrimitive`]: https://docs.rs/num_enum/latest/num_enum/derive.IntoPrimitive.html
51+
/// [`num_enum`]: https://crates.io/crates/num_enum
4052
pub trait ScalarOrVectorComposite: Copy + Send + Sync + 'static {
4153
/// Transform the individual [`Scalar`] and [`Vector`] values of this type to a different value.
4254
///
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffle,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_enum::disassembly
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
6+
use glam::*;
7+
use spirv_std::ScalarOrVectorComposite;
8+
use spirv_std::arch::*;
9+
use spirv_std::spirv;
10+
11+
#[repr(u32)]
12+
#[derive(Copy, Clone, Default, ScalarOrVectorComposite)]
13+
pub enum MyEnum {
14+
#[default]
15+
A,
16+
B,
17+
C,
18+
}
19+
20+
impl From<u32> for MyEnum {
21+
#[inline]
22+
fn from(value: u32) -> Self {
23+
match value {
24+
0 => Self::A,
25+
1 => Self::B,
26+
2 => Self::C,
27+
_ => Self::default(),
28+
}
29+
}
30+
}
31+
32+
impl From<MyEnum> for u32 {
33+
#[inline]
34+
fn from(value: MyEnum) -> Self {
35+
value as u32
36+
}
37+
}
38+
39+
/// this should be 3 `subgroup_shuffle` instructions, with all calls inlined
40+
fn disassembly(my_struct: MyEnum, id: u32) -> MyEnum {
41+
subgroup_shuffle(my_struct, id)
42+
}
43+
44+
#[spirv(compute(threads(32)))]
45+
pub fn main(
46+
#[spirv(local_invocation_index)] inv_id: UVec3,
47+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut MyEnum,
48+
) {
49+
unsafe {
50+
let my_enum = MyEnum::from(inv_id.x % 3);
51+
*output = disassembly(my_enum, 5);
52+
}
53+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpFunctionParameter %2
3+
%5 = OpFunctionParameter %2
4+
%6 = OpLabel
5+
%8 = OpGroupNonUniformShuffle %2 %9 %4 %5
6+
OpNoLine
7+
OpSelectionMerge %10 None
8+
OpSwitch %8 %11 0 %12 1 %13 2 %14
9+
%11 = OpLabel
10+
OpBranch %10
11+
%12 = OpLabel
12+
OpBranch %10
13+
%13 = OpLabel
14+
OpBranch %10
15+
%14 = OpLabel
16+
OpBranch %10
17+
%10 = OpLabel
18+
%15 = OpPhi %2 %16 %11 %16 %12 %17 %13 %18 %14
19+
OpReturnValue %15
20+
OpFunctionEnd
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// build-fail
2+
// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/"
3+
// normalize-stderr-test "\.rs:\d+:\d+" -> ".rs:"
4+
// normalize-stderr-test "(\n)\d* *([ -])([\|\+\-\=])" -> "$1 $2$3"
5+
6+
use glam::*;
7+
use spirv_std::ScalarOrVectorComposite;
8+
use spirv_std::arch::*;
9+
use spirv_std::spirv;
10+
11+
macro_rules! enum_repr_from {
12+
($ident:ident, $repr:ty) => {
13+
impl From<$repr> for $ident {
14+
#[inline]
15+
fn from(value: $repr) -> Self {
16+
match value {
17+
0 => Self::A,
18+
1 => Self::B,
19+
2 => Self::C,
20+
_ => Self::default(),
21+
}
22+
}
23+
}
24+
25+
impl From<$ident> for $repr {
26+
#[inline]
27+
fn from(value: $ident) -> Self {
28+
value as $repr
29+
}
30+
}
31+
};
32+
}
33+
34+
#[derive(Copy, Clone, Default, ScalarOrVectorComposite)]
35+
pub enum NoRepr {
36+
#[default]
37+
A,
38+
B,
39+
C,
40+
}
41+
42+
#[repr(u32)]
43+
#[repr(u16)]
44+
#[derive(Copy, Clone, Default, ScalarOrVectorComposite)]
45+
pub enum TwoRepr {
46+
#[default]
47+
A,
48+
B,
49+
C,
50+
}
51+
52+
#[repr(C)]
53+
#[derive(Copy, Clone, Default, ScalarOrVectorComposite)]
54+
pub enum CRepr {
55+
#[default]
56+
A,
57+
B,
58+
C,
59+
}
60+
61+
#[repr(i32)]
62+
#[derive(Copy, Clone, Default, ScalarOrVectorComposite)]
63+
pub enum NoFrom {
64+
#[default]
65+
A,
66+
B,
67+
C,
68+
}
69+
70+
#[repr(i32)]
71+
#[derive(Copy, Clone, Default, ScalarOrVectorComposite)]
72+
pub enum WrongFrom {
73+
#[default]
74+
A,
75+
B,
76+
C,
77+
}
78+
79+
enum_repr_from!(WrongFrom, u32);
80+
81+
#[repr(i32)]
82+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
83+
pub enum NoDefault {
84+
A,
85+
B,
86+
C,
87+
}
88+
89+
enum_repr_from!(NoDefault, i32);

0 commit comments

Comments
 (0)