Skip to content

Commit d98579d

Browse files
committed
finish, pass test
1 parent fc22294 commit d98579d

File tree

6 files changed

+22
-12
lines changed

6 files changed

+22
-12
lines changed

docs/CN/source/getting_started/benchmark.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ ShareGPT 数据集测试 (benchmark_sharegpt.py)
8989
python test/benchmark/service/benchmark_sharegpt.py \
9090
--dataset /path/to/sharegpt_dataset.json \
9191
--tokenizer /path/to/tokenizer \
92-
--num_prompts 1000 \
93-
--request_rate 10.0
92+
--num-prompts 1000 \
93+
--request-rate 10.0
9494
9595
**主要参数:**
9696

9797
- ``--dataset``: ShareGPT 格式数据集路径
9898
- ``--tokenizer``: 分词器路径
99-
- ``--num_prompts``: 测试提示数量
100-
- ``--request_rate``: 请求速率 (requests/s)
99+
- ``--num-prompts``: 测试提示数量
100+
- ``--request-rate``: 请求速率 (requests/s)
101101

102102

103103
Prompt Cache 测试

docs/EN/source/getting_started/benchmark.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ Performance testing using ShareGPT real conversation data.
8888
python test/benchmark/service/benchmark_sharegpt.py \
8989
--dataset /path/to/sharegpt_dataset.json \
9090
--tokenizer /path/to/tokenizer \
91-
--num_prompts 1000 \
92-
--request_rate 10.0
91+
--num-prompts 1000 \
92+
--request-rate 10.0
9393
9494
**Main Parameters:**
9595

9696
- ``--dataset``: ShareGPT format dataset path
9797
- ``--tokenizer``: Tokenizer path
98-
- ``--num_prompts``: Number of test prompts
99-
- ``--request_rate``: Request rate (requests/s)
98+
- ``--num-prompts``: Number of test prompts
99+
- ``--request-rate``: Request rate (requests/s)
100100

101101
Prompt Cache Testing
102102
~~~~~~~~~~~~~~~~~~~

lightllm/common/basemodel/cuda_graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
4040
batch_sizes.append(max_batch_size)
4141
batch_sizes.sort()
4242

43+
if self.args.enable_fa3_mtp:
44+
step_size = self.args.mtp_step + 1
45+
batch_sizes = [b for b in batch_sizes if b % step_size == 0]
46+
4347
self.cuda_graph_batch_sizes = batch_sizes
4448
assert batch_sizes[-1] == self.max_batch_size
4549
logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}")

lightllm/common/basemodel/triton_kernel/gen_decode_params.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ def gen_decode_params(b_seq_len: torch.Tensor):
1212
mtp_step = get_env_start_args().mtp_step
1313
mtp_size = mtp_step + 1
1414
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp
15+
b_q_seq_len = torch.ones_like(b_seq_len)
1516

1617
if enable_fa3_mtp:
17-
b_q_seq_len = torch.ones_like(b_seq_len[: len(b_seq_len) // mtp_size])
18-
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size])
18+
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(
19+
b_q_seq_len[: len(b_seq_len) // mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size]
20+
)
1921
else:
2022
b_q_seq_len = torch.ones_like(b_seq_len)
2123
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def _token_gqa_decode_attention_mtp(
569569
v_cache=kv_nope,
570570
qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank),
571571
page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size],
572-
cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size],
572+
cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(),
573573
cu_seqlens_q=infer_state.cu_seqlens_q,
574574
cu_seqlens_k_new=infer_state.cu_seqlens_k,
575575
max_seqlen_q=1,
@@ -582,7 +582,7 @@ def _token_gqa_decode_attention_mtp(
582582
return_softmax_lse=False,
583583
mtp_step=self.mtp_step,
584584
)
585-
return o_tensor
585+
return o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank)
586586

587587
def _token_gqa_decode_attention_flashattention(
588588
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None

lightllm/server/api_start.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .router.manager import start_router_process
1616
from lightllm.utils.process_check import is_process_active
1717
from lightllm.utils.multinode_utils import send_and_receive_node_ip
18+
from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp
1819

1920
logger = init_logger(__name__)
2021

@@ -140,6 +141,9 @@ def normal_or_p_d_start(args):
140141

141142
if args.enable_fa3_mtp:
142143
assert args.mtp_mode is not None, "enable_fa3_mtp must set mtp_mode"
144+
assert (
145+
flash_attn_with_kvcache_mtp is not None
146+
), "flash_attn_with_kvcache_mtp is None, please check if you have installed the fa3_mtp kernel"
143147

144148
# 检查GPU数量是否足够
145149
if args.visual_gpu_ids is None:

0 commit comments

Comments
 (0)