Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jan 16, 2025
1 parent 76bcc85 commit e7d8a96
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 79 deletions.
19 changes: 7 additions & 12 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

namespace mlir {
class Type;
class MemRefType;
class ShapedType;
class OpOperand;
class AffineDimExpr;
class AffineMap;
Expand All @@ -37,26 +37,21 @@ enum class VnniOperandRank {
};

// Return the VNNI blocking factor.
// Optionally, operation can be provided to give access to DLTI.
// Optionally, an operation can be provided to give access to DLTI.
std::optional<int64_t> getVnniBlockingFactor(Type type,
Operation *op = nullptr);

// Return true if the memref is in VNNI layout with rank `expectedRank`.
// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref,
bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor = std::nullopt);

// Return true if the vector is in VNNI layout with rank `expectedRank`.
// Return true if the shaped type is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector,
bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor = std::nullopt);

// Return true if the vector is in VNNI layout with rank `expectedRank`.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(int64_t expectedRank, VectorType vector,
std::optional<int64_t> blockingFactor = std::nullopt);

// Return true if the operation is in VNNI layout.
// Return true if the linalg operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor = std::nullopt);
Expand Down
3 changes: 1 addition & 2 deletions lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,7 @@ LogicalResult GemmOp::verify() {
assert(memref && (memref.getRank() == 2 || memref.getRank() == 3));

if (memref.getRank() == 3) {
auto vnniFactor = vnni::utils::getVnniBlockingFactor(memref);
if (!vnniFactor || (*vnniFactor) % 2 != 0 ||
if (memref.getShape().back() % 2 != 0 ||
!vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM,
memref)) {
return emitOpError() << "expect VNNI layout for operand: " << actualIdx;
Expand Down
25 changes: 4 additions & 21 deletions lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "TPP/Passes.h"
#include "TPP/Transforms/Transforms.h"
#include "TPP/Transforms/Utils/DLTIUtils.h"
#include "TPP/Transforms/Utils/TransformUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -457,28 +458,10 @@ static int64_t getTileForDim(linalg::LinalgOp linalgOp, unsigned dim) {
int64_t tile = 32;

// Check if a tile size hint is associated to the IR via DLTI.
auto deriveFromDLTI = [&](ModuleOp moduleOp) {
if (!moduleOp)
return;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return;
auto deviceId = StringAttr::get(linalgOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return;
auto tileSizeId = StringAttr::get(linalgOp->getContext(), "tile_size");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(tileSizeId);
if (!entry)
return;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
auto tileValue = dlti::utils::query(linalgOp, {"CPU", "tile_size"});
if (succeeded(tileValue))
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(*tileValue))
tile = intAttr.getInt();
// TODO: might want to print a warning if tile_size exists as a key but the
// associated attribute has an unexpected type.
};
deriveFromDLTI(linalgOp->getParentOfType<mlir::ModuleOp>());

SmallVector<int64_t, 4> loopsRange = linalgOp.getStaticLoopRanges();
if (loopsRange[dim] == ShapedType::kDynamic)
Expand Down
53 changes: 9 additions & 44 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "TPP/Transforms/Utils/DLTIUtils.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
Expand All @@ -25,52 +26,16 @@ std::optional<int64_t> getVnniBlockingFactor(Type type, Operation *op) {
auto elementType = getElementTypeOrSelf(type);
if (elementType.isBF16()) {
// Check if a VNNI factor hint is associated to the IR via DLTI.
auto deriveVnniFromDLTI = [&]() -> std::optional<int64_t> {
if (!op)
return std::nullopt;
ModuleOp moduleOp = op->getParentOfType<mlir::ModuleOp>();
if (!moduleOp)
return std::nullopt;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return std::nullopt;
auto deviceId = StringAttr::get(moduleOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return std::nullopt;
auto vnniId = StringAttr::get(moduleOp->getContext(), "vnni");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(vnniId);
if (!entry)
return std::nullopt;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
auto vnniValue = dlti::utils::query(op, {"CPU", "vnni"});
if (succeeded(vnniValue))
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(*vnniValue))
return intAttr.getInt();
return std::nullopt;
};
if (auto vnniFactor = deriveVnniFromDLTI())
return *vnniFactor;

return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16);
}
return std::nullopt;
}

// Until we have a better way to express the VNNI layout (see: #563), it is up
// to the callee to specify the expected rank in the VNNI layout as the rank
// depends on the operations we are dealing with.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref,
std::optional<int64_t> blockingFactor) {
if (memref.getRank() != static_cast<int64_t>(expectedRank) ||
!memref.getElementType().isBF16())
return false;

if (blockingFactor && memref.getShape().back() != *blockingFactor)
return false;

return true;
}

bool isInVnniLayout(linalg::LinalgOp linalgOp,
std::optional<int64_t> blockingFactor) {
// Narrow down type operations - VNNI only applies to contractions.
Expand Down Expand Up @@ -142,18 +107,18 @@ bool isInVnniLayout(linalg::LinalgOp linalgOp,
return true;
}

bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector,
bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor) {
return isInVnniLayout(static_cast<int64_t>(expectedRank), vector,
return isInVnniLayout(static_cast<int64_t>(expectedRank), shape,
blockingFactor);
}

bool isInVnniLayout(int64_t expectedRank, VectorType vector,
bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
std::optional<int64_t> blockingFactor) {
if (vector.getRank() != expectedRank || !vector.getElementType().isBF16())
if (shape.getRank() != expectedRank || !shape.getElementType().isBF16())
return false;

if (blockingFactor && vector.getShape().back() != *blockingFactor)
if (blockingFactor && shape.getShape().back() != *blockingFactor)
return false;

return true;
Expand Down

0 comments on commit e7d8a96

Please sign in to comment.