Skip to content

Commit 16a41d2

Browse files
hiworldwzjwangzaijun
andauthored
tpsp mode support db prefill balance. (#1086)
Co-authored-by: wangzaijun <[email protected]>
1 parent e308e32 commit 16a41d2

File tree

24 files changed

+382
-135
lines changed

24 files changed

+382
-135
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
290290
infer_state.req_manager = self.req_manager
291291

292292
infer_state.mem_index = model_input.mem_indexes
293-
infer_state.kv_buffer_shapedtype = (
294-
(model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
295-
self.data_type,
296-
)
297293
infer_state.microbatch_index = microbatch_index
298294
infer_state.dist_group = dist_group_manager.get_group(microbatch_index)
299295

lightllm/common/basemodel/infer_struct.py

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
2+
import triton
3+
import collections
24
from lightllm.common.mem_manager import MemoryManager
35
from lightllm.common.req_manager import ReqManager
46
from lightllm.distributed import CustomProcessGroup
5-
from typing import Tuple, Any, Optional
7+
from typing import Tuple, Any, Optional, List
68
from .triton_kernel.gen_prefill_params import gen_prefill_params
79
from .triton_kernel.gen_decode_params import gen_decode_params
810
from .triton_kernel.multimodal_emb import mark_multimodal_obj
911
from .batch_objs import ModelInput
12+
from lightllm.utils.envs_utils import get_env_start_args
13+
from lightllm.utils.dist_utils import get_global_dp_rank
1014

1115

1216
class InferStateInfo:
@@ -36,7 +40,6 @@ def __init__(self):
3640
self.req_manager: ReqManager = None
3741

3842
self.mem_index: torch.Tensor = None
39-
self.kv_buffer_shapedtype: Tuple[Any, Any] = None
4043

4144
self.is_token_healing: bool = False
4245
self.return_all_prompt_logics: bool = False
@@ -69,6 +72,18 @@ def __init__(self):
6972
# 的输入会用到,其他模型和场景都不会用到
7073
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
7174

75+
# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
76+
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
77+
# dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升
78+
# prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下
79+
# 下面的变量不会被使用。
80+
self.need_dp_prefill_balance: bool = False
81+
self.dp_origin_lens: List[int] = None
82+
self.dp_handle_lens: List[int] = None
83+
# self.dp_input_lens: torch.Tensor = None
84+
self.dp_output_split_sizes: List[List[int]] = None
85+
self.dp_input_split_sizes: List[List[int]] = None
86+
7287
def init_some_extra_state(self, model, input_ids: torch.Tensor):
7388
if self.is_prefill:
7489
(
@@ -123,3 +138,153 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
123138
for mark, obj in zip(marks_array, multi_objs):
124139
obj["_prefill_"] = mark > 0
125140
return
141+
142+
def prefill_dp_balance(self, input_ids: torch.Tensor):
143+
"""
144+
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致
145+
的prefill 推理性能下降
146+
"""
147+
assert self.is_prefill
148+
import torch.distributed as dist
149+
150+
self.need_dp_prefill_balance = True
151+
152+
args = get_env_start_args()
153+
154+
dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32)
155+
input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32)
156+
input_len.fill_(len(input_ids))
157+
dist.all_gather_into_tensor(
158+
output_tensor=dp_input_lens,
159+
input_tensor=input_len,
160+
group=self.dist_group.dp_prefill_balance_group,
161+
async_op=False,
162+
)
163+
dp_input_lens = dp_input_lens.detach().cpu()
164+
self.dp_origin_lens = dp_input_lens.tolist()
165+
sum_input_len = dp_input_lens.sum().item()
166+
dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)]
167+
for i in range(sum_input_len % args.dp):
168+
dp_handle_lens[i] += 1
169+
170+
self.dp_handle_lens = dp_handle_lens.copy()
171+
172+
dest_dp_inputs = [[] for _ in range(args.dp)]
173+
# 分配每个dp 的原始输入和分配后的原始输入
174+
origin_datas = collections.deque()
175+
for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()):
176+
handle_len = dp_handle_lens[origin_dp_index]
177+
if origin_dp_input_len > handle_len:
178+
origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len))
179+
dp_handle_lens[origin_dp_index] = 0
180+
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len))
181+
else:
182+
dp_handle_lens[origin_dp_index] -= origin_dp_input_len
183+
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len))
184+
185+
for dest_dp_index in range(args.dp):
186+
need_size = dp_handle_lens[dest_dp_index]
187+
if need_size == 0:
188+
continue
189+
while len(origin_datas) != 0:
190+
origin_data = origin_datas.popleft()
191+
origin_dp_index, start, end = origin_data
192+
if end - start > need_size:
193+
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size))
194+
origin_datas.appendleft((origin_dp_index, start + need_size, end))
195+
break
196+
else:
197+
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end))
198+
need_size -= end - start
199+
if need_size == 0:
200+
break
201+
202+
dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
203+
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
204+
for origin_dp_index, start, end in dest_dp_data:
205+
dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start
206+
dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
207+
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
208+
for origin_dp_index, start, end in dest_dp_data:
209+
dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start
210+
211+
self.dp_input_split_sizes = dp_input_split_sizes
212+
self.dp_output_split_sizes = dp_output_split_sizes
213+
214+
new_input_ids = self._all_to_all_balance_get(input_ids)
215+
if hasattr(self, "position_ids") and self.position_ids is not None:
216+
# deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量
217+
self._unbalance_position_ids = self.position_ids
218+
219+
self.position_ids = self._all_to_all_balance_get(self.position_ids)
220+
if hasattr(self, "position_cos") and self.position_cos is not None:
221+
# deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量
222+
self._unbalance_position_cos = self.position_cos
223+
224+
self.position_cos = self._all_to_all_balance_get(self.position_cos)
225+
if hasattr(self, "position_sin") and self.position_sin is not None:
226+
# deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量
227+
self._unbalance_position_sin = self.position_sin
228+
229+
self.position_sin = self._all_to_all_balance_get(self.position_sin)
230+
231+
return new_input_ids
232+
233+
def _all_to_all_balance_get(self, data: torch.Tensor):
234+
dp_rank = get_global_dp_rank()
235+
import torch.distributed as dist
236+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
237+
238+
old_shape = data.shape
239+
data = data.view(-1)
240+
241+
origin_len = self.dp_origin_lens[dp_rank]
242+
assert data.shape[0] % origin_len == 0
243+
scale_size = data.shape[0] // origin_len
244+
handle_len = self.dp_handle_lens[dp_rank]
245+
246+
dest_data = g_cache_manager.alloc_tensor(
247+
shape=(handle_len * scale_size,),
248+
data_type=data.dtype,
249+
device="cuda",
250+
is_graph_out=False,
251+
microbatch_index=self.microbatch_index,
252+
)
253+
dist.all_to_all_single(
254+
output=dest_data.view(-1),
255+
input=data.view(-1),
256+
output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
257+
input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
258+
group=self.dist_group.dp_prefill_balance_group,
259+
async_op=False,
260+
)
261+
return dest_data.view(-1, *old_shape[1:])
262+
263+
def _all_to_all_unbalance_get(self, data: torch.Tensor):
264+
dp_rank = get_global_dp_rank()
265+
import torch.distributed as dist
266+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
267+
268+
old_shape = data.shape
269+
data = data.view(-1)
270+
271+
handle_len = self.dp_handle_lens[dp_rank]
272+
scale_size = data.shape[0] // handle_len
273+
assert data.shape[0] % handle_len == 0
274+
origin_len = self.dp_origin_lens[dp_rank]
275+
origin_data = g_cache_manager.alloc_tensor(
276+
shape=(origin_len * scale_size,),
277+
data_type=data.dtype,
278+
device="cuda",
279+
is_graph_out=False,
280+
microbatch_index=self.microbatch_index,
281+
)
282+
dist.all_to_all_single(
283+
output=origin_data.view(-1),
284+
input=data,
285+
output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
286+
input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
287+
group=self.dist_group.dp_prefill_balance_group,
288+
async_op=False,
289+
)
290+
return origin_data.view(-1, *old_shape[1:])

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@ def _bind_rotary_emb_fwd(self):
4444
def _get_qkv(
4545
self, input, infer_state: InferStateInfo, layer_weight
4646
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
4847
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
49-
torch.mm(
48+
cache_kv = torch.mm(
5049
input.view(-1, self.embed_dim_),
5150
layer_weight.kv_weight_,
52-
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
53-
)
51+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
52+
5453
if self.use_qk_norm_:
5554
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
5655
k = cache_kv[:, 0 : self.tp_k_head_num_, :]

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,6 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3030
def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
3131
raise Exception("need to impl")
3232

33-
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
34-
cache_kv = self.alloc_tensor(
35-
shape=infer_state.kv_buffer_shapedtype[0],
36-
dtype=infer_state.kv_buffer_shapedtype[1],
37-
device="cuda",
38-
is_graph_out=False,
39-
microbatch_index=infer_state.microbatch_index,
40-
)
41-
return cache_kv
42-
4333
def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
4434
raise Exception("need to impl")
4535

lightllm/distributed/communication_op.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
get_global_rank,
3737
get_current_rank_in_dp,
3838
create_new_group_for_current_dp,
39+
create_dp_special_inter_group,
3940
)
4041
from lightllm.utils.device_utils import get_device_sm_count
4142
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL
@@ -62,6 +63,11 @@ def __init__(self):
6263
self.custom_gather = None
6364
self.dp_world_size = get_dp_world_size()
6465
self.device_group = create_new_group_for_current_dp("nccl")
66+
if get_env_start_args().enable_dp_prefill_balance:
67+
self.dp_prefill_balance_group = create_dp_special_inter_group("nccl")
68+
else:
69+
self.dp_prefill_balance_group = None
70+
6571
self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo")
6672

6773
def init_custom_reduce(self) -> None:

lightllm/models/bloom/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def _get_qkv(
4747
self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
4848
) -> Tuple[torch.Tensor, torch.Tensor]:
4949
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
50-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
51-
cache_kv = layer_weight.kv_proj.mm(
52-
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
53-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
50+
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
5451
return q, cache_kv
5552

5653
def _context_attention_kernel(

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,6 @@ def _bind_attention(self):
143143
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
144144
)
145145

146-
def _pre_cache_kv(
147-
self, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
148-
) -> torch.Tensor:
149-
# q_lora_rank 不是None的时候,融合 q_a_proj 和 kv_a_proj_with_mqa
150-
if self.q_lora_rank is None:
151-
return super()._pre_cache_kv(infer_state, layer_weight)
152-
return None
153-
154146
def _get_qkv(
155147
self,
156148
input: torch.Tensor,
@@ -161,8 +153,7 @@ def _get_qkv(
161153

162154
if self.q_lora_rank is None:
163155
q = layer_weight.q_weight_.mm(input)
164-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
165-
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
156+
cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
166157
else:
167158
q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split(
168159
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -203,8 +194,25 @@ def _tpsp_get_qkv(
203194

204195
input = input.view(-1, self.embed_dim_)
205196
q = layer_weight.q_weight_.mm(input)
206-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
207-
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
197+
cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
198+
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
199+
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
200+
rmsnorm_forward(
201+
cache_kv[:, :, : self.kv_lora_rank],
202+
weight=layer_weight.kv_a_layernorm_.weight,
203+
eps=self.eps_,
204+
out=cache_kv[:, :, : self.kv_lora_rank],
205+
)
206+
rotary_emb_fwd(
207+
q_rope,
208+
cache_kv[:, :, self.kv_lora_rank :],
209+
infer_state.position_cos,
210+
infer_state.position_sin,
211+
)
212+
if infer_state.need_dp_prefill_balance:
213+
q = infer_state._all_to_all_unbalance_get(data=q)
214+
cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv)
215+
return q, cache_kv
208216
else:
209217
input = input.view(-1, self.embed_dim_)
210218
qkv = layer_weight.qkv_a_proj_with_mqa_.mm(input)
@@ -217,25 +225,33 @@ def _tpsp_get_qkv(
217225
all_gather_into_tensor(gather_qkv, qkv, group=infer_state.dist_group, async_op=False)
218226
qkv = gather_qkv[0 : len(infer_state.position_cos), :]
219227

228+
if infer_state.need_dp_prefill_balance:
229+
qkv = infer_state._all_to_all_unbalance_get(data=qkv)
230+
position_cos = infer_state._unbalance_position_cos
231+
position_sin = infer_state._unbalance_position_sin
232+
else:
233+
position_cos = infer_state.position_cos
234+
position_sin = infer_state.position_sin
235+
220236
q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1)
221237
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
222238
q = layer_weight.q_b_proj_.mm(q)
223239
cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
224-
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
225-
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
226-
rmsnorm_forward(
227-
cache_kv[:, :, : self.kv_lora_rank],
228-
weight=layer_weight.kv_a_layernorm_.weight,
229-
eps=self.eps_,
230-
out=cache_kv[:, :, : self.kv_lora_rank],
231-
)
232-
rotary_emb_fwd(
233-
q_rope,
234-
cache_kv[:, :, self.kv_lora_rank :],
235-
infer_state.position_cos,
236-
infer_state.position_sin,
237-
)
238-
return q, cache_kv
240+
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
241+
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
242+
rmsnorm_forward(
243+
cache_kv[:, :, : self.kv_lora_rank],
244+
weight=layer_weight.kv_a_layernorm_.weight,
245+
eps=self.eps_,
246+
out=cache_kv[:, :, : self.kv_lora_rank],
247+
)
248+
rotary_emb_fwd(
249+
q_rope,
250+
cache_kv[:, :, self.kv_lora_rank :],
251+
position_cos,
252+
position_sin,
253+
)
254+
return q, cache_kv
239255

240256
def _get_o(
241257
self, input: torch.Tensor, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
@@ -248,13 +264,20 @@ def _get_o(
248264
def _tpsp_get_o(
249265
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
250266
) -> torch.Tensor:
267+
if infer_state.need_dp_prefill_balance:
268+
input = infer_state._all_to_all_balance_get(data=input)
269+
251270
if input.shape[2] == self.kv_lora_rank:
252271
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
253272

254273
input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)
255274
dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_
256275
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
257276
layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :])
277+
e_o_tensor = o_tensor[len(infer_state.position_cos) :, :]
278+
if e_o_tensor.shape[0] > 0:
279+
e_o_tensor.fill_(0)
280+
258281
if self.tp_world_size_ > 1:
259282
sp_token_num = o_tensor.shape[0] // self.tp_world_size_
260283
reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device)

0 commit comments

Comments
 (0)