Skip to content

Commit 594de73

Browse files
committed
Improved the testing script
1 parent d3bf481 commit 594de73

File tree

4 files changed

+336
-102
lines changed

4 files changed

+336
-102
lines changed

test/benchmark/service/benchmark_longbench.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import aiohttp
2727
import numpy as np
2828
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
29+
from tqdm.asyncio import tqdm
2930

3031
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3132

@@ -142,22 +143,32 @@ def render_with_template(messages: List[dict]) -> str:
142143
async def get_request(
143144
input_requests: List[Tuple[List[dict], str, int, int]],
144145
request_rate: float,
146+
concurrency: int = None,
145147
) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]:
146148
input_requests = iter(input_requests)
147-
for request in input_requests:
148-
yield request
149149

150-
if request_rate == float("inf"):
151-
# If the request rate is infinity, then we don't need to wait.
152-
continue
153-
# Sample the request interval from the exponential distribution.
154-
interval = np.random.exponential(1.0 / request_rate)
155-
# The next request will be sent after the interval.
156-
await asyncio.sleep(interval)
150+
if concurrency is not None:
151+
# Concurrency-based request generation
152+
# This generator will be consumed by the benchmark function
153+
# which will manage the concurrency
154+
for request in input_requests:
155+
yield request
156+
else:
157+
# Rate-based request generation (original logic)
158+
for request in input_requests:
159+
yield request
160+
161+
if request_rate == float("inf"):
162+
# If the request rate is infinity, then we don't need to wait.
163+
continue
164+
# Sample the request interval from the exponential distribution.
165+
interval = np.random.exponential(1.0 / request_rate)
166+
# The next request will be sent after the interval.
167+
await asyncio.sleep(interval)
157168

158169

159170
async def send_request(
160-
messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool
171+
messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool, pbar=None
161172
) -> None:
162173
if use_openai_api:
163174
# Use OpenAI API to send the request.
@@ -191,7 +202,7 @@ async def send_request(
191202
if is_first:
192203
is_first = False
193204
ttft = delta_time
194-
text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "")
205+
# text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "")
195206
if delta_time < 0.005:
196207
receive_n += 1
197208
chunks.append(delta_time)
@@ -236,18 +247,50 @@ async def send_request(
236247
request_latency = request_end_time - request_start_time
237248
REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft))
238249

250+
# Update progress bar if provided
251+
if pbar:
252+
pbar.update(1)
253+
239254

240255
async def benchmark(
241256
input_requests: List[Tuple[List[dict], str, int, int]],
242257
request_rate: float,
243258
use_openai_api: bool = False,
259+
concurrency: int = None,
244260
) -> None:
245-
tasks: List[asyncio.Task] = []
246-
async for request in get_request(input_requests, request_rate):
247-
messages, rendered_prompt, prompt_len, output_len = request
248-
task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api))
249-
tasks.append(task)
250-
await asyncio.gather(*tasks)
261+
total_requests = len(input_requests)
262+
263+
# Create progress bar
264+
pbar = tqdm(total=total_requests, desc="Processing requests", unit="req")
265+
266+
if concurrency is not None:
267+
# Concurrency-based processing
268+
semaphore = asyncio.Semaphore(concurrency)
269+
tasks: List[asyncio.Task] = []
270+
271+
async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len):
272+
async with semaphore:
273+
await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar)
274+
275+
async for request in get_request(input_requests, request_rate, concurrency):
276+
messages, rendered_prompt, prompt_len, output_len = request
277+
task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len))
278+
tasks.append(task)
279+
280+
await asyncio.gather(*tasks)
281+
else:
282+
# Rate-based processing (original logic)
283+
tasks: List[asyncio.Task] = []
284+
async for request in get_request(input_requests, request_rate, concurrency):
285+
messages, rendered_prompt, prompt_len, output_len = request
286+
task = asyncio.create_task(
287+
send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar)
288+
)
289+
tasks.append(task)
290+
await asyncio.gather(*tasks)
291+
292+
# Close progress bar
293+
pbar.close()
251294

252295

253296
def main(args: argparse.Namespace):
@@ -258,7 +301,7 @@ def main(args: argparse.Namespace):
258301
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_total_tokens)
259302

260303
benchmark_start_time = time.time()
261-
asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api))
304+
asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency))
262305
benchmark_end_time = time.time()
263306
benchmark_time = benchmark_end_time - benchmark_start_time
264307
print(f"Total time: {benchmark_time:.2f} s")
@@ -298,8 +341,19 @@ def main(args: argparse.Namespace):
298341
"Otherwise, we use Poisson process to synthesize "
299342
"the request arrival times.",
300343
)
344+
parser.add_argument(
345+
"--concurrency",
346+
type=int,
347+
default=None,
348+
help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.",
349+
)
301350
parser.add_argument("--num-prompts", type=int, default=1, help="Number of prompts to process.")
302351
parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).")
303352
parser.add_argument("--seed", type=int, default=0)
304353
args = parser.parse_args()
354+
355+
# Validate that only one of request_rate or concurrency is set
356+
if args.concurrency is not None and args.request_rate != float("inf"):
357+
raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.")
358+
305359
main(args)

test/benchmark/service/benchmark_sharegpt.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import aiohttp
2727
import numpy as np
2828
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
29+
from tqdm.asyncio import tqdm
2930

3031
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3132

@@ -167,22 +168,32 @@ def to_openai_role(role_value: str) -> str:
167168
async def get_request(
168169
input_requests: List[Tuple[List[dict], str, int, int]],
169170
request_rate: float,
171+
concurrency: int = None,
170172
) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]:
171173
input_requests = iter(input_requests)
172-
for request in input_requests:
173-
yield request
174174

175-
if request_rate == float("inf"):
176-
# If the request rate is infinity, then we don't need to wait.
177-
continue
178-
# Sample the request interval from the exponential distribution.
179-
interval = np.random.exponential(1.0 / request_rate)
180-
# The next request will be sent after the interval.
181-
await asyncio.sleep(interval)
175+
if concurrency is not None:
176+
# Concurrency-based request generation
177+
# This generator will be consumed by the benchmark function
178+
# which will manage the concurrency
179+
for request in input_requests:
180+
yield request
181+
else:
182+
# Rate-based request generation (original logic)
183+
for request in input_requests:
184+
yield request
185+
186+
if request_rate == float("inf"):
187+
# If the request rate is infinity, then we don't need to wait.
188+
continue
189+
# Sample the request interval from the exponential distribution.
190+
interval = np.random.exponential(1.0 / request_rate)
191+
# The next request will be sent after the interval.
192+
await asyncio.sleep(interval)
182193

183194

184195
async def send_request(
185-
messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool
196+
messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool, pbar=None
186197
) -> None:
187198
if use_openai_api:
188199
# Use OpenAI API to send the request.
@@ -216,7 +227,7 @@ async def send_request(
216227
if is_first:
217228
is_first = False
218229
ttft = delta_time
219-
text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "")
230+
# text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "")
220231
if delta_time < 0.005:
221232
receive_n += 1
222233
chunks.append(delta_time)
@@ -261,18 +272,50 @@ async def send_request(
261272
request_latency = request_end_time - request_start_time
262273
REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft))
263274

275+
# Update progress bar if provided
276+
if pbar:
277+
pbar.update(1)
278+
264279

265280
async def benchmark(
266281
input_requests: List[Tuple[List[dict], str, int, int]],
267282
request_rate: float,
268283
use_openai_api: bool = False,
284+
concurrency: int = None,
269285
) -> None:
270-
tasks: List[asyncio.Task] = []
271-
async for request in get_request(input_requests, request_rate):
272-
messages, rendered_prompt, prompt_len, output_len = request
273-
task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api))
274-
tasks.append(task)
275-
await asyncio.gather(*tasks)
286+
total_requests = len(input_requests)
287+
288+
# Create progress bar
289+
pbar = tqdm(total=total_requests, desc="Processing requests", unit="req")
290+
291+
if concurrency is not None:
292+
# Concurrency-based processing
293+
semaphore = asyncio.Semaphore(concurrency)
294+
tasks: List[asyncio.Task] = []
295+
296+
async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len):
297+
async with semaphore:
298+
await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar)
299+
300+
async for request in get_request(input_requests, request_rate, concurrency):
301+
messages, rendered_prompt, prompt_len, output_len = request
302+
task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len))
303+
tasks.append(task)
304+
305+
await asyncio.gather(*tasks)
306+
else:
307+
# Rate-based processing (original logic)
308+
tasks: List[asyncio.Task] = []
309+
async for request in get_request(input_requests, request_rate, concurrency):
310+
messages, rendered_prompt, prompt_len, output_len = request
311+
task = asyncio.create_task(
312+
send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar)
313+
)
314+
tasks.append(task)
315+
await asyncio.gather(*tasks)
316+
317+
# Close progress bar
318+
pbar.close()
276319

277320

278321
def main(args: argparse.Namespace):
@@ -285,7 +328,7 @@ def main(args: argparse.Namespace):
285328
)
286329

287330
benchmark_start_time = time.time()
288-
asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api))
331+
asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency))
289332
benchmark_end_time = time.time()
290333
benchmark_time = benchmark_end_time - benchmark_start_time
291334
print(f"Total time: {benchmark_time:.2f} s")
@@ -325,11 +368,22 @@ def main(args: argparse.Namespace):
325368
"Otherwise, we use Poisson process to synthesize "
326369
"the request arrival times.",
327370
)
371+
parser.add_argument(
372+
"--concurrency",
373+
type=int,
374+
default=None,
375+
help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.",
376+
)
328377
parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.")
329378
parser.add_argument(
330379
"--history-turns", type=int, default=6, help="Max number of context turns before the target assistant reply."
331380
)
332381
parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).")
333382
parser.add_argument("--seed", type=int, default=0)
334383
args = parser.parse_args()
384+
385+
# Validate that only one of request_rate or concurrency is set
386+
if args.concurrency is not None and args.request_rate != float("inf"):
387+
raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.")
388+
335389
main(args)

0 commit comments

Comments
 (0)