@@ -307,7 +307,7 @@ def maybe_register_static_data(self, forward_context: ForwardContext):
307307 self .init_static_flag = True
308308
309309 def wait_transfer_task_done (self ):
310- assert len (self .tasks ) > 0
310+ # assert len(self.tasks) > 0
311311 for task_hash , task in self .tasks .items ():
312312 # TODO: handle exceptions
313313 ret = self .store_instance .wait (task )
@@ -352,9 +352,10 @@ def wait_retrieval_and_start_load(self):
352352 self .pre_topk_block_hashes , diff_blocks = diff_two_map (
353353 self .pre_topk_block_hashes , target_map
354354 )
355- self .launch_transfer_task (
356- "load" , list (diff_blocks .values ()), list (diff_blocks .keys ())
357- )
355+ if diff_blocks :
356+ self .launch_transfer_task (
357+ "load" , list (diff_blocks .values ()), list (diff_blocks .keys ())
358+ )
358359
359360 ## 2. load all
360361 # self.launch_transfer_task(
@@ -438,7 +439,8 @@ def attention_begin(
438439 self .k_cache [vllm_block_ids [- local_window_sz :]] = self .local_window
439440 self .start_retrieval (query , forward_context )
440441 self .wait_retrieval_and_start_load ()
441- self .wait_transfer_task_done ()
442+ if len (self .tasks ) > 0 :
443+ self .wait_transfer_task_done ()
442444
443445 def attention_finished (
444446 self ,
0 commit comments