Skip to content

Commit

Permalink
Enable vector to amx code generation and execution.
Browse files Browse the repository at this point in the history
Adds support to check and enable amx-bf16 feature using libxsmm platform setup API.
Updates default pipeline to enable vector to amx lowering based on target feature.
  • Loading branch information
shahidact committed Feb 13, 2025
1 parent fecd114 commit 400ad9b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 2 deletions.
4 changes: 3 additions & 1 deletion include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class LinalgOp;

namespace vnni {
namespace utils {

enum class VnniOperandRank {
TRANSPOSE = 3,
GEMM = 3,
BRGEMM_INS = 4,
BRGEMM_OUTS = 3
};

// Returns True if the current architecture supports AMX instructions.
bool hasAMX();

// Return the VNNI blocking factor if it can be determined for the given type or
// zero, otherwise.
// Optionally, an operation can be provided to give access to DLTI.
Expand Down
5 changes: 4 additions & 1 deletion lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "TPP/Dialect/Perf/PerfOps.h"
#include "TPP/Dialect/Xsmm/XsmmDialect.h"
#include "TPP/PassUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Transforms/Passes.h"

#include <string>
Expand Down Expand Up @@ -187,7 +188,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createPrintIRPass());

// Lower to LLVM
pm.addPass(createConvertVectorToLLVMPass());
ConvertVectorToLLVMPassOptions options;
options.amx = vnni::utils::hasAMX() ? true : false;
pm.addPass(createConvertVectorToLLVMPass(options));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createConvertSCFToCFPass());
if (defParallel)
Expand Down
6 changes: 6 additions & 0 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ namespace mlir {
namespace vnni {
namespace utils {

// Returns True if the current architecture supports AMX instructions.
bool hasAMX() {
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SPR) &&
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
}

unsigned getVnniBlockingFactor(Type type, Operation *op) {
unsigned blockingFactor = 0;

Expand Down
19 changes: 19 additions & 0 deletions test/Integration/tpp-run-amx-feature-initialization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: not --crash tpp-run %s -e entry -entry-point-result=void -mattr=amx-bf16 2>&1 | FileCheck %s --check-prefix=CHECK-AMX-BF16
// RUN: not --crash env LIBXSMM_TARGET=spr tpp-run %s -e entry -entry-point-result=void -mattr=amx-bf16 2>&1 | FileCheck %s --check-prefix=CHECK-AMX-BF16-SETUP

//Tests for unsuccessfull compilation implying AMX pipeline was not initialized
// CHECK-AMX-BF16: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast

//Tests for successfull compilation implying AMX pipeline was initialized properly.
// CHECK-AMX-BF16-SETUP-NOT: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
func.func @entry(%arg0: memref<16x32xbf16>,
%arg1: memref<16x32xbf16>,
%arg2: memref<16x16xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
%3 = amx.tile_zero : !amx.tile<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, !amx.tile<16x16xf32>
return
}
3 changes: 3 additions & 0 deletions tools/tpp-run/tpp-run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/Target/TargetOptions.h"

#include "TPP/Transforms/Utils/TensorInit.h"
#include "libxsmm.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -270,6 +271,8 @@ int main(int argc, char **argv) {
if (failed(validateInput()))
return 1;

// Initialize the underlying platform
libxsmm_init();
// Initialize the LLVM machinery
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
Expand Down

0 comments on commit 400ad9b

Please sign in to comment.