Skip to content

Commit 0986b89

Browse files
authored
[feat]adapt GQA & modify config.yaml (#407)
* adapt GQA & modify config.yaml * move process to UCMDirectConnector * fix comment * modify hash function * fix style * code style and modify hash * init parent_block_hash_value
1 parent 6358406 commit 0986b89

File tree

3 files changed

+83
-72
lines changed

3 files changed

+83
-72
lines changed

examples/ucm_config_example.yaml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
# for backward compatibility.
99

1010
# Connector name (e.g., "UcmNfsStore", "UcmDramStore")
11-
ucm_connector_name: "UcmNfsStore"
12-
13-
# Connector-specific configuration
14-
ucm_connector_config:
15-
storage_backends: "/mnt/test"
16-
transferIoDirect: false
11+
ucm_connectors:
12+
- ucm_connector_name: "UcmNfsStore"
13+
ucm_connector_config:
14+
storage_backends: "/mnt/test"
15+
use_direct: false
1716

1817
load_only_first_rank: false
1918

ucm/integration/vllm/ucm_connector.py

Lines changed: 77 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -56,37 +56,21 @@ class RequestHasher:
5656

5757
_SEED_HASH = None
5858

59-
def __init__(self):
60-
if RequestHasher._SEED_HASH is None:
61-
RequestHasher._SEED_HASH = self._md5("UCM_HASH_SEED")
62-
63-
@staticmethod
64-
def _md5(input_data) -> int:
65-
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
66-
md5_bytes = hashlib.md5(input_bytes).digest()
67-
return int.from_bytes(md5_bytes, byteorder="big")
68-
69-
def __call__(self, block_size: int, request: "Request") -> list[str]:
70-
token_ids = request.all_token_ids
71-
72-
ret = []
73-
parent_block_hash_value = None
74-
for start in range(0, len(token_ids), block_size):
75-
end = start + block_size
76-
block_token_ids = token_ids[start:end]
77-
# Do not hash the block if it is not full.
78-
if len(block_token_ids) < block_size:
79-
break
59+
def __init__(self, vllm_config, rank_id):
60+
meta = f"{vllm_config.model_config.model}:{vllm_config.parallel_config.world_size}:{vllm_config.model_config.dtype}:{rank_id}"
61+
self.meta_bytes = meta.encode("utf-8")
8062

81-
if not parent_block_hash_value:
82-
parent_block_hash_value = RequestHasher._SEED_HASH
63+
if RequestHasher._SEED_HASH is None:
64+
RequestHasher._SEED_HASH = self("UCM_HASH_SEED")
8365

84-
block_token_ids_tuple = tuple(block_token_ids)
85-
hash_value = self._md5((parent_block_hash_value, block_token_ids_tuple))
86-
parent_block_hash_value = hash_value
87-
ret.append(str(hash_value))
66+
def __call__(self, input_data) -> int:
67+
if isinstance(input_data, str):
68+
input_bytes = input_data.encode("utf-8")
69+
else:
70+
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
8871

89-
return ret
72+
h = hashlib.md5(self.meta_bytes + input_bytes)
73+
return int.from_bytes(h.digest(), byteorder="big")
9074

9175

9276
class UCMDirectConnector(KVConnectorBase_V1):
@@ -114,15 +98,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
11498
torch_dev = torch.npu
11599
dev_name = "npu"
116100
else:
117-
raise RuntimeError("Unsupported device platform for LMCache engine.")
101+
raise RuntimeError("Unsupported device platform for UCMDirectConnector.")
118102

119103
if self.rank >= 0:
120104
self.device = torch_dev.device(f"{dev_name}:{self.rank}")
121105
self._layer_offset_cache = {}
122106

123107
self.store: UcmKVStoreBase
124108

125-
self.request_hasher = RequestHasher()
109+
if role == KVConnectorRole.SCHEDULER:
110+
self.request_hasher = RequestHasher(vllm_config, 0)
111+
else:
112+
self.request_hasher = RequestHasher(vllm_config, self.rank)
126113

127114
# save block info, avoid hash request twice, and track them until request finished
128115
self.requests_meta: dict[str, RequestMeta] = {}
@@ -139,41 +126,60 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
139126
self.broadcast_fn = self.group_coordinator.broadcast
140127
self.broadcast_stream = torch.cuda.Stream()
141128

142-
if "ucm_connector_name" in self.launch_config:
143-
name = self.launch_config.get("ucm_connector_name")
144-
config = self.launch_config.get("ucm_connector_config") or {}
145-
config["device"] = self.rank
146-
config["role"] = (
147-
"scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
148-
)
149-
element_size = vllm_config.model_config.dtype.itemsize
150-
single_head_dim = vllm_config.model_config.get_head_size()
151-
num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
152-
vllm_config.parallel_config
153-
)
154-
total_tp_size = vllm_config.parallel_config.tensor_parallel_size
155-
num_layers = vllm_config.model_config.get_num_layers(
156-
vllm_config.parallel_config
157-
)
158-
block_size_per_layer = self.block_size * element_size * single_head_dim
159-
config["kv_block_size"] = (
160-
block_size_per_layer
161-
* num_layers
162-
* (1 if self.is_mla else num_head_per_tp * total_tp_size * 2)
163-
)
164-
config["io_size"] = block_size_per_layer * (
165-
1 if self.is_mla else num_head_per_tp
166-
)
167-
self.store = UcmConnectorFactory.create_connector(name, config)
129+
connector_configs = self.launch_config.get("ucm_connectors", [])
130+
assert len(connector_configs) > 0, "no storage connector name in config."
131+
132+
name = connector_configs[0].get("ucm_connector_name")
133+
config = connector_configs[0].get("ucm_connector_config") or {}
134+
config["device"] = self.rank
135+
config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
136+
element_size = vllm_config.model_config.dtype.itemsize
137+
single_head_dim = vllm_config.model_config.get_head_size()
138+
num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
139+
vllm_config.parallel_config
140+
)
141+
total_tp_size = vllm_config.parallel_config.tensor_parallel_size
142+
num_layers = vllm_config.model_config.get_num_layers(
143+
vllm_config.parallel_config
144+
)
145+
block_size_per_layer = self.block_size * element_size * single_head_dim
146+
config["kv_block_size"] = (
147+
block_size_per_layer
148+
* num_layers
149+
* (1 if self.is_mla else num_head_per_tp * 2)
150+
)
151+
config["io_size"] = block_size_per_layer * (
152+
1 if self.is_mla else num_head_per_tp
153+
)
154+
self.store = UcmConnectorFactory.create_connector(name, config)
155+
156+
logger.info("init UCConnectorImpl, connector: %s", name)
157+
logger.info(
158+
"single file size = %d MB, io_size = %d KB,",
159+
config["kv_block_size"] / 1024 / 1024,
160+
config["io_size"] / 1024,
161+
)
162+
163+
def generate_hash(self, block_size: int, request: "Request") -> list[str]:
164+
token_ids = request.all_token_ids
165+
166+
ret = []
167+
parent_block_hash_value = RequestHasher._SEED_HASH
168+
for start in range(0, len(token_ids), block_size):
169+
end = start + block_size
170+
block_token_ids = token_ids[start:end]
171+
# Do not hash the block if it is not full.
172+
if len(block_token_ids) < block_size:
173+
break
168174

169-
logger.info("init UCConnectorImpl, connector: %s", name)
170-
logger.info(
171-
"single file size = %d MB, io_size = %d KB,",
172-
config["kv_block_size"] / 1024 / 1024,
173-
config["io_size"] / 1024,
175+
block_token_ids_tuple = tuple(block_token_ids)
176+
hash_value = self.request_hasher(
177+
(parent_block_hash_value, block_token_ids_tuple)
174178
)
175-
else:
176-
raise TypeError(f"no storage connector name in config.")
179+
parent_block_hash_value = hash_value
180+
ret.append(str(hash_value))
181+
182+
return ret
177183

178184
def get_num_new_matched_tokens(
179185
self,
@@ -184,7 +190,7 @@ def get_num_new_matched_tokens(
184190
assert num_computed_tokens % self.block_size == 0
185191
hbm_hit_block_num = num_computed_tokens // self.block_size
186192

187-
ucm_block_ids = self.request_hasher(self.block_size, request)
193+
ucm_block_ids = self.generate_hash(self.block_size, request)
188194

189195
external_block_ids = ucm_block_ids[hbm_hit_block_num:]
190196
if not external_block_ids:
@@ -210,7 +216,7 @@ def get_num_new_matched_tokens(
210216
# When all the tokens are cached in ssd or hbm,
211217
# we need to recompute the last token. This if condition will be removed
212218
# once vLLM scheduler provides a better solution in the future.
213-
if external_hit_tokens == request.num_prompt_tokens:
219+
if total_hit_block_num * self.block_size == request.num_tokens:
214220
external_hit_tokens -= 1
215221

216222
self.requests_meta[request.request_id] = RequestMeta(
@@ -449,6 +455,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
449455
continue
450456

451457
ucm_block_ids, vllm_block_ids = request.load_block_ids
458+
if self.rank != 0 and not self.is_mla:
459+
for i, ucm_block_id in enumerate(ucm_block_ids):
460+
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
452461
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
453462
vllm_block_ids, ucm_block_ids
454463
)
@@ -495,6 +504,9 @@ def wait_for_save(self) -> None:
495504
continue
496505

497506
ucm_block_ids, vllm_block_ids = request.dump_block_ids
507+
if self.rank != 0:
508+
for i, ucm_block_id in enumerate(ucm_block_ids):
509+
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
498510
rets = self.store.create(ucm_block_ids)
499511
end = 0
500512
for i, ret in enumerate(rets):

ucm/store/nfsstore/nfsstore_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, config: Dict):
5151
if transfer_enable:
5252
param.transferDeviceId = config["device"]
5353
param.transferIoSize = config["io_size"]
54-
param.transferIoDirect = config.get("transferIoDirect", False)
54+
param.transferIoDirect = config.get("use_direct", False)
5555

5656
# NOTE: compatible with legacy nfsstore lib
5757
if hasattr(param, "storageCapacity"):

0 commit comments

Comments
 (0)