@@ -4,6 +4,7 @@ use std::cmp;
4
4
use libc:: c_uint;
5
5
use rustc_abi:: { BackendRepr , HasDataLayout , Primitive , Reg , RegKind , Size } ;
6
6
use rustc_codegen_ssa:: MemFlags ;
7
+ use rustc_codegen_ssa:: common:: TypeKind ;
7
8
use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
8
9
use rustc_codegen_ssa:: mir:: place:: { PlaceRef , PlaceValue } ;
9
10
use rustc_codegen_ssa:: traits:: * ;
@@ -331,20 +332,37 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
331
332
let args =
332
333
if self . c_variadic { & self . args [ ..self . fixed_count as usize ] } else { & self . args } ;
333
334
335
+ let adjust_ty = |ty| {
336
+ // todo: rectify this to be more selective (help wanted)
337
+ let probably_unadjusted = self . conv == Conv :: C && !self . can_unwind && !self . c_variadic ;
338
+ let probably_amx_intrinsic = probably_unadjusted && cx. tcx . sess . target . arch == "x86_64" ;
339
+ // Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
340
+ if probably_amx_intrinsic
341
+ && cx. type_kind ( ty) == TypeKind :: Vector
342
+ && cx. vector_length ( ty) == 256
343
+ {
344
+ let element_ty = cx. element_type ( ty) ;
345
+ if cx. type_kind ( element_ty) == TypeKind :: Integer && cx. int_width ( element_ty) == 32 {
346
+ return cx. type_x86amx ( ) ;
347
+ }
348
+ }
349
+ ty
350
+ } ;
351
+
334
352
// This capacity calculation is approximate.
335
353
let mut llargument_tys = Vec :: with_capacity (
336
354
self . args . len ( ) + if let PassMode :: Indirect { .. } = self . ret . mode { 1 } else { 0 } ,
337
355
) ;
338
356
339
- let llreturn_ty = match & self . ret . mode {
357
+ let llreturn_ty = adjust_ty ( match & self . ret . mode {
340
358
PassMode :: Ignore => cx. type_void ( ) ,
341
359
PassMode :: Direct ( _) | PassMode :: Pair ( ..) => self . ret . layout . immediate_llvm_type ( cx) ,
342
360
PassMode :: Cast { cast, pad_i32 : _ } => cast. llvm_type ( cx) ,
343
361
PassMode :: Indirect { .. } => {
344
362
llargument_tys. push ( cx. type_ptr ( ) ) ;
345
363
cx. type_void ( )
346
364
}
347
- } ;
365
+ } ) ;
348
366
349
367
for arg in args {
350
368
// Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +406,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388
406
cast. llvm_type ( cx)
389
407
}
390
408
} ;
391
- llargument_tys. push ( llarg_ty) ;
409
+ llargument_tys. push ( adjust_ty ( llarg_ty) ) ;
392
410
}
393
411
394
412
if self . c_variadic {
0 commit comments