Skip to content

Commit 3557af1

Browse files
tom-pollakjoao-alex-cunha
authored andcommitted
fix: port moe routing to new triton_kernels API
The `triton_kernels.routing` module was deprecated and removed in triton commit 30ede52aa (triton-lang/triton#8375). Replaced deprecated `routing()` call with new primitives in `compute_routing()`. --- Upgrade to `triton>=3.5`. `triton_kernels` HEAD uses on `tl.target_info()` that is not available in 3.4.
1 parent 0a9ec7f commit 3557af1

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

gpt_oss/triton/moe.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import triton_kernels.swiglu
66
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
77
from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
8-
from triton_kernels.matmul_ogs import matmul_ogs
8+
from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx
99
from triton_kernels.numerics import InFlexData
10-
from triton_kernels.routing import routing
11-
from triton_kernels.tensor import convert_layout
10+
from triton_kernels.topk import topk
11+
from triton_kernels.tensor import convert_layout, make_ragged_tensor_metadata
1212
from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout
1313
from triton_kernels.tensor import wrap_torch_tensor, FP4
1414

@@ -31,6 +31,32 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True
3131
return out_glu * (x_linear + 1)
3232

3333

34+
def compute_routing(logits, n_expts_act, n_expts_tot):
35+
sparse_logits = topk(logits, n_expts_act, apply_softmax=True)
36+
37+
dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
38+
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
39+
40+
ragged_batch_metadata = make_ragged_tensor_metadata(
41+
sparse_logits.mask_metadata.col_sum,
42+
dispatch_indx.shape[0]
43+
)
44+
45+
gate_scal = sparse_logits.vals.flatten()[combine_indx]
46+
47+
rdata = RoutingData(
48+
gate_scal,
49+
ragged_batch_metadata.batch_sizes,
50+
n_expts_tot,
51+
n_expts_act,
52+
ragged_batch_metadata
53+
)
54+
gather_indx = GatherIndx(combine_indx, dispatch_indx)
55+
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
56+
57+
return rdata, gather_indx, scatter_indx
58+
59+
3460
def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True):
3561
if x.numel() == 0:
3662
return x
@@ -41,8 +67,9 @@ def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_expert
4167

4268
with record_function("wg"):
4369
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
70+
4471
with record_function("routing"):
45-
rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1)
72+
rdata, gather_indx, scatter_indx = compute_routing(logits, experts_per_token, num_experts)
4673

4774
if fused_act:
4875
assert interleaved, "Fused activation requires interleaved weights"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ requires-python = ">=3.12"
2424
version = "0.0.8"
2525

2626
[project.optional-dependencies]
27-
triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"]
27+
triton = ["triton>=3.5", "safetensors>=0.5.3", "torch>=2.7.0"]
2828
torch = ["safetensors>=0.5.3", "torch>=2.7.0"]
2929
metal = ["numpy", "tqdm", "safetensors", "torch"]
3030
test = ["pytest>=8.4.1", "httpx>=0.28.1"]

0 commit comments

Comments
 (0)