From 815ce3f9596c6798392e7cfe52dc2a31e2f4df17 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 17 Jan 2025 18:14:47 +0100 Subject: [PATCH] Make VNNI more robust (#1001) Improves VNNI validation and lit tests to be target independent w.r.t. VNNI blocking factor. VNNI infrastructure is extended with DLTI support that allows to override target's default VNNI factor. This allows to fix VNNI shape for lit testing to ensure that general pass logic works fine. This is primarily applied to tests that use prepacked VNNI shapes. When possible, tests' checks are generalized to account for varying packing factor. Test shapes are adjusted to prevent failure due to incompatible VNNI packing sizes compared to dimension size used for verification. BF16 integration tests are now disabled on non-x86 systems to avoid runtime mismatch between prepacked shapes and microkernel requirements. --- include/TPP/Transforms/Utils/DLTIUtils.h | 30 ++ include/TPP/Transforms/Utils/VNNIUtils.h | 27 +- .../ConvertLinalgToXsmm.cpp | 8 +- .../ConvertVectorToXsmm.cpp | 3 +- lib/TPP/Dialect/Xsmm/XsmmOps.cpp | 10 +- lib/TPP/Dialect/Xsmm/XsmmVerify.cpp | 10 +- lib/TPP/IR/MatcherUtils.cpp | 6 +- .../TileConsumerAndFuseProducers.cpp | 25 +- lib/TPP/Transforms/ToBlockLayoutAndBack.cpp | 12 +- lib/TPP/Transforms/Utils/CMakeLists.txt | 1 + lib/TPP/Transforms/Utils/DLTIUtils.cpp | 32 ++ lib/TPP/Transforms/Utils/VNNIUtils.cpp | 67 ++-- test/BF16/Integration/lit.local.cfg | 23 ++ test/BF16/Integration/mlir-gen-bf16.mlir | 28 +- test/BF16/Integration/vnni-xsmm-vs-loops.mlir | 29 +- test/BF16/brgemm-tpp.mlir | 8 +- test/BF16/brgemm-vnni.mlir | 16 +- test/BF16/matmul-untiled-vnni.mlir | 2 +- test/BF16/matmul-vnni.mlir | 16 +- .../LinalgToXsmm/linalg-to-brgemm.mlir | 129 ++++---- .../LinalgToXsmm/linalg-to-gemm.mlir | 311 ++++++++++-------- .../LinalgToXsmm/linalg-to-unary.mlir | 75 +++-- .../VectorToXsmm/vector-to-transpose.mlir | 55 ++-- test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir | 15 +- .../hoist-vector-transfer-brgemm.mlir | 2 +- test/Integration/transpose-bf16.mlir | 19 +- test/Integration/vector-contract-to-fma.mlir | 4 +- test/Passes/DefaultPipeline/linalg.mlir | 63 ++-- test/Passes/DefaultPipeline/vnni.mlir | 272 ++++++++------- test/Passes/DefaultPipeline/xsmm.mlir | 82 +++-- test/Passes/pack-vnni.mlir | 104 ++++++ test/Passes/xsmm-combine.mlir | 202 ++++++------ 32 files changed, 1012 insertions(+), 674 deletions(-) create mode 100644 include/TPP/Transforms/Utils/DLTIUtils.h create mode 100644 lib/TPP/Transforms/Utils/DLTIUtils.cpp create mode 100644 test/Passes/pack-vnni.mlir diff --git a/include/TPP/Transforms/Utils/DLTIUtils.h b/include/TPP/Transforms/Utils/DLTIUtils.h new file mode 100644 index 000000000..3e1e17ddc --- /dev/null +++ b/include/TPP/Transforms/Utils/DLTIUtils.h @@ -0,0 +1,30 @@ +//===- DLTIUtils.h -----------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TPP_TRANSFORMS_UTILS_DLTIUTILS_H +#define TPP_TRANSFORMS_UTILS_DLTIUTILS_H + +#include "mlir/Dialect/DLTI/DLTI.h" + +namespace llvm { +class StringRef; +} // namespace llvm + +namespace mlir { +namespace dlti { +namespace utils { + +// Perform a DLTI-query using string keys. +FailureOr query(Operation *op, ArrayRef keys, + bool emitError = false); + +} // namespace utils +} // namespace dlti +} // namespace mlir + +#endif // TPP_TRANSFORMS_UTILS_DLTIUTILS_H diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 58d1c73bd..d5d12a6f6 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -15,11 +15,12 @@ namespace mlir { class Type; -class MemRefType; +class ShapedType; class OpOperand; class AffineDimExpr; class AffineMap; class VectorType; +class Operation; namespace linalg { class LinalgOp; @@ -35,21 +36,25 @@ enum class VnniOperandRank { BRGEMM_OUTS = 3 }; -// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8. -std::optional getVnniBlockingFactor(Type type); +// Return the VNNI blocking factor if it can be determined for the given type or +// zero, otherwise. +// Optionally, an operation can be provided to give access to DLTI. +unsigned getVnniBlockingFactor(Type type, Operation *op = nullptr); -// Return true if the memref is in VNNI layout with rank `expectedRank`. -bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref); - -// Return true if the vector is in VNNI layout with rank `expectedRank`. -bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector); +// Return true if the shaped type is in VNNI layout with rank `expectedRank`. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, + std::optional blockingFactor = std::nullopt); -bool isInVnniLayout(int64_t expectedRank, VectorType vector); +// Return true if the shaped type is in VNNI layout with rank `expectedRank`. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +bool isInVnniLayout(int64_t expectedRank, ShapedType shape, + std::optional blockingFactor = std::nullopt); -// Return true if the operation is in VNNI layout. +// Return true if the linalg operation is in VNNI layout. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(linalg::LinalgOp linalgOp, - std::optional blockingFactor = std::nullopt); + std::optional blockingFactor = std::nullopt); } // namespace utils } // namespace vnni diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp index e1e003694..c9166341f 100644 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp @@ -1068,8 +1068,10 @@ struct ConvertVnniPacking : public OpRewritePattern { if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) return failure(); // Ajust ldo based on the VNNI factor. - unaryInfo.ldo = stridesOnOutput->front() / - *vnni::utils::getVnniBlockingFactor(out.getType()); + auto vnniFactor = + vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp); + assert(vnniFactor && "Failed to get VNNI blocking factor"); + unaryInfo.ldo = stridesOnOutput->front() / vnniFactor; auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( rewriter.getContext(), xsmm::UnaryFlags::NONE)); xsmm::UnaryKindAttr kind = @@ -1112,7 +1114,7 @@ struct ConvertGenericToVnniMatmulLikeOp // Take the whole reduction dim size. Account for the VNNI factor (ensured // by the earlier check) that splits the K dim in the shape. std::optional vnniFactor = - vnni::utils::getVnniBlockingFactor(bufferB.getType()); + vnni::utils::getVnniBlockingFactor(bufferB.getType(), genericOp); if (!vnniFactor) return rewriter.notifyMatchFailure(genericOp, "failed to determine VNNI factor"); diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp index 04c1f6fdb..b6f2ec3ee 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp @@ -99,7 +99,8 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp, if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE, outType)) { // Adjust ldo based on vnni factor - auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType); + auto vnniFactor = vnni::utils::getVnniBlockingFactor(outType, transposeOp); + assert(vnniFactor && "Failed to get VNNI blocking factor"); unaryInfo.ldo = unaryInfo.ldo / vnniFactor; } else { std::swap(unaryInfo.m, unaryInfo.n); diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp index c1e3209a9..bf04eb496 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp @@ -455,10 +455,12 @@ LogicalResult GemmOp::verify() { auto memref = dyn_cast(memrefOperands[idx].getType()); assert(memref && (memref.getRank() == 2 || memref.getRank() == 3)); - if (memref.getRank() == 3 && - !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM, - memref)) { - return emitOpError() << "expect VNNI layout for operand: " << actualIdx; + if (memref.getRank() == 3) { + if (memref.getShape().back() % 2 != 0 || + !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM, + memref)) { + return emitOpError() << "expect VNNI layout for operand: " << actualIdx; + } } } return success(); diff --git a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp index 5d5abc45f..2040e8833 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp @@ -71,21 +71,25 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) { : vnni::utils::VnniOperandRank::GEMM; // VNNI flags must be consistent with the memref shapes. + auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp); + ArrayAttr flags = dispatchOp->getFlags(); for (auto flag : flags) { int64_t gemmFlag = cast(flag).getInt(); if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_A) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA)) { + !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA, + vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand A or invalid VNNI_A flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_B) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB)) { + !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB, + vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand B or invalid VNNI_B flags"); } if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_C) && - !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC)) { + !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) { return gemmOp.emitOpError( "expect VNNI layout for operand C or invalid VNNI_C flags"); } diff --git a/lib/TPP/IR/MatcherUtils.cpp b/lib/TPP/IR/MatcherUtils.cpp index c74e6d5af..a0d045364 100644 --- a/lib/TPP/IR/MatcherUtils.cpp +++ b/lib/TPP/IR/MatcherUtils.cpp @@ -40,8 +40,8 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap, std::pair isMatmulVnniOp(linalg::GenericOp linalgOp, SmallVectorImpl *operands) { bool hasBatch = false; - auto blockingFactor = - vnni::utils::getVnniBlockingFactor(linalgOp->getOperands()[0].getType()); + auto blockingFactor = vnni::utils::getVnniBlockingFactor( + linalgOp->getOperands()[0].getType(), linalgOp); if (!blockingFactor) return std::make_pair(false, hasBatch); @@ -115,7 +115,7 @@ std::pair isMatmulVnniOp(linalg::GenericOp linalgOp, // At this point, the operation is a valid matmul contraction. // Finally, ensure that it is in VNNI layout. - bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor); + bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp); return std::make_pair(isVnniMatmul, hasBatch); } diff --git a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp index 540a10b22..3f24c1bdd 100644 --- a/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp +++ b/lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp @@ -8,6 +8,7 @@ #include "TPP/Passes.h" #include "TPP/Transforms/Transforms.h" +#include "TPP/Transforms/Utils/DLTIUtils.h" #include "TPP/Transforms/Utils/TransformUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -457,28 +458,10 @@ static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) { int64_t tile = 32; // Check if a tile size hint is associated to the IR via DLTI. - auto deriveFromDLTI = [&](ModuleOp moduleOp) { - if (!moduleOp) - return; - TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec(); - if (!sysSpec) - return; - auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU"); - auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId); - if (!deviceSpec) - return; - auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size"); - DataLayoutEntryInterface entry = - (*deviceSpec).getSpecForIdentifier(tileSizeId); - if (!entry) - return; - Attribute value = entry.getValue(); - if (auto intAttr = llvm::dyn_cast(value)) + auto tileValue = dlti::utils::query(linalgOp, {"CPU", "tile_size"}); + if (succeeded(tileValue)) + if (auto intAttr = llvm::dyn_cast(*tileValue)) tile = intAttr.getInt(); - // TODO: might want to print a warning if tile_size exists as a key but the - // associated attribute has an unexpected type. - }; - deriveFromDLTI(linalgOp->getParentOfType()); SmallVector loopsRange = linalgOp.getStaticLoopRanges(); if (loopsRange[dim] == ShapedType::kDynamic) diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index 52c0a5f17..688c492f6 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -333,19 +333,19 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter, OpOperand &operandB = matmulOp->getOpOperand(1); auto blockingFactor = - vnni::utils::getVnniBlockingFactor(operandB.get().getType()); + vnni::utils::getVnniBlockingFactor(operandB.get().getType(), matmulOp); if (!blockingFactor) { return rewriter.notifyMatchFailure(matmulOp, "unsupported blocking factor for type"); } - if (vnni::utils::isInVnniLayout(matmulOp, *blockingFactor)) { + if (vnni::utils::isInVnniLayout(matmulOp)) { return rewriter.notifyMatchFailure(matmulOp, "already packed to VNNI"); } Location loc = matmulOp.getLoc(); SmallVector tilesOnSmallK = { - rewriter.getI64IntegerAttr(*blockingFactor)}; + rewriter.getI64IntegerAttr(blockingFactor)}; SmallVector> kOperands; matmulOp.mapIterationSpaceDimToAllOperandDims(dims->k.back(), kOperands); if (kOperands.size() != 2) @@ -409,12 +409,14 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter, Value operandB = brgemmOp.getInputs()[1]; // Blocking factor on the `k` dimension. - auto blockingFactor = vnni::utils::getVnniBlockingFactor(operandB.getType()); + auto blockingFactor = + vnni::utils::getVnniBlockingFactor(operandB.getType(), brgemmOp); if (!blockingFactor) { return rewriter.notifyMatchFailure(brgemmOp, "unsupported blocking factor for type"); } - SmallVector tilesOnK = {rewriter.getI64IntegerAttr(2)}; + SmallVector tilesOnK = { + rewriter.getI64IntegerAttr(blockingFactor)}; Location loc = brgemmOp.getLoc(); // Reshape input A. diff --git a/lib/TPP/Transforms/Utils/CMakeLists.txt b/lib/TPP/Transforms/Utils/CMakeLists.txt index 4e7e484a8..a6b0c4501 100644 --- a/lib/TPP/Transforms/Utils/CMakeLists.txt +++ b/lib/TPP/Transforms/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TPPTransformsUtils BuilderUtils.cpp + DLTIUtils.cpp TensorInit.cpp TensorInitFloat.cpp TensorInitInt.cpp diff --git a/lib/TPP/Transforms/Utils/DLTIUtils.cpp b/lib/TPP/Transforms/Utils/DLTIUtils.cpp new file mode 100644 index 000000000..8fac42db2 --- /dev/null +++ b/lib/TPP/Transforms/Utils/DLTIUtils.cpp @@ -0,0 +1,32 @@ +//===- DLTIUtils.cpp ---------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Transforms/Utils/DLTIUtils.h" + +namespace mlir { +namespace dlti { +namespace utils { + +FailureOr query(Operation *op, ArrayRef keys, + bool emitError) { + if (!op) + return failure(); + + auto ctx = op->getContext(); + SmallVector entryKeys; + for (auto &key : keys) { + auto entry = StringAttr::get(ctx, key); + entryKeys.push_back(entry); + } + + return dlti::query(op, entryKeys, emitError); +} + +} // namespace utils +} // namespace dlti +} // namespace mlir diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index dd44c247f..87f290e25 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "TPP/Transforms/Utils/VNNIUtils.h" +#include "TPP/Transforms/Utils/DLTIUtils.h" + #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -20,26 +22,30 @@ namespace mlir { namespace vnni { namespace utils { -std::optional getVnniBlockingFactor(Type type) { - auto elementType = getElementTypeOrSelf(type); - if (elementType.isBF16()) - return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); - return std::nullopt; -} +unsigned getVnniBlockingFactor(Type type, Operation *op) { + unsigned blockingFactor = 0; -// Until we have a better way to express the VNNI layout (see: #563), it is up -// to the callee to specify the expected rank in the VNNI layout as the rank -// depends on the operations we are dealing with. -bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref) { - if (memref.getRank() != static_cast(expectedRank) || - !memref.getElementType().isBF16()) { - return false; + auto elementType = getElementTypeOrSelf(type); + if (elementType.isBF16()) { + // Check if a VNNI factor hint is associated to the IR via DLTI. + auto vnniValue = dlti::utils::query(op, {"CPU", "vnni"}); + if (succeeded(vnniValue)) { + if (auto intAttr = llvm::dyn_cast(*vnniValue)) + blockingFactor = intAttr.getInt(); + } else { + blockingFactor = libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16); + } } - return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); + + // Ensure that the factor is divisible by two. + if (blockingFactor % 2 != 0) + return 0; + + return blockingFactor; } bool isInVnniLayout(linalg::LinalgOp linalgOp, - std::optional blockingFactor) { + std::optional blockingFactor) { // Narrow down type operations - VNNI only applies to contractions. if (!linalg::isaContractionOpInterface(linalgOp)) return false; @@ -96,10 +102,12 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, // - statically known // - multiple of 2 or equal to the specified factor auto vnniDimSize = typeB.getShape().back(); - if (!(vnniDimSize != ShapedType::kDynamic && - typeA.getShape().back() == vnniDimSize && - (blockingFactor ? vnniDimSize == *blockingFactor - : vnniDimSize % 2 == 0))) + if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || + vnniDimSize % 2 != 0) + return false; + if (typeA.getShape().back() != vnniDimSize) + return false; + if (blockingFactor && vnniDimSize != *blockingFactor) return false; // The split reduction dimension size should also match. @@ -109,15 +117,24 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp, return true; } -bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) { - return isInVnniLayout(static_cast(expectedRank), vector); +bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape, + std::optional blockingFactor) { + return isInVnniLayout(static_cast(expectedRank), shape, + blockingFactor); } -bool isInVnniLayout(int64_t expectedRank, VectorType vector) { - if (vector.getRank() != expectedRank || !vector.getElementType().isBF16()) { +bool isInVnniLayout(int64_t expectedRank, ShapedType shape, + std::optional blockingFactor) { + if (shape.getRank() != expectedRank || !shape.getElementType().isBF16()) return false; - } - return vector.getShape().back() == vnni::utils::getVnniBlockingFactor(vector); + + auto vnniDim = shape.getShape().back(); + if (vnniDim == 0 || vnniDim % 2 != 0) + return false; + if (blockingFactor && vnniDim != *blockingFactor) + return false; + + return true; } } // namespace utils diff --git a/test/BF16/Integration/lit.local.cfg b/test/BF16/Integration/lit.local.cfg index 48448ac5c..5f963c2b5 100644 --- a/test/BF16/Integration/lit.local.cfg +++ b/test/BF16/Integration/lit.local.cfg @@ -19,6 +19,29 @@ def has_support(feature): return True +def is_arch(target): + # Arch detection not working on Windows + if sys.platform in ['win32']: + return False + + try: + cmd = subprocess.Popen( + ['uname', '-m'], stdout=subprocess.PIPE) + except OSError: + return False + + out = cmd.stdout.read().decode('ascii') + cmd.wait() + + return target in out + + # AVX512 and SVE should have BF16 support if not has_support('avx512') and not has_support('avx2') and not has_support('sve'): config.unsupported = True + +# Enable only on x86 +# Other targets may use different VNNI blocking scheme that is not compatible with +# prepacked shapes in some of the tests +if not is_arch('x86'): + config.unsupported = True diff --git a/test/BF16/Integration/mlir-gen-bf16.mlir b/test/BF16/Integration/mlir-gen-bf16.mlir index a0db89a6b..97035a7d1 100644 --- a/test/BF16/Integration/mlir-gen-bf16.mlir +++ b/test/BF16/Integration/mlir-gen-bf16.mlir @@ -1,28 +1,28 @@ // MLP without softmax (can't print packed version for now) -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void // Matmul only -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --float-type=bf16 | tpp-run -e entry -entry-point-result=void // Kernel - matmul -// RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 -// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 +// RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 +// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 // Kernel - fc -// RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 -// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 +// RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=16 --layers=16,16 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 // BF16/VNNI execution -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 --tiles=8,8,8 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF -// GEN-MATMUL-BF16: ( 11, 11, 11, 11, 11, 11, 11, 11, 11, 11 ) +// GEN-MATMUL-BF16: ( 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17 ) -// GEN-FC-BF16: ( 12, 12, 12, 12, 12, 12, 12, 12, 12, 12 ) +// GEN-FC-BF16: ( 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18 ) // PERF: {{[0-9]+}}{{.?}}{{[0-9e-]+}} diff --git a/test/BF16/Integration/vnni-xsmm-vs-loops.mlir b/test/BF16/Integration/vnni-xsmm-vs-loops.mlir index 2a8419395..0f7eb99d1 100644 --- a/test/BF16/Integration/vnni-xsmm-vs-loops.mlir +++ b/test/BF16/Integration/vnni-xsmm-vs-loops.mlir @@ -1,26 +1,13 @@ -// RUN: tpp-run %s -print -seed 123 \ +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 \ +// RUN: --tiles=16,16,16 --float-type=bf16 | \ +// RUN: tpp-opt --pack-vnni | \ +// RUN: tpp-run -print -seed 123 \ // RUN: -e entry -entry-point-result=void > %t.xsmm -// RUN: tpp-run %s -print -seed 123 -linalg-to-loops \ +// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=16 --layers=16,16 \ +// RUN: --tiles=16,16,16 --float-type=bf16 | \ +// RUN: tpp-opt --pack-vnni | \ +// RUN: tpp-run -print -seed 123 -linalg-to-loops \ // RUN: -e entry -entry-point-result=void > %t.loops // RUN: fpcmp -r 0.01 %t.xsmm %t.loops - -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6, d5, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> - -func.func @entry(%arg0: tensor<2x2x7x4x2xbf16>, %arg1: tensor<2x2x4x5x2xbf16>, - %arg2: tensor<2x2x7x5xbf16>) -> tensor<2x2x7x5xbf16> { - %1 = linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor<2x2x7x4x2xbf16>, tensor<2x2x4x5x2xbf16>) - outs(%arg2 : tensor<2x2x7x5xbf16>) { - ^bb0(%in: bf16, %in_0: bf16, %out: bf16): - %2 = arith.mulf %in, %in_0 : bf16 - %3 = arith.addf %out, %2 : bf16 - linalg.yield %3 : bf16 - } -> tensor<2x2x7x5xbf16> - return %1 : tensor<2x2x7x5xbf16> -} diff --git a/test/BF16/brgemm-tpp.mlir b/test/BF16/brgemm-tpp.mlir index 78caff1ed..7ab922e62 100644 --- a/test/BF16/brgemm-tpp.mlir +++ b/test/BF16/brgemm-tpp.mlir @@ -14,10 +14,10 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>, // CHECK-LABEL: brgemm // CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16> -// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape [32, 4, 2, 2] : tensor<32x4x4xbf16> into tensor<32x4x2x2xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x2xbf16> -// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [2] -// CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x2x4x2xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x{{2|1}}x4x{{2|4}}xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [{{2|4}}] +// CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: %{{.+}} = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] diff --git a/test/BF16/brgemm-vnni.mlir b/test/BF16/brgemm-vnni.mlir index aa6d069d2..5970ebec4 100644 --- a/test/BF16/brgemm-vnni.mlir +++ b/test/BF16/brgemm-vnni.mlir @@ -14,11 +14,11 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>, // CHECK-LABEL: brgemm // CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16> -// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape [32, 4, 2, 2] : tensor<32x4x4xbf16> into tensor<32x4x2x2xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x2xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] -// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [2] into %[[EMPTY]] -// CHECK-SAME: : tensor<32x4x4xbf16> -> tensor<32x2x4x2xbf16> +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [{{2|4}}] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16> // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] @@ -69,10 +69,10 @@ func.func @prepacked_matmul(%pack: tensor<4x4x32x32xbf16>, %pack_0: tensor<4x4x3 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4x32x32xbf16>, %[[ARG1:.+]]: tensor<4x4x32x32xbf16>, // CHECK-SAME: %[[ARG2:.+]]: tensor<4x4x32x32xbf16> // CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]] -// CHECK-SAME: output_shape [4, 4, 32, 16, 2] : tensor<4x4x32x32xbf16> into tensor<4x4x32x16x2xbf16> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x16x32x2xbf16> -// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [2] inner_tiles = [2] into %[[EMPTY]] -// CHECK-SAME: : tensor<4x4x32x32xbf16> -> tensor<4x4x16x32x2xbf16> +// CHECK-SAME: output_shape{{.*}}: tensor<4x4x32x32xbf16> into tensor<4x4x32x{{16|8}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x{{16|8}}x32x{{2|4}}xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [2] inner_tiles = [{{2|4}}] into %[[EMPTY]] +// CHECK-SAME: : tensor<4x4x32x32xbf16> -> tensor<4x4x{{16|8}}x32x{{2|4}}xbf16> // CHECK: {{.+}} = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "reduction"] diff --git a/test/BF16/matmul-untiled-vnni.mlir b/test/BF16/matmul-untiled-vnni.mlir index 2609ca90f..7a47d9b07 100644 --- a/test/BF16/matmul-untiled-vnni.mlir +++ b/test/BF16/matmul-untiled-vnni.mlir @@ -26,7 +26,7 @@ func.func @blocked_matmul(%arg0: tensor<32x64x4x4xbf16>, %arg1: tensor<128x64x4x // CHECK: %[[ARG0:.*]]: tensor<32x64x4x4xbf16>, // CHECK: %[[ARG1:.*]]: tensor<128x64x4x4xbf16>, // CHECK: %[[ARG2:.*]]: tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> { -// CHECK: %[[PACKBUF:.*]] = tensor.empty() : tensor<128x64x2x4x2xbf16> +// CHECK: %[[PACKBUF:.*]] = tensor.empty() : tensor<128x64x{{2|1}}x4x{{2|4}}xbf16> // CHECK: linalg.generic // CHECK: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] // CHECK: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "reduction"] diff --git a/test/BF16/matmul-vnni.mlir b/test/BF16/matmul-vnni.mlir index 2d4a5ffda..24e83a8b3 100644 --- a/test/BF16/matmul-vnni.mlir +++ b/test/BF16/matmul-vnni.mlir @@ -25,17 +25,17 @@ func.func @matmul_static( // CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] // CHECK-SAME: into %{{.+}} : tensor<512x1024xbf16> -> tensor<32x16x32x32xbf16> // CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1], [2], [3, 4]] -// CHECK-SAME: output_shape [8, 16, 32, 16, 2] : tensor<8x16x32x32xbf16> into tensor<8x16x32x16x2xbf16> -// CHECK: %[[EMPTY_2:.+]] = tensor.empty() : tensor<32x16x16x32x2xbf16> -// CHECK: %[[PACK_1:.+]] = tensor.pack %[[PACK_0]] inner_dims_pos = [2] inner_tiles = [2] into %[[EMPTY_2]] -// CHECK-SAME: : tensor<32x16x32x32xbf16> -> tensor<32x16x16x32x2xbf16> +// CHECK-SAME: output_shape{{.*}}: tensor<8x16x32x32xbf16> into tensor<8x16x32x{{16|8}}x{{2|4}}xbf16> +// CHECK: %[[EMPTY_2:.+]] = tensor.empty() : tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> +// CHECK: %[[PACK_1:.+]] = tensor.pack %[[PACK_0]] inner_dims_pos = [2] inner_tiles = [{{2|4}}] into %[[EMPTY_2]] +// CHECK-SAME: : tensor<32x16x32x32xbf16> -> tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> // CHECK: %{{.+}} = scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (8, 32) shared_outs(%[[ARG5:.+]] = %[[ARG2]]) // CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[ARG3]]) // CHECK: %[[APPLY_1:.+]] = affine.apply #[[MAP]](%[[ARG4]]) -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[VNNI_A]][%[[ARG3]], 0, 0, 0, 0] [1, 16, 32, 16, 2] [1, 1, 1, 1, 1] -// CHECK-SAME: : tensor<8x16x32x16x2xbf16> to tensor<16x32x16x2xbf16> -// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[PACK_1]][%[[ARG4]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] -// CHECK-SAME: : tensor<32x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[VNNI_A]][%[[ARG3]], 0, 0, 0, 0] [1, 16, 32, {{16|8}}, {{2|4}}] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<8x16x32x{{16|8}}x{{2|4}}xbf16> to tensor<16x32x{{16|8}}x{{2|4}}xbf16> +// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[PACK_1]][%[[ARG4]], 0, 0, 0, 0] [1, 16, {{16|8}}, 32, {{2|4}}] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> to tensor<16x{{16|8}}x32x{{2|4}}xbf16> // CHECK: %[[SLICE_3:.+]] = tensor.extract_slice %[[ARG5]][%[[APPLY]], %[[APPLY_1]]] [32, 32] [1, 1] // CHECK-SAME: : tensor<256x1024xbf16> to tensor<32x32xbf16> // CHECK: %[[GEMM:.+]] = linalg.generic diff --git a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir index 647a5881d..a83787d6c 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir @@ -285,21 +285,25 @@ func.func @simple_brgemm(%arg0: memref<2x32x32xf32>, %arg1: memref<2x32x32xf32>, #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> - -func.func @vnni_brgemm_interchanged(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] - : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm_interchanged(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] + : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) + outs(%arg2 : memref<32x32xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm_interchanged @@ -316,20 +320,25 @@ func.func @vnni_brgemm_interchanged(%arg0: memref<16x32x32xbf16>, %arg1: memref< #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> -func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] - : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] + : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) + outs(%arg2 : memref<32x32xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm @@ -346,22 +355,27 @@ func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> -func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, - %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, - %arg2: memref<8x8xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [8, 8, 4, 2] - : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) - outs(%arg2 : memref<8x8xbf16>) { - ^bb0(%in: bf16, %in_9: bf16, %out: bf16): - %11 = arith.mulf %in, %in_9 : bf16 - %12 = arith.addf %out, %11 : bf16 - linalg.yield %12 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, + %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, + %arg2: memref<8x8xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [8, 8, 4, 2] + : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>> into memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<8x8x4x2xbf16, strided<[64, 8, 2, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) + outs(%arg2 : memref<8x8xbf16>) { + ^bb0(%in: bf16, %in_9: bf16, %out: bf16): + %11 = arith.mulf %in, %in_9 : bf16 + %12 = arith.addf %out, %11 : bf16 + linalg.yield %12 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm_strided @@ -379,20 +393,25 @@ func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], off #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d2)> -func.func @vnni_brgemm_require_transpose_on_C(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] - : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_brgemm_require_transpose_on_C(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [16, 32, 16, 2] + : memref<16x32x32xbf16> into memref<16x32x16x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<16x32x16x2xbf16>, memref<16x16x32x2xbf16>) + outs(%arg2 : memref<32x32xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // CHECK-LABEL: vnni_brgemm_require_transpose_on_C diff --git a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir index de4d88ad1..c4e1e9eca 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir @@ -149,22 +149,28 @@ func.func @mha_projection(%arg0: memref<512x8x64xf32>, %arg1: memref<64x32x512xf #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +// Fix VNNI blocking factor for lit testing. +// Prevent mismatches due to target specific VNNI factors. +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: square_vnni_gemm @@ -179,20 +185,24 @@ func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ? #map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)> #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @expanded_arg_vnni_gemm(%arg0: memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expanded_arg_vnni_gemm(%arg0: memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expanded_arg_vnni_gemm @@ -211,21 +221,26 @@ func.func @expanded_arg_vnni_gemm(%arg0: memref<64x32x2xbf16, strided<[64, 2, 1] #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> // Require a transpose on C, before mapping to vnni Gemm. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -239,21 +254,26 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> // Not VNNI layout. A factor of 5 is not VNNI. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x160xbf16, strided<[160, 1], offset: ?>>, - %arg1: memref<32x64x5xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 5] - : memref<64x160xbf16, strided<[160, 1], offset: ?>> into memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>>, memref<32x64x5xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x160xbf16, strided<[160, 1], offset: ?>>, + %arg1: memref<32x64x5xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 5] + : memref<64x160xbf16, strided<[160, 1], offset: ?>> into memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x5xbf16, strided<[160, 5, 1], offset: ?>>, memref<32x64x5xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -267,19 +287,24 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x160xbf16, strided<[160 #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> // Require a transpose on A, before mapping to vnni Gemm. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : memref<32x64x2xbf16, strided<[128, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -294,21 +319,26 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<32x64x2xbf16, strided<[12 // Make sure we can handle interchange on the iterators, but with the right // access patterns. -func.func @vnni_gemm_interchanged(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_gemm_interchanged(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<32x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: vnni_gemm_interchanged @@ -325,21 +355,26 @@ func.func @vnni_gemm_interchanged(%arg0: memref<64x64xbf16, strided<[64, 1], off #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> // Not VNNI layout. The VNNI is not innermost in the access pattern for B. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<2x64x32xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] - : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<2x64x32xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<2x64x32xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 32, 2] + : memref<64x64xbf16, strided<[64, 1], offset: ?>> into memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x32x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<2x64x32xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: expect_not_to_match_vnni_gemm @@ -353,21 +388,26 @@ func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 8, 2] - : memref<64x16xbf16, strided<[64, 1], offset: ?>> into memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [64, 8, 2] + : memref<64x16xbf16, strided<[64, 1], offset: ?>> into memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<64x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) + outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: non_square_vnni_gemm @@ -383,21 +423,26 @@ func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offse #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> -func.func @non_square_vnni_gemm_1(%arg0: memref<4x16xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<8x64x2xbf16>, %arg2: memref<4x64xbf16, strided<[64, 1], offset: ?>>) { - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 8, 2] - : memref<4x16xbf16, strided<[64, 1], offset: ?>> into memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) - outs(%arg2 : memref<4x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @non_square_vnni_gemm_1(%arg0: memref<4x16xbf16, strided<[64, 1], offset: ?>>, + %arg1: memref<8x64x2xbf16>, %arg2: memref<4x64xbf16, strided<[64, 1], offset: ?>>) { + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 8, 2] + : memref<4x16xbf16, strided<[64, 1], offset: ?>> into memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<4x8x2xbf16, strided<[64, 2, 1], offset: ?>>, memref<8x64x2xbf16>) + outs(%arg2 : memref<4x64xbf16, strided<[64, 1], offset: ?>>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + return + } } // CHECK-LABEL: non_square_vnni_gemm_1 diff --git a/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir b/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir index 217491ebe..0f880eea0 100644 --- a/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir +++ b/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir @@ -295,14 +295,19 @@ func.func @identity_3(%arg0: memref<128x1xf32>, %arg1: memref<128x512xf32>) { // ----- -func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>, - %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xbf16, strided<[512, 1], offset: ?>> - into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>) - outs(%arg1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>, + %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] + : memref<32x32xbf16, strided<[512, 1], offset: ?>> + into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> + linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>) + outs(%arg1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] + return + } } // CHECK-LABEL: vnni_packing @@ -313,14 +318,19 @@ func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>> // ----- -func.func @not_vnni_packing(%arg0 : memref<32x32xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) { - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xf32, strided<[512, 1], offset: ?>> - into memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>) - outs(%arg1 : memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @not_vnni_packing(%arg0 : memref<32x32xf32, strided<[512, 1], offset: ?>>, + %arg1: memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) { + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] + : memref<32x32xf32, strided<[512, 1], offset: ?>> + into memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>> + linalg.transpose ins(%expand_shape : memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>) + outs(%arg1 : memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] + return + } } // CHECK-LABEL: not_vnni_packing @@ -351,21 +361,26 @@ func.func @identity_4(%arg0: memref<1024xbf16>, %arg1: memref<128x1024xbf16>) { #map = affine_map<(d0) -> (d0 * 32)> -func.func @vnni_packing_1(%arg1: memref<128x128xbf16>, %arg2: memref<4x4x16x32x2xbf16>) { - scf.forall (%arg3, %arg4) in (4, 4) { - %0 = affine.apply #map(%arg4) - %1 = affine.apply #map(%arg3) - %subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] - : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> - %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] - : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>) - outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) - permutation = [0, 2, 1] +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing_1(%arg1: memref<128x128xbf16>, %arg2: memref<4x4x16x32x2xbf16>) { + scf.forall (%arg3, %arg4) in (4, 4) { + %0 = affine.apply #map(%arg4) + %1 = affine.apply #map(%arg3) + %subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] + : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> + %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] + : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape[16, 2, 32] + : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> + linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>) + outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) + permutation = [0, 2, 1] + } + return } - return } // CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0 * 32)> diff --git a/test/Conversion/VectorToXsmm/vector-to-transpose.mlir b/test/Conversion/VectorToXsmm/vector-to-transpose.mlir index 57af7099c..99610d1b8 100644 --- a/test/Conversion/VectorToXsmm/vector-to-transpose.mlir +++ b/test/Conversion/VectorToXsmm/vector-to-transpose.mlir @@ -41,14 +41,19 @@ func.func @transpose_op_3d_f32(%arg0: memref<5x3x5xf32>, %arg1: memref<5x5x3xf32 // CHECK-NOT: call @xsmm_unary_invoke // ----- -func.func @vnni_packing_2d_bf16(%arg0: memref<32x32xbf16, strided<[512, 1], offset: ?>>, %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { - %cst = arith.constant 0.000000e+00 : bf16 - %c0 = arith.constant 0 : index - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[512, 1], offset: ?>> into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> - %0 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>, vector<16x2x32xbf16> - %1 = vector.transpose %0, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> - vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing_2d_bf16(%arg0: memref<32x32xbf16, strided<[512, 1], offset: ?>>, %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[512, 1], offset: ?>> into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> + %0 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>, vector<16x2x32xbf16> + %1 = vector.transpose %0, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> + vector.transfer_write %1, %arg1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + return + } } // CHECK-LABEL: func.func @vnni_packing_2d_bf16( @@ -87,20 +92,25 @@ func.func @not_vnni_packing_2d_f32(%arg0: memref<32x32xf32, strided<[512, 1], of // ----- #map = affine_map<(d0) -> (d0 * 32)> -func.func @vnni_packing_2d_bf16_forall(%arg0: memref<128x128xbf16>, %arg1: memref<4x4x16x32x2xbf16>) { - %cst = arith.constant 0.000000e+00 : bf16 - %c0 = arith.constant 0 : index - scf.forall (%arg2, %arg3) in (4, 4) { - %0 = affine.apply #map(%arg3) - %1 = affine.apply #map(%arg2) - %subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> - %subview_0 = memref.subview %arg1[%arg2, %arg3, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> - %2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>, vector<16x2x32xbf16> - %3 = vector.transpose %2, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> - vector.transfer_write %3, %subview_0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - } - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @vnni_packing_2d_bf16_forall(%arg0: memref<128x128xbf16>, %arg1: memref<4x4x16x32x2xbf16>) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + scf.forall (%arg2, %arg3) in (4, 4) { + %0 = affine.apply #map(%arg3) + %1 = affine.apply #map(%arg2) + %subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> + %subview_0 = memref.subview %arg1[%arg2, %arg3, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape [16, 2, 32] : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> + %2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>, vector<16x2x32xbf16> + %3 = vector.transpose %2, [0, 2, 1] : vector<16x2x32xbf16> to vector<16x32x2xbf16> + vector.transfer_write %3, %subview_0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<16x32x2xbf16>, memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> + } + return + } } // CHECK-LABEL: func.func @vnni_packing_2d_bf16_forall( @@ -126,4 +136,3 @@ func.func @vnni_packing_2d_bf16_forall(%arg0: memref<128x128xbf16>, %arg1: memre // CHECK-NEXT: %[[indexCast2:.*]] = arith.index_cast %[[intptr0]] // CHECK-NEXT: %[[inttoptr2:.*]] = llvm.inttoptr %[[indexCast2]] // CHECK: func.call @xsmm_unary_invoke(%[[c2_i64]], %[[dispatch]], %[[inttoptr]], %[[offset]], %[[inttoptr2]], %[[offset_1]]) - diff --git a/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir b/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir index 353408670..eec22627b 100644 --- a/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir +++ b/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir @@ -30,9 +30,14 @@ func.func @identity(%arg0: f32, %arg1: memref<1x1xf32>) { // ----- -func.func @gemm(%arg0: memref<3x6x2xbf16>, %arg1: memref<6x6xbf16>) { - %0 = xsmm.gemm.dispatch [6, 6, 6, 6, 6, 6] flags = (vnni_a) data_type = bf16 - xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x6x2xbf16>, memref<3x6x2xbf16>, memref<6x6xbf16>) -> () - return +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @gemm(%arg0: memref<3x6x2xbf16>, %arg1: memref<6x6xbf16>) { + %0 = xsmm.gemm.dispatch [6, 6, 6, 6, 6, 6] flags = (vnni_a) data_type = bf16 + xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : + (i64, memref<3x6x2xbf16>, memref<3x6x2xbf16>, memref<6x6xbf16>) -> () + return + } } diff --git a/test/Integration/hoist-vector-transfer-brgemm.mlir b/test/Integration/hoist-vector-transfer-brgemm.mlir index 37190b68c..3a1bab701 100644 --- a/test/Integration/hoist-vector-transfer-brgemm.mlir +++ b/test/Integration/hoist-vector-transfer-brgemm.mlir @@ -1,6 +1,6 @@ // RUN: tpp-run -e entry --entry-point-result=void -print %s > %t.1 // RUN: tpp-opt %s --loop-invariant-code-motion --vectorization-pass --loop-invariant-code-motion --hoist-vector-transfer | tpp-run -e entry --entry-point-result=void -print > %t.2 -// RUN: diff %t.1 %t.2 +// RUN: diff -q %t.1 %t.2 // RUN: rm %t.1 %t.2 module { diff --git a/test/Integration/transpose-bf16.mlir b/test/Integration/transpose-bf16.mlir index d4f4472a1..862f055a1 100644 --- a/test/Integration/transpose-bf16.mlir +++ b/test/Integration/transpose-bf16.mlir @@ -3,13 +3,18 @@ // RUN: tpp-opt --default-tpp-passes="vector-to-xsmm" %s -mlir-print-ir-after=vectorization-pass 2>&1 | FileCheck %s --check-prefix=VECTOR // RUN: tpp-run --vector-to-XSMM %s -e entry -entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s --check-prefix=XSMM -func.func @entry(%arg0 : tensor<4x4xbf16>, %arg1 : tensor<2x4x2xbf16>)-> tensor<2x4x2xbf16> { - %expand_shape = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape[2, 2, 4] - : tensor<4x4xbf16> - into tensor<2x2x4xbf16> - %retval = linalg.transpose ins(%expand_shape : tensor<2x2x4xbf16>) - outs(%arg1 : tensor<2x4x2xbf16>) permutation = [0, 2, 1] - return %retval: tensor<2x4x2xbf16> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @entry(%arg0 : tensor<4x4xbf16>, %arg1 : tensor<2x4x2xbf16>)-> tensor<2x4x2xbf16> { + %expand_shape = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape[2, 2, 4] + : tensor<4x4xbf16> + into tensor<2x2x4xbf16> + %retval = linalg.transpose ins(%expand_shape : tensor<2x2x4xbf16>) + outs(%arg1 : tensor<2x4x2xbf16>) permutation = [0, 2, 1] + return %retval: tensor<2x4x2xbf16> + } } // VECTOR: vector.transfer_read diff --git a/test/Integration/vector-contract-to-fma.mlir b/test/Integration/vector-contract-to-fma.mlir index 4d03e8bb8..c13b340a8 100644 --- a/test/Integration/vector-contract-to-fma.mlir +++ b/test/Integration/vector-contract-to-fma.mlir @@ -1,9 +1,7 @@ // RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.1 // RUN: tpp-opt %s --vector-contract-to-fma | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.2 -// RUN: diff %t.1 %t.2 -// RUN: rm %t.1 %t.2 +// RUN: fpcmp -r 0.001 %t.1 %t.2 -// DIFF-NOT: {{.}} #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> diff --git a/test/Passes/DefaultPipeline/linalg.mlir b/test/Passes/DefaultPipeline/linalg.mlir index e32309ef1..0c7f18f78 100644 --- a/test/Passes/DefaultPipeline/linalg.mlir +++ b/test/Passes/DefaultPipeline/linalg.mlir @@ -207,36 +207,41 @@ func.func @brgemm(%arg0: memref<2x3x4xf32>, %arg1: memref<2x4x3xf32>, %arg2: mem // CHECK-SAME: %[[ARG0:.+]]: memref<64x4x4xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<64x2x4x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<4x4xbf16> -func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, - %arg2: memref<4x4xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_brgemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [64, 4, 2, 2] - : memref<64x4x4xbf16> into memref<64x4x2x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<64x4x2x2xbf16>, memref<64x2x4x2xbf16>) - outs(%arg2 : memref<4x4xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, + %arg2: memref<4x4xbf16>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_brgemm_dispatch + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [64, 4, 2, 2] + : memref<64x4x4xbf16> into memref<64x4x2x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<64x4x2x2xbf16>, memref<64x2x4x2xbf16>) + outs(%arg2 : memref<4x4xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + return } - return } // ----- diff --git a/test/Passes/DefaultPipeline/vnni.mlir b/test/Passes/DefaultPipeline/vnni.mlir index ae54cde9e..7cebef489 100644 --- a/test/Passes/DefaultPipeline/vnni.mlir +++ b/test/Passes/DefaultPipeline/vnni.mlir @@ -8,39 +8,44 @@ // CHECK-SAME: %[[ARG0:.+]]: memref<128x1024xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<512x2048x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<128x2048xbf16>) -func.func @matmul_tensor(%arg0: tensor<128x1024xbf16>, - %arg1: tensor<512x2048x2xbf16>, - %arg2: tensor<128x2048xbf16>) -> tensor<128x2048xbf16> { - // CHECK: %[[of:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_gemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] - : tensor<128x1024xbf16> into tensor<128x512x2xbf16> - %result = linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : tensor<128x512x2xbf16>, tensor<512x2048x2xbf16>) - outs(%arg2 : tensor<128x2048xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } -> tensor<128x2048xbf16> - - return %result : tensor<128x2048xbf16> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @matmul_tensor(%arg0: tensor<128x1024xbf16>, + %arg1: tensor<512x2048x2xbf16>, + %arg2: tensor<128x2048xbf16>) -> tensor<128x2048xbf16> { + // CHECK: %[[of:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_gemm_dispatch + + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] + : tensor<128x1024xbf16> into tensor<128x512x2xbf16> + %result = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : tensor<128x512x2xbf16>, tensor<512x2048x2xbf16>) + outs(%arg2 : tensor<128x2048xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } -> tensor<128x2048xbf16> + + return %result : tensor<128x2048xbf16> + } } // ----- @@ -53,38 +58,43 @@ func.func @matmul_tensor(%arg0: tensor<128x1024xbf16>, // CHECK-SAME: %[[ARG0:.+]]: memref<128x1024xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<512x2048x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<128x2048xbf16>) -func.func @matmul_memref(%arg0: memref<128x1024xbf16>, - %arg1: memref<512x2048x2xbf16>, - %arg2: memref<128x2048xbf16>) -> memref<128x2048xbf16> { - // CHECK: call @xsmm_gemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] - : memref<128x1024xbf16> into memref<128x512x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%expanded, %arg1 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) - outs(%arg2 : memref<128x2048xbf16>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @matmul_memref(%arg0: memref<128x1024xbf16>, + %arg1: memref<512x2048x2xbf16>, + %arg2: memref<128x2048xbf16>) -> memref<128x2048xbf16> { + // CHECK: call @xsmm_gemm_dispatch + + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 512, 2] + : memref<128x1024xbf16> into memref<128x512x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"]} + ins(%expanded, %arg1 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>) + outs(%arg2 : memref<128x2048xbf16>) { + ^bb0(%in: bf16, %in_2: bf16, %out: bf16): + %1 = arith.mulf %in, %in_2 : bf16 + %2 = arith.addf %out, %1 : bf16 + linalg.yield %2 : bf16 + } + + return %arg2 : memref<128x2048xbf16> } - - return %arg2 : memref<128x2048xbf16> } // ----- @@ -97,39 +107,44 @@ func.func @matmul_memref(%arg0: memref<128x1024xbf16>, // CHECK: %[[ARG0:.+]]: memref<4x256x512xbf16>, // CHECK: %[[ARG1:.+]]: memref<4x512x1024xbf16>, // CHECK: %[[ARG2:.+]]: memref<256x1024xbf16>) -func.func @brgemm_static_tensor(%arg0: tensor<4x256x512xbf16>, %arg1: tensor<4x512x1024xbf16>, %arg2: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> { - // CHECK: %[[alloc:.*]] = memref.alloc{{.*}}: memref<4x256x1024x2xbf16> - %0 = tensor.empty() : tensor<4x256x1024x2xbf16> - %1 = tensor.pack %arg1 inner_dims_pos = [1] inner_tiles = [2] into %0 : tensor<4x512x1024xbf16> -> tensor<4x256x1024x2xbf16> - - // CHECK: call @xsmm_brgemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[alloc]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] - : tensor<4x256x512xbf16> into tensor<4x256x256x2xbf16> - %2 = linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %1 : tensor<4x256x256x2xbf16>, tensor<4x256x1024x2xbf16>) - outs(%arg2 : tensor<256x1024xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 - } -> tensor<256x1024xbf16> - - return %2 : tensor<256x1024xbf16> +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_static_tensor(%arg0: tensor<4x256x512xbf16>, %arg1: tensor<4x512x1024xbf16>, %arg2: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> { + // CHECK: %[[alloc:.*]] = memref.alloc{{.*}}: memref<4x256x1024x2xbf16> + %0 = tensor.empty() : tensor<4x256x1024x2xbf16> + %1 = tensor.pack %arg1 inner_dims_pos = [1] inner_tiles = [2] into %0 : tensor<4x512x1024xbf16> -> tensor<4x256x1024x2xbf16> + + // CHECK: call @xsmm_brgemm_dispatch + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[alloc]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] + : tensor<4x256x512xbf16> into tensor<4x256x256x2xbf16> + %2 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %1 : tensor<4x256x256x2xbf16>, tensor<4x256x1024x2xbf16>) + outs(%arg2 : tensor<256x1024xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } -> tensor<256x1024xbf16> + + return %2 : tensor<256x1024xbf16> + } } // ----- @@ -142,34 +157,39 @@ func.func @brgemm_static_tensor(%arg0: tensor<4x256x512xbf16>, %arg1: tensor<4x5 // CHECK: %[[ARG0:.+]]: memref<4x256x512xbf16>, // CHECK: %[[ARG1:.+]]: memref<4x256x1024x2xbf16>, // CHECK: %[[ARG2:.+]]: memref<256x1024xbf16>) -func.func @brgemm_static_memref(%arg0: memref<4x256x512xbf16>, %arg1: memref<4x256x1024x2xbf16>, %arg2: memref<256x1024xbf16>) -> memref<256x1024xbf16> { - // CHECK: call @xsmm_brgemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] - %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] - : memref<4x256x512xbf16> into memref<4x256x256x2xbf16> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%expanded, %arg1 : memref<4x256x256x2xbf16>, memref<4x256x1024x2xbf16>) - outs(%arg2 : memref<256x1024xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_static_memref(%arg0: memref<4x256x512xbf16>, %arg1: memref<4x256x1024x2xbf16>, %arg2: memref<256x1024xbf16>) -> memref<256x1024xbf16> { + // CHECK: call @xsmm_brgemm_dispatch + + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[of]], %[[llvm_ptr1]], %[[of]], %[[llvm_ptr2]], %[[of]] + %expanded = memref.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [4, 256, 256, 2] + : memref<4x256x512xbf16> into memref<4x256x256x2xbf16> + linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} + ins(%expanded, %arg1 : memref<4x256x256x2xbf16>, memref<4x256x1024x2xbf16>) + outs(%arg2 : memref<256x1024xbf16>) { + ^bb0(%in: bf16, %in_5: bf16, %out: bf16): + %5 = arith.mulf %in, %in_5 : bf16 + %6 = arith.addf %out, %5 : bf16 + linalg.yield %6 : bf16 + } + + return %arg2 : memref<256x1024xbf16> } - - return %arg2 : memref<256x1024xbf16> } diff --git a/test/Passes/DefaultPipeline/xsmm.mlir b/test/Passes/DefaultPipeline/xsmm.mlir index bee500e22..76fbb1976 100644 --- a/test/Passes/DefaultPipeline/xsmm.mlir +++ b/test/Passes/DefaultPipeline/xsmm.mlir @@ -220,30 +220,35 @@ func.func @brgemm(%arg0: memref<2x3x4xf32>, %arg1: memref<2x4x3xf32>, %arg2: mem // CHECK-SAME: %[[ARG0:.+]]: memref<64x4x4xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<64x2x4x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<4x4xbf16> -func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, - %arg2: memref<4x4xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_brgemm_dispatch +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, + %arg2: memref<4x4xbf16>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_brgemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x4x4xbf16> -> index - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x4x4xbf16> -> index + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<64x2x4x2xbf16> -> index - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<64x2x4x2xbf16> -> index + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] : memref<4x4xbf16> -> index - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] : memref<4x4xbf16> -> index + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %c64_i64 = arith.constant 64 : i64 - %0 = xsmm.brgemm.dispatch [4, 4, 4, 4, 4, 4, 16, 16] flags = (vnni_b) data_type = bf16 - xsmm.brgemm(data_type = bf16, %0, %arg0, %arg1, %arg2, %c64_i64) - : (i64, memref<64x4x4xbf16>, memref<64x2x4x2xbf16>, memref<4x4xbf16>, i64) -> () + // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] + %c64_i64 = arith.constant 64 : i64 + %0 = xsmm.brgemm.dispatch [4, 4, 4, 4, 4, 4, 16, 16] flags = (vnni_b) data_type = bf16 + xsmm.brgemm(data_type = bf16, %0, %arg0, %arg1, %arg2, %c64_i64) + : (i64, memref<64x4x4xbf16>, memref<64x2x4x2xbf16>, memref<4x4xbf16>, i64) -> () - return + return + } } // ----- @@ -282,28 +287,33 @@ func.func @gemm(%A: memref<4x8xf32>, // CHECK-SAME: %[[ARG0:.+]]: memref<6x10xbf16>, // CHECK-SAME: %[[ARG1:.+]]: memref<5x6x2xbf16>, // CHECK-SAME: %[[ARG2:.+]]: memref<6x6xbf16> -func.func @gemm_bf16(%arg0: memref<6x10xbf16>, %arg1: memref<5x6x2xbf16>, - %arg2: memref<6x6xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_gemm_dispatch +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @gemm_bf16(%arg0: memref<6x10xbf16>, %arg1: memref<5x6x2xbf16>, + %arg2: memref<6x6xbf16>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: call @xsmm_gemm_dispatch - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr + // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] + // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr + // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] + // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr + // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] + // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 + // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %0 = xsmm.gemm.dispatch [6, 6, 10, 10, 6, 6] flags = (vnni_b) data_type = bf16 - xsmm.gemm(data_type = bf16, %0, %arg0, %arg1, %arg2) : (i64, memref<6x10xbf16>, memref<5x6x2xbf16>, memref<6x6xbf16>) -> () + // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] + %0 = xsmm.gemm.dispatch [6, 6, 10, 10, 6, 6] flags = (vnni_b) data_type = bf16 + xsmm.gemm(data_type = bf16, %0, %arg0, %arg1, %arg2) : (i64, memref<6x10xbf16>, memref<5x6x2xbf16>, memref<6x6xbf16>) -> () - return + return + } } // ----- diff --git a/test/Passes/pack-vnni.mlir b/test/Passes/pack-vnni.mlir new file mode 100644 index 000000000..e30050712 --- /dev/null +++ b/test/Passes/pack-vnni.mlir @@ -0,0 +1,104 @@ +// RUN: tpp-opt -pack-vnni -split-input-file %s | FileCheck %s + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + func.func @brgemm_vnni_2(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> + +// CHECK-LABEL: @brgemm_vnni_2( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: output_shape{{.*}}: tensor<5x32x64xbf16> into tensor<5x32x32x2xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [2] +// CHECK-SAME: : tensor<5x64x32xbf16> -> tensor<5x32x32x2xbf16> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[VNNI_A]], %[[PACK]] +// CHECK-SAME: outs(%[[ARG2]] + +// ----- + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 4 : i32>> +} { + func.func @brgemm_vnni_4(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> + +// CHECK-LABEL: @brgemm_vnni_4( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: output_shape{{.*}}: tensor<5x32x64xbf16> into tensor<5x32x16x4xbf16> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] +// CHECK-SAME: : tensor<5x64x32xbf16> -> tensor<5x16x32x4xbf16> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[VNNI_A]], %[[PACK]] +// CHECK-SAME: outs(%[[ARG2]] + +// ----- + +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 0 : i32>> +} { + func.func @invalid_vnni_factor_0(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK-LABEL: @invalid_vnni_factor_0( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK-NOT: linalg.generic +// CHECK: linalg.batch_reduce_matmul + +// ----- + +// Blocking factor is expected to be divisible by 2. +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 5 : i32>> +} { + func.func @invalid_vnni_factor_5(%arg0: tensor<5x32x64xbf16>, %arg1: tensor<5x64x32xbf16>, + %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16>{ + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1: tensor<5x32x64xbf16>, tensor<5x64x32xbf16>) + outs(%arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0: tensor<32x32xbf16> + } +} + +// CHECK-LABEL: @invalid_vnni_factor_5( +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x32x64xbf16>, %[[ARG1:.+]]: tensor<5x64x32xbf16>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<32x32xbf16> +// CHECK-NOT: linalg.generic +// CHECK: linalg.batch_reduce_matmul diff --git a/test/Passes/xsmm-combine.mlir b/test/Passes/xsmm-combine.mlir index 8edf7a2bc..e8d6ad5c5 100644 --- a/test/Passes/xsmm-combine.mlir +++ b/test/Passes/xsmm-combine.mlir @@ -133,39 +133,44 @@ func.func @none_on_binary_add(%arg0: memref<256x128xf32>) -> memref<256x512xf32> // ----- -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} -// Bcast_col_in0 flag set on binary add -func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %2, %subview, %subview) : (i64, memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce + // Bcast_col_in0 flag set on binary add + func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> + %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> + %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> + %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> + %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = bf16 + %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + xsmm.binary add(data_type = bf16, %5, %2, %subview, %subview) : (i64, memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + scf.reduce + } + return %alloc_1 : memref<256x512xbf16> } - return %alloc_1 : memref<256x512xbf16> } // CHECK-LABEL: func.func @bcast_col_in0_on_binary_add_bf16( @@ -176,39 +181,44 @@ func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memr // ----- -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} -// Bcast_col_in1 flag set on binary add -func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in1) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce + // Bcast_col_in1 flag set on binary add + func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> + %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> + %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> + %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> + %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in1) data_type = bf16 + %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + scf.reduce + } + return %alloc_1 : memref<256x512xbf16> } - return %alloc_1 : memref<256x512xbf16> } // CHECK-LABEL: func.func @bcast_col_in1_on_binary_add_bf16( @@ -220,39 +230,44 @@ func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memr // ----- -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32x32xbf16: memref<32x32xbf16> = dense<1.000000e+00> {alignment = 128 : i64} +module attributes { + "#dlti.sys_spec" = #dlti.target_system_spec<"CPU" + = #dlti.target_device_spec<"vnni" = 2 : i32>> +} { + memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} + memref.global "private" constant @__constant_32x32xbf16: memref<32x32xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -// None flag set on binary add -func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32x32xbf16 : memref<32x32xbf16> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32x32xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce + // None flag set on binary add + func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %cst = arith.constant 0.000000e+00 : bf16 + %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> + %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> + %2 = memref.get_global @__constant_32x32xbf16 : memref<32x32xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> + %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> + %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 + %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = bf16 + %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> + scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () + xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32x32xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () + scf.reduce + } + return %alloc_1 : memref<256x512xbf16> } - return %alloc_1 : memref<256x512xbf16> } // CHECK-LABEL: func.func @none_on_binary_add_bf16( @@ -305,4 +320,3 @@ func.func @forward(%arg0: memref<256x1024xf32>) -> memref<256x1024xf32> { // CHECK: xsmm.fused_brgemm(data_type = f32, %[[temp2]], %[[subview_2]], %{{.*}}, %[[subview]], %{{.*}} %[[c32_i64]]) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> () // CHECK: } // CHECK: return %{{.*}} : memref<256x1024xf32> -