Skip to content

Commit f978e11

Browse files
committed
gmm tuning
1 parent 61b7ed6 commit f978e11

File tree

1 file changed

+32
-63
lines changed
  • python/sgl_jax/srt/layers

1 file changed

+32
-63
lines changed

python/sgl_jax/srt/layers/moe.py

Lines changed: 32 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ def _get_tiling_from_configs(
164164
self, gmm_tiling_configs, m: int, k: int, n: int, num_groups: int
165165
):
166166
key = (m, k, n, num_groups)
167-
if gmm_tiling_configs is None:
168-
return (8, 1024, 1024) # Default fallback when configs not loaded
169167
return gmm_tiling_configs.get(key, (8, 1024, 1024)) # Default fallback
170168

171169
def _detect_device_capabilities(self):
@@ -207,32 +205,6 @@ def __call__(self, inputs, router_logits=None, gmm_tiling_configs=None):
207205
def _expert_parallel_forward_with_shard_map(
208206
self, inputs, router_logits, gmm_tiling_configs
209207
):
210-
# 预先计算静态 tiling 参数(在 shard_map 外部)
211-
total_tokens, hidden_dim = inputs.shape
212-
m, k = total_tokens, hidden_dim
213-
n_gate = self.intermediate_dim
214-
n_down = hidden_dim
215-
216-
# 获取最优 tiling 配置
217-
optimal_tiling_gate = self._get_tiling_from_configs(
218-
gmm_tiling_configs, m, k, n_gate, self.num_experts
219-
)
220-
optimal_tiling_down = self._get_tiling_from_configs(
221-
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
222-
)
223-
224-
# 转换为静态整数参数(在动态 m 值上使用最大值作为安全的静态值)
225-
static_tiling_gate = (
226-
min(optimal_tiling_gate[0], 16384), # 设置合理的最大值
227-
optimal_tiling_gate[1],
228-
optimal_tiling_gate[2],
229-
)
230-
static_tiling_down = (
231-
min(optimal_tiling_down[0], 16384),
232-
optimal_tiling_down[1],
233-
optimal_tiling_down[2],
234-
)
235-
236208
def _internal_moe_computation(
237209
hidden_states,
238210
router_logits,
@@ -279,16 +251,15 @@ def _internal_moe_computation(
279251
else:
280252
local_group_sizes = group_sizes
281253

282-
# GMM (使用预先计算的静态 tiling)
283-
intermediate_output = self._gmm_compute_with_static_tiling(
254+
# GMM
255+
intermediate_output = self._gmm_compute_with_sharded_weights(
284256
x,
285257
local_group_sizes,
286258
selected_experts,
287259
w0_weights,
288260
w1_weights,
289261
wo_weights,
290-
static_tiling_gate,
291-
static_tiling_down,
262+
gmm_tiling_configs,
292263
)
293264

294265
# EP Combine
@@ -328,26 +299,44 @@ def _internal_moe_computation(
328299
self.wo.value,
329300
)
330301

331-
def _gmm_compute_with_static_tiling(
302+
def _gmm_compute_with_sharded_weights(
332303
self,
333304
x,
334305
local_group_sizes,
335306
selected_experts,
336307
w0_kernel,
337308
w1_kernel,
338309
wo_kernel,
339-
static_tiling_gate,
340-
static_tiling_down,
310+
gmm_tiling_configs,
341311
):
342312
if x.shape[0] == 0:
343313
empty_output = jnp.zeros(
344314
(0, wo_kernel.shape[-1]), dtype=x.dtype
345315
) # (0, hidden_dim)
346316
return empty_output
347317

348-
# 直接使用预先计算好的静态 tiling 参数
349-
tiling_gate = static_tiling_gate
350-
tiling_down = static_tiling_down
318+
m, k = x.shape[0], x.shape[1]
319+
n_gate = w0_kernel.shape[2]
320+
n_down = wo_kernel.shape[2]
321+
322+
optimal_tiling_gate = self._get_tiling_from_configs(
323+
gmm_tiling_configs, m, k, n_gate, self.num_experts
324+
)
325+
optimal_tiling_down = self._get_tiling_from_configs(
326+
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
327+
)
328+
329+
# Convert to Python integers for static tiling parameters
330+
tiling_gate = (
331+
optimal_tiling_gate[0],
332+
optimal_tiling_gate[1],
333+
optimal_tiling_gate[2],
334+
)
335+
tiling_down = (
336+
optimal_tiling_down[0],
337+
optimal_tiling_down[1],
338+
optimal_tiling_down[2],
339+
)
351340
# gate
352341
layer_w0 = gmm(
353342
lhs=x,
@@ -390,33 +379,13 @@ def _single_device_forward(self, inputs, router_logits, gmm_tiling_configs):
390379

391380
top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True)
392381

393-
# 为单设备也预先计算静态 tiling 参数
394-
total_tokens, hidden_dim = inputs.shape
395-
m, k = total_tokens, hidden_dim
396-
n_gate = self.intermediate_dim
397-
n_down = hidden_dim
398-
399-
optimal_tiling_gate = self._get_tiling_from_configs(
400-
gmm_tiling_configs, m, k, n_gate, self.num_experts
401-
)
402-
optimal_tiling_down = self._get_tiling_from_configs(
403-
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
382+
return self._single_device_forward_impl(
383+
inputs, top_k_indices, top_k_weights, gmm_tiling_configs
404384
)
405385

406-
static_tiling_gate = (
407-
min(optimal_tiling_gate[0], 16384),
408-
optimal_tiling_gate[1],
409-
optimal_tiling_gate[2],
410-
)
411-
static_tiling_down = (
412-
min(optimal_tiling_down[0], 16384),
413-
optimal_tiling_down[1],
414-
optimal_tiling_down[2],
415-
)
416-
417-
return self._single_device_forward_impl(inputs, top_k_indices, top_k_weights)
418-
419-
def _single_device_forward_impl(self, inputs, top_k_indices, top_k_weights):
386+
def _single_device_forward_impl(
387+
self, inputs, top_k_indices, top_k_weights, gmm_tiling_configs
388+
):
420389
num_tokens = inputs.shape[0] * (inputs.shape[1] if inputs.ndim > 1 else 1)
421390
inputs_flat = inputs.reshape(num_tokens, -1)
422391

0 commit comments

Comments
 (0)