@@ -58,15 +58,14 @@ class RequestHasher:
5858
5959 def __init__ (self ):
6060 if RequestHasher ._SEED_HASH is None :
61- RequestHasher ._SEED_HASH = self . md5 ("UCM_HASH_SEED" )
61+ RequestHasher ._SEED_HASH = self ("UCM_HASH_SEED" )
6262
63- @staticmethod
64- def md5 (input_data ) -> int :
63+ def __call__ (self , input_data ) -> int :
6564 input_bytes = pickle .dumps (input_data , protocol = pickle .HIGHEST_PROTOCOL )
6665 md5_bytes = hashlib .md5 (input_bytes ).digest ()
6766 return int .from_bytes (md5_bytes , byteorder = "big" )
6867
69- def __call__ (self , block_size : int , request : "Request" ) -> list [str ]:
68+ def process (self , block_size : int , request : "Request" ) -> list [str ]:
7069 token_ids = request .all_token_ids
7170
7271 ret = []
@@ -82,7 +81,7 @@ def __call__(self, block_size: int, request: "Request") -> list[str]:
8281 parent_block_hash_value = RequestHasher ._SEED_HASH
8382
8483 block_token_ids_tuple = tuple (block_token_ids )
85- hash_value = self . md5 ((parent_block_hash_value , block_token_ids_tuple ))
84+ hash_value = self ((parent_block_hash_value , block_token_ids_tuple ))
8685 parent_block_hash_value = hash_value
8786 ret .append (str (hash_value ))
8887
@@ -182,7 +181,7 @@ def get_num_new_matched_tokens(
182181 assert num_computed_tokens % self .block_size == 0
183182 hbm_hit_block_num = num_computed_tokens // self .block_size
184183
185- ucm_block_ids = self .request_hasher (self .block_size , request )
184+ ucm_block_ids = self .request_hasher . process (self .block_size , request )
186185
187186 external_block_ids = ucm_block_ids [hbm_hit_block_num :]
188187 if not external_block_ids :
@@ -449,7 +448,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
449448 ucm_block_ids , vllm_block_ids = request .load_block_ids
450449 if self .rank != 0 :
451450 for i , ucm_block_id in enumerate (ucm_block_ids ):
452- ucm_block_ids [i ] = str (RequestHasher . md5 ((ucm_block_id , self .rank )))
451+ ucm_block_ids [i ] = str (self . request_hasher ((ucm_block_id , self .rank )))
453452 ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
454453 vllm_block_ids , ucm_block_ids
455454 )
@@ -498,7 +497,7 @@ def wait_for_save(self) -> None:
498497 ucm_block_ids , vllm_block_ids = request .dump_block_ids
499498 if self .rank != 0 :
500499 for i , ucm_block_id in enumerate (ucm_block_ids ):
501- ucm_block_ids [i ] = str (RequestHasher . md5 ((ucm_block_id , self .rank )))
500+ ucm_block_ids [i ] = str (self . request_hasher ((ucm_block_id , self .rank )))
502501 rets = self .store .create (ucm_block_ids )
503502 end = 0
504503 for i , ret in enumerate (rets ):
0 commit comments