Skip to content

Commit ab2ce10

Browse files
committed
add support for CooperativeMatrix
fix test
1 parent 15e705a commit ab2ce10

16 files changed

Lines changed: 381 additions & 112 deletions

File tree

Cargo.lock

Lines changed: 8 additions & 46 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 91 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
55
use crate::codegen_cx::CodegenCx;
6+
use crate::maybe_pqp_cg_ssa::traits::ConstCodegenMethods as _;
67
use crate::spirv_type::SpirvType;
78
use itertools::Itertools;
89
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
@@ -807,6 +808,48 @@ fn trans_intrinsic_type<'tcx>(
807808
args: GenericArgsRef<'tcx>,
808809
intrinsic_type_attr: IntrinsicType,
809810
) -> Result<Word, ErrorGuaranteed> {
811+
trait FromScalarInt: Sized {
812+
fn from_scalar_int(n: ScalarInt) -> Option<Self>;
813+
}
814+
815+
impl FromScalarInt for u32 {
816+
fn from_scalar_int(n: ScalarInt) -> Option<Self> {
817+
Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
818+
}
819+
}
820+
821+
impl FromScalarInt for Dim {
822+
fn from_scalar_int(n: ScalarInt) -> Option<Self> {
823+
Dim::from_u32(u32::from_scalar_int(n)?)
824+
}
825+
}
826+
827+
impl FromScalarInt for ImageFormat {
828+
fn from_scalar_int(n: ScalarInt) -> Option<Self> {
829+
ImageFormat::from_u32(u32::from_scalar_int(n)?)
830+
}
831+
}
832+
833+
fn const_int_value<'tcx, P: FromScalarInt>(
834+
cx: &CodegenCx<'tcx>,
835+
const_: Const<'tcx>,
836+
) -> Result<P, ErrorGuaranteed> {
837+
let ty::Value {
838+
ty: const_ty,
839+
valtree: const_val,
840+
} = const_.to_value();
841+
assert!(const_ty.is_integral());
842+
const_val
843+
.try_to_scalar()
844+
.and_then(|scalar| scalar.try_to_scalar_int().ok())
845+
.and_then(P::from_scalar_int)
846+
.ok_or_else(|| {
847+
cx.tcx
848+
.dcx()
849+
.err(format!("invalid value for const generic: {const_}"))
850+
})
851+
}
852+
810853
match intrinsic_type_attr {
811854
IntrinsicType::GenericImageType => {
812855
// see SpirvType::sizeof
@@ -870,48 +913,6 @@ fn trans_intrinsic_type<'tcx>(
870913
// let image_format: spirv::ImageFormat =
871914
// type_from_variant_discriminant(cx, args.const_at(6));
872915

873-
trait FromScalarInt: Sized {
874-
fn from_scalar_int(n: ScalarInt) -> Option<Self>;
875-
}
876-
877-
impl FromScalarInt for u32 {
878-
fn from_scalar_int(n: ScalarInt) -> Option<Self> {
879-
Some(n.try_to_bits(Size::from_bits(32)).ok()?.try_into().unwrap())
880-
}
881-
}
882-
883-
impl FromScalarInt for Dim {
884-
fn from_scalar_int(n: ScalarInt) -> Option<Self> {
885-
Dim::from_u32(u32::from_scalar_int(n)?)
886-
}
887-
}
888-
889-
impl FromScalarInt for ImageFormat {
890-
fn from_scalar_int(n: ScalarInt) -> Option<Self> {
891-
ImageFormat::from_u32(u32::from_scalar_int(n)?)
892-
}
893-
}
894-
895-
fn const_int_value<'tcx, P: FromScalarInt>(
896-
cx: &CodegenCx<'tcx>,
897-
const_: Const<'tcx>,
898-
) -> Result<P, ErrorGuaranteed> {
899-
let ty::Value {
900-
ty: const_ty,
901-
valtree: const_val,
902-
} = const_.to_value();
903-
assert!(const_ty.is_integral());
904-
const_val
905-
.try_to_scalar()
906-
.and_then(|scalar| scalar.try_to_scalar_int().ok())
907-
.and_then(P::from_scalar_int)
908-
.ok_or_else(|| {
909-
cx.tcx
910-
.dcx()
911-
.err(format!("invalid value for Image const generic: {const_}"))
912-
})
913-
}
914-
915916
let dim = const_int_value(cx, args.const_at(1))?;
916917
let depth = const_int_value(cx, args.const_at(2))?;
917918
let arrayed = const_int_value(cx, args.const_at(3))?;
@@ -941,6 +942,54 @@ fn trans_intrinsic_type<'tcx>(
941942
Ok(SpirvType::AccelerationStructureKhr.def(span, cx))
942943
}
943944
IntrinsicType::RayQueryKhr => Ok(SpirvType::RayQueryKhr.def(span, cx)),
945+
IntrinsicType::CooperativeMatrixKhr => {
946+
if ty.size != Size::from_bytes(4) {
947+
return Err(cx.tcx.dcx().err("cooperative_matrix type must have size 4"));
948+
}
949+
950+
// Generic arg 0: component type T
951+
let component_type = match args.type_at(0).kind() {
952+
TyKind::Float(FloatTy::F32) => SpirvType::Float(32).def(span, cx),
953+
TyKind::Float(FloatTy::F64) => SpirvType::Float(64).def(span, cx),
954+
TyKind::Int(IntTy::I8) => SpirvType::Integer(8, true).def(span, cx),
955+
TyKind::Int(IntTy::I16) => SpirvType::Integer(16, true).def(span, cx),
956+
TyKind::Int(IntTy::I32) => SpirvType::Integer(32, true).def(span, cx),
957+
TyKind::Int(IntTy::I64) => SpirvType::Integer(64, true).def(span, cx),
958+
TyKind::Uint(UintTy::U8) => SpirvType::Integer(8, false).def(span, cx),
959+
TyKind::Uint(UintTy::U16) => SpirvType::Integer(16, false).def(span, cx),
960+
TyKind::Uint(UintTy::U32) => SpirvType::Integer(32, false).def(span, cx),
961+
TyKind::Uint(UintTy::U64) => SpirvType::Integer(64, false).def(span, cx),
962+
_ => {
963+
return Err(cx.tcx.dcx().span_err(
964+
span,
965+
"unsupported component type for #[spirv(cooperative_matrix)]: \
966+
must be f32, f64, i8, i16, i32, i64, u8, u16, u32, or u64",
967+
));
968+
}
969+
};
970+
971+
// Const generic 1: USE (MatrixA=0, MatrixB=1, MatrixAccumulator=2)
972+
// Const generic 2: ROWS
973+
// Const generic 3: COLS
974+
let use_val: u32 = const_int_value(cx, args.const_at(1))?;
975+
let rows_val: u32 = const_int_value(cx, args.const_at(2))?;
976+
let cols_val: u32 = const_int_value(cx, args.const_at(3))?;
977+
978+
// Scope: Subgroup = 3
979+
let scope = cx.const_u32(3).def_cx(cx);
980+
let rows = cx.const_u32(rows_val).def_cx(cx);
981+
let columns = cx.const_u32(cols_val).def_cx(cx);
982+
let use_ = cx.const_u32(use_val).def_cx(cx);
983+
984+
Ok(SpirvType::CooperativeMatrixKhr {
985+
component_type,
986+
scope,
987+
rows,
988+
columns,
989+
use_,
990+
}
991+
.def(span, cx))
992+
}
944993
IntrinsicType::SampledImage => {
945994
// see SpirvType::sizeof
946995
if ty.size != Size::from_bytes(4) {

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ pub enum IntrinsicType {
6969
TypedBuffer,
7070
Matrix,
7171
Vector,
72+
CooperativeMatrixKhr,
7273
}
7374

7475
#[derive(Copy, Clone, Debug, PartialEq, Eq)]

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
443443
self.fatal("cannot memset acceleration structure")
444444
}
445445
SpirvType::RayQueryKhr => self.fatal("cannot memset ray query"),
446+
SpirvType::CooperativeMatrixKhr { .. } => {
447+
self.fatal("cannot memset cooperative matrix")
448+
}
446449
}
447450
}
448451

@@ -500,6 +503,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
500503
self.fatal("cannot memset acceleration structure")
501504
}
502505
SpirvType::RayQueryKhr => self.fatal("cannot memset ray query"),
506+
SpirvType::CooperativeMatrixKhr { .. } => {
507+
self.fatal("cannot memset cooperative matrix")
508+
}
503509
}
504510
}
505511

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
44
use super::Builder;
55
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
66
use crate::spirv_type::SpirvType;
7+
use rspirv::spirv::Capability;
78
use rspirv::spirv::{Decoration, Word};
89
use rustc_abi::{Align, Size};
9-
use rspirv::spirv::Capability;
1010
use rustc_codegen_ssa::traits::BuilderMethods;
1111
use rustc_errors::ErrorGuaranteed;
1212
use rustc_middle::ty::Ty;

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ pub struct InstructionTable {
3232

3333
impl InstructionTable {
3434
pub fn new() -> Self {
35-
let table = rspirv::grammar::INSTRUCTION_TABLE.iter()
35+
let table = rspirv::grammar::INSTRUCTION_TABLE
36+
.iter()
3637
.map(|inst| (inst.opname, inst))
3738
.collect();
3839
Self { table }
@@ -1574,10 +1575,7 @@ pub const IMAGE_OPERANDS: &[(&str, ImageOperands)] = &[
15741575
("Sample", ImageOperands::SAMPLE),
15751576
("MinLod", ImageOperands::MIN_LOD),
15761577
("MakeTexelAvailable", ImageOperands::MAKE_TEXEL_AVAILABLE),
1577-
(
1578-
"MakeTexelAvailableKHR",
1579-
ImageOperands::MAKE_TEXEL_AVAILABLE,
1580-
),
1578+
("MakeTexelAvailableKHR", ImageOperands::MAKE_TEXEL_AVAILABLE),
15811579
("MakeTexelVisible", ImageOperands::MAKE_TEXEL_VISIBLE),
15821580
("MakeTexelVisibleKHR", ImageOperands::MAKE_TEXEL_VISIBLE),
15831581
("NonPrivateTexel", ImageOperands::NON_PRIVATE_TEXEL),
@@ -1660,15 +1658,9 @@ pub const MEMORY_ACCESS: &[(&str, MemoryAccess)] = &[
16601658
MemoryAccess::MAKE_POINTER_AVAILABLE,
16611659
),
16621660
("MakePointerVisible", MemoryAccess::MAKE_POINTER_VISIBLE),
1663-
(
1664-
"MakePointerVisibleKHR",
1665-
MemoryAccess::MAKE_POINTER_VISIBLE,
1666-
),
1661+
("MakePointerVisibleKHR", MemoryAccess::MAKE_POINTER_VISIBLE),
16671662
("NonPrivatePointer", MemoryAccess::NON_PRIVATE_POINTER),
1668-
(
1669-
"NonPrivatePointerKHR",
1670-
MemoryAccess::NON_PRIVATE_POINTER,
1671-
),
1663+
("NonPrivatePointerKHR", MemoryAccess::NON_PRIVATE_POINTER),
16721664
];
16731665
pub const KERNEL_PROFILING_INFO: &[(&str, KernelProfilingInfo)] = &[
16741666
("None", KernelProfilingInfo::NONE),

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,8 @@ impl<'tcx> CodegenCx<'tcx> {
619619
| SpirvType::SampledImage { .. }
620620
| SpirvType::InterfaceBlock { .. }
621621
| SpirvType::AccelerationStructureKhr
622-
| SpirvType::RayQueryKhr => {
622+
| SpirvType::RayQueryKhr
623+
| SpirvType::CooperativeMatrixKhr { .. } => {
623624
let result = self.undef(ty);
624625
self.zombie_no_span(
625626
result.def_cx(self),

0 commit comments

Comments
 (0)