@@ -127,6 +127,12 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>
127127            //  convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked
128128            for  (int  ie = 0 ; ie < param.topK ; ++ie)
129129            {
130+                 //  Set invalid topk indices for the first half of the topk
131+                 if  (param.hasInvalidTopKInput  && ie < param.topK  / 2  + 1 )
132+                 {
133+                     expIdx[ie].idx  = -1 ;
134+                 }
135+ 
130136                PackedType si{static_cast <T>(expIdx[ie].score ), expIdx[ie].idx };
131137                reinterpret_cast <PackedType*>(bufferCast<int8_t >(*this ->mPtrTopKPackedHost ))[it * param.topK  + ie] = si;
132138                if  (param.useTopKAsInput )
@@ -198,7 +204,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization)
198204        /* numExperts=*/ 128 , /* topK=*/ 8 ,
199205        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
200206        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
201-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
207+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false , 
202208        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
203209    this ->runTest (param);
204210};
@@ -209,7 +215,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertPara
209215        /* numExperts=*/ 128 , /* topK=*/ 8 ,
210216        /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 192 ,
211217        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
212-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
218+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
219+         /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
220+     this ->runTest (param);
221+ };
222+ 
223+ TYPED_TEST (RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput)
224+ {
225+     RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 4 ,
226+         /* numExperts=*/ 128 , /* topK=*/ 8 ,
227+         /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 192 ,
228+         /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
229+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
213230        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
214231    this ->runTest (param);
215232};
@@ -220,7 +237,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
220237        /* numExperts=*/ 128 , /* topK=*/ 8 ,
221238        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 192 ,
222239        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
223-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
240+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false , 
224241        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
225242    this ->runTest (param);
226243};
@@ -231,7 +248,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
231248        /* numExperts=*/ 128 , /* topK=*/ 8 ,
232249        /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 256 ,
233250        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
234-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
251+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
252+         /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
253+     this ->runTest (param);
254+ };
255+ 
256+ TYPED_TEST (RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput)
257+ {
258+     RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 100 ,
259+         /* numExperts=*/ 128 , /* topK=*/ 8 ,
260+         /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 256 ,
261+         /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
262+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
235263        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
236264    this ->runTest (param);
237265};
@@ -242,7 +270,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
242270        /* numExperts=*/ 128 , /* topK=*/ 8 ,
243271        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
244272        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
245-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
273+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false , 
246274        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
247275    this ->runTest (param);
248276};
@@ -264,7 +292,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4)
264292        /* numExperts=*/ 128 , /* topK=*/ 4 ,
265293        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 8 ,
266294        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
267-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
295+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
296+         /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
297+     this ->runTest (param);
298+ };
299+ 
300+ TYPED_TEST (RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4)
301+ {
302+     RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 200 ,
303+         /* numExperts=*/ 128 , /* topK=*/ 4 ,
304+         /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 8 ,
305+         /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
306+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
268307        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
269308    this ->runTest (param);
270309};
@@ -275,7 +314,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
275314        /* numExperts=*/ 128 , /* topK=*/ 4 ,
276315        /* expertParallelization=*/ 2 , /* expertParallelizationId=*/ 1 , /* tileTokensDim=*/ 8 ,
277316        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
278-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
317+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false ,
279318        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
280319    this ->runTest (param);
281320};
@@ -286,7 +325,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
286325        /* numExperts=*/ 128 , /* topK=*/ 4 ,
287326        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 8 ,
288327        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
289-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
328+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false ,
290329        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
291330    this ->runTest (param);
292331};
@@ -297,7 +336,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
297336        /* numExperts=*/ 128 , /* topK=*/ 4 ,
298337        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
299338        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
300-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,
339+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true ,  /* hasInvalidTopKInput= */ true , 
301340        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   8 );
302341    this ->runTest (param);
303342};
@@ -308,7 +347,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN)
308347        /* numExperts=*/ 512 , /* topK=*/ 10 ,
309348        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
310349        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
311-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
350+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false , 
312351        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
313352    this ->runTest (param);
314353};
@@ -319,7 +358,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN)
319358        /* numExperts=*/ 512 , /* topK=*/ 10 ,
320359        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
321360        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
322-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
361+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,  /* hasInvalidTopKInput= */ false , 
323362        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   9 );
324363    this ->runTest (param);
325364};
@@ -330,7 +369,18 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN)
330369        /* numExperts=*/ 512 , /* topK=*/ 10 ,
331370        /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
332371        /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
333-         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false ,
372+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ false , /* hasInvalidTopKInput=*/ false ,
373+         /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   8 );
374+     this ->runTest (param);
375+ };
376+ 
377+ TYPED_TEST (RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput)
378+ {
379+     RoutingKernelTestParam param (RoutingMethodType::Renormalize, /* numTokens=*/ 1000 ,
380+         /* numExperts=*/ 512 , /* topK=*/ 10 ,
381+         /* expertParallelization=*/ 1 , /* expertParallelizationId=*/ 0 , /* tileTokensDim=*/ 256 ,
382+         /* paddingLog2=*/ 3 , /* localExpertsStrideLog2=*/ 0 ,
383+         /* usePdl=*/ true , /* getExpWeights=*/ true , /* useTopKAsInput=*/ true , /* hasInvalidTopKInput=*/ true ,
334384        /* nGroup*/   0 , /* topkGroup*/   0 , /* routedScalingFactor*/   1 .0f , /* requiredComputeCapability*/   8 );
335385    this ->runTest (param);
336386};
0 commit comments