-
Notifications
You must be signed in to change notification settings - Fork 281
tpsp mode support db prefill balance. #1086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of LightLLM in distributed inference environments, particularly during the prefill stage when operating in data-parallel mode with Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a mechanism to balance the workload across data-parallel (DP) ranks during the prefill stage, specifically for the TP+SP mixed mode. This is beneficial when input sequences have highly variable lengths across DP ranks. The implementation involves redistributing input tensors using all_to_all communication. The core logic is added to InferStateInfo and is triggered from the pre_layer_infer modules. My review focuses on the performance implications of the chosen implementation strategy, particularly the frequent use of all_to_all communication within each transformer layer, and some minor maintainability improvements. While the load balancing concept is sound, the per-layer communication overhead might negate the benefits.
| if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance: | ||
| q = infer_state._all_to_all_unbalance_get(data=q) | ||
| cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code performs an _all_to_all_unbalance_get before the attention calculation and an _all_to_all_balance_get after the attention calculation (in _tpsp_get_o). This introduces two all-to-all communication steps per layer, which can be a significant performance bottleneck. This seems to defeat the purpose of balancing the load for prefill, as the attention, a compute-heavy part, would run on unbalanced data. Please clarify the reasoning behind this design. If the attention kernel does not support the balanced data layout, this should be documented with a code comment explaining the limitation and the workaround.
| 的prefill 推理性能下降 | ||
| """ | ||
| assert self.is_prefill | ||
| import torch.distributed as dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.dp_handle_lens = dp_handle_lens | ||
| # 分配每个dp 的原始输入和分配后的原始输入 | ||
| origin_datas = collections.deque() | ||
| for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
|
|
||
| def normal_or_p_d_start(args): | ||
| from lightllm.server.core.objs.start_args_type import StartArgs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lightllm/server/api_start.py
Outdated
| assert args.router_token_ratio == 0.0 | ||
|
|
||
| if args.enable_dp_prefill_balance: | ||
| assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message is a bit vague. A more descriptive message would be helpful for the user, explaining that --enable_dp_prefill_balance has dependencies.
| assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly" | |
| assert args.enable_tpsp_mix_mode and args.dp > 1, "--enable_dp_prefill_balance requires --enable_tpsp_mix_mode and --dp > 1" |
| 0,1,2,3 为一个dp, 4,5,6,7 为另一个 dp, 则在[0,4], | ||
| [1,5], [2,6], [3,7] 间建立通信组 | ||
| """ | ||
| from lightllm.utils.envs_utils import get_env_start_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.