1111from lightllm .common .basemodel .infer_struct import InferStateInfo
1212from lightllm .common .mem_manager import MemoryManager
1313from lightllm .common .req_manager import ReqManager
14- from lightllm .common .infer_utils import init_req_to_token_indexes
1514from lightllm .common .build_utils import repair_config
16- from lightllm .common .basemodel .triton_kernel .copy_kv_index_to_req import copy_kv_index_to_req
1715from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
1816from lightllm .common .basemodel .cuda_graph import CudaGraph
1917from lightllm .common .quantization import Quantcfg
@@ -330,14 +328,6 @@ def _prefill(
330328 model_input : ModelInput ,
331329 ):
332330 infer_state = self ._create_inferstate (model_input )
333- init_req_to_token_indexes (
334- self .req_manager .req_to_token_indexs ,
335- model_input .b_req_idx ,
336- model_input .b_seq_len ,
337- infer_state .b_ready_cache_len ,
338- model_input .max_len_in_batch ,
339- infer_state .mem_index ,
340- )
341331
342332 infer_state .init_some_extra_state (self , model_input .input_ids )
343333 return self ._context_forward (model_input .input_ids , infer_state )
@@ -350,12 +340,6 @@ def _decode(
350340 find_graph_batch_size = self .graph .find_closest_graph_batch_size (model_input .batch_size )
351341 padded_model_input = self ._create_padded_decode_model_input (model_input , find_graph_batch_size )
352342 infer_state = self ._create_inferstate (padded_model_input )
353- copy_kv_index_to_req (
354- self .req_manager .req_to_token_indexs ,
355- infer_state .b_req_idx ,
356- infer_state .b_seq_len ,
357- infer_state .mem_index ,
358- )
359343 infer_state .init_some_extra_state (self , padded_model_input .input_ids )
360344
361345 if self .graph .need_capture (find_graph_batch_size ):
@@ -371,12 +355,6 @@ def _decode(
371355 )
372356 else :
373357 infer_state = self ._create_inferstate (model_input )
374- copy_kv_index_to_req (
375- self .req_manager .req_to_token_indexs ,
376- infer_state .b_req_idx ,
377- infer_state .b_seq_len ,
378- infer_state .mem_index ,
379- )
380358 infer_state .init_some_extra_state (self , model_input .input_ids )
381359 model_output = self ._token_forward (model_input .input_ids , infer_state )
382360
@@ -458,25 +436,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
458436 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
459437
460438 infer_state0 = self ._create_inferstate (model_input0 , 0 )
461- init_req_to_token_indexes (
462- self .req_manager .req_to_token_indexs ,
463- model_input0 .b_req_idx ,
464- model_input0 .b_seq_len ,
465- infer_state0 .b_ready_cache_len ,
466- model_input0 .max_len_in_batch ,
467- infer_state0 .mem_index ,
468- )
469439 infer_state0 .init_some_extra_state (self , input_ids0 )
470440
471441 infer_state1 = self ._create_inferstate (model_input1 , 1 )
472- init_req_to_token_indexes (
473- self .req_manager .req_to_token_indexs ,
474- model_input1 .b_req_idx ,
475- model_input1 .b_seq_len ,
476- infer_state1 .b_ready_cache_len ,
477- model_input1 .max_len_in_batch ,
478- infer_state1 .mem_index ,
479- )
480442 infer_state1 .init_some_extra_state (self , input_ids1 )
481443
482444 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
@@ -502,20 +464,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
502464 padded_model_input0 = self ._create_padded_decode_model_input (model_input0 , find_graph_batch_size )
503465 padded_model_input1 = self ._create_padded_decode_model_input (model_input1 , find_graph_batch_size )
504466 infer_state0 = self ._create_inferstate (padded_model_input0 , 0 )
505- copy_kv_index_to_req (
506- self .req_manager .req_to_token_indexs ,
507- infer_state0 .b_req_idx ,
508- infer_state0 .b_seq_len ,
509- infer_state0 .mem_index ,
510- )
511467 infer_state0 .init_some_extra_state (self , padded_model_input0 .input_ids )
512468 infer_state1 = self ._create_inferstate (padded_model_input1 , 1 )
513- copy_kv_index_to_req (
514- self .req_manager .req_to_token_indexs ,
515- infer_state1 .b_req_idx ,
516- infer_state1 .b_seq_len ,
517- infer_state1 .mem_index ,
518- )
519469 infer_state1 .init_some_extra_state (self , padded_model_input1 .input_ids )
520470
521471 if self .graph .need_capture (find_graph_batch_size ):
@@ -540,20 +490,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
540490 model_output1 = self ._create_unpad_decode_model_output (model_output1 , origin_batch_size = origin_batch_size )
541491 else :
542492 infer_state0 = self ._create_inferstate (model_input0 , 0 )
543- copy_kv_index_to_req (
544- self .req_manager .req_to_token_indexs ,
545- infer_state0 .b_req_idx ,
546- infer_state0 .b_seq_len ,
547- infer_state0 .mem_index ,
548- )
549493 infer_state0 .init_some_extra_state (self , model_input0 .input_ids )
550494 infer_state1 = self ._create_inferstate (model_input1 , 1 )
551- copy_kv_index_to_req (
552- self .req_manager .req_to_token_indexs ,
553- infer_state1 .b_req_idx ,
554- infer_state1 .b_seq_len ,
555- infer_state1 .mem_index ,
556- )
557495 infer_state1 .init_some_extra_state (self , model_input1 .input_ids )
558496
559497 model_output0 , model_output1 = self ._overlap_tpsp_token_forward (
@@ -654,10 +592,12 @@ def _check_max_len_infer(self):
654592 logger .info ("begin check max_len infer" )
655593 dummy_input_ids = torch .ones (self .batch_max_tokens , dtype = torch .int32 , device = "cuda" )
656594 b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda" )
657- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )).cuda ()
658595 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
659596 b_seq_len [:] = self .batch_max_tokens
660597 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
598+ mem_indexes = self .mem_manager .alloc (
599+ len (dummy_input_ids ), b_req_idx , b_seq_len , b_ready_cache_len , True
600+ ).cuda ()
661601 total_token_num = self .batch_max_tokens
662602 model_input = ModelInput (
663603 batch_size = 1 ,
0 commit comments