22import os
33import torch
44import torch .distributed as dist
5- from typing import List , Union
5+ from typing import List , Union , Optional
66from lightllm .server .pd_io_struct import KVMoveTask
77from lightllm .utils .log_utils import init_logger
88from lightllm .server .router .dynamic_prompt .shared_arr import SharedInt
99from lightllm .utils .profile_max_tokens import get_available_gpu_memory , get_total_gpu_memory
1010from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
11+ from lightllm .common .kv_trans_kernel .kv_trans_v2 import kv_trans_v2_for_d_node , kv_trans_v2_for_p_node
1112from lightllm .utils .dist_utils import get_current_rank_in_node
1213from lightllm .utils .envs_utils import get_unique_server_name , get_env_start_args
1314from lightllm .distributed .pynccl import PyNcclCommunicator
@@ -103,7 +104,10 @@ def send_to_decode_node(
103104 dp_size_in_node : int ,
104105 nccl_comm : PyNcclCommunicator ,
105106 ):
106- assert dp_size_in_node == 1
107+ if dp_size_in_node > 1 :
108+ return self .send_to_decode_node_p2p (
109+ move_tasks , mem_managers , dp_size_in_node , nccl_comm
110+ )
107111
108112 # 先将数据发送到指定的一张卡上的buffer,再发送。
109113
@@ -143,8 +147,10 @@ def receive_from_prefill_node(
143147 dp_size_in_node : int ,
144148 nccl_comm : PyNcclCommunicator ,
145149 ):
146- assert dp_size_in_node == 1
147-
150+ if dp_size_in_node > 1 :
151+ return self .receive_from_prefill_node_p2p (
152+ move_tasks , mem_managers , dp_size_in_node , nccl_comm
153+ )
148154 # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
149155
150156 move_token_indexes = []
@@ -183,29 +189,73 @@ def send_to_decode_node_p2p(
183189 """
184190 使用 p2p triton kernel 进行数据复制和传输的实现方式。
185191 """
186- assert dp_size_in_node == 1
187-
188- # 先将数据发送到指定的一张卡上的buffer,再发送。
192+ if dp_size_in_node > 1 :
193+ mem_ptrs_dict = {}
194+ # 一个 dp 副本占用的 device 数量
195+ group_stride = max (1 , len (mem_managers ) // dp_size_in_node )
196+ for layer_index in range (self .layer_num ):
197+ mems_ptr = []
198+ for i in range (0 , len (mem_managers ), group_stride ):
199+ mems_ptr .append (mem_managers [i ].kv_buffer [layer_index , :, :, :].data_ptr ())
200+ mem_ptrs_dict [layer_index ] = torch .tensor (mems_ptr , dtype = torch .uint64 , device = "cuda" )
189201
190202 move_token_indexes = []
203+ token_dp_indexes = []
191204 for task in move_tasks :
192205 if task .move_kv_len != 0 :
193206 move_token_indexes .extend (task .prefill_token_indexes [- task .move_kv_len :])
207+ if dp_size_in_node > 1 :
208+ token_dp_indexes .extend ([task .prefill_dp_index for _ in range (task .move_kv_len )])
209+
210+ if len (move_token_indexes ) == 0 :
211+ return
194212
195213 move_token_indexes = torch .tensor (move_token_indexes , dtype = torch .int64 , device = "cuda" )
196- for i , mem in enumerate (mem_managers ):
197- for layer_index in range (mem .layer_num ):
198- move_buffer = mem ._get_kv_move_data_p2p (move_token_indexes , layer_index , self .kv_move_buffer )
199- nccl_comm .send (move_buffer , dst = 1 )
214+ token_dp_tensor = (
215+ torch .tensor (token_dp_indexes , dtype = torch .int32 , device = "cuda" ) if dp_size_in_node > 1 else None
216+ )
217+
218+ for layer_index in range (self .layer_num ):
219+ move_buffer = self ._get_kv_move_data_p2p (
220+ move_token_indexes ,
221+ layer_index ,
222+ self .kv_move_buffer ,
223+ token_dp_indexes = token_dp_tensor ,
224+ dp_size_in_node = dp_size_in_node ,
225+ mem_ptrs_dict = mem_ptrs_dict
226+ )
227+ nccl_comm .send (move_buffer , dst = 1 )
200228 return
201229
202- def _get_kv_move_data_p2p (self , token_indexes : torch .Tensor , layer_index : int , kv_move_buffer : torch .Tensor ):
230+ def _get_kv_move_data_p2p (
231+ self ,
232+ token_indexes : torch .Tensor ,
233+ layer_index : int ,
234+ kv_move_buffer : torch .Tensor ,
235+ token_dp_indexes : Optional [torch .Tensor ] = None ,
236+ dp_size_in_node : int = 1 ,
237+ mem_ptrs_dict : Optional [dict ] = None
238+ ):
203239 move_token_num = len (token_indexes )
204240 move_size = self .token_dim_size * move_token_num
205241 move_buffer = kv_move_buffer .view (- 1 )[0 :move_size ].view (move_token_num , 2 * self .head_num , self .head_dim )
206- kv_trans (
207- self .kv_buffer [layer_index , :, :, :], token_indexes , move_buffer , self .kv_move_buf_indexes [0 :move_token_num ]
208- )
242+
243+ if dp_size_in_node == 1 or token_dp_indexes is None :
244+ kv_trans (
245+ self .kv_buffer [layer_index , :, :, :],
246+ token_indexes ,
247+ move_buffer ,
248+ self .kv_move_buf_indexes [0 :move_token_num ],
249+ )
250+ else :
251+ kv_trans_v2_for_p_node (
252+ input_mems = mem_ptrs_dict [layer_index ],
253+ input_idx = token_indexes ,
254+ input_dp_idx = token_dp_indexes ,
255+ output = move_buffer ,
256+ output_idx = self .kv_move_buf_indexes [0 :move_token_num ],
257+ dp_size_in_node = dp_size_in_node ,
258+ )
209259 return move_buffer
210260
211261 def receive_from_prefill_node_p2p (
@@ -215,29 +265,72 @@ def receive_from_prefill_node_p2p(
215265 dp_size_in_node : int ,
216266 nccl_comm : PyNcclCommunicator ,
217267 ):
218- assert dp_size_in_node == 1
219-
220- # 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
268+ if dp_size_in_node > 1 :
269+ mem_ptrs_dict = {}
270+ for layer_index in range (self .layer_num ):
271+ mems_ptr = []
272+ for mem in mem_managers :
273+ mems_ptr .append (mem .kv_buffer [layer_index , :, :, :].data_ptr ())
274+ mem_ptrs_dict [layer_index ] = torch .tensor (mems_ptr , dtype = torch .uint64 , device = "cuda" )
221275
222276 move_token_indexes = []
277+ token_dp_indexes = []
223278 for task in move_tasks :
224279 if task .move_kv_len != 0 :
225280 move_token_indexes .extend (task .decode_token_indexes [- task .move_kv_len :])
281+ if dp_size_in_node > 1 :
282+ token_dp_indexes .extend ([task .decode_dp_index for _ in range (task .move_kv_len )])
283+
284+ if len (move_token_indexes ) == 0 :
285+ return
226286
227287 move_token_indexes = torch .tensor (move_token_indexes , dtype = torch .int64 , device = "cuda" )
288+ token_dp_tensor = (
289+ torch .tensor (token_dp_indexes , dtype = torch .int32 , device = "cuda" ) if dp_size_in_node > 1 else None
290+ )
228291
229- token_num = len (move_token_indexes )
230- move_size = self .token_dim_size * token_num
231- recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (token_num , 2 * self .head_num , self .head_dim )
232- for i , mem in enumerate (mem_managers ):
233- for layer_index in range (mem .layer_num ):
234- nccl_comm .recv (recive_buffer , src = 0 )
235- mem ._write_kv_move_data_p2p (move_token_indexes , recive_buffer , layer_index )
292+ move_token_num = len (move_token_indexes )
293+ move_size = self .token_dim_size * move_token_num
294+ recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (move_token_num , 2 * self .head_num , self .head_dim )
295+
296+ for layer_index in range (self .layer_num ):
297+ nccl_comm .recv (recive_buffer , src = 0 )
298+ self ._write_kv_move_data_p2p (
299+ move_token_indexes ,
300+ recive_buffer ,
301+ layer_index ,
302+ token_dp_indexes = token_dp_tensor ,
303+ dp_size_in_node = dp_size_in_node ,
304+ mem_ptrs_dict = mem_ptrs_dict
305+ )
236306 return
237307
238- def _write_kv_move_data_p2p (self , token_indexes : torch .Tensor , buffer_tensor : torch .Tensor , layer_index ):
308+ def _write_kv_move_data_p2p (
309+ self ,
310+ token_indexes : torch .Tensor ,
311+ buffer_tensor : torch .Tensor ,
312+ layer_index : int ,
313+ token_dp_indexes : Optional [torch .Tensor ] = None ,
314+ dp_size_in_node : int = 1 ,
315+ mem_ptrs_dict : Optional [dict ] = None
316+ ):
239317 move_token_num = len (token_indexes )
240- kv_trans (buffer_tensor , self .kv_move_buf_indexes [0 :move_token_num ], self .kv_buffer [layer_index ], token_indexes )
318+ if dp_size_in_node == 1 or token_dp_indexes is None :
319+ kv_trans (
320+ buffer_tensor ,
321+ self .kv_move_buf_indexes [0 :move_token_num ],
322+ self .kv_buffer [layer_index ],
323+ token_indexes ,
324+ )
325+ else :
326+ kv_trans_v2_for_d_node (
327+ output_mems = mem_ptrs_dict [layer_index ],
328+ output_idx = token_indexes ,
329+ output_dp_idx = token_dp_indexes ,
330+ input = buffer_tensor ,
331+ input_idx = self .kv_move_buf_indexes [0 :move_token_num ],
332+ dp_size_in_node = dp_size_in_node ,
333+ )
241334 return
242335
243336 def _free_buffers (self ):
0 commit comments