Skip to content

Commit 63c916b

Browse files
committed
[feature] optimize generate_tensor (#396)
optimize generate_tensor
1 parent 66e3e18 commit 63c916b

File tree

2 files changed

+105
-130
lines changed

2 files changed

+105
-130
lines changed

examples/ucm_config_example.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ ucm_connector_name: "UcmNfsStore"
1414
ucm_connector_config:
1515
storage_backends: "/mnt/test"
1616
transferIoDirect: false
17-
load_only_first_rank: false
17+
18+
load_only_first_rank: false
1819

1920
# Sparse attention configuration
2021
# Format 1: Dictionary format (for methods like ESA, KvComp)

ucm/integration/vllm/ucm_connector.py

Lines changed: 103 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
KVConnectorRole,
1414
)
1515
from vllm.distributed.parallel_state import get_tp_group, get_world_group
16+
from vllm.platforms import current_platform
1617
from vllm.v1.core.sched.output import SchedulerOutput
1718
from vllm.v1.request import Request
1819

@@ -102,16 +103,42 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
102103
)
103104
self.block_size = self._vllm_config.cache_config.block_size
104105
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
106+
self.kv_cache_dtype: torch.dtype = None
107+
108+
if current_platform.is_cuda_alike():
109+
logger.info("CUDA device is available.")
110+
torch_dev = torch
111+
dev_name = "cuda"
112+
elif current_platform.is_npu():
113+
logger.info("NPU device is available.")
114+
torch_dev = torch.npu
115+
dev_name = "npu"
116+
else:
117+
raise RuntimeError("Unsupported device platform for LMCache engine.")
118+
119+
if self.rank >= 0:
120+
self.device = torch_dev.device(f"{dev_name}:{self.rank}")
121+
self._layer_offset_cache = {}
105122

106123
self.store: UcmKVStoreBase
107124

108125
self.request_hasher = RequestHasher()
109126

110127
# save block info, avoid hash request twice, and track them until request finished
111128
self.requests_meta: dict[str, RequestMeta] = {}
129+
112130
ucm_config = Config(vllm_config.kv_transfer_config)
113131
self.launch_config = ucm_config.get_config()
114132

133+
self.load_only_first_rank: bool = (
134+
self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla
135+
)
136+
if self.load_only_first_rank:
137+
if role == KVConnectorRole.WORKER:
138+
self.group_coordinator = get_tp_group()
139+
self.broadcast_fn = self.group_coordinator.broadcast
140+
self.broadcast_stream = torch.cuda.Stream()
141+
115142
if "ucm_connector_name" in self.launch_config:
116143
name = self.launch_config.get("ucm_connector_name")
117144
config = self.launch_config.get("ucm_connector_config") or {}
@@ -137,14 +164,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
137164
config["io_size"] = block_size_per_layer * (
138165
1 if self.is_mla else num_head_per_tp
139166
)
140-
self.load_only_first_rank: bool = (
141-
config.get("load_only_first_rank", self.is_mla) and self.is_mla
142-
)
143-
if self.load_only_first_rank:
144-
if role == KVConnectorRole.WORKER:
145-
self.group_coordinator = get_tp_group()
146-
self.broadcast_fn = self.group_coordinator.broadcast
147-
self.broadcast_stream = torch.cuda.Stream()
148167
self.store = UcmConnectorFactory.create_connector(name, config)
149168

150169
logger.info("init UCConnectorImpl, connector: %s", name)
@@ -320,6 +339,8 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
320339
self.kv_caches[layer_name] = attn_layer.kv_cache[
321340
forward_context.virtual_engine
322341
]
342+
if self.kv_cache_dtype is None:
343+
self.kv_cache_dtype = self.kv_caches[layer_name][0].dtype
323344

324345
@staticmethod
325346
def _extract_layer_index(layer_name: str) -> Optional[int]:
@@ -331,133 +352,88 @@ def _extract_layer_index(layer_name: str) -> Optional[int]:
331352
return int(chunk)
332353
return None
333354

334-
def _data_offset(self, kv_layer, layer_id, is_v) -> int:
335-
"""
336-
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
337-
MLA: one layer shape is (1, num_blocks, block_size, head_size)
338-
"""
339-
elem_size = kv_layer[0].element_size()
355+
def _precompute_layer_offsets(self):
356+
if not self.kv_caches:
357+
return
358+
359+
sample_kv_layer = next(iter(self.kv_caches.values()))
360+
elem_size = sample_kv_layer[0].element_size()
340361
block_data_size = (
341-
kv_layer[0].numel() if self.is_mla else kv_layer[0][0].numel()
362+
sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel()
342363
) * elem_size
343-
if is_v:
344-
return self._data_offset(kv_layer, layer_id, False) + block_data_size
345-
346364
layer_data_size = block_data_size if self.is_mla else block_data_size * 2
347-
return layer_data_size * layer_id
365+
366+
# precompute all layers offset
367+
for layer_name, _ in self.kv_caches.items():
368+
layer_id = self._extract_layer_index(layer_name)
369+
assert layer_id is not None
370+
k_offset = layer_data_size * layer_id
371+
v_offset = k_offset + block_data_size if not self.is_mla else 0
372+
self._layer_offset_cache[layer_name] = (k_offset, v_offset)
348373

349374
def _get_tensor_and_offset(
350375
self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str
351376
) -> tuple[list[torch.Tensor], list[int]]:
377+
"""
378+
GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
379+
MLA: one layer shape is (num_blocks, block_size, head_size)
380+
"""
352381
k_tensors, k_offsets = [], []
353382
v_tensors, v_offsets = [], []
354-
layer_id = self._extract_layer_index(layer_name)
355-
assert layer_id is not None
383+
k_offset, v_offset = self._layer_offset_cache[layer_name]
356384

357385
for vllm_block_id in vllm_block_ids:
358-
offset = self._data_offset(kv_layer, layer_id, False)
359-
tensor = (
386+
k_tensors.append(
360387
kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id]
361388
)
362-
k_tensors.append(tensor)
363-
k_offsets.append(offset)
389+
k_offsets.append(k_offset)
364390
if not self.is_mla:
365-
v_offset = self._data_offset(kv_layer, layer_id, True)
366391
v_tensors.append(kv_layer[1][vllm_block_id])
367392
v_offsets.append(v_offset)
368393
return k_tensors + v_tensors, k_offsets + v_offsets
369394

370-
def _generate_task(
371-
self,
372-
vllm_block_ids,
373-
ucm_block_ids,
374-
func: Callable[[List[str], List[int], List[torch.Tensor]], Task],
375-
) -> Task:
376-
dst_tensor_addr, ucm_offsets = [], []
395+
def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]):
396+
if not self._layer_offset_cache:
397+
self._precompute_layer_offsets()
398+
399+
num_layers = len(self.kv_caches)
400+
num_blocks_per_layer = len(vllm_block_ids)
401+
num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2)
402+
dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer)
403+
ucm_offsets = [0] * (num_layers * num_tensors_per_layer)
404+
405+
idx = 0
377406
for layer_name, one_layer_kv_cache in self.kv_caches.items():
378-
addrs, offsets = self._get_tensor_and_offset(
407+
tensors, offsets = self._get_tensor_and_offset(
379408
vllm_block_ids, one_layer_kv_cache, layer_name
380409
)
381-
dst_tensor_addr.extend(addrs)
382-
ucm_offsets.extend(offsets)
383-
ucm_total_block_ids = ucm_block_ids * len(self.kv_caches)
384-
if not self.is_mla:
385-
ucm_total_block_ids *= 2
410+
dst_tensor_addr[idx : idx + len(tensors)] = tensors
411+
ucm_offsets[idx : idx + len(offsets)] = offsets
412+
idx += len(tensors)
413+
414+
repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2)
415+
ucm_total_block_ids = ucm_block_ids * repeat_times
416+
386417
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
387-
return func(ucm_total_block_ids, ucm_offsets, dst_tensor_addr)
418+
return ucm_total_block_ids, ucm_offsets, dst_tensor_addr
388419

389-
def _generate_load_task_for_broadcast(
390-
self,
391-
vllm_block_ids,
392-
ucm_block_ids,
393-
can_load: bool,
394-
) -> tuple[Task, dict[str, torch.Tensor], int]:
395-
"""
396-
Load or Dump func is only called in rank 0 in MLA;
397-
In rank != 0, worker will receive broadcast tensors from rank 0.
398-
"""
399-
layer_to_tensors = {}
400-
total_block_num = len(ucm_block_ids)
401-
dst_tensor_addr, ucm_offsets = [], []
402-
for layer_name, one_layer_kv_cache in self.kv_caches.items():
403-
addrs, offsets = self._get_tensor_and_offset(
404-
vllm_block_ids, one_layer_kv_cache, layer_name
405-
)
406-
layer_to_tensors[layer_name] = addrs[:total_block_num]
407-
dst_tensor_addr.extend(addrs)
408-
ucm_offsets.extend(offsets)
409-
ucm_total_block_ids = ucm_block_ids * len(self.kv_caches)
410-
411-
task = None
412-
if can_load:
413-
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
414-
task = self.store.load(ucm_total_block_ids, ucm_offsets, dst_tensor_addr)
415-
return task, layer_to_tensors, total_block_num
416-
417-
def _broadcast_or_receive_blocks(
418-
self, layer_to_tensors: dict[str : torch.Tensor], total_block_num
419-
):
420-
receive_dict = {}
421-
for layer_name, kv_layer in self.kv_caches.items():
422-
k_tensors = layer_to_tensors[layer_name][:total_block_num]
420+
def _broadcast(self, dst_tensor_addr: list[torch.Tensor]):
421+
rec_tensor: torch.Tensor = None
422+
with torch.cuda.stream(self.broadcast_stream):
423423
if self.rank == 0:
424-
tensor_to_broadcast = torch.stack(k_tensors, dim=0)
424+
tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0)
425425
self.broadcast_fn(tensor_to_broadcast, 0)
426426
else:
427-
shape = (len(k_tensors),) + k_tensors[0].shape
428-
dtype = k_tensors[0].dtype
429-
rec_tensor = torch.empty(shape, dtype=dtype, device=f"cuda:{self.rank}")
427+
shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape
428+
# TODO create earlier
429+
rec_tensor = torch.empty(
430+
shape, dtype=self.kv_cache_dtype, device=self.device
431+
)
430432
self.broadcast_fn(rec_tensor, 0)
431-
receive_dict[layer_name] = rec_tensor
432-
return receive_dict
433-
434-
def _wait_for_broadcast(
435-
self,
436-
req_id: str,
437-
task: Task,
438-
layer_to_tensors: dict[str, torch.Tensor],
439-
total_block_num: int,
440-
):
441-
if self.rank == 0:
442-
if self.store.wait(task) != 0:
443-
logger.error(f"request {req_id} load kv cache failed.")
444-
return
445-
logger.debug(
446-
f"request {req_id} load {total_block_num} blocks on rank {self.rank}"
447-
)
448-
with torch.cuda.stream(self.broadcast_stream):
449-
receive_dict = self._broadcast_or_receive_blocks(
450-
layer_to_tensors, total_block_num
451-
)
452433
self.broadcast_stream.synchronize()
453-
if self.rank > 0 and receive_dict:
454-
for layer_name, kv_layer in self.kv_caches.items():
455-
received_tensor = receive_dict[layer_name]
456-
for i in range(total_block_num):
457-
layer_to_tensors[layer_name][i].copy_(received_tensor[i])
458-
logger.debug(
459-
f"request {req_id} receive broadcast {total_block_num} blocks on rank {self.rank}"
460-
)
434+
if self.rank != 0 and rec_tensor is not None:
435+
for i, tensor in enumerate(dst_tensor_addr):
436+
tensor.copy_(rec_tensor[i])
461437

462438
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
463439

@@ -467,35 +443,30 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
467443
self._init_kv_caches_from_forward_context(forward_context)
468444

469445
request_to_task: dict[str, Optional[Task]] = {}
470-
req_to_layer = {}
446+
req_broadcast_addr = {}
471447
for request_id, request in metadata.request_meta.items():
472448
if len(request.load_block_ids[0]) == 0:
473449
continue
474450

475451
ucm_block_ids, vllm_block_ids = request.load_block_ids
476-
if self.load_only_first_rank:
477-
can_load = self.rank == 0
478-
task, layer_to_tensors, total_block_num = (
479-
self._generate_load_task_for_broadcast(
480-
vllm_block_ids, ucm_block_ids, can_load
481-
)
452+
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
453+
vllm_block_ids, ucm_block_ids
454+
)
455+
if self.rank == 0 or not self.load_only_first_rank:
456+
request_to_task[request_id] = self.store.load(
457+
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
482458
)
483-
req_to_layer[request_id] = (layer_to_tensors, total_block_num)
484459
else:
485-
task = self._generate_task(
486-
vllm_block_ids, ucm_block_ids, self.store.load
487-
)
488-
request_to_task[request_id] = task
460+
request_to_task[request_id] = None
461+
req_broadcast_addr[request_id] = dst_tensor_addr
489462

490-
for req_id, task in request_to_task.items():
491-
if self.load_only_first_rank:
492-
layer_to_tensors, total_block_num = req_to_layer[req_id]
493-
self._wait_for_broadcast(
494-
req_id, task, layer_to_tensors, total_block_num
495-
)
496-
else:
463+
for request_id, task in request_to_task.items():
464+
# TODO error handling
465+
if self.rank == 0 or not self.load_only_first_rank:
497466
if self.store.wait(task) != 0:
498-
logger.error(f"request {req_id} load kv cache failed.")
467+
logger.error(f"request {request_id} load kv cache failed.")
468+
if self.load_only_first_rank:
469+
self._broadcast(req_broadcast_addr[request_id])
499470

500471
def wait_for_layer_load(self, layer_name: str) -> None:
501472
pass
@@ -538,8 +509,11 @@ def wait_for_save(self) -> None:
538509
continue
539510
ucm_block_ids = ucm_block_ids[:end]
540511
vllm_block_ids = vllm_block_ids[:end]
541-
request_to_task[request_id] = self._generate_task(
542-
vllm_block_ids, ucm_block_ids, self.store.dump
512+
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
513+
vllm_block_ids, ucm_block_ids
514+
)
515+
request_to_task[request_id] = self.store.dump(
516+
ucm_total_block_ids, ucm_offsets, dst_tensor_addr
543517
)
544518
request_to_blocks[request_id] = ucm_block_ids
545519

0 commit comments

Comments
 (0)