Skip to content

Commit 21d2ef2

Browse files
leeliu103zhanglx13
andauthored
[AMD] Improve shared layout for Wmma's operands (#7319)
Swizzling is always disabled for Wmma's B operand, it should be disabled only when k dimension is not contiguous. Both vectorSize, perPhase and maxPhase are now determined using a heuristic approach. --------- Co-authored-by: Lixun Zhang <[email protected]>
1 parent fa73f39 commit 21d2ef2

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -286,25 +286,10 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
286286
}
287287

288288
// ---- begin WMMA ----
289-
if (mlir::isa<AMDWmmaEncodingAttr>(dotOpEnc.getParent())) {
290-
if (dotOpEnc.getOpIdx() == 0) {
291-
const int numBanks = 32;
292-
const int bankBitWidth = 32;
293-
294-
// number of inner dimension rows per one pattern repeat
295-
int innerDimLength = shape[order[0]];
296-
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
297-
298-
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
299-
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
300-
int maxPhase = 16 / perPhase;
301-
302-
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
303-
} else {
304-
// Do not swizzle in case k dimension is not innermost.
305-
// In this case accesses will go in different banks even without swizzling.
306-
return get(context, 1, 1, 1, order, CTALayout);
307-
}
289+
if (auto wmmaEnc = mlir::dyn_cast<AMDWmmaEncodingAttr>(dotOpEnc.getParent())) {
290+
return wmmaEnc.composeSharedLayoutForOperand(
291+
CTALayout, dotOpEnc.getOpIdx(), shape, order, dotOpEnc.getKWidth(),
292+
typeWidthInBit, needTrans);
308293
}
309294

310295

@@ -1230,6 +1215,13 @@ Row |
12301215
Type elemType, int kWidth, int kDim, int opIdx) const;
12311216
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
12321217
static SmallVector<unsigned> getMNKDimPerInstr();
1218+
1219+
// Returns a swizzled shared layout matching this WMMA layout for the
1220+
// dot operand at the given |operandIdx| with |operandShape|.
1221+
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(
1222+
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
1223+
ArrayRef<unsigned> sharedOrder, unsigned kWidth,
1224+
unsigned elemBitWidth, bool needTrans) const;
12331225
}];
12341226
}
12351227

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,42 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
20312031
return {16, 16, 16};
20322032
}
20332033

2034+
SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand(
2035+
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
2036+
ArrayRef<unsigned> sharedOrder, unsigned kWidth, unsigned elemBitWidth,
2037+
bool needTrans) const {
2038+
int kDimIndex = operandIdx == 0 ? 1 : 0;
2039+
bool isKContig = sharedOrder[0] == kDimIndex;
2040+
2041+
if (!isKContig) {
2042+
// Do not swizzle. In this case accesses will go in different banks even
2043+
// without swizzling.
2044+
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
2045+
ctaLayout);
2046+
}
2047+
2048+
// max vectorization size for ds_load is 128 bits
2049+
int vectorSize = std::min(kWidth * elemBitWidth, 128u) / elemBitWidth;
2050+
2051+
const int numBanks = 32;
2052+
const int bankBitWidth = 32;
2053+
2054+
// Number of inner dimension rows per one pattern repeat
2055+
int innerDimLength = operandShape[sharedOrder[0]];
2056+
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth;
2057+
2058+
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
2059+
// for both RDNA3 and RDNA4, the M/N dimension of wmma is 16
2060+
// This represents the max number of rows that can be accessed
2061+
// at the same time
2062+
int mDim = getMNKDimPerInstr()[0];
2063+
int maxPhase =
2064+
std::max(std::min(mDim / perPhase, innerDimLength / vectorSize), 1);
2065+
2066+
return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase,
2067+
maxPhase, sharedOrder, ctaLayout);
2068+
}
2069+
20342070
//===----------------------------------------------------------------------===//
20352071
// Mma encoding
20362072
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)