55import triton_kernels .swiglu
66from triton_kernels .numerics_details .mxfp import downcast_to_mxfp
77from triton_kernels .matmul_ogs import PrecisionConfig , FlexCtx , FnSpecs , FusedActivation
8- from triton_kernels .matmul_ogs import matmul_ogs
8+ from triton_kernels .matmul_ogs import matmul_ogs , RoutingData , GatherIndx , ScatterIndx
99from triton_kernels .numerics import InFlexData
10- from triton_kernels .routing import routing
11- from triton_kernels .tensor import convert_layout
10+ from triton_kernels .topk import topk
11+ from triton_kernels .tensor import convert_layout , make_ragged_tensor_metadata
1212from triton_kernels .tensor_details .layout import StridedLayout , HopperMXScaleLayout , HopperMXValueLayout
1313from triton_kernels .tensor import wrap_torch_tensor , FP4
1414
@@ -31,6 +31,32 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True
3131 return out_glu * (x_linear + 1 )
3232
3333
34+ def compute_routing (logits , n_expts_act , n_expts_tot ):
35+ sparse_logits = topk (logits , n_expts_act , apply_softmax = True )
36+
37+ dispatch_indx = sparse_logits .mask_metadata .col_sorted_indx
38+ combine_indx = sparse_logits .mask_metadata .row_sorted_indx
39+
40+ ragged_batch_metadata = make_ragged_tensor_metadata (
41+ sparse_logits .mask_metadata .col_sum ,
42+ dispatch_indx .shape [0 ]
43+ )
44+
45+ gate_scal = sparse_logits .vals .flatten ()[combine_indx ]
46+
47+ rdata = RoutingData (
48+ gate_scal ,
49+ ragged_batch_metadata .batch_sizes ,
50+ n_expts_tot ,
51+ n_expts_act ,
52+ ragged_batch_metadata
53+ )
54+ gather_indx = GatherIndx (combine_indx , dispatch_indx )
55+ scatter_indx = ScatterIndx (dispatch_indx , combine_indx )
56+
57+ return rdata , gather_indx , scatter_indx
58+
59+
3460def moe (x , wg , w1 , w1_mx , w2 , w2_mx , bg , b1 , b2 , experts_per_token = 4 , num_experts = 128 , swiglu_limit = 7.0 , fused_act = True , interleaved = True ):
3561 if x .numel () == 0 :
3662 return x
@@ -41,8 +67,9 @@ def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_expert
4167
4268 with record_function ("wg" ):
4369 logits = matmul_ogs (x , wg , bg , precision_config = pcg )
70+
4471 with record_function ("routing" ):
45- rdata , gather_indx , scatter_indx = routing (logits , experts_per_token , simulated_ep = 1 )
72+ rdata , gather_indx , scatter_indx = compute_routing (logits , experts_per_token , num_experts )
4673
4774 if fused_act :
4875 assert interleaved , "Fused activation requires interleaved weights"
0 commit comments