Skip to content

Commit e192735

Browse files
committed
Add unit tests and revisement in block_level kernel for invalid input
Signed-off-by: Christina Zhang <[email protected]>
1 parent ce0d761 commit e192735

File tree

6 files changed

+122
-38
lines changed

6 files changed

+122
-38
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo
130130
{
131131
if (laneIdx < params.mTopK)
132132
{
133-
int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx];
134-
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
133+
auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx];
134+
if (expertIdx != -1)
135+
{
136+
int offset = warpIdx * MaxNumExperts + expertIdx;
137+
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
138+
}
139+
else
140+
{
141+
params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1};
142+
}
135143
}
136144
}
137145
}

cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32)
217217
/*numExperts=*/32, /*topK=*/8,
218218
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
219219
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
220-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
220+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
221221
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
222222
this->runTest(param);
223223
};
@@ -228,7 +228,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72)
228228
/*numExperts=*/72, /*topK=*/6,
229229
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
230230
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
231-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
231+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
232232
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
233233
this->runTest(param);
234234
};
@@ -239,7 +239,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384)
239239
/*numExperts=*/384, /*topK=*/8,
240240
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
241241
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
242-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
242+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
243243
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
244244
this->runTest(param);
245245
};
@@ -250,7 +250,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
250250
/*numExperts=*/256, /*topK=*/8,
251251
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
252252
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
253-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
253+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
254254
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
255255
this->runTest(param);
256256
};
@@ -261,7 +261,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
261261
/*numExperts=*/256, /*topK=*/8,
262262
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192,
263263
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
264-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
264+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
265265
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
266266
this->runTest(param);
267267
};
@@ -272,7 +272,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
272272
/*numExperts=*/384, /*topK=*/8,
273273
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
274274
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
275-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
275+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false,
276276
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
277277
this->runTest(param);
278278
};
@@ -283,7 +283,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
283283
/*numExperts=*/256, /*topK=*/8,
284284
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
285285
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
286-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
286+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
287287
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
288288
this->runTest(param);
289289
};
@@ -294,7 +294,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
294294
/*numExperts=*/256, /*topK=*/8,
295295
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
296296
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
297-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
297+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
298298
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
299299
this->runTest(param);
300300
};
@@ -305,7 +305,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384)
305305
/*numExperts=*/384, /*topK=*/8,
306306
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
307307
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
308-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
308+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
309309
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
310310
this->runTest(param);
311311
};
@@ -316,7 +316,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
316316
/*numExperts=*/256, /*topK=*/8,
317317
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
318318
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
319-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
319+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
320320
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
321321
this->runTest(param);
322322
};
@@ -327,7 +327,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384)
327327
/*numExperts=*/384, /*topK=*/8,
328328
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
329329
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
330-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
330+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
331331
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
332332
this->runTest(param);
333333
};
@@ -338,7 +338,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
338338
/*numExperts=*/256, /*topK=*/2,
339339
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
340340
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
341-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
341+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
342342
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
343343
this->runTest(param);
344344
};
@@ -349,7 +349,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
349349
/*numExperts=*/256, /*topK=*/2,
350350
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
351351
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
352-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
352+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
353353
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
354354
this->runTest(param);
355355
};
@@ -360,7 +360,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2)
360360
/*numExperts=*/256, /*topK=*/2,
361361
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
362362
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
363-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
363+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
364364
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
365365
this->runTest(param);
366366
};
@@ -371,7 +371,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8)
371371
/*numExperts=*/32, /*topK=*/8,
372372
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
373373
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
374-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
374+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
375375
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
376376
this->runTest(param);
377377
};

cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>
127127
// convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked
128128
for (int ie = 0; ie < param.topK; ++ie)
129129
{
130+
// Set invalid topk indices for the first half of the topk
131+
if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1)
132+
{
133+
expIdx[ie].idx = -1;
134+
}
135+
130136
PackedType si{static_cast<T>(expIdx[ie].score), expIdx[ie].idx};
131137
reinterpret_cast<PackedType*>(bufferCast<int8_t>(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si;
132138
if (param.useTopKAsInput)
@@ -198,7 +204,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization)
198204
/*numExperts=*/128, /*topK=*/8,
199205
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
200206
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
201-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
207+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
202208
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
203209
this->runTest(param);
204210
};
@@ -209,7 +215,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertPara
209215
/*numExperts=*/128, /*topK=*/8,
210216
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
211217
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
212-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
218+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
219+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
220+
this->runTest(param);
221+
};
222+
223+
TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput)
224+
{
225+
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4,
226+
/*numExperts=*/128, /*topK=*/8,
227+
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
228+
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
229+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
213230
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
214231
this->runTest(param);
215232
};
@@ -220,7 +237,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
220237
/*numExperts=*/128, /*topK=*/8,
221238
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192,
222239
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
223-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
240+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
224241
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
225242
this->runTest(param);
226243
};
@@ -231,7 +248,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
231248
/*numExperts=*/128, /*topK=*/8,
232249
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256,
233250
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
234-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
251+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
252+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
253+
this->runTest(param);
254+
};
255+
256+
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput)
257+
{
258+
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100,
259+
/*numExperts=*/128, /*topK=*/8,
260+
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256,
261+
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
262+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
235263
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
236264
this->runTest(param);
237265
};
@@ -242,7 +270,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
242270
/*numExperts=*/128, /*topK=*/8,
243271
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
244272
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
245-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
273+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
246274
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
247275
this->runTest(param);
248276
};
@@ -264,7 +292,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4)
264292
/*numExperts=*/128, /*topK=*/4,
265293
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8,
266294
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
267-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
295+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
296+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
297+
this->runTest(param);
298+
};
299+
300+
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4)
301+
{
302+
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200,
303+
/*numExperts=*/128, /*topK=*/4,
304+
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8,
305+
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
306+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
268307
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
269308
this->runTest(param);
270309
};
@@ -275,7 +314,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
275314
/*numExperts=*/128, /*topK=*/4,
276315
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/8,
277316
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
278-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
317+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
279318
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
280319
this->runTest(param);
281320
};
@@ -286,7 +325,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
286325
/*numExperts=*/128, /*topK=*/4,
287326
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8,
288327
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
289-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
328+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
290329
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
291330
this->runTest(param);
292331
};
@@ -297,7 +336,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
297336
/*numExperts=*/128, /*topK=*/4,
298337
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
299338
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
300-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
339+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
301340
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
302341
this->runTest(param);
303342
};
@@ -308,7 +347,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN)
308347
/*numExperts=*/512, /*topK=*/10,
309348
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
310349
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
311-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
350+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
312351
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
313352
this->runTest(param);
314353
};
@@ -319,7 +358,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN)
319358
/*numExperts=*/512, /*topK=*/10,
320359
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
321360
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
322-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
361+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
323362
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
324363
this->runTest(param);
325364
};
@@ -330,7 +369,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN)
330369
/*numExperts=*/512, /*topK=*/10,
331370
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
332371
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
333-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
372+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
373+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
374+
this->runTest(param);
375+
};
376+
377+
TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput)
378+
{
379+
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000,
380+
/*numExperts=*/512, /*topK=*/10,
381+
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
382+
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
383+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
334384
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
335385
this->runTest(param);
336386
};

0 commit comments

Comments
 (0)