Skip to content

Commit ce822f1

Browse files
committed
scalar or vector: add #[derive(ScalarOrVectorComposite)]
1 parent 3455bb3 commit ce822f1

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
mod debug_printf;
7575
mod image;
7676
mod sample_param_permutations;
77+
mod scalar_or_vector_composite;
7778

7879
use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
7980
use proc_macro::TokenStream;
@@ -309,3 +310,10 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
309310
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
310311
sample_param_permutations::gen_sample_param_permutations(item)
311312
}
313+
314+
#[proc_macro_derive(ScalarOrVectorComposite)]
315+
pub fn derive_scalar_or_vector_composite(item: TokenStream) -> TokenStream {
316+
scalar_or_vector_composite::derive(item.into())
317+
.unwrap_or_else(syn::Error::into_compile_error)
318+
.into()
319+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use proc_macro2::TokenStream;
2+
use quote::{ToTokens, quote};
3+
use syn::punctuated::Punctuated;
4+
use syn::{Fields, FieldsNamed, FieldsUnnamed, GenericParam, Token};
5+
6+
pub fn derive(item: TokenStream) -> syn::Result<TokenStream> {
7+
// Whenever we'll properly resolve the crate symbol, replace this.
8+
let spirv_std = quote!(spirv_std);
9+
10+
// Defer all validation to our codegen backend. Rather than erroring here, emit garbage.
11+
let item = syn::parse2::<syn::ItemStruct>(item)?;
12+
let struct_ident = &item.ident;
13+
let gens = &item.generics.params;
14+
let gen_refs = &item
15+
.generics
16+
.params
17+
.iter()
18+
.map(|p| match p {
19+
GenericParam::Lifetime(p) => p.lifetime.to_token_stream(),
20+
GenericParam::Type(p) => p.ident.to_token_stream(),
21+
GenericParam::Const(p) => p.ident.to_token_stream(),
22+
})
23+
.collect::<Punctuated<_, Token![,]>>();
24+
let where_clause = &item.generics.where_clause;
25+
26+
let content =
27+
match item.fields {
28+
Fields::Named(FieldsNamed { named, .. }) => {
29+
let content = named.iter().map(|f| {
30+
let ident = &f.ident;
31+
quote!(#ident: #spirv_std::ScalarOrVectorComposite::transform(self.#ident, f))
32+
}).collect::<Punctuated<_, Token![,]>>();
33+
quote!(Self { #content })
34+
}
35+
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
36+
let content = (0..unnamed.len())
37+
.map(|i| {
38+
let i = syn::Index::from(i);
39+
quote!(#spirv_std::ScalarOrVectorComposite::transform(self.#i, f))
40+
})
41+
.collect::<Punctuated<_, Token![,]>>();
42+
quote!(Self(#content))
43+
}
44+
Fields::Unit => quote!(Self),
45+
};
46+
47+
Ok(quote! {
48+
impl<#gens> #spirv_std::ScalarOrVectorComposite for #struct_ident<#gen_refs> #where_clause {
49+
fn transform<F: #spirv_std::ScalarOrVectorTransform>(self, f: &mut F) -> Self {
50+
#content
51+
}
52+
}
53+
})
54+
}

crates/spirv-std/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
/// Public re-export of the `spirv-std-macros` crate.
8888
#[macro_use]
8989
pub extern crate spirv_std_macros as macros;
90+
pub use macros::ScalarOrVectorComposite;
9091
pub use macros::spirv;
9192

9293
pub mod arch;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+GroupNonUniformShuffle,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model
3+
// normalize-stderr-test "OpLine .*\n" -> ""
4+
5+
use glam::*;
6+
use spirv_std::ScalarOrVectorComposite;
7+
use spirv_std::arch::*;
8+
use spirv_std::spirv;
9+
10+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
11+
pub struct MyStruct {
12+
a: f32,
13+
b: UVec3,
14+
c: Nested,
15+
d: Zst,
16+
}
17+
18+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
19+
pub struct Nested(i32);
20+
21+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
22+
pub struct Zst;
23+
24+
#[spirv(compute(threads(32)))]
25+
pub fn main(
26+
#[spirv(local_invocation_index)] inv_id: UVec3,
27+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut UVec3,
28+
) {
29+
unsafe {
30+
let my_struct = MyStruct {
31+
a: 1.,
32+
b: inv_id,
33+
c: Nested(-42),
34+
d: Zst,
35+
};
36+
37+
let mut out = UVec3::ZERO;
38+
out += subgroup_broadcast(my_struct, 19).b;
39+
out += subgroup_broadcast_first(my_struct).b;
40+
out += subgroup_shuffle(my_struct, 2).b;
41+
out += subgroup_shuffle_xor(my_struct, 4).b;
42+
out += subgroup_shuffle_up(my_struct, 5).b;
43+
out += subgroup_shuffle_down(my_struct, 7).b;
44+
*output = out;
45+
}
46+
}

0 commit comments

Comments
 (0)