Skip to content

Commit

Permalink
Get VNNI factor from DLTI
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jan 13, 2025
1 parent 0afd6a3 commit 9e89fd2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
7 changes: 5 additions & 2 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OpOperand;
class AffineDimExpr;
class AffineMap;
class VectorType;
class Operation;

namespace linalg {
class LinalgOp;
Expand All @@ -35,8 +36,10 @@ enum class VnniOperandRank {
BRGEMM_OUTS = 3
};

// Return the VNNI blocking factor: 2 for BF16 and 4 for BF8.
std::optional<int64_t> getVnniBlockingFactor(Type type);
// Return the VNNI blocking factor.
// Optionally, operation can be provided to give access to DLTI.
std::optional<int64_t> getVnniBlockingFactor(Type type,
Operation *op = nullptr);

// Return true if the memref is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOperandRank expectedRank, MemRefType memref);
Expand Down
33 changes: 31 additions & 2 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Utils/VNNIUtils.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand All @@ -20,10 +21,38 @@ namespace mlir {
namespace vnni {
namespace utils {

std::optional<int64_t> getVnniBlockingFactor(Type type) {
std::optional<int64_t> getVnniBlockingFactor(Type type, Operation *op) {
auto elementType = getElementTypeOrSelf(type);
if (elementType.isBF16())
if (elementType.isBF16()) {
// Check if a VNNI factor hint is associated to the IR via DLTI.
auto deriveVnniFromDLTI = [&]() -> std::optional<int64_t> {
if (!op)
return std::nullopt;
ModuleOp moduleOp = op->getParentOfType<mlir::ModuleOp>();
if (!moduleOp)
return std::nullopt;
TargetSystemSpecInterface sysSpec = moduleOp.getTargetSystemSpec();
if (!sysSpec)
return std::nullopt;
auto deviceId = StringAttr::get(moduleOp->getContext(), "CPU");
auto deviceSpec = sysSpec.getDeviceSpecForDeviceID(deviceId);
if (!deviceSpec)
return std::nullopt;
auto tileSizeId = StringAttr::get(moduleOp->getContext(), "vnni");
DataLayoutEntryInterface entry =
(*deviceSpec).getSpecForIdentifier(tileSizeId);
if (!entry)
return std::nullopt;
Attribute value = entry.getValue();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(value))
return intAttr.getInt();
return std::nullopt;
};
if (auto vnniFactor = deriveVnniFromDLTI())
return *vnniFactor;

return libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16);
}
return std::nullopt;
}

Expand Down

0 comments on commit 9e89fd2

Please sign in to comment.