|
1 | 1 | import torch |
2 | 2 | import collections |
| 3 | +import triton |
3 | 4 | from lightllm.utils.log_utils import init_logger |
4 | 5 | from .mem_manager import MemoryManager |
5 | 6 | from typing import List, Optional |
|
11 | 12 | logger = init_logger(__name__) |
12 | 13 |
|
13 | 14 |
|
14 | | -def cdiv(a, b): |
15 | | - return (a + b - 1) // b |
16 | | - |
17 | | - |
18 | 15 | class _ReqNode: |
19 | 16 | def __init__(self, index): |
20 | 17 | self.index = index |
@@ -71,25 +68,60 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana |
71 | 68 | self.max_request_num = max_request_num |
72 | 69 | self.HOLD_REQUEST_ID = max_request_num |
73 | 70 |
|
| 71 | + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): |
| 72 | + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) |
| 73 | + |
| 74 | + def alloc_mem_indices(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None) -> torch.Tensor: |
| 75 | + page_size = get_page_size() |
| 76 | + if page_size > 1 and b_req_idx is not None and b_seq_len is not None: |
| 77 | + return self._alloc_paged_mem_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) |
| 78 | + else: |
| 79 | + return self.mem_manager.alloc(need_size) |
| 80 | + |
| 81 | + def alloc(self): |
| 82 | + return self.req_list.alloc() |
| 83 | + |
| 84 | + def free(self, free_req_indexes: List[int], free_token_index): |
| 85 | + for req_index in free_req_indexes: |
| 86 | + self.req_list.free(req_index) |
| 87 | + |
| 88 | + if self.req_list.is_all_free(): |
| 89 | + logger.debug(f"freed all request size {self.req_list.can_alloc_size}") |
| 90 | + self.mem_manager.free(free_token_index) |
| 91 | + |
| 92 | + def free_req(self, free_req_index: int): |
| 93 | + self.req_list.free(free_req_index) |
| 94 | + if self.req_list.is_all_free(): |
| 95 | + logger.debug(f"freed all request size {self.req_list.can_alloc_size}") |
| 96 | + return |
| 97 | + |
| 98 | + def free_token(self, free_token_index): |
| 99 | + self.mem_manager.free(free_token_index) |
| 100 | + return |
| 101 | + |
| 102 | + def free_all(self): |
| 103 | + self.req_list = _ReqLinkedList(self.max_request_num) |
| 104 | + return |
| 105 | + |
74 | 106 | def _expand_by_page_size(self, b_token_len, page_size): |
75 | | - # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 |
76 | | - b_page_len = cdiv(b_token_len, page_size) |
| 107 | + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> p_token_len = [4,4,1,4,4,1,4,4,1], page_size = 4 |
| 108 | + b_page_len = triton.cdiv(b_token_len, page_size) |
77 | 109 | need_pages_num = b_page_len.sum() |
78 | 110 | p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) |
79 | 111 | cumsum_pages = torch.cumsum(b_page_len, dim=0) |
80 | 112 | last_page_positions = cumsum_pages - 1 |
81 | 113 | remainders = b_token_len - (b_page_len - 1) * page_size |
82 | 114 | p_token_len[last_page_positions] = remainders |
83 | | - return need_pages_num, b_page_len, p_token_len |
| 115 | + return need_pages_num, p_token_len |
84 | 116 |
|
85 | | - def _alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): |
| 117 | + def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len): |
86 | 118 | if b_ready_cache_len is not None: |
87 | 119 | # prefill |
88 | 120 | b_seq_len = b_seq_len.cpu() |
89 | 121 | b_ready_cache_len = b_ready_cache_len.cpu() |
90 | 122 |
|
91 | 123 | b_token_len = b_seq_len - b_ready_cache_len |
92 | | - total_pages_needed, b_page_len, p_token_len = self._expand_by_page_size(b_token_len, page_size) |
| 124 | + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) |
93 | 125 | paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) |
94 | 126 | pages = paged_token_idxs.view(-1, page_size) |
95 | 127 | mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) |
@@ -126,41 +158,6 @@ def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): |
126 | 158 | need_new_pages = mask.sum() |
127 | 159 | return need_new_pages * page_size |
128 | 160 |
|
129 | | - def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): |
130 | | - return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) |
131 | | - |
132 | | - def alloc_token_indices(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None) -> torch.Tensor: |
133 | | - page_size = get_page_size() |
134 | | - if page_size > 1: |
135 | | - return self._alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len) |
136 | | - else: |
137 | | - return self.mem_manager.alloc(need_size) |
138 | | - |
139 | | - def alloc(self): |
140 | | - return self.req_list.alloc() |
141 | | - |
142 | | - def free(self, free_req_indexes: List[int], free_token_index): |
143 | | - for req_index in free_req_indexes: |
144 | | - self.req_list.free(req_index) |
145 | | - |
146 | | - if self.req_list.is_all_free(): |
147 | | - logger.debug(f"freed all request size {self.req_list.can_alloc_size}") |
148 | | - self.mem_manager.free(free_token_index) |
149 | | - |
150 | | - def free_req(self, free_req_index: int): |
151 | | - self.req_list.free(free_req_index) |
152 | | - if self.req_list.is_all_free(): |
153 | | - logger.debug(f"freed all request size {self.req_list.can_alloc_size}") |
154 | | - return |
155 | | - |
156 | | - def free_token(self, free_token_index): |
157 | | - self.mem_manager.free(free_token_index) |
158 | | - return |
159 | | - |
160 | | - def free_all(self): |
161 | | - self.req_list = _ReqLinkedList(self.max_request_num) |
162 | | - return |
163 | | - |
164 | 161 |
|
165 | 162 | class ReqSamplingParamsManager: |
166 | 163 | """ |
|
0 commit comments