Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ SamplerDescriptor {
- Using both the wgpu command encoding APIs and `CommandEncoder::as_hal_mut` on the same encoder will now result in a panic.
- Allow `include_spirv!` and `include_spirv_raw!` macros to be used in constants and statics. By @clarfonthey in [#8250](https://github.com/gfx-rs/wgpu/pull/8250).

#### Naga

- Prevent UB with invalid ray query calls on spirv. By @Vecvec in [#8390](https://github.com/gfx-rs/wgpu/pull/8390).

### Bug Fixes

#### naga
Expand Down
3 changes: 3 additions & 0 deletions naga-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub struct SpirvOutParameters {
pub separate_entry_points: bool,
#[serde(deserialize_with = "deserialize_binding_map")]
pub binding_map: naga::back::spv::BindingMap,
pub ray_query_initialization_tracking: bool,
pub use_storage_input_output_16: bool,
}
impl Default for SpirvOutParameters {
Expand All @@ -126,6 +127,7 @@ impl Default for SpirvOutParameters {
force_point_size: false,
clamp_frag_depth: false,
separate_entry_points: false,
ray_query_initialization_tracking: true,
use_storage_input_output_16: true,
binding_map: naga::back::spv::BindingMap::default(),
}
Expand Down Expand Up @@ -159,6 +161,7 @@ impl SpirvOutParameters {
binding_map: self.binding_map.clone(),
zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill,
force_loop_bounding: true,
ray_query_initialization_tracking: true,
debug_info,
use_storage_input_output_16: self.use_storage_input_output_16,
}
Expand Down
46 changes: 28 additions & 18 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl Writer {
));

let clamp_id = self.id_gen.next();
body.push(Instruction::ext_inst(
body.push(Instruction::ext_inst_gl_op(
self.gl450_ext_inst_id,
spirv::GLOp::FClamp,
float_type_id,
Expand Down Expand Up @@ -1026,15 +1026,15 @@ impl BlockContext<'_> {
};

let max_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
max_op,
result_type_id,
max_id,
&[arg0_id, arg1_id],
));

MathOp::Custom(Instruction::ext_inst(
MathOp::Custom(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
min_op,
result_type_id,
Expand Down Expand Up @@ -1068,7 +1068,7 @@ impl BlockContext<'_> {
arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
}

MathOp::Custom(Instruction::ext_inst(
MathOp::Custom(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FClamp,
result_type_id,
Expand Down Expand Up @@ -1282,7 +1282,7 @@ impl BlockContext<'_> {
&self.temp_list,
));

MathOp::Custom(Instruction::ext_inst(
MathOp::Custom(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FMix,
result_type_id,
Expand Down Expand Up @@ -1339,15 +1339,15 @@ impl BlockContext<'_> {
};

let lsb_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FindILsb,
result_type_id,
lsb_id,
&[arg0_id],
));

MathOp::Custom(Instruction::ext_inst(
MathOp::Custom(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
result_type_id,
Expand Down Expand Up @@ -1388,7 +1388,7 @@ impl BlockContext<'_> {
};

let msb_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
if width != 4 {
spirv::GLOp::FindILsb
Expand Down Expand Up @@ -1445,7 +1445,7 @@ impl BlockContext<'_> {

// o = min(offset, w)
let offset_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
u32_type,
Expand All @@ -1465,7 +1465,7 @@ impl BlockContext<'_> {

// c = min(count, tmp)
let count_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
u32_type,
Expand Down Expand Up @@ -1495,7 +1495,7 @@ impl BlockContext<'_> {

// o = min(offset, w)
let offset_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
u32_type,
Expand All @@ -1515,7 +1515,7 @@ impl BlockContext<'_> {

// c = min(count, tmp)
let count_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
u32_type,
Expand Down Expand Up @@ -1610,7 +1610,7 @@ impl BlockContext<'_> {
};

block.body.push(match math_op {
MathOp::Ext(op) => Instruction::ext_inst(
MathOp::Ext(op) => Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
op,
result_type_id,
Expand All @@ -1621,7 +1621,13 @@ impl BlockContext<'_> {
});
id
}
crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
crate::Expression::LocalVariable(variable) => {
if let Some(rq_tracker) = self.function.ray_query_tracker_variables.get(&variable) {
self.ray_query_tracker_expr
.insert(expr_handle, rq_tracker.id);
}
self.function.variables[&variable].id
}
crate::Expression::Load { pointer } => {
self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
}
Expand Down Expand Up @@ -1772,6 +1778,10 @@ impl BlockContext<'_> {
crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
crate::Expression::RayQueryGetIntersection { query, committed } => {
let query_id = self.cached[query];
let init_tracker_id = *self
.ray_query_tracker_expr
.get(&query)
.expect("not a cached ray query");
let func_id = self
.writer
.write_ray_query_get_intersection_function(committed, self.ir_module);
Expand All @@ -1782,7 +1792,7 @@ impl BlockContext<'_> {
intersection_type_id,
id,
func_id,
&[query_id],
&[query_id, init_tracker_id],
));
id
}
Expand Down Expand Up @@ -2008,7 +2018,7 @@ impl BlockContext<'_> {
let max_const_id = maybe_splat_const(self.writer, max_const_id);

let clamp_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FClamp,
expr_type_id,
Expand Down Expand Up @@ -2671,7 +2681,7 @@ impl BlockContext<'_> {
});

let clamp_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
clamp_op,
wide_vector_type_id,
Expand Down Expand Up @@ -2765,7 +2775,7 @@ impl BlockContext<'_> {
let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit));

let clamp_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
clamp_op,
result_type_id,
Expand Down
6 changes: 3 additions & 3 deletions naga/src/back/spv/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ impl BlockContext<'_> {
// and negative values in a single instruction: negative values of
// `input_id` get treated as very large positive values.
let restricted_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
type_id,
Expand Down Expand Up @@ -580,7 +580,7 @@ impl BlockContext<'_> {
// and negative values in a single instruction: negative values of
// `coordinates` get treated as very large positive values.
let restricted_coordinates_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
coordinates.type_id,
Expand Down Expand Up @@ -923,7 +923,7 @@ impl BlockContext<'_> {

// Clamp the coords to the calculated margins
let clamped_coords_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::NClamp,
vec2f_type_id,
Expand Down
2 changes: 1 addition & 1 deletion naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ impl BlockContext<'_> {
// One or the other of the index or length is dynamic, so emit code for
// BoundsCheckPolicy::Restrict.
let restricted_index_id = self.gen_id();
block.body.push(Instruction::ext_inst(
block.body.push(Instruction::ext_inst_gl_op(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
self.writer.get_u32_type_id(),
Expand Down
14 changes: 12 additions & 2 deletions naga/src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,28 @@ impl super::Instruction {
instruction
}

pub(super) fn ext_inst(
pub(super) fn ext_inst_gl_op(
set_id: Word,
op: spirv::GLOp,
result_type_id: Word,
id: Word,
operands: &[Word],
) -> Self {
Self::ext_inst(set_id, op as u32, result_type_id, id, operands)
}

pub(super) fn ext_inst(
set_id: Word,
op: u32,
result_type_id: Word,
id: Word,
operands: &[Word],
) -> Self {
let mut instruction = Self::new(Op::ExtInst);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(set_id);
instruction.add_operand(op as u32);
instruction.add_operand(op);
for operand in operands {
instruction.add_operand(*operand)
}
Expand Down
49 changes: 47 additions & 2 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ struct Function {
signature: Option<Instruction>,
parameters: Vec<FunctionArgument>,
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
/// Map from a local variable that is a ray query to its u32 tracker.
ray_query_tracker_variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
/// List of local variables used as a counters to ensure that all loops are bounded.
force_loop_bounding_vars: Vec<LocalVariable>,

Expand Down Expand Up @@ -445,6 +447,16 @@ struct LookupFunctionType {
return_type_id: Word,
}

#[derive(Debug, PartialEq, Clone, Hash, Eq)]
enum LookupRayQueryFunction {
Initialize,
Proceed,
GenerateIntersection,
ConfirmIntersection,
GetVertexPositions { committed: bool },
GetIntersection { committed: bool },
}

#[derive(Debug)]
enum Dimension {
Scalar,
Expand Down Expand Up @@ -685,6 +697,10 @@ struct BlockContext<'w> {
expression_constness: ExpressionConstnessTracker,

force_loop_bounding: bool,

/// Hash from an expression whose type is a ray query / pointer to a ray query to its tracker.
/// Note: this is sparse, so can't be a handle vec
ray_query_tracker_expr: crate::FastHashMap<Handle<crate::Expression>, Word>,
}

impl BlockContext<'_> {
Expand Down Expand Up @@ -741,6 +757,7 @@ pub struct Writer {
/// The set of spirv extensions used.
extensions_used: crate::FastIndexSet<&'static str>,

debug_strings: Vec<Instruction>,
debugs: Vec<Instruction>,
annotations: Vec<Instruction>,
flags: WriterFlags,
Expand Down Expand Up @@ -773,12 +790,15 @@ pub struct Writer {
// Just a temporary list of SPIR-V ids
temp_list: Vec<Word>,

ray_get_committed_intersection_function: Option<Word>,
ray_get_candidate_intersection_function: Option<Word>,
ray_query_functions: crate::FastHashMap<LookupRayQueryFunction, Word>,

/// F16 I/O polyfill manager for handling `f16` input/output variables
/// when `StorageInputOutput16` capability is not available.
io_f16_polyfills: f16_polyfill::F16IoPolyfill,

/// Non semantic debug printf extension `OpExtInstImport`
debug_printf: Option<Word>,
pub(crate) ray_query_initialization_tracking: bool,
}

bitflags::bitflags! {
Expand Down Expand Up @@ -810,6 +830,26 @@ bitflags::bitflags! {
///
/// [`BuiltIn::FragDepth`]: crate::BuiltIn::FragDepth
const CLAMP_FRAG_DEPTH = 0x10;

/// Instead of silently failing if the arguments to generate a ray query are
/// invalid, uses debug printf extension to print to the command line
///
/// Note: VK_KHR_shader_non_semantic_info must be enabled. This will have no
/// effect if `options.ray_query_initialization_tracking` is set to false.
const PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL = 0x20;
}
}

bitflags::bitflags! {
/// How far through a ray query are we
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(super) struct RayQueryPoint: u32 {
/// Ray query has been successfully initialized.
const INITIALIZED = 1 << 0;
/// Proceed has been called on ray query.
const PROCEED = 1 << 1;
/// Proceed has returned false (have finished traversal).
const FINISHED_TRAVERSAL = 1 << 2;
}
}

Expand Down Expand Up @@ -867,6 +907,10 @@ pub struct Options<'a> {
/// to think the number of iterations is bounded.
pub force_loop_bounding: bool,

/// if set, ray queries will get a variable to track their state to prevent
/// misuse.
pub ray_query_initialization_tracking: bool,

/// Whether to use the `StorageInputOutput16` capability for `f16` shader I/O.
/// When false, `f16` I/O is polyfilled using `f32` types with conversions.
pub use_storage_input_output_16: bool,
Expand All @@ -891,6 +935,7 @@ impl Default for Options<'_> {
bounds_check_policies: BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
force_loop_bounding: true,
ray_query_initialization_tracking: true,
use_storage_input_output_16: true,
debug_info: None,
}
Expand Down
Loading