@@ -379,9 +379,6 @@ def _forward(
379379 forward_batch : ForwardBatch ,
380380 logits_metadata : LogitsMetadata ,
381381 ):
382- # 预先计算 GMM tiling 参数并设置到 forward_batch 中
383- self ._compute_and_set_gmm_tiling (forward_batch )
384-
385382 cache_miss_count = 0
386383 import jax ._src .test_util as jtu
387384
@@ -394,67 +391,11 @@ def _forward(
394391
395392 return output , cache_miss_count
396393
397- def _compute_and_set_gmm_tiling (self , forward_batch : ForwardBatch ):
398- """预先计算 GMM tiling 参数并设置到 forward_batch"""
399- try :
400- # 获取模型配置
401- hidden_size = getattr (self .model_config , "hidden_size" , 2048 )
402- intermediate_size = getattr (self .model_config , "moe_intermediate_size" , 768 )
403- num_experts = getattr (self .model_config , "num_experts" , 128 )
404- num_experts_per_tok = getattr (self .model_config , "num_experts_per_tok" , 8 )
405-
406- # 计算 tiling 参数
407- static_tiling_gate , static_tiling_down = self .compute_gmm_tiling_for_batch (
408- forward_batch ,
409- hidden_size ,
410- intermediate_size ,
411- num_experts ,
412- num_experts_per_tok ,
413- )
414-
415- # 设置到 forward_batch
416- forward_batch .static_tiling_gate = static_tiling_gate
417- forward_batch .static_tiling_down = static_tiling_down
418-
419- except Exception as e :
420- # 出现任何错误时使用默认值
421- forward_batch .static_tiling_gate = (512 , 1024 , 1024 )
422- forward_batch .static_tiling_down = (512 , 1024 , 1024 )
423-
424394 def _set_kv_cache_after_forward (self , layers_kv_fused , forward_batch : ForwardBatch ):
425395 start_idx = forward_batch .token_to_kv_pool .start_layer
426396 end_idx = start_idx + len (layers_kv_fused )
427397 forward_batch .token_to_kv_pool .kv_buffer [start_idx :end_idx ] = layers_kv_fused
428398
429- def compute_gmm_tiling_for_batch (
430- self ,
431- forward_batch : ForwardBatch ,
432- hidden_size : int ,
433- intermediate_size : int ,
434- num_experts : int ,
435- num_experts_per_tok : int ,
436- ) -> Tuple [Tuple [int , int , int ], Tuple [int , int , int ]]:
437- total_tokens = forward_batch .seq_lens .sum ()
438-
439- # 计算考虑 expert topk 的实际 m 值
440- m_actual = int (total_tokens * num_experts_per_tok )
441-
442- # 构造高效的字符串 key
443- gate_key = f"m{ m_actual } _k{ hidden_size } _n{ intermediate_size } _g{ num_experts } "
444- down_key = f"m{ m_actual } _k{ intermediate_size } _n{ hidden_size } _g{ num_experts } "
445- logger .info (f"gate_key: { gate_key } , down_key: { down_key } " )
446- if forward_batch .gmm_tiling_configs :
447- # 只做精确匹配
448- gate_tiling = forward_batch .gmm_tiling_configs .get (gate_key , None )
449- down_tiling = forward_batch .gmm_tiling_configs .get (down_key , None )
450-
451- if gate_tiling and down_tiling :
452- return gate_tiling , down_tiling
453- else :
454- logger .warning ("No GMM tiling configs found in forward_batch" )
455- logger .warning (f"No GMM tiling found for key: { gate_key } or { down_key } " )
456- return (512 , 1024 , 1024 ), (512 , 1024 , 1024 )
457-
458399 def forward_idle (
459400 self ,
460401 forward_batch : ForwardBatch ,
0 commit comments