Skip to content

Commit 8823274

Browse files
committed
fix comment
1 parent 72ac7a7 commit 8823274

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

examples/ucm_config_example.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ ucm_connectors:
1414
storage_backends: "/mnt/test"
1515
transferIoDirect: false
1616

17-
- ucm_connector_name: "UcmLocalStore"
18-
ucm_connector_config:
19-
cache_size: 12288
20-
2117
load_only_first_rank: false
2218

2319
# Sparse attention configuration

ucm/integration/vllm/ucm_connector.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,14 @@ class RequestHasher:
5858

5959
def __init__(self):
6060
if RequestHasher._SEED_HASH is None:
61-
RequestHasher._SEED_HASH = self.md5("UCM_HASH_SEED")
61+
RequestHasher._SEED_HASH = self("UCM_HASH_SEED")
6262

63-
@staticmethod
64-
def md5(input_data) -> int:
63+
def __call__(self, input_data) -> int:
6564
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
6665
md5_bytes = hashlib.md5(input_bytes).digest()
6766
return int.from_bytes(md5_bytes, byteorder="big")
6867

69-
def __call__(self, block_size: int, request: "Request") -> list[str]:
68+
def process(self, block_size: int, request: "Request") -> list[str]:
7069
token_ids = request.all_token_ids
7170

7271
ret = []
@@ -82,7 +81,7 @@ def __call__(self, block_size: int, request: "Request") -> list[str]:
8281
parent_block_hash_value = RequestHasher._SEED_HASH
8382

8483
block_token_ids_tuple = tuple(block_token_ids)
85-
hash_value = self.md5((parent_block_hash_value, block_token_ids_tuple))
84+
hash_value = self((parent_block_hash_value, block_token_ids_tuple))
8685
parent_block_hash_value = hash_value
8786
ret.append(str(hash_value))
8887

@@ -182,7 +181,7 @@ def get_num_new_matched_tokens(
182181
assert num_computed_tokens % self.block_size == 0
183182
hbm_hit_block_num = num_computed_tokens // self.block_size
184183

185-
ucm_block_ids = self.request_hasher(self.block_size, request)
184+
ucm_block_ids = self.request_hasher.process(self.block_size, request)
186185

187186
external_block_ids = ucm_block_ids[hbm_hit_block_num:]
188187
if not external_block_ids:
@@ -449,7 +448,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
449448
ucm_block_ids, vllm_block_ids = request.load_block_ids
450449
if self.rank != 0:
451450
for i, ucm_block_id in enumerate(ucm_block_ids):
452-
ucm_block_ids[i] = str(RequestHasher.md5((ucm_block_id, self.rank)))
451+
ucm_block_ids[i] = str(self.request_hasher((ucm_block_id, self.rank)))
453452
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
454453
vllm_block_ids, ucm_block_ids
455454
)
@@ -498,7 +497,7 @@ def wait_for_save(self) -> None:
498497
ucm_block_ids, vllm_block_ids = request.dump_block_ids
499498
if self.rank != 0:
500499
for i, ucm_block_id in enumerate(ucm_block_ids):
501-
ucm_block_ids[i] = str(RequestHasher.md5((ucm_block_id, self.rank)))
500+
ucm_block_ids[i] = str(self.request_hasher((ucm_block_id, self.rank)))
502501
rets = self.store.create(ucm_block_ids)
503502
end = 0
504503
for i, ret in enumerate(rets):

0 commit comments

Comments
 (0)