Skip to content

Commit 6eec9ee

Browse files
committed
feat: add kv_trans_v2 version
1 parent 10a9b66 commit 6eec9ee

File tree

1 file changed

+120
-27
lines changed

1 file changed

+120
-27
lines changed

lightllm/common/mem_manager.py

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import os
33
import torch
44
import torch.distributed as dist
5-
from typing import List, Union
5+
from typing import List, Union, Optional
66
from lightllm.server.pd_io_struct import KVMoveTask
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
99
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1010
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
11+
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
1112
from lightllm.utils.dist_utils import get_current_rank_in_node
1213
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1314
from lightllm.distributed.pynccl import PyNcclCommunicator
@@ -103,7 +104,10 @@ def send_to_decode_node(
103104
dp_size_in_node: int,
104105
nccl_comm: PyNcclCommunicator,
105106
):
106-
assert dp_size_in_node == 1
107+
if dp_size_in_node > 1:
108+
return self.send_to_decode_node_p2p(
109+
move_tasks, mem_managers, dp_size_in_node, nccl_comm
110+
)
107111

108112
# 先将数据发送到指定的一张卡上的buffer,再发送。
109113

@@ -143,8 +147,10 @@ def receive_from_prefill_node(
143147
dp_size_in_node: int,
144148
nccl_comm: PyNcclCommunicator,
145149
):
146-
assert dp_size_in_node == 1
147-
150+
if dp_size_in_node > 1:
151+
return self.receive_from_prefill_node_p2p(
152+
move_tasks, mem_managers, dp_size_in_node, nccl_comm
153+
)
148154
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
149155

150156
move_token_indexes = []
@@ -183,29 +189,73 @@ def send_to_decode_node_p2p(
183189
"""
184190
使用 p2p triton kernel 进行数据复制和传输的实现方式。
185191
"""
186-
assert dp_size_in_node == 1
187-
188-
# 先将数据发送到指定的一张卡上的buffer,再发送。
192+
if dp_size_in_node > 1:
193+
mem_ptrs_dict = {}
194+
# 一个 dp 副本占用的 device 数量
195+
group_stride = max(1, len(mem_managers) // dp_size_in_node)
196+
for layer_index in range(self.layer_num):
197+
mems_ptr = []
198+
for i in range(0, len(mem_managers), group_stride):
199+
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
200+
mem_ptrs_dict[layer_index] = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
189201

190202
move_token_indexes = []
203+
token_dp_indexes = []
191204
for task in move_tasks:
192205
if task.move_kv_len != 0:
193206
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
207+
if dp_size_in_node > 1:
208+
token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)])
209+
210+
if len(move_token_indexes) == 0:
211+
return
194212

195213
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
196-
for i, mem in enumerate(mem_managers):
197-
for layer_index in range(mem.layer_num):
198-
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
199-
nccl_comm.send(move_buffer, dst=1)
214+
token_dp_tensor = (
215+
torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") if dp_size_in_node > 1 else None
216+
)
217+
218+
for layer_index in range(self.layer_num):
219+
move_buffer = self._get_kv_move_data_p2p(
220+
move_token_indexes,
221+
layer_index,
222+
self.kv_move_buffer,
223+
token_dp_indexes=token_dp_tensor,
224+
dp_size_in_node=dp_size_in_node,
225+
mem_ptrs_dict=mem_ptrs_dict
226+
)
227+
nccl_comm.send(move_buffer, dst=1)
200228
return
201229

202-
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
230+
def _get_kv_move_data_p2p(
231+
self,
232+
token_indexes: torch.Tensor,
233+
layer_index: int,
234+
kv_move_buffer: torch.Tensor,
235+
token_dp_indexes: Optional[torch.Tensor] = None,
236+
dp_size_in_node: int = 1,
237+
mem_ptrs_dict: Optional[dict] = None
238+
):
203239
move_token_num = len(token_indexes)
204240
move_size = self.token_dim_size * move_token_num
205241
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim)
206-
kv_trans(
207-
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
208-
)
242+
243+
if dp_size_in_node == 1 or token_dp_indexes is None:
244+
kv_trans(
245+
self.kv_buffer[layer_index, :, :, :],
246+
token_indexes,
247+
move_buffer,
248+
self.kv_move_buf_indexes[0:move_token_num],
249+
)
250+
else:
251+
kv_trans_v2_for_p_node(
252+
input_mems=mem_ptrs_dict[layer_index],
253+
input_idx=token_indexes,
254+
input_dp_idx=token_dp_indexes,
255+
output=move_buffer,
256+
output_idx=self.kv_move_buf_indexes[0:move_token_num],
257+
dp_size_in_node=dp_size_in_node,
258+
)
209259
return move_buffer
210260

211261
def receive_from_prefill_node_p2p(
@@ -215,29 +265,72 @@ def receive_from_prefill_node_p2p(
215265
dp_size_in_node: int,
216266
nccl_comm: PyNcclCommunicator,
217267
):
218-
assert dp_size_in_node == 1
219-
220-
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
268+
if dp_size_in_node > 1:
269+
mem_ptrs_dict = {}
270+
for layer_index in range(self.layer_num):
271+
mems_ptr = []
272+
for mem in mem_managers:
273+
mems_ptr.append(mem.kv_buffer[layer_index, :, :, :].data_ptr())
274+
mem_ptrs_dict[layer_index] = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
221275

222276
move_token_indexes = []
277+
token_dp_indexes = []
223278
for task in move_tasks:
224279
if task.move_kv_len != 0:
225280
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
281+
if dp_size_in_node > 1:
282+
token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)])
283+
284+
if len(move_token_indexes) == 0:
285+
return
226286

227287
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
288+
token_dp_tensor = (
289+
torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") if dp_size_in_node > 1 else None
290+
)
228291

229-
token_num = len(move_token_indexes)
230-
move_size = self.token_dim_size * token_num
231-
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
232-
for i, mem in enumerate(mem_managers):
233-
for layer_index in range(mem.layer_num):
234-
nccl_comm.recv(recive_buffer, src=0)
235-
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
292+
move_token_num = len(move_token_indexes)
293+
move_size = self.token_dim_size * move_token_num
294+
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim)
295+
296+
for layer_index in range(self.layer_num):
297+
nccl_comm.recv(recive_buffer, src=0)
298+
self._write_kv_move_data_p2p(
299+
move_token_indexes,
300+
recive_buffer,
301+
layer_index,
302+
token_dp_indexes=token_dp_tensor,
303+
dp_size_in_node=dp_size_in_node,
304+
mem_ptrs_dict=mem_ptrs_dict
305+
)
236306
return
237307

238-
def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):
308+
def _write_kv_move_data_p2p(
309+
self,
310+
token_indexes: torch.Tensor,
311+
buffer_tensor: torch.Tensor,
312+
layer_index: int,
313+
token_dp_indexes: Optional[torch.Tensor] = None,
314+
dp_size_in_node: int = 1,
315+
mem_ptrs_dict: Optional[dict] = None
316+
):
239317
move_token_num = len(token_indexes)
240-
kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes)
318+
if dp_size_in_node == 1 or token_dp_indexes is None:
319+
kv_trans(
320+
buffer_tensor,
321+
self.kv_move_buf_indexes[0:move_token_num],
322+
self.kv_buffer[layer_index],
323+
token_indexes,
324+
)
325+
else:
326+
kv_trans_v2_for_d_node(
327+
output_mems=mem_ptrs_dict[layer_index],
328+
output_idx=token_indexes,
329+
output_dp_idx=token_dp_indexes,
330+
input=buffer_tensor,
331+
input_idx=self.kv_move_buf_indexes[0:move_token_num],
332+
dp_size_in_node=dp_size_in_node,
333+
)
241334
return
242335

243336
def _free_buffers(self):

0 commit comments

Comments
 (0)