Skip to content

Commit 0cc8032

Browse files
committed
gmm tuning
1 parent 1ad5223 commit 0cc8032

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

python/sgl_jax/srt/layers/moe.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -328,32 +328,31 @@ def _gmm_compute_with_sharded_weights(
328328
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
329329
)
330330

331-
# Use JAX operations for tiling parameters (cannot use int() on tracers)
332-
# tiling_gate = (
333-
# jnp.minimum(optimal_tiling_gate[0], m),
334-
# jnp.minimum(optimal_tiling_gate[1], k),
335-
# jnp.minimum(optimal_tiling_gate[2], n_gate),
336-
# )
337-
# tiling_down = (
338-
# jnp.minimum(optimal_tiling_down[0], m),
339-
# jnp.minimum(optimal_tiling_down[1], n_gate),
340-
# jnp.minimum(optimal_tiling_down[2], n_down),
341-
# )
331+
static_tiling_gate = (
332+
int(optimal_tiling_gate[0]),
333+
int(optimal_tiling_gate[1]),
334+
int(optimal_tiling_gate[2]),
335+
)
336+
static_tiling_down = (
337+
int(optimal_tiling_down[0]),
338+
int(optimal_tiling_down[1]),
339+
int(optimal_tiling_down[2]),
340+
)
342341
# gate
343342
layer_w0 = gmm(
344343
lhs=x,
345344
rhs=w0_kernel,
346345
group_sizes=local_group_sizes,
347346
preferred_element_type=self.dtype,
348-
tiling=optimal_tiling_gate,
347+
tiling=static_tiling_gate,
349348
)
350349
# up
351350
layer_w1 = gmm(
352351
lhs=x,
353352
rhs=w1_kernel,
354353
group_sizes=local_group_sizes,
355354
preferred_element_type=self.dtype,
356-
tiling=optimal_tiling_gate,
355+
tiling=static_tiling_gate,
357356
)
358357

359358
# activation
@@ -366,7 +365,7 @@ def _gmm_compute_with_sharded_weights(
366365
rhs=wo_kernel,
367366
group_sizes=local_group_sizes,
368367
preferred_element_type=self.dtype,
369-
tiling=optimal_tiling_down,
368+
tiling=static_tiling_down,
370369
)
371370

372371
return intermediate_output

python/sgl_jax/srt/managers/tp_worker_overlap_thread.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,5 +248,8 @@ def run_precompile(self):
248248
def run_gmm_auto_tune(self):
249249
self.worker.run_gmm_auto_tune()
250250

251+
def get_gmm_tiling_configs(self):
252+
return self.worker.get_gmm_tiling_configs()
253+
251254
def __delete__(self):
252255
self.input_queue.put((None, None, None, None))

0 commit comments

Comments
 (0)