Skip to content

Commit

Permalink
Make VNNI more robust (#1001)
Browse files Browse the repository at this point in the history
Improves VNNI validation and lit tests to be target independent w.r.t.
VNNI blocking factor.

VNNI infrastructure is extended with DLTI support that allows to
override target's default VNNI factor.
This allows to fix VNNI shape for lit testing to ensure that general
pass logic works fine. This is primarily applied to tests that use
prepacked VNNI shapes.

When possible, tests' checks are generalized to account for varying
packing factor.
Test shapes are adjusted to prevent failure due to incompatible VNNI
packing sizes compared to dimension size used for verification.

BF16 integration tests are now disabled on non-x86 systems to avoid
runtime mismatch between prepacked shapes and microkernel requirements.
  • Loading branch information
adam-smnk authored Jan 17, 2025
1 parent 7e6b24d commit 815ce3f
Show file tree
Hide file tree
Showing 32 changed files with 1,012 additions and 674 deletions.
30 changes: 30 additions & 0 deletions include/TPP/Transforms/Utils/DLTIUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- DLTIUtils.h -----------------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef TPP_TRANSFORMS_UTILS_DLTIUTILS_H
#define TPP_TRANSFORMS_UTILS_DLTIUTILS_H

#include "mlir/Dialect/DLTI/DLTI.h"

namespace llvm {
class StringRef;
} // namespace llvm

namespace mlir {
namespace dlti {
namespace utils {

// Perform a DLTI-query using string keys.
FailureOr<Attribute> query(Operation *op, ArrayRef<StringRef> keys,
bool emitError = false);

} // namespace utils
} // namespace dlti
} // namespace mlir

#endif // TPP_TRANSFORMS_UTILS_DLTIUTILS_H
27 changes: 16 additions & 11 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

namespace mlir {
class Type;
class MemRefType;
class ShapedType;
class OpOperand;
class AffineDimExpr;
class AffineMap;
class VectorType;
class Operation;

namespace linalg {
class LinalgOp;
Expand All @@ -35,21 +36,25 @@ enum class VnniOperandRank {
BRGEMM_OUTS = 3
};

// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8.
std::optional<int64_t> getVnniBlockingFactor(Type type);
// Return the VNNI blocking factor if it can be determined for the given type or
// zero, otherwise.
// Optionally, an operation can be provided to give access to DLTI.
unsigned getVnniBlockingFactor(Type type, Operation *op = nullptr);

// Return true if the memref is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref);

// Return true if the vector is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector);
// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
std::optional<unsigned> blockingFactor = std::nullopt);

bool isInVnniLayout(int64_t expectedRank, VectorType vector);
// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
std::optional<unsigned> blockingFactor = std::nullopt);

// Return true if the operation is in VNNI layout.
// Return true if the linalg operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);
std::optional<unsigned> blockingFactor = std::nullopt);

} // namespace utils
} // namespace vnni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,10 @@ struct ConvertVnniPacking : public OpRewritePattern<linalg::TransposeOp> {
if (failed(stridesOnOutput) || stridesOnOutput->back() != 1)
return failure();
// Ajust ldo based on the VNNI factor.
unaryInfo.ldo = stridesOnOutput->front() /
*vnni::utils::getVnniBlockingFactor(out.getType());
auto vnniFactor =
vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp);
assert(vnniFactor && "Failed to get VNNI blocking factor");
unaryInfo.ldo = stridesOnOutput->front() / vnniFactor;
auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get(
rewriter.getContext(), xsmm::UnaryFlags::NONE));
xsmm::UnaryKindAttr kind =
Expand Down Expand Up @@ -1112,7 +1114,7 @@ struct ConvertGenericToVnniMatmulLikeOp
// Take the whole reduction dim size. Account for the VNNI factor (ensured
// by the earlier check) that splits the K dim in the shape.
std::optional<int64_t> vnniFactor =
vnni::utils::getVnniBlockingFactor(bufferB.getType());
vnni::utils::getVnniBlockingFactor(bufferB.getType(), genericOp);
if (!vnniFactor)
return rewriter.notifyMatchFailure(genericOp,
"failed to determine VNNI factor");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE,
outType)) {
// Adjust ldo based on vnni factor
auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType);
auto vnniFactor = vnni::utils::getVnniBlockingFactor(outType, transposeOp);
assert(vnniFactor && "Failed to get VNNI blocking factor");
unaryInfo.ldo = unaryInfo.ldo / vnniFactor;
} else {
std::swap(unaryInfo.m, unaryInfo.n);
Expand Down
10 changes: 6 additions & 4 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,12 @@ LogicalResult GemmOp::verify() {
auto memref = dyn_cast<MemRefType>(memrefOperands[idx].getType());
assert(memref && (memref.getRank() == 2 || memref.getRank() == 3));

if (memref.getRank() == 3 &&
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
if (memref.getRank() == 3) {
if (memref.getShape().back() % 2 != 0 ||
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
}
}
}
return success();
Expand Down
10 changes: 7 additions & 3 deletions lib/TPP/Dialect/Xsmm/XsmmVerify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,25 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) {
: vnni::utils::VnniOperandRank::GEMM;

// VNNI flags must be consistent with the memref shapes.
auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp);

ArrayAttr flags = dispatchOp->getFlags();
for (auto flag : flags) {
int64_t gemmFlag = cast<IntegerAttr>(flag).getInt();
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_A) &&
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA)) {
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA,
vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand A or invalid VNNI_A flags");
}
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_B) &&
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB)) {
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB,
vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand B or invalid VNNI_B flags");
}
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_C) &&
!vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC)) {
!vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand C or invalid VNNI_C flags");
}
Expand Down
6 changes: 3 additions & 3 deletions lib/TPP/IR/MatcherUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap,
std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *operands) {
bool hasBatch = false;
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(linalgOp->getOperands()[0].getType());
auto blockingFactor = vnni::utils::getVnniBlockingFactor(
linalgOp->getOperands()[0].getType(), linalgOp);
if (!blockingFactor)
return std::make_pair(false, hasBatch);

Expand Down Expand Up @@ -115,7 +115,7 @@ std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,

// At this point, the operation is a valid matmul contraction.
// Finally, ensure that it is in VNNI layout.
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor);
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp);
return std::make_pair(isVnniMatmul, hasBatch);
}

Expand Down
25 changes: 4 additions & 21 deletions lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "TPP/Passes.h"
#include "TPP/Transforms/Transforms.h"
#include "TPP/Transforms/Utils/DLTIUtils.h"
#include "TPP/Transforms/Utils/TransformUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -457,28 +458,10 @@ static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) {
int64_t tile = 32;

// Check if a tile size hint is associated to the IR via DLTI.
auto deriveFromDLTI = [&](ModuleOp moduleOp) {
if (!moduleOp)
return;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return;
auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return;
auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(tileSizeId);
if (!entry)
return;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
auto tileValue = dlti::utils::query(linalgOp, {"CPU", "tile_size"});
if (succeeded(tileValue))
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(*tileValue))
tile = intAttr.getInt();
// TODO: might want to print a warning if tile_size exists as a key but the
// associated attribute has an unexpected type.
};
deriveFromDLTI(linalgOp->getParentOfType<mlir::ModuleOp>());

SmallVector<int64_t, 4> loopsRange = linalgOp.getStaticLoopRanges();
if (loopsRange[dim] == ShapedType::kDynamic)
Expand Down
12 changes: 7 additions & 5 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,19 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter,

OpOperand &operandB = matmulOp->getOpOperand(1);
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(operandB.get().getType());
vnni::utils::getVnniBlockingFactor(operandB.get().getType(), matmulOp);
if (!blockingFactor) {
return rewriter.notifyMatchFailure(matmulOp,
"unsupported blocking factor for type");
}

if (vnni::utils::isInVnniLayout(matmulOp, *blockingFactor)) {
if (vnni::utils::isInVnniLayout(matmulOp)) {
return rewriter.notifyMatchFailure(matmulOp, "already packed to VNNI");
}

Location loc = matmulOp.getLoc();
SmallVector<OpFoldResult> tilesOnSmallK = {
rewriter.getI64IntegerAttr(*blockingFactor)};
rewriter.getI64IntegerAttr(blockingFactor)};
SmallVector<std::pair<Value, unsigned>> kOperands;
matmulOp.mapIterationSpaceDimToAllOperandDims(dims->k.back(), kOperands);
if (kOperands.size() != 2)
Expand Down Expand Up @@ -409,12 +409,14 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter,

Value operandB = brgemmOp.getInputs()[1];
// Blocking factor on the `k` dimension.
auto blockingFactor = vnni::utils::getVnniBlockingFactor(operandB.getType());
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(operandB.getType(), brgemmOp);
if (!blockingFactor) {
return rewriter.notifyMatchFailure(brgemmOp,
"unsupported blocking factor for type");
}
SmallVector<OpFoldResult> tilesOnK = {rewriter.getI64IntegerAttr(2)};
SmallVector<OpFoldResult> tilesOnK = {
rewriter.getI64IntegerAttr(blockingFactor)};

Location loc = brgemmOp.getLoc();
// Reshape input A.
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_library(TPPTransformsUtils
BuilderUtils.cpp
DLTIUtils.cpp
TensorInit.cpp
TensorInitFloat.cpp
TensorInitInt.cpp
Expand Down
32 changes: 32 additions & 0 deletions lib/TPP/Transforms/Utils/DLTIUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- DLTIUtils.cpp ---------------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Utils/DLTIUtils.h"

namespace mlir {
namespace dlti {
namespace utils {

FailureOr<Attribute> query(Operation *op, ArrayRef<StringRef> keys,
bool emitError) {
if (!op)
return failure();

auto ctx = op->getContext();
SmallVector<DataLayoutEntryKey> entryKeys;
for (auto &key : keys) {
auto entry = StringAttr::get(ctx, key);
entryKeys.push_back(entry);
}

return dlti::query(op, entryKeys, emitError);
}

} // namespace utils
} // namespace dlti
} // namespace mlir
Loading

0 comments on commit 815ce3f

Please sign in to comment.