[JAX][Draft] Async issuing D2H memcpy for grouped_gemm group_sizes array #2213
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This is a draft PR for saving some work and discussion.
Recently we used TE/JAX's
grouped_gemm()
interface for a MoE model's inference. Nsys shows a GPU bubble whengrouped_gemm()
is copying thegroup_sizes
array from device to host. This is a known issue when we were designing thegrouped_gemm()
interface. It's performance impact for training / inference prefill stage is relatively small but cannot be ignored in inference decode stage. This draft aims to partially address the bubble issue.Our target model uses MLP-MoE, i.e., each expert is a MLP layer. After fusing GEMMs, each MLP-MoE layer needs two
grouped_gemm()
with the samegroup_sizes
array. This PR allows issuing an async D2H copy of thegroup_size
array before enteringgrouped_gemm()
, thengrouped_gemm()
can reuse the downloadedgroup_sizes
. We have validated the correctness of the implementation in this PR in our target model.This PR does not solve the issue of breaking CUDA graph in
grouped_gemm()
since in the async copy mode it still needs to callcudaEventSynchronize()
. Furthermore, the D2H memcpy does not overlap with other operations for copying and dispatching tokens to experts in our implementation for the target model, since those JAX-native operations are captured and executed in CUDA graph, while the async D2H copy does not support CUDA graph.@phu0ngng @mingxu1067 Please let me know your comments and suggestions. Much appreciated!
Type of change
Changes
GroupedGemmCopySizesPrimitive
for async copying ofgroup_sizes
from GPU to hostuse_async_d2h_group_sizes
forgrouped_gemm()
, the default value isFalse
so the original code path will be used by defaultChecklist: