diff --git a/CHANGELOG.md b/CHANGELOG.md index 880c8cf73a2..d216c9be2c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,12 @@ Bottom level categories: - Expanded documentation of `QuerySet`, `QueryType`, and `resolve_query_set()` describing how to use queries. By @kpreid in [#8776](https://github.com/gfx-rs/wgpu/pull/8776). +### Changes + +#### Naga + +- Prevent UB from incorrectly using ray queries on HLSL. By @Vecvec in [#8763](https://github.com/gfx-rs/wgpu/pull/8763). + ## v28.0.0 (2025-12-17) ### Major Changes diff --git a/naga/src/back/hlsl/keywords.rs b/naga/src/back/hlsl/keywords.rs index 13f48cef8b5..a5ae4085af5 100644 --- a/naga/src/back/hlsl/keywords.rs +++ b/naga/src/back/hlsl/keywords.rs @@ -935,4 +935,6 @@ pub static RESERVED_CASE_INSENSITIVE_SET: RacyLock = pub const RESERVED_PREFIXES: &[&str] = &[ "__dynamic_buffer_offsets", super::help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER, + super::writer::RAY_QUERY_TRACKER_VARIABLE_PREFIX, + super::writer::INTERNAL_PREFIX, ]; diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 0230434a742..799beae6fcf 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -543,6 +543,9 @@ pub struct Options { /// If set, loops will have code injected into them, forcing the compiler /// to think the number of iterations is bounded. pub force_loop_bounding: bool, + /// if set, ray queries will get a variable to track their state to prevent + /// misuse. + pub ray_query_initialization_tracking: bool, } impl Default for Options { @@ -560,6 +563,7 @@ impl Default for Options { zero_initialize_workgroup_memory: true, restrict_indexing: true, force_loop_bounding: true, + ray_query_initialization_tracking: true, } } } diff --git a/naga/src/back/hlsl/ray.rs b/naga/src/back/hlsl/ray.rs index 24eae4b573f..0d69756362d 100644 --- a/naga/src/back/hlsl/ray.rs +++ b/naga/src/back/hlsl/ray.rs @@ -1,9 +1,35 @@ +use alloc::{ + format, + string::{String, ToString}, + vec, + vec::Vec, +}; use core::fmt::Write; -use crate::back::hlsl::BackendResult; +use crate::{ + back::{hlsl::BackendResult, Baked, Level}, + Handle, +}; use crate::{RayQueryIntersection, TypeInner}; impl super::Writer<'_, W> { + // https://sakibsaikia.github.io/graphics/2022/01/04/Nan-Checks-In-HLSL.html suggests that isnan may not work, unsure if this has changed. + fn write_not_finite(&mut self, expr: &str) -> BackendResult { + self.write_contains_flags(&format!("asuint({expr})"), 0x7f800000) + } + + fn write_nan(&mut self, expr: &str) -> BackendResult { + write!(self.out, "(")?; + self.write_not_finite(expr)?; + write!(self.out, " && ((asuint({expr}) & 0x7fffff) != 0))")?; + Ok(()) + } + + fn write_contains_flags(&mut self, expr: &str, flags: u32) -> BackendResult { + write!(self.out, "(({expr} & {flags}) == {flags})")?; + Ok(()) + } + // constructs hlsl RayDesc from wgsl RayDesc pub(super) fn write_ray_desc_from_ray_desc_constructor_function( &mut self, @@ -34,60 +60,81 @@ impl super::Writer<'_, W> { vertex_return: false, }, )?; - writeln!(self.out, " rq) {{")?; + write!(self.out, " rq, ")?; + self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?; + writeln!(self.out, " rq_tracker) {{")?; write!(self.out, " ")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; write!(self.out, " ret = (")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; writeln!(self.out, ")0;")?; - writeln!(self.out, " ret.kind = rq.CommittedStatus();")?; + let mut extra_level = Level(0); + if self.options.ray_query_initialization_tracking { + // *Technically*, `CommittedStatus` is valid as long as the ray query is initialized, but the metal backend + // doesn't support this function unless it has finished traversal, so to encourage portable behaviour we + // disallow it here too. + write!(self.out, " if (")?; + self.write_contains_flags( + "rq_tracker", + crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + )?; + writeln!(self.out, ") {{")?; + extra_level = extra_level.next(); + } writeln!( self.out, - " if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{" + " {extra_level}ret.kind = rq.CommittedStatus();" )?; - writeln!(self.out, " ret.t = rq.CommittedRayT();")?; writeln!( self.out, - " ret.instance_custom_data = rq.CommittedInstanceID();" + " {extra_level}if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{" )?; + writeln!(self.out, " {extra_level}ret.t = rq.CommittedRayT();")?; writeln!( self.out, - " ret.instance_index = rq.CommittedInstanceIndex();" + " {extra_level}ret.instance_custom_data = rq.CommittedInstanceID();" )?; writeln!( self.out, - " ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();" + " {extra_level}ret.instance_index = rq.CommittedInstanceIndex();" )?; writeln!( self.out, - " ret.geometry_index = rq.CommittedGeometryIndex();" + " {extra_level}ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();" )?; writeln!( self.out, - " ret.primitive_index = rq.CommittedPrimitiveIndex();" + " {extra_level}ret.geometry_index = rq.CommittedGeometryIndex();" )?; writeln!( self.out, - " if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{" + " {extra_level}ret.primitive_index = rq.CommittedPrimitiveIndex();" )?; writeln!( self.out, - " ret.barycentrics = rq.CommittedTriangleBarycentrics();" + " {extra_level}if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{" )?; writeln!( self.out, - " ret.front_face = rq.CommittedTriangleFrontFace();" + " {extra_level}ret.barycentrics = rq.CommittedTriangleBarycentrics();" )?; - writeln!(self.out, " }}")?; writeln!( self.out, - " ret.object_to_world = rq.CommittedObjectToWorld4x3();" + " {extra_level}ret.front_face = rq.CommittedTriangleFrontFace();" )?; + writeln!(self.out, " {extra_level}}}")?; writeln!( self.out, - " ret.world_to_object = rq.CommittedWorldToObject4x3();" + " {extra_level}ret.object_to_world = rq.CommittedObjectToWorld4x3();" )?; - writeln!(self.out, " }}")?; + writeln!( + self.out, + " {extra_level}ret.world_to_object = rq.CommittedWorldToObject4x3();" + )?; + writeln!(self.out, " {extra_level}}}")?; + if self.options.ray_query_initialization_tracking { + writeln!(self.out, " }}")?; + } writeln!(self.out, " return ret;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; @@ -105,70 +152,413 @@ impl super::Writer<'_, W> { vertex_return: false, }, )?; - writeln!(self.out, " rq) {{")?; + write!(self.out, " rq, ")?; + self.write_value_type(module, &TypeInner::Scalar(crate::Scalar::U32))?; + writeln!(self.out, " rq_tracker) {{")?; write!(self.out, " ")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; write!(self.out, " ret = (")?; self.write_type(module, module.special_types.ray_intersection.unwrap())?; writeln!(self.out, ")0;")?; - writeln!(self.out, " CANDIDATE_TYPE kind = rq.CandidateType();")?; + let mut extra_level = Level(0); + if self.options.ray_query_initialization_tracking { + write!(self.out, " if (")?; + self.write_contains_flags("rq_tracker", crate::back::RayQueryPoint::PROCEED.bits())?; + write!(self.out, " && !")?; + self.write_contains_flags( + "rq_tracker", + crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + )?; + writeln!(self.out, ") {{")?; + extra_level = extra_level.next(); + } + writeln!( + self.out, + " {extra_level}CANDIDATE_TYPE kind = rq.CandidateType();" + )?; writeln!( self.out, - " if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{" + " {extra_level}if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{" )?; writeln!( self.out, - " ret.kind = {};", + " {extra_level}ret.kind = {};", RayQueryIntersection::Triangle as u32 )?; - writeln!(self.out, " ret.t = rq.CandidateTriangleRayT();")?; writeln!( self.out, - " ret.barycentrics = rq.CandidateTriangleBarycentrics();" + " {extra_level}ret.t = rq.CandidateTriangleRayT();" + )?; + writeln!( + self.out, + " {extra_level}ret.barycentrics = rq.CandidateTriangleBarycentrics();" )?; writeln!( self.out, - " ret.front_face = rq.CandidateTriangleFrontFace();" + " {extra_level}ret.front_face = rq.CandidateTriangleFrontFace();" )?; - writeln!(self.out, " }} else {{")?; + writeln!(self.out, " {extra_level}}} else {{")?; writeln!( self.out, - " ret.kind = {};", + " {extra_level}ret.kind = {};", RayQueryIntersection::Aabb as u32 )?; - writeln!(self.out, " }}")?; + writeln!(self.out, " {extra_level}}}")?; writeln!( self.out, - " ret.instance_custom_data = rq.CandidateInstanceID();" + " {extra_level}ret.instance_custom_data = rq.CandidateInstanceID();" )?; writeln!( self.out, - " ret.instance_index = rq.CandidateInstanceIndex();" + " {extra_level}ret.instance_index = rq.CandidateInstanceIndex();" )?; writeln!( self.out, - " ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();" + " {extra_level}ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();" )?; writeln!( self.out, - " ret.geometry_index = rq.CandidateGeometryIndex();" + " {extra_level}ret.geometry_index = rq.CandidateGeometryIndex();" )?; writeln!( self.out, - " ret.primitive_index = rq.CandidatePrimitiveIndex();" + " {extra_level}ret.primitive_index = rq.CandidatePrimitiveIndex();" )?; writeln!( self.out, - " ret.object_to_world = rq.CandidateObjectToWorld4x3();" + " {extra_level}ret.object_to_world = rq.CandidateObjectToWorld4x3();" )?; writeln!( self.out, - " ret.world_to_object = rq.CandidateWorldToObject4x3();" + " {extra_level}ret.world_to_object = rq.CandidateWorldToObject4x3();" )?; + if self.options.ray_query_initialization_tracking { + writeln!(self.out, " }}")?; + } writeln!(self.out, " return ret;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; Ok(()) } + + #[expect(clippy::too_many_arguments)] + pub(super) fn write_initialize_function( + &mut self, + module: &crate::Module, + mut level: Level, + query: Handle, + acceleration_structure: Handle, + descriptor: Handle, + rq_tracker: &str, + func_ctx: &crate::back::FunctionCtx<'_>, + ) -> BackendResult { + let base_level = level; + + // This prevents variables flowing down a level and causing compile errors. + writeln!(self.out, "{level}{{")?; + level = level.next(); + write!(self.out, "{level}")?; + self.write_type( + module, + module + .special_types + .ray_desc + .expect("should have been generated"), + )?; + write!(self.out, " naga_desc = ")?; + self.write_expr(module, descriptor, func_ctx)?; + writeln!(self.out, ";")?; + + if self.options.ray_query_initialization_tracking { + // Validate ray extents https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#ray-extents + + // just for convenience + writeln!(self.out, "{level}float naga_tmin = naga_desc.tmin;")?; + writeln!(self.out, "{level}float naga_tmax = naga_desc.tmax;")?; + writeln!(self.out, "{level}float3 naga_origin = naga_desc.origin;")?; + writeln!(self.out, "{level}float3 naga_dir = naga_desc.dir;")?; + writeln!(self.out, "{level}uint naga_flags = naga_desc.flags;")?; + write!( + self.out, + "{level}bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !" + )?; + self.write_nan("naga_tmin")?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_tmax_valid = !")?; + self.write_nan("naga_tmax")?; + writeln!(self.out, ";")?; + // Unlike Vulkan it seems that for DX12, it seems only NaN components of the origin and direction are invalid + write!(self.out, "{level}bool naga_origin_valid = !any(")?; + self.write_nan("naga_origin")?; + writeln!(self.out, ");")?; + write!(self.out, "{level}bool naga_dir_valid = !any(")?; + self.write_nan("naga_dir")?; + writeln!(self.out, ");")?; + write!(self.out, "{level}bool naga_contains_opaque = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_OPAQUE.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_no_opaque = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::FORCE_NO_OPAQUE.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_cull_opaque = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::CULL_OPAQUE.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_cull_no_opaque = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::CULL_NO_OPAQUE.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_cull_front = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::CULL_FRONT_FACING.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_cull_back = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::CULL_BACK_FACING.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_skip_triangles = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_TRIANGLES.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_contains_skip_aabbs = ")?; + self.write_contains_flags("naga_flags", crate::RayFlag::SKIP_AABBS.bits())?; + writeln!(self.out, ";")?; + // A textified version of the same in the spirv writer + fn less_than_two_true(mut bools: Vec<&str>) -> Result { + assert!(bools.len() > 1, "Must have multiple booleans!"); + let mut final_expr = String::new(); + while let Some(last_bool) = bools.pop() { + for &bool in &bools { + if !final_expr.is_empty() { + final_expr.push_str("||"); + } + write!(final_expr, " ({last_bool} && {bool}) ")?; + } + } + Ok(final_expr) + } + writeln!( + self.out, + "{level}bool naga_contains_skip_triangles_aabbs = {};", + less_than_two_true(vec![ + "naga_contains_skip_triangles", + "naga_contains_skip_aabbs" + ])? + )?; + writeln!( + self.out, + "{level}bool naga_contains_skip_triangles_cull = {};", + less_than_two_true(vec![ + "naga_contains_skip_triangles", + "naga_contains_cull_back", + "naga_contains_cull_front" + ])? + )?; + writeln!( + self.out, + "{level}bool naga_contains_multiple_opaque = {};", + less_than_two_true(vec![ + "naga_contains_opaque", + "naga_contains_no_opaque", + "naga_contains_cull_opaque", + "naga_contains_cull_no_opaque" + ])? + )?; + writeln!( + self.out, + "{level}if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) {{" + )?; + level = level.next(); + writeln!( + self.out, + "{level}{rq_tracker} = {rq_tracker} | {};", + crate::back::RayQueryPoint::INITIALIZED.bits() + )?; + } + write!(self.out, "{level}")?; + self.write_expr(module, query, func_ctx)?; + write!(self.out, ".TraceRayInline(")?; + self.write_expr(module, acceleration_structure, func_ctx)?; + writeln!( + self.out, + ", naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc));" + )?; + if self.options.ray_query_initialization_tracking { + writeln!(self.out, "{base_level} }}")?; + } + writeln!(self.out, "{base_level}}}")?; + Ok(()) + } + + pub(super) fn write_proceed( + &mut self, + module: &crate::Module, + mut level: Level, + query: Handle, + result: Handle, + rq_tracker: &str, + func_ctx: &crate::back::FunctionCtx<'_>, + ) -> BackendResult { + let base_level = level; + write!(self.out, "{level}")?; + let name = Baked(result).to_string(); + writeln!(self.out, "bool {name} = false;")?; + // This prevents variables flowing down a level and causing compile errors. + if self.options.ray_query_initialization_tracking { + writeln!(self.out, "{level}{{")?; + level = level.next(); + write!(self.out, "{level}bool naga_has_initialized = ")?; + self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?; + writeln!(self.out, ";")?; + write!(self.out, "{level}bool naga_has_finished = ")?; + self.write_contains_flags( + rq_tracker, + crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + )?; + writeln!(self.out, ";")?; + writeln!( + self.out, + "{level}if (naga_has_initialized && !naga_has_finished) {{" + )?; + level = level.next(); + } + + write!(self.out, "{level}{name} = ")?; + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".Proceed();")?; + + if self.options.ray_query_initialization_tracking { + writeln!( + self.out, + "{level}{rq_tracker} = {rq_tracker} | {};", + crate::back::RayQueryPoint::PROCEED.bits() + )?; + writeln!( + self.out, + "{level}if (!{name}) {{ {rq_tracker} = {rq_tracker} | {}; }}", + crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits() + )?; + writeln!(self.out, "{base_level}}}}}")?; + } + + self.named_expressions.insert(result, name); + + Ok(()) + } + + pub(super) fn write_generate_intersection( + &mut self, + module: &crate::Module, + mut level: Level, + query: Handle, + hit_t: Handle, + rq_tracker: &str, + func_ctx: &crate::back::FunctionCtx<'_>, + ) -> BackendResult { + let base_level = level; + if self.options.ray_query_initialization_tracking { + write!(self.out, "{level}if (")?; + self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?; + write!(self.out, " && !")?; + self.write_contains_flags( + rq_tracker, + crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + )?; + writeln!(self.out, ") {{")?; + level = level.next(); + write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?; + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".CandidateType();")?; + write!(self.out, "{level}float naga_tmin = ")?; + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".RayTMin();")?; + write!(self.out, "{level}float naga_tcurrentmax = ")?; + self.write_expr(module, query, func_ctx)?; + // This gets initialized to tmax and is updated after each intersection is committed so is valid to call. + // Note: there is a bug in DXC's spirv backend that makes this technically UB in spirv, but HLSL backend + // is intended for DXIL, so it should be fine (hopefully). + writeln!(self.out, ".CommittedRayT();")?; + write!( + self.out, + "{level}if ((naga_kind == CANDIDATE_PROCEDURAL_PRIMITIVE) && (naga_tmin <=" + )?; + self.write_expr(module, hit_t, func_ctx)?; + write!(self.out, ") && (")?; + self.write_expr(module, hit_t, func_ctx)?; + writeln!(self.out, " <= naga_tcurrentmax)) {{")?; + level = level.next(); + } + + write!(self.out, "{level}")?; + self.write_expr(module, query, func_ctx)?; + write!(self.out, ".CommitProceduralPrimitiveHit(")?; + self.write_expr(module, hit_t, func_ctx)?; + writeln!(self.out, ");")?; + if self.options.ray_query_initialization_tracking { + writeln!(self.out, "{base_level}}}}}")?; + } + Ok(()) + } + pub(super) fn write_confirm_intersection( + &mut self, + module: &crate::Module, + mut level: Level, + query: Handle, + rq_tracker: &str, + func_ctx: &crate::back::FunctionCtx<'_>, + ) -> BackendResult { + let base_level = level; + if self.options.ray_query_initialization_tracking { + write!(self.out, "{level}if (")?; + self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::PROCEED.bits())?; + write!(self.out, " && !")?; + self.write_contains_flags( + rq_tracker, + crate::back::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + )?; + writeln!(self.out, ") {{")?; + level = level.next(); + write!(self.out, "{level}CANDIDATE_TYPE naga_kind = ")?; + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".CandidateType();")?; + writeln!( + self.out, + "{level}if (naga_kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{" + )?; + level = level.next(); + } + + write!(self.out, "{level}")?; + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?; + if self.options.ray_query_initialization_tracking { + writeln!(self.out, "{base_level}}}}}")?; + } + Ok(()) + } + + pub(super) fn write_terminate( + &mut self, + module: &crate::Module, + mut level: Level, + query: Handle, + rq_tracker: &str, + func_ctx: &crate::back::FunctionCtx<'_>, + ) -> BackendResult { + let base_level = level; + if self.options.ray_query_initialization_tracking { + write!(self.out, "{level}if (")?; + // RayQuery::Abort() can be called any time after RayQuery::TraceRayInline() has been called. + // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#rayquery-abort + self.write_contains_flags(rq_tracker, crate::back::RayQueryPoint::INITIALIZED.bits())?; + writeln!(self.out, ") {{")?; + level = level.next(); + } + + write!(self.out, "{level}")?; + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".Abort();")?; + + if self.options.ray_query_initialization_tracking { + writeln!(self.out, "{base_level}}}")?; + } + + Ok(()) + } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 95feb4d9d98..773c124d01f 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -46,6 +46,9 @@ pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64"; pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str = "nagaTextureSampleBaseClampToEdge"; pub(crate) const IMAGE_LOAD_EXTERNAL_FUNCTION: &str = "nagaTextureLoadExternal"; +pub(crate) const RAY_QUERY_TRACKER_VARIABLE_PREFIX: &str = "naga_query_init_tracker_for_"; +/// Prefix for variables in a naga statement +pub(crate) const INTERNAL_PREFIX: &str = "naga_"; enum Index { Expression(Handle), @@ -1664,9 +1667,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_array_size(module, base, size)?; } - match module.types[local.ty].inner { + let is_ray_query = match module.types[local.ty].inner { // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed - TypeInner::RayQuery { .. } => {} + TypeInner::RayQuery { .. } => true, _ => { write!(self.out, " = ")?; // Write the local initializer if needed @@ -1676,10 +1679,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // Zero initialize local variables self.write_default_init(module, local.ty)?; } + false } - } + }; // Finish the local with `;` and add a newline (only for readability) - writeln!(self.out, ";")? + writeln!(self.out, ";")?; + // If it's a ray query, we also want a tracker variable + if is_ray_query { + write!(self.out, "{}", back::INDENT)?; + self.write_value_type(module, &TypeInner::Scalar(Scalar::U32))?; + writeln!( + self.out, + " {RAY_QUERY_TRACKER_VARIABLE_PREFIX}{} = 0;", + self.names[&func_ctx.name_key(handle)] + )?; + } } if !func.local_variables.is_empty() { @@ -2564,50 +2578,77 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } => { self.write_switch(module, func_ctx, level, selector, cases)?; } - Statement::RayQuery { query, ref fun } => match *fun { - RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - } => { - write!(self.out, "{level}")?; - self.write_expr(module, query, func_ctx)?; - write!(self.out, ".TraceRayInline(")?; - self.write_expr(module, acceleration_structure, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, descriptor, func_ctx)?; - write!(self.out, ".flags, ")?; - self.write_expr(module, descriptor, func_ctx)?; - write!(self.out, ".cull_mask, ")?; - write!(self.out, "RayDescFromRayDesc_(")?; - self.write_expr(module, descriptor, func_ctx)?; - writeln!(self.out, "));")?; - } - RayQueryFunction::Proceed { result } => { - write!(self.out, "{level}")?; - let name = Baked(result).to_string(); - write!(self.out, "const bool {name} = ")?; - self.named_expressions.insert(result, name); - self.write_expr(module, query, func_ctx)?; - writeln!(self.out, ".Proceed();")?; - } - RayQueryFunction::GenerateIntersection { hit_t } => { - write!(self.out, "{level}")?; - self.write_expr(module, query, func_ctx)?; - write!(self.out, ".CommitProceduralPrimitiveHit(")?; - self.write_expr(module, hit_t, func_ctx)?; - writeln!(self.out, ");")?; - } - RayQueryFunction::ConfirmIntersection => { - write!(self.out, "{level}")?; - self.write_expr(module, query, func_ctx)?; - writeln!(self.out, ".CommitNonOpaqueTriangleHit();")?; - } - RayQueryFunction::Terminate => { - write!(self.out, "{level}")?; - self.write_expr(module, query, func_ctx)?; - writeln!(self.out, ".Abort();")?; + Statement::RayQuery { query, ref fun } => { + // There are three possibilities for a ptr to be: + // 1. A variable + // 2. A function argument + // 3. part of a struct + // + // 2 and 3 are not possible, a ray query (in naga IR) + // is not allowed to be passed into a function, and + // all languages disallow it in a struct (you get fun results if + // you try it :) ). + // + // Therefore, the ray query expression must be a variable. + let crate::Expression::LocalVariable(query_var) = func_ctx.expressions[query] + else { + unreachable!() + }; + + let tracker_expr_name = format!( + "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}", + self.names[&func_ctx.name_key(query_var)] + ); + + match *fun { + RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + self.write_initialize_function( + module, + level, + query, + acceleration_structure, + descriptor, + &tracker_expr_name, + func_ctx, + )?; + } + RayQueryFunction::Proceed { result } => { + self.write_proceed( + module, + level, + query, + result, + &tracker_expr_name, + func_ctx, + )?; + } + RayQueryFunction::GenerateIntersection { hit_t } => { + self.write_generate_intersection( + module, + level, + query, + hit_t, + &tracker_expr_name, + func_ctx, + )?; + } + RayQueryFunction::ConfirmIntersection => { + self.write_confirm_intersection( + module, + level, + query, + &tracker_expr_name, + func_ctx, + )?; + } + RayQueryFunction::Terminate => { + self.write_terminate(module, level, query, &tracker_expr_name, func_ctx)?; + } } - }, + } Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); @@ -4275,14 +4316,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")? } Expression::RayQueryGetIntersection { query, committed } => { + // For reasoning, see write_stmt + let Expression::LocalVariable(query_var) = func_ctx.expressions[query] else { + unreachable!() + }; + + let tracker_expr_name = format!( + "{RAY_QUERY_TRACKER_VARIABLE_PREFIX}{}", + self.names[&func_ctx.name_key(query_var)] + ); + if committed { write!(self.out, "GetCommittedIntersection(")?; self.write_expr(module, query, func_ctx)?; - write!(self.out, ")")?; + write!(self.out, ", {tracker_expr_name})")?; } else { write!(self.out, "GetCandidateIntersection(")?; self.write_expr(module, query, func_ctx)?; - write!(self.out, ")")?; + write!(self.out, ", {tracker_expr_name})")?; } } // Not supported yet diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 91fca9e42b3..8e1dcef803f 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -61,6 +61,26 @@ impl core::fmt::Display for Baked { } } +bitflags::bitflags! { + /// How far through a ray query are we + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + #[cfg_attr( + not(any(hlsl_out, spv_out)), + allow( + dead_code, + reason = "shared helpers can be dead if none of the enabled backends need it" + ) + )] + pub(super) struct RayQueryPoint: u32 { + /// Ray query has been successfully initialized. + const INITIALIZED = 1 << 0; + /// Proceed has been called on ray query. + const PROCEED = 1 << 1; + /// Proceed has returned false (have finished traversal). + const FINISHED_TRAVERSAL = 1 << 2; + } +} + /// Specifies the values of pipeline-overridable constants in the shader module. /// /// If an `@id` attribute was specified on the declaration, diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 9b8fbe618ba..d90d1569b52 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -895,19 +895,6 @@ bitflags::bitflags! { } } -bitflags::bitflags! { - /// How far through a ray query are we - #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub(super) struct RayQueryPoint: u32 { - /// Ray query has been successfully initialized. - const INITIALIZED = 1 << 0; - /// Proceed has been called on ray query. - const PROCEED = 1 << 1; - /// Proceed has returned false (have finished traversal). - const FINISHED_TRAVERSAL = 1 << 2; - } -} - #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index b65c9daa32d..e9a5b835240 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -5,10 +5,10 @@ Generating SPIR-V for ray query operations. use alloc::{vec, vec::Vec}; use super::{ - Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType, - Writer, + Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, + LookupRayQueryFunction, NumericType, Writer, }; -use crate::{arena::Handle, back::spv::LookupRayQueryFunction}; +use crate::{arena::Handle, back::RayQueryPoint}; /// helper function to check if a particular flag is set in a u32. fn write_ray_flags_contains_flags( @@ -190,13 +190,13 @@ impl Writer { self, &mut block, loaded_ray_query_tracker_id, - super::RayQueryPoint::PROCEED.bits(), + RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, loaded_ray_query_tracker_id, - super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); let proceed_finished_correct_id = if is_committed { finished_proceed_id @@ -998,9 +998,8 @@ impl Writer { tmax_id, )); - let const_initialized = self.get_constant_scalar(crate::Literal::U32( - super::RayQueryPoint::INITIALIZED.bits(), - )); + let const_initialized = + self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::INITIALIZED.bits())); valid_block .body .push(Instruction::store(init_tracker_id, const_initialized, None)); @@ -1089,7 +1088,7 @@ impl Writer { self, &mut block, initialized_tracker_id, - super::RayQueryPoint::INITIALIZED.bits(), + RayQueryPoint::INITIALIZED.bits(), ); block.body.push(Instruction::selection_merge( @@ -1116,10 +1115,10 @@ impl Writer { .push(Instruction::store(proceeded_id, has_proceeded, None)); let add_flag_finished = self.get_constant_scalar(crate::Literal::U32( - (super::RayQueryPoint::PROCEED | super::RayQueryPoint::FINISHED_TRAVERSAL).bits(), + (RayQueryPoint::PROCEED | RayQueryPoint::FINISHED_TRAVERSAL).bits(), )); let add_flag_continuing = - self.get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::PROCEED.bits())); + self.get_constant_scalar(crate::Literal::U32(RayQueryPoint::PROCEED.bits())); let add_flags_id = self.id_gen.next(); valid_block.body.push(Instruction::select( @@ -1226,13 +1225,13 @@ impl Writer { self, &mut block, initialized_tracker_id, - super::RayQueryPoint::PROCEED.bits(), + RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, - super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); // Can't find anything to suggest double calling this function is invalid. @@ -1501,13 +1500,13 @@ impl Writer { self, &mut block, initialized_tracker_id, - super::RayQueryPoint::PROCEED.bits(), + RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, - super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); // Although it seems strange to call this twice, I (Vecvec) can't find anything to suggest double calling this function is invalid. let not_finished_id = self.id_gen.next(); @@ -1673,13 +1672,13 @@ impl Writer { self, &mut block, initialized_tracker_id, - super::RayQueryPoint::PROCEED.bits(), + RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, - super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); let correct_finish_id = if is_committed { @@ -1825,14 +1824,14 @@ impl Writer { self, &mut block, initialized_tracker_id, - super::RayQueryPoint::PROCEED.bits(), + RayQueryPoint::PROCEED.bits(), ); let finished_proceed_id = write_ray_flags_contains_flags( self, &mut block, initialized_tracker_id, - super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + RayQueryPoint::FINISHED_TRAVERSAL.bits(), ); let not_finished_id = self.id_gen.next(); diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 3e607dcf778..47e8b308ad9 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1140,9 +1140,9 @@ impl Writer { .writer .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function); let tracker_id = context.gen_id(); - let tracker_init_id = context - .writer - .get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::empty().bits())); + let tracker_init_id = context.writer.get_constant_scalar(crate::Literal::U32( + crate::back::RayQueryPoint::empty().bits(), + )); let tracker_instruction = Instruction::variable( ptr_u32_type_id, tracker_id, diff --git a/naga/tests/in/wgsl/ray-query-no-init-tracking.toml b/naga/tests/in/wgsl/ray-query-no-init-tracking.toml index 4ae5b151277..e7e09a4bbc4 100644 --- a/naga/tests/in/wgsl/ray-query-no-init-tracking.toml +++ b/naga/tests/in/wgsl/ray-query-no-init-tracking.toml @@ -11,8 +11,7 @@ zero_initialize_workgroup_memory = false shader_model = "V6_5" fake_missing_bindings = true zero_initialize_workgroup_memory = true -# Not yet implemented -# ray_query_initialization_tracking = false +ray_query_initialization_tracking = false [spv] version = [1, 4] diff --git a/naga/tests/out/hlsl/wgsl-aliased-ray-query.hlsl b/naga/tests/out/hlsl/wgsl-aliased-ray-query.hlsl index 0b31e87ff89..55221e7f971 100644 --- a/naga/tests/out/hlsl/wgsl-aliased-ray-query.hlsl +++ b/naga/tests/out/hlsl/wgsl-aliased-ray-query.hlsl @@ -49,24 +49,26 @@ RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 return ret; } -RayIntersection GetCandidateIntersection(RayQuery rq) { +RayIntersection GetCandidateIntersection(RayQuery rq, uint rq_tracker) { RayIntersection ret = (RayIntersection)0; - CANDIDATE_TYPE kind = rq.CandidateType(); - if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { - ret.kind = 1; - ret.t = rq.CandidateTriangleRayT(); - ret.barycentrics = rq.CandidateTriangleBarycentrics(); - ret.front_face = rq.CandidateTriangleFrontFace(); - } else { - ret.kind = 3; + if (((rq_tracker & 2) == 2) && !((rq_tracker & 4) == 4)) { + CANDIDATE_TYPE kind = rq.CandidateType(); + if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { + ret.kind = 1; + ret.t = rq.CandidateTriangleRayT(); + ret.barycentrics = rq.CandidateTriangleBarycentrics(); + ret.front_face = rq.CandidateTriangleFrontFace(); + } else { + ret.kind = 3; + } + ret.instance_custom_data = rq.CandidateInstanceID(); + ret.instance_index = rq.CandidateInstanceIndex(); + ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CandidateGeometryIndex(); + ret.primitive_index = rq.CandidatePrimitiveIndex(); + ret.object_to_world = rq.CandidateObjectToWorld4x3(); + ret.world_to_object = rq.CandidateWorldToObject4x3(); } - ret.instance_custom_data = rq.CandidateInstanceID(); - ret.instance_index = rq.CandidateInstanceIndex(); - ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex(); - ret.geometry_index = rq.CandidateGeometryIndex(); - ret.primitive_index = rq.CandidatePrimitiveIndex(); - ret.object_to_world = rq.CandidateObjectToWorld4x3(); - ret.world_to_object = rq.CandidateWorldToObject4x3(); return ret; } @@ -74,20 +76,59 @@ RayIntersection GetCandidateIntersection(RayQuery rq) { void main_candidate() { RayQuery rq_1; + uint naga_query_init_tracker_for_rq_1 = 0; float3 pos = (0.0).xxx; float3 dir = float3(0.0, 1.0, 0.0); - rq_1.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir))); - RayIntersection intersection = GetCandidateIntersection(rq_1); + { + RayDesc_ naga_desc = ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir); + float naga_tmin = naga_desc.tmin; + float naga_tmax = naga_desc.tmax; + float3 naga_origin = naga_desc.origin; + float3 naga_dir = naga_desc.dir; + uint naga_flags = naga_desc.flags; + bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !(((asuint(naga_tmin) & 2139095040) == 2139095040) && ((asuint(naga_tmin) & 0x7fffff) != 0)); + bool naga_tmax_valid = !(((asuint(naga_tmax) & 2139095040) == 2139095040) && ((asuint(naga_tmax) & 0x7fffff) != 0)); + bool naga_origin_valid = !any((((asuint(naga_origin) & 2139095040) == 2139095040) && ((asuint(naga_origin) & 0x7fffff) != 0))); + bool naga_dir_valid = !any((((asuint(naga_dir) & 2139095040) == 2139095040) && ((asuint(naga_dir) & 0x7fffff) != 0))); + bool naga_contains_opaque = ((naga_flags & 1) == 1); + bool naga_contains_no_opaque = ((naga_flags & 2) == 2); + bool naga_contains_cull_opaque = ((naga_flags & 64) == 64); + bool naga_contains_cull_no_opaque = ((naga_flags & 128) == 128); + bool naga_contains_cull_front = ((naga_flags & 32) == 32); + bool naga_contains_cull_back = ((naga_flags & 16) == 16); + bool naga_contains_skip_triangles = ((naga_flags & 256) == 256); + bool naga_contains_skip_aabbs = ((naga_flags & 512) == 512); + bool naga_contains_skip_triangles_aabbs = (naga_contains_skip_aabbs && naga_contains_skip_triangles) ; + bool naga_contains_skip_triangles_cull = (naga_contains_cull_front && naga_contains_skip_triangles) || (naga_contains_cull_front && naga_contains_cull_back) || (naga_contains_cull_back && naga_contains_skip_triangles) ; + bool naga_contains_multiple_opaque = (naga_contains_cull_no_opaque && naga_contains_opaque) || (naga_contains_cull_no_opaque && naga_contains_no_opaque) || (naga_contains_cull_no_opaque && naga_contains_cull_opaque) || (naga_contains_cull_opaque && naga_contains_opaque) || (naga_contains_cull_opaque && naga_contains_no_opaque) || (naga_contains_no_opaque && naga_contains_opaque) ; + if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) { + naga_query_init_tracker_for_rq_1 = naga_query_init_tracker_for_rq_1 | 1; + rq_1.TraceRayInline(acc_struct, naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc)); + } + } + RayIntersection intersection = GetCandidateIntersection(rq_1, naga_query_init_tracker_for_rq_1); if ((intersection.kind == 3u)) { - rq_1.CommitProceduralPrimitiveHit(10.0); + if (((naga_query_init_tracker_for_rq_1 & 2) == 2) && !((naga_query_init_tracker_for_rq_1 & 4) == 4)) { + CANDIDATE_TYPE naga_kind = rq_1.CandidateType(); + float naga_tmin = rq_1.RayTMin(); + float naga_tcurrentmax = rq_1.CommittedRayT(); + if ((naga_kind == CANDIDATE_PROCEDURAL_PRIMITIVE) && (naga_tmin <=10.0) && (10.0 <= naga_tcurrentmax)) { + rq_1.CommitProceduralPrimitiveHit(10.0); + }} return; } else { if ((intersection.kind == 1u)) { - rq_1.CommitNonOpaqueTriangleHit(); + if (((naga_query_init_tracker_for_rq_1 & 2) == 2) && !((naga_query_init_tracker_for_rq_1 & 4) == 4)) { + CANDIDATE_TYPE naga_kind = rq_1.CandidateType(); + if (naga_kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { + rq_1.CommitNonOpaqueTriangleHit(); + }} return; } else { - rq_1.Abort(); + if (((naga_query_init_tracker_for_rq_1 & 1) == 1)) { + rq_1.Abort(); + } return; } } diff --git a/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl b/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl index 68be09b6b01..418cbbd79f1 100644 --- a/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl +++ b/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl @@ -59,7 +59,7 @@ RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 return ret; } -RayIntersection GetCommittedIntersection(RayQuery rq) { +RayIntersection GetCommittedIntersection(RayQuery rq, uint rq_tracker) { RayIntersection ret = (RayIntersection)0; ret.kind = rq.CommittedStatus(); if( rq.CommittedStatus() == COMMITTED_NOTHING) {} else { @@ -82,13 +82,18 @@ RayIntersection GetCommittedIntersection(RayQuery rq) { RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructure acs) { RayQuery rq_1; + uint naga_query_init_tracker_for_rq_1 = 0; - rq_1.TraceRayInline(acs, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir))); + { + RayDesc_ naga_desc = ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir); + rq_1.TraceRayInline(acs, naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc)); + } uint2 loop_bound = uint2(4294967295u, 4294967295u); while(true) { if (all(loop_bound == uint2(0u, 0u))) { break; } loop_bound -= uint2(loop_bound.y == 0u, 1u); - const bool _e9 = rq_1.Proceed(); + bool _e9 = false; + _e9 = rq_1.Proceed(); if (_e9) { } else { break; @@ -96,7 +101,7 @@ RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructu { } } - const RayIntersection rayintersection = GetCommittedIntersection(rq_1); + const RayIntersection rayintersection = GetCommittedIntersection(rq_1, naga_query_init_tracker_for_rq_1); return rayintersection; } @@ -120,7 +125,7 @@ void main() return; } -RayIntersection GetCandidateIntersection(RayQuery rq) { +RayIntersection GetCandidateIntersection(RayQuery rq, uint rq_tracker) { RayIntersection ret = (RayIntersection)0; CANDIDATE_TYPE kind = rq.CandidateType(); if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { @@ -145,11 +150,15 @@ RayIntersection GetCandidateIntersection(RayQuery rq) { void main_candidate() { RayQuery rq; + uint naga_query_init_tracker_for_rq = 0; float3 pos_2 = (0.0).xxx; float3 dir_2 = float3(0.0, 1.0, 0.0); - rq.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2))); - RayIntersection intersection_1 = GetCandidateIntersection(rq); + { + RayDesc_ naga_desc = ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2); + rq.TraceRayInline(acc_struct, naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc)); + } + RayIntersection intersection_1 = GetCandidateIntersection(rq, naga_query_init_tracker_for_rq); if ((intersection_1.kind == 3u)) { rq.CommitProceduralPrimitiveHit(10.0); return; diff --git a/naga/tests/out/hlsl/wgsl-ray-query.hlsl b/naga/tests/out/hlsl/wgsl-ray-query.hlsl index 68be09b6b01..cb0e2a723ec 100644 --- a/naga/tests/out/hlsl/wgsl-ray-query.hlsl +++ b/naga/tests/out/hlsl/wgsl-ray-query.hlsl @@ -59,22 +59,24 @@ RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 return ret; } -RayIntersection GetCommittedIntersection(RayQuery rq) { +RayIntersection GetCommittedIntersection(RayQuery rq, uint rq_tracker) { RayIntersection ret = (RayIntersection)0; - ret.kind = rq.CommittedStatus(); - if( rq.CommittedStatus() == COMMITTED_NOTHING) {} else { - ret.t = rq.CommittedRayT(); - ret.instance_custom_data = rq.CommittedInstanceID(); - ret.instance_index = rq.CommittedInstanceIndex(); - ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex(); - ret.geometry_index = rq.CommittedGeometryIndex(); - ret.primitive_index = rq.CommittedPrimitiveIndex(); - if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) { - ret.barycentrics = rq.CommittedTriangleBarycentrics(); - ret.front_face = rq.CommittedTriangleFrontFace(); + if (((rq_tracker & 4) == 4)) { + ret.kind = rq.CommittedStatus(); + if( rq.CommittedStatus() == COMMITTED_NOTHING) {} else { + ret.t = rq.CommittedRayT(); + ret.instance_custom_data = rq.CommittedInstanceID(); + ret.instance_index = rq.CommittedInstanceIndex(); + ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CommittedGeometryIndex(); + ret.primitive_index = rq.CommittedPrimitiveIndex(); + if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) { + ret.barycentrics = rq.CommittedTriangleBarycentrics(); + ret.front_face = rq.CommittedTriangleFrontFace(); + } + ret.object_to_world = rq.CommittedObjectToWorld4x3(); + ret.world_to_object = rq.CommittedWorldToObject4x3(); } - ret.object_to_world = rq.CommittedObjectToWorld4x3(); - ret.world_to_object = rq.CommittedWorldToObject4x3(); } return ret; } @@ -82,13 +84,48 @@ RayIntersection GetCommittedIntersection(RayQuery rq) { RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructure acs) { RayQuery rq_1; + uint naga_query_init_tracker_for_rq_1 = 0; - rq_1.TraceRayInline(acs, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir))); + { + RayDesc_ naga_desc = ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir); + float naga_tmin = naga_desc.tmin; + float naga_tmax = naga_desc.tmax; + float3 naga_origin = naga_desc.origin; + float3 naga_dir = naga_desc.dir; + uint naga_flags = naga_desc.flags; + bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !(((asuint(naga_tmin) & 2139095040) == 2139095040) && ((asuint(naga_tmin) & 0x7fffff) != 0)); + bool naga_tmax_valid = !(((asuint(naga_tmax) & 2139095040) == 2139095040) && ((asuint(naga_tmax) & 0x7fffff) != 0)); + bool naga_origin_valid = !any((((asuint(naga_origin) & 2139095040) == 2139095040) && ((asuint(naga_origin) & 0x7fffff) != 0))); + bool naga_dir_valid = !any((((asuint(naga_dir) & 2139095040) == 2139095040) && ((asuint(naga_dir) & 0x7fffff) != 0))); + bool naga_contains_opaque = ((naga_flags & 1) == 1); + bool naga_contains_no_opaque = ((naga_flags & 2) == 2); + bool naga_contains_cull_opaque = ((naga_flags & 64) == 64); + bool naga_contains_cull_no_opaque = ((naga_flags & 128) == 128); + bool naga_contains_cull_front = ((naga_flags & 32) == 32); + bool naga_contains_cull_back = ((naga_flags & 16) == 16); + bool naga_contains_skip_triangles = ((naga_flags & 256) == 256); + bool naga_contains_skip_aabbs = ((naga_flags & 512) == 512); + bool naga_contains_skip_triangles_aabbs = (naga_contains_skip_aabbs && naga_contains_skip_triangles) ; + bool naga_contains_skip_triangles_cull = (naga_contains_cull_front && naga_contains_skip_triangles) || (naga_contains_cull_front && naga_contains_cull_back) || (naga_contains_cull_back && naga_contains_skip_triangles) ; + bool naga_contains_multiple_opaque = (naga_contains_cull_no_opaque && naga_contains_opaque) || (naga_contains_cull_no_opaque && naga_contains_no_opaque) || (naga_contains_cull_no_opaque && naga_contains_cull_opaque) || (naga_contains_cull_opaque && naga_contains_opaque) || (naga_contains_cull_opaque && naga_contains_no_opaque) || (naga_contains_no_opaque && naga_contains_opaque) ; + if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) { + naga_query_init_tracker_for_rq_1 = naga_query_init_tracker_for_rq_1 | 1; + rq_1.TraceRayInline(acs, naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc)); + } + } uint2 loop_bound = uint2(4294967295u, 4294967295u); while(true) { if (all(loop_bound == uint2(0u, 0u))) { break; } loop_bound -= uint2(loop_bound.y == 0u, 1u); - const bool _e9 = rq_1.Proceed(); + bool _e9 = false; + { + bool naga_has_initialized = ((naga_query_init_tracker_for_rq_1 & 1) == 1); + bool naga_has_finished = ((naga_query_init_tracker_for_rq_1 & 4) == 4); + if (naga_has_initialized && !naga_has_finished) { + _e9 = rq_1.Proceed(); + naga_query_init_tracker_for_rq_1 = naga_query_init_tracker_for_rq_1 | 2; + if (!_e9) { naga_query_init_tracker_for_rq_1 = naga_query_init_tracker_for_rq_1 | 4; } + }} if (_e9) { } else { break; @@ -96,7 +133,7 @@ RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructu { } } - const RayIntersection rayintersection = GetCommittedIntersection(rq_1); + const RayIntersection rayintersection = GetCommittedIntersection(rq_1, naga_query_init_tracker_for_rq_1); return rayintersection; } @@ -120,24 +157,26 @@ void main() return; } -RayIntersection GetCandidateIntersection(RayQuery rq) { +RayIntersection GetCandidateIntersection(RayQuery rq, uint rq_tracker) { RayIntersection ret = (RayIntersection)0; - CANDIDATE_TYPE kind = rq.CandidateType(); - if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { - ret.kind = 1; - ret.t = rq.CandidateTriangleRayT(); - ret.barycentrics = rq.CandidateTriangleBarycentrics(); - ret.front_face = rq.CandidateTriangleFrontFace(); - } else { - ret.kind = 3; + if (((rq_tracker & 2) == 2) && !((rq_tracker & 4) == 4)) { + CANDIDATE_TYPE kind = rq.CandidateType(); + if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { + ret.kind = 1; + ret.t = rq.CandidateTriangleRayT(); + ret.barycentrics = rq.CandidateTriangleBarycentrics(); + ret.front_face = rq.CandidateTriangleFrontFace(); + } else { + ret.kind = 3; + } + ret.instance_custom_data = rq.CandidateInstanceID(); + ret.instance_index = rq.CandidateInstanceIndex(); + ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CandidateGeometryIndex(); + ret.primitive_index = rq.CandidatePrimitiveIndex(); + ret.object_to_world = rq.CandidateObjectToWorld4x3(); + ret.world_to_object = rq.CandidateWorldToObject4x3(); } - ret.instance_custom_data = rq.CandidateInstanceID(); - ret.instance_index = rq.CandidateInstanceIndex(); - ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex(); - ret.geometry_index = rq.CandidateGeometryIndex(); - ret.primitive_index = rq.CandidatePrimitiveIndex(); - ret.object_to_world = rq.CandidateObjectToWorld4x3(); - ret.world_to_object = rq.CandidateWorldToObject4x3(); return ret; } @@ -145,20 +184,59 @@ RayIntersection GetCandidateIntersection(RayQuery rq) { void main_candidate() { RayQuery rq; + uint naga_query_init_tracker_for_rq = 0; float3 pos_2 = (0.0).xxx; float3 dir_2 = float3(0.0, 1.0, 0.0); - rq.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2))); - RayIntersection intersection_1 = GetCandidateIntersection(rq); + { + RayDesc_ naga_desc = ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2); + float naga_tmin = naga_desc.tmin; + float naga_tmax = naga_desc.tmax; + float3 naga_origin = naga_desc.origin; + float3 naga_dir = naga_desc.dir; + uint naga_flags = naga_desc.flags; + bool naga_tmin_valid = (naga_tmin >= 0.0) && (naga_tmin <= naga_tmax) && !(((asuint(naga_tmin) & 2139095040) == 2139095040) && ((asuint(naga_tmin) & 0x7fffff) != 0)); + bool naga_tmax_valid = !(((asuint(naga_tmax) & 2139095040) == 2139095040) && ((asuint(naga_tmax) & 0x7fffff) != 0)); + bool naga_origin_valid = !any((((asuint(naga_origin) & 2139095040) == 2139095040) && ((asuint(naga_origin) & 0x7fffff) != 0))); + bool naga_dir_valid = !any((((asuint(naga_dir) & 2139095040) == 2139095040) && ((asuint(naga_dir) & 0x7fffff) != 0))); + bool naga_contains_opaque = ((naga_flags & 1) == 1); + bool naga_contains_no_opaque = ((naga_flags & 2) == 2); + bool naga_contains_cull_opaque = ((naga_flags & 64) == 64); + bool naga_contains_cull_no_opaque = ((naga_flags & 128) == 128); + bool naga_contains_cull_front = ((naga_flags & 32) == 32); + bool naga_contains_cull_back = ((naga_flags & 16) == 16); + bool naga_contains_skip_triangles = ((naga_flags & 256) == 256); + bool naga_contains_skip_aabbs = ((naga_flags & 512) == 512); + bool naga_contains_skip_triangles_aabbs = (naga_contains_skip_aabbs && naga_contains_skip_triangles) ; + bool naga_contains_skip_triangles_cull = (naga_contains_cull_front && naga_contains_skip_triangles) || (naga_contains_cull_front && naga_contains_cull_back) || (naga_contains_cull_back && naga_contains_skip_triangles) ; + bool naga_contains_multiple_opaque = (naga_contains_cull_no_opaque && naga_contains_opaque) || (naga_contains_cull_no_opaque && naga_contains_no_opaque) || (naga_contains_cull_no_opaque && naga_contains_cull_opaque) || (naga_contains_cull_opaque && naga_contains_opaque) || (naga_contains_cull_opaque && naga_contains_no_opaque) || (naga_contains_no_opaque && naga_contains_opaque) ; + if (naga_tmin_valid && naga_tmax_valid && naga_origin_valid && naga_dir_valid && !(naga_contains_skip_triangles_aabbs || naga_contains_skip_triangles_cull || naga_contains_multiple_opaque)) { + naga_query_init_tracker_for_rq = naga_query_init_tracker_for_rq | 1; + rq.TraceRayInline(acc_struct, naga_desc.flags, naga_desc.cull_mask, RayDescFromRayDesc_(naga_desc)); + } + } + RayIntersection intersection_1 = GetCandidateIntersection(rq, naga_query_init_tracker_for_rq); if ((intersection_1.kind == 3u)) { - rq.CommitProceduralPrimitiveHit(10.0); + if (((naga_query_init_tracker_for_rq & 2) == 2) && !((naga_query_init_tracker_for_rq & 4) == 4)) { + CANDIDATE_TYPE naga_kind = rq.CandidateType(); + float naga_tmin = rq.RayTMin(); + float naga_tcurrentmax = rq.CommittedRayT(); + if ((naga_kind == CANDIDATE_PROCEDURAL_PRIMITIVE) && (naga_tmin <=10.0) && (10.0 <= naga_tcurrentmax)) { + rq.CommitProceduralPrimitiveHit(10.0); + }} return; } else { if ((intersection_1.kind == 1u)) { - rq.CommitNonOpaqueTriangleHit(); + if (((naga_query_init_tracker_for_rq & 2) == 2) && !((naga_query_init_tracker_for_rq & 4) == 4)) { + CANDIDATE_TYPE naga_kind = rq.CandidateType(); + if (naga_kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { + rq.CommitNonOpaqueTriangleHit(); + }} return; } else { - rq.Abort(); + if (((naga_query_init_tracker_for_rq & 1) == 1)) { + rq.Abort(); + } return; } } diff --git a/tests/tests/wgpu-gpu/ray_tracing/shader.rs b/tests/tests/wgpu-gpu/ray_tracing/shader.rs index db880854de9..8b53292c1a9 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/shader.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/shader.rs @@ -116,7 +116,7 @@ static PREVENT_INVALID_RAY_QUERY_CALLS: GpuTestConfiguration = GpuTestConfigurat // Otherwise, mistakes in the generated code won't be caught. .instance_flags(InstanceFlags::GPU_BASED_VALIDATION) // not yet implemented in directx12 - .skip(FailureCase::backend(Backends::DX12 | Backends::METAL)), + .skip(FailureCase::backend(Backends::METAL)), ) .run_sync(prevent_invalid_ray_query_calls); diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 26ee47707ba..b22aa088828 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -291,14 +291,22 @@ impl super::Device { != layout.naga_options.zero_initialize_workgroup_memory || stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing || stage.module.runtime_checks.force_loop_bounding - != layout.naga_options.force_loop_bounding; - // Note: ray query initialization tracking not yet implemented + != layout.naga_options.force_loop_bounding + || stage + .module + .runtime_checks + .ray_query_initialization_tracking + != layout.naga_options.ray_query_initialization_tracking; let mut temp_options; let naga_options = if needs_temp_options { temp_options = layout.naga_options.clone(); temp_options.zero_initialize_workgroup_memory = stage.zero_initialize_workgroup_memory; temp_options.restrict_indexing = stage.module.runtime_checks.bounds_checks; temp_options.force_loop_bounding = stage.module.runtime_checks.force_loop_bounding; + temp_options.ray_query_initialization_tracking = stage + .module + .runtime_checks + .ray_query_initialization_tracking; &temp_options } else { &layout.naga_options @@ -1488,6 +1496,7 @@ impl crate::Device for super::Device { sampler_buffer_binding_map, external_texture_binding_map, force_loop_bounding: true, + ray_query_initialization_tracking: true, }, }) }