Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,10 +684,10 @@ def _check_max_len_infer(self):
logger.info("begin check max_len infer")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
mem_indexes = self.req_manager.alloc_mem_indices(len(dummy_input_ids), b_seq_len, b_ready_cache_len).cuda()
total_token_num = self.batch_max_tokens
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
model_input = ModelInput(
Expand Down Expand Up @@ -757,12 +757,14 @@ def _autotune_warmup(self):
0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen
)
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_seq_len[:] = input_len
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
total_token_num = input_len
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
mem_indexes = self.req_manager.alloc_mem_indices(
len(dummy_input_ids), b_seq_len, b_ready_cache_len
).cuda()
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
Expand Down
10 changes: 8 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,16 @@ def warmup(self, model):
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
b_last_mem_index = torch.zeros_like(b_seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
mem_indexes = model.req_manager.alloc_mem_indices(
len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index
).cuda()

model_input = ModelInput(
batch_size=batch_size,
Expand Down Expand Up @@ -252,13 +255,16 @@ def warmup_overlap(self, model):
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
b_last_mem_index = torch.zeros_like(b_seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
mem_indexes = model.req_manager.alloc_mem_indices(
len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index
).cuda()

micro_batch = ModelInput(
is_prefill=False,
Expand Down
8 changes: 7 additions & 1 deletion lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
from lightllm.distributed.pynccl import PyNcclCommunicator
from lightllm.utils.envs_utils import get_page_size

logger = init_logger(__name__)

Expand All @@ -20,7 +21,12 @@ def get_cell_size(self):
return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda")
page_size = get_page_size()
self.kv_buffer = torch.empty(
(layer_num, (size // page_size + 1) * page_size, head_num, head_dim),
dtype=dtype,
device="cuda",
)

# todo, etp or edp use the same work buffer here
# also it can be used for any kernels for work buffer witout save info only
Expand Down
30 changes: 22 additions & 8 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import os
import torch
import torch.distributed as dist
import triton
from typing import List, Union
from lightllm.server.pd_io_struct import KVMoveTask
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size
from lightllm.distributed.pynccl import PyNcclCommunicator
from lightllm.utils.dist_utils import get_current_device_id

Expand Down Expand Up @@ -81,7 +82,12 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
# 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch
# 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX
# 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。
self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")
page_size = get_page_size()
self.kv_buffer = torch.empty(
(layer_num, (size // page_size + 1) * page_size, 2 * head_num, head_dim),
dtype=dtype,
device="cuda",
)

def alloc_kv_move_buffer(self, max_req_total_len):
"""
Expand Down Expand Up @@ -244,6 +250,7 @@ def _free_buffers(self):
self.kv_buffer = None

def alloc(self, need_size) -> torch.Tensor:
assert need_size % get_page_size() == 0, "Need size must be a multiple of page size"
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"
Expand All @@ -265,18 +272,25 @@ def free(self, free_index: Union[torch.Tensor, List[int]]):
"""

end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
page_size = get_page_size()
free_len = page_size * triton.cdiv(len(free_index), page_size)
start = self.mark_start - free_len
assert start >= 0, f"error free state start: {self.mark_start} free len {free_len}"

if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
free_index = torch.tensor(free_index)

# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
if page_size > 1:
base_free_index = free_index[free_index % page_size == 0]
token_idxs = base_free_index[:, None] + torch.arange(page_size)
self.mem_state[start:end] = token_idxs.flatten()
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index

self.mark_start -= len(free_index)
self.mark_start -= free_len

self.can_use_mem_size += len(free_index)
self.can_use_mem_size += free_len
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.can_use_mem_size == len(self.mem_state):
Expand Down
76 changes: 75 additions & 1 deletion lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import collections
import triton
from lightllm.utils.log_utils import init_logger
from .mem_manager import MemoryManager
from typing import List, Optional
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size
from lightllm.utils.config_utils import get_vocab_size

logger = init_logger(__name__)
Expand Down Expand Up @@ -67,6 +68,27 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
self.max_request_num = max_request_num
self.HOLD_REQUEST_ID = max_request_num

def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None):
return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len))

def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None):
b_token_len = b_seq_len
if b_ready_cache_len is not None:
b_token_len = b_seq_len - b_ready_cache_len
b_token_len_cumsum = torch.cumsum(b_token_len, dim=0)
b_last_mem_index = mem_indices[b_token_len_cumsum - 1]
return b_last_mem_index

# b_ready_cache_len为None时才需要b_last_mem_index
def alloc_mem_indices(
self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None
) -> torch.Tensor:
page_size = get_page_size()
if page_size > 1 and b_seq_len is not None:
return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index)
else:
return self.mem_manager.alloc(need_size)

def alloc(self):
return self.req_list.alloc()

Expand All @@ -92,6 +114,58 @@ def free_all(self):
self.req_list = _ReqLinkedList(self.max_request_num)
return

def _expand_by_page_size(self, b_token_len, page_size):
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> p_token_len = [4,4,1,4,4,1,4,4,1], page_size = 4
b_page_len = triton.cdiv(b_token_len, page_size)
need_pages_num = b_page_len.sum()
p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device)
cumsum_pages = torch.cumsum(b_page_len, dim=0)
last_page_positions = cumsum_pages - 1
remainders = b_token_len - (b_page_len - 1) * page_size
p_token_len[last_page_positions] = remainders
return need_pages_num, p_token_len

def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index):
b_seq_len = b_seq_len.cpu()
if b_ready_cache_len is not None:
# prefill
b_ready_cache_len = b_ready_cache_len.cpu()
b_token_len = b_seq_len - b_ready_cache_len
total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size)
paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size)
pages = paged_token_idxs.view(-1, page_size)
mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1)
return pages[mask]
else:
# decode
assert b_last_mem_index is not None
b_last_mem_index = b_last_mem_index.cpu()
need_new_page_mask = (b_seq_len - 1) % page_size == 0
new_pages_num = need_new_page_mask.sum()
token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device)
if new_pages_num > 0:
new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size)
token_idxs[need_new_page_mask] = new_pages_tokens[::page_size]
mask = ~need_new_page_mask
if mask.any():
token_idxs[mask] = b_last_mem_index[mask] + 1
return token_idxs

def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None):
page_size = get_page_size()
if page_size == 1:
return 0

need_new_pages = 0
if b_ready_cache_len is not None:
need_tokens_array = b_seq_len - b_ready_cache_len
need_pages_array = (need_tokens_array + page_size - 1) // page_size
need_new_pages = need_pages_array.sum()
else:
mask = (b_seq_len - 1) % page_size == 0
need_new_pages = mask.sum()
return need_new_pages * page_size


class ReqSamplingParamsManager:
"""
Expand Down
27 changes: 15 additions & 12 deletions lightllm/models/deepseek2/flashattention_infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import torch
import numpy as np
import torch.distributed as dist
import triton
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.envs_utils import get_page_size


class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
_shared_page_table_buffer = None

def __init__(self):
super().__init__()
self.page_size = get_page_size()

@classmethod
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
Expand Down Expand Up @@ -39,19 +42,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.cu_seqlens_k = self.b1_cu_kv_seq_len
max_seq_len_k = self.max_kv_seq_len
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(
model.graph_max_batch_size, model.graph_max_len_in_batch
length = triton.cdiv(model.graph_max_len_in_batch, self.page_size)
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
self.batch_size, length
)
self.page_table = page_buffer[self.microbatch_index][
: self.batch_size * model.graph_max_len_in_batch
].reshape(self.batch_size, model.graph_max_len_in_batch)
else:
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
input_ids.device
)
length = triton.cdiv(self.max_len_in_batch, self.page_size)
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device)

self.page_table[:, :max_seq_len_k].copy_(
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
)
self.page_table[:, max_seq_len_k:].fill_(0)
length = triton.cdiv(max_seq_len_k, self.page_size)
token_indexs = model.req_manager.req_to_token_indexs[
self.b_req_idx, : length * self.page_size : self.page_size
]
self.page_table[:, :length].copy_(token_indexs // self.page_size)
self.page_table[:, length:].fill_(0)
return
22 changes: 14 additions & 8 deletions lightllm/models/deepseek2/flashinfer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import torch
import numpy as np
import torch.distributed as dist
import triton
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index


Expand All @@ -13,6 +14,7 @@ def __init__(self):
self.prefill_wrapper = None
self.decode_wrapper = None
self.flashinfer_extra_state = None
self.page_size = get_page_size()

def init_some_extra_state(self, model, input_ids: torch.Tensor):
super().init_some_extra_state(model, input_ids)
Expand All @@ -23,22 +25,26 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
if not self.is_prefill:
if get_env_start_args().enable_flashinfer_decode:
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size)
if self.batch_size <= model.graph_max_batch_size:
self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][
: self.batch_size * self.flashinfer_extra_state.max_seq_length
: self.batch_size * length
]
else:
self.kv_indices = torch.empty(
self.batch_size * self.flashinfer_extra_state.max_seq_length,
self.batch_size * length,
dtype=torch.int32,
device=input_ids.device,
)
b_page_len = triton.cdiv(self.b_seq_len, self.page_size)
self.kv_starts[1:] = b_page_len.cumsum(0)
repack_kv_index(
self.req_manager.req_to_token_indexs,
self.b_req_idx,
self.b_seq_len,
self.b_start_loc,
self.max_len_in_batch,
b_page_len,
self.kv_starts[:-1],
triton.cdiv(self.max_len_in_batch, self.page_size),
self.page_size,
self.kv_indices,
)
if self.decode_wrapper is None:
Expand All @@ -58,7 +64,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.flashinfer_extra_state.tp_q_head_num,
self.flashinfer_extra_state.kv_lora_rank,
self.flashinfer_extra_state.qk_rope_head_dim,
1,
self.page_size,
False, # causal
self.flashinfer_extra_state.softmax_scale,
self.flashinfer_extra_state.q_data_type,
Expand Down Expand Up @@ -97,7 +103,7 @@ def copy_for_cuda_graph(self, new_infer_state):
new_infer_state.flashinfer_extra_state.tp_q_head_num,
new_infer_state.flashinfer_extra_state.kv_lora_rank,
new_infer_state.flashinfer_extra_state.qk_rope_head_dim,
1,
self.page_size,
False, # causal
new_infer_state.flashinfer_extra_state.softmax_scale,
new_infer_state.flashinfer_extra_state.q_data_type,
Expand Down
Loading