@@ -56,37 +56,21 @@ class RequestHasher:
5656
5757 _SEED_HASH = None
5858
59- def __init__ (self ):
60- if RequestHasher ._SEED_HASH is None :
61- RequestHasher ._SEED_HASH = self ._md5 ("UCM_HASH_SEED" )
62-
63- @staticmethod
64- def _md5 (input_data ) -> int :
65- input_bytes = pickle .dumps (input_data , protocol = pickle .HIGHEST_PROTOCOL )
66- md5_bytes = hashlib .md5 (input_bytes ).digest ()
67- return int .from_bytes (md5_bytes , byteorder = "big" )
68-
69- def __call__ (self , block_size : int , request : "Request" ) -> list [str ]:
70- token_ids = request .all_token_ids
71-
72- ret = []
73- parent_block_hash_value = None
74- for start in range (0 , len (token_ids ), block_size ):
75- end = start + block_size
76- block_token_ids = token_ids [start :end ]
77- # Do not hash the block if it is not full.
78- if len (block_token_ids ) < block_size :
79- break
59+ def __init__ (self , vllm_config , rank_id ):
60+ meta = f"{ vllm_config .model_config .model } :{ vllm_config .parallel_config .world_size } :{ vllm_config .model_config .dtype } :{ rank_id } "
61+ self .meta_bytes = meta .encode ("utf-8" )
8062
81- if not parent_block_hash_value :
82- parent_block_hash_value = RequestHasher ._SEED_HASH
63+ if RequestHasher . _SEED_HASH is None :
64+ RequestHasher ._SEED_HASH = self ( "UCM_HASH_SEED" )
8365
84- block_token_ids_tuple = tuple (block_token_ids )
85- hash_value = self ._md5 ((parent_block_hash_value , block_token_ids_tuple ))
86- parent_block_hash_value = hash_value
87- ret .append (str (hash_value ))
66+ def __call__ (self , input_data ) -> int :
67+ if isinstance (input_data , str ):
68+ input_bytes = input_data .encode ("utf-8" )
69+ else :
70+ input_bytes = pickle .dumps (input_data , protocol = pickle .HIGHEST_PROTOCOL )
8871
89- return ret
72+ h = hashlib .md5 (self .meta_bytes + input_bytes )
73+ return int .from_bytes (h .digest (), byteorder = "big" )
9074
9175
9276class UCMDirectConnector (KVConnectorBase_V1 ):
@@ -114,15 +98,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
11498 torch_dev = torch .npu
11599 dev_name = "npu"
116100 else :
117- raise RuntimeError ("Unsupported device platform for LMCache engine ." )
101+ raise RuntimeError ("Unsupported device platform for UCMDirectConnector ." )
118102
119103 if self .rank >= 0 :
120104 self .device = torch_dev .device (f"{ dev_name } :{ self .rank } " )
121105 self ._layer_offset_cache = {}
122106
123107 self .store : UcmKVStoreBase
124108
125- self .request_hasher = RequestHasher ()
109+ if role == KVConnectorRole .SCHEDULER :
110+ self .request_hasher = RequestHasher (vllm_config , 0 )
111+ else :
112+ self .request_hasher = RequestHasher (vllm_config , self .rank )
126113
127114 # save block info, avoid hash request twice, and track them until request finished
128115 self .requests_meta : dict [str , RequestMeta ] = {}
@@ -139,41 +126,60 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
139126 self .broadcast_fn = self .group_coordinator .broadcast
140127 self .broadcast_stream = torch .cuda .Stream ()
141128
142- if "ucm_connector_name" in self .launch_config :
143- name = self .launch_config .get ("ucm_connector_name" )
144- config = self .launch_config .get ("ucm_connector_config" ) or {}
145- config ["device" ] = self .rank
146- config ["role" ] = (
147- "scheduler" if role == KVConnectorRole .SCHEDULER else "worker"
148- )
149- element_size = vllm_config .model_config .dtype .itemsize
150- single_head_dim = vllm_config .model_config .get_head_size ()
151- num_head_per_tp = vllm_config .model_config .get_num_kv_heads (
152- vllm_config .parallel_config
153- )
154- total_tp_size = vllm_config .parallel_config .tensor_parallel_size
155- num_layers = vllm_config .model_config .get_num_layers (
156- vllm_config .parallel_config
157- )
158- block_size_per_layer = self .block_size * element_size * single_head_dim
159- config ["kv_block_size" ] = (
160- block_size_per_layer
161- * num_layers
162- * (1 if self .is_mla else num_head_per_tp * total_tp_size * 2 )
163- )
164- config ["io_size" ] = block_size_per_layer * (
165- 1 if self .is_mla else num_head_per_tp
166- )
167- self .store = UcmConnectorFactory .create_connector (name , config )
129+ connector_configs = self .launch_config .get ("ucm_connectors" , [])
130+ assert len (connector_configs ) > 0 , "no storage connector name in config."
131+
132+ name = connector_configs [0 ].get ("ucm_connector_name" )
133+ config = connector_configs [0 ].get ("ucm_connector_config" ) or {}
134+ config ["device" ] = self .rank
135+ config ["role" ] = "scheduler" if role == KVConnectorRole .SCHEDULER else "worker"
136+ element_size = vllm_config .model_config .dtype .itemsize
137+ single_head_dim = vllm_config .model_config .get_head_size ()
138+ num_head_per_tp = vllm_config .model_config .get_num_kv_heads (
139+ vllm_config .parallel_config
140+ )
141+ total_tp_size = vllm_config .parallel_config .tensor_parallel_size
142+ num_layers = vllm_config .model_config .get_num_layers (
143+ vllm_config .parallel_config
144+ )
145+ block_size_per_layer = self .block_size * element_size * single_head_dim
146+ config ["kv_block_size" ] = (
147+ block_size_per_layer
148+ * num_layers
149+ * (1 if self .is_mla else num_head_per_tp * 2 )
150+ )
151+ config ["io_size" ] = block_size_per_layer * (
152+ 1 if self .is_mla else num_head_per_tp
153+ )
154+ self .store = UcmConnectorFactory .create_connector (name , config )
155+
156+ logger .info ("init UCConnectorImpl, connector: %s" , name )
157+ logger .info (
158+ "single file size = %d MB, io_size = %d KB," ,
159+ config ["kv_block_size" ] / 1024 / 1024 ,
160+ config ["io_size" ] / 1024 ,
161+ )
162+
163+ def generate_hash (self , block_size : int , request : "Request" ) -> list [str ]:
164+ token_ids = request .all_token_ids
165+
166+ ret = []
167+ parent_block_hash_value = RequestHasher ._SEED_HASH
168+ for start in range (0 , len (token_ids ), block_size ):
169+ end = start + block_size
170+ block_token_ids = token_ids [start :end ]
171+ # Do not hash the block if it is not full.
172+ if len (block_token_ids ) < block_size :
173+ break
168174
169- logger .info ("init UCConnectorImpl, connector: %s" , name )
170- logger .info (
171- "single file size = %d MB, io_size = %d KB," ,
172- config ["kv_block_size" ] / 1024 / 1024 ,
173- config ["io_size" ] / 1024 ,
175+ block_token_ids_tuple = tuple (block_token_ids )
176+ hash_value = self .request_hasher (
177+ (parent_block_hash_value , block_token_ids_tuple )
174178 )
175- else :
176- raise TypeError (f"no storage connector name in config." )
179+ parent_block_hash_value = hash_value
180+ ret .append (str (hash_value ))
181+
182+ return ret
177183
178184 def get_num_new_matched_tokens (
179185 self ,
@@ -184,7 +190,7 @@ def get_num_new_matched_tokens(
184190 assert num_computed_tokens % self .block_size == 0
185191 hbm_hit_block_num = num_computed_tokens // self .block_size
186192
187- ucm_block_ids = self .request_hasher (self .block_size , request )
193+ ucm_block_ids = self .generate_hash (self .block_size , request )
188194
189195 external_block_ids = ucm_block_ids [hbm_hit_block_num :]
190196 if not external_block_ids :
@@ -210,7 +216,7 @@ def get_num_new_matched_tokens(
210216 # When all the tokens are cached in ssd or hbm,
211217 # we need to recompute the last token. This if condition will be removed
212218 # once vLLM scheduler provides a better solution in the future.
213- if external_hit_tokens == request .num_prompt_tokens :
219+ if total_hit_block_num * self . block_size == request .num_tokens :
214220 external_hit_tokens -= 1
215221
216222 self .requests_meta [request .request_id ] = RequestMeta (
@@ -449,6 +455,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
449455 continue
450456
451457 ucm_block_ids , vllm_block_ids = request .load_block_ids
458+ if self .rank != 0 and not self .is_mla :
459+ for i , ucm_block_id in enumerate (ucm_block_ids ):
460+ ucm_block_ids [i ] = str (self .request_hasher (ucm_block_id ))
452461 ucm_total_block_ids , ucm_offsets , dst_tensor_addr = self ._generate_task (
453462 vllm_block_ids , ucm_block_ids
454463 )
@@ -495,6 +504,9 @@ def wait_for_save(self) -> None:
495504 continue
496505
497506 ucm_block_ids , vllm_block_ids = request .dump_block_ids
507+ if self .rank != 0 :
508+ for i , ucm_block_id in enumerate (ucm_block_ids ):
509+ ucm_block_ids [i ] = str (self .request_hasher (ucm_block_id ))
498510 rets = self .store .create (ucm_block_ids )
499511 end = 0
500512 for i , ret in enumerate (rets ):
0 commit comments