@@ -328,32 +328,31 @@ def _gmm_compute_with_sharded_weights(
328328 gmm_tiling_configs , m , n_gate , n_down , self .num_experts
329329 )
330330
331- # Use JAX operations for tiling parameters (cannot use int() on tracers)
332- # tiling_gate = (
333- # jnp.minimum(optimal_tiling_gate[0], m),
334- # jnp.minimum(optimal_tiling_gate[1], k),
335- # jnp.minimum(optimal_tiling_gate[2], n_gate),
336- # )
337- # tiling_down = (
338- # jnp.minimum(optimal_tiling_down[0], m),
339- # jnp.minimum(optimal_tiling_down[1], n_gate),
340- # jnp.minimum(optimal_tiling_down[2], n_down),
341- # )
331+ static_tiling_gate = (
332+ int (optimal_tiling_gate [0 ]),
333+ int (optimal_tiling_gate [1 ]),
334+ int (optimal_tiling_gate [2 ]),
335+ )
336+ static_tiling_down = (
337+ int (optimal_tiling_down [0 ]),
338+ int (optimal_tiling_down [1 ]),
339+ int (optimal_tiling_down [2 ]),
340+ )
342341 # gate
343342 layer_w0 = gmm (
344343 lhs = x ,
345344 rhs = w0_kernel ,
346345 group_sizes = local_group_sizes ,
347346 preferred_element_type = self .dtype ,
348- tiling = optimal_tiling_gate ,
347+ tiling = static_tiling_gate ,
349348 )
350349 # up
351350 layer_w1 = gmm (
352351 lhs = x ,
353352 rhs = w1_kernel ,
354353 group_sizes = local_group_sizes ,
355354 preferred_element_type = self .dtype ,
356- tiling = optimal_tiling_gate ,
355+ tiling = static_tiling_gate ,
357356 )
358357
359358 # activation
@@ -366,7 +365,7 @@ def _gmm_compute_with_sharded_weights(
366365 rhs = wo_kernel ,
367366 group_sizes = local_group_sizes ,
368367 preferred_element_type = self .dtype ,
369- tiling = optimal_tiling_down ,
368+ tiling = static_tiling_down ,
370369 )
371370
372371 return intermediate_output
0 commit comments