Skip to content

Intel AMX tile configuration

Adam Siemieniuk edited this page Jun 20, 2024 · 5 revisions

Example input kernel

To illustrate TPP-MLIR's AMX tile configuration process, let's start with a high level input kernel:

func.func @entry(%arg0: tensor<1024x1024xbf16>,
                 %arg1: tensor<1024x1024xbf16>,
                 %arg2: tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16> {
  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<1024x1024xbf16>, tensor<1024x1024xbf16>)
                     outs(%arg2 : tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16>
  return %0 : tensor<1024x1024xbf16>
}

NOTES:

  • The AMX configuration is currently supported for XSMM dialect's BRGEMM operation using VNNI format.
  • Relevant transformation work only with bf16.
  • The presented lowering steps are based on the DefaultTppPasses pipeline. Certain passes are omitted for brevity as they are not required for a simple GEMM kernel. The full TPP pipeline (stops before lowering to LLVM dialect) can be run using: tpp-opt -default-tpp-passes

Kernel preprocessing

First, tensor-level transformation are applied using the following passes:

tpp-opt -tpp-mapping -lower-packs-unpacks -canonicalize -cse

Then the kernel can be bufferized to lower the ops to memref abstraction:

tpp-opt -bufferize

The above transformations produces the following entry-point memref kernel:

#map = affine_map<(d0) -> (d0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
module {
  func.func @entry(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xbf16>) {
    // Block pack matrix A.
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32x32x32xbf16>
    scf.forall (%arg3, %arg4) in (32, 32) {
      %0 = affine.apply #map(%arg3)
      %1 = affine.apply #map(%arg4)
      %subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
      %subview_1 = memref.subview %alloc[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
      linalg.copy ins(%subview : memref<32x32xbf16, strided<[1024, 1], offset: ?>>) outs(%subview_1 : memref<32x32xbf16, strided<[32, 1], offset: ?>>)
    }
    // Block pack plus VNNI pack matrix B.
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x32x16x32x2xbf16>
    scf.forall (%arg3, %arg4) in (32, 32) {
      %0 = affine.apply #map(%arg4)
      %1 = affine.apply #map(%arg3)
      %subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
      %subview_1 = memref.subview %alloc_0[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>
      %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[1024, 1], offset: ?>> into memref<16x2x32xbf16, strided<[2048, 1024, 1], offset: ?>>
      linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[2048, 1024, 1], offset: ?>>) outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] 
    }
    // BRGEMM in VNNI format.
    scf.forall (%arg3, %arg4) in (32, 32) {
      %0 = affine.apply #map(%arg3)
      %1 = affine.apply #map(%arg4)
      %subview = memref.subview %alloc[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
      %subview_1 = memref.subview %alloc_0[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
      %subview_2 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
      linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} ins(%subview, %subview_1 : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_2 : memref<32x32xbf16, strided<[1024, 1], offset: ?>>) {
      ^bb0(%in: bf16, %in_3: bf16, %out: bf16):
        %2 = arith.mulf %in, %in_3 : bf16
        %3 = arith.addf %out, %2 : bf16
        linalg.yield %3 : bf16
      }
    }
    // Cleanup temporary buffers.
    memref.dealloc %alloc : memref<32x32x32x32xbf16>
    memref.dealloc %alloc_0 : memref<32x32x16x32x2xbf16>
    return
  }
}

Lowering to XSMM dialect

Next, the preprocessed memref-level kernel can be lowered into XSMM operations.

For better readability, a simpler input kernel with prepacked arguments is used:

#map = affine_map<(d0) -> (d0 * 32)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
module {
  func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
    scf.forall (%arg3, %arg4) in (32, 32) {
      %0 = affine.apply #map(%arg3)
      %1 = affine.apply #map(%arg4)
      %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
      %subview_1 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
      %subview_2 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
      linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} ins(%subview, %subview_1 : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_2 : memref<32x32xbf16, strided<[1024, 1], offset: ?>>) {
      ^bb0(%in: bf16, %in_3: bf16, %out: bf16):
        %2 = arith.mulf %in, %in_3 : bf16
        %3 = arith.addf %out, %2 : bf16
        linalg.yield %3 : bf16
      }
    }
    return
  }
}

The following passes can be applied to lower to XSMM and cleanup the IR:

tpp-opt -linalg-lowering -convert-forall-to-parallel

which produce:

#map = affine_map<(d0) -> (d0 * 32)>
module {
  func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
    %c32_i64 = arith.constant 32 : i64
    %c0 = arith.constant 0 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) {
      %0 = affine.apply #map(%arg3)
      %1 = affine.apply #map(%arg4)
      %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
      %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
      %subview_1 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
      %2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
      xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
      scf.reduce 
    }
    return
  }
}

2D parallelism

The parallel loop around the XSMM call is further tiled using:

tpp-opt -scf-parallel-loop-tiling-pass=parallel-loop-tile-sizes=4,8 -canonicalize

The tile sizes are chosen somewhat arbitrarily to highlight creation of the two nested serial loops within the parallel loop.
Resulting IR:

#map = affine_map<(d0) -> (d0 * 32)>
module {
  func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
    %c32_i64 = arith.constant 32 : i64
    %c0 = arith.constant 0 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    %c4 = arith.constant 4 : index
    %c8 = arith.constant 8 : index
    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c4, %c8) {
      scf.for %arg5 = %c0 to %c4 step %c1 {
        scf.for %arg6 = %c0 to %c8 step %c1 {
          %0 = arith.addi %arg5, %arg3 : index
          %1 = arith.addi %arg6, %arg4 : index
          %2 = affine.apply #map(%0)
          %3 = affine.apply #map(%1)
          %subview = memref.subview %arg0[%0, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
          %subview_0 = memref.subview %arg1[%1, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
          %subview_1 = memref.subview %arg2[%2, %3] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
          %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
          xsmm.brgemm(data_type = bf16, %4, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
        }
      }
      scf.reduce 
    }
    return
  }
}

Generating AMX configuration

The AMX configuration can now be added to the above kernel containing XSMM operations.
NOTE: AMX config is currently applicable only to xsmm.(fused_)brgemm.

The config ops can be inserted using:

tpp-opt -intel-amx-tile-config-insertion-pass

The function dispatch ops can be then hoisted out of the loops using:

tpp-opt -loop-invariant-code-motion -intel-amx-tile-config-hoisting-pass

The transformations produce the following kernel:

#map = affine_map<(d0) -> (d0 * 32)>
module {
  func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
    %c32_i64 = arith.constant 32 : i64
    %c0 = arith.constant 0 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    %c4 = arith.constant 4 : index
    %c8 = arith.constant 8 : index
    %0 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b, no_reset_tileconfig) data_type = bf16
    %1 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b, no_setup_tileconfig) data_type = bf16
    %2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b, no_reset_tileconfig, no_setup_tileconfig) data_type = bf16
    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c4, %c8) {
      %alloca = memref.alloca() : memref<64xi8>
      "xsmm.IntelAMXtileConfig"(%0, %alloca) : (i64, memref<64xi8>) -> ()
      scf.for %arg5 = %c0 to %c4 step %c1 {
        %3 = arith.addi %arg5, %arg3 : index
        %4 = affine.apply #map(%3)
        %subview = memref.subview %arg0[%3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
        scf.for %arg6 = %c0 to %c8 step %c1 {
          %5 = arith.addi %arg6, %arg4 : index
          %6 = affine.apply #map(%5)
          %subview_0 = memref.subview %arg1[%5, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
          %subview_1 = memref.subview %arg2[%4, %6] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
          xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
        }
      }
      "xsmm.IntelAMXtileConfig"(%1, %alloca) : (i64, memref<64xi8>) -> ()
      scf.reduce 
    }
    return
  }
}

The AMX insertion pass materializes explicit configuration ops (xsmm.IntelAMXtileConfig ops) while also ensuring that other XSMM operations do not modify the tile config (no_reset_tileconfig, no_setup_tileconfig flags). For more details see libxsmm documentation.

Lowering XSMM operations

Finally, the XSMM operations can be lowered to libxsmm function calls (plus cleanup for clarity) using:

tpp-opt -convert-xsmm-to-func -canonicalize -cse

resulting in:

#map = affine_map<(d0) -> (d0 * 32)>
module {
  func.func private @xsmm_intel_amx_tile_config_invoke(i64, i64, !llvm.ptr, index)
  func.func private @xsmm_brgemm_invoke(i64, i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64)
  func.func private @xsmm_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
  func.func private @xsmm_intel_amx_tile_config_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
  func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
    %c2240_i64 = arith.constant 2240 : i64
    %c2176_i64 = arith.constant 2176 : i64
    %c32_i64 = arith.constant 32 : i64
    %c0 = arith.constant 0 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    %c4 = arith.constant 4 : index
    %c8 = arith.constant 8 : index
    %c2_i64 = arith.constant 2 : i64
    %c1024_i64 = arith.constant 1024 : i64
    %c2112_i64 = arith.constant 2112 : i64
    %0 = call @xsmm_intel_amx_tile_config_dispatch(%c2_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %c1024_i64, %c2112_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
    %1 = call @xsmm_intel_amx_tile_config_dispatch(%c2_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %c1024_i64, %c2176_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
    %2 = call @xsmm_brgemm_dispatch(%c2_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %c1024_i64, %c2240_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64
    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c32, %c32) step (%c4, %c8) {
      %alloca = memref.alloca() : memref<64xi8>
      %intptr = memref.extract_aligned_pointer_as_index %alloca : memref<64xi8> -> index
      %3 = arith.index_cast %intptr : index to i64
      %4 = llvm.inttoptr %3 : i64 to !llvm.ptr
      func.call @xsmm_intel_amx_tile_config_invoke(%c2_i64, %0, %4, %c0) : (i64, i64, !llvm.ptr, index) -> ()
      scf.for %arg5 = %c0 to %c4 step %c1 {
        %5 = arith.addi %arg5, %arg3 : index
        %6 = affine.apply #map(%5)
        %subview = memref.subview %arg0[%5, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
        scf.for %arg6 = %c0 to %c8 step %c1 {
          %7 = arith.addi %arg6, %arg4 : index
          %8 = affine.apply #map(%7)
          %subview_0 = memref.subview %arg1[%7, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
          %subview_1 = memref.subview %arg2[%6, %8] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
          %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index
          %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index
          %9 = arith.index_cast %intptr_2 : index to i64
          %10 = llvm.inttoptr %9 : i64 to !llvm.ptr
          %base_buffer_3, %offset_4, %sizes_5:4, %strides_6:4 = memref.extract_strided_metadata %subview_0 : memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index, index, index
          %intptr_7 = memref.extract_aligned_pointer_as_index %subview_0 : memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index
          %11 = arith.index_cast %intptr_7 : index to i64
          %12 = llvm.inttoptr %11 : i64 to !llvm.ptr
          %base_buffer_8, %offset_9, %sizes_10:2, %strides_11:2 = memref.extract_strided_metadata %subview_1 : memref<32x32xbf16, strided<[1024, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index
          %intptr_12 = memref.extract_aligned_pointer_as_index %subview_1 : memref<32x32xbf16, strided<[1024, 1], offset: ?>> -> index
          %13 = arith.index_cast %intptr_12 : index to i64
          %14 = llvm.inttoptr %13 : i64 to !llvm.ptr
          func.call @xsmm_brgemm_invoke(%c2_i64, %2, %10, %offset, %12, %offset_4, %14, %offset_9, %c32_i64) : (i64, i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> ()
        }
      }
      func.call @xsmm_intel_amx_tile_config_invoke(%c2_i64, %1, %4, %c0) : (i64, i64, !llvm.ptr, index) -> ()
      scf.reduce 
    }
    return
  }
}