Skip to content

Commit

Permalink
[naga msl-out hlsl-out] Improve workaround for infinite loops causing…
Browse files Browse the repository at this point in the history
… undefined behaviour (#6929)

Co-authored-by: Teodor Tanasoaia <[email protected]>
  • Loading branch information
jamienicol and teoxoy authored Jan 31, 2025
1 parent ad194a8 commit 4e7d892
Show file tree
Hide file tree
Showing 20 changed files with 223 additions and 95 deletions.
4 changes: 4 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ pub struct Options {
pub zero_initialize_workgroup_memory: bool,
/// Should we restrict indexing of vectors, matrices and arrays?
pub restrict_indexing: bool,
/// If set, loops will have code injected into them, forcing the compiler
/// to think the number of iterations is bounded.
pub force_loop_bounding: bool,
}

impl Default for Options {
Expand All @@ -302,6 +305,7 @@ impl Default for Options {
dynamic_storage_buffer_offsets_targets: std::collections::BTreeMap::new(),
zero_initialize_workgroup_memory: true,
restrict_indexing: true,
force_loop_bounding: true,
}
}
}
Expand Down
50 changes: 44 additions & 6 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.clear();
}

/// Generates statements to be inserted immediately before and at the very
/// start of the body of each loop, to defeat infinite loop reasoning.
/// The 0th item of the returned tuple should be inserted immediately prior
/// to the loop and the 1st item should be inserted at the very start of
/// the loop body.
///
/// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
fn gen_force_bounded_loop_statements(
&mut self,
level: back::Level,
) -> Option<(String, String)> {
if !self.options.force_loop_bounding {
return None;
}

let loop_bound_name = self.namer.call("loop_bound");
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
let level = level.next();
let max = u32::MAX;
let break_and_inc = format!(
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
);

Some((decl, break_and_inc))
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
Expand Down Expand Up @@ -2162,12 +2189,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ref continuing,
break_if,
} => {
let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
let gate_name = (!continuing.is_empty() || break_if.is_some())
.then(|| self.namer.call("loop_init"));

if let Some((ref decl, _)) = force_loop_bound_statements {
writeln!(self.out, "{decl}")?;
}
if let Some(ref gate_name) = gate_name {
writeln!(self.out, "{level}bool {gate_name} = true;")?;
}

self.continue_ctx.enter_loop();
writeln!(self.out, "{level}while(true) {{")?;
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
writeln!(self.out, "{break_and_inc}")?;
}
let l2 = level.next();
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(self.out, "{level}while(true) {{")?;
if let Some(gate_name) = gate_name {
writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
let l3 = l2.next();
for sta in continuing.iter() {
Expand All @@ -2182,13 +2221,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
writeln!(self.out, "{l2}}}")?;
writeln!(self.out, "{l2}{gate_name} = false;")?;
} else {
writeln!(self.out, "{level}while(true) {{")?;
}

for sta in body.iter() {
self.write_stmt(module, sta, func_ctx, l2)?;
}

writeln!(self.out, "{level}}}")?;
self.continue_ctx.exit_loop();
}
Expand Down
117 changes: 59 additions & 58 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,6 @@ pub struct Writer<W> {
/// Set of (struct type, struct field index) denoting which fields require
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,

/// Name of the force-bounded-loop macro.
///
/// See `emit_force_bounded_loop_macro` for details.
force_bounded_loop_macro_name: String,
}

impl crate::Scalar {
Expand Down Expand Up @@ -601,7 +596,7 @@ struct ExpressionContext<'a> {
/// accesses. These may need to be cached in temporary variables. See
/// `index::find_checked_indexes` for details.
guarded_indices: HandleSet<crate::Expression>,
/// See [`Writer::emit_force_bounded_loop_macro`] for details.
/// See [`Writer::gen_force_bounded_loop_statements`] for details.
force_loop_bounding: bool,
}

Expand Down Expand Up @@ -685,7 +680,6 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
force_bounded_loop_macro_name: String::default(),
}
}

Expand All @@ -696,17 +690,11 @@ impl<W: Write> Writer<W> {
self.out
}

/// Define a macro to invoke at the bottom of each loop body, to
/// defeat MSL infinite loop reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked at the end of each loop body in the generated MSL,
/// to ensure that the MSL compiler's optimizations do not remove bounds
/// checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
/// prior macro definition, since macros aren't block-scoped.
/// Generates statements to be inserted immediately before and at the very
/// start of the body of each loop, to defeat MSL infinite loop reasoning.
/// The 0th item of the returned tuple should be inserted immediately prior
/// to the loop and the 1st item should be inserted at the very start of
/// the loop body.
///
/// # What is this trying to solve?
///
Expand Down Expand Up @@ -774,7 +762,8 @@ impl<W: Write> Writer<W> {
/// but which in fact generates no instructions. Unfortunately, inline
/// assembly is not handled correctly by some Metal device drivers.
///
/// Instead, we add the following code to the bottom of every loop:
/// A previously used approach was to add the following code to the bottom
/// of every loop:
///
/// ```ignore
/// if (volatile bool unpredictable = false; unpredictable)
Expand All @@ -785,37 +774,47 @@ impl<W: Write> Writer<W> {
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
/// it must assume that the `break` might be reached, and hence that the
/// loop is not unbounded. This prevents the range analysis impact described
/// above.
/// above. Unfortunately this prevented the compiler from making important,
/// and safe, optimizations such as loop unrolling and was observed to
/// significantly hurt performance.
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, in every iteration of the loop. There's no part of the system
/// that has a global enough view to be sure that `unpredictable` is true,
/// and remove it from the code. Adding the branch also affects
/// optimization: for example, it's impossible to unroll this loop. This
/// transformation has been observed to significantly hurt performance.
/// Our current approach declares a counter before every loop and
/// increments it every iteration, breaking after 2^64 iterations:
///
/// ```ignore
/// uint2 loop_bound = uint2(0);
/// while (true) {
/// if (metal::all(loop_bound == uint2(4294967295))) { break; }
/// loop_bound += uint2(loop_bound.y == 4294967295, 1);
/// }
/// ```
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
/// This convinces the compiler that the loop is finite and therefore may
/// execute, whilst at the same time allowing optimizations such as loop
/// unrolling. Furthermore the 64-bit counter is large enough it seems
/// implausible that it would affect the execution of any shader.
///
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
if !self.force_bounded_loop_macro_name.is_empty() {
return Ok(());
/// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
fn gen_force_bounded_loop_statements(
&mut self,
level: back::Level,
context: &StatementContext,
) -> Option<(String, String)> {
if !context.expression.force_loop_bounding {
return None;
}

self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
writeln!(
self.out,
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
self.force_bounded_loop_macro_name,
loop_bounded_volatile_name,
loop_bounded_volatile_name,
)?;
let loop_bound_name = self.namer.call("loop_bound");
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
let level = level.next();
let max = u32::MAX;
let break_and_inc = format!(
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
);

Ok(())
Some((decl, break_and_inc))
}

fn put_call_parameters(
Expand Down Expand Up @@ -3201,10 +3200,23 @@ impl<W: Write> Writer<W> {
ref continuing,
break_if,
} => {
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
let force_loop_bound_statements =
self.gen_force_bounded_loop_statements(level, context);
let gate_name = (!continuing.is_empty() || break_if.is_some())
.then(|| self.namer.call("loop_init"));

if let Some((ref decl, _)) = force_loop_bound_statements {
writeln!(self.out, "{decl}")?;
}
if let Some(ref gate_name) = gate_name {
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(self.out, "{level}while(true) {{",)?;
}

writeln!(self.out, "{level}while(true) {{",)?;
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
writeln!(self.out, "{break_and_inc}")?;
}
if let Some(ref gate_name) = gate_name {
let lif = level.next();
let lcontinuing = lif.next();
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
Expand All @@ -3218,19 +3230,9 @@ impl<W: Write> Writer<W> {
}
writeln!(self.out, "{lif}}}")?;
writeln!(self.out, "{lif}{gate_name} = false;")?;
} else {
writeln!(self.out, "{level}while(true) {{",)?;
}
self.put_block(level.next(), body, context)?;
if context.expression.force_loop_bounding {
self.emit_force_bounded_loop_macro()?;
writeln!(
self.out,
"{}{}",
level.next(),
self.force_bounded_loop_macro_name
)?;
}

writeln!(self.out, "{level}}}")?;
}
crate::Statement::Break => {
Expand Down Expand Up @@ -3724,7 +3726,6 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names,
);
self.force_bounded_loop_macro_name.clear();
self.struct_member_pads.clear();

writeln!(
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/out/hlsl/boids.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
vPos = _e8;
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
vVel = _e14;
uint2 loop_bound = uint2(0u, 0u);
bool loop_init = true;
while(true) {
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
if (!loop_init) {
uint _e91 = i;
i = (_e91 + 1u);
Expand Down
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/break-if.hlsl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
void breakIfEmpty()
{
uint2 loop_bound = uint2(0u, 0u);
bool loop_init = true;
while(true) {
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
if (!loop_init) {
if (true) {
break;
Expand All @@ -17,8 +20,11 @@ void breakIfEmptyBody(bool a)
bool b = (bool)0;
bool c = (bool)0;

uint2 loop_bound_1 = uint2(0u, 0u);
bool loop_init_1 = true;
while(true) {
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
if (!loop_init_1) {
b = a;
bool _e2 = b;
Expand All @@ -38,8 +44,11 @@ void breakIf(bool a_1)
bool d = (bool)0;
bool e = (bool)0;

uint2 loop_bound_2 = uint2(0u, 0u);
bool loop_init_2 = true;
while(true) {
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
if (!loop_init_2) {
bool _e5 = e;
if ((a_1 == _e5)) {
Expand All @@ -58,8 +67,11 @@ void breakIfSeparateVariable()
{
uint counter = 0u;

uint2 loop_bound_3 = uint2(0u, 0u);
bool loop_init_3 = true;
while(true) {
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
if (!loop_init_3) {
uint _e5 = counter;
if ((_e5 == 5u)) {
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/out/hlsl/collatz.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ uint collatz_iterations(uint n_base)
uint i = 0u;

n = n_base;
uint2 loop_bound = uint2(0u, 0u);
while(true) {
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
uint _e4 = n;
if ((_e4 > 1u)) {
} else {
Expand Down
Loading

0 comments on commit 4e7d892

Please sign in to comment.