Skip to content

Commit 881da16

Browse files
committed
coop: Rework MulAdd into a statement, implement Load/Store
1 parent 4f614bc commit 881da16

35 files changed

+758
-305
lines changed

naga/src/back/dot/mod.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,33 @@ impl StatementGraph {
403403
},
404404
}
405405
}
406+
S::Cooperation { target, ref op } => {
407+
self.dependencies.push((id, target, "target"));
408+
match *op {
409+
crate::CooperativeOperation::LoadStore {
410+
store,
411+
pointer,
412+
stride,
413+
row_major: _,
414+
} => {
415+
self.dependencies.push((id, pointer, "pointer"));
416+
if let Some(stride) = stride {
417+
self.dependencies.push((id, stride, "stride"));
418+
}
419+
if store {
420+
"Store"
421+
} else {
422+
"Load"
423+
}
424+
}
425+
crate::CooperativeOperation::MultiplyAdd { a, b, c } => {
426+
self.dependencies.push((id, a, "a"));
427+
self.dependencies.push((id, b, "b"));
428+
self.dependencies.push((id, c, "c"));
429+
"MultiplyAdd"
430+
}
431+
}
432+
}
406433
};
407434
// Set the last node to the merge node
408435
last_node = merge_id;
@@ -742,12 +769,6 @@ fn write_function_expressions(
742769
let ty = if committed { "Committed" } else { "Candidate" };
743770
(format!("get{ty}HitVertexPositions").into(), 4)
744771
}
745-
E::MulAdd { a, b, c } => {
746-
edges.insert("a", a);
747-
edges.insert("b", b);
748-
edges.insert("c", c);
749-
("MulAdd".into(), 6)
750-
}
751772
};
752773

753774
// give uniform expressions an outline

naga/src/back/glsl/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,7 @@ impl<'a, W: Write> Writer<'a, W> {
28052805
}
28062806
writeln!(self.out, ");")?;
28072807
}
2808+
Statement::Cooperation { .. } => unimplemented!(),
28082809
}
28092810

28102811
Ok(())
@@ -4341,8 +4342,7 @@ impl<'a, W: Write> Writer<'a, W> {
43414342
}
43424343
// not supported yet
43434344
Expression::RayQueryGetIntersection { .. }
4344-
| Expression::RayQueryVertexPositions { .. }
4345-
| Expression::MulAdd { .. } => unreachable!(),
4345+
| Expression::RayQueryVertexPositions { .. } => unreachable!(),
43464346
}
43474347

43484348
Ok(())

naga/src/back/hlsl/writer.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2747,6 +2747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
27472747
}
27482748
writeln!(self.out, ");")?;
27492749
}
2750+
Statement::Cooperation { .. } => unimplemented!(),
27502751
}
27512752

27522753
Ok(())
@@ -4275,7 +4276,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
42754276
}
42764277
}
42774278
// Not supported yet
4278-
Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => {
4279+
Expression::RayQueryVertexPositions { .. } => {
42794280
unreachable!()
42804281
}
42814282
// Nothing to do here, since call expression already cached

naga/src/back/mod.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,18 +311,6 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314-
impl crate::TypeInner {
315-
/// Returns true if this is a handle to a type rather than the type directly.
316-
pub const fn is_handle(&self) -> bool {
317-
match *self {
318-
crate::TypeInner::Image { .. }
319-
| crate::TypeInner::Sampler { .. }
320-
| crate::TypeInner::AccelerationStructure { .. } => true,
321-
_ => false,
322-
}
323-
}
324-
}
325-
326314
impl crate::Statement {
327315
/// Returns true if the statement directly terminates the current block.
328316
///

naga/src/back/msl/writer.rs

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -543,14 +543,6 @@ impl crate::Scalar {
543543
}
544544
}
545545

546-
impl crate::CooperativeScalar {
547-
const fn to_msl_name(self) -> &'static str {
548-
match self {
549-
Self::F32 => "float",
550-
}
551-
}
552-
}
553-
554546
const fn separate(need_separator: bool) -> &'static str {
555547
if need_separator {
556548
","
@@ -2842,13 +2834,6 @@ impl<W: Write> Writer<W> {
28422834
}
28432835
write!(self.out, "}}")?;
28442836
}
2845-
crate::Expression::MulAdd { a, b, c } => {
2846-
self.put_expression(a, context, false)?;
2847-
write!(self.out, " * ")?;
2848-
self.put_expression(b, context, false)?;
2849-
write!(self.out, " + ")?;
2850-
self.put_expression(c, context, false)?;
2851-
}
28522837
}
28532838
Ok(())
28542839
}
@@ -4230,6 +4215,65 @@ impl<W: Write> Writer<W> {
42304215
}
42314216
writeln!(self.out, ");")?;
42324217
}
4218+
crate::Statement::Cooperation { target, ref op } => {
4219+
write!(self.out, "{level}")?;
4220+
match *op {
4221+
crate::CooperativeOperation::LoadStore {
4222+
store,
4223+
pointer,
4224+
stride,
4225+
row_major,
4226+
} => {
4227+
let op_str = if store { "store" } else { "load" };
4228+
write!(self.out, "simdgroup_{op_str}(")?;
4229+
self.put_expression(target, &context.expression, true)?;
4230+
write!(self.out, ", ")?;
4231+
self.put_expression(pointer, &context.expression, true)?;
4232+
if stride.is_some() || row_major {
4233+
write!(self.out, ", ")?;
4234+
match stride {
4235+
Some(expression) => {
4236+
self.put_expression(expression, &context.expression, true)?;
4237+
}
4238+
None => {
4239+
let default_stride =
4240+
match *context.expression.resolve_type(target) {
4241+
crate::TypeInner::CooperativeMatrix {
4242+
columns,
4243+
rows,
4244+
..
4245+
} => {
4246+
if row_major {
4247+
columns as u32
4248+
} else {
4249+
rows as u32
4250+
}
4251+
}
4252+
_ => 0,
4253+
};
4254+
write!(self.out, "{default_stride}")?;
4255+
}
4256+
}
4257+
}
4258+
if row_major {
4259+
let matrix_origin = "0";
4260+
let transpose = true;
4261+
write!(self.out, ", {matrix_origin}, {transpose}")?;
4262+
}
4263+
}
4264+
crate::CooperativeOperation::MultiplyAdd { a, b, c } => {
4265+
write!(self.out, "simdgroup_multiply_accumulate(")?;
4266+
self.put_expression(target, &context.expression, true)?;
4267+
write!(self.out, ", ")?;
4268+
self.put_expression(a, &context.expression, true)?;
4269+
write!(self.out, ", ")?;
4270+
self.put_expression(b, &context.expression, true)?;
4271+
write!(self.out, ", ")?;
4272+
self.put_expression(c, &context.expression, true)?;
4273+
}
4274+
}
4275+
writeln!(self.out, ");")?;
4276+
}
42334277
}
42344278
}
42354279

naga/src/back/pipeline_constants.rs

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -633,15 +633,6 @@ fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut E
633633
} => {
634634
adjust(query);
635635
}
636-
Expression::MulAdd {
637-
ref mut a,
638-
ref mut b,
639-
ref mut c,
640-
} => {
641-
adjust(a);
642-
adjust(b);
643-
adjust(c);
644-
}
645636
}
646637
}
647638

@@ -844,6 +835,34 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
844835
crate::RayQueryFunction::Terminate => {}
845836
}
846837
}
838+
Statement::Cooperation {
839+
ref mut target,
840+
ref mut op,
841+
} => {
842+
adjust(target);
843+
match *op {
844+
crate::CooperativeOperation::LoadStore {
845+
store: _,
846+
ref mut pointer,
847+
ref mut stride,
848+
row_major: _,
849+
} => {
850+
adjust(pointer);
851+
if let Some(ref mut stride) = *stride {
852+
adjust(stride);
853+
}
854+
}
855+
crate::CooperativeOperation::MultiplyAdd {
856+
ref mut a,
857+
ref mut b,
858+
ref mut c,
859+
} => {
860+
adjust(a);
861+
adjust(b);
862+
adjust(c);
863+
}
864+
}
865+
}
847866
Statement::Break
848867
| Statement::Continue
849868
| Statement::Kill

naga/src/back/spv/block.rs

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,17 +1805,6 @@ impl BlockContext<'_> {
18051805
)?;
18061806
self.write_ray_query_return_vertex_position(query, block, committed)
18071807
}
1808-
crate::Expression::MulAdd { a, b, c } => {
1809-
let id = self.gen_id();
1810-
block.body.push(Instruction::coop_mul_add(
1811-
result_type_id,
1812-
id,
1813-
self.cached[a],
1814-
self.cached[b],
1815-
self.cached[c],
1816-
));
1817-
id
1818-
}
18191808
};
18201809

18211810
self.cached[expr_handle] = id;
@@ -3677,6 +3666,51 @@ impl BlockContext<'_> {
36773666
} => {
36783667
self.write_subgroup_gather(mode, argument, result, &mut block)?;
36793668
}
3669+
Statement::Cooperation { target, ref op } => {
3670+
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3671+
match *op {
3672+
crate::CooperativeOperation::LoadStore {
3673+
store,
3674+
pointer,
3675+
stride,
3676+
row_major,
3677+
} => {
3678+
let layout = if row_major {
3679+
spirv::CooperativeMatrixLayout::RowMajorKHR
3680+
} else {
3681+
spirv::CooperativeMatrixLayout::ColumnMajorKHR
3682+
};
3683+
let layout_id = self.get_index_constant(layout as u32);
3684+
let stride_id = stride.map(|exp| self.cached[exp]);
3685+
block.body.push(if store {
3686+
Instruction::coop_store(
3687+
result_type_id,
3688+
self.cached[target],
3689+
self.cached[pointer],
3690+
layout_id,
3691+
stride_id,
3692+
)
3693+
} else {
3694+
Instruction::coop_load(
3695+
result_type_id,
3696+
self.cached[target],
3697+
self.cached[pointer],
3698+
layout_id,
3699+
stride_id,
3700+
)
3701+
});
3702+
}
3703+
crate::CooperativeOperation::MultiplyAdd { a, b, c } => {
3704+
block.body.push(Instruction::coop_mul_add(
3705+
result_type_id,
3706+
self.cached[target],
3707+
self.cached[a],
3708+
self.cached[b],
3709+
self.cached[c],
3710+
));
3711+
}
3712+
}
3713+
}
36803714
}
36813715
}
36823716

naga/src/back/spv/instructions.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,42 @@ impl super::Instruction {
12471247
}
12481248

12491249
// Cooperative operations
1250+
pub(super) fn coop_load(
1251+
result_type_id: Word,
1252+
id: Word,
1253+
pointer_id: Word,
1254+
layout_id: Word,
1255+
stride_id: Option<Word>,
1256+
) -> Self {
1257+
let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR);
1258+
instruction.set_type(result_type_id);
1259+
instruction.set_result(id);
1260+
instruction.add_operand(pointer_id);
1261+
instruction.add_operand(layout_id);
1262+
if let Some(stride_id) = stride_id {
1263+
instruction.add_operand(stride_id);
1264+
}
1265+
1266+
instruction
1267+
}
1268+
pub(super) fn coop_store(
1269+
result_type_id: Word,
1270+
id: Word,
1271+
pointer_id: Word,
1272+
layout_id: Word,
1273+
stride_id: Option<Word>,
1274+
) -> Self {
1275+
let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR);
1276+
instruction.set_type(result_type_id);
1277+
instruction.set_result(id);
1278+
instruction.add_operand(pointer_id);
1279+
instruction.add_operand(layout_id);
1280+
if let Some(stride_id) = stride_id {
1281+
instruction.add_operand(stride_id);
1282+
}
1283+
1284+
instruction
1285+
}
12501286
pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self {
12511287
let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR);
12521288
instruction.set_type(result_type_id);

naga/src/back/spv/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ enum CooperativeType {
344344
Matrix {
345345
columns: crate::CooperativeSize,
346346
rows: crate::CooperativeSize,
347-
scalar: crate::CooperativeScalar,
347+
scalar: crate::Scalar,
348348
role: crate::CooperativeRole,
349349
},
350350
}

naga/src/back/spv/writer.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,6 @@ impl Writer {
374374
})
375375
}
376376

377-
pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word {
378-
match scalar {
379-
crate::CooperativeScalar::F32 => self.get_f32_type_id(),
380-
}
381-
}
382-
383377
pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
384378
let f32_id = self.get_f32_type_id();
385379
self.get_pointer_type_id(f32_id, class)
@@ -1384,7 +1378,8 @@ impl Writer {
13841378
scalar,
13851379
role,
13861380
} => {
1387-
let scalar_id = self.get_cooperative_type_id(scalar);
1381+
let scalar_id =
1382+
self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar)));
13881383
let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
13891384
let columns_id = self.get_index_constant(columns as u32);
13901385
let rows_id = self.get_index_constant(rows as u32);

0 commit comments

Comments
 (0)