Skip to content

Commit

Permalink
Enable vector to amx code generation and execution using libxsmm plat…
Browse files Browse the repository at this point in the history
…form setup call
  • Loading branch information
shahidact committed Feb 4, 2025
1 parent fecd114 commit b4f0b97
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class LinalgOp;

namespace vnni {
namespace utils {

bool hasAMX();
enum class VnniOperandRank {
TRANSPOSE = 3,
GEMM = 3,
Expand Down
5 changes: 4 additions & 1 deletion lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "TPP/Dialect/Perf/PerfOps.h"
#include "TPP/Dialect/Xsmm/XsmmDialect.h"
#include "TPP/PassUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Transforms/Passes.h"

#include <string>
Expand Down Expand Up @@ -187,7 +188,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createPrintIRPass());

// Lower to LLVM
pm.addPass(createConvertVectorToLLVMPass());
ConvertVectorToLLVMPassOptions options;
options.amx = vnni::utils::hasAMX() ? true : false;
pm.addPass(createConvertVectorToLLVMPass(options));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createConvertSCFToCFPass());
if (defParallel)
Expand Down
5 changes: 5 additions & 0 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ namespace mlir {
namespace vnni {
namespace utils {

bool hasAMX() {
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SPR) &&
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
}

unsigned getVnniBlockingFactor(Type type, Operation *op) {
unsigned blockingFactor = 0;

Expand Down
2 changes: 2 additions & 0 deletions tools/tpp-run/tpp-run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/Target/TargetOptions.h"

#include "TPP/Transforms/Utils/TensorInit.h"
#include "libxsmm.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -271,6 +272,7 @@ int main(int argc, char **argv) {
return 1;

// Initialize the LLVM machinery
libxsmm_init();
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
Expand Down

0 comments on commit b4f0b97

Please sign in to comment.