@@ -86,34 +86,36 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
86
86
87
87
// FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM.
88
88
BatchedGemmOptions (gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX,
89
- int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC,
90
- tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
91
- bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN,
92
- bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit,
93
- bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
94
- gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK,
95
- tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK,
96
- int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
97
- int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
98
- std::optional<int32_t > sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
99
- int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN,
100
- gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8,
101
- bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA,
102
- bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize,
103
- gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector<int > batchedM, std::vector<int > batchedN,
104
- BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl,
105
- bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp,
106
- int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
89
+ int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA,
90
+ tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit,
91
+ bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits,
92
+ int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
93
+ bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit,
94
+ bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA,
95
+ gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
96
+ int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp,
97
+ int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, int numStages,
98
+ int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
99
+ bool outputDebugTensors, bool patchF2fp, std::optional<int32_t > sfBlockSizeA, tg::SfLayout sfLayoutA,
100
+ tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK,
101
+ int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput,
102
+ bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA,
103
+ bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
104
+ bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType, bool clampBeforeAct,
105
+ std::vector<int > batchedM, std::vector<int > batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
106
+ int numTokens, RouteImpl routeImpl, std::optional<RouteImpl> routeSfsImpl, bool gridWaitForPrimaryRouting,
107
+ bool fusedAct, bool useTmaOobOpt)
107
108
: gemmGatedAct::GemmGatedActOptions(
108
- gemm::GemmOptions (allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, dtypeA,
109
- dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs,
110
- epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA,
111
- gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB,
112
- hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM,
113
- mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
114
- numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp,
115
- sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN,
116
- tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
109
+ gemm::GemmOptions (allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, ctaSwizzleType,
110
+ dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit,
111
+ enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN,
112
+ gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA,
113
+ gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m,
114
+ mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps, numRegsCopySfLdsSttm,
115
+ numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK,
116
+ numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
117
+ outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK,
118
+ splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8,
117
119
useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore,
118
120
useTwoTmaLoadWarps, useTwoMmaWarps, useUnrollLoop2xForMma, worldSize),
119
121
actType, clampBeforeAct)
@@ -124,11 +126,9 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
124
126
, mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
125
127
, mIsStaticBatch(isStaticBatch)
126
128
, mNumBatches(numBatches)
127
- , mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp)
128
- , mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp)
129
- , mNumRegsCastAWarps(numRegsCastAWarps)
130
129
, mNumTokens(numTokens)
131
130
, mRouteImpl(routeImpl)
131
+ , mRouteSfsImpl(routeSfsImpl)
132
132
, mUseTmaOobOpt(useTmaOobOpt)
133
133
{
134
134
}
@@ -148,16 +148,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
148
148
bool mIsStaticBatch {true };
149
149
// Number of Gemm batches.
150
150
int mNumBatches ;
151
- // Number of registers per thread for non-epilogue warps
152
- int mNumRegsPerThreadNonEpilogueWarp {0 };
153
- // Number of registers per thread for epilogue warps
154
- int mNumRegsPerThreadEpilogueWarp {0 };
155
- // Number of registers for the cast A warps.
156
- int mNumRegsCastAWarps {0 };
157
151
// Total number of tokens.
158
152
int mNumTokens {32 };
159
153
// Whether load the input tokens and do routing.
160
154
RouteImpl mRouteImpl {RouteImpl::NoRoute};
155
+ // Routing logic for scaling factors. If not specified, mRouteImpl is used.
156
+ std::optional<RouteImpl> mRouteSfsImpl {std::nullopt };
161
157
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
162
158
// BatchedGemm/KernelParamsDecl.h.
163
159
bool mUseTmaOobOpt {false };
@@ -255,6 +251,24 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
255
251
" E2m1 is not supported with DeepSeek FP8" );
256
252
}
257
253
254
+ if (options.mRouteSfsImpl .has_value () && options.mRouteSfsImpl .value () != options.mRouteImpl )
255
+ {
256
+ TLLM_CHECK_ERROR (options.mRouteSfsImpl .value () == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma,
257
+ " RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma" );
258
+ }
259
+ else if (!options.mRouteSfsImpl .has_value ())
260
+ {
261
+ if (updateOptions)
262
+ {
263
+ options.mRouteSfsImpl = options.mRouteImpl ;
264
+ }
265
+ else
266
+ {
267
+ TLLM_LOG_ERROR (" RouteSfsImpl must be specified" );
268
+ return false ;
269
+ }
270
+ }
271
+
258
272
if (batchM)
259
273
{
260
274
if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4)
@@ -299,20 +313,23 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
299
313
}
300
314
}
301
315
302
- if (doesRouteImplUseTma (options.mRouteImpl ))
316
+ if (doesRouteImplUseTma (options.mRouteSfsImpl . value () ))
303
317
{
304
318
TLLM_CHECK_ERROR (!batchM, " UTMALDG.GATHER4 only supported for batch N." );
305
319
306
320
if (tg::mmaKindIsBlockFmt (options.mMmaKind ))
307
321
{
308
322
auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB ;
309
- TLLM_CHECK_ERROR (options.mTileK % tg::dtypeNumEltsPerSf (dtypeRoute) == 0 ,
310
- " tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)." );
311
323
TLLM_CHECK_ERROR (options.mTileK % (tg::dtypeNumEltsPerSf (dtypeRoute) * 16 ) == 0 ,
312
324
" tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)." );
313
325
}
314
326
}
315
327
328
+ if (options.mClusterDimX > 1 )
329
+ {
330
+ TLLM_CHECK_ERROR (!batchM, " 2CTA Gemm currently only supports batch N." );
331
+ }
332
+
316
333
if (!batchM || doesRouteImplUseNoRoute (options.mRouteImpl ))
317
334
{
318
335
TLLM_CHECK_ERROR (options.mSfLayoutA == tg::SfLayout::R128c4,
@@ -336,6 +353,13 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
336
353
TLLM_CHECK_ERROR (options.mK % options.mTileK == 0 , " K must be a multiple of TileK" );
337
354
}
338
355
356
+ if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute)
357
+ {
358
+ TLLM_CHECK_ERROR (false ,
359
+ " 2CTA BatchedGemm does not support routing along M dimension. To support it, "
360
+ " change the input routing data layout to be padded to clusterDimX size." );
361
+ }
362
+
339
363
return isValid;
340
364
}
341
365
@@ -359,6 +383,7 @@ struct BatchedGemmConfig
359
383
char const * mHash {nullptr };
360
384
#else
361
385
trtllm::gen::CudaRunner* mCudaRunner {nullptr };
386
+ int32_t mInstanceIdx {0 };
362
387
#endif
363
388
364
389
BatchedGemmOptions mOptions ;
@@ -379,11 +404,10 @@ inline std::string dumpOptions(BatchedGemmOptions const& options)
379
404
ss << " mIsStaticBatch=" << options.mIsStaticBatch << " ," << std::endl;
380
405
ss << " mNumTokens=" << options.mNumTokens << " ," << std::endl;
381
406
ss << " mRouteImpl=batchedGemm::RouteImpl(" << static_cast <int32_t >(options.mRouteImpl ) << " )," << std::endl;
407
+ ss << " mRouteSfsImpl={batchedGemm::RouteImpl(" << static_cast <int32_t >(options.mRouteSfsImpl .value ()) << " )},"
408
+ << std::endl;
382
409
ss << " mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << " ," << std::endl;
383
410
ss << " mFusedAct=" << options.mFusedAct << " ," << std::endl;
384
- ss << " mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << " ," << std::endl;
385
- ss << " mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << " ," << std::endl;
386
- ss << " mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << " ," << std::endl;
387
411
ss << " mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
388
412
return ss.str ();
389
413
}
0 commit comments