diff --git a/docker/Dockerfile b/docker/Dockerfile index 6d67fcf4d..54eb57d7d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -39,8 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly -# TODO: offline compile -# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . +RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl && \ + pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel diff --git a/docker/Dockerfile.deepep b/docker/Dockerfile.deepep index e765978b9..1243e0467 100644 --- a/docker/Dockerfile.deepep +++ b/docker/Dockerfile.deepep @@ -39,8 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly -# TODO: offline compile -# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . +RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl && \ + pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev diff --git a/docs/CN/source/getting_started/benchmark.rst b/docs/CN/source/getting_started/benchmark.rst index c9fc778aa..cfcd2c0ed 100644 --- a/docs/CN/source/getting_started/benchmark.rst +++ b/docs/CN/source/getting_started/benchmark.rst @@ -89,15 +89,15 @@ ShareGPT 数据集测试 (benchmark_sharegpt.py) python test/benchmark/service/benchmark_sharegpt.py \ --dataset /path/to/sharegpt_dataset.json \ --tokenizer /path/to/tokenizer \ - --num_prompts 1000 \ - --request_rate 10.0 + --num-prompts 1000 \ + --request-rate 10.0 **主要参数:** - ``--dataset``: ShareGPT 格式数据集路径 - ``--tokenizer``: 分词器路径 -- ``--num_prompts``: 测试提示数量 -- ``--request_rate``: 请求速率 (requests/s) +- ``--num-prompts``: 测试提示数量 +- ``--request-rate``: 请求速率 (requests/s) Prompt Cache 测试 diff --git a/docs/EN/source/getting_started/benchmark.rst b/docs/EN/source/getting_started/benchmark.rst index 87caaa06a..5587b8a11 100755 --- a/docs/EN/source/getting_started/benchmark.rst +++ b/docs/EN/source/getting_started/benchmark.rst @@ -88,15 +88,15 @@ Performance testing using ShareGPT real conversation data. python test/benchmark/service/benchmark_sharegpt.py \ --dataset /path/to/sharegpt_dataset.json \ --tokenizer /path/to/tokenizer \ - --num_prompts 1000 \ - --request_rate 10.0 + --num-prompts 1000 \ + --request-rate 10.0 **Main Parameters:** - ``--dataset``: ShareGPT format dataset path - ``--tokenizer``: Tokenizer path -- ``--num_prompts``: Number of test prompts -- ``--request_rate``: Request rate (requests/s) +- ``--num-prompts``: Number of test prompts +- ``--request-rate``: Request rate (requests/s) Prompt Cache Testing ~~~~~~~~~~~~~~~~~~~ diff --git a/lightllm/common/flash_attn.py b/lightllm/common/flash_attn.py new file mode 100644 index 000000000..66609e700 --- /dev/null +++ b/lightllm/common/flash_attn.py @@ -0,0 +1,103 @@ +import torch +from typing import List, Optional, Tuple, Union +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def get_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +try: + import flash_attn_3._C # Registers operators with PyTorch + + flash_attn_3_mtp = torch.ops.flash_attn_3 + + def flash_attn_with_kvcache_mtp( + q, + k, + v, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + q_v: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + is_causal=False, + window_size=(-1, -1), + softcap=0.0, # 0.0 means deactivated + is_rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, + pack_gqa=None, + sm_margin=0, + mtp_step=0, + ): + assert k.stride(-1) == 1, "k must have contiguous last dimension" + assert v.stride(-1) == 1, "v must have contiguous last dimension" + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (q_v.shape[-1] if q_v is not None else 0)) ** (-0.5) + seqused_k = get_contiguous(seqused_k) + + q, k, k_new, v_new = [get_contiguous(x) for x in (q, k, k_new, v_new)] + v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v + cu_seqlens_q, cu_seqlens_k_new = [get_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)] + page_table = get_contiguous(page_table) + out, softmax_lse, *rest = flash_attn_3_mtp.fwd( + q, + k, + v, + k_new, + v_new, + q_v, + None, # out + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + seqused_k, + max_seqlen_q, + None, # max_seqlen_k + page_table, + cache_batch_idx, + cache_leftpad, + rotary_cos, + rotary_sin, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + is_causal, + window_size[0], + window_size[1], + 0, + softcap, + is_rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + mtp_step, + ) + return out + +except: + flash_attn_3_mtp = None + flash_attn_with_kvcache_mtp = None + logger.warning("flash_attn_3._C is not available, please install flash-attention-3 package.") diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index e4f64d34d..af5dea062 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -30,6 +30,7 @@ from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 +from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp logger = init_logger(__name__) @@ -70,6 +71,8 @@ def __init__(self, layer_num, network_config, mode=[]): super().__init__(layer_num, network_config, mode) self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] + self.mtp_step = get_env_start_args().mtp_step + self.mtp_size = self.mtp_step + 1 return def _bind_func(self): @@ -95,7 +98,11 @@ def _bind_attention(self): ) else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: + if get_env_start_args().enable_fa3_mtp: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_mtp, self + ) + elif get_env_start_args().enable_fa3: self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self ) @@ -559,6 +566,35 @@ def _context_attention_kernel_origin_fp8( ) return o_tensor + def _token_gqa_decode_attention_mtp( + self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache_mtp( + q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim), + k=k_rope, + v=kv_nope, + q_v=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank), + page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size], + seqused_k=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(), + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=self.softmax_scale, + is_causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + mtp_step=self.mtp_step, + ) + return o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank) + def _token_gqa_decode_attention_flashattention( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index a08147769..1d71bfed7 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -69,7 +69,7 @@ def __init__(self, kvargs): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: + if get_env_start_args().enable_fa3 or get_env_start_args().enable_fa3_mtp: self.infer_state_class = Deepseek2FlashAttentionStateInfo elif self.enable_flashinfer: self.infer_state_class = Deepseek2FlashInferStateInfo diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 8fa519578..19dbb2658 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -507,6 +507,11 @@ def make_argument_parser() -> argparse.ArgumentParser: but ensure that the model is compatible with the specified step count. currently, deepseekv3 model only support 1 step""", ) + parser.add_argument( + "--enable_fa3_mtp", + action="store_true", + help="""inference backend will use the fa3_mtp kernel for decode with MTP mode""", + ) parser.add_argument( "--kv_quant_calibration_config_path", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 5bd61666e..047409a4d 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -15,6 +15,7 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp from lightllm.utils.shm_size_check import check_recommended_shm_size logger = init_logger(__name__) @@ -139,6 +140,12 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + if args.enable_fa3_mtp: + assert args.mtp_mode is not None, "enable_fa3_mtp must set mtp_mode" + assert ( + flash_attn_with_kvcache_mtp is not None + ), "flash_attn_with_kvcache_mtp is None, please check if you have installed the fa3_mtp kernel" + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ce5ef56fc..c60d68c53 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -100,6 +100,7 @@ class StartArgs: mtp_mode: Optional[str] = field(default=None) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) + enable_fa3_mtp: bool = field(default=False) kv_quant_calibration_config_path: Optional[str] = field(default=None) nixl_pd_kv_page_num: int = field(default=16) nixl_pd_kv_page_size: int = field(default=1024) diff --git a/lightllm/utils/bench_utils.py b/lightllm/utils/bench_utils.py new file mode 100644 index 000000000..e4b2100c6 --- /dev/null +++ b/lightllm/utils/bench_utils.py @@ -0,0 +1,118 @@ +# This file is adapted from tile-ai/tilelang: +# https://github.com/tile-ai/tilelang/blob/main/tilelang/profiler/bench.py +# The original code and this file are licensed under the Apache License, Version 2.0. +# +# Copyright (c) sgl-project and other contributors. +# Modifications Copyright (c) LightLLM contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The profiler and convert to torch utils""" + +import torch +from typing import Callable, List, Literal, Optional, Union + + +def do_bench( + fn: Callable, + warmup: float = 25, + rep: float = 100, + _n_warmup: int = 0, + _n_repeat: int = 0, + grad_to_none: Optional[List[torch.Tensor]] = None, + quantiles: Optional[List[float]] = None, + fast_flush: bool = True, + return_mode: Literal["min", "max", "mean", "median"] = "mean", +) -> Union[float, List[float]]: + """Benchmarks the runtime of a PyTorch function. + + This function handles: + - L2 cache flushing between runs for consistent timing + - Automatic warmup and repeat count calculation + - Optional gradient clearing for backward passes + - Multiple measurement modes (mean, median, min, max) + + Args: + fn: Function to benchmark + warmup: Target warmup time in milliseconds + rep: Target number of repetitions + _n_warmup: Override for number of warmup iterations + _n_repeat: Override for number of timing iterations + grad_to_none: Tensors whose gradients should be cleared between runs + quantiles: Optional performance percentiles to compute + fast_flush: Whether to use faster L2 cache flushing + return_mode: How to aggregate timing results ("mean", "median", "min", "max") + + Returns: + float: Aggregated runtime in milliseconds + """ + assert return_mode in ["min", "max", "mean", "median"] + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + if _n_warmup > 0: + n_warmup = _n_warmup + if _n_repeat > 0: + n_repeat = _n_repeat + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() diff --git a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py new file mode 100644 index 000000000..08564c669 --- /dev/null +++ b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py @@ -0,0 +1,214 @@ +# This file is adapted from tile-ai/tilelang: +# https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/example_mla_decode_paged.py +# The original code and this file are licensed under the Apache License, Version 2.0. +# +# Copyright (c) sgl-project and other contributors. +# Modifications Copyright (c) LightLLM contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# type: ignore +import torch +import argparse +import math +from typing import Callable, Optional, List, Literal, Union +from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp +from lightllm.utils.bench_utils import do_bench + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype +): + # q: [b, s_q, h_q, d] + # block_table: [b, max_seqlen_pad // block_size] + # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] + # cache_seqlens: [b] + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device) + for i in range(b): + seq_len = cache_seqlens[i // 2] - ((i + 1) % 2) + kv_indices = block_table[i // 2, :seq_len] # 获取前seq_len个block索引 + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[kv_indices].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[kv_indices].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out.to(dtype), lse.to(dtype) + + out_torch, _ = ref_mla() + return out_torch + + +def run_fa3_mla_mtp( + mtp_size, + q, + block_table, + blocked_k, + max_seqlen_pad, + block_size, + b, + s_q, + cache_seqlens, + h_q, + h_kv, + d, + dv, + causal, + dtype, +): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + + batch_mtp = b // mtp_size + cu_seqlens_q = torch.arange(0, batch_mtp + 1, step=s_q, dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.cumsum(cache_seqlens, dim=0) + cu_seqlens_k = torch.cat([torch.tensor([0]).to(cu_seqlens_k), cu_seqlens_k]) + scale = (1.0 / (dv + dpe)) ** 0.5 # log2(e) + k_descale, v_descale = None, None + BLOCK_H = h_q * mtp_size + + def flash_mla_fa3(): + out = flash_attn_with_kvcache_mtp( + q=q_pe.view(-1, BLOCK_H, dpe), + k=blocked_k_pe, + v=blocked_k_nope, + q_v=q_nope.view(-1, BLOCK_H, dv), + page_table=block_table, + seqused_k=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=scale, + is_causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + mtp_step=1, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_fa3() + t = do_bench(flash_mla_fa3) + + out_ref = run_torch_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + # 计算相对绝对误差 + def print_error(a, b, name=""): + max_absolute_error = torch.abs(a - b).max() + relative_abs_error = torch.abs(a - b) / (torch.abs(a) + 1e-4) + max_relative_abs_error = relative_abs_error.max() + mean_relative_abs_error = relative_abs_error.mean() + + print(f"{name}: Maximum absolute difference: {max_absolute_error:.6e}") + print(f"Maximum relative absolute error: {max_relative_abs_error:.6e}") + print(f"Mean relative absolute error: {mean_relative_abs_error:.6e}") + + print_error(out_flash, out_ref, "out_flash, out_ref") + torch.testing.assert_close(out_flash, out_ref, rtol=0.001, atol=0.001) + print("All close") + return out_flash, t + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=16, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") + parser.add_argument("--mtp_size", type=int, default=2, help="Specifies the number of tokens per prediction.") + args = parser.parse_args() + b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv + mtp_size = args.mtp_size + + device = "cuda" + dtype = torch.float16 + + s_q = 1 # for decode, s_q = 1 + block_size = 1 + batch_mtp = b // mtp_size + cache_seqlens = torch.tensor([cache_seqlen + i for i in range(batch_mtp)], dtype=torch.int32, device=device) + # print(cache_seqlens[-1]) + dpe = d - dv + causal = True + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256 + + total_flops = s_q * total_seqlens * h_q * (d + dv) * 2 * mtp_size + + q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) + block_table = torch.arange(batch_mtp * max_seqlen_pad, dtype=torch.int32, device=device).view( + batch_mtp, max_seqlen_pad + ) + + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) + out_flash, latency = run_fa3_mla_mtp( + mtp_size, + q, + block_table, + blocked_k, + max_seqlen_pad, + block_size, + b, + s_q, + cache_seqlens, + h_q, + h_kv, + d, + dv, + causal, + dtype, + ) + + print("Tile-lang: {:.3f} ms".format(latency)) + print("Tile-lang: {:.3f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/test/benchmark/service/benchmark_longbench.py b/test/benchmark/service/benchmark_longbench.py new file mode 100644 index 000000000..53b9eb360 --- /dev/null +++ b/test/benchmark/service/benchmark_longbench.py @@ -0,0 +1,359 @@ +# Adapted from benchmarks/benchmark_serving.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import json +import random +import time +from typing import AsyncGenerator, List, Tuple, Union + +import aiohttp +import numpy as np +from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase +from tqdm.asyncio import tqdm + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + + +def get_tokenizer( + tokenizer_name: str, + tokenizer_mode: str = "auto", + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): + pass + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs) + except TypeError as e: + err_msg = "Failed to load the tokenizer. {e}" + raise RuntimeError(err_msg) from e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + pass + return tokenizer + + +# (prompt len, output len, latency) +REQUEST_LATENCY: List[Tuple[int, int, float]] = [] + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + max_total_tokens: int = 16384, +) -> List[Tuple[List[dict], str, int, int]]: + # Load the dataset (jsonl) + dataset = [] + with open(dataset_path) as f: + for line in f.readlines(): + if not line.strip(): + continue + dataset.append(json.loads(line)) + print("read data set finish") + + def render_with_template(messages: List[dict]) -> str: + try: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + parts = [] + for m in messages: + parts.append(f"{m['role']}: {m['content']}") + parts.append("assistant:") + return "\n".join(parts) + + built_examples: List[Tuple[List[dict], str, int, int]] = [] + + for data in dataset: + context = data.get("context") or "" + question = data.get("input") or "Summarizing government work reports" + answers = data.get("answers") + if not isinstance(context, str) or not isinstance(question, str): + continue + + # Build messages: system + user with context and question + system_prompt = "You are a helpful assistant. Read the context and answer the question concisely." + user_content = f"Context:\n{context}\nInput:\n{question}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + rendered_prompt = render_with_template(messages) + prompt_len = len(tokenizer(rendered_prompt).input_ids) + + # Estimate output length from reference answer if available + target_text = "" + if isinstance(answers, list) and len(answers) > 0: + first_ans = answers[0] + if isinstance(first_ans, str): + target_text = first_ans + else: + target_text = str(first_ans) + elif isinstance(answers, str): + target_text = answers + + estimated_out = len(tokenizer(target_text).input_ids) if target_text else 128 + + # Fit within max_total_tokens + available_out = max_total_tokens - 1 - prompt_len + if available_out < 4: + # Skip samples that are too long + continue + output_len = min(estimated_out, available_out) + + built_examples.append((messages, rendered_prompt, prompt_len, output_len)) + + # Take the first N valid samples + sampled_requests = built_examples[:num_requests] + sum_len = 0 + for _, _, prompt_len, output_len in sampled_requests: + sum_len += prompt_len + output_len + print("total tokens:", sum_len) + return sampled_requests + + +async def get_request( + input_requests: List[Tuple[List[dict], str, int, int]], + request_rate: float, + concurrency: int = None, +) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: + input_requests = iter(input_requests) + + if concurrency is not None: + # Concurrency-based request generation + # This generator will be consumed by the benchmark function + # which will manage the concurrency + for request in input_requests: + yield request + else: + # Rate-based request generation (original logic) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +async def send_request( + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool, pbar=None +) -> None: + if use_openai_api: + # Use OpenAI API to send the request. + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/v1/chat/completions" + + data = { + "model": "DeepSeek-R1", + "messages": messages, + "top_k": 1, + "top_p": 1.0, + "temperature": 0, + "stream": True, + "ignore_eos": True, + "max_tokens": output_len, + } + timeout = aiohttp.ClientTimeout(total=3 * 3600) + receive_n = 1 + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + text = "" + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + # text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + if delta_time < 0.005: + receive_n += 1 + chunks.append(delta_time) + start_time = now_time + + else: + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/generate_stream" + + data = { + "inputs": rendered_prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + }, + } + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + receive_n = 0 + text = "" + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + if delta_time < 0.005: + receive_n += 1 + chunks.append(chunk) + text += json.loads(chunk.decode("utf-8")[5:])["token"]["text"] + start_time = now_time + + request_end_time = time.time() + request_latency = request_end_time - request_start_time + REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) + + # Update progress bar if provided + if pbar: + pbar.update(1) + + +async def benchmark( + input_requests: List[Tuple[List[dict], str, int, int]], + request_rate: float, + use_openai_api: bool = False, + concurrency: int = None, +) -> None: + total_requests = len(input_requests) + + # Create progress bar + pbar = tqdm(total=total_requests, desc="Processing requests", unit="req") + + if concurrency is not None: + # Concurrency-based processing + semaphore = asyncio.Semaphore(concurrency) + tasks: List[asyncio.Task] = [] + + async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len): + async with semaphore: + await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len)) + tasks.append(task) + + await asyncio.gather(*tasks) + else: + # Rate-based processing (original logic) + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task( + send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Close progress bar + pbar.close() + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + tokenizer = get_tokenizer(args.tokenizer, "slow") + input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_total_tokens) + + benchmark_start_time = time.time() + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency)) + benchmark_end_time = time.time() + benchmark_time = benchmark_end_time - benchmark_start_time + print(f"Total time: {benchmark_time:.2f} s") + print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") + + # Compute the latency statistics. + avg_latency = np.mean([latency for _, _, latency, _ in REQUEST_LATENCY]) + print(f"Average latency: {avg_latency:.2f} s") + avg_time_to_first_token = np.mean([ttft for _, _, _, ttft in REQUEST_LATENCY]) + print("Average time to first token: " f"{avg_time_to_first_token:.2f} s") + avg_per_token_latency = ( + np.mean([latency / (prompt_len + output_len) for prompt_len, output_len, latency, _ in REQUEST_LATENCY]) * 1000 + ) + print(f"Average latency per token: {avg_per_token_latency:.1f} ms") + # avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency, _ in REQUEST_LATENCY]) + # print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + avg_inter_token_latency = ( + np.mean( + [(latency - ttft) / (output_len - 1) for _, output_len, latency, ttft in REQUEST_LATENCY if output_len > 1] + ) + * 1000 + ) + print(f"Average inter-token latency: {avg_inter_token_latency:.1f} ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--use_openai_api", default=False, action="store_true", help="Use OpenAI API for requests.") + parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") + parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.", + ) + parser.add_argument( + "--concurrency", + type=int, + default=None, + help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.", + ) + parser.add_argument("--num-prompts", type=int, default=1, help="Number of prompts to process.") + parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + # Validate that only one of request_rate or concurrency is set + if args.concurrency is not None and args.request_rate != float("inf"): + raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.") + + main(args) diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index c9f92f098..9a7ea556f 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -26,6 +26,7 @@ import aiohttp import numpy as np from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase +from tqdm.asyncio import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -63,112 +64,258 @@ def sample_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, int, int]]: + max_history_turns: int = 6, + max_total_tokens: int = 16384, +) -> List[Tuple[List[dict], str, int, int]]: # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] + # Filter out the conversations with at least 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= max_history_turns] print("read data set finish") - # Tokenize the prompts and completions. - import random - - dataset = random.sample(dataset, num_requests * 3) - prompts = [prompt for prompt, _ in dataset] - completions = [completion for _, completion in dataset] + dataset = dataset[: num_requests * 3] + + def to_openai_role(role_value: str) -> str: + lower_value = role_value.lower() + if lower_value in ["human", "user", "system"]: + return "user" if lower_value != "system" else "system" + return "assistant" + + # Build messages and targets + built_examples: List[Tuple[List[dict], str]] = [] + for data in dataset: + convs = data.get("conversations", []) + if not convs: + continue + # Find the last assistant turn to be used as the completion target + last_assistant_idx = -1 + for idx in range(len(convs) - 1, -1, -1): + role_val = convs[idx].get("from") or convs[idx].get("role") or "assistant" + if to_openai_role(role_val) == "assistant": + last_assistant_idx = idx + break + if last_assistant_idx <= 0: + # Need at least one prompt message before the assistant response + continue + # Determine how many turns of history to keep before the target assistant turn + start_idx = max(0, last_assistant_idx - max_history_turns) + context_convs = convs[start_idx:last_assistant_idx] + completion_text = convs[last_assistant_idx].get("value") or convs[last_assistant_idx].get("content") or "" + if not completion_text: + continue + messages: List[dict] = [] + for turn in context_convs: + role_val = turn.get("from") or turn.get("role") or "user" + content_val = turn.get("value") or turn.get("content") or "" + if not content_val: + continue + messages.append({"role": to_openai_role(role_val), "content": content_val}) + if not messages: + continue + built_examples.append((messages, completion_text)) + + # Render prompts using chat template when possible + rendered_prompts: List[str] = [] + for messages, _ in built_examples: + rendered_text = None + try: + # Prefer using the tokenizer's chat template + rendered_text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + # Fallback rendering if chat template is unavailable + parts = [] + for m in messages: + parts.append(f"{m['role']}: {m['content']}") + parts.append("assistant:") + rendered_text = "\n".join(parts) + rendered_prompts.append(rendered_text) - prompt_token_ids = tokenizer(prompts).input_ids - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): + # Tokenize the prompts and completions. + prompt_token_ids = tokenizer(rendered_prompts).input_ids if rendered_prompts else [] + completion_texts = [completion for _, completion in built_examples] + completion_token_ids = tokenizer(completion_texts).input_ids if completion_texts else [] + + tokenized_dataset: List[Tuple[List[dict], str, int, int]] = [] + for i in range(len(built_examples)): + messages, _ = built_examples[i] + prompt_len = len(prompt_token_ids[i]) output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + tokenized_dataset.append((messages, rendered_prompts[i], prompt_len, output_len)) - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) + # Filter out too long or too short sequences. + filtered_dataset: List[Tuple[List[dict], str, int, int]] = [] + for messages, rendered_prompt, prompt_len, output_len in tokenized_dataset: if prompt_len < 4 or output_len < 4: - # Prune too short sequences. continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. + if (prompt_len + output_len) >= max_total_tokens: continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((messages, rendered_prompt, prompt_len, output_len)) # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) + sampled_requests = filtered_dataset[:num_requests] sum_len = 0 - for e in sampled_requests: - sum_len += e[1] + e[2] + for _, _, prompt_len, output_len in sampled_requests: + sum_len += prompt_len + output_len print("total tokens:", sum_len) return sampled_requests async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: + concurrency: int = None, +) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: input_requests = iter(input_requests) - for request in input_requests: - yield request - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) - - -async def send_request(prompt: str, prompt_len: int, output_len: int) -> None: - request_start_time = time.time() - headers = {"Content-Type": "application/json"} - headers = {"User-Agent": "Benchmark Client"} - url = "http://localhost:8000/generate" - - data = { - "inputs": prompt, - "parameters": { - "do_sample": False, + if concurrency is not None: + # Concurrency-based request generation + # This generator will be consumed by the benchmark function + # which will manage the concurrency + for request in input_requests: + yield request + else: + # Rate-based request generation (original logic) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +async def send_request( + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool, pbar=None +) -> None: + if use_openai_api: + # Use OpenAI API to send the request. + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/v1/chat/completions" + + data = { + "model": "DeepSeek-R1", + "messages": messages, + "top_k": 1, + "top_p": 1.0, + "temperature": 0, + "stream": True, "ignore_eos": True, - "max_new_tokens": output_len, - # 'temperature': 0.1, - }, - } - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout) as session: - while True: + "max_tokens": output_len, + } + timeout = aiohttp.ClientTimeout(total=3 * 3600) + receive_n = 1 + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + text = "" + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + # text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + if delta_time < 0.005: + receive_n += 1 + chunks.append(delta_time) + start_time = now_time + + else: + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/generate_stream" + + data = { + "inputs": rendered_prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + }, + } + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + receive_n = 0 + text = "" async with session.post(url, headers=headers, json=data) as response: chunks = [] + start_time = time.time() + is_first = True async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + if delta_time < 0.005: + receive_n += 1 chunks.append(chunk) - output = b"".join(chunks).decode("utf-8") - output = json.loads(output) - - if "error" not in output: - break + text += json.loads(chunk.decode("utf-8")[5:])["token"]["text"] + start_time = now_time request_end_time = time.time() request_latency = request_end_time - request_start_time - REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) + REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) + + # Update progress bar if provided + if pbar: + pbar.update(1) async def benchmark( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, + use_openai_api: bool = False, + concurrency: int = None, ) -> None: - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request - task = asyncio.create_task(send_request(prompt, prompt_len, output_len)) - tasks.append(task) - await asyncio.gather(*tasks) + total_requests = len(input_requests) + + # Create progress bar + pbar = tqdm(total=total_requests, desc="Processing requests", unit="req") + + if concurrency is not None: + # Concurrency-based processing + semaphore = asyncio.Semaphore(concurrency) + tasks: List[asyncio.Task] = [] + + async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len): + async with semaphore: + await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len)) + tasks.append(task) + + await asyncio.gather(*tasks) + else: + # Rate-based processing (original logic) + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task( + send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Close progress bar + pbar.close() def main(args: argparse.Namespace): @@ -176,28 +323,40 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) tokenizer = get_tokenizer(args.tokenizer, "slow") - input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + input_requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.history_turns, args.max_total_tokens + ) benchmark_start_time = time.time() - asyncio.run(benchmark(input_requests, args.request_rate)) + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency)) benchmark_end_time = time.time() benchmark_time = benchmark_end_time - benchmark_start_time print(f"Total time: {benchmark_time:.2f} s") print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") # Compute the latency statistics. - avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) + avg_latency = np.mean([latency for _, _, latency, _ in REQUEST_LATENCY]) print(f"Average latency: {avg_latency:.2f} s") - avg_per_token_latency = np.mean( - [latency / (prompt_len + output_len) for prompt_len, output_len, latency in REQUEST_LATENCY] + avg_time_to_first_token = np.mean([ttft for _, _, _, ttft in REQUEST_LATENCY]) + print("Average time to first token: " f"{avg_time_to_first_token:.2f} s") + avg_per_token_latency = ( + np.mean([latency / (prompt_len + output_len) for prompt_len, output_len, latency, _ in REQUEST_LATENCY]) * 1000 ) - print(f"Average latency per token: {avg_per_token_latency:.2f} s") - avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency in REQUEST_LATENCY]) - print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + print(f"Average latency per token: {avg_per_token_latency:.1f} ms") + # avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency, _ in REQUEST_LATENCY]) + # print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + avg_inter_token_latency = ( + np.mean( + [(latency - ttft) / (output_len - 1) for _, output_len, latency, ttft in REQUEST_LATENCY if output_len > 1] + ) + * 1000 + ) + print(f"Average inter-token latency: {avg_inter_token_latency:.1f} ms") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--use_openai_api", default=False, action="store_true", help="Use OpenAI API for requests.") parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") parser.add_argument( @@ -209,7 +368,22 @@ def main(args: argparse.Namespace): "Otherwise, we use Poisson process to synthesize " "the request arrival times.", ) + parser.add_argument( + "--concurrency", + type=int, + default=None, + help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.", + ) parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") + parser.add_argument( + "--history-turns", type=int, default=6, help="Max number of context turns before the target assistant reply." + ) + parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() + + # Validate that only one of request_rate or concurrency is set + if args.concurrency is not None and args.request_rate != float("inf"): + raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.") + main(args) diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ba90e709b..fa7e92dbd 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -21,6 +21,8 @@ def init_mtp_model(args: StartArgs, kvargs, main_model): mtp_step = args.mtp_step draft_models = [] + logger.info(f"Initializing {mtp_step} MTP draft models") + os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" mtp_model_kvargs = kvargs mtp_model_kvargs.update( @@ -33,18 +35,26 @@ def init_mtp_model(args: StartArgs, kvargs, main_model): } ) for i in range(mtp_step): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) - mtp_model_kvargs.update( - { - "weight_dir": args.spec_model_dir, - "max_total_token_num": main_model.mem_manager.size, - "disable_chunked_prefill": True, - "mtp_mode": args.mtp_mode, - "main_model": main_model, - "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], - } - ) - draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + try: + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) + mtp_model_kvargs.update( + { + "weight_dir": args.mtp_draft_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "disable_chunked_prefill": True, + "mtp_mode": args.mtp_mode, + "main_model": main_model, + "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + } + ) + draft_model = Deepseek3MTPModel(mtp_model_kvargs) + draft_models.append(draft_model) + logger.info(f"Successfully initialized draft model {i+1}/{mtp_step}") + except Exception as e: + logger.error(f"Failed to initialize draft model {i+1}: {str(e)}") + raise + + logger.info(f"Successfully initialized all {len(draft_models)} draft models") return draft_models @@ -68,12 +78,11 @@ def test_model_inference_mtp(args): "max_total_token_num": args.max_total_token_num, "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, - "mem_faction": args.mem_fraction, + "mem_fraction": args.mem_fraction, "max_req_num": 2000, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, - "spec_algo": args.spec_algo, "disable_cudagraph": args.disable_cudagraph, } proc = multiprocessing.Process( @@ -92,7 +101,7 @@ def test_model_inference_mtp(args): return -def torch_profile(fn, log_dir=None): +def torch_profile(fn, batch_size, log_dir=None): torch.cuda.synchronize() with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], @@ -101,69 +110,124 @@ def torch_profile(fn, log_dir=None): on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir), ) as prof: fn() + torch.cuda.synchronize() if get_current_rank_in_dp() == 0: - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - -def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False): + logger.info(f"batch_size {batch_size}\n{prof.key_averages().table(sort_by='cuda_time_total', row_limit=20)}") + table = prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=20) + logger.info(table if table else " (no ops recorded)") + + +def run_forward_once( + args, + input_len, + output_len, + batch_size, + main_model, + draft_models, + warmup=False, + enable_torch_profile=False, + skip_prefill=False, +): import time + import torch.distributed as dist + + dist.barrier() torch.cuda.synchronize() prefill_start_time = time.time() - test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) - test_data = test_data.reshape(-1) - test_data = torch.from_numpy(test_data).cuda() - b_req_idx = torch.tensor( [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") for i in range(batch_size): b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() - # Main model Prefill - model_input = ModelInput( - batch_size=batch_size, - total_token_num=total_token_num, - max_len_in_batch=input_len, - input_ids=test_data, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - is_prefill=True, - b_ready_cache_len=b_ready_cache_len, - ) - model_output: ModelOutput = main_model.forward(model_input) - prob_out = torch.softmax(model_output.logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - draft_ids = [predict_ids] + if skip_prefill: + # Skip prefill computation but simulate the state after prefill + # Generate dummy output tokens as if prefill happened + draft_ids = [] + + # Generate dummy token IDs for main model and draft models + # Simulate one token output per model (main + draft models) + for model_idx in range(len(draft_models) + 1): + # Generate random token IDs as if they were predicted + dummy_predict_ids = np.random.randint(1000, 10000, (batch_size, 1)) + draft_ids.append(dummy_predict_ids) + + # Update sequence lengths to reflect that prefill "happened" + # No need to update b_seq_len as it already contains input_len + + if get_current_rank_in_dp() == 0 and not warmup: + logger.info(f"Skipped prefill phase, simulated {len(draft_ids)} draft outputs") + else: + # Generate test data for actual prefill + test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) + test_data = test_data.reshape(-1) + test_data = torch.from_numpy(test_data).cuda() + + # Allocate memory for prefill tokens + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + # Main model Prefill + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=input_len, + input_ids=test_data, + b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, + b_seq_len=b_seq_len, + mem_indexes=mem_indexes, + is_prefill=True, + b_ready_cache_len=b_ready_cache_len, + ) - # Draft model Prefill - # For simplicity, we'll just take the input of main_model to draft model. - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens - for draft_model_id in range(len(draft_models)): - draft_model = draft_models[draft_model_id] - model_output = draft_model.forward(model_input) + model_output: ModelOutput = main_model.forward(model_input) prob_out = torch.softmax(model_output.logits, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() - draft_ids.append(predict_ids) + + draft_ids = [predict_ids] + + # Draft model Prefill + # For simplicity, we'll just take the input of main_model to draft model. model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + draft_ids.append(predict_ids) + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens torch.cuda.synchronize() prefill_end_time = time.time() - if get_current_rank_in_dp() == 0 and not warmup: - print("prefill time cost:", (prefill_end_time - prefill_start_time) * 1000) - print( - f"Prefill throughput: {batch_size * input_len * args.dp / (prefill_end_time - prefill_start_time)} tokens/s" - ) + + rank_id = get_current_rank_in_dp() + + if rank_id == 0 and not warmup and not skip_prefill: + prefill_time = (prefill_end_time - prefill_start_time) * 1000 + dp_size = getattr(args, "dp", 1) + throughput = dp_size * batch_size * input_len / (prefill_end_time - prefill_start_time) + logger.info(f"prefill time cost: {prefill_time:.2f} ms, prefill throughput: {throughput:.2f} tokens/s") + + # Add profiling support for prefill + if enable_torch_profile and not warmup and not skip_prefill: + logger.info("Profile Prefill") + try: + torch_profile( + lambda: main_model.forward(model_input), + batch_size, + log_dir=f"./logs/forward_prefill_mtp_bs{batch_size}_{rank_id}", + ) + except Exception as e: + logger.error(f"Profiling error: {str(e)}") + # Continue without profiling torch.cuda.synchronize() @@ -172,12 +236,14 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ # build main decode input: nopad_b_seq_idx = [] + nopad_b_mtp_index = [] nopad_b_seq_len = [] nopad_total_token_num = 0 nopad_max_len_in_batch = 0 for i in range(batch_size): nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_mtp_index.append(0) seq_len = b_seq_len[i].item() nopad_b_seq_len.append(seq_len + 1) nopad_total_token_num += seq_len + 1 @@ -185,11 +251,13 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ for step in range(len(draft_models)): nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_mtp_index.append(step + 1) nopad_b_seq_len.append(seq_len + step + 2) nopad_total_token_num += seq_len + step + 2 nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") + nopad_b_mtp_index = torch.tensor(nopad_b_mtp_index, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() @@ -198,9 +266,10 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ total_token_num=nopad_total_token_num, max_len_in_batch=nopad_max_len_in_batch, input_ids=decode_input_ids, - mem_indexes=mem_indexes, b_req_idx=nopad_b_seq_idx, + b_mtp_index=nopad_b_mtp_index, b_seq_len=nopad_b_seq_len, + mem_indexes=mem_indexes, is_prefill=False, ) @@ -232,15 +301,31 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input.input_ids = predict_ids.reshape(-1) model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens torch.cuda.synchronize() - if i % 100 == 0 or i == output_len - 1: + if i % 100 == 0 or i == output_len - (len(draft_models) + 1): step_end_time = time.time() - if get_current_rank_in_dp() == 0 and not warmup: + if rank_id == 0 and not warmup: step_time = step_end_time - step_start_time - print(i, " step cost time:", step_time * 1000) - print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") + dp_size = getattr(args, "dp", 1) + throughput = dp_size * batch_size * (len(draft_models) + 1) / step_time + logger.info(f"i: {i}, step cost time: {step_time * 1000:.2f} ms, throughput: {throughput:.2f} tokens/s") + + # Add profiling support for decode on last step + if enable_torch_profile and not warmup and i == output_len - (len(draft_models) + 1): + logger.info("Profile Decode") + try: + torch_profile( + lambda: main_model.forward(model_input), + batch_size, + log_dir=f"./logs/forward_decode_mtp_bs{batch_size}_{rank_id}", + ) + except Exception as e: + logger.error(f"Profiling error: {str(e)}") + # Continue without profiling main_model.mem_manager.free_all() main_model.req_manager.free_all() + torch.cuda.synchronize() + torch.cuda.empty_cache() def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, ans_queue): @@ -250,11 +335,22 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a from lightllm.distributed import dist_group_manager from lightllm.utils.dist_utils import set_current_device_id + # Handle batch_sizes as either int or list + if isinstance(batch_sizes, int): + batch_sizes = [batch_sizes] + else: + # Default batch sizes for comprehensive testing + batch_sizes = [16, 32, 64] + + logger.info(f"Testing batch sizes: {batch_sizes}") + import torch.distributed as dist enable_decode_overlap = args.enable_decode_microbatch_overlap group_size = 1 if enable_decode_overlap or args.enable_prefill_microbatch_overlap: + for bs in batch_sizes: + assert bs % 2 == 0, f"batch size {bs} must be even number for overlap mode" group_size = 2 init_distributed_env(model_kvargs) dist_group_manager.create_groups(group_size=group_size) @@ -265,14 +361,41 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a main_model, _ = get_model(model_cfg, model_kvargs) draft_models = init_mtp_model(args, model_kvargs, main_model) - if isinstance(batch_sizes, int): - batch_sizes = [batch_sizes] - + rank_id = model_kvargs["rank_id"] + skip_prefill = getattr(args, "skip_prefill", False) for batch_size in batch_sizes: + if rank_id == 0: + logger.info(f"Testing batch size {batch_size}") + # warm up - run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=True) + run_forward_once( + args, + input_len, + 10, + batch_size, + main_model, + draft_models, + warmup=True, + enable_torch_profile=False, + skip_prefill=skip_prefill, + ) torch.cuda.synchronize() - run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False) + + # actual test + enable_profiling = getattr(args, "torch_profile", False) + run_forward_once( + args, + input_len, + output_len, + batch_size, + main_model, + draft_models, + warmup=False, + enable_torch_profile=enable_profiling, + skip_prefill=skip_prefill, + ) + if rank_id == 0: + logger.info("=" * 50) dist.barrier() ans_queue.put(True) diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 5b3751bcc..8725ac267 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -40,6 +40,11 @@ def test_model_infer(self): action="store_true", help="Enable torch profiler to profile the model", ) + parser.add_argument( + "--skip_prefill", + action="store_true", + help="Whether or not to skip prefill phase, because it is easy to have OOM in large batches", + ) args = parser.parse_args() set_env_start_args(args) torch.multiprocessing.set_start_method("spawn")