-
Notifications
You must be signed in to change notification settings - Fork 31
Intel AMX tile configuration
To illustrate TPP-MLIR's AMX tile configuration process, let's start with a high level input kernel:
func.func @entry(%arg0: tensor<1024x1024xbf16>,
%arg1: tensor<1024x1024xbf16>,
%arg2: tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<1024x1024xbf16>, tensor<1024x1024xbf16>)
outs(%arg2 : tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16>
return %0 : tensor<1024x1024xbf16>
}
NOTES:
- The AMX configuration is currently supported for XSMM dialect's BRGEMM operation using VNNI format.
- Relevant transformation work only with
bf16
. - The presented lowering steps are based on the
DefaultTppPasses
pipeline. Certain passes are omitted for brevity as they are not required for a simple GEMM kernel. The full TPP pipeline (stops before lowering to LLVM dialect) can be run using:tpp-opt -default-tpp-passes
First, tensor-level transformation are applied using the following passes:
tpp-opt -tpp-mapping -lower-packs-unpacks -canonicalize -cse
Then the kernel can be bufferized to lower the ops to memref abstraction:
tpp-opt -bufferize
The above transformations produces the following entry-point memref kernel:
#map = affine_map<(d0) -> (d0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
module {
func.func @entry(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xbf16>) {
// Block pack matrix A.
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32x32x32xbf16>
scf.forall (%arg3, %arg4) in (32, 32) {
%0 = affine.apply #map(%arg3)
%1 = affine.apply #map(%arg4)
%subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
%subview_1 = memref.subview %alloc[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
linalg.copy ins(%subview : memref<32x32xbf16, strided<[1024, 1], offset: ?>>) outs(%subview_1 : memref<32x32xbf16, strided<[32, 1], offset: ?>>)
}
// Block pack plus VNNI pack matrix B.
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x32x16x32x2xbf16>
scf.forall (%arg3, %arg4) in (32, 32) {
%0 = affine.apply #map(%arg4)
%1 = affine.apply #map(%arg3)
%subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
%subview_1 = memref.subview %alloc_0[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>
%expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[1024, 1], offset: ?>> into memref<16x2x32xbf16, strided<[2048, 1024, 1], offset: ?>>
linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[2048, 1024, 1], offset: ?>>) outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1]
}
// BRGEMM in VNNI format.
scf.forall (%arg3, %arg4) in (32, 32) {
%0 = affine.apply #map(%arg3)
%1 = affine.apply #map(%arg4)
%subview = memref.subview %alloc[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
%subview_1 = memref.subview %alloc_0[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_2 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} ins(%subview, %subview_1 : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_2 : memref<32x32xbf16, strided<[1024, 1], offset: ?>>) {
^bb0(%in: bf16, %in_3: bf16, %out: bf16):
%2 = arith.mulf %in, %in_3 : bf16
%3 = arith.addf %out, %2 : bf16
linalg.yield %3 : bf16
}
}
// Cleanup temporary buffers.
memref.dealloc %alloc : memref<32x32x32x32xbf16>
memref.dealloc %alloc_0 : memref<32x32x16x32x2xbf16>
return
}
}
Next, the preprocessed memref-level kernel can be lowered into XSMM operations.
For better readability, a simpler input kernel with prepacked arguments is used:
#map = affine_map<(d0) -> (d0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
module {
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
scf.forall (%arg3, %arg4) in (32, 32) {
%0 = affine.apply #map(%arg3)
%1 = affine.apply #map(%arg4)
%subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
%subview_1 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_2 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} ins(%subview, %subview_1 : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_2 : memref<32x32xbf16, strided<[1024, 1], offset: ?>>) {
^bb0(%in: bf16, %in_3: bf16, %out: bf16):
%2 = arith.mulf %in, %in_3 : bf16
%3 = arith.addf %out, %2 : bf16
linalg.yield %3 : bf16
}
}
return
}
}
The following passes can be applied to lower to XSMM and cleanup the IR:
tpp-opt -linalg-lowering -convert-forall-to-parallel
which produce:
#map = affine_map<(d0) -> (d0 * 32)>
module {
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
%c32_i64 = arith.constant 32 : i64
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) {
%0 = affine.apply #map(%arg3)
%1 = affine.apply #map(%arg4)
%subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
%subview_0 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
%2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
scf.reduce
}
return
}
}
The parallel loop around the XSMM call is further tiled using:
tpp-opt -scf-parallel-loop-tiling-pass=parallel-loop-tile-sizes=4,8 -canonicalize
The tile sizes are chosen somewhat arbitrarily to highlight creation of the two nested serial loops within the parallel loop.
Resulting IR:
#map = affine_map<(d0) -> (d0 * 32)>
module {
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
%c32_i64 = arith.constant 32 : i64
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c4, %c8) {
scf.for %arg5 = %c0 to %c4 step %c1 {
scf.for %arg6 = %c0 to %c8 step %c1 {
%0 = arith.addi %arg5, %arg3 : index
%1 = arith.addi %arg6, %arg4 : index
%2 = affine.apply #map(%0)
%3 = affine.apply #map(%1)
%subview = memref.subview %arg0[%0, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
%subview_0 = memref.subview %arg1[%1, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%2, %3] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
%4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
xsmm.brgemm(data_type = bf16, %4, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
}
}
scf.reduce
}
return
}
}
The AMX configuration can now be added to the above kernel containing XSMM operations.
NOTE: AMX config is currently applicable only to xsmm.(fused_)brgemm
.
The config ops can be inserted using:
tpp-opt -intel-amx-tile-config-insertion-pass
The function dispatch ops can be then hoisted out of the loops using:
tpp-opt -loop-invariant-code-motion -intel-amx-tile-config-hoisting-pass
The transformations produce the following kernel:
#map = affine_map<(d0) -> (d0 * 32)>
module {
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
%c32_i64 = arith.constant 32 : i64
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%0 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b, no_reset_tileconfig) data_type = bf16
%1 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b, no_setup_tileconfig) data_type = bf16
%2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b, no_reset_tileconfig, no_setup_tileconfig) data_type = bf16
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c4, %c8) {
%alloca = memref.alloca() : memref<64xi8>
"xsmm.IntelAMXtileConfig"(%0, %alloca) : (i64, memref<64xi8>) -> ()
scf.for %arg5 = %c0 to %c4 step %c1 {
%3 = arith.addi %arg5, %arg3 : index
%4 = affine.apply #map(%3)
%subview = memref.subview %arg0[%3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
scf.for %arg6 = %c0 to %c8 step %c1 {
%5 = arith.addi %arg6, %arg4 : index
%6 = affine.apply #map(%5)
%subview_0 = memref.subview %arg1[%5, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%4, %6] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
}
}
"xsmm.IntelAMXtileConfig"(%1, %alloca) : (i64, memref<64xi8>) -> ()
scf.reduce
}
return
}
}
The AMX insertion pass materializes explicit configuration ops (xsmm.IntelAMXtileConfig
ops) while also ensuring that other XSMM operations do not modify the tile config (no_reset_tileconfig
, no_setup_tileconfig
flags). For more details see libxsmm documentation.
Finally, the XSMM operations can be lowered to libxsmm function calls (plus cleanup for clarity) using:
tpp-opt -convert-xsmm-to-func -canonicalize -cse
resulting in:
#map = affine_map<(d0) -> (d0 * 32)>
module {
func.func private @xsmm_intel_amx_tile_config_invoke(i64, i64, !llvm.ptr, index)
func.func private @xsmm_brgemm_invoke(i64, i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64)
func.func private @xsmm_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
func.func private @xsmm_intel_amx_tile_config_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
%c2240_i64 = arith.constant 2240 : i64
%c2176_i64 = arith.constant 2176 : i64
%c32_i64 = arith.constant 32 : i64
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c2_i64 = arith.constant 2 : i64
%c1024_i64 = arith.constant 1024 : i64
%c2112_i64 = arith.constant 2112 : i64
%0 = call @xsmm_intel_amx_tile_config_dispatch(%c2_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %c1024_i64, %c2112_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
%1 = call @xsmm_intel_amx_tile_config_dispatch(%c2_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %c1024_i64, %c2176_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
%2 = call @xsmm_brgemm_dispatch(%c2_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %c1024_i64, %c2240_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c4, %c8) {
%alloca = memref.alloca() : memref<64xi8>
%intptr = memref.extract_aligned_pointer_as_index %alloca : memref<64xi8> -> index
%3 = arith.index_cast %intptr : index to i64
%4 = llvm.inttoptr %3 : i64 to !llvm.ptr
func.call @xsmm_intel_amx_tile_config_invoke(%c2_i64, %0, %4, %c0) : (i64, i64, !llvm.ptr, index) -> ()
scf.for %arg5 = %c0 to %c4 step %c1 {
%5 = arith.addi %arg5, %arg3 : index
%6 = affine.apply #map(%5)
%subview = memref.subview %arg0[%5, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
scf.for %arg6 = %c0 to %c8 step %c1 {
%7 = arith.addi %arg6, %arg4 : index
%8 = affine.apply #map(%7)
%subview_0 = memref.subview %arg1[%7, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%6, %8] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
%base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index
%intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index
%9 = arith.index_cast %intptr_2 : index to i64
%10 = llvm.inttoptr %9 : i64 to !llvm.ptr
%base_buffer_3, %offset_4, %sizes_5:4, %strides_6:4 = memref.extract_strided_metadata %subview_0 : memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index, index, index
%intptr_7 = memref.extract_aligned_pointer_as_index %subview_0 : memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index
%11 = arith.index_cast %intptr_7 : index to i64
%12 = llvm.inttoptr %11 : i64 to !llvm.ptr
%base_buffer_8, %offset_9, %sizes_10:2, %strides_11:2 = memref.extract_strided_metadata %subview_1 : memref<32x32xbf16, strided<[1024, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index
%intptr_12 = memref.extract_aligned_pointer_as_index %subview_1 : memref<32x32xbf16, strided<[1024, 1], offset: ?>> -> index
%13 = arith.index_cast %intptr_12 : index to i64
%14 = llvm.inttoptr %13 : i64 to !llvm.ptr
func.call @xsmm_brgemm_invoke(%c2_i64, %2, %10, %offset, %12, %offset_4, %14, %offset_9, %c32_i64) : (i64, i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> ()
}
}
func.call @xsmm_intel_amx_tile_config_invoke(%c2_i64, %1, %4, %c0) : (i64, i64, !llvm.ptr, index) -> ()
scf.reduce
}
return
}
}