Skip to content

Commit ec5749f

Browse files
committed
Implement checked mul via Op{U,S}MulExtended (#537).
1 parent c904e5c commit ec5749f

11 files changed

Lines changed: 434 additions & 14 deletions

File tree

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,20 +1726,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
17261726
rhs: Self::Value,
17271727
) -> (Self::Value, Self::Value) {
17281728
// adopted partially from https://github.com/ziglang/zig/blob/master/src/codegen/spirv.zig
1729-
let is_add = match oop {
1730-
OverflowOp::Add => true,
1731-
OverflowOp::Sub => false,
1732-
OverflowOp::Mul => {
1733-
// NOTE(eddyb) this needs to be `undef`, not `false`/`true`, because
1734-
// we don't want the user's boolean constants to keep the zombie alive.
1735-
let bool = SpirvType::Bool.def(self.span(), self);
1736-
let overflowed = self.undef(bool);
1737-
1738-
let result = (self.mul(lhs, rhs), overflowed);
1739-
self.zombie(result.1.def(self), "checked mul is not supported yet");
1740-
return result;
1741-
}
1742-
};
17431729
let signed = match ty.kind() {
17441730
ty::Int(_) => true,
17451731
ty::Uint(_) => false,
@@ -1752,6 +1738,68 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
17521738
}
17531739
)),
17541740
};
1741+
let is_add = match oop {
1742+
OverflowOp::Add => true,
1743+
OverflowOp::Sub => false,
1744+
OverflowOp::Mul => {
1745+
let int_ty = lhs.ty;
1746+
let bits = match self.lookup_type(int_ty) {
1747+
SpirvType::Integer(width, _) => width,
1748+
other => self.fatal(format!(
1749+
"checked mul on non-integer type: {}",
1750+
other.debug(int_ty, self)
1751+
)),
1752+
};
1753+
1754+
// OpUMulExtended / OpSMulExtended produce an OpTypeStruct{T, T}
1755+
// holding the low and high halves of the full 2*bits-wide product.
1756+
// The struct is purely intermediate — we extract both halves
1757+
// immediately and never expose it to Rust code.
1758+
let pair_ty = SpirvType::Adt {
1759+
def_id: None,
1760+
size: None,
1761+
align: Align::from_bytes(0).unwrap(),
1762+
field_types: &[int_ty, int_ty],
1763+
field_offsets: &[],
1764+
field_names: None,
1765+
}
1766+
.def(self.span(), self);
1767+
1768+
let extended = if signed {
1769+
self.emit()
1770+
.s_mul_extended(pair_ty, None, lhs.def(self), rhs.def(self))
1771+
} else {
1772+
self.emit()
1773+
.u_mul_extended(pair_ty, None, lhs.def(self), rhs.def(self))
1774+
}
1775+
.unwrap();
1776+
1777+
let low = self
1778+
.emit()
1779+
.composite_extract(int_ty, None, extended, [0].iter().cloned())
1780+
.unwrap()
1781+
.with_type(int_ty);
1782+
let high = self
1783+
.emit()
1784+
.composite_extract(int_ty, None, extended, [1].iter().cloned())
1785+
.unwrap()
1786+
.with_type(int_ty);
1787+
1788+
let overflowed = if signed {
1789+
// For signed multiplication, no overflow occurs iff the high
1790+
// half is the sign extension of the low half, i.e. the
1791+
// arithmetic-shift of `low` by `bits-1` (replicating the MSB).
1792+
let shift_amount = self.constant_int(int_ty, u128::from(bits - 1));
1793+
let expected_high = self.ashr(low, shift_amount);
1794+
self.icmp(IntPredicate::IntNE, high, expected_high)
1795+
} else {
1796+
let zero = self.constant_int(int_ty, 0);
1797+
self.icmp(IntPredicate::IntNE, high, zero)
1798+
};
1799+
1800+
return (low, overflowed);
1801+
}
1802+
};
17551803

17561804
let result = if is_add {
17571805
self.add(lhs, rhs)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Verifies that checked multiplication lowers to `OpUMulExtended` /
2+
// `OpSMulExtended` and detects overflow correctly (issue #537).
3+
4+
// build-pass
5+
// compile-flags: -C llvm-args=--disassemble-fn=checked_mul::checked_mul
6+
7+
use spirv_std::spirv;
8+
9+
fn checked_mul(a: u32, b: u32, c: i32, d: i32) -> u32 {
10+
let (ur, uo) = a.overflowing_mul(b);
11+
let (sr, so) = c.overflowing_mul(d);
12+
ur ^ (sr as u32) ^ (uo as u32) ^ (so as u32)
13+
}
14+
15+
#[spirv(fragment)]
16+
pub fn main(
17+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] u_in: &[u32; 2],
18+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] s_in: &[i32; 2],
19+
out: &mut u32,
20+
) {
21+
*out = checked_mul(u_in[0], u_in[1], s_in[0], s_in[1]);
22+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpFunctionParameter %2
3+
%5 = OpFunctionParameter %2
4+
%6 = OpFunctionParameter %7
5+
%8 = OpFunctionParameter %7
6+
%9 = OpLabel
7+
OpLine %10 1239 4
8+
%11 = OpUMulExtended %12 %4 %5
9+
%13 = OpCompositeExtract %2 %11 0
10+
%14 = OpCompositeExtract %2 %11 1
11+
%15 = OpINotEqual %16 %14 %17
12+
OpLine %10 390 4
13+
%18 = OpSMulExtended %19 %6 %8
14+
%20 = OpCompositeExtract %7 %18 0
15+
%21 = OpCompositeExtract %7 %18 1
16+
%22 = OpShiftRightArithmetic %7 %20 %23
17+
%24 = OpINotEqual %16 %21 %22
18+
OpLine %25 12 9
19+
%26 = OpBitcast %2 %20
20+
OpLine %25 12 4
21+
%27 = OpBitwiseXor %2 %13 %26
22+
OpLine %25 12 23
23+
%28 = OpSelect %2 %15 %29 %17
24+
OpLine %25 12 4
25+
%30 = OpBitwiseXor %2 %27 %28
26+
OpLine %25 12 37
27+
%31 = OpSelect %2 %24 %29 %17
28+
OpLine %25 12 4
29+
%32 = OpBitwiseXor %2 %30 %31
30+
OpNoLine
31+
OpReturnValue %32
32+
OpFunctionEnd
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Tests that checked / overflowing / unchecked multiplication compile, including
2+
// the `unchecked_mul` precondition check that internally calls `overflowing_mul`
3+
// (see https://github.com/Rust-GPU/rust-gpu/issues/537).
4+
5+
// build-pass
6+
// compile-flags: -C target-feature=+Int8,+Int16,+Int64
7+
8+
use spirv_std::spirv;
9+
10+
#[spirv(fragment)]
11+
pub fn checked_mul_u8(
12+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u8,
13+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u8,
14+
out: &mut u32,
15+
) {
16+
let (r, o) = a.overflowing_mul(*b);
17+
*out = (r as u32) | ((o as u32) << 8);
18+
if let Some(v) = a.checked_mul(*b) {
19+
*out ^= v as u32;
20+
}
21+
}
22+
23+
#[spirv(fragment)]
24+
pub fn checked_mul_u16(
25+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u16,
26+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u16,
27+
out: &mut u32,
28+
) {
29+
let (r, o) = a.overflowing_mul(*b);
30+
*out = (r as u32) | ((o as u32) << 16);
31+
if let Some(v) = a.checked_mul(*b) {
32+
*out ^= v as u32;
33+
}
34+
}
35+
36+
#[spirv(fragment)]
37+
pub fn checked_mul_u32(
38+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
39+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
40+
out: &mut u32,
41+
) {
42+
let (r, o) = a.overflowing_mul(*b);
43+
*out = r ^ (o as u32);
44+
if let Some(v) = a.checked_mul(*b) {
45+
*out ^= v;
46+
}
47+
}
48+
49+
#[spirv(fragment)]
50+
pub fn checked_mul_u64(
51+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u64,
52+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u64,
53+
out: &mut u32,
54+
) {
55+
let (r, o) = a.overflowing_mul(*b);
56+
*out = (r as u32) ^ (o as u32);
57+
if let Some(v) = a.checked_mul(*b) {
58+
*out ^= v as u32;
59+
}
60+
}
61+
62+
#[spirv(fragment)]
63+
pub fn checked_mul_i8(
64+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i8,
65+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i8,
66+
out: &mut u32,
67+
) {
68+
let (r, o) = a.overflowing_mul(*b);
69+
*out = (r as u32) ^ ((o as u32) << 8);
70+
if let Some(v) = a.checked_mul(*b) {
71+
*out ^= v as u32;
72+
}
73+
}
74+
75+
#[spirv(fragment)]
76+
pub fn checked_mul_i16(
77+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i16,
78+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i16,
79+
out: &mut u32,
80+
) {
81+
let (r, o) = a.overflowing_mul(*b);
82+
*out = (r as u32) ^ ((o as u32) << 16);
83+
if let Some(v) = a.checked_mul(*b) {
84+
*out ^= v as u32;
85+
}
86+
}
87+
88+
#[spirv(fragment)]
89+
pub fn checked_mul_i32(
90+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i32,
91+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i32,
92+
out: &mut u32,
93+
) {
94+
let (r, o) = a.overflowing_mul(*b);
95+
*out = (r as u32) ^ (o as u32);
96+
if let Some(v) = a.checked_mul(*b) {
97+
*out ^= v as u32;
98+
}
99+
}
100+
101+
#[spirv(fragment)]
102+
pub fn checked_mul_i64(
103+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i64,
104+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i64,
105+
out: &mut u32,
106+
) {
107+
let (r, o) = a.overflowing_mul(*b);
108+
*out = (r as u32) ^ (o as u32);
109+
if let Some(v) = a.checked_mul(*b) {
110+
*out ^= v as u32;
111+
}
112+
}
113+
114+
// Issue #537 specifically: `unchecked_mul`'s precondition check uses
115+
// `overflowing_mul`, which previously zombied with "checked mul is not
116+
// supported yet".
117+
#[spirv(fragment)]
118+
pub fn unchecked_mul_u32(
119+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
120+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
121+
out: &mut u32,
122+
) {
123+
*out = unsafe { a.unchecked_mul(*b) };
124+
}
125+
126+
#[spirv(fragment)]
127+
pub fn unchecked_mul_i32(
128+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &i32,
129+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &i32,
130+
out: &mut u32,
131+
) {
132+
*out = unsafe { a.unchecked_mul(*b) } as u32;
133+
}
134+
135+
// The original issue used `usize::unchecked_mul()` (e.g. via `Layout::repeat`).
136+
#[spirv(fragment)]
137+
pub fn unchecked_mul_usize(
138+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
139+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
140+
out: &mut u32,
141+
) {
142+
let a = *a as usize;
143+
let b = *b as usize;
144+
*out = unsafe { a.unchecked_mul(b) } as u32;
145+
}
146+
147+
#[spirv(fragment)]
148+
pub fn checked_mul_usize(
149+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] a: &u32,
150+
#[spirv(descriptor_set = 0, binding = 1, storage_buffer)] b: &u32,
151+
out: &mut u32,
152+
) {
153+
let a = *a as usize;
154+
let b = *b as usize;
155+
let (r, o) = a.overflowing_mul(b);
156+
*out = (r as u32) ^ (o as u32);
157+
if let Some(v) = a.checked_mul(b) {
158+
*out ^= v as u32;
159+
}
160+
}

tests/difftests/tests/Cargo.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/difftests/tests/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ members = [
3434
"lang/core/ops/matrix_ops/matrix_ops-rust",
3535
"lang/core/ops/matrix_ops/matrix_ops-wgsl",
3636
"lang/core/ops/bitwise_ops/bitwise_ops-rust",
37+
"lang/core/ops/checked_mul/checked_mul-cpu",
38+
"lang/core/ops/checked_mul/checked_mul-shader",
3739
"lang/core/ops/const_fold_int/const-expr-cpu",
3840
"lang/core/ops/const_fold_int/const-expr-shader",
3941
"lang/core/ops/const_fold_int/const-fold-cpu",
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "checked_mul-cpu"
3+
edition.workspace = true
4+
5+
[lints]
6+
workspace = true
7+
8+
[dependencies]
9+
checked_mul-shader = { path = "../checked_mul-shader" }
10+
difftest.workspace = true
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use checked_mul_shader::{INPUTS, OUTPUT_LEN, PAIR_COUNT};
2+
use difftest::config::Config;
3+
4+
fn main() {
5+
let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap();
6+
7+
let mut output = vec![0u32; OUTPUT_LEN];
8+
for i in 0..PAIR_COUNT {
9+
let a = INPUTS[2 * i];
10+
let b = INPUTS[2 * i + 1];
11+
let (ur, uo) = a.overflowing_mul(b);
12+
output[4 * i] = ur;
13+
output[4 * i + 1] = uo as u32;
14+
let (sr, so) = (a as i32).overflowing_mul(b as i32);
15+
output[4 * i + 2] = sr as u32;
16+
output[4 * i + 3] = so as u32;
17+
}
18+
19+
config.write_result(&output).unwrap();
20+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "checked_mul-shader"
3+
edition.workspace = true
4+
5+
[lints]
6+
workspace = true
7+
8+
# GPU deps
9+
[dependencies]
10+
spirv-std.workspace = true
11+
12+
# CPU deps (for the test harness)
13+
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
14+
difftest.workspace = true
15+
bytemuck.workspace = true

0 commit comments

Comments
 (0)