@@ -71,10 +71,21 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
7171 def calc_real_need_token_num (self , need_token_num , b_seq_len , b_ready_cache_len = None ):
7272 return max (need_token_num , self ._get_need_paged_token_num (b_seq_len , b_ready_cache_len ))
7373
74- def alloc_mem_indices (self , need_size , b_req_idx = None , b_seq_len = None , b_ready_cache_len = None ) -> torch .Tensor :
74+ def calc_last_mem_index_in_prefill (self , mem_indices , b_seq_len , b_ready_cache_len = None ):
75+ b_token_len = b_seq_len
76+ if b_ready_cache_len is not None :
77+ b_token_len = b_seq_len - b_ready_cache_len
78+ b_token_len_cumsum = torch .cumsum (b_token_len , dim = 0 )
79+ b_last_mem_index = mem_indices [b_token_len_cumsum - 1 ]
80+ return b_last_mem_index
81+
82+ # b_ready_cache_len为None时才需要b_last_mem_index
83+ def alloc_mem_indices (
84+ self , need_size , b_seq_len = None , b_ready_cache_len = None , b_last_mem_index = None
85+ ) -> torch .Tensor :
7586 page_size = get_page_size ()
76- if page_size > 1 and b_req_idx is not None and b_seq_len is not None :
77- return self ._alloc_paged_mem_indices (b_req_idx , page_size , b_seq_len , b_ready_cache_len )
87+ if page_size > 1 and b_seq_len is not None :
88+ return self ._alloc_paged_mem_indices (page_size , b_seq_len , b_ready_cache_len , b_last_mem_index )
7889 else :
7990 return self .mem_manager .alloc (need_size )
8091
@@ -114,12 +125,11 @@ def _expand_by_page_size(self, b_token_len, page_size):
114125 p_token_len [last_page_positions ] = remainders
115126 return need_pages_num , p_token_len
116127
117- def _alloc_paged_mem_indices (self , b_req_idx , page_size , b_seq_len , b_ready_cache_len ):
128+ def _alloc_paged_mem_indices (self , page_size , b_seq_len , b_ready_cache_len , b_last_mem_index ):
129+ b_seq_len = b_seq_len .cpu ()
118130 if b_ready_cache_len is not None :
119131 # prefill
120- b_seq_len = b_seq_len .cpu ()
121132 b_ready_cache_len = b_ready_cache_len .cpu ()
122-
123133 b_token_len = b_seq_len - b_ready_cache_len
124134 total_pages_needed , p_token_len = self ._expand_by_page_size (b_token_len , page_size )
125135 paged_token_idxs = self .mem_manager .alloc (total_pages_needed * page_size )
@@ -128,19 +138,17 @@ def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cach
128138 return pages [mask ]
129139 else :
130140 # decode
131- b_seq_len = b_seq_len . cuda ()
132- b_req_idx = b_req_idx . cuda ()
141+ assert b_last_mem_index is not None
142+ b_last_mem_index = b_last_mem_index . cpu ()
133143 need_new_page_mask = (b_seq_len - 1 ) % page_size == 0
134- new_pages_num = need_new_page_mask .sum (). cpu ()
144+ new_pages_num = need_new_page_mask .sum ()
135145 token_idxs = torch .zeros_like (b_seq_len , device = b_seq_len .device )
136146 if new_pages_num > 0 :
137- new_pages_tokens = self .mem_manager .alloc (new_pages_num * page_size ). cuda ()
147+ new_pages_tokens = self .mem_manager .alloc (new_pages_num * page_size )
138148 token_idxs [need_new_page_mask ] = new_pages_tokens [::page_size ]
139-
140149 mask = ~ need_new_page_mask
141150 if mask .any ():
142- seq_lens = b_seq_len [mask ]
143- token_idxs [mask ] = self .req_to_token_indexs [b_req_idx [mask ], seq_lens - 2 ] + 1
151+ token_idxs [mask ] = b_last_mem_index [mask ] + 1
144152 return token_idxs
145153
146154 def _get_need_paged_token_num (self , b_seq_len , b_ready_cache_len = None ):
0 commit comments