@@ -164,8 +164,6 @@ def _get_tiling_from_configs(
164164 self , gmm_tiling_configs , m : int , k : int , n : int , num_groups : int
165165 ):
166166 key = (m , k , n , num_groups )
167- if gmm_tiling_configs is None :
168- return (8 , 1024 , 1024 ) # Default fallback when configs not loaded
169167 return gmm_tiling_configs .get (key , (8 , 1024 , 1024 )) # Default fallback
170168
171169 def _detect_device_capabilities (self ):
@@ -207,32 +205,6 @@ def __call__(self, inputs, router_logits=None, gmm_tiling_configs=None):
207205 def _expert_parallel_forward_with_shard_map (
208206 self , inputs , router_logits , gmm_tiling_configs
209207 ):
210- # 预先计算静态 tiling 参数(在 shard_map 外部)
211- total_tokens , hidden_dim = inputs .shape
212- m , k = total_tokens , hidden_dim
213- n_gate = self .intermediate_dim
214- n_down = hidden_dim
215-
216- # 获取最优 tiling 配置
217- optimal_tiling_gate = self ._get_tiling_from_configs (
218- gmm_tiling_configs , m , k , n_gate , self .num_experts
219- )
220- optimal_tiling_down = self ._get_tiling_from_configs (
221- gmm_tiling_configs , m , n_gate , n_down , self .num_experts
222- )
223-
224- # 转换为静态整数参数(在动态 m 值上使用最大值作为安全的静态值)
225- static_tiling_gate = (
226- min (optimal_tiling_gate [0 ], 16384 ), # 设置合理的最大值
227- optimal_tiling_gate [1 ],
228- optimal_tiling_gate [2 ],
229- )
230- static_tiling_down = (
231- min (optimal_tiling_down [0 ], 16384 ),
232- optimal_tiling_down [1 ],
233- optimal_tiling_down [2 ],
234- )
235-
236208 def _internal_moe_computation (
237209 hidden_states ,
238210 router_logits ,
@@ -279,16 +251,15 @@ def _internal_moe_computation(
279251 else :
280252 local_group_sizes = group_sizes
281253
282- # GMM (使用预先计算的静态 tiling)
283- intermediate_output = self ._gmm_compute_with_static_tiling (
254+ # GMM
255+ intermediate_output = self ._gmm_compute_with_sharded_weights (
284256 x ,
285257 local_group_sizes ,
286258 selected_experts ,
287259 w0_weights ,
288260 w1_weights ,
289261 wo_weights ,
290- static_tiling_gate ,
291- static_tiling_down ,
262+ gmm_tiling_configs ,
292263 )
293264
294265 # EP Combine
@@ -328,26 +299,44 @@ def _internal_moe_computation(
328299 self .wo .value ,
329300 )
330301
331- def _gmm_compute_with_static_tiling (
302+ def _gmm_compute_with_sharded_weights (
332303 self ,
333304 x ,
334305 local_group_sizes ,
335306 selected_experts ,
336307 w0_kernel ,
337308 w1_kernel ,
338309 wo_kernel ,
339- static_tiling_gate ,
340- static_tiling_down ,
310+ gmm_tiling_configs ,
341311 ):
342312 if x .shape [0 ] == 0 :
343313 empty_output = jnp .zeros (
344314 (0 , wo_kernel .shape [- 1 ]), dtype = x .dtype
345315 ) # (0, hidden_dim)
346316 return empty_output
347317
348- # 直接使用预先计算好的静态 tiling 参数
349- tiling_gate = static_tiling_gate
350- tiling_down = static_tiling_down
318+ m , k = x .shape [0 ], x .shape [1 ]
319+ n_gate = w0_kernel .shape [2 ]
320+ n_down = wo_kernel .shape [2 ]
321+
322+ optimal_tiling_gate = self ._get_tiling_from_configs (
323+ gmm_tiling_configs , m , k , n_gate , self .num_experts
324+ )
325+ optimal_tiling_down = self ._get_tiling_from_configs (
326+ gmm_tiling_configs , m , n_gate , n_down , self .num_experts
327+ )
328+
329+ # Convert to Python integers for static tiling parameters
330+ tiling_gate = (
331+ optimal_tiling_gate [0 ],
332+ optimal_tiling_gate [1 ],
333+ optimal_tiling_gate [2 ],
334+ )
335+ tiling_down = (
336+ optimal_tiling_down [0 ],
337+ optimal_tiling_down [1 ],
338+ optimal_tiling_down [2 ],
339+ )
351340 # gate
352341 layer_w0 = gmm (
353342 lhs = x ,
@@ -390,33 +379,13 @@ def _single_device_forward(self, inputs, router_logits, gmm_tiling_configs):
390379
391380 top_k_weights = top_k_weights / jnp .sum (top_k_weights , axis = - 1 , keepdims = True )
392381
393- # 为单设备也预先计算静态 tiling 参数
394- total_tokens , hidden_dim = inputs .shape
395- m , k = total_tokens , hidden_dim
396- n_gate = self .intermediate_dim
397- n_down = hidden_dim
398-
399- optimal_tiling_gate = self ._get_tiling_from_configs (
400- gmm_tiling_configs , m , k , n_gate , self .num_experts
401- )
402- optimal_tiling_down = self ._get_tiling_from_configs (
403- gmm_tiling_configs , m , n_gate , n_down , self .num_experts
382+ return self ._single_device_forward_impl (
383+ inputs , top_k_indices , top_k_weights , gmm_tiling_configs
404384 )
405385
406- static_tiling_gate = (
407- min (optimal_tiling_gate [0 ], 16384 ),
408- optimal_tiling_gate [1 ],
409- optimal_tiling_gate [2 ],
410- )
411- static_tiling_down = (
412- min (optimal_tiling_down [0 ], 16384 ),
413- optimal_tiling_down [1 ],
414- optimal_tiling_down [2 ],
415- )
416-
417- return self ._single_device_forward_impl (inputs , top_k_indices , top_k_weights )
418-
419- def _single_device_forward_impl (self , inputs , top_k_indices , top_k_weights ):
386+ def _single_device_forward_impl (
387+ self , inputs , top_k_indices , top_k_weights , gmm_tiling_configs
388+ ):
420389 num_tokens = inputs .shape [0 ] * (inputs .shape [1 ] if inputs .ndim > 1 else 1 )
421390 inputs_flat = inputs .reshape (num_tokens , - 1 )
422391
0 commit comments