Skip to content

Commit 49843b2

Browse files
author
niushengxiao
committed
feat: support page size variable for deepseek2
1 parent 7f3c9c6 commit 49843b2

File tree

5 files changed

+145
-26
lines changed

5 files changed

+145
-26
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
import numpy as np
3+
from .deepseek2_mem_manager import Deepseek2MemoryManager
4+
from .page_size_variable_mem_manager import PageSizeVariableMemoryManager
5+
from lightllm.utils.log_utils import init_logger
6+
from lightllm.utils.envs_utils import get_page_size
7+
8+
9+
def cdiv(a, b):
10+
return (a + b - 1) // b
11+
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager):
17+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
18+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
19+
20+
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
21+
self.kv_buffer = torch.empty(
22+
(layer_num, cdiv(size, get_page_size()) * get_page_size(), head_num, head_dim),
23+
dtype=dtype,
24+
device="cuda",
25+
)

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import torch.distributed as dist
55
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
66
from lightllm.utils.dist_utils import get_current_device_id
7+
from lightllm.utils.envs_utils import get_page_size
8+
9+
10+
def cdiv(a, b):
11+
return (a + b - 1) // b
712

813

914
class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
@@ -38,20 +43,24 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3843
self.cu_seqlens_q = self.b1_cu_q_seq_len
3944
self.cu_seqlens_k = self.b1_cu_kv_seq_len
4045
max_seq_len_k = self.max_kv_seq_len
46+
page_size = get_page_size()
4147
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
42-
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(
43-
model.graph_max_batch_size, model.graph_max_len_in_batch
48+
length = cdiv(model.graph_max_len_in_batch, page_size)
49+
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
50+
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
51+
self.batch_size, length
4452
)
45-
self.page_table = page_buffer[self.microbatch_index][
46-
: self.batch_size * model.graph_max_len_in_batch
47-
].reshape(self.batch_size, model.graph_max_len_in_batch)
4853
else:
49-
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
50-
input_ids.device
51-
)
54+
length = cdiv(self.max_len_in_batch, page_size)
55+
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device)
5256

53-
self.page_table[:, :max_seq_len_k].copy_(
54-
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
55-
)
56-
self.page_table[:, max_seq_len_k:].fill_(0)
57+
if "page_size_variable" in model.mode:
58+
length = cdiv(max_seq_len_k, page_size)
59+
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
60+
self.page_table[:, length:].fill_(0)
61+
else:
62+
self.page_table[:, :max_seq_len_k].copy_(
63+
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
64+
)
65+
self.page_table[:, max_seq_len_k:].fill_(0)
5766
return

lightllm/models/deepseek2/flashinfer_struct.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
6-
from lightllm.utils.envs_utils import get_env_start_args
6+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
77
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
88

99

10+
def cdiv(a, b):
11+
return (a + b - 1) // b
12+
13+
1014
class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo):
1115
def __init__(self):
1216
super().__init__()
1317
self.prefill_wrapper = None
1418
self.decode_wrapper = None
1519
self.flashinfer_extra_state = None
20+
self.page_size = get_page_size()
1621

1722
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1823
super().init_some_extra_state(model, input_ids)
@@ -23,24 +28,37 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2328
if not self.is_prefill:
2429
if get_env_start_args().enable_flashinfer_decode:
2530
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
31+
length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size)
2632
if self.batch_size <= model.graph_max_batch_size:
2733
self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][
28-
: self.batch_size * self.flashinfer_extra_state.max_seq_length
34+
: self.batch_size * length
2935
]
3036
else:
3137
self.kv_indices = torch.empty(
32-
self.batch_size * self.flashinfer_extra_state.max_seq_length,
38+
self.batch_size * length,
3339
dtype=torch.int32,
3440
device=input_ids.device,
3541
)
36-
repack_kv_index(
37-
self.req_manager.req_to_token_indexs,
38-
self.b_req_idx,
39-
self.b_seq_len,
40-
self.b_start_loc,
41-
self.max_len_in_batch,
42-
self.kv_indices,
43-
)
42+
if "page_size_variable" in model.mode:
43+
b_page_len = cdiv(self.b_seq_len, self.page_size)
44+
self.kv_starts[1:] = b_page_len.cumsum(0)
45+
repack_kv_index(
46+
self.req_manager.req_to_page_indexs,
47+
self.b_req_idx,
48+
b_page_len,
49+
self.kv_starts[:-1],
50+
cdiv(self.max_len_in_batch, self.page_size),
51+
self.kv_indices,
52+
)
53+
else:
54+
repack_kv_index(
55+
self.req_manager.req_to_token_indexs,
56+
self.b_req_idx,
57+
self.b_seq_len,
58+
self.b_start_loc,
59+
self.max_len_in_batch,
60+
self.kv_indices,
61+
)
4462
if self.decode_wrapper is None:
4563
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
4664
self.flashinfer_extra_state.workspace_buffer,
@@ -58,7 +76,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5876
self.flashinfer_extra_state.tp_q_head_num,
5977
self.flashinfer_extra_state.kv_lora_rank,
6078
self.flashinfer_extra_state.qk_rope_head_dim,
61-
1,
79+
self.page_size,
6280
False, # causal
6381
self.flashinfer_extra_state.softmax_scale,
6482
self.flashinfer_extra_state.q_data_type,
@@ -97,7 +115,7 @@ def copy_for_cuda_graph(self, new_infer_state):
97115
new_infer_state.flashinfer_extra_state.tp_q_head_num,
98116
new_infer_state.flashinfer_extra_state.kv_lora_rank,
99117
new_infer_state.flashinfer_extra_state.qk_rope_head_dim,
100-
1,
118+
self.page_size,
101119
False, # causal
102120
new_infer_state.flashinfer_extra_state.softmax_scale,
103121
new_infer_state.flashinfer_extra_state.q_data_type,

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import partial
2727
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2828
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
29-
from lightllm.utils.envs_utils import get_env_start_args
29+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
3030
from lightllm.utils.dist_utils import get_global_world_size
3131
from lightllm.utils.log_utils import init_logger
3232
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
@@ -93,6 +93,18 @@ def _bind_attention(self):
9393
self._token_attention_kernel = partial(
9494
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self
9595
)
96+
elif "page_size_variable" in self.mode:
97+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
98+
if get_env_start_args().enable_fa3:
99+
self._token_attention_kernel = partial(
100+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention_paged, self
101+
)
102+
elif get_env_start_args().enable_flashinfer_decode:
103+
self._token_attention_kernel = partial(
104+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer_paged, self
105+
)
106+
else:
107+
raise Exception("Page size variable mode is not supported in other backends.")
96108
else:
97109
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
98110
if get_env_start_args().enable_fa3:
@@ -574,6 +586,36 @@ def _token_gqa_decode_attention_flashattention(
574586
)
575587
return o_tensor
576588

589+
def _token_gqa_decode_attention_flashattention_paged(
590+
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
591+
):
592+
page_size = get_page_size()
593+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
594+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
595+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
596+
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim)
597+
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank)
598+
k_descale, v_descale = None, None
599+
o_tensor = flash_attn_with_kvcache(
600+
q=q_rope,
601+
k_cache=k_rope,
602+
v_cache=kv_nope,
603+
qv=q_nope,
604+
page_table=infer_state.page_table,
605+
cache_seqlens=infer_state.b_seq_len,
606+
cu_seqlens_q=infer_state.cu_seqlens_q,
607+
cu_seqlens_k_new=infer_state.cu_seqlens_k,
608+
max_seqlen_q=1,
609+
softmax_scale=self.softmax_scale,
610+
causal=True,
611+
window_size=(-1, -1),
612+
softcap=0.0,
613+
k_descale=k_descale,
614+
v_descale=v_descale,
615+
return_softmax_lse=False,
616+
)
617+
return o_tensor
618+
577619
def _token_gqa_decode_attention_flashinfer(
578620
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
579621
):
@@ -593,6 +635,26 @@ def _token_gqa_decode_attention_flashinfer(
593635
)
594636
return o_tensor
595637

638+
def _token_gqa_decode_attention_flashinfer_paged(
639+
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
640+
):
641+
page_size = get_page_size()
642+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
643+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
644+
645+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
646+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
647+
648+
infer_state.decode_wrapper.run(
649+
q_nope,
650+
q_rope,
651+
kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank),
652+
kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim),
653+
out=o_tensor,
654+
return_lse=False,
655+
)
656+
return o_tensor
657+
596658
def _token_gqa_decode_attention_flashdecoding(
597659
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
598660
):

lightllm/models/deepseek2/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from lightllm.models.llama.model import LlamaTpPartModel
1212
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
13+
from lightllm.common.deepseek2_page_size_variable_mem_manager import Deepseek2PageSizeVariableMemoryManager
1314
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1415
from lightllm.utils.log_utils import init_logger
1516
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
@@ -97,6 +98,10 @@ def _init_mem_manager(self):
9798
manager_class = Deepseek2MemoryManager
9899
if "triton_fp8kv" in self.mode:
99100
manager_class = Deepseek2FP8KVMemoryManager
101+
elif "page_size_variable" in self.mode:
102+
manager_class = Deepseek2PageSizeVariableMemoryManager
103+
elif self.mode:
104+
raise ValueError(f"Unsupported mode for deepseek2: {self.mode}")
100105

101106
# mtp 模式下需要在mem manger上扩展draft model使用的layer
102107
added_mtp_layer_num = 0

0 commit comments

Comments
 (0)