|
| 1 | +# Adapted from benchmarks/benchmark_serving.py |
| 2 | +# of the vllm-project/vllm GitHub repository. |
| 3 | +# |
| 4 | +# Copyright 2023 ModelTC Team |
| 5 | +# Copyright 2023 vLLM Team |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | + |
| 19 | +import argparse |
| 20 | +import asyncio |
| 21 | +import json |
| 22 | +import random |
| 23 | +import time |
| 24 | +from typing import AsyncGenerator, List, Tuple, Union |
| 25 | + |
| 26 | +import aiohttp |
| 27 | +import numpy as np |
| 28 | +from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase |
| 29 | + |
| 30 | +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast |
| 31 | + |
| 32 | + |
| 33 | +def get_tokenizer( |
| 34 | + tokenizer_name: str, |
| 35 | + tokenizer_mode: str = "auto", |
| 36 | + *args, |
| 37 | + **kwargs, |
| 38 | +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: |
| 39 | + """Gets a tokenizer for the given model name via Huggingface.""" |
| 40 | + if tokenizer_mode == "slow": |
| 41 | + if kwargs.get("use_fast", False): |
| 42 | + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") |
| 43 | + kwargs["use_fast"] = False |
| 44 | + |
| 45 | + if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): |
| 46 | + pass |
| 47 | + try: |
| 48 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs) |
| 49 | + except TypeError as e: |
| 50 | + err_msg = "Failed to load the tokenizer. {e}" |
| 51 | + raise RuntimeError(err_msg) from e |
| 52 | + |
| 53 | + if not isinstance(tokenizer, PreTrainedTokenizerFast): |
| 54 | + pass |
| 55 | + return tokenizer |
| 56 | + |
| 57 | + |
| 58 | +# (prompt len, output len, latency) |
| 59 | +REQUEST_LATENCY: List[Tuple[int, int, float]] = [] |
| 60 | + |
| 61 | + |
| 62 | +def sample_requests( |
| 63 | + dataset_path: str, |
| 64 | + num_requests: int, |
| 65 | + tokenizer: PreTrainedTokenizerBase, |
| 66 | + max_total_tokens: int = 16384, |
| 67 | +) -> List[Tuple[List[dict], str, int, int]]: |
| 68 | + # Load the dataset (jsonl) |
| 69 | + dataset = [] |
| 70 | + with open(dataset_path) as f: |
| 71 | + for line in f.readlines(): |
| 72 | + if not line.strip(): |
| 73 | + continue |
| 74 | + dataset.append(json.loads(line)) |
| 75 | + print("read data set finish") |
| 76 | + |
| 77 | + def render_with_template(messages: List[dict]) -> str: |
| 78 | + try: |
| 79 | + return tokenizer.apply_chat_template( |
| 80 | + messages, |
| 81 | + tokenize=False, |
| 82 | + add_generation_prompt=True, |
| 83 | + ) |
| 84 | + except Exception: |
| 85 | + parts = [] |
| 86 | + for m in messages: |
| 87 | + parts.append(f"{m['role']}: {m['content']}") |
| 88 | + parts.append("assistant:") |
| 89 | + return "\n".join(parts) |
| 90 | + |
| 91 | + built_examples: List[Tuple[List[dict], str, int, int]] = [] |
| 92 | + |
| 93 | + for data in dataset: |
| 94 | + context = data.get("context") or "" |
| 95 | + question = data.get("input") or "Summarizing government work reports" |
| 96 | + answers = data.get("answers") |
| 97 | + if not isinstance(context, str) or not isinstance(question, str): |
| 98 | + continue |
| 99 | + |
| 100 | + # Build messages: system + user with context and question |
| 101 | + system_prompt = "You are a helpful assistant. Read the context and answer the question concisely." |
| 102 | + user_content = f"Context:\n{context}\nInput:\n{question}" |
| 103 | + messages = [ |
| 104 | + {"role": "system", "content": system_prompt}, |
| 105 | + {"role": "user", "content": user_content}, |
| 106 | + ] |
| 107 | + |
| 108 | + rendered_prompt = render_with_template(messages) |
| 109 | + prompt_len = len(tokenizer(rendered_prompt).input_ids) |
| 110 | + |
| 111 | + # Estimate output length from reference answer if available |
| 112 | + target_text = "" |
| 113 | + if isinstance(answers, list) and len(answers) > 0: |
| 114 | + first_ans = answers[0] |
| 115 | + if isinstance(first_ans, str): |
| 116 | + target_text = first_ans |
| 117 | + else: |
| 118 | + target_text = str(first_ans) |
| 119 | + elif isinstance(answers, str): |
| 120 | + target_text = answers |
| 121 | + |
| 122 | + estimated_out = len(tokenizer(target_text).input_ids) if target_text else 128 |
| 123 | + |
| 124 | + # Fit within max_total_tokens |
| 125 | + available_out = max_total_tokens - 1 - prompt_len |
| 126 | + if available_out < 4: |
| 127 | + # Skip samples that are too long |
| 128 | + continue |
| 129 | + output_len = min(estimated_out, available_out) |
| 130 | + |
| 131 | + built_examples.append((messages, rendered_prompt, prompt_len, output_len)) |
| 132 | + |
| 133 | + # Take the first N valid samples |
| 134 | + sampled_requests = built_examples[:num_requests] |
| 135 | + sum_len = 0 |
| 136 | + for _, _, prompt_len, output_len in sampled_requests: |
| 137 | + sum_len += prompt_len + output_len |
| 138 | + print("total tokens:", sum_len) |
| 139 | + return sampled_requests |
| 140 | + |
| 141 | + |
| 142 | +async def get_request( |
| 143 | + input_requests: List[Tuple[List[dict], str, int, int]], |
| 144 | + request_rate: float, |
| 145 | +) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: |
| 146 | + input_requests = iter(input_requests) |
| 147 | + for request in input_requests: |
| 148 | + yield request |
| 149 | + |
| 150 | + if request_rate == float("inf"): |
| 151 | + # If the request rate is infinity, then we don't need to wait. |
| 152 | + continue |
| 153 | + # Sample the request interval from the exponential distribution. |
| 154 | + interval = np.random.exponential(1.0 / request_rate) |
| 155 | + # The next request will be sent after the interval. |
| 156 | + await asyncio.sleep(interval) |
| 157 | + |
| 158 | + |
| 159 | +async def send_request( |
| 160 | + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool |
| 161 | +) -> None: |
| 162 | + if use_openai_api: |
| 163 | + # Use OpenAI API to send the request. |
| 164 | + # Use local server to send the request. |
| 165 | + request_start_time = time.time() |
| 166 | + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} |
| 167 | + url = "http://localhost:8000/v1/chat/completions" |
| 168 | + |
| 169 | + data = { |
| 170 | + "model": "DeepSeek-R1", |
| 171 | + "messages": messages, |
| 172 | + "top_k": 1, |
| 173 | + "top_p": 1.0, |
| 174 | + "temperature": 0, |
| 175 | + "stream": True, |
| 176 | + "ignore_eos": True, |
| 177 | + "max_tokens": output_len, |
| 178 | + } |
| 179 | + timeout = aiohttp.ClientTimeout(total=3 * 3600) |
| 180 | + receive_n = 1 |
| 181 | + |
| 182 | + async with aiohttp.ClientSession(timeout=timeout) as session: |
| 183 | + async with session.post(url, headers=headers, json=data) as response: |
| 184 | + chunks = [] |
| 185 | + text = "" |
| 186 | + start_time = time.time() |
| 187 | + is_first = True |
| 188 | + async for chunk, _ in response.content.iter_chunks(): |
| 189 | + now_time = time.time() |
| 190 | + delta_time = now_time - start_time |
| 191 | + if is_first: |
| 192 | + is_first = False |
| 193 | + ttft = delta_time |
| 194 | + text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") |
| 195 | + if delta_time < 0.005: |
| 196 | + receive_n += 1 |
| 197 | + chunks.append(delta_time) |
| 198 | + start_time = now_time |
| 199 | + |
| 200 | + else: |
| 201 | + # Use local server to send the request. |
| 202 | + request_start_time = time.time() |
| 203 | + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} |
| 204 | + url = "http://localhost:8000/generate_stream" |
| 205 | + |
| 206 | + data = { |
| 207 | + "inputs": rendered_prompt, |
| 208 | + "parameters": { |
| 209 | + "do_sample": False, |
| 210 | + "ignore_eos": True, |
| 211 | + "max_new_tokens": output_len, |
| 212 | + }, |
| 213 | + } |
| 214 | + |
| 215 | + timeout = aiohttp.ClientTimeout(total=3 * 3600) |
| 216 | + async with aiohttp.ClientSession(timeout=timeout) as session: |
| 217 | + receive_n = 0 |
| 218 | + text = "" |
| 219 | + async with session.post(url, headers=headers, json=data) as response: |
| 220 | + chunks = [] |
| 221 | + start_time = time.time() |
| 222 | + is_first = True |
| 223 | + async for chunk, _ in response.content.iter_chunks(): |
| 224 | + now_time = time.time() |
| 225 | + delta_time = now_time - start_time |
| 226 | + if is_first: |
| 227 | + is_first = False |
| 228 | + ttft = delta_time |
| 229 | + if delta_time < 0.005: |
| 230 | + receive_n += 1 |
| 231 | + chunks.append(chunk) |
| 232 | + text += json.loads(chunk.decode("utf-8")[5:])["token"]["text"] |
| 233 | + start_time = now_time |
| 234 | + |
| 235 | + request_end_time = time.time() |
| 236 | + request_latency = request_end_time - request_start_time |
| 237 | + REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) |
| 238 | + |
| 239 | + |
| 240 | +async def benchmark( |
| 241 | + input_requests: List[Tuple[List[dict], str, int, int]], |
| 242 | + request_rate: float, |
| 243 | + use_openai_api: bool = False, |
| 244 | +) -> None: |
| 245 | + tasks: List[asyncio.Task] = [] |
| 246 | + async for request in get_request(input_requests, request_rate): |
| 247 | + messages, rendered_prompt, prompt_len, output_len = request |
| 248 | + task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api)) |
| 249 | + tasks.append(task) |
| 250 | + await asyncio.gather(*tasks) |
| 251 | + |
| 252 | + |
| 253 | +def main(args: argparse.Namespace): |
| 254 | + print(args) |
| 255 | + random.seed(args.seed) |
| 256 | + np.random.seed(args.seed) |
| 257 | + tokenizer = get_tokenizer(args.tokenizer, "slow") |
| 258 | + input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_total_tokens) |
| 259 | + |
| 260 | + benchmark_start_time = time.time() |
| 261 | + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api)) |
| 262 | + benchmark_end_time = time.time() |
| 263 | + benchmark_time = benchmark_end_time - benchmark_start_time |
| 264 | + print(f"Total time: {benchmark_time:.2f} s") |
| 265 | + print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") |
| 266 | + |
| 267 | + # Compute the latency statistics. |
| 268 | + avg_latency = np.mean([latency for _, _, latency, _ in REQUEST_LATENCY]) |
| 269 | + print(f"Average latency: {avg_latency:.2f} s") |
| 270 | + avg_time_to_first_token = np.mean([ttft for _, _, _, ttft in REQUEST_LATENCY]) |
| 271 | + print("Average time to first token: " f"{avg_time_to_first_token:.2f} s") |
| 272 | + avg_per_token_latency = ( |
| 273 | + np.mean([latency / (prompt_len + output_len) for prompt_len, output_len, latency, _ in REQUEST_LATENCY]) * 1000 |
| 274 | + ) |
| 275 | + print(f"Average latency per token: {avg_per_token_latency:.1f} ms") |
| 276 | + # avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency, _ in REQUEST_LATENCY]) |
| 277 | + # print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") |
| 278 | + avg_inter_token_latency = ( |
| 279 | + np.mean( |
| 280 | + [(latency - ttft) / (output_len - 1) for _, output_len, latency, ttft in REQUEST_LATENCY if output_len > 1] |
| 281 | + ) |
| 282 | + * 1000 |
| 283 | + ) |
| 284 | + print(f"Average inter-token latency: {avg_inter_token_latency:.1f} ms") |
| 285 | + |
| 286 | + |
| 287 | +if __name__ == "__main__": |
| 288 | + parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") |
| 289 | + parser.add_argument("--use_openai_api", default=False, action="store_true", help="Use OpenAI API for requests.") |
| 290 | + parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") |
| 291 | + parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") |
| 292 | + parser.add_argument( |
| 293 | + "--request-rate", |
| 294 | + type=float, |
| 295 | + default=float("inf"), |
| 296 | + help="Number of requests per second. If this is inf, " |
| 297 | + "then all the requests are sent at time 0. " |
| 298 | + "Otherwise, we use Poisson process to synthesize " |
| 299 | + "the request arrival times.", |
| 300 | + ) |
| 301 | + parser.add_argument("--num-prompts", type=int, default=1, help="Number of prompts to process.") |
| 302 | + parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") |
| 303 | + parser.add_argument("--seed", type=int, default=0) |
| 304 | + args = parser.parse_args() |
| 305 | + main(args) |
0 commit comments