Skip to content

Commit 72144a4

Browse files
authored
[https://nvbugs/5541494] [fix] Fix missing sm100f/103a kernels and add tests (#8098)
Signed-off-by: Xiwen Yu <[email protected]>
1 parent b4e6a16 commit 72144a4

File tree

708 files changed

+10041
-5184
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

708 files changed

+10041
-5184
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,25 @@ using namespace batchedGemm::trtllm::gen;
3636

3737
static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache;
3838

39+
constexpr bool isSMCompatible(int gpuSM, SmVersion kernelSM)
40+
{
41+
if (gpuSM == 103)
42+
{
43+
return kernelSM == SmVersion::Sm100f || kernelSM == SmVersion::Sm103a;
44+
}
45+
else if (gpuSM == 100)
46+
{
47+
return kernelSM == SmVersion::Sm100f || kernelSM == SmVersion::Sm100a;
48+
}
49+
else if (gpuSM == 90)
50+
{
51+
return kernelSM == SmVersion::Sm90a;
52+
}
53+
54+
TLLM_THROW("Unexpected gpuSM %d", gpuSM);
55+
return false;
56+
}
57+
3958
std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
4059
batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
4160
{
@@ -98,6 +117,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
98117

99118
mPassingConfigIndices.clear();
100119

120+
int gpuSM = tensorrt_llm::common::getSMVersion();
101121
for (size_t i = 0; i < bmm.getNumBatchedGemmConfigs(); ++i)
102122
{
103123
auto const options = configs[i].mOptions;
@@ -108,7 +128,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
108128
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput
109129
&& (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct
110130
&& options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch
111-
&& tileSize == mOptions.tileSize)
131+
&& tileSize == mOptions.tileSize && isSMCompatible(gpuSM, configs[i].mSm))
112132
{
113133
// FIXME: Disable split-k for now.
114134
if (options.mClusterDimZ != 1)

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ struct BatchedGemmData
235235
void const* mPtrBias{nullptr};
236236

237237
// The output tensor scaling factor for Fp8 (not DeepSeek FP8) and NvFp4 quantization.
238-
// TensorRT LLM API requires a scaling factor on the device.
238+
// TensorRT-LLM API requires a scaling factor on the device.
239239
// scaleC = dequantA * dequantB * quantC,
240240
// where dequantA is global dequantization scaling factor of A
241241
// if dtypeA is FP8, it transforms the range from [-448, 448] to [-amaxA, amaxA]
@@ -250,7 +250,7 @@ struct BatchedGemmData
250250
float const* mPtrScaleC{nullptr};
251251

252252
// The output gate scale for Fp8 (not DeepSeek FP8) and NvFp4 quantization.
253-
// TensorRT LLM API requires a scaling factor on the device.
253+
// TensorRT-LLM API requires a scaling factor on the device.
254254
// scaleGate = dequantA * dequantB,
255255
// where dequantA is global dequantization scaling factor of A
256256
// if dtypeA is FP8, it transforms the range from [-448, 448] to [-amaxA, amaxA]
@@ -507,8 +507,25 @@ class BatchedGemmInterface
507507
throw std::invalid_argument("Invalid combination of options");
508508
}
509509

510-
int32_t const numCtasTile
510+
if (batchM)
511+
{
512+
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX);
513+
}
514+
else
515+
{
516+
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY);
517+
}
518+
519+
int32_t numCtasTile
511520
= batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM);
521+
if (batchM)
522+
{
523+
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY);
524+
}
525+
else
526+
{
527+
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX);
528+
}
512529
int32_t const numCtasInner = options.mNumSlicesForSplitK;
513530
return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner);
514531
}
@@ -531,7 +548,6 @@ class BatchedGemmInterface
531548
// Aligns the pointer to the alignment
532549
template <typename Dtype>
533550
inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const;
534-
535551
// Returns the size of the workspace buffers in bytes
536552
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
537553

@@ -792,7 +808,9 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
792808
cuModuleUnload(cuModule);
793809
}
794810
#else
795-
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
811+
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid,
812+
/* cluster */ {},
813+
/* instanceId */ config.mInstanceIdx);
796814
#endif
797815

798816
return 0;

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -86,34 +86,36 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
8686

8787
// FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM.
8888
BatchedGemmOptions(gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX,
89-
int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC,
90-
tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
91-
bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN,
92-
bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit,
93-
bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
94-
gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK,
95-
tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK,
96-
int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
97-
int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
98-
std::optional<int32_t> sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
99-
int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN,
100-
gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8,
101-
bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA,
102-
bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize,
103-
gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector<int> batchedM, std::vector<int> batchedN,
104-
BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
105-
bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
106-
int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
89+
int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA,
90+
tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit,
91+
bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits,
92+
int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
93+
bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit,
94+
bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA,
95+
gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
96+
int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp,
97+
int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, int numStages,
98+
int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
99+
bool outputDebugTensors, bool patchF2fp, std::optional<int32_t> sfBlockSizeA, tg::SfLayout sfLayoutA,
100+
tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK,
101+
int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput,
102+
bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA,
103+
bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
104+
bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType, bool clampBeforeAct,
105+
std::vector<int> batchedM, std::vector<int> batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
106+
int numTokens, RouteImpl routeImpl, std::optional<RouteImpl> routeSfsImpl, bool gridWaitForPrimaryRouting,
107+
bool fusedAct, bool useTmaOobOpt)
107108
: gemmGatedAct::GemmGatedActOptions(
108-
gemm::GemmOptions(allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, dtypeA,
109-
dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs,
110-
epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
111-
gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB,
112-
hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM,
113-
mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
114-
numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp,
115-
sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN,
116-
tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
109+
gemm::GemmOptions(allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, ctaSwizzleType,
110+
dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit,
111+
enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN,
112+
gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
113+
gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m,
114+
mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps, numRegsCopySfLdsSttm,
115+
numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK,
116+
numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
117+
outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK,
118+
splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
117119
useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore,
118120
useTwoTmaLoadWarps, useTwoMmaWarps, useUnrollLoop2xForMma, worldSize),
119121
actType, clampBeforeAct)
@@ -124,11 +126,9 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
124126
, mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
125127
, mIsStaticBatch(isStaticBatch)
126128
, mNumBatches(numBatches)
127-
, mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp)
128-
, mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp)
129-
, mNumRegsCastAWarps(numRegsCastAWarps)
130129
, mNumTokens(numTokens)
131130
, mRouteImpl(routeImpl)
131+
, mRouteSfsImpl(routeSfsImpl)
132132
, mUseTmaOobOpt(useTmaOobOpt)
133133
{
134134
}
@@ -148,16 +148,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
148148
bool mIsStaticBatch{true};
149149
// Number of Gemm batches.
150150
int mNumBatches;
151-
// Number of registers per thread for non-epilogue warps
152-
int mNumRegsPerThreadNonEpilogueWarp{0};
153-
// Number of registers per thread for epilogue warps
154-
int mNumRegsPerThreadEpilogueWarp{0};
155-
// Number of registers for the cast A warps.
156-
int mNumRegsCastAWarps{0};
157151
// Total number of tokens.
158152
int mNumTokens{32};
159153
// Whether load the input tokens and do routing.
160154
RouteImpl mRouteImpl{RouteImpl::NoRoute};
155+
// Routing logic for scaling factors. If not specified, mRouteImpl is used.
156+
std::optional<RouteImpl> mRouteSfsImpl{std::nullopt};
161157
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
162158
// BatchedGemm/KernelParamsDecl.h.
163159
bool mUseTmaOobOpt{false};
@@ -255,6 +251,24 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
255251
"E2m1 is not supported with DeepSeek FP8");
256252
}
257253

254+
if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() != options.mRouteImpl)
255+
{
256+
TLLM_CHECK_ERROR(options.mRouteSfsImpl.value() == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma,
257+
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma");
258+
}
259+
else if (!options.mRouteSfsImpl.has_value())
260+
{
261+
if (updateOptions)
262+
{
263+
options.mRouteSfsImpl = options.mRouteImpl;
264+
}
265+
else
266+
{
267+
TLLM_LOG_ERROR("RouteSfsImpl must be specified");
268+
return false;
269+
}
270+
}
271+
258272
if (batchM)
259273
{
260274
if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4)
@@ -299,20 +313,23 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
299313
}
300314
}
301315

302-
if (doesRouteImplUseTma(options.mRouteImpl))
316+
if (doesRouteImplUseTma(options.mRouteSfsImpl.value()))
303317
{
304318
TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N.");
305319

306320
if (tg::mmaKindIsBlockFmt(options.mMmaKind))
307321
{
308322
auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB;
309-
TLLM_CHECK_ERROR(options.mTileK % tg::dtypeNumEltsPerSf(dtypeRoute) == 0,
310-
"tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA).");
311323
TLLM_CHECK_ERROR(options.mTileK % (tg::dtypeNumEltsPerSf(dtypeRoute) * 16) == 0,
312324
"tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA).");
313325
}
314326
}
315327

328+
if (options.mClusterDimX > 1)
329+
{
330+
TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N.");
331+
}
332+
316333
if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl))
317334
{
318335
TLLM_CHECK_ERROR(options.mSfLayoutA == tg::SfLayout::R128c4,
@@ -336,6 +353,13 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
336353
TLLM_CHECK_ERROR(options.mK % options.mTileK == 0, "K must be a multiple of TileK");
337354
}
338355

356+
if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute)
357+
{
358+
TLLM_CHECK_ERROR(false,
359+
"2CTA BatchedGemm does not support routing along M dimension. To support it, "
360+
"change the input routing data layout to be padded to clusterDimX size.");
361+
}
362+
339363
return isValid;
340364
}
341365

@@ -359,6 +383,7 @@ struct BatchedGemmConfig
359383
char const* mHash{nullptr};
360384
#else
361385
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
386+
int32_t mInstanceIdx{0};
362387
#endif
363388

364389
BatchedGemmOptions mOptions;
@@ -379,11 +404,10 @@ inline std::string dumpOptions(BatchedGemmOptions const& options)
379404
ss << "mIsStaticBatch=" << options.mIsStaticBatch << "," << std::endl;
380405
ss << "mNumTokens=" << options.mNumTokens << "," << std::endl;
381406
ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast<int32_t>(options.mRouteImpl) << ")," << std::endl;
407+
ss << "mRouteSfsImpl={batchedGemm::RouteImpl(" << static_cast<int32_t>(options.mRouteSfsImpl.value()) << ")},"
408+
<< std::endl;
382409
ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl;
383410
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
384-
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," << std::endl;
385-
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," << std::endl;
386-
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
387411
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
388412
return ss.str();
389413
}

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/Enums.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,24 @@ enum class TileScheduler
104104

105105
////////////////////////////////////////////////////////////////////////////////////////////////////
106106

107+
enum class CtaSwizzleType : uint32_t
108+
{
109+
// Rasterize CTAs along the M dimension.
110+
RasterizeAlongM = 0,
111+
// Rasterize CTAs along the N dimension.
112+
RasterizeAlongN,
113+
// Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2.
114+
ZigZagAlongM2,
115+
// Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2.
116+
ZigZagAlongN2,
117+
// Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4.
118+
ZigZagAlongM4,
119+
// Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4.
120+
ZigZagAlongN4,
121+
};
122+
123+
////////////////////////////////////////////////////////////////////////////////////////////////////
124+
107125
// Helper functions to check the SplitK type.
108126

109127
#define SPLIT_K_FUNCTION(Mode) \

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ struct GemmGatedActConfig
210210
char const* mHash{nullptr};
211211
#else
212212
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
213+
int32_t mInstanceIdx{0};
213214
#endif
214215

215216
GemmGatedActOptions mOptions{};

0 commit comments

Comments
 (0)