diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5edf6ad20..77ca299b2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -290,10 +290,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.req_manager = self.req_manager infer_state.mem_index = model_input.mem_indexes - infer_state.kv_buffer_shapedtype = ( - (model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, - ) infer_state.microbatch_index = microbatch_index infer_state.dist_group = dist_group_manager.get_group(microbatch_index) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 51d84df4f..b966ee043 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -1,12 +1,16 @@ import torch +import triton +import collections from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager from lightllm.distributed import CustomProcessGroup -from typing import Tuple, Any, Optional +from typing import Tuple, Any, Optional, List from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params from .triton_kernel.multimodal_emb import mark_multimodal_obj from .batch_objs import ModelInput +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.dist_utils import get_global_dp_rank class InferStateInfo: @@ -36,7 +40,6 @@ def __init__(self): self.req_manager: ReqManager = None self.mem_index: torch.Tensor = None - self.kv_buffer_shapedtype: Tuple[Any, Any] = None self.is_token_healing: bool = False self.return_all_prompt_logics: bool = False @@ -69,6 +72,18 @@ def __init__(self): # 的输入会用到,其他模型和场景都不会用到 self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, + # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 + # dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升 + # prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下 + # 下面的变量不会被使用。 + self.need_dp_prefill_balance: bool = False + self.dp_origin_lens: List[int] = None + self.dp_handle_lens: List[int] = None + # self.dp_input_lens: torch.Tensor = None + self.dp_output_split_sizes: List[List[int]] = None + self.dp_input_split_sizes: List[List[int]] = None + def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: ( @@ -123,3 +138,153 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): for mark, obj in zip(marks_array, multi_objs): obj["_prefill_"] = mark > 0 return + + def prefill_dp_balance(self, input_ids: torch.Tensor): + """ + 在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致 + 的prefill 推理性能下降 + """ + assert self.is_prefill + import torch.distributed as dist + + self.need_dp_prefill_balance = True + + args = get_env_start_args() + + dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32) + input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32) + input_len.fill_(len(input_ids)) + dist.all_gather_into_tensor( + output_tensor=dp_input_lens, + input_tensor=input_len, + group=self.dist_group.dp_prefill_balance_group, + async_op=False, + ) + dp_input_lens = dp_input_lens.detach().cpu() + self.dp_origin_lens = dp_input_lens.tolist() + sum_input_len = dp_input_lens.sum().item() + dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)] + for i in range(sum_input_len % args.dp): + dp_handle_lens[i] += 1 + + self.dp_handle_lens = dp_handle_lens.copy() + + dest_dp_inputs = [[] for _ in range(args.dp)] + # 分配每个dp 的原始输入和分配后的原始输入 + origin_datas = collections.deque() + for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()): + handle_len = dp_handle_lens[origin_dp_index] + if origin_dp_input_len > handle_len: + origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len)) + dp_handle_lens[origin_dp_index] = 0 + dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len)) + else: + dp_handle_lens[origin_dp_index] -= origin_dp_input_len + dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len)) + + for dest_dp_index in range(args.dp): + need_size = dp_handle_lens[dest_dp_index] + if need_size == 0: + continue + while len(origin_datas) != 0: + origin_data = origin_datas.popleft() + origin_dp_index, start, end = origin_data + if end - start > need_size: + dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size)) + origin_datas.appendleft((origin_dp_index, start + need_size, end)) + break + else: + dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end)) + need_size -= end - start + if need_size == 0: + break + + dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] + for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): + for origin_dp_index, start, end in dest_dp_data: + dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start + dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] + for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): + for origin_dp_index, start, end in dest_dp_data: + dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start + + self.dp_input_split_sizes = dp_input_split_sizes + self.dp_output_split_sizes = dp_output_split_sizes + + new_input_ids = self._all_to_all_balance_get(input_ids) + if hasattr(self, "position_ids") and self.position_ids is not None: + # deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量 + self._unbalance_position_ids = self.position_ids + + self.position_ids = self._all_to_all_balance_get(self.position_ids) + if hasattr(self, "position_cos") and self.position_cos is not None: + # deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量 + self._unbalance_position_cos = self.position_cos + + self.position_cos = self._all_to_all_balance_get(self.position_cos) + if hasattr(self, "position_sin") and self.position_sin is not None: + # deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量 + self._unbalance_position_sin = self.position_sin + + self.position_sin = self._all_to_all_balance_get(self.position_sin) + + return new_input_ids + + def _all_to_all_balance_get(self, data: torch.Tensor): + dp_rank = get_global_dp_rank() + import torch.distributed as dist + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + old_shape = data.shape + data = data.view(-1) + + origin_len = self.dp_origin_lens[dp_rank] + assert data.shape[0] % origin_len == 0 + scale_size = data.shape[0] // origin_len + handle_len = self.dp_handle_lens[dp_rank] + + dest_data = g_cache_manager.alloc_tensor( + shape=(handle_len * scale_size,), + data_type=data.dtype, + device="cuda", + is_graph_out=False, + microbatch_index=self.microbatch_index, + ) + dist.all_to_all_single( + output=dest_data.view(-1), + input=data.view(-1), + output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], + input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], + group=self.dist_group.dp_prefill_balance_group, + async_op=False, + ) + return dest_data.view(-1, *old_shape[1:]) + + def _all_to_all_unbalance_get(self, data: torch.Tensor): + dp_rank = get_global_dp_rank() + import torch.distributed as dist + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + old_shape = data.shape + data = data.view(-1) + + handle_len = self.dp_handle_lens[dp_rank] + scale_size = data.shape[0] // handle_len + assert data.shape[0] % handle_len == 0 + origin_len = self.dp_origin_lens[dp_rank] + origin_data = g_cache_manager.alloc_tensor( + shape=(origin_len * scale_size,), + data_type=data.dtype, + device="cuda", + is_graph_out=False, + microbatch_index=self.microbatch_index, + ) + dist.all_to_all_single( + output=origin_data.view(-1), + input=data, + output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], + input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], + group=self.dist_group.dp_prefill_balance_group, + async_op=False, + ) + return origin_data.view(-1, *old_shape[1:]) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py index 548165e59..fefb3d162 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py @@ -44,13 +44,12 @@ def _bind_rotary_emb_fwd(self): def _get_qkv( self, input, infer_state: InferStateInfo, layer_weight ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.mm( + cache_kv = torch.mm( input.view(-1, self.embed_dim_), layer_weight.kv_weight_, - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + if self.use_qk_norm_: q = q.view(-1, self.tp_q_head_num_, self.head_dim_) k = cache_kv[:, 0 : self.tp_k_head_num_, :] diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 86691b93f..7567bc644 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -30,16 +30,6 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - cache_kv = self.alloc_tensor( - shape=infer_state.kv_buffer_shapedtype[0], - dtype=infer_state.kv_buffer_shapedtype[1], - device="cuda", - is_graph_out=False, - microbatch_index=infer_state.microbatch_index, - ) - return cache_kv - def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: raise Exception("need to impl") diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index eb66ec056..d5c96f821 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -36,6 +36,7 @@ get_global_rank, get_current_rank_in_dp, create_new_group_for_current_dp, + create_dp_special_inter_group, ) from lightllm.utils.device_utils import get_device_sm_count from lightllm.utils.sgl_utils import HAS_SGL_KERNEL @@ -62,6 +63,11 @@ def __init__(self): self.custom_gather = None self.dp_world_size = get_dp_world_size() self.device_group = create_new_group_for_current_dp("nccl") + if get_env_start_args().enable_dp_prefill_balance: + self.dp_prefill_balance_group = create_dp_special_inter_group("nccl") + else: + self.dp_prefill_balance_group = None + self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo") def init_custom_reduce(self) -> None: diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index cf3e5e1f0..8299697f3 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -47,10 +47,7 @@ def _get_qkv( self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> Tuple[torch.Tensor, torch.Tensor]: q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) return q, cache_kv def _context_attention_kernel( diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index a9de83c8f..30d37d1df 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -143,14 +143,6 @@ def _bind_attention(self): Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self ) - def _pre_cache_kv( - self, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight - ) -> torch.Tensor: - # q_lora_rank 不是None的时候,融合 q_a_proj 和 kv_a_proj_with_mqa - if self.q_lora_rank is None: - return super()._pre_cache_kv(infer_state, layer_weight) - return None - def _get_qkv( self, input: torch.Tensor, @@ -161,8 +153,7 @@ def _get_qkv( if self.q_lora_rank is None: q = layer_weight.q_weight_.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)) + cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) else: q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 @@ -203,8 +194,25 @@ def _tpsp_get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_weight_.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)) + cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + return q, cache_kv else: input = input.view(-1, self.embed_dim_) qkv = layer_weight.qkv_a_proj_with_mqa_.mm(input) @@ -217,25 +225,33 @@ def _tpsp_get_qkv( all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False) qkv = gather_qkv[0 : len(infer_state.position_cos), :] + if infer_state.need_dp_prefill_balance: + qkv = infer_state._all_to_all_unbalance_get(data=qkv) + position_cos = infer_state._unbalance_position_cos + position_sin = infer_state._unbalance_position_sin + else: + position_cos = infer_state.position_cos + position_sin = infer_state.position_sin + q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) - q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) - q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - weight=layer_weight.kv_a_layernorm_.weight, - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], - ) - rotary_emb_fwd( - q_rope, - cache_kv[:, :, self.kv_lora_rank :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + position_cos, + position_sin, + ) + return q, cache_kv def _get_o( self, input: torch.Tensor, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight @@ -248,6 +264,9 @@ def _get_o( def _tpsp_get_o( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_balance_get(data=input) + if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) @@ -255,6 +274,10 @@ def _tpsp_get_o( dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) + e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] + if e_o_tensor.shape[0] > 0: + e_o_tensor.fill_(0) + if self.tp_world_size_ > 1: sp_token_num = o_tensor.shape[0] // self.tp_world_size_ reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device) diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index b1f1bee94..09efe9a36 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -87,9 +87,9 @@ def _get_qkv( # kv = kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) k = layer_weight.k_proj.mm(input) v = layer_weight.v_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv[:, 0 : self.tp_k_head_num_, :] = k.view(-1, self.tp_k_head_num_, self.head_dim_) - cache_kv[:, self.tp_k_head_num_ :, :] = v.view(-1, self.tp_v_head_num_, self.head_dim_) + cache_kv = torch.cat( + [k.view(-1, self.tp_k_head_num_, self.head_dim_), v.view(-1, self.tp_v_head_num_, self.head_dim_)], dim=1 + ) # gemma3 use qk norm q = q.view(-1, self.tp_q_head_num_, self.head_dim_) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 202df6969..711f63021 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -116,6 +116,9 @@ def tpsp_token_forward( # len(infer_state.position_sin) 获取真实输入长度 input_embdings = gather_data[0 : len(infer_state.position_sin)] + if infer_state.need_dp_prefill_balance: + input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) + return self.token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight) def overlap_tpsp_token_forward( @@ -130,12 +133,18 @@ def overlap_tpsp_token_forward( infer_state.hook() infer_state.hook = None + if infer_state.need_dp_prefill_balance: + input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings) + logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() infer_state1.hook = None + if infer_state1.need_dp_prefill_balance: + input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1) + logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) return logics, logics1 diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index a314a28d0..99b7db5bf 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -9,6 +9,7 @@ from lightllm.utils.infer_utils import mark_cost_time from lightllm.models.llama.triton_kernel.embedding import embedding from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args class LlamaPreLayerInfer(PreLayerInferTpl): @@ -42,6 +43,9 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh def tpsp_context_forward( self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight ): + if get_env_start_args().enable_dp_prefill_balance: + input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) + input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy @@ -86,12 +90,17 @@ def overlap_tpsp_context_forward( infer_state1: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight, ): + if get_env_start_args().enable_dp_prefill_balance: + input_ids = infer_state.prefill_dp_balance(input_ids=input_ids) input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight) from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_) + if get_env_start_args().enable_dp_prefill_balance: + input_ids1 = infer_state1.prefill_dp_balance(input_ids=input_ids1) + input_embdings1 = self.context_forward( input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight ) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 268ff8cdc..bb38c45bb 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -197,10 +197,7 @@ def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: q = layer_weight.q_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), @@ -222,10 +219,7 @@ def _tpsp_get_qkv( input = gather_input[0 : len(infer_state.position_cos), :] q = layer_weight.q_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), @@ -233,6 +227,11 @@ def _tpsp_get_qkv( infer_state.position_cos, infer_state.position_sin, ) + + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + return q, cache_kv def _context_attention_flashinfer_kernel_fp8( @@ -402,10 +401,16 @@ def _get_o( def _tpsp_get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_balance_get(data=input) + input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) + e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] + if e_o_tensor.shape[0] > 0: + e_o_tensor.fill_(0) if self.tp_world_size_ > 1: sp_token_num = o_tensor.shape[0] // self.tp_world_size_ diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index aefb461ec..ce27e3ee5 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -29,10 +29,8 @@ def _bind_attention(self): def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): q = layer_weight.q_proj.mm(input_emb.view(-1, self.embed_dim_)) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) cache_kv = layer_weight.kv_proj.mm( input_emb.view(-1, self.embed_dim_), - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 242e35b1c..7a4b2ca81 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -2,7 +2,7 @@ import torch.functional as F import torch.distributed as dist import numpy as np - +from typing import Tuple from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.qwen.layer_weights.transformer_layer_weight import QwenTransformerLayerWeight @@ -18,10 +18,9 @@ def __init__(self, layer_num, network_config, mode=[]): def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): q = layer_weight.q_proj.mm(input_emb) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input_emb, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input_emb).view( + -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ + ) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), @@ -32,3 +31,7 @@ def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: Qwe if infer_state.logn_values is not None: q.mul_(infer_state.logn_values.view(-1, 1)) return q, cache_kv + + def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + raise Exception("not impl") diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index 63efeded8..27fe466e9 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -2,6 +2,7 @@ import torch.functional as F import torch.distributed as dist import numpy as np +from typing import Tuple from functools import partial from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton @@ -19,10 +20,7 @@ def __init__(self, layer_num, network_config, mode=[]): def _get_qkv(self, input, infer_state, layer_weight): q = layer_weight.q_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) seq_len, _ = q.shape q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2) self.axis_map = self.axis_map.to(q.device) @@ -32,3 +30,7 @@ def _get_qkv(self, input, infer_state, layer_weight): cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2) return new_q, cache_kv + + def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + raise Exception("not impl") diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 4ae9120e4..68891b6bf 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -31,10 +31,7 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rmsnorm_forward( q.view(-1, self.head_dim_), @@ -56,3 +53,7 @@ def _get_qkv( infer_state.position_sin, ) return q, cache_kv + + def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + raise Exception("not impl") diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index ea2000c41..45f1f59d7 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -14,7 +14,7 @@ from functools import partial from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_global_world_size -from lightllm.distributed.communication_op import all_gather_into_tensor +from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor logger = init_logger(__name__) @@ -45,10 +45,13 @@ def _bind_ffn(self): moe_mode = os.environ.get("MOE_MODE", "TP") if moe_mode == "EP": self._ffn = partial(Qwen3MOETransformerLayerInfer._moe_ffn_edp, self) + self._tpsp_ffn = self._tpsp_ffn_ep else: self._ffn = partial(Qwen3MOETransformerLayerInfer._moe_ffn, self) + self._tpsp_ffn = self._tpsp_ffn_tp else: self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) + self._tpsp_ffn = self._tpsp_ffn_tp def _get_qkv( self, @@ -58,10 +61,7 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rmsnorm_forward( q.view(-1, self.head_dim_), weight=layer_weight.q_norm_weight_.weight, @@ -99,10 +99,7 @@ def _tpsp_get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) - cache_kv = layer_weight.kv_proj.mm( - input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rmsnorm_forward( q.view(-1, self.head_dim_), @@ -123,6 +120,11 @@ def _tpsp_get_qkv( infer_state.position_cos, infer_state.position_sin, ) + + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + return q, cache_kv def _moe_ffn( @@ -165,6 +167,45 @@ def _moe_ffn_edp( ep_output = ep_output.view(token_num, hidden_dim) return ep_output + def _tpsp_ffn( + self, input: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight + ): + raise Exception("need bind to real impl") + + def _tpsp_ffn_tp( + self, input: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + if self.tp_world_size_ > 1: + sp_token_num, hidden_dim = input.shape + gather_input = self.alloc_tensor( + (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device + ) + all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) + input = gather_input + + ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight) + + if self.tp_world_size_ > 1: + sp_token_num = ffn2_out.shape[0] // self.tp_world_size_ + reduce_o_tensor = self.alloc_tensor( + (sp_token_num, self.embed_dim_), dtype=ffn2_out.dtype, device=ffn2_out.device + ) + reduce_scatter_tensor( + reduce_o_tensor, ffn2_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False + ) + ffn2_out = reduce_o_tensor + return ffn2_out + + def _tpsp_ffn_ep( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + ffn2_out = self._ffn(input=input, infer_state=infer_state, layer_weight=layer_weight) + + return ffn2_out + def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 47ba79823..53171ce53 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -3,7 +3,7 @@ import torch.distributed as dist import numpy as np from functools import partial - +from typing import Tuple from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.models.stablelm.layer_weights.transformer_layer_weight import StablelmTransformerLayerWeight @@ -26,10 +26,8 @@ def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) - cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight) cache_kv = layer_weight.kv_proj.mm( input.view(-1, self.embed_dim_), - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), @@ -40,6 +38,10 @@ def _get_qkv( ) return q, cache_kv + def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + raise Exception("not impl") + def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: @@ -48,6 +50,10 @@ def _get_o( ) return o_tensor + def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + raise Exception("not impl") + def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index df4641a65..7135b9412 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -241,14 +241,6 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""aggressive schedule can lead to frequent prefill interruptions during decode. disabling it allows the router_max_wait_tokens parameter to work more effectively.""", ) - parser.add_argument( - "--dp_prefill_wait_step", - type=int, - default=0, - help="""dp_prefill_wait_step is used to control the pacing of dp chunked prefill mode, aiming to reduce - computational waste during prefill. However, higher values can negatively impact the - first token latency. It is generally recommended to set this value between 0 and 12.""", - ) parser.add_argument( "--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use." @@ -292,6 +284,11 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""inference backend will use TP SP Mixed running mode. only llama and deepseek v3 model supported now.""", ) + parser.add_argument( + "--enable_dp_prefill_balance", + action="store_true", + help="inference backend will use dp balance, need set --enable_tpsp_mix_mode first", + ) parser.add_argument( "--enable_prefill_microbatch_overlap", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8557be579..f73be30db 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -61,6 +61,10 @@ def signal_handler(sig, frame): def normal_or_p_d_start(args): + from lightllm.server.core.objs.start_args_type import StartArgs + + args: StartArgs = args + set_unique_server_name(args) if not args.disable_shm_warning: @@ -135,6 +139,9 @@ def normal_or_p_d_start(args): if args.diverse_mode: assert args.router_token_ratio == 0.0 + if args.enable_dp_prefill_balance: + assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1" + # mtp params check if args.mtp_mode is not None: assert args.mtp_draft_model_dir is not None diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f7fbfb974..2c83750c6 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -310,7 +310,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len): # 就是通过模拟加长其输出token长度,来延长其在估计阶段的生命周期。max_waiting_token # 的计算是保守的,每次chuncked prefill 延迟的最大步数为两种模式之合,因为 # 这个并不会导致预估的token占用量大幅增加,所以可以放心使用。 - max_waiting_token = args.router_max_wait_tokens + args.dp_prefill_wait_step + max_waiting_token = args.router_max_wait_tokens has_out_len = self.shm_cur_output_len if self.sample_params.ignore_eos: cur_max_new_token_len = self.sample_params.max_new_tokens diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index cd3166616..69d907fff 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -49,7 +49,6 @@ class StartArgs: router_token_ratio: float = field(default=0.0) router_max_new_token_len: int = field(default=1024) router_max_wait_tokens: int = field(default=1) - dp_prefill_wait_step: int = field(default=0) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) @@ -61,6 +60,7 @@ class StartArgs: enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) enable_tpsp_mix_mode: bool = field(default=False) + enable_dp_prefill_balance: bool = field(default=False) enable_decode_microbatch_overlap: bool = field(default=False) enable_prefill_microbatch_overlap: bool = field(default=False) cache_capacity: int = field(default=200) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/control_state.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/control_state.py index 4268883cc..ffd92202d 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/control_state.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/control_state.py @@ -16,10 +16,6 @@ def __init__(self, backend: ModeBackend): self.left_decode_num = self.decode_max_step self.step_count = 0 - - # dp prefill 配平调度的延迟参数。 - self.dp_prefill_wait_step = 0 - self.dp_prefill_wait_max_step = get_env_start_args().dp_prefill_wait_step return def select_run_way( @@ -71,17 +67,7 @@ def _normal_way( prefill_reqs: List[InferReq], decode_reqs: List[InferReq], ): - """ - _normal_way 接口用于控制 DP 模式下进行chuncked prefill时,需要考虑各个DP的真实运行请求数量: - 考虑 8 个 dp 的场景,如果每个 dp 执行 prefill 的请求的数量分别为: [1, 1, 0, 0, 0, 0, 0, 0], 则在运行 - 的过程中,请求数量为0的dp会pad一个fake req来参与计算,但是这会导致这些dp因为一些通信同步的原因,造成大量 - 算力浪费,实际有效率很低。 - 解决方法: - 在判断是否可以进行 prefill 的时候,需要先考虑所有dp的请求数量是否均衡,浪费率是否在可以接受的范围,如果无法 - 接受这么高的浪费率,则可以延迟 prefill 的执行时机,直到所有dp的浪费率较低时再进行prefill, 不过延迟执行的极限 - 等待时间,受到 dp_prefill_wait_step 参数的控制。 - """ - use_ratio = np.count_nonzero(dp_prefill_req_nums) / dp_prefill_req_nums.shape[0] + # use_ratio = np.count_nonzero(dp_prefill_req_nums) / dp_prefill_req_nums.shape[0] max_decode_num = np.max(dp_decode_req_nums) max_prefill_num = np.max(dp_prefill_req_nums) @@ -89,30 +75,15 @@ def _normal_way( self.left_decode_num -= 1 return RunWay.DECODE - if use_ratio < 0.6: - if max_prefill_num > 0: - self.dp_prefill_wait_step += 1 - if self.dp_prefill_wait_step > self.dp_prefill_wait_max_step: - # prefill 一次允许进行几次 decode 操作。 - self.left_decode_num = self.decode_max_step - self.dp_prefill_wait_step = max(0, (self.dp_prefill_wait_step - self.decode_max_step)) - return RunWay.PREFILL - + if max_prefill_num > 0: + # prefill 一次允许进行几次 decode 操作。 + self.left_decode_num = self.decode_max_step + return RunWay.PREFILL + else: if max_decode_num > 0: return RunWay.DECODE else: return RunWay.PASS - else: - if max_prefill_num > 0: - self.dp_prefill_wait_step = 0 - # prefill 一次允许进行几次 decode 操作。 - self.left_decode_num = self.decode_max_step - return RunWay.PREFILL - else: - if max_decode_num > 0: - return RunWay.DECODE - else: - return RunWay.PASS def try_recover_paused_reqs(self) -> bool: return self.step_count % 100 == 0 diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 8b6c2f893..1bb625db0 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -101,7 +101,7 @@ def rpc_loop(self): logger.exception(str(e)) error_count += 1 - if error_count >= 3: + if error_count >= 1: logger.error("infer process error to exit") os._exit(-1) diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 76afa479c..65ac401d4 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -220,6 +220,28 @@ def create_new_group_for_current_node(backend): return ans_group +def create_dp_special_inter_group(backend): + """ + 创建一种特殊的通信组 + 假设全局通信组为 [0, 1, 2, 3, 4, 5, 6, 7], 其中 + 0,1,2,3 为一个dp, 4,5,6,7 为另一个 dp, 则在[0,4], + [1,5], [2,6], [3,7] 间建立通信组 + """ + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + ans_group = None + dp_size = args.dp + dp_world_size = get_dp_world_size() + rank = get_global_rank() + for iter_tp_rank in range(dp_world_size): + ranks = list(iter_tp_rank + i * dp_world_size for i in range(dp_size)) + device_group = dist.new_group(ranks, backend=backend) + if rank in ranks: + ans_group = device_group + return ans_group + + def _init_nccl_env(): from lightllm.utils.envs_utils import get_env_start_args