-
Notifications
You must be signed in to change notification settings - Fork 400
[llama4] enable expert parallel on the same device mesh as tp (tp2ep) #1269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for the PR! I'll take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for the PR! I think the idea sounds very interesting.
I have some high-level questions:
- Compared with the PR [MoE][PoC] Expert Parallel: tp and tp2ep #731 you refer to, this tp2ep implementation is all-to-all based rather than using all-gather / reduce-scatter. Do you have any idea which is more efficient, assuming both are correct?
- Personally I think the implementation itself is a bit too intrusive to model code, whereas the idea of torchtitan is trying not to do so (https://github.com/pytorch/torchtitan/blob/main/README.md?plain=1#L38). Do you think there is a chance you could make it cleaner?
- Do you have some testing to show that your implementation is correct, e.g. in terms of loss curves compared with training with single-device code?
"moe": | ||
PrepareModuleInputOutput( | ||
input_layouts=(Shard(1), ), | ||
desired_input_layouts=(Shard(1), ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the input to router is sharded. Then this might break the semantics / correctness of the load balancing algorithm, given the update to self.tokens_per_expert
is local to each EP rank.
https://github.com/pytorch/torchtitan/pull/1269/files#diff-87cc24d85c768f0b3d1f5c54cca39dc9de52ee20e8f601814c3200722901aee5R293
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out this issue. We need an all_reduce
across all ep groups.
Fixed in b87aa1e
|
This PR is built on top of the concept introduced in #731.
In this implementation, the input to the MoE module is sharded along the seqlen dimension rather than being replicated. After gathering tokens from different EP ranks using
all_to_all_single_autograd
, the output tokens remain sharded along the seqlen dimension.To activate this feature, set
enable_tp2ep = true
in the configuration file.cc @tianyu-l