1313 KVConnectorRole ,
1414)
1515from vllm .distributed .parallel_state import get_tp_group , get_world_group
16+ from vllm .platforms import current_platform
1617from vllm .v1 .core .sched .output import SchedulerOutput
1718from vllm .v1 .request import Request
1819
@@ -102,16 +103,42 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
102103 )
103104 self .block_size = self ._vllm_config .cache_config .block_size
104105 self .is_mla = self ._vllm_config .model_config .is_deepseek_mla
106+ self .kv_cache_dtype : torch .dtype = None
107+
108+ if current_platform .is_cuda_alike ():
109+ logger .info ("CUDA device is available." )
110+ torch_dev = torch
111+ dev_name = "cuda"
112+ elif current_platform .is_npu ():
113+ logger .info ("NPU device is available." )
114+ torch_dev = torch .npu
115+ dev_name = "npu"
116+ else :
117+ raise RuntimeError ("Unsupported device platform for LMCache engine." )
118+
119+ if self .rank >= 0 :
120+ self .device = torch_dev .device (f"{ dev_name } :{ self .rank } " )
121+ self ._layer_offset_cache = {}
105122
106123 self .store : UcmKVStoreBase
107124
108125 self .request_hasher = RequestHasher ()
109126
110127 # save block info, avoid hash request twice, and track them until request finished
111128 self .requests_meta : dict [str , RequestMeta ] = {}
129+
112130 ucm_config = Config (vllm_config .kv_transfer_config )
113131 self .launch_config = ucm_config .get_config ()
114132
133+ self .load_only_first_rank : bool = (
134+ self .launch_config .get ("load_only_first_rank" , self .is_mla ) and self .is_mla
135+ )
136+ if self .load_only_first_rank :
137+ if role == KVConnectorRole .WORKER :
138+ self .group_coordinator = get_tp_group ()
139+ self .broadcast_fn = self .group_coordinator .broadcast
140+ self .broadcast_stream = torch .cuda .Stream ()
141+
115142 if "ucm_connector_name" in self .launch_config :
116143 name = self .launch_config .get ("ucm_connector_name" )
117144 config = self .launch_config .get ("ucm_connector_config" ) or {}
@@ -137,14 +164,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
137164 config ["io_size" ] = block_size_per_layer * (
138165 1 if self .is_mla else num_head_per_tp
139166 )
140- self .load_only_first_rank : bool = (
141- config .get ("load_only_first_rank" , self .is_mla ) and self .is_mla
142- )
143- if self .load_only_first_rank :
144- if role == KVConnectorRole .WORKER :
145- self .group_coordinator = get_tp_group ()
146- self .broadcast_fn = self .group_coordinator .broadcast
147- self .broadcast_stream = torch .cuda .Stream ()
148167 self .store = UcmConnectorFactory .create_connector (name , config )
149168
150169 logger .info ("init UCConnectorImpl, connector: %s" , name )
@@ -320,6 +339,8 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
320339 self .kv_caches [layer_name ] = attn_layer .kv_cache [
321340 forward_context .virtual_engine
322341 ]
342+ if self .kv_cache_dtype is None :
343+ self .kv_cache_dtype = self .kv_caches [layer_name ][0 ].dtype
323344
324345 @staticmethod
325346 def _extract_layer_index (layer_name : str ) -> Optional [int ]:
@@ -331,133 +352,88 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
331352 return int (chunk )
332353 return None
333354
334- def _data_offset (self , kv_layer , layer_id , is_v ) -> int :
335- """
336- GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
337- MLA: one layer shape is (1, num_blocks, block_size, head_size)
338- """
339- elem_size = kv_layer [0 ].element_size ()
355+ def _precompute_layer_offsets (self ) :
356+ if not self . kv_caches :
357+ return
358+
359+ sample_kv_layer = next ( iter ( self . kv_caches . values ()))
360+ elem_size = sample_kv_layer [0 ].element_size ()
340361 block_data_size = (
341- kv_layer [0 ].numel () if self .is_mla else kv_layer [0 ][0 ].numel ()
362+ sample_kv_layer [0 ].numel () if self .is_mla else sample_kv_layer [0 ][0 ].numel ()
342363 ) * elem_size
343- if is_v :
344- return self ._data_offset (kv_layer , layer_id , False ) + block_data_size
345-
346364 layer_data_size = block_data_size if self .is_mla else block_data_size * 2
347- return layer_data_size * layer_id
365+
366+ # precompute all layers offset
367+ for layer_name , _ in self .kv_caches .items ():
368+ layer_id = self ._extract_layer_index (layer_name )
369+ assert layer_id is not None
370+ k_offset = layer_data_size * layer_id
371+ v_offset = k_offset + block_data_size if not self .is_mla else 0
372+ self ._layer_offset_cache [layer_name ] = (k_offset , v_offset )
348373
349374 def _get_tensor_and_offset (
350375 self , vllm_block_ids : list [int ], kv_layer : torch .Tensor , layer_name : str
351376 ) -> tuple [list [torch .Tensor ], list [int ]]:
377+ """
378+ GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
379+ MLA: one layer shape is (num_blocks, block_size, head_size)
380+ """
352381 k_tensors , k_offsets = [], []
353382 v_tensors , v_offsets = [], []
354- layer_id = self ._extract_layer_index (layer_name )
355- assert layer_id is not None
383+ k_offset , v_offset = self ._layer_offset_cache [layer_name ]
356384
357385 for vllm_block_id in vllm_block_ids :
358- offset = self ._data_offset (kv_layer , layer_id , False )
359- tensor = (
386+ k_tensors .append (
360387 kv_layer [vllm_block_id ] if self .is_mla else kv_layer [0 ][vllm_block_id ]
361388 )
362- k_tensors .append (tensor )
363- k_offsets .append (offset )
389+ k_offsets .append (k_offset )
364390 if not self .is_mla :
365- v_offset = self ._data_offset (kv_layer , layer_id , True )
366391 v_tensors .append (kv_layer [1 ][vllm_block_id ])
367392 v_offsets .append (v_offset )
368393 return k_tensors + v_tensors , k_offsets + v_offsets
369394
370- def _generate_task (
371- self ,
372- vllm_block_ids ,
373- ucm_block_ids ,
374- func : Callable [[List [str ], List [int ], List [torch .Tensor ]], Task ],
375- ) -> Task :
376- dst_tensor_addr , ucm_offsets = [], []
395+ def _generate_task (self , vllm_block_ids : List [int ], ucm_block_ids : List [str ]):
396+ if not self ._layer_offset_cache :
397+ self ._precompute_layer_offsets ()
398+
399+ num_layers = len (self .kv_caches )
400+ num_blocks_per_layer = len (vllm_block_ids )
401+ num_tensors_per_layer = num_blocks_per_layer * (1 if self .is_mla else 2 )
402+ dst_tensor_addr = [None ] * (num_layers * num_tensors_per_layer )
403+ ucm_offsets = [0 ] * (num_layers * num_tensors_per_layer )
404+
405+ idx = 0
377406 for layer_name , one_layer_kv_cache in self .kv_caches .items ():
378- addrs , offsets = self ._get_tensor_and_offset (
407+ tensors , offsets = self ._get_tensor_and_offset (
379408 vllm_block_ids , one_layer_kv_cache , layer_name
380409 )
381- dst_tensor_addr .extend (addrs )
382- ucm_offsets .extend (offsets )
383- ucm_total_block_ids = ucm_block_ids * len (self .kv_caches )
384- if not self .is_mla :
385- ucm_total_block_ids *= 2
410+ dst_tensor_addr [idx : idx + len (tensors )] = tensors
411+ ucm_offsets [idx : idx + len (offsets )] = offsets
412+ idx += len (tensors )
413+
414+ repeat_times = len (self .kv_caches ) * (1 if self .is_mla else 2 )
415+ ucm_total_block_ids = ucm_block_ids * repeat_times
416+
386417 assert len (ucm_total_block_ids ) == len (ucm_offsets ) == len (dst_tensor_addr )
387- return func ( ucm_total_block_ids , ucm_offsets , dst_tensor_addr )
418+ return ucm_total_block_ids , ucm_offsets , dst_tensor_addr
388419
389- def _generate_load_task_for_broadcast (
390- self ,
391- vllm_block_ids ,
392- ucm_block_ids ,
393- can_load : bool ,
394- ) -> tuple [Task , dict [str , torch .Tensor ], int ]:
395- """
396- Load or Dump func is only called in rank 0 in MLA;
397- In rank != 0, worker will receive broadcast tensors from rank 0.
398- """
399- layer_to_tensors = {}
400- total_block_num = len (ucm_block_ids )
401- dst_tensor_addr , ucm_offsets = [], []
402- for layer_name , one_layer_kv_cache in self .kv_caches .items ():
403- addrs , offsets = self ._get_tensor_and_offset (
404- vllm_block_ids , one_layer_kv_cache , layer_name
405- )
406- layer_to_tensors [layer_name ] = addrs [:total_block_num ]
407- dst_tensor_addr .extend (addrs )
408- ucm_offsets .extend (offsets )
409- ucm_total_block_ids = ucm_block_ids * len (self .kv_caches )
410-
411- task = None
412- if can_load :
413- assert len (ucm_total_block_ids ) == len (ucm_offsets ) == len (dst_tensor_addr )
414- task = self .store .load (ucm_total_block_ids , ucm_offsets , dst_tensor_addr )
415- return task , layer_to_tensors , total_block_num
416-
417- def _broadcast_or_receive_blocks (
418- self , layer_to_tensors : dict [str : torch .Tensor ], total_block_num
419- ):
420- receive_dict = {}
421- for layer_name , kv_layer in self .kv_caches .items ():
422- k_tensors = layer_to_tensors [layer_name ][:total_block_num ]
420+ def _broadcast (self , dst_tensor_addr : list [torch .Tensor ]):
421+ rec_tensor : torch .Tensor = None
422+ with torch .cuda .stream (self .broadcast_stream ):
423423 if self .rank == 0 :
424- tensor_to_broadcast = torch .stack (k_tensors , dim = 0 )
424+ tensor_to_broadcast = torch .stack (dst_tensor_addr , dim = 0 )
425425 self .broadcast_fn (tensor_to_broadcast , 0 )
426426 else :
427- shape = (len (k_tensors ),) + k_tensors [0 ].shape
428- dtype = k_tensors [0 ].dtype
429- rec_tensor = torch .empty (shape , dtype = dtype , device = f"cuda:{ self .rank } " )
427+ shape = (len (dst_tensor_addr ),) + dst_tensor_addr [0 ].shape
428+ # TODO create earlier
429+ rec_tensor = torch .empty (
430+ shape , dtype = self .kv_cache_dtype , device = self .device
431+ )
430432 self .broadcast_fn (rec_tensor , 0 )
431- receive_dict [layer_name ] = rec_tensor
432- return receive_dict
433-
434- def _wait_for_broadcast (
435- self ,
436- req_id : str ,
437- task : Task ,
438- layer_to_tensors : dict [str , torch .Tensor ],
439- total_block_num : int ,
440- ):
441- if self .rank == 0 :
442- if self .store .wait (task ) != 0 :
443- logger .error (f"request { req_id } load kv cache failed." )
444- return
445- logger .debug (
446- f"request { req_id } load { total_block_num } blocks on rank { self .rank } "
447- )
448- with torch .cuda .stream (self .broadcast_stream ):
449- receive_dict = self ._broadcast_or_receive_blocks (
450- layer_to_tensors , total_block_num
451- )
452433 self .broadcast_stream .synchronize ()
453- if self .rank > 0 and receive_dict :
454- for layer_name , kv_layer in self .kv_caches .items ():
455- received_tensor = receive_dict [layer_name ]
456- for i in range (total_block_num ):
457- layer_to_tensors [layer_name ][i ].copy_ (received_tensor [i ])
458- logger .debug (
459- f"request { req_id } receive broadcast { total_block_num } blocks on rank { self .rank } "
460- )
434+ if self .rank != 0 and rec_tensor is not None :
435+ for i , tensor in enumerate (dst_tensor_addr ):
436+ tensor .copy_ (rec_tensor [i ])
461437
462438 def start_load_kv (self , forward_context : "ForwardContext" , ** kwargs ) -> None :
463439
@@ -467,35 +443,30 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
467443 self ._init_kv_caches_from_forward_context (forward_context )
468444
469445 request_to_task : dict [str , Optional [Task ]] = {}
470- req_to_layer = {}
446+ req_broadcast_addr = {}
471447 for request_id , request in metadata .request_meta .items ():
472448 if len (request .load_block_ids [0 ]) == 0 :
473449 continue
474450
475451 ucm_block_ids , vllm_block_ids = request .load_block_ids
476- if self .load_only_first_rank :
477- can_load = self . rank == 0
478- task , layer_to_tensors , total_block_num = (
479- self ._generate_load_task_for_broadcast (
480- vllm_block_ids , ucm_block_ids , can_load
481- )
452+ ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
453+ vllm_block_ids , ucm_block_ids
454+ )
455+ if self . rank == 0 or not self .load_only_first_rank :
456+ request_to_task [ request_id ] = self . store . load (
457+ ucm_total_block_ids , ucm_offsets , dst_tensor_addr
482458 )
483- req_to_layer [request_id ] = (layer_to_tensors , total_block_num )
484459 else :
485- task = self ._generate_task (
486- vllm_block_ids , ucm_block_ids , self .store .load
487- )
488- request_to_task [request_id ] = task
460+ request_to_task [request_id ] = None
461+ req_broadcast_addr [request_id ] = dst_tensor_addr
489462
490- for req_id , task in request_to_task .items ():
491- if self .load_only_first_rank :
492- layer_to_tensors , total_block_num = req_to_layer [req_id ]
493- self ._wait_for_broadcast (
494- req_id , task , layer_to_tensors , total_block_num
495- )
496- else :
463+ for request_id , task in request_to_task .items ():
464+ # TODO error handling
465+ if self .rank == 0 or not self .load_only_first_rank :
497466 if self .store .wait (task ) != 0 :
498- logger .error (f"request { req_id } load kv cache failed." )
467+ logger .error (f"request { request_id } load kv cache failed." )
468+ if self .load_only_first_rank :
469+ self ._broadcast (req_broadcast_addr [request_id ])
499470
500471 def wait_for_layer_load (self , layer_name : str ) -> None :
501472 pass
@@ -538,8 +509,11 @@ def wait_for_save(self) -> None:
538509 continue
539510 ucm_block_ids = ucm_block_ids [:end ]
540511 vllm_block_ids = vllm_block_ids [:end ]
541- request_to_task [request_id ] = self ._generate_task (
542- vllm_block_ids , ucm_block_ids , self .store .dump
512+ ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
513+ vllm_block_ids , ucm_block_ids
514+ )
515+ request_to_task [request_id ] = self .store .dump (
516+ ucm_total_block_ids , ucm_offsets , dst_tensor_addr
543517 )
544518 request_to_blocks [request_id ] = ucm_block_ids
545519
0 commit comments