Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.req_manager = self.req_manager

infer_state.mem_index = model_input.mem_indexes
infer_state.kv_buffer_shapedtype = (
(model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
self.data_type,
)
infer_state.microbatch_index = microbatch_index
infer_state.dist_group = dist_group_manager.get_group(microbatch_index)

Expand Down
169 changes: 167 additions & 2 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import torch
import triton
import collections
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager
from lightllm.distributed import CustomProcessGroup
from typing import Tuple, Any, Optional
from typing import Tuple, Any, Optional, List
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from .triton_kernel.multimodal_emb import mark_multimodal_obj
from .batch_objs import ModelInput
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.dist_utils import get_global_dp_rank


class InferStateInfo:
Expand Down Expand Up @@ -36,7 +40,6 @@ def __init__(self):
self.req_manager: ReqManager = None

self.mem_index: torch.Tensor = None
self.kv_buffer_shapedtype: Tuple[Any, Any] = None

self.is_token_healing: bool = False
self.return_all_prompt_logics: bool = False
Expand Down Expand Up @@ -69,6 +72,18 @@ def __init__(self):
# 的输入会用到,其他模型和场景都不会用到
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
# dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升
# prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下
# 下面的变量不会被使用。
self.need_dp_prefill_balance: bool = False
self.dp_origin_lens: List[int] = None
self.dp_handle_lens: List[int] = None
# self.dp_input_lens: torch.Tensor = None
self.dp_output_split_sizes: List[List[int]] = None
self.dp_input_split_sizes: List[List[int]] = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
(
Expand Down Expand Up @@ -123,3 +138,153 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
for mark, obj in zip(marks_array, multi_objs):
obj["_prefill_"] = mark > 0
return

def prefill_dp_balance(self, input_ids: torch.Tensor):
"""
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致
的prefill 推理性能下降
"""
assert self.is_prefill
import torch.distributed as dist

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import import torch.distributed as dist is local to the prefill_dp_balance method. It's better to move it to the top of the file for consistency and to avoid repeated import overhead. The same applies to other local imports in _all_to_all_balance_get and _all_to_all_unbalance_get.


self.need_dp_prefill_balance = True

args = get_env_start_args()

dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32)
input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32)
input_len.fill_(len(input_ids))
dist.all_gather_into_tensor(
output_tensor=dp_input_lens,
input_tensor=input_len,
group=self.dist_group.dp_prefill_balance_group,
async_op=False,
)
dp_input_lens = dp_input_lens.detach().cpu()
self.dp_origin_lens = dp_input_lens.tolist()
sum_input_len = dp_input_lens.sum().item()
dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)]
for i in range(sum_input_len % args.dp):
dp_handle_lens[i] += 1

self.dp_handle_lens = dp_handle_lens.copy()

dest_dp_inputs = [[] for _ in range(args.dp)]
# 分配每个dp 的原始输入和分配后的原始输入
origin_datas = collections.deque()
for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .numpy() inside a loop can be inefficient as it may cause a GPU-to-CPU synchronization on each iteration. It's better to call it once before the loop and iterate over the resulting numpy array.

handle_len = dp_handle_lens[origin_dp_index]
if origin_dp_input_len > handle_len:
origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len))
dp_handle_lens[origin_dp_index] = 0
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len))
else:
dp_handle_lens[origin_dp_index] -= origin_dp_input_len
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len))

for dest_dp_index in range(args.dp):
need_size = dp_handle_lens[dest_dp_index]
if need_size == 0:
continue
while len(origin_datas) != 0:
origin_data = origin_datas.popleft()
origin_dp_index, start, end = origin_data
if end - start > need_size:
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size))
origin_datas.appendleft((origin_dp_index, start + need_size, end))
break
else:
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end))
need_size -= end - start
if need_size == 0:
break

dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
for origin_dp_index, start, end in dest_dp_data:
dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start
dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
for origin_dp_index, start, end in dest_dp_data:
dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start

self.dp_input_split_sizes = dp_input_split_sizes
self.dp_output_split_sizes = dp_output_split_sizes

new_input_ids = self._all_to_all_balance_get(input_ids)
if hasattr(self, "position_ids") and self.position_ids is not None:
# deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量
self._unbalance_position_ids = self.position_ids

self.position_ids = self._all_to_all_balance_get(self.position_ids)
if hasattr(self, "position_cos") and self.position_cos is not None:
# deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量
self._unbalance_position_cos = self.position_cos

self.position_cos = self._all_to_all_balance_get(self.position_cos)
if hasattr(self, "position_sin") and self.position_sin is not None:
# deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量
self._unbalance_position_sin = self.position_sin

self.position_sin = self._all_to_all_balance_get(self.position_sin)

return new_input_ids

def _all_to_all_balance_get(self, data: torch.Tensor):
dp_rank = get_global_dp_rank()
import torch.distributed as dist
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

old_shape = data.shape
data = data.view(-1)

origin_len = self.dp_origin_lens[dp_rank]
assert data.shape[0] % origin_len == 0
scale_size = data.shape[0] // origin_len
handle_len = self.dp_handle_lens[dp_rank]

dest_data = g_cache_manager.alloc_tensor(
shape=(handle_len * scale_size,),
data_type=data.dtype,
device="cuda",
is_graph_out=False,
microbatch_index=self.microbatch_index,
)
dist.all_to_all_single(
output=dest_data.view(-1),
input=data.view(-1),
output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
group=self.dist_group.dp_prefill_balance_group,
async_op=False,
)
return dest_data.view(-1, *old_shape[1:])

def _all_to_all_unbalance_get(self, data: torch.Tensor):
dp_rank = get_global_dp_rank()
import torch.distributed as dist
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

old_shape = data.shape
data = data.view(-1)

handle_len = self.dp_handle_lens[dp_rank]
scale_size = data.shape[0] // handle_len
assert data.shape[0] % handle_len == 0
origin_len = self.dp_origin_lens[dp_rank]
origin_data = g_cache_manager.alloc_tensor(
shape=(origin_len * scale_size,),
data_type=data.dtype,
device="cuda",
is_graph_out=False,
microbatch_index=self.microbatch_index,
)
dist.all_to_all_single(
output=origin_data.view(-1),
input=data,
output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
group=self.dist_group.dp_prefill_balance_group,
async_op=False,
)
return origin_data.view(-1, *old_shape[1:])
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,12 @@ def _bind_rotary_emb_fwd(self):
def _get_qkv(
self, input, infer_state: InferStateInfo, layer_weight
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.mm(
cache_kv = torch.mm(
input.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

if self.use_qk_norm_:
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
k = cache_kv[:, 0 : self.tp_k_head_num_, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
raise Exception("need to impl")

def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
cache_kv = self.alloc_tensor(
shape=infer_state.kv_buffer_shapedtype[0],
dtype=infer_state.kv_buffer_shapedtype[1],
device="cuda",
is_graph_out=False,
microbatch_index=infer_state.microbatch_index,
)
return cache_kv

def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
raise Exception("need to impl")

Expand Down
6 changes: 6 additions & 0 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_global_rank,
get_current_rank_in_dp,
create_new_group_for_current_dp,
create_dp_special_inter_group,
)
from lightllm.utils.device_utils import get_device_sm_count
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL
Expand All @@ -62,6 +63,11 @@ def __init__(self):
self.custom_gather = None
self.dp_world_size = get_dp_world_size()
self.device_group = create_new_group_for_current_dp("nccl")
if get_env_start_args().enable_dp_prefill_balance:
self.dp_prefill_balance_group = create_dp_special_inter_group("nccl")
else:
self.dp_prefill_balance_group = None

self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo")

def init_custom_reduce(self) -> None:
Expand Down
5 changes: 1 addition & 4 deletions lightllm/models/bloom/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def _get_qkv(
self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
) -> Tuple[torch.Tensor, torch.Tensor]:
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
return q, cache_kv

def _context_attention_kernel(
Expand Down
77 changes: 50 additions & 27 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,6 @@ def _bind_attention(self):
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
)

def _pre_cache_kv(
self, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
) -> torch.Tensor:
# q_lora_rank 不是None的时候,融合 q_a_proj 和 kv_a_proj_with_mqa
if self.q_lora_rank is None:
return super()._pre_cache_kv(infer_state, layer_weight)
return None

def _get_qkv(
self,
input: torch.Tensor,
Expand All @@ -161,8 +153,7 @@ def _get_qkv(

if self.q_lora_rank is None:
q = layer_weight.q_weight_.mm(input)
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
else:
q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
Expand Down Expand Up @@ -203,8 +194,25 @@ def _tpsp_get_qkv(

input = input.view(-1, self.embed_dim_)
q = layer_weight.q_weight_.mm(input)
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
rmsnorm_forward(
cache_kv[:, :, : self.kv_lora_rank],
weight=layer_weight.kv_a_layernorm_.weight,
eps=self.eps_,
out=cache_kv[:, :, : self.kv_lora_rank],
)
rotary_emb_fwd(
q_rope,
cache_kv[:, :, self.kv_lora_rank :],
infer_state.position_cos,
infer_state.position_sin,
)
if infer_state.need_dp_prefill_balance:
q = infer_state._all_to_all_unbalance_get(data=q)
cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv)
return q, cache_kv
else:
input = input.view(-1, self.embed_dim_)
qkv = layer_weight.qkv_a_proj_with_mqa_.mm(input)
Expand All @@ -217,25 +225,33 @@ def _tpsp_get_qkv(
all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False)
qkv = gather_qkv[0 : len(infer_state.position_cos), :]

if infer_state.need_dp_prefill_balance:
qkv = infer_state._all_to_all_unbalance_get(data=qkv)
position_cos = infer_state._unbalance_position_cos
position_sin = infer_state._unbalance_position_sin
else:
position_cos = infer_state.position_cos
position_sin = infer_state.position_sin

q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1)
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
q = layer_weight.q_b_proj_.mm(q)
cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
rmsnorm_forward(
cache_kv[:, :, : self.kv_lora_rank],
weight=layer_weight.kv_a_layernorm_.weight,
eps=self.eps_,
out=cache_kv[:, :, : self.kv_lora_rank],
)
rotary_emb_fwd(
q_rope,
cache_kv[:, :, self.kv_lora_rank :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
rmsnorm_forward(
cache_kv[:, :, : self.kv_lora_rank],
weight=layer_weight.kv_a_layernorm_.weight,
eps=self.eps_,
out=cache_kv[:, :, : self.kv_lora_rank],
)
rotary_emb_fwd(
q_rope,
cache_kv[:, :, self.kv_lora_rank :],
position_cos,
position_sin,
)
return q, cache_kv

def _get_o(
self, input: torch.Tensor, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
Expand All @@ -248,13 +264,20 @@ def _get_o(
def _tpsp_get_o(
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
) -> torch.Tensor:
if infer_state.need_dp_prefill_balance:
input = infer_state._all_to_all_balance_get(data=input)

if input.shape[2] == self.kv_lora_rank:
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)

input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)
dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :])
e_o_tensor = o_tensor[len(infer_state.position_cos) :, :]
if e_o_tensor.shape[0] > 0:
e_o_tensor.fill_(0)

if self.tp_world_size_ > 1:
sp_token_num = o_tensor.shape[0] // self.tp_world_size_
reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device)
Expand Down
Loading