Skip to content

[llama4] enable expert parallel on the same device mesh as tp (tp2ep) #1269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

hann-wang
Copy link
Contributor

@hann-wang hann-wang commented Jun 6, 2025

This PR is built on top of the concept introduced in #731.

In this implementation, the input to the MoE module is sharded along the seqlen dimension rather than being replicated. After gathering tokens from different EP ranks using all_to_all_single_autograd, the output tokens remain sharded along the seqlen dimension.

To activate this feature, set enable_tp2ep = true in the configuration file.

cc @tianyu-l

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 6, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 6, 2025

Thank you for the PR! I'll take a look.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the PR! I think the idea sounds very interesting.

I have some high-level questions:

  1. Compared with the PR [MoE][PoC] Expert Parallel: tp and tp2ep #731 you refer to, this tp2ep implementation is all-to-all based rather than using all-gather / reduce-scatter. Do you have any idea which is more efficient, assuming both are correct?
  2. Personally I think the implementation itself is a bit too intrusive to model code, whereas the idea of torchtitan is trying not to do so (https://github.com/pytorch/torchtitan/blob/main/README.md?plain=1#L38). Do you think there is a chance you could make it cleaner?
  3. Do you have some testing to show that your implementation is correct, e.g. in terms of loss curves compared with training with single-device code?

"moe":
PrepareModuleInputOutput(
input_layouts=(Shard(1), ),
desired_input_layouts=(Shard(1), ),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the input to router is sharded. Then this might break the semantics / correctness of the load balancing algorithm, given the update to self.tokens_per_expert is local to each EP rank.
https://github.com/pytorch/torchtitan/pull/1269/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R293

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out this issue. We need an all_reduce across all ep groups.

Fixed in b87aa1e

@hann-wang
Copy link
Contributor Author

Thank you very much for the PR! I think the idea sounds very interesting.

I have some high-level questions:

  1. Compared with the PR [MoE][PoC] Expert Parallel: tp and tp2ep #731 you refer to, this tp2ep implementation is all-to-all based rather than using all-gather / reduce-scatter. Do you have any idea which is more efficient, assuming both are correct?
  2. Personally I think the implementation itself is a bit too intrusive to model code, whereas the idea of torchtitan is trying not to do so (https://github.com/pytorch/torchtitan/blob/main/README.md?plain=1#L38). Do you think there is a chance you could make it cleaner?
  3. Do you have some testing to show that your implementation is correct, e.g. in terms of loss curves compared with training with single-device code?
  1. When I mentioned [MoE][PoC] Expert Parallel: tp and tp2ep #731, I was referring to sharing the same device mesh between TP and EP. It is indeed possible to create a separate EP mesh in conjunction with TP. If the intermediate dimension of expert weights is relatively small, sharing the same device mesh between TP and EP should be feasible. The sharding of the router in the aforementioned pull request has some issues, as the top-k selection process is constrained to local experts.

  2. The tricky part of EP is that the activations are not evenly splitted, requiring us to determine split sizes through a top-k router. I believe implementing a TokenDispatcher could be beneficial, but I haven't found an appropriate location for its initialization. If placed within ExpertParallel._apply, it causes a torch.compile failure.

  3. I conducted an experiment using the Llama4 debug model, where I modified the dataset to C4 and set the steps to 1000. The loss curve for the TP2EP configuration (in blue) aligns well with that of the single GPU configuration (in black).
    (Note: Since torch._grouped_mm is not available on the ROCm platform, this experiment utilizes the cg_grouped_mm kernels found in torchtitan/experiments/kernels/triton_contiguous_group_gemm)

{A7E3713C-7012-4244-B836-D6ED02543EAF}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants