11import torch
2+ import torch .distributed as dist
23from .radix_cache import RadixCache , TreeNode , match
34from typing import Tuple , Dict , Set , List
45from lightllm .common .mem_manager import MemoryManager
@@ -23,12 +24,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
2324 self .is_hi_radix_cache = True
2425 all_buffers = self .mem_manager .kv_buffer
2526 all_buffers = all_buffers .view (all_buffers .shape [0 ], all_buffers .shape [1 ], - 1 )
26- self .py_cache_service = PyLocalCacheService (
27- file = "cache/cache_file" ,
28- storage_size = 128 * (1024 ** 3 ),
29- num_shard = 32 ,
30- kvcache_tensor = all_buffers ,
31- num_worker = 32 ,
27+ self .py_cache_service = (
28+ PyLocalCacheService (
29+ file = "cache/cache_file" ,
30+ storage_size = 128 * (1024 ** 3 ),
31+ num_shard = 32 ,
32+ kvcache_tensor = all_buffers ,
33+ num_worker = 32 ,
34+ )
35+ if self .do_store
36+ else None
3237 )
3338 self .working_tasks = {}
3439 except Exception as e :
@@ -48,7 +53,7 @@ def insert_disk(self, req_id, key, value):
4853 logger .info (f"Created store task for req { req_id } ." )
4954
5055 def abort_req_store_task (self , req_id ):
51- if not self .do_store :
56+ if not self .do_store or req_id not in self . working_tasks :
5257 return
5358 if self .working_tasks [req_id ].ready ():
5459 logger .info (f"Calling abort for req { req_id } , but is finished." )
@@ -126,48 +131,63 @@ def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, upd
126131 self .evict_tree_set .add (node )
127132
128133 def match_prefix (self , key , update_refs = False ):
129- st_time = time .time ()
130134 assert len (key ) != 0
131135 ans_value_list = []
132- tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
133- # add a parameter if get long enough (>50%)
134- first_query_time = time .time ()
135- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.1 First GPU query took { first_query_time - st_time } " )
136- max_len = self ._query_hi_cache (key ) # x64
137- hi_cache_query_time = time .time ()
138- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.2 Disk query took { hi_cache_query_time - first_query_time } " )
139- logger .info (f"Matched { sum (len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
136+ pull_hi_cache_tensor = torch .tensor ([0 ], dtype = torch .int64 ).cuda (self .rank_in_node )
137+ if self .do_store :
138+ # st_time = time.time()
139+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
140+ # add a parameter if get long enough (>50%)
141+ # first_query_time = time.time()
142+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s")
143+ max_len = self ._query_hi_cache (key ) # x64
144+ # hi_cache_q_time = time.time()
145+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s")
146+ logger .info (f"Matched { sum (len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
147+ pull_hi_cache_tensor [0 ] = max_len if (max_len > sum (len (s ) for s in ans_value_list )) else 0
148+ # hi_cache_q_time = time.time()
149+ dist .broadcast (pull_hi_cache_tensor , src = 0 )
150+ # logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}")
140151 pull_hi_cache = False
141- if max_len > sum (len (s ) for s in ans_value_list ):
152+ # logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}")
153+
154+ if pull_hi_cache_tensor [0 ] == 0 and not self .do_store :
155+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
156+ elif pull_hi_cache_tensor [0 ] > 0 :
142157 pull_hi_cache = True
158+ max_len = pull_hi_cache_tensor [0 ]
143159 try :
144160 self .free_radix_cache_to_get_enough_token (max_len )
145161 except :
146- if update_refs :
147- tree_node = self . _match_prefix_helper ( self . root_node , key , ans_value_list , update_refs = update_refs )
162+ logger . info ( f"Unable to free on rank { self . rank_in_node } " )
163+ pull_hi_cache_tensor [ 0 ] = 0
148164 pull_hi_cache = False
165+ ans_value_list = []
166+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
149167 if pull_hi_cache :
150168 buffers = self .mem_manager .alloc (max_len )
151- before_pull_time = time .time ()
152- logger .info (
153- f"HiCache of [{ self .rank_in_node } ]: No.2.5 Before pull took { before_pull_time - hi_cache_query_time } "
154- )
155- read_task = self .py_cache_service .create (tokens = key [:max_len ], kv_page_indexer = buffers , mode = "r" )
156- while not read_task .ready ():
157- time .sleep (0.1 )
158- hicache_pull_time = time .time ()
159- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.3 Disk pull took { hicache_pull_time - before_pull_time } " )
169+ # before_pull_time = time.time()
170+ # logger.info(
171+ # f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}"
172+ # )
173+ if self .do_store :
174+ read_task = self .py_cache_service .create (tokens = key [:max_len ], kv_page_indexer = buffers , mode = "r" )
175+ while not read_task .ready ():
176+ time .sleep (0.05 )
177+ dist .broadcast (self .mem_manager .get_index_kv_buffer (buffers )["kv_buffer" ], src = 0 )
178+ # hicache_pull_time = time.time()
179+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s")
160180 logger .info (f"HiCache pulled one cache with len = { max_len } " )
161181 # maybe try: add a function to only insert middle part of kv cache
162182 self ._insert_helper (self .root_node , key , buffers )
163- insert_time = time .time ()
164- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.4 Reinsert took { insert_time - hicache_pull_time } " )
183+ # insert_time = time.time()
184+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
165185 ans_value_list = []
166186 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
167- logger .info (
168- f"HiCache of [{ self .rank_in_node } ]: No.5 Re match prefix took { time .time () - insert_time } "
169- + f" matched { sum (len (s ) for s in ans_value_list )} tokens"
170- )
187+ # logger.info(
188+ # f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
189+ # + f" matched {sum(len(s) for s in ans_value_list)} tokens"
190+ # )
171191 if tree_node != self .root_node :
172192 if len (ans_value_list ) != 0 :
173193 value = torch .concat (ans_value_list )
0 commit comments