-
Notifications
You must be signed in to change notification settings - Fork 282
tpsp mode support db prefill balance. #1086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
6d0376d
fix tpsp get o
1d0da2d
dp balance cache data
36ecad8
fix
8f23444
fix
c47a76e
fix
645baa1
fix
b99deda
remove _pre_cache_kv
hiworldwzj 5ac8bfa
fix
hiworldwzj 5d6fe53
fix
hiworldwzj 08cc488
fix deepseekv2 balance
hiworldwzj b75b3c1
fix
hiworldwzj 6180581
improve deepseekv2
815f5ac
fix
5e2dfc9
fix
hiworldwzj f68a306
fix
hiworldwzj fbc1648
fix typing
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,16 @@ | ||
| import torch | ||
| import triton | ||
| import collections | ||
| from lightllm.common.mem_manager import MemoryManager | ||
| from lightllm.common.req_manager import ReqManager | ||
| from lightllm.distributed import CustomProcessGroup | ||
| from typing import Tuple, Any, Optional | ||
| from typing import Tuple, Any, Optional, List | ||
| from .triton_kernel.gen_prefill_params import gen_prefill_params | ||
| from .triton_kernel.gen_decode_params import gen_decode_params | ||
| from .triton_kernel.multimodal_emb import mark_multimodal_obj | ||
| from .batch_objs import ModelInput | ||
| from lightllm.utils.envs_utils import get_env_start_args | ||
| from lightllm.utils.dist_utils import get_global_dp_rank | ||
|
|
||
|
|
||
| class InferStateInfo: | ||
|
|
@@ -36,7 +40,6 @@ def __init__(self): | |
| self.req_manager: ReqManager = None | ||
|
|
||
| self.mem_index: torch.Tensor = None | ||
| self.kv_buffer_shapedtype: Tuple[Any, Any] = None | ||
|
|
||
| self.is_token_healing: bool = False | ||
| self.return_all_prompt_logics: bool = False | ||
|
|
@@ -69,6 +72,18 @@ def __init__(self): | |
| # 的输入会用到,其他模型和场景都不会用到 | ||
| self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None | ||
|
|
||
| # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, | ||
| # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 | ||
| # dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升 | ||
| # prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下 | ||
| # 下面的变量不会被使用。 | ||
| self.need_dp_prefill_balance: bool = False | ||
| self.dp_origin_lens: List[int] = None | ||
| self.dp_handle_lens: List[int] = None | ||
| # self.dp_input_lens: torch.Tensor = None | ||
| self.dp_output_split_sizes: List[List[int]] = None | ||
| self.dp_input_split_sizes: List[List[int]] = None | ||
|
|
||
| def init_some_extra_state(self, model, input_ids: torch.Tensor): | ||
| if self.is_prefill: | ||
| ( | ||
|
|
@@ -123,3 +138,153 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): | |
| for mark, obj in zip(marks_array, multi_objs): | ||
| obj["_prefill_"] = mark > 0 | ||
| return | ||
|
|
||
| def prefill_dp_balance(self, input_ids: torch.Tensor): | ||
| """ | ||
| 在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致 | ||
| 的prefill 推理性能下降 | ||
| """ | ||
| assert self.is_prefill | ||
| import torch.distributed as dist | ||
|
|
||
| self.need_dp_prefill_balance = True | ||
|
|
||
| args = get_env_start_args() | ||
|
|
||
| dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32) | ||
| input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32) | ||
| input_len.fill_(len(input_ids)) | ||
| dist.all_gather_into_tensor( | ||
| output_tensor=dp_input_lens, | ||
| input_tensor=input_len, | ||
| group=self.dist_group.dp_prefill_balance_group, | ||
| async_op=False, | ||
| ) | ||
| dp_input_lens = dp_input_lens.detach().cpu() | ||
| self.dp_origin_lens = dp_input_lens.tolist() | ||
| sum_input_len = dp_input_lens.sum().item() | ||
| dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)] | ||
| for i in range(sum_input_len % args.dp): | ||
| dp_handle_lens[i] += 1 | ||
|
|
||
| self.dp_handle_lens = dp_handle_lens.copy() | ||
|
|
||
| dest_dp_inputs = [[] for _ in range(args.dp)] | ||
| # 分配每个dp 的原始输入和分配后的原始输入 | ||
| origin_datas = collections.deque() | ||
| for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| handle_len = dp_handle_lens[origin_dp_index] | ||
| if origin_dp_input_len > handle_len: | ||
| origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len)) | ||
| dp_handle_lens[origin_dp_index] = 0 | ||
| dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len)) | ||
| else: | ||
| dp_handle_lens[origin_dp_index] -= origin_dp_input_len | ||
| dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len)) | ||
|
|
||
| for dest_dp_index in range(args.dp): | ||
| need_size = dp_handle_lens[dest_dp_index] | ||
| if need_size == 0: | ||
| continue | ||
| while len(origin_datas) != 0: | ||
| origin_data = origin_datas.popleft() | ||
| origin_dp_index, start, end = origin_data | ||
| if end - start > need_size: | ||
| dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size)) | ||
| origin_datas.appendleft((origin_dp_index, start + need_size, end)) | ||
| break | ||
| else: | ||
| dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end)) | ||
| need_size -= end - start | ||
| if need_size == 0: | ||
| break | ||
|
|
||
| dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] | ||
| for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): | ||
| for origin_dp_index, start, end in dest_dp_data: | ||
| dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start | ||
| dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] | ||
| for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): | ||
| for origin_dp_index, start, end in dest_dp_data: | ||
| dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start | ||
|
|
||
| self.dp_input_split_sizes = dp_input_split_sizes | ||
| self.dp_output_split_sizes = dp_output_split_sizes | ||
|
|
||
| new_input_ids = self._all_to_all_balance_get(input_ids) | ||
| if hasattr(self, "position_ids") and self.position_ids is not None: | ||
| # deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量 | ||
| self._unbalance_position_ids = self.position_ids | ||
|
|
||
| self.position_ids = self._all_to_all_balance_get(self.position_ids) | ||
| if hasattr(self, "position_cos") and self.position_cos is not None: | ||
| # deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量 | ||
| self._unbalance_position_cos = self.position_cos | ||
|
|
||
| self.position_cos = self._all_to_all_balance_get(self.position_cos) | ||
| if hasattr(self, "position_sin") and self.position_sin is not None: | ||
| # deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量 | ||
| self._unbalance_position_sin = self.position_sin | ||
|
|
||
| self.position_sin = self._all_to_all_balance_get(self.position_sin) | ||
|
|
||
| return new_input_ids | ||
|
|
||
| def _all_to_all_balance_get(self, data: torch.Tensor): | ||
| dp_rank = get_global_dp_rank() | ||
| import torch.distributed as dist | ||
| from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager | ||
|
|
||
| old_shape = data.shape | ||
| data = data.view(-1) | ||
|
|
||
| origin_len = self.dp_origin_lens[dp_rank] | ||
| assert data.shape[0] % origin_len == 0 | ||
| scale_size = data.shape[0] // origin_len | ||
| handle_len = self.dp_handle_lens[dp_rank] | ||
|
|
||
| dest_data = g_cache_manager.alloc_tensor( | ||
| shape=(handle_len * scale_size,), | ||
| data_type=data.dtype, | ||
| device="cuda", | ||
| is_graph_out=False, | ||
| microbatch_index=self.microbatch_index, | ||
| ) | ||
| dist.all_to_all_single( | ||
| output=dest_data.view(-1), | ||
| input=data.view(-1), | ||
| output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], | ||
| input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], | ||
| group=self.dist_group.dp_prefill_balance_group, | ||
| async_op=False, | ||
| ) | ||
| return dest_data.view(-1, *old_shape[1:]) | ||
|
|
||
| def _all_to_all_unbalance_get(self, data: torch.Tensor): | ||
| dp_rank = get_global_dp_rank() | ||
| import torch.distributed as dist | ||
| from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager | ||
|
|
||
| old_shape = data.shape | ||
| data = data.view(-1) | ||
|
|
||
| handle_len = self.dp_handle_lens[dp_rank] | ||
| scale_size = data.shape[0] // handle_len | ||
| assert data.shape[0] % handle_len == 0 | ||
| origin_len = self.dp_origin_lens[dp_rank] | ||
| origin_data = g_cache_manager.alloc_tensor( | ||
| shape=(origin_len * scale_size,), | ||
| data_type=data.dtype, | ||
| device="cuda", | ||
| is_graph_out=False, | ||
| microbatch_index=self.microbatch_index, | ||
| ) | ||
| dist.all_to_all_single( | ||
| output=origin_data.view(-1), | ||
| input=data, | ||
| output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], | ||
| input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], | ||
| group=self.dist_group.dp_prefill_balance_group, | ||
| async_op=False, | ||
| ) | ||
| return origin_data.view(-1, *old_shape[1:]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import
import torch.distributed as distis local to theprefill_dp_balancemethod. It's better to move it to the top of the file for consistency and to avoid repeated import overhead. The same applies to other local imports in_all_to_all_balance_getand_all_to_all_unbalance_get.