Skip to content

Conversation

huanghua1994
Copy link
Collaborator

Description

This is a draft PR for saving some work and discussion.

Recently we used TE/JAX's grouped_gemm() interface for a MoE model's inference. Nsys shows a GPU bubble when grouped_gemm() is copying the group_sizes array from device to host. This is a known issue when we were designing the grouped_gemm() interface. It's performance impact for training / inference prefill stage is relatively small but cannot be ignored in inference decode stage. This draft aims to partially address the bubble issue.

Our target model uses MLP-MoE, i.e., each expert is a MLP layer. After fusing GEMMs, each MLP-MoE layer needs two grouped_gemm() with the same group_sizes array. This PR allows issuing an async D2H copy of the group_size array before entering grouped_gemm(), then grouped_gemm() can reuse the downloaded group_sizes. We have validated the correctness of the implementation in this PR in our target model.

This PR does not solve the issue of breaking CUDA graph in grouped_gemm() since in the async copy mode it still needs to call cudaEventSynchronize(). Furthermore, the D2H memcpy does not overlap with other operations for copying and dispatching tokens to experts in our implementation for the target model, since those JAX-native operations are captured and executed in CUDA graph, while the async D2H copy does not support CUDA graph.

@phu0ngng @mingxu1067 Please let me know your comments and suggestions. Much appreciated!

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added GroupedGemmCopySizesPrimitive for async copying of group_sizes from GPU to host
  • Added optional argument use_async_d2h_group_sizes for grouped_gemm(), the default value is False so the original code path will be used by default

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good improvement for now.

We should probably provide a GroupedLayerNormMLP VJP op, which encloses the grouped_gemm_copy_group_sizes function and the use_async_d2h_group_sizes option so that we don't expose these two to users as they can be pretty bug-prone.

"supported number ", max_num_gemms, " to be downloaded in advance.");
host_num_gemms = num_gemms;
// Wait for current compute stream to finish
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mingxu1067 could you check if this causes the same stream sync issue as last time when we used the compute_stream(0) instead of the stream given by XLA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +291 to +294
auto init = [&]() {
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event));
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms));
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this causes any issues, we could consider moving this allocation into the FFI prepare phase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants