Skip to content

Commit 0ea0fb5

Browse files
committed
Add auto-bitcasts from/to x86amx and i32x256 for AMX intrinsics
1 parent 3ef8e64 commit 0ea0fb5

File tree

5 files changed

+35
-5
lines changed

5 files changed

+35
-5
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

+21-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::cmp;
44
use libc::c_uint;
55
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
66
use rustc_codegen_ssa::MemFlags;
7+
use rustc_codegen_ssa::common::TypeKind;
78
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
89
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
910
use rustc_codegen_ssa::traits::*;
@@ -331,20 +332,37 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
331332
let args =
332333
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };
333334

335+
let adjust_ty = |ty| {
336+
// todo: rectify this to be more selective (help wanted)
337+
let probably_unadjusted = self.conv == Conv::C && !self.can_unwind && !self.c_variadic;
338+
let probably_amx_intrinsic = probably_unadjusted && cx.tcx.sess.target.arch == "x86_64";
339+
// Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
340+
if probably_amx_intrinsic
341+
&& cx.type_kind(ty) == TypeKind::Vector
342+
&& cx.vector_length(ty) == 256
343+
{
344+
let element_ty = cx.element_type(ty);
345+
if cx.type_kind(element_ty) == TypeKind::Integer && cx.int_width(element_ty) == 32 {
346+
return cx.type_x86amx();
347+
}
348+
}
349+
ty
350+
};
351+
334352
// This capacity calculation is approximate.
335353
let mut llargument_tys = Vec::with_capacity(
336354
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
337355
);
338356

339-
let llreturn_ty = match &self.ret.mode {
357+
let llreturn_ty = adjust_ty(match &self.ret.mode {
340358
PassMode::Ignore => cx.type_void(),
341359
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
342360
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
343361
PassMode::Indirect { .. } => {
344362
llargument_tys.push(cx.type_ptr());
345363
cx.type_void()
346364
}
347-
};
365+
});
348366

349367
for arg in args {
350368
// Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +406,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388406
cast.llvm_type(cx)
389407
}
390408
};
391-
llargument_tys.push(llarg_ty);
409+
llargument_tys.push(adjust_ty(llarg_ty));
392410
}
393411

394412
if self.c_variadic {

compiler/rustc_codegen_llvm/src/builder.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14351435
if let Some(fn_abi) = fn_abi {
14361436
fn_abi.apply_attrs_callsite(self, call);
14371437
}
1438-
call
1438+
1439+
if self.cx.type_kind(self.cx.val_ty(call)) == TypeKind::X86_AMX {
1440+
self.bitcast(call, self.cx.type_vector(self.cx.type_i32(), 256))
1441+
} else {
1442+
call
1443+
}
14391444
}
14401445

14411446
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,9 @@ unsafe extern "C" {
10551055
pub(crate) fn LLVMPointerTypeInContext(C: &Context, AddressSpace: c_uint) -> &Type;
10561056
pub(crate) fn LLVMVectorType(ElementType: &Type, ElementCount: c_uint) -> &Type;
10571057

1058+
// Special X86 Type for AMX
1059+
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;
1060+
10581061
pub(crate) fn LLVMGetElementType(Ty: &Type) -> &Type;
10591062
pub(crate) fn LLVMGetVectorSize(VectorTy: &Type) -> c_uint;
10601063

compiler/rustc_codegen_llvm/src/type_.rs

+4
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
154154
)
155155
}
156156
}
157+
158+
pub(crate) fn type_x86amx(&self) -> &'ll Type {
159+
unsafe { llvm::LLVMX86AMXTypeInContext(self.llcx()) }
160+
}
157161
}
158162

159163
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {

compiler/rustc_target/src/target_features.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ pub fn all_rust_features() -> impl Iterator<Item = (&'static str, Stability)> {
786786
// certain size to have their "proper" ABI on each architecture.
787787
// Note that they must be kept sorted by vector size.
788788
const X86_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] =
789-
&[(128, "sse"), (256, "avx"), (512, "avx512f")]; // FIXME: might need changes for AVX10.
789+
&[(128, "sse"), (256, "avx"), (512, "avx512f"), (8192, "amx-tile")];
790790
const AARCH64_FEATURES_FOR_CORRECT_VECTOR_ABI: &'static [(u64, &'static str)] = &[(128, "neon")];
791791

792792
// We might want to add "helium" too.

0 commit comments

Comments
 (0)