diff --git a/crates/spirv-std/macros/src/lib.rs b/crates/spirv-std/macros/src/lib.rs index d8ecf7b0cf..8e13cd1258 100644 --- a/crates/spirv-std/macros/src/lib.rs +++ b/crates/spirv-std/macros/src/lib.rs @@ -80,6 +80,9 @@ use proc_macro::TokenStream; use proc_macro2::{Delimiter, Group, Ident, TokenTree}; use quote::{ToTokens, TokenStreamExt, format_ident, quote}; use spirv_std_types::spirv_attr_version::spirv_attr_with_version; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::{GenericParam, Token}; /// A macro for creating SPIR-V `OpTypeImage` types. Always produces a /// `spirv_std::image::Image<...>` type. @@ -311,3 +314,54 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream { pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream { sample_param_permutations::gen_sample_param_permutations(item) } + +#[proc_macro_attribute] +pub fn spirv_vector(attr: TokenStream, item: TokenStream) -> TokenStream { + spirv_vector_impl(attr.into(), item.into()) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +fn spirv_vector_impl( + attr: proc_macro2::TokenStream, + item: proc_macro2::TokenStream, +) -> syn::Result { + // Whenever we'll properly resolve the crate symbol, replace this. + let spirv_std = quote!(spirv_std); + + // Defer all validation to our codegen backend. Rather than erroring here, emit garbage. + let item = syn::parse2::(item)?; + let ident = &item.ident; + let gens = &item.generics.params; + let gen_refs = &item + .generics + .params + .iter() + .map(|p| match p { + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(p) => p.ident.to_token_stream(), + GenericParam::Const(p) => p.ident.to_token_stream(), + }) + .collect::>(); + let where_clause = &item.generics.where_clause; + let element = item + .fields + .iter() + .next() + .ok_or_else(|| syn::Error::new(item.span(), "Vector ZST not allowed"))? + .ty + .to_token_stream(); + let count = item.fields.len(); + + Ok(quote! { + #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1(#attr))] + #item + + unsafe impl<#gens> #spirv_std::ScalarOrVector for #ident<#gen_refs> #where_clause { + type Scalar = #element; + const N: core::num::NonZeroUsize = core::num::NonZeroUsize::new(#count).unwrap(); + } + + unsafe impl<#gens> #spirv_std::Vector<#element, #count> for #ident<#gen_refs> #where_clause {} + }) +} diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 2c85dc9af0..13bd406cd1 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -87,8 +87,8 @@ /// Public re-export of the `spirv-std-macros` crate. #[macro_use] pub extern crate spirv_std_macros as macros; -pub use macros::spirv; pub use macros::{debug_printf, debug_printfln}; +pub use macros::{spirv, spirv_vector}; pub mod arch; pub mod byte_addressable_buffer; diff --git a/crates/spirv-std/src/vector.rs b/crates/spirv-std/src/vector.rs index 0389424df0..471d918b22 100644 --- a/crates/spirv-std/src/vector.rs +++ b/crates/spirv-std/src/vector.rs @@ -7,9 +7,10 @@ use glam::{Vec3Swizzles, Vec4Swizzles}; /// Abstract trait representing a SPIR-V vector type. /// -/// To implement this trait, your struct must be marked with: +/// To derive this trait, mark your struct with: /// ```no_run -/// #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1)] +/// #[spirv_std::spirv_vector] +/// # #[derive(Copy, Clone, Default)] /// # struct Bla(f32, f32); /// ``` /// @@ -32,8 +33,8 @@ use glam::{Vec3Swizzles, Vec4Swizzles}; /// /// # Example /// ```no_run +/// #[spirv_std::spirv_vector] /// #[derive(Copy, Clone, Default)] -/// #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1)] /// struct MyColor { /// r: f32, /// b: f32, @@ -44,7 +45,8 @@ use glam::{Vec3Swizzles, Vec4Swizzles}; /// /// # Safety /// * Must only be implemented on types that the spirv codegen emits as valid `OpTypeVector`. This includes all structs -/// marked with `#[rust_gpu::vector::v1]`, like [`glam`]'s non-SIMD "scalar" vector types. +/// marked with `#[rust_gpu::vector::v1]`, which `#[spirv_std::spirv_vector]` expands into or [`glam`]'s non-SIMD +/// "scalar" vector types use directly. /// * `VectorOrScalar::DIM == N`, since const equality is behind rustc feature `associated_const_equality` // Note(@firestar99) I would like to have these two generics be associated types instead. Doesn't make much sense for // a vector type to implement this interface multiple times with different Scalar types or N, after all. diff --git a/tests/compiletests/ui/glam/invalid_vector_type_macro.rs b/tests/compiletests/ui/glam/invalid_vector_type_macro.rs new file mode 100644 index 0000000000..6492f90932 --- /dev/null +++ b/tests/compiletests/ui/glam/invalid_vector_type_macro.rs @@ -0,0 +1,33 @@ +// build-fail + +use core::num::NonZeroU32; +use spirv_std::glam::Vec2; +use spirv_std::spirv; + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct FewerFields { + _v: f32, +} + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct TooManyFields { + _x: f32, + _y: f32, + _z: f32, + _w: f32, + _v: f32, +} + +// wrong member types fails too early + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct DifferentTypes { + _x: f32, + _y: u32, +} + +#[spirv(fragment)] +pub fn entry(_: FewerFields, _: TooManyFields, #[spirv(flat)] _: DifferentTypes) {} diff --git a/tests/compiletests/ui/glam/invalid_vector_type_macro.stderr b/tests/compiletests/ui/glam/invalid_vector_type_macro.stderr new file mode 100644 index 0000000000..1cefd3c0bd --- /dev/null +++ b/tests/compiletests/ui/glam/invalid_vector_type_macro.stderr @@ -0,0 +1,20 @@ +error: `#[spirv(vector)]` must have 2, 3 or 4 members + --> $DIR/invalid_vector_type_macro.rs:9:1 + | +LL | pub struct FewerFields { + | ^^^^^^^^^^^^^^^^^^^^^^ + +error: `#[spirv(vector)]` must have 2, 3 or 4 members + --> $DIR/invalid_vector_type_macro.rs:15:1 + | +LL | pub struct TooManyFields { + | ^^^^^^^^^^^^^^^^^^^^^^^^ + +error: `#[spirv(vector)]` member types must all be the same + --> $DIR/invalid_vector_type_macro.rs:27:1 + | +LL | pub struct DifferentTypes { + | ^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: aborting due to 3 previous errors + diff --git a/tests/compiletests/ui/glam/invalid_vector_type_macro2.rs b/tests/compiletests/ui/glam/invalid_vector_type_macro2.rs new file mode 100644 index 0000000000..8a122f040d --- /dev/null +++ b/tests/compiletests/ui/glam/invalid_vector_type_macro2.rs @@ -0,0 +1,42 @@ +// build-fail +// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/" + +use core::num::NonZeroU32; +use spirv_std::glam::Vec2; +use spirv_std::spirv; + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct ZstVector; + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct NotVectorField { + _x: Vec2, + _y: Vec2, +} + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone)] +pub struct NotVectorField2 { + _x: NonZeroU32, + _y: NonZeroU32, +} + +impl Default for NotVectorField2 { + fn default() -> Self { + Self { + _x: NonZeroU32::new(1).unwrap(), + _y: NonZeroU32::new(1).unwrap(), + } + } +} + +#[spirv(fragment)] +pub fn entry( + // workaround to ZST loading + #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32), + _: NotVectorField, + #[spirv(flat)] _: NotVectorField2, +) { +} diff --git a/tests/compiletests/ui/glam/invalid_vector_type_macro2.stderr b/tests/compiletests/ui/glam/invalid_vector_type_macro2.stderr new file mode 100644 index 0000000000..73ce18d9ae --- /dev/null +++ b/tests/compiletests/ui/glam/invalid_vector_type_macro2.stderr @@ -0,0 +1,113 @@ +error: Vector ZST not allowed + --> $DIR/invalid_vector_type_macro2.rs:9:1 + | +LL | / #[derive(Copy, Clone, Default)] +LL | | pub struct ZstVector; + | |_____________________^ + +error[E0412]: cannot find type `ZstVector` in this scope + --> $DIR/invalid_vector_type_macro2.rs:38:67 + | +LL | #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32), + | ^^^^^^^^^ not found in this scope + +error: unknown argument to spirv attribute + --> $DIR/invalid_vector_type_macro2.rs:38:13 + | +LL | #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32), + | ^^^^^^^^^^^^^ + +error[E0277]: the trait bound `Vec2: Scalar` is not satisfied + --> $DIR/invalid_vector_type_macro2.rs:15:9 + | +LL | _x: Vec2, + | ^^^^ the trait `Scalar` is not implemented for `Vec2` + | + = help: the following other types implement trait `Scalar`: + bool + f32 + f64 + i16 + i32 + i64 + i8 + u16 + and 3 others +note: required by a bound in `spirv_std::ScalarOrVector::Scalar` + --> $SPIRV_STD_SRC/scalar_or_vector.rs:18:18 + | +LL | type Scalar: Scalar; + | ^^^^^^ required by this bound in `ScalarOrVector::Scalar` + +error[E0277]: the trait bound `Vec2: Scalar` is not satisfied + --> $DIR/invalid_vector_type_macro2.rs:12:1 + | +LL | #[spirv_std::spirv_vector] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Scalar` is not implemented for `Vec2` + | + = help: the following other types implement trait `Scalar`: + bool + f32 + f64 + i16 + i32 + i64 + i8 + u16 + and 3 others +note: required by a bound in `Vector` + --> $SPIRV_STD_SRC/vector.rs:56:28 + | +LL | pub unsafe trait Vector: ScalarOrVector {} + | ^^^^^^ required by this bound in `Vector` + = note: this error originates in the attribute macro `spirv_std::spirv_vector` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `NonZero: Scalar` is not satisfied + --> $DIR/invalid_vector_type_macro2.rs:22:9 + | +LL | _x: NonZeroU32, + | ^^^^^^^^^^ the trait `Scalar` is not implemented for `NonZero` + | + = help: the following other types implement trait `Scalar`: + bool + f32 + f64 + i16 + i32 + i64 + i8 + u16 + and 3 others +note: required by a bound in `spirv_std::ScalarOrVector::Scalar` + --> $SPIRV_STD_SRC/scalar_or_vector.rs:18:18 + | +LL | type Scalar: Scalar; + | ^^^^^^ required by this bound in `ScalarOrVector::Scalar` + +error[E0277]: the trait bound `NonZero: Scalar` is not satisfied + --> $DIR/invalid_vector_type_macro2.rs:19:1 + | +LL | #[spirv_std::spirv_vector] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Scalar` is not implemented for `NonZero` + | + = help: the following other types implement trait `Scalar`: + bool + f32 + f64 + i16 + i32 + i64 + i8 + u16 + and 3 others +note: required by a bound in `Vector` + --> $SPIRV_STD_SRC/vector.rs:56:28 + | +LL | pub unsafe trait Vector: ScalarOrVector {} + | ^^^^^^ required by this bound in `Vector` + = note: this error originates in the attribute macro `spirv_std::spirv_vector` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: aborting due to 7 previous errors + +Some errors have detailed explanations: E0277, E0412. +For more information about an error, try `rustc --explain E0277`. diff --git a/tests/compiletests/ui/glam/spirv_vector_macro.rs b/tests/compiletests/ui/glam/spirv_vector_macro.rs new file mode 100644 index 0000000000..90a118b3aa --- /dev/null +++ b/tests/compiletests/ui/glam/spirv_vector_macro.rs @@ -0,0 +1,33 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" + +use spirv_std::arch::subgroup_shuffle_up; +use spirv_std::glam::Vec3; +use spirv_std::spirv; + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct MyColor { + pub r: f32, + pub g: f32, + pub b: f32, +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &Vec3, + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut MyColor, +) { + let color = MyColor { + r: input.x, + g: input.y, + b: input.z, + }; + // any function that accepts a `VectorOrScalar` would do + *output = subgroup_shuffle_up(color, 5); +} diff --git a/tests/compiletests/ui/glam/spirv_vector_macro.stderr b/tests/compiletests/ui/glam/spirv_vector_macro.stderr new file mode 100644 index 0000000000..7bd692745a --- /dev/null +++ b/tests/compiletests/ui/glam/spirv_vector_macro.stderr @@ -0,0 +1,53 @@ +; SPIR-V +; Version: 1.5 +; Generator: rspirv +; Bound: 31 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformShuffleRelative +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 %3 +OpExecutionMode %1 LocalSize 32 1 1 +OpName %2 "input" +OpName %3 "output" +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %2 NonWritable +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 Binding 1 +OpDecorate %3 DescriptorSet 0 +%7 = OpTypeFloat 32 +%8 = OpTypeVector %7 3 +%6 = OpTypeStruct %8 +%9 = OpTypePointer StorageBuffer %6 +%10 = OpTypeVoid +%11 = OpTypeFunction %10 +%12 = OpTypePointer StorageBuffer %8 +%2 = OpVariable %9 StorageBuffer +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 0 +%3 = OpVariable %9 StorageBuffer +%15 = OpTypePointer StorageBuffer %7 +%16 = OpConstant %13 1 +%17 = OpConstant %13 2 +%18 = OpConstant %13 3 +%19 = OpConstant %13 5 +%1 = OpFunction %10 None %11 +%20 = OpLabel +%21 = OpInBoundsAccessChain %12 %2 %14 +%22 = OpInBoundsAccessChain %12 %3 %14 +%23 = OpInBoundsAccessChain %15 %21 %14 +%24 = OpLoad %7 %23 +%25 = OpInBoundsAccessChain %15 %21 %16 +%26 = OpLoad %7 %25 +%27 = OpInBoundsAccessChain %15 %21 %17 +%28 = OpLoad %7 %27 +%29 = OpCompositeConstruct %8 %24 %26 %28 +%30 = OpGroupNonUniformShuffleUp %8 %18 %29 %19 +OpStore %22 %30 +OpNoLine +OpReturn +OpFunctionEnd diff --git a/tests/compiletests/ui/glam/spirv_vector_macro_generic.rs b/tests/compiletests/ui/glam/spirv_vector_macro_generic.rs new file mode 100644 index 0000000000..6f03091d98 --- /dev/null +++ b/tests/compiletests/ui/glam/spirv_vector_macro_generic.rs @@ -0,0 +1,34 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" + +use spirv_std::Scalar; +use spirv_std::arch::subgroup_shuffle_up; +use spirv_std::glam::Vec3; +use spirv_std::spirv; + +#[spirv_std::spirv_vector] +#[derive(Copy, Clone, Default)] +pub struct Vec { + pub x: T, + pub y: T, + pub z: T, +} + +#[spirv(compute(threads(32)))] +pub fn main( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &Vec, + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut Vec, +) { + let vec = Vec { + x: input.x as f32, + y: input.y as f32, + z: input.z as f32, + }; + // any function that accepts a `VectorOrScalar` would do + *output = subgroup_shuffle_up(vec, 5); +} diff --git a/tests/compiletests/ui/glam/spirv_vector_macro_generic.stderr b/tests/compiletests/ui/glam/spirv_vector_macro_generic.stderr new file mode 100644 index 0000000000..7f3ed8d9f1 --- /dev/null +++ b/tests/compiletests/ui/glam/spirv_vector_macro_generic.stderr @@ -0,0 +1,63 @@ +; SPIR-V +; Version: 1.5 +; Generator: rspirv +; Bound: 39 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformShuffleRelative +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main" %2 %3 +OpExecutionMode %1 LocalSize 32 1 1 +OpName %2 "input" +OpName %3 "output" +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %7 Block +OpMemberDecorate %7 0 Offset 0 +OpDecorate %2 NonWritable +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 Binding 1 +OpDecorate %3 DescriptorSet 0 +%8 = OpTypeInt 32 1 +%9 = OpTypeVector %8 3 +%6 = OpTypeStruct %9 +%10 = OpTypePointer StorageBuffer %6 +%11 = OpTypeFloat 32 +%12 = OpTypeVector %11 3 +%7 = OpTypeStruct %12 +%13 = OpTypePointer StorageBuffer %7 +%14 = OpTypeVoid +%15 = OpTypeFunction %14 +%16 = OpTypePointer StorageBuffer %9 +%2 = OpVariable %10 StorageBuffer +%17 = OpTypeInt 32 0 +%18 = OpConstant %17 0 +%19 = OpTypePointer StorageBuffer %12 +%3 = OpVariable %13 StorageBuffer +%20 = OpTypePointer StorageBuffer %8 +%21 = OpConstant %17 1 +%22 = OpConstant %17 2 +%23 = OpConstant %17 3 +%24 = OpConstant %17 5 +%1 = OpFunction %14 None %15 +%25 = OpLabel +%26 = OpInBoundsAccessChain %16 %2 %18 +%27 = OpInBoundsAccessChain %19 %3 %18 +%28 = OpInBoundsAccessChain %20 %26 %18 +%29 = OpLoad %8 %28 +%30 = OpConvertSToF %11 %29 +%31 = OpInBoundsAccessChain %20 %26 %21 +%32 = OpLoad %8 %31 +%33 = OpConvertSToF %11 %32 +%34 = OpInBoundsAccessChain %20 %26 %22 +%35 = OpLoad %8 %34 +%36 = OpConvertSToF %11 %35 +%37 = OpCompositeConstruct %12 %30 %33 %36 +%38 = OpGroupNonUniformShuffleUp %12 %23 %37 %24 +OpStore %27 %38 +OpNoLine +OpReturn +OpFunctionEnd