diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index d5d12a6f6..8978f3cba 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -28,7 +28,6 @@ class LinalgOp; namespace vnni { namespace utils { - enum class VnniOperandRank { TRANSPOSE = 3, GEMM = 3, @@ -36,6 +35,9 @@ enum class VnniOperandRank { BRGEMM_OUTS = 3 }; +// Returns True if the current architecture supports AMX instructions. +bool hasAMX(); + // Return the VNNI blocking factor if it can be determined for the given type or // zero, otherwise. // Optionally, an operation can be provided to give access to DLTI. diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index b9eefa786..cdf84a79b 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -22,6 +22,7 @@ #include "TPP/Dialect/Perf/PerfOps.h" #include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" +#include "TPP/Transforms/Utils/VNNIUtils.h" #include "mlir/Transforms/Passes.h" #include @@ -187,7 +188,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createPrintIRPass()); // Lower to LLVM - pm.addPass(createConvertVectorToLLVMPass()); + ConvertVectorToLLVMPassOptions options; + options.amx = vnni::utils::hasAMX() ? true : false; + pm.addPass(createConvertVectorToLLVMPass(options)); pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createConvertSCFToCFPass()); if (defParallel) diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 87f290e25..1e14002b0 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -22,6 +22,12 @@ namespace mlir { namespace vnni { namespace utils { +// Returns True if the current architecture supports AMX instructions. +bool hasAMX() { + return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SPR) && + (libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT); +} + unsigned getVnniBlockingFactor(Type type, Operation *op) { unsigned blockingFactor = 0; diff --git a/test/Integration/tpp-run-amx-feature-initialization.mlir b/test/Integration/tpp-run-amx-feature-initialization.mlir new file mode 100644 index 000000000..b1de9f593 --- /dev/null +++ b/test/Integration/tpp-run-amx-feature-initialization.mlir @@ -0,0 +1,19 @@ +// RUN: not --crash tpp-run %s -e entry -entry-point-result=void -mattr=amx-bf16 2>&1 | FileCheck %s --check-prefix=CHECK-AMX-BF16 +// RUN: not --crash env LIBXSMM_TARGET=spr tpp-run %s -e entry -entry-point-result=void -mattr=amx-bf16 2>&1 | FileCheck %s --check-prefix=CHECK-AMX-BF16-SETUP + +//Tests for unsuccessfull compilation implying AMX pipeline was not initialized +// CHECK-AMX-BF16: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast + +//Tests for successfull compilation implying AMX pipeline was initialized properly. +// CHECK-AMX-BF16-SETUP-NOT: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast +func.func @entry(%arg0: memref<16x32xbf16>, + %arg1: memref<16x32xbf16>, + %arg2: memref<16x16xf32>) { + %0 = arith.constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16> + %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16> + %3 = amx.tile_zero : !amx.tile<16x16xf32> + %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> + amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, !amx.tile<16x16xf32> + return +} \ No newline at end of file diff --git a/tools/tpp-run/tpp-run.cpp b/tools/tpp-run/tpp-run.cpp index 7db7f81c2..6ccf3d909 100644 --- a/tools/tpp-run/tpp-run.cpp +++ b/tools/tpp-run/tpp-run.cpp @@ -23,6 +23,7 @@ #include "llvm/Target/TargetOptions.h" #include "TPP/Transforms/Utils/TensorInit.h" +#include "libxsmm.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -270,6 +271,8 @@ int main(int argc, char **argv) { if (failed(validateInput())) return 1; + // Initialize the underlying platform + libxsmm_init(); // Initialize the LLVM machinery llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget();