Skip to content

[AMD] Improve shared layout for Wmma's operands #7319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -286,25 +286,10 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
}

// ---- begin WMMA ----
if (mlir::isa<AMDWmmaEncodingAttr>(dotOpEnc.getParent())) {
if (dotOpEnc.getOpIdx() == 0) {
const int numBanks = 32;
const int bankBitWidth = 32;

// number of inner dimension rows per one pattern repeat
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;

int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = 16 / perPhase;

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
} else {
// Do not swizzle in case k dimension is not innermost.
// In this case accesses will go in different banks even without swizzling.
return get(context, 1, 1, 1, order, CTALayout);
}
if (auto wmmaEnc = mlir::dyn_cast<AMDWmmaEncodingAttr>(dotOpEnc.getParent())) {
return wmmaEnc.composeSharedLayoutForOperand(
CTALayout, dotOpEnc.getOpIdx(), shape, order, dotOpEnc.getKWidth(),
typeWidthInBit, needTrans);
}


Expand Down Expand Up @@ -1230,6 +1215,13 @@ Row |
Type elemType, int kWidth, int kDim, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
static SmallVector<unsigned> getMNKDimPerInstr();

// Returns a swizzled shared layout matching this WMMA layout for the
// dot operand at the given |operandIdx| with |operandShape|.
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
ArrayRef<unsigned> sharedOrder, unsigned kWidth,
unsigned elemBitWidth, bool needTrans) const;
}];
}

Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,42 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
return {16, 16, 16};
}

SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand(
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
ArrayRef<unsigned> sharedOrder, unsigned kWidth, unsigned elemBitWidth,
bool needTrans) const {
int kDimIndex = operandIdx == 0 ? 1 : 0;
bool isKContig = sharedOrder[0] == kDimIndex;

if (!isKContig) {
// Do not swizzle. In this case accesses will go in different banks even
// without swizzling.
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
ctaLayout);
}

// max vectorization size for ds_load is 128 bits
int vectorSize = std::min(kWidth * elemBitWidth, 128u) / elemBitWidth;

const int numBanks = 32;
const int bankBitWidth = 32;

// Number of inner dimension rows per one pattern repeat
int innerDimLength = operandShape[sharedOrder[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth;

int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// for both RDNA3 and RDNA4, the M/N dimension of wmma is 16
// This represents the max number of rows that can be accessed
// at the same time
int mDim = getMNKDimPerInstr()[0];
int maxPhase =
std::max(std::min(mDim / perPhase, innerDimLength / vectorSize), 1);

return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase,
maxPhase, sharedOrder, ctaLayout);
}

//===----------------------------------------------------------------------===//
// Mma encoding
//===----------------------------------------------------------------------===//
Expand Down
Loading