Skip to content

Commit e4090ae

Browse files
committed
Add bypass for x86amx
1 parent 9b01f48 commit e4090ae

File tree

5 files changed

+40
-4
lines changed

5 files changed

+40
-4
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,25 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
368368
}
369369

370370
match self.type_kind(llvm_ty) {
371+
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
372+
let element_count = self.vector_length(rust_ty);
373+
let element_ty = self.element_type(rust_ty);
374+
375+
let element_size_bits = match self.type_kind(element_ty) {
376+
TypeKind::Half => 16,
377+
TypeKind::Float => 32,
378+
TypeKind::Double => 64,
379+
TypeKind::FP128 => 128,
380+
TypeKind::Integer => self.int_width(element_ty),
381+
TypeKind::Pointer => self.int_width(self.isize_ty()),
382+
_ => bug!(
383+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
384+
),
385+
};
386+
let vector_size_bits = element_size_bits * element_count as u64;
387+
388+
vector_size_bits == 8192
389+
}
371390
TypeKind::BFloat => rust_ty == self.type_i16(),
372391
TypeKind::Vector => {
373392
let llvm_element_count = self.vector_length(llvm_ty) as u64;

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
492492
let width = size.bits();
493493

494494
if oop == OverflowOp::Sub && !signed {
495-
let sub = self.sub(lhs, rhs);
496-
let cmp = self.icmp(IntPredicate::IntULT, lhs, rhs);
497-
return (sub, cmp);
498-
}
495+
let sub = self.sub(lhs, rhs);
496+
let cmp = self.icmp(IntPredicate::IntULT, lhs, rhs);
497+
return (sub, cmp);
498+
}
499499

500500
let oop_str = match oop {
501501
OverflowOp::Add => "add",
@@ -1641,6 +1641,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16411641
}
16421642

16431643
match self.type_kind(llvm_ty) {
1644+
TypeKind::X86_AMX => {
1645+
if is_argument {
1646+
self.call_intrinsic("llvm.x86.cast.vector.to.tile", &[rust_ty], &[val])
1647+
} else {
1648+
self.call_intrinsic("llvm.x86.cast.tile.to.vector", &[rust_ty], &[val])
1649+
}
1650+
}
16441651
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
16451652
if is_argument {
16461653
self.trunc_int_to_i1_vector(val, dest_ty)

compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,7 @@ impl<'ll> CodegenCx<'ll, '_> {
925925
let t_isize = self.type_isize();
926926
let t_metadata = self.type_metadata();
927927
let t_token = self.type_token();
928+
let x86amx = self.type_x86amx();
928929

929930
ifn!("llvm.wasm.get.exception", fn(t_token) -> ptr);
930931
ifn!("llvm.wasm.get.ehselector", fn(t_token) -> t_i32);
@@ -1038,6 +1039,9 @@ impl<'ll> CodegenCx<'ll, '_> {
10381039
ifn!("llvm.masked.gather", fn(1, t_i32, same_width_vector(0, i1), 0) -> 0);
10391040
ifn!("llvm.masked.scatter", fn(0, 1, t_i32, same_width_vector(0, i1)) -> void);
10401041

1042+
ifn!("llvm.x86.cast.vector.to.tile", fn(0) -> x86amx);
1043+
ifn!("llvm.x86.cast.tile.to.vector", fn(x86amx) -> 0);
1044+
10411045
bug!("Unknown intrinsic: `{base_name}`")
10421046
}
10431047

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,8 @@ unsafe extern "C" {
10841084
pub(crate) fn LLVMTokenTypeInContext(C: &Context) -> &Type;
10851085
pub(crate) fn LLVMMetadataTypeInContext(C: &Context) -> &Type;
10861086

1087+
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;
1088+
10871089
// Operations on all values
10881090
pub(crate) fn LLVMTypeOf(Val: &Value) -> &Type;
10891091
pub(crate) fn LLVMGetValueName2(Val: &Value, Length: *mut size_t) -> *const c_char;

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
6666
unsafe { llvm::LLVMMetadataTypeInContext(self.llcx()) }
6767
}
6868

69+
pub(crate) fn type_x86amx(&self) -> &'ll Type {
70+
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
71+
}
72+
6973
///x Creates an integer type with the given number of bits, e.g., i24
7074
pub(crate) fn type_ix(&self, num_bits: u64) -> &'ll Type {
7175
unsafe { llvm::LLVMIntTypeInContext(self.llcx(), num_bits as c_uint) }

0 commit comments

Comments
 (0)