Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router"
)
parser.add_argument(
"--past_future_scheduler",
action="store_true",
help="""use past_future_scheduler for adaptive request new token len prediction,
override --router_token_ratio and --router_max_new_token_len (still used during warmup)""",
)

parser.add_argument(
"--router_max_wait_tokens",
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def get_all_prompt_metadata(self):
class ChunkedPrefillReq(Req):
_pack_ = 4

def get_tuple_tokens(self, is_busy, router_max_new_token_len):
def get_tuple_tokens(self, is_busy, router_max_new_token_len, has_out_len_factor=1.1):
args = get_env_start_args()
# chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于
# 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存
Expand All @@ -283,7 +283,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len):
cur_max_new_token_len = self.sample_params.max_new_tokens
else:
cur_max_new_token_len = min(
self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), router_max_new_token_len)
self.sample_params.max_new_tokens, max(int(has_out_len_factor * has_out_len), router_max_new_token_len)
)

a_len = max(self.input_len + has_out_len + 1, self.shm_cur_kv_len + 1)
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,9 @@ async def recycle_resource_loop(self):
continue

logger.info(
f"left req id {req_status.group_req_objs.group_req_id}"
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
f"left req id: {req_status.group_req_objs.group_req_id}, "
f"can release: {req_status.group_req_objs.shm_req_objs[0].can_released_mark}, "
f"refcount: {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
return

Expand Down
3 changes: 3 additions & 0 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .shm_reqs_io_buffer import ShmReqsIOBuffer
from lightllm.utils.log_utils import init_logger, log_time_ready
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.router.req_queue.chunked_prefill.impl_past_future import PastFutureQueue
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager
Expand Down Expand Up @@ -319,6 +320,8 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):

def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
if isinstance(self.req_queue, PastFutureQueue):
self.req_queue.record_finished_len_from_batch(self.running_batch)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
Expand Down
20 changes: 14 additions & 6 deletions lightllm/server/router/req_queue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
from .chunked_prefill.impl_for_pd_decode import QueueForPDDecode
from .chunked_prefill.impl import ChunkedPrefillQueue
from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue
from .chunked_prefill.impl_past_future import PastFutureQueue
from .dp_base_queue import DpQueue


def _get_req_queue_class(args, router, dp_size_in_node: int):
if args.past_future_scheduler:
if args.diverse_mode:
raise ValueError("Diverse mode is not supported with past future scheduler yet")
chunked_prefill_queue_impl = PastFutureQueue
else:
chunked_prefill_queue_impl = ChunkedPrefillQueue

if args.diverse_mode:
return ChunkedBeamContinuesBatchQueue
if args.token_healing_mode:
return ChunkedPrefillQueue
return chunked_prefill_queue_impl
if args.output_constraint_mode != "none":
return ChunkedPrefillQueue
return chunked_prefill_queue_impl
if args.first_token_constraint_mode:
return ChunkedPrefillQueue
return chunked_prefill_queue_impl
if args.run_mode == "decode":
return QueueForPDDecode
if args.run_mode == "prefill":
return ChunkedPrefillQueue
return chunked_prefill_queue_impl

if args.disable_chunked_prefill:
# 虽然也使用chuncked prefill queue 但是由于 args.chunked_prefill_size = args.max_req_total_len
# 所以调度的实际行为类似过去的 continues batch 调度,所以将两种调度的实现统一为一种实现,减少代码重复。
return ChunkedPrefillQueue
return chunked_prefill_queue_impl
else:
return ChunkedPrefillQueue
return chunked_prefill_queue_impl


def build_req_queue(args, router, dp_size_in_node: int):
Expand Down
8 changes: 6 additions & 2 deletions lightllm/server/router/req_queue/chunked_prefill/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy):
self.cache_len_list = []
return

# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
def _update_cache_len_list(self, req: Req, is_busy):
self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])

Expand All @@ -32,6 +31,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)

need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
return need_max_token_num

# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
need_max_token_num = self._update_cache_len_list(req, is_busy)
with g_router_lock.obj:
ok_token_num = (
need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import bisect
from collections import deque
import random
from typing import List, Tuple
import numpy as np
from ...batch import Batch, Req
from .impl import ChunkedPrefillQueue


class PastFutureQueue(ChunkedPrefillQueue):
WINDOW_SIZE = 200
MINIMUM_SAMPLES = 200
MAXIMUM_LISTS = 5
REVERSED = 0.05
COMPLIANCE_IS_BUSY_FLAG = False

def __init__(self, args, router, dp_index, dp_size_in_node) -> None:
super().__init__(args, router, dp_index, dp_size_in_node)
initial_len = args.router_max_new_token_len
self.history_output_len = deque([initial_len] * (self.WINDOW_SIZE // 2), maxlen=self.WINDOW_SIZE)

def _sample_cache_list(self, reqs: List[Req], is_busy, samples=1) -> List[List[Tuple[int, int]]]:
cache_len_lists = [[] for _ in range(samples)]
his_Lo = sorted(self.history_output_len)
for req in reqs:
dl = req.shm_cur_output_len
pos = bisect.bisect(his_Lo, dl)

sample_range = [dl] + his_Lo[pos:] + [req.sample_params.max_new_tokens] # at least 2 value

for i in range(samples):
random_p = np.random.random() * (len(sample_range) - 1)
l_pos = int(random_p)
l_val, r_val = sample_range[l_pos : l_pos + 2]

# Linear interpolation
sampled = round(l_val + (r_val - l_val) * (random_p - l_pos))
cache_len_lists[i].append(
req.get_tuple_tokens(is_busy and self.COMPLIANCE_IS_BUSY_FLAG, sampled, has_out_len_factor=1.0)
)

return cache_len_lists

def _calc_max_token_num_needed(self, cache_len_list: List[Tuple[int, int]]) -> int:
cache_len_list.sort(key=lambda x: -x[1])

left_out_len_array = np.array([e[1] for e in cache_len_list])
has_run_len_array = np.array([e[0] for e in cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(cache_len_list) + 1, 1)

need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
return need_max_token_num

def _init_cache_list(self, current_batch: Batch, is_busy):
if current_batch is not None:
n_lists = min(self.MAXIMUM_LISTS, int(self.MINIMUM_SAMPLES / len(current_batch.reqs)) + 1)
local_reqs = [req for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index]
self._cache_len_lists = self._sample_cache_list(local_reqs, is_busy, samples=n_lists)
else:
self._cache_len_lists = [[]]
self.cache_len_list = self._cache_len_lists[0] # keep compatibility

def _update_cache_len_list(self, req: Req, is_busy):
need_max_token_nums = []
for li in self._cache_len_lists:
newreq_output_len_sample = random.choice(self.history_output_len)
li.append(
req.get_tuple_tokens(
is_busy and self.COMPLIANCE_IS_BUSY_FLAG, newreq_output_len_sample, has_out_len_factor=1.0
)
)
need_max_token_nums.append(self._calc_max_token_num_needed(li))
need_max_token_num = np.max(need_max_token_nums)
return need_max_token_num

def record_finished_len_from_batch(self, batch: Batch):
for req in batch.reqs:
if req.shm_infer_released:
self.history_output_len.append(req.shm_cur_output_len)