Skip to content

Commit d098bf3

Browse files
committed
gmm tuning v3
1 parent 4766817 commit d098bf3

File tree

1 file changed

+0
-59
lines changed

1 file changed

+0
-59
lines changed

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,6 @@ def _forward(
379379
forward_batch: ForwardBatch,
380380
logits_metadata: LogitsMetadata,
381381
):
382-
# 预先计算 GMM tiling 参数并设置到 forward_batch 中
383-
self._compute_and_set_gmm_tiling(forward_batch)
384-
385382
cache_miss_count = 0
386383
import jax._src.test_util as jtu
387384

@@ -394,67 +391,11 @@ def _forward(
394391

395392
return output, cache_miss_count
396393

397-
def _compute_and_set_gmm_tiling(self, forward_batch: ForwardBatch):
398-
"""预先计算 GMM tiling 参数并设置到 forward_batch"""
399-
try:
400-
# 获取模型配置
401-
hidden_size = getattr(self.model_config, "hidden_size", 2048)
402-
intermediate_size = getattr(self.model_config, "moe_intermediate_size", 768)
403-
num_experts = getattr(self.model_config, "num_experts", 128)
404-
num_experts_per_tok = getattr(self.model_config, "num_experts_per_tok", 8)
405-
406-
# 计算 tiling 参数
407-
static_tiling_gate, static_tiling_down = self.compute_gmm_tiling_for_batch(
408-
forward_batch,
409-
hidden_size,
410-
intermediate_size,
411-
num_experts,
412-
num_experts_per_tok,
413-
)
414-
415-
# 设置到 forward_batch
416-
forward_batch.static_tiling_gate = static_tiling_gate
417-
forward_batch.static_tiling_down = static_tiling_down
418-
419-
except Exception as e:
420-
# 出现任何错误时使用默认值
421-
forward_batch.static_tiling_gate = (512, 1024, 1024)
422-
forward_batch.static_tiling_down = (512, 1024, 1024)
423-
424394
def _set_kv_cache_after_forward(self, layers_kv_fused, forward_batch: ForwardBatch):
425395
start_idx = forward_batch.token_to_kv_pool.start_layer
426396
end_idx = start_idx + len(layers_kv_fused)
427397
forward_batch.token_to_kv_pool.kv_buffer[start_idx:end_idx] = layers_kv_fused
428398

429-
def compute_gmm_tiling_for_batch(
430-
self,
431-
forward_batch: ForwardBatch,
432-
hidden_size: int,
433-
intermediate_size: int,
434-
num_experts: int,
435-
num_experts_per_tok: int,
436-
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]:
437-
total_tokens = forward_batch.seq_lens.sum()
438-
439-
# 计算考虑 expert topk 的实际 m 值
440-
m_actual = int(total_tokens * num_experts_per_tok)
441-
442-
# 构造高效的字符串 key
443-
gate_key = f"m{m_actual}_k{hidden_size}_n{intermediate_size}_g{num_experts}"
444-
down_key = f"m{m_actual}_k{intermediate_size}_n{hidden_size}_g{num_experts}"
445-
logger.info(f"gate_key: {gate_key}, down_key: {down_key}")
446-
if forward_batch.gmm_tiling_configs:
447-
# 只做精确匹配
448-
gate_tiling = forward_batch.gmm_tiling_configs.get(gate_key, None)
449-
down_tiling = forward_batch.gmm_tiling_configs.get(down_key, None)
450-
451-
if gate_tiling and down_tiling:
452-
return gate_tiling, down_tiling
453-
else:
454-
logger.warning("No GMM tiling configs found in forward_batch")
455-
logger.warning(f"No GMM tiling found for key: {gate_key} or {down_key}")
456-
return (512, 1024, 1024), (512, 1024, 1024)
457-
458399
def forward_idle(
459400
self,
460401
forward_batch: ForwardBatch,

0 commit comments

Comments
 (0)