|
1 | 1 | import torch |
| 2 | +import triton |
| 3 | +import collections |
2 | 4 | from lightllm.common.mem_manager import MemoryManager |
3 | 5 | from lightllm.common.req_manager import ReqManager |
4 | 6 | from lightllm.distributed import CustomProcessGroup |
5 | | -from typing import Tuple, Any, Optional |
| 7 | +from typing import Tuple, Any, Optional, List |
6 | 8 | from .triton_kernel.gen_prefill_params import gen_prefill_params |
7 | 9 | from .triton_kernel.gen_decode_params import gen_decode_params |
8 | 10 | from .triton_kernel.multimodal_emb import mark_multimodal_obj |
9 | 11 | from .batch_objs import ModelInput |
| 12 | +from lightllm.utils.envs_utils import get_env_start_args |
| 13 | +from lightllm.utils.dist_utils import get_global_dp_rank |
10 | 14 |
|
11 | 15 |
|
12 | 16 | class InferStateInfo: |
@@ -36,7 +40,6 @@ def __init__(self): |
36 | 40 | self.req_manager: ReqManager = None |
37 | 41 |
|
38 | 42 | self.mem_index: torch.Tensor = None |
39 | | - self.kv_buffer_shapedtype: Tuple[Any, Any] = None |
40 | 43 |
|
41 | 44 | self.is_token_healing: bool = False |
42 | 45 | self.return_all_prompt_logics: bool = False |
@@ -69,6 +72,18 @@ def __init__(self): |
69 | 72 | # 的输入会用到,其他模型和场景都不会用到 |
70 | 73 | self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None |
71 | 74 |
|
| 75 | + # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, |
| 76 | + # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 |
| 77 | + # dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升 |
| 78 | + # prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下 |
| 79 | + # 下面的变量不会被使用。 |
| 80 | + self.need_dp_prefill_balance: bool = False |
| 81 | + self.dp_origin_lens: List[int] = None |
| 82 | + self.dp_handle_lens: List[int] = None |
| 83 | + # self.dp_input_lens: torch.Tensor = None |
| 84 | + self.dp_output_split_sizes: List[List[int]] = None |
| 85 | + self.dp_input_split_sizes: List[List[int]] = None |
| 86 | + |
72 | 87 | def init_some_extra_state(self, model, input_ids: torch.Tensor): |
73 | 88 | if self.is_prefill: |
74 | 89 | ( |
@@ -123,3 +138,153 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): |
123 | 138 | for mark, obj in zip(marks_array, multi_objs): |
124 | 139 | obj["_prefill_"] = mark > 0 |
125 | 140 | return |
| 141 | + |
| 142 | + def prefill_dp_balance(self, input_ids: torch.Tensor): |
| 143 | + """ |
| 144 | + 在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致 |
| 145 | + 的prefill 推理性能下降 |
| 146 | + """ |
| 147 | + assert self.is_prefill |
| 148 | + import torch.distributed as dist |
| 149 | + |
| 150 | + self.need_dp_prefill_balance = True |
| 151 | + |
| 152 | + args = get_env_start_args() |
| 153 | + |
| 154 | + dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32) |
| 155 | + input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32) |
| 156 | + input_len.fill_(len(input_ids)) |
| 157 | + dist.all_gather_into_tensor( |
| 158 | + output_tensor=dp_input_lens, |
| 159 | + input_tensor=input_len, |
| 160 | + group=self.dist_group.dp_prefill_balance_group, |
| 161 | + async_op=False, |
| 162 | + ) |
| 163 | + dp_input_lens = dp_input_lens.detach().cpu() |
| 164 | + self.dp_origin_lens = dp_input_lens.tolist() |
| 165 | + sum_input_len = dp_input_lens.sum().item() |
| 166 | + dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)] |
| 167 | + for i in range(sum_input_len % args.dp): |
| 168 | + dp_handle_lens[i] += 1 |
| 169 | + |
| 170 | + self.dp_handle_lens = dp_handle_lens.copy() |
| 171 | + |
| 172 | + dest_dp_inputs = [[] for _ in range(args.dp)] |
| 173 | + # 分配每个dp 的原始输入和分配后的原始输入 |
| 174 | + origin_datas = collections.deque() |
| 175 | + for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()): |
| 176 | + handle_len = dp_handle_lens[origin_dp_index] |
| 177 | + if origin_dp_input_len > handle_len: |
| 178 | + origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len)) |
| 179 | + dp_handle_lens[origin_dp_index] = 0 |
| 180 | + dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len)) |
| 181 | + else: |
| 182 | + dp_handle_lens[origin_dp_index] -= origin_dp_input_len |
| 183 | + dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len)) |
| 184 | + |
| 185 | + for dest_dp_index in range(args.dp): |
| 186 | + need_size = dp_handle_lens[dest_dp_index] |
| 187 | + if need_size == 0: |
| 188 | + continue |
| 189 | + while len(origin_datas) != 0: |
| 190 | + origin_data = origin_datas.popleft() |
| 191 | + origin_dp_index, start, end = origin_data |
| 192 | + if end - start > need_size: |
| 193 | + dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size)) |
| 194 | + origin_datas.appendleft((origin_dp_index, start + need_size, end)) |
| 195 | + break |
| 196 | + else: |
| 197 | + dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end)) |
| 198 | + need_size -= end - start |
| 199 | + if need_size == 0: |
| 200 | + break |
| 201 | + |
| 202 | + dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] |
| 203 | + for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): |
| 204 | + for origin_dp_index, start, end in dest_dp_data: |
| 205 | + dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start |
| 206 | + dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] |
| 207 | + for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): |
| 208 | + for origin_dp_index, start, end in dest_dp_data: |
| 209 | + dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start |
| 210 | + |
| 211 | + self.dp_input_split_sizes = dp_input_split_sizes |
| 212 | + self.dp_output_split_sizes = dp_output_split_sizes |
| 213 | + |
| 214 | + new_input_ids = self._all_to_all_balance_get(input_ids) |
| 215 | + if hasattr(self, "position_ids") and self.position_ids is not None: |
| 216 | + # deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量 |
| 217 | + self._unbalance_position_ids = self.position_ids |
| 218 | + |
| 219 | + self.position_ids = self._all_to_all_balance_get(self.position_ids) |
| 220 | + if hasattr(self, "position_cos") and self.position_cos is not None: |
| 221 | + # deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量 |
| 222 | + self._unbalance_position_cos = self.position_cos |
| 223 | + |
| 224 | + self.position_cos = self._all_to_all_balance_get(self.position_cos) |
| 225 | + if hasattr(self, "position_sin") and self.position_sin is not None: |
| 226 | + # deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量 |
| 227 | + self._unbalance_position_sin = self.position_sin |
| 228 | + |
| 229 | + self.position_sin = self._all_to_all_balance_get(self.position_sin) |
| 230 | + |
| 231 | + return new_input_ids |
| 232 | + |
| 233 | + def _all_to_all_balance_get(self, data: torch.Tensor): |
| 234 | + dp_rank = get_global_dp_rank() |
| 235 | + import torch.distributed as dist |
| 236 | + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager |
| 237 | + |
| 238 | + old_shape = data.shape |
| 239 | + data = data.view(-1) |
| 240 | + |
| 241 | + origin_len = self.dp_origin_lens[dp_rank] |
| 242 | + assert data.shape[0] % origin_len == 0 |
| 243 | + scale_size = data.shape[0] // origin_len |
| 244 | + handle_len = self.dp_handle_lens[dp_rank] |
| 245 | + |
| 246 | + dest_data = g_cache_manager.alloc_tensor( |
| 247 | + shape=(handle_len * scale_size,), |
| 248 | + data_type=data.dtype, |
| 249 | + device="cuda", |
| 250 | + is_graph_out=False, |
| 251 | + microbatch_index=self.microbatch_index, |
| 252 | + ) |
| 253 | + dist.all_to_all_single( |
| 254 | + output=dest_data.view(-1), |
| 255 | + input=data.view(-1), |
| 256 | + output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], |
| 257 | + input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], |
| 258 | + group=self.dist_group.dp_prefill_balance_group, |
| 259 | + async_op=False, |
| 260 | + ) |
| 261 | + return dest_data.view(-1, *old_shape[1:]) |
| 262 | + |
| 263 | + def _all_to_all_unbalance_get(self, data: torch.Tensor): |
| 264 | + dp_rank = get_global_dp_rank() |
| 265 | + import torch.distributed as dist |
| 266 | + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager |
| 267 | + |
| 268 | + old_shape = data.shape |
| 269 | + data = data.view(-1) |
| 270 | + |
| 271 | + handle_len = self.dp_handle_lens[dp_rank] |
| 272 | + scale_size = data.shape[0] // handle_len |
| 273 | + assert data.shape[0] % handle_len == 0 |
| 274 | + origin_len = self.dp_origin_lens[dp_rank] |
| 275 | + origin_data = g_cache_manager.alloc_tensor( |
| 276 | + shape=(origin_len * scale_size,), |
| 277 | + data_type=data.dtype, |
| 278 | + device="cuda", |
| 279 | + is_graph_out=False, |
| 280 | + microbatch_index=self.microbatch_index, |
| 281 | + ) |
| 282 | + dist.all_to_all_single( |
| 283 | + output=origin_data.view(-1), |
| 284 | + input=data, |
| 285 | + output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], |
| 286 | + input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], |
| 287 | + group=self.dist_group.dp_prefill_balance_group, |
| 288 | + async_op=False, |
| 289 | + ) |
| 290 | + return origin_data.view(-1, *old_shape[1:]) |
0 commit comments