Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga spv-out] Ensure loops generated by SPIRV backend are bounded #7080

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ pub enum PollError {
By @cwfitzgerald in [#6942](https://github.com/gfx-rs/wgpu/pull/6942).
By @cwfitzgerald in [#7030](https://github.com/gfx-rs/wgpu/pull/7030).

#### Naga

##### Ensure loops generated by SPIR-V and HLSL Naga backends are bounded

Make sure that all loops in shaders generated by these naga backends are bounded
to avoid undefined behaviour due to infinite loops. Note that this may have a
performance cost. As with the existing implementation for the MSL backend this
can be disabled by using `Device::create_shader_module_trusted()`.

By @jamienicol in [#6929](https://github.com/gfx-rs/wgpu/pull/6929) and [#7080](https://github.com/gfx-rs/wgpu/pull/7080).

### New Features

#### General
Expand Down
153 changes: 153 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,155 @@ impl Writer {
}

impl BlockContext<'_> {
/// Generates code to ensure that a loop is bounded. Should be called immediately
/// after adding the OpLoopMerge instruction to `block`. This function will
/// [`consume()`](crate::back::spv::Function::consume) `block` and append its
/// instructions to a new [`Block`], which will be returned to the caller for it to
/// consumed prior to writing the loop body.
///
/// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
/// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
/// declare the required variables.
///
/// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
/// of why this is required.
fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
let uint_type_id = self.writer.get_uint_type_id();
let uint2_type_id = self.writer.get_uint2_type_id();
let uint2_ptr_type_id = self
.writer
.get_uint2_pointer_type_id(spirv::StorageClass::Function);
let bool_type_id = self.writer.get_bool_type_id();
let bool2_type_id = self.writer.get_bool2_type_id();
let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
let zero_uint2_const_id = self.writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
})),
&[zero_uint_const_id, zero_uint_const_id],
);
let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
let max_uint_const_id = self
.writer
.get_constant_scalar(crate::Literal::U32(u32::MAX));
let max_uint2_const_id = self.writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
})),
&[max_uint_const_id, max_uint_const_id],
);

let loop_counter_var_id = self.gen_id();
if self.writer.flags.contains(WriterFlags::DEBUG) {
self.writer
.debugs
.push(Instruction::name(loop_counter_var_id, "loop_bound"));
}
let var = super::LocalVariable {
id: loop_counter_var_id,
instruction: Instruction::variable(
uint2_ptr_type_id,
loop_counter_var_id,
spirv::StorageClass::Function,
Some(zero_uint2_const_id),
),
};
self.function.force_loop_bounding_vars.push(var);

let break_if_block = self.gen_id();

self.function
.consume(block, Instruction::branch(break_if_block));
block = Block::new(break_if_block);

// Load the current loop counter value from its variable. We use a vec2<u32> to
// simulate a 64-bit counter.
let load_id = self.gen_id();
block.body.push(Instruction::load(
uint2_type_id,
load_id,
loop_counter_var_id,
None,
));

// If both the high and low u32s have reached u32::MAX then break. ie
// if (all(eq(loop_counter, vec2(u32::MAX)))) { break; }
let eq_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool2_type_id,
eq_id,
max_uint2_const_id,
load_id,
));
let all_eq_id = self.gen_id();
block.body.push(Instruction::relational(
spirv::Op::All,
bool_type_id,
all_eq_id,
eq_id,
));

let inc_counter_block_id = self.gen_id();
block.body.push(Instruction::selection_merge(
inc_counter_block_id,
spirv::SelectionControl::empty(),
));
self.function.consume(
block,
Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
);
block = Block::new(inc_counter_block_id);

// To simulate a 64-bit counter we always increment the low u32, and increment
// the high u32 when the low u32 overflows. ie
// counter += vec2(select(0u, 1u, counter.y == u32::MAX), 1u);
let low_id = self.gen_id();
block.body.push(Instruction::composite_extract(
uint_type_id,
low_id,
load_id,
&[1],
));
let low_overflow_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
low_overflow_id,
low_id,
max_uint_const_id,
));
let carry_bit_id = self.gen_id();
block.body.push(Instruction::select(
uint_type_id,
carry_bit_id,
low_overflow_id,
one_uint_const_id,
zero_uint_const_id,
));
let increment_id = self.gen_id();
block.body.push(Instruction::composite_construct(
uint2_type_id,
increment_id,
&[carry_bit_id, one_uint_const_id],
));
let result_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IAdd,
uint2_type_id,
result_id,
load_id,
increment_id,
));
block
.body
.push(Instruction::store(loop_counter_var_id, result_id, None));

block
}

/// Cache an expression for a value.
pub(super) fn cache_expression_value(
&mut self,
Expand Down Expand Up @@ -2558,6 +2707,10 @@ impl BlockContext<'_> {
continuing_id,
spirv::SelectionControl::NONE,
));

if self.force_loop_bounding {
block = self.write_force_bounded_loop_instructions(block, merge_id);
}
self.function.consume(block, Instruction::branch(body_id));

// We can ignore the `BlockExitDisposition` returned here because,
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ struct Function {
signature: Option<Instruction>,
parameters: Vec<FunctionArgument>,
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
/// List of local variables used as a counters to ensure that all loops are bounded.
force_loop_bounding_vars: Vec<LocalVariable>,

/// A map taking an expression that yields a composite value (array, matrix)
/// to the temporary variables we have spilled it to, if any. Spilling
Expand Down Expand Up @@ -726,6 +728,8 @@ struct BlockContext<'w> {

/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
expression_constness: ExpressionConstnessTracker,

force_loop_bounding: bool,
}

impl BlockContext<'_> {
Expand Down Expand Up @@ -779,6 +783,7 @@ pub struct Writer {
flags: WriterFlags,
bounds_check_policies: BoundsCheckPolicies,
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
force_loop_bounding: bool,
void_type: Word,
//TODO: convert most of these into vectors, addressable by handle indices
lookup_type: crate::FastHashMap<LookupType, Word>,
Expand Down Expand Up @@ -882,6 +887,10 @@ pub struct Options<'a> {
/// Dictates the way workgroup variables should be zero initialized
pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,

/// If set, loops will have code injected into them, forcing the compiler
/// to think the number of iterations is bounded.
pub force_loop_bounding: bool,

pub debug_info: Option<DebugInfo<'a>>,
}

Expand All @@ -900,6 +909,7 @@ impl Default for Options<'_> {
capabilities: None,
bounds_check_policies: BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
force_loop_bounding: true,
debug_info: None,
}
}
Expand Down
33 changes: 33 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl Function {
for local_var in self.variables.values() {
local_var.instruction.to_words(sink);
}
for local_var in self.force_loop_bounding_vars.iter() {
local_var.instruction.to_words(sink);
}
for internal_var in self.spilled_composites.values() {
internal_var.instruction.to_words(sink);
}
Expand Down Expand Up @@ -71,6 +74,7 @@ impl Writer {
flags: options.flags,
bounds_check_policies: options.bounds_check_policies,
zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
force_loop_bounding: options.force_loop_bounding,
void_type,
lookup_type: crate::FastHashMap::default(),
lookup_function: crate::FastHashMap::default(),
Expand Down Expand Up @@ -111,6 +115,7 @@ impl Writer {
flags: self.flags,
bounds_check_policies: self.bounds_check_policies,
zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
force_loop_bounding: self.force_loop_bounding,
capabilities_available: take(&mut self.capabilities_available),
binding_map: take(&mut self.binding_map),

Expand Down Expand Up @@ -267,6 +272,14 @@ impl Writer {
self.get_type_id(local_type.into())
}

pub(super) fn get_uint2_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
});
self.get_type_id(local_type.into())
}

pub(super) fn get_uint3_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
Expand All @@ -283,6 +296,17 @@ impl Writer {
self.get_type_id(local_type.into())
}

pub(super) fn get_uint2_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let local_type = LocalType::LocalPointer {
base: NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
},
class,
};
self.get_type_id(local_type.into())
}

pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let local_type = LocalType::LocalPointer {
base: NumericType::Vector {
Expand All @@ -299,6 +323,14 @@ impl Writer {
self.get_type_id(local_type.into())
}

pub(super) fn get_bool2_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::BOOL,
});
self.get_type_id(local_type.into())
}

pub(super) fn get_bool3_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
Expand Down Expand Up @@ -839,6 +871,7 @@ impl Writer {

// Steal the Writer's temp list for a bit.
temp_list: std::mem::take(&mut self.temp_list),
force_loop_bounding: self.force_loop_bounding,
writer: self,
expression_constness: super::ExpressionConstnessTracker::from_arena(
&ir_function.expressions,
Expand Down
50 changes: 37 additions & 13 deletions naga/tests/out/spv/6220-break-from-loop.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 26
; Bound: 46
OpCapability Shader
OpCapability Linkage
%1 = OpExtInstImport "GLSL.std.450"
Expand All @@ -13,31 +13,55 @@ OpMemoryModel Logical GLSL450
%8 = OpConstant %3 4
%9 = OpConstant %3 1
%11 = OpTypePointer Function %3
%18 = OpTypeBool
%17 = OpTypeInt 32 0
%18 = OpTypeVector %17 2
%19 = OpTypePointer Function %18
%20 = OpTypeBool
%21 = OpTypeVector %20 2
%22 = OpConstant %17 0
%23 = OpConstantComposite %18 %22 %22
%24 = OpConstant %17 1
%25 = OpConstant %17 4294967295
%26 = OpConstantComposite %18 %25 %25
%5 = OpFunction %2 None %6
%4 = OpLabel
%10 = OpVariable %11 Function %7
%27 = OpVariable %19 Function %23
OpBranch %12
%12 = OpLabel
OpBranch %13
%13 = OpLabel
OpLoopMerge %14 %16 None
OpBranch %28
%28 = OpLabel
%29 = OpLoad %18 %27
%30 = OpIEqual %21 %26 %29
%31 = OpAll %20 %30
OpSelectionMerge %32 None
OpBranchConditional %31 %14 %32
%32 = OpLabel
%33 = OpCompositeExtract %17 %29 1
%34 = OpIEqual %20 %33 %25
%35 = OpSelect %17 %34 %24 %22
%36 = OpCompositeConstruct %18 %35 %24
%37 = OpIAdd %18 %29 %36
OpStore %27 %37
OpBranch %15
%15 = OpLabel
%17 = OpLoad %3 %10
%19 = OpSLessThan %18 %17 %8
OpSelectionMerge %20 None
OpBranchConditional %19 %20 %21
%21 = OpLabel
%38 = OpLoad %3 %10
%39 = OpSLessThan %20 %38 %8
OpSelectionMerge %40 None
OpBranchConditional %39 %40 %41
%41 = OpLabel
OpBranch %14
%20 = OpLabel
OpBranch %22
%22 = OpLabel
%40 = OpLabel
OpBranch %42
%42 = OpLabel
OpBranch %14
%16 = OpLabel
%24 = OpLoad %3 %10
%25 = OpIAdd %3 %24 %9
OpStore %10 %25
%44 = OpLoad %3 %10
%45 = OpIAdd %3 %44 %9
OpStore %10 %45
OpBranch %13
%14 = OpLabel
OpReturn
Expand Down
Loading