Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions crates/spirv-std/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ use proc_macro::TokenStream;
use proc_macro2::{Delimiter, Group, Ident, TokenTree};
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{GenericParam, Token};

/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
/// `spirv_std::image::Image<...>` type.
Expand Down Expand Up @@ -311,3 +314,54 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
sample_param_permutations::gen_sample_param_permutations(item)
}

#[proc_macro_attribute]
pub fn spirv_vector(attr: TokenStream, item: TokenStream) -> TokenStream {
spirv_vector_impl(attr.into(), item.into())
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

fn spirv_vector_impl(
attr: proc_macro2::TokenStream,
item: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
// Whenever we'll properly resolve the crate symbol, replace this.
let spirv_std = quote!(spirv_std);

// Defer all validation to our codegen backend. Rather than erroring here, emit garbage.
let item = syn::parse2::<syn::ItemStruct>(item)?;
let ident = &item.ident;
let gens = &item.generics.params;
let gen_refs = &item
.generics
.params
.iter()
.map(|p| match p {
GenericParam::Lifetime(p) => p.lifetime.to_token_stream(),
GenericParam::Type(p) => p.ident.to_token_stream(),
GenericParam::Const(p) => p.ident.to_token_stream(),
})
.collect::<Punctuated<_, Token![,]>>();
let where_clause = &item.generics.where_clause;
let element = item
.fields
.iter()
.next()
.ok_or_else(|| syn::Error::new(item.span(), "Vector ZST not allowed"))?
.ty
.to_token_stream();
let count = item.fields.len();

Ok(quote! {
#[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1(#attr))]
#item

unsafe impl<#gens> #spirv_std::ScalarOrVector for #ident<#gen_refs> #where_clause {
type Scalar = #element;
const N: core::num::NonZeroUsize = core::num::NonZeroUsize::new(#count).unwrap();
}

unsafe impl<#gens> #spirv_std::Vector<#element, #count> for #ident<#gen_refs> #where_clause {}
})
}
2 changes: 1 addition & 1 deletion crates/spirv-std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
/// Public re-export of the `spirv-std-macros` crate.
#[macro_use]
pub extern crate spirv_std_macros as macros;
pub use macros::spirv;
pub use macros::{debug_printf, debug_printfln};
pub use macros::{spirv, spirv_vector};

pub mod arch;
pub mod byte_addressable_buffer;
Expand Down
10 changes: 6 additions & 4 deletions crates/spirv-std/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use glam::{Vec3Swizzles, Vec4Swizzles};

/// Abstract trait representing a SPIR-V vector type.
///
/// To implement this trait, your struct must be marked with:
/// To derive this trait, mark your struct with:
/// ```no_run
/// #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1)]
/// #[spirv_std::spirv_vector]
/// # #[derive(Copy, Clone, Default)]
/// # struct Bla(f32, f32);
/// ```
///
Expand All @@ -32,8 +33,8 @@ use glam::{Vec3Swizzles, Vec4Swizzles};
///
/// # Example
/// ```no_run
/// #[spirv_std::spirv_vector]
/// #[derive(Copy, Clone, Default)]
/// #[cfg_attr(target_arch = "spirv", rust_gpu::vector::v1)]
/// struct MyColor {
/// r: f32,
/// b: f32,
Expand All @@ -44,7 +45,8 @@ use glam::{Vec3Swizzles, Vec4Swizzles};
///
/// # Safety
/// * Must only be implemented on types that the spirv codegen emits as valid `OpTypeVector`. This includes all structs
/// marked with `#[rust_gpu::vector::v1]`, like [`glam`]'s non-SIMD "scalar" vector types.
/// marked with `#[rust_gpu::vector::v1]`, which `#[spirv_std::spirv_vector]` expands into or [`glam`]'s non-SIMD
/// "scalar" vector types use directly.
/// * `VectorOrScalar::DIM == N`, since const equality is behind rustc feature `associated_const_equality`
// Note(@firestar99) I would like to have these two generics be associated types instead. Doesn't make much sense for
// a vector type to implement this interface multiple times with different Scalar types or N, after all.
Expand Down
33 changes: 33 additions & 0 deletions tests/compiletests/ui/glam/invalid_vector_type_macro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// build-fail

use core::num::NonZeroU32;
use spirv_std::glam::Vec2;
use spirv_std::spirv;

#[spirv_std::spirv_vector]
#[derive(Copy, Clone, Default)]
pub struct FewerFields {
_v: f32,
}

#[spirv_std::spirv_vector]
#[derive(Copy, Clone, Default)]
pub struct TooManyFields {
_x: f32,
_y: f32,
_z: f32,
_w: f32,
_v: f32,
}

// wrong member types fails too early

#[spirv_std::spirv_vector]
#[derive(Copy, Clone, Default)]
pub struct DifferentTypes {
_x: f32,
_y: u32,
}

#[spirv(fragment)]
pub fn entry(_: FewerFields, _: TooManyFields, #[spirv(flat)] _: DifferentTypes) {}
20 changes: 20 additions & 0 deletions tests/compiletests/ui/glam/invalid_vector_type_macro.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
error: `#[spirv(vector)]` must have 2, 3 or 4 members
--> $DIR/invalid_vector_type_macro.rs:9:1
|
LL | pub struct FewerFields {
| ^^^^^^^^^^^^^^^^^^^^^^

error: `#[spirv(vector)]` must have 2, 3 or 4 members
--> $DIR/invalid_vector_type_macro.rs:15:1
|
LL | pub struct TooManyFields {
| ^^^^^^^^^^^^^^^^^^^^^^^^

error: `#[spirv(vector)]` member types must all be the same
--> $DIR/invalid_vector_type_macro.rs:27:1
|
LL | pub struct DifferentTypes {
| ^^^^^^^^^^^^^^^^^^^^^^^^^

error: aborting due to 3 previous errors

42 changes: 42 additions & 0 deletions tests/compiletests/ui/glam/invalid_vector_type_macro2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// build-fail
// normalize-stderr-test "\S*/crates/spirv-std/src/" -> "$$SPIRV_STD_SRC/"

use core::num::NonZeroU32;
use spirv_std::glam::Vec2;
use spirv_std::spirv;

#[spirv_std::spirv_vector]
#[derive(Copy, Clone, Default)]
pub struct ZstVector;

#[spirv_std::spirv_vector]
#[derive(Copy, Clone, Default)]
pub struct NotVectorField {
_x: Vec2,
_y: Vec2,
}

#[spirv_std::spirv_vector]
#[derive(Copy, Clone)]
pub struct NotVectorField2 {
_x: NonZeroU32,
_y: NonZeroU32,
}

impl Default for NotVectorField2 {
fn default() -> Self {
Self {
_x: NonZeroU32::new(1).unwrap(),
_y: NonZeroU32::new(1).unwrap(),
}
}
}

#[spirv(fragment)]
pub fn entry(
// workaround to ZST loading
#[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32),
_: NotVectorField,
#[spirv(flat)] _: NotVectorField2,
) {
}
113 changes: 113 additions & 0 deletions tests/compiletests/ui/glam/invalid_vector_type_macro2.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
error: Vector ZST not allowed
--> $DIR/invalid_vector_type_macro2.rs:9:1
|
LL | / #[derive(Copy, Clone, Default)]
LL | | pub struct ZstVector;
| |_____________________^

error[E0412]: cannot find type `ZstVector` in this scope
--> $DIR/invalid_vector_type_macro2.rs:38:67
|
LL | #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32),
| ^^^^^^^^^ not found in this scope

error: unknown argument to spirv attribute
--> $DIR/invalid_vector_type_macro2.rs:38:13
|
LL | #[spirv(storage_class, descriptor_set = 0, binding = 0)] _: &(ZstVector, i32),
| ^^^^^^^^^^^^^

error[E0277]: the trait bound `Vec2: Scalar` is not satisfied
--> $DIR/invalid_vector_type_macro2.rs:15:9
|
LL | _x: Vec2,
| ^^^^ the trait `Scalar` is not implemented for `Vec2`
|
= help: the following other types implement trait `Scalar`:
bool
f32
f64
i16
i32
i64
i8
u16
and 3 others
note: required by a bound in `spirv_std::ScalarOrVector::Scalar`
--> $SPIRV_STD_SRC/scalar_or_vector.rs:18:18
|
LL | type Scalar: Scalar;
| ^^^^^^ required by this bound in `ScalarOrVector::Scalar`

error[E0277]: the trait bound `Vec2: Scalar` is not satisfied
--> $DIR/invalid_vector_type_macro2.rs:12:1
|
LL | #[spirv_std::spirv_vector]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Scalar` is not implemented for `Vec2`
|
= help: the following other types implement trait `Scalar`:
bool
f32
f64
i16
i32
i64
i8
u16
and 3 others
note: required by a bound in `Vector`
--> $SPIRV_STD_SRC/vector.rs:56:28
|
LL | pub unsafe trait Vector<T: Scalar, const N: usize>: ScalarOrVector<Scalar = T> {}
| ^^^^^^ required by this bound in `Vector`
= note: this error originates in the attribute macro `spirv_std::spirv_vector` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `NonZero<u32>: Scalar` is not satisfied
--> $DIR/invalid_vector_type_macro2.rs:22:9
|
LL | _x: NonZeroU32,
| ^^^^^^^^^^ the trait `Scalar` is not implemented for `NonZero<u32>`
|
= help: the following other types implement trait `Scalar`:
bool
f32
f64
i16
i32
i64
i8
u16
and 3 others
note: required by a bound in `spirv_std::ScalarOrVector::Scalar`
--> $SPIRV_STD_SRC/scalar_or_vector.rs:18:18
|
LL | type Scalar: Scalar;
| ^^^^^^ required by this bound in `ScalarOrVector::Scalar`

error[E0277]: the trait bound `NonZero<u32>: Scalar` is not satisfied
--> $DIR/invalid_vector_type_macro2.rs:19:1
|
LL | #[spirv_std::spirv_vector]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Scalar` is not implemented for `NonZero<u32>`
|
= help: the following other types implement trait `Scalar`:
bool
f32
f64
i16
i32
i64
i8
u16
and 3 others
note: required by a bound in `Vector`
--> $SPIRV_STD_SRC/vector.rs:56:28
|
LL | pub unsafe trait Vector<T: Scalar, const N: usize>: ScalarOrVector<Scalar = T> {}
| ^^^^^^ required by this bound in `Vector`
= note: this error originates in the attribute macro `spirv_std::spirv_vector` (in Nightly builds, run with -Z macro-backtrace for more info)

error: aborting due to 7 previous errors

Some errors have detailed explanations: E0277, E0412.
For more information about an error, try `rustc --explain E0277`.
33 changes: 33 additions & 0 deletions tests/compiletests/ui/glam/spirv_vector_macro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// build-pass
// only-vulkan1.2
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformShuffleRelative,+ext:SPV_KHR_vulkan_memory_model
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""

use spirv_std::arch::subgroup_shuffle_up;
use spirv_std::glam::Vec3;
use spirv_std::spirv;

#[spirv_std::spirv_vector]
#[derive(Copy, Clone, Default)]
pub struct MyColor {
pub r: f32,
pub g: f32,
pub b: f32,
}

#[spirv(compute(threads(32)))]
pub fn main(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &Vec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut MyColor,
) {
let color = MyColor {
r: input.x,
g: input.y,
b: input.z,
};
// any function that accepts a `VectorOrScalar` would do
*output = subgroup_shuffle_up(color, 5);
}
Loading
Loading