Skip to content

Commit 3aed1ca

Browse files
committed
move process to UCMDirectConnector
1 parent 95113f8 commit 3aed1ca

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

ucm/integration/vllm/ucm_connector.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,6 @@ def __call__(self, input_data) -> int:
6565
md5_bytes = hashlib.md5(input_bytes).digest()
6666
return int.from_bytes(md5_bytes, byteorder="big")
6767

68-
def process(self, block_size: int, request: "Request") -> list[str]:
69-
token_ids = request.all_token_ids
70-
71-
ret = []
72-
parent_block_hash_value = None
73-
for start in range(0, len(token_ids), block_size):
74-
end = start + block_size
75-
block_token_ids = token_ids[start:end]
76-
# Do not hash the block if it is not full.
77-
if len(block_token_ids) < block_size:
78-
break
79-
80-
if not parent_block_hash_value:
81-
parent_block_hash_value = RequestHasher._SEED_HASH
82-
83-
block_token_ids_tuple = tuple(block_token_ids)
84-
hash_value = self((parent_block_hash_value, block_token_ids_tuple))
85-
parent_block_hash_value = hash_value
86-
ret.append(str(hash_value))
87-
88-
return ret
89-
9068

9169
class UCMDirectConnector(KVConnectorBase_V1):
9270
"""
@@ -172,6 +150,30 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
172150
config["io_size"] / 1024,
173151
)
174152

153+
def generate_hash(self, block_size: int, request: "Request") -> list[str]:
154+
token_ids = request.all_token_ids
155+
156+
ret = []
157+
parent_block_hash_value = None
158+
for start in range(0, len(token_ids), block_size):
159+
end = start + block_size
160+
block_token_ids = token_ids[start:end]
161+
# Do not hash the block if it is not full.
162+
if len(block_token_ids) < block_size:
163+
break
164+
165+
if not parent_block_hash_value:
166+
parent_block_hash_value = RequestHasher._SEED_HASH
167+
168+
block_token_ids_tuple = tuple(block_token_ids)
169+
hash_value = self.request_hasher(
170+
(parent_block_hash_value, block_token_ids_tuple)
171+
)
172+
parent_block_hash_value = hash_value
173+
ret.append(str(hash_value))
174+
175+
return ret
176+
175177
def get_num_new_matched_tokens(
176178
self,
177179
request: "Request",
@@ -181,7 +183,7 @@ def get_num_new_matched_tokens(
181183
assert num_computed_tokens % self.block_size == 0
182184
hbm_hit_block_num = num_computed_tokens // self.block_size
183185

184-
ucm_block_ids = self.request_hasher.process(self.block_size, request)
186+
ucm_block_ids = self.generate_hash(self.block_size, request)
185187

186188
external_block_ids = ucm_block_ids[hbm_hit_block_num:]
187189
if not external_block_ids:

0 commit comments

Comments
 (0)