Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make VNNI more robust #1001

Merged
merged 26 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/TPP/Transforms/Utils/DLTIUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- DLTIUtils.h -----------------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef TPP_TRANSFORMS_UTILS_DLTIUTILS_H
#define TPP_TRANSFORMS_UTILS_DLTIUTILS_H

#include "mlir/Dialect/DLTI/DLTI.h"

namespace llvm {
class StringRef;
} // namespace llvm

namespace mlir {
namespace dlti {
namespace utils {

// Perform a DLTI-query using string keys.
FailureOr<Attribute> query(Operation *op, ArrayRef<StringRef> keys,
bool emitError = false);

} // namespace utils
} // namespace dlti
} // namespace mlir

#endif // TPP_TRANSFORMS_UTILS_DLTIUTILS_H
27 changes: 16 additions & 11 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

namespace mlir {
class Type;
class MemRefType;
class ShapedType;
class OpOperand;
class AffineDimExpr;
class AffineMap;
class VectorType;
class Operation;

namespace linalg {
class LinalgOp;
Expand All @@ -35,21 +36,25 @@ enum class VnniOperandRank {
BRGEMM_OUTS = 3
};

// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8.
std::optional<int64_t> getVnniBlockingFactor(Type type);
// Return the VNNI blocking factor if it can be determined for the given type or
// zero, otherwise.
// Optionally, an operation can be provided to give access to DLTI.
unsigned getVnniBlockingFactor(Type type, Operation *op = nullptr);

// Return true if the memref is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref);

// Return true if the vector is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector);
// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
std::optional<unsigned> blockingFactor = std::nullopt);

bool isInVnniLayout(int64_t expectedRank, VectorType vector);
// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
std::optional<unsigned> blockingFactor = std::nullopt);

// Return true if the operation is in VNNI layout.
// Return true if the linalg operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);
std::optional<unsigned> blockingFactor = std::nullopt);

} // namespace utils
} // namespace vnni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,10 @@ struct ConvertVnniPacking : public OpRewritePattern<linalg::TransposeOp> {
if (failed(stridesOnOutput) || stridesOnOutput->back() != 1)
return failure();
// Ajust ldo based on the VNNI factor.
unaryInfo.ldo = stridesOnOutput->front() /
*vnni::utils::getVnniBlockingFactor(out.getType());
auto vnniFactor =
vnni::utils::getVnniBlockingFactor(out.getType(), transposeOp);
assert(vnniFactor && "Failed to get VNNI blocking factor");
unaryInfo.ldo = stridesOnOutput->front() / vnniFactor;
auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get(
rewriter.getContext(), xsmm::UnaryFlags::NONE));
xsmm::UnaryKindAttr kind =
Expand Down Expand Up @@ -1112,7 +1114,7 @@ struct ConvertGenericToVnniMatmulLikeOp
// Take the whole reduction dim size. Account for the VNNI factor (ensured
// by the earlier check) that splits the K dim in the shape.
std::optional<int64_t> vnniFactor =
vnni::utils::getVnniBlockingFactor(bufferB.getType());
vnni::utils::getVnniBlockingFactor(bufferB.getType(), genericOp);
if (!vnniFactor)
return rewriter.notifyMatchFailure(genericOp,
"failed to determine VNNI factor");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ convertTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
if (vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE,
outType)) {
// Adjust ldo based on vnni factor
auto vnniFactor = *vnni::utils::getVnniBlockingFactor(outType);
auto vnniFactor = vnni::utils::getVnniBlockingFactor(outType, transposeOp);
assert(vnniFactor && "Failed to get VNNI blocking factor");
unaryInfo.ldo = unaryInfo.ldo / vnniFactor;
} else {
std::swap(unaryInfo.m, unaryInfo.n);
Expand Down
10 changes: 6 additions & 4 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,12 @@ LogicalResult GemmOp::verify() {
auto memref = dyn_cast<MemRefType>(memrefOperands[idx].getType());
assert(memref && (memref.getRank() == 2 || memref.getRank() == 3));

if (memref.getRank() == 3 &&
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
if (memref.getRank() == 3) {
if (memref.getShape().back() % 2 != 0 ||
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
}
}
}
return success();
Expand Down
10 changes: 7 additions & 3 deletions lib/TPP/Dialect/Xsmm/XsmmVerify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,25 @@ static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) {
: vnni::utils::VnniOperandRank::GEMM;

// VNNI flags must be consistent with the memref shapes.
auto vnniFactor = vnni::utils::getVnniBlockingFactor(operandA, gemmOp);

ArrayAttr flags = dispatchOp->getFlags();
for (auto flag : flags) {
int64_t gemmFlag = cast<IntegerAttr>(flag).getInt();
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_A) &&
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA)) {
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA,
vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand A or invalid VNNI_A flags");
}
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_B) &&
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB)) {
!vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB,
vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand B or invalid VNNI_B flags");
}
if (gemmFlag == static_cast<int64_t>(xsmm::GemmFlags::VNNI_C) &&
!vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC)) {
!vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC, vnniFactor)) {
return gemmOp.emitOpError(
"expect VNNI layout for operand C or invalid VNNI_C flags");
}
Expand Down
6 changes: 3 additions & 3 deletions lib/TPP/IR/MatcherUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ getIteratorPos(linalg::LinalgOp linalgOp, AffineMap indexingMap,
std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,
SmallVectorImpl<Value> *operands) {
bool hasBatch = false;
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(linalgOp->getOperands()[0].getType());
auto blockingFactor = vnni::utils::getVnniBlockingFactor(
linalgOp->getOperands()[0].getType(), linalgOp);
if (!blockingFactor)
return std::make_pair(false, hasBatch);

Expand Down Expand Up @@ -115,7 +115,7 @@ std::pair<bool, bool> isMatmulVnniOp(linalg::GenericOp linalgOp,

// At this point, the operation is a valid matmul contraction.
// Finally, ensure that it is in VNNI layout.
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp, *blockingFactor);
bool isVnniMatmul = vnni::utils::isInVnniLayout(linalgOp);
return std::make_pair(isVnniMatmul, hasBatch);
}

Expand Down
25 changes: 4 additions & 21 deletions lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "TPP/Passes.h"
#include "TPP/Transforms/Transforms.h"
#include "TPP/Transforms/Utils/DLTIUtils.h"
#include "TPP/Transforms/Utils/TransformUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -457,28 +458,10 @@ static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) {
int64_t tile = 32;

// Check if a tile size hint is associated to the IR via DLTI.
auto deriveFromDLTI = [&](ModuleOp moduleOp) {
if (!moduleOp)
return;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return;
auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return;
auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(tileSizeId);
if (!entry)
return;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
auto tileValue = dlti::utils::query(linalgOp, {"CPU", "tile_size"});
if (succeeded(tileValue))
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(*tileValue))
tile = intAttr.getInt();
// TODO: might want to print a warning if tile_size exists as a key but the
// associated attribute has an unexpected type.
};
deriveFromDLTI(linalgOp->getParentOfType<mlir::ModuleOp>());

SmallVector<int64_t, 4> loopsRange = linalgOp.getStaticLoopRanges();
if (loopsRange[dim] == ShapedType::kDynamic)
Expand Down
12 changes: 7 additions & 5 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,19 @@ mlir::linalgx::packVNNIMatmulOp(RewriterBase &rewriter,

OpOperand &operandB = matmulOp->getOpOperand(1);
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(operandB.get().getType());
vnni::utils::getVnniBlockingFactor(operandB.get().getType(), matmulOp);
if (!blockingFactor) {
return rewriter.notifyMatchFailure(matmulOp,
"unsupported blocking factor for type");
}

if (vnni::utils::isInVnniLayout(matmulOp, *blockingFactor)) {
if (vnni::utils::isInVnniLayout(matmulOp)) {
return rewriter.notifyMatchFailure(matmulOp, "already packed to VNNI");
}

Location loc = matmulOp.getLoc();
SmallVector<OpFoldResult> tilesOnSmallK = {
rewriter.getI64IntegerAttr(*blockingFactor)};
rewriter.getI64IntegerAttr(blockingFactor)};
SmallVector<std::pair<Value, unsigned>> kOperands;
matmulOp.mapIterationSpaceDimToAllOperandDims(dims->k.back(), kOperands);
if (kOperands.size() != 2)
Expand Down Expand Up @@ -409,12 +409,14 @@ mlir::linalgx::packVNNIBRGemmOp(RewriterBase &rewriter,

Value operandB = brgemmOp.getInputs()[1];
// Blocking factor on the `k` dimension.
auto blockingFactor = vnni::utils::getVnniBlockingFactor(operandB.getType());
auto blockingFactor =
vnni::utils::getVnniBlockingFactor(operandB.getType(), brgemmOp);
if (!blockingFactor) {
return rewriter.notifyMatchFailure(brgemmOp,
"unsupported blocking factor for type");
}
SmallVector<OpFoldResult> tilesOnK = {rewriter.getI64IntegerAttr(2)};
SmallVector<OpFoldResult> tilesOnK = {
rewriter.getI64IntegerAttr(blockingFactor)};

Location loc = brgemmOp.getLoc();
// Reshape input A.
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_library(TPPTransformsUtils
BuilderUtils.cpp
DLTIUtils.cpp
TensorInit.cpp
TensorInitFloat.cpp
TensorInitInt.cpp
Expand Down
32 changes: 32 additions & 0 deletions lib/TPP/Transforms/Utils/DLTIUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- DLTIUtils.cpp ---------------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Utils/DLTIUtils.h"

namespace mlir {
namespace dlti {
namespace utils {

FailureOr<Attribute> query(Operation *op, ArrayRef<StringRef> keys,
bool emitError) {
if (!op)
return failure();

auto ctx = op->getContext();
SmallVector<DataLayoutEntryKey> entryKeys;
for (auto &key : keys) {
auto entry = StringAttr::get(ctx, key);
entryKeys.push_back(entry);
}

return dlti::query(op, entryKeys, emitError);
}

} // namespace utils
} // namespace dlti
} // namespace mlir
Loading