Skip to content

Commit dc42095

Browse files
author
t00939662
committed
[Fix]Add import checking to trace_replay and fix the issue of unclosed network resources
1 parent b53b23a commit dc42095

File tree

1 file changed

+150
-141
lines changed

1 file changed

+150
-141
lines changed

benchmarks/trace_replay.py

Lines changed: 150 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -454,174 +454,175 @@ async def replay_trace_by_time(
454454
if test_request is None:
455455
raise ValueError("No request found for initial test run.")
456456

457-
session = aiohttp.ClientSession(
457+
async with aiohttp.ClientSession(
458458
trust_env=True,
459459
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
460-
)
461-
462-
test_input = RequestFuncInput(
463-
model=model_id,
464-
model_name=model_name,
465-
prompt=test_request.prompt,
466-
api_url=api_url,
467-
prompt_len=test_request.prompt_len,
468-
output_len=test_request.expected_output_len,
469-
logprobs=None,
470-
multi_modal_content=getattr(test_request, "multi_modal_data", None),
471-
ignore_eos=True,
472-
extra_body={"temperature": 0.9},
473-
)
474-
475-
test_output = await request_func(request_func_input=test_input, session=session)
476-
477-
if not getattr(test_output, "success", False):
478-
raise ValueError(
479-
"Initial test run failed - Please make sure arguments "
480-
f"are correctly specified. Error: {getattr(test_output, 'error', '')}"
481-
)
482-
else:
483-
print("Initial test run completed. Starting main run...")
484-
485-
total = sum(len(req_list) for req_list in req_groups.values())
486-
pbar = None if disable_tqdm else tqdm(total=total)
487-
semaphore = (
488-
asyncio.Semaphore(args.max_concurrency)
489-
if getattr(args, "max_concurrency", None)
490-
else None
491-
)
492-
start_time = time.perf_counter()
493-
print(f"Start time is {start_time}")
494-
tasks = []
495-
flat_requests = []
496-
497-
async def _run_one_request(sample_req):
498-
sampling_params = {"temperature": 0.9}
499-
req_input = RequestFuncInput(
460+
) as session:
461+
test_input = RequestFuncInput(
500462
model=model_id,
501463
model_name=model_name,
502-
prompt=sample_req.prompt,
464+
prompt=test_request.prompt,
503465
api_url=api_url,
504-
prompt_len=sample_req.prompt_len,
505-
output_len=sample_req.expected_output_len,
466+
prompt_len=test_request.prompt_len,
467+
output_len=test_request.expected_output_len,
506468
logprobs=None,
507-
extra_body=sampling_params,
469+
multi_modal_content=getattr(test_request, "multi_modal_data", None),
508470
ignore_eos=True,
471+
extra_body={"temperature": 0.9},
472+
)
473+
474+
test_output = await request_func(request_func_input=test_input, session=session)
475+
476+
if not getattr(test_output, "success", False):
477+
raise ValueError(
478+
"Initial test run failed - Please make sure arguments "
479+
f"are correctly specified. Error: {getattr(test_output, 'error', '')}"
480+
)
481+
else:
482+
print("Initial test run completed. Starting main run...")
483+
484+
total = sum(len(req_list) for req_list in req_groups.values())
485+
pbar = None if disable_tqdm else tqdm(total=total)
486+
semaphore = (
487+
asyncio.Semaphore(args.max_concurrency)
488+
if getattr(args, "max_concurrency", None)
489+
else None
509490
)
510-
if semaphore is not None:
511-
async with semaphore:
491+
start_time = time.perf_counter()
492+
print(f"Start time is {start_time}")
493+
tasks = []
494+
flat_requests = []
495+
496+
async def _run_one_request(sample_req):
497+
sampling_params = {"temperature": 0.9}
498+
req_input = RequestFuncInput(
499+
model=model_id,
500+
model_name=model_name,
501+
prompt=sample_req.prompt,
502+
api_url=api_url,
503+
prompt_len=sample_req.prompt_len,
504+
output_len=sample_req.expected_output_len,
505+
logprobs=None,
506+
extra_body=sampling_params,
507+
ignore_eos=True,
508+
)
509+
if semaphore is not None:
510+
async with semaphore:
511+
return await request_func(
512+
request_func_input=req_input, session=session, pbar=pbar
513+
)
514+
else:
512515
return await request_func(
513516
request_func_input=req_input, session=session, pbar=pbar
514517
)
515-
else:
516-
return await request_func(
517-
request_func_input=req_input, session=session, pbar=pbar
518-
)
519-
520-
for sec, reqs in sorted(req_groups.items()):
521-
delay = sec - (time.perf_counter() - start_time)
522-
delay = max(0, delay)
523518

524-
async def send_group(r=reqs, d=delay):
525-
await asyncio.sleep(d)
526-
print(
527-
f"Sending request at {time.perf_counter() - start_time:.3f}s with {len(r)} reqs"
528-
)
529-
group_tasks = [asyncio.create_task(_run_one_request(req)) for req in r]
530-
try:
531-
return await asyncio.gather(*group_tasks)
532-
except asyncio.TimeoutError:
533-
print(f"Request timed out: group at delay {d:.3f}s")
534-
return []
535-
except Exception as e:
536-
print(f"Request failed: {e}")
537-
return []
538-
539-
tasks.append(asyncio.create_task(send_group(reqs, delay)))
540-
flat_requests.extend(reqs)
541-
542-
group_results = await asyncio.gather(*tasks)
543-
outputs = []
544-
for res in group_results:
545-
if isinstance(res, list):
546-
outputs.extend(res)
547-
548-
if pbar is not None:
549-
pbar.close()
550-
551-
benchmark_duration = time.perf_counter() - start_time
552-
metrics, actual_output_lens = calculate_metrics(
553-
input_requests=flat_requests,
554-
outputs=outputs,
555-
dur_s=benchmark_duration,
556-
tokenizer=tokenizer,
557-
selected_percentiles=[25.0, 50.0, 75.0, 99.0],
558-
goodput_config_dict={"ttft": 2000, "tpot": 50},
559-
)
519+
for sec, reqs in sorted(req_groups.items()):
520+
delay = sec - (time.perf_counter() - start_time)
521+
delay = max(0, delay)
560522

561-
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
562-
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
563-
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
564-
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
565-
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
566-
print(
567-
"{:<40} {:<10.2f}".format(
568-
"Request throughput (req/s):", metrics.request_throughput
569-
)
570-
)
571-
print(
572-
"{:<40} {:<10.2f}".format(
573-
"Output token throughput (tok/s):", metrics.output_throughput
574-
)
575-
)
576-
print(
577-
"{:<40} {:<10.2f}".format(
578-
"Total Token throughput (tok/s):", metrics.total_token_throughput
523+
async def send_group(r=reqs, d=delay):
524+
await asyncio.sleep(d)
525+
print(
526+
f"Sending request at {time.perf_counter() - start_time:.3f}s with {len(r)} reqs"
527+
)
528+
group_tasks = [asyncio.create_task(_run_one_request(req)) for req in r]
529+
try:
530+
return await asyncio.gather(*group_tasks)
531+
except asyncio.TimeoutError:
532+
print(f"Request timed out: group at delay {d:.3f}s")
533+
return []
534+
except Exception as e:
535+
print(f"Request failed: {e}")
536+
return []
537+
538+
tasks.append(asyncio.create_task(send_group(reqs, delay)))
539+
flat_requests.extend(reqs)
540+
541+
group_results = await asyncio.gather(*tasks)
542+
outputs = []
543+
for res in group_results:
544+
if isinstance(res, list):
545+
outputs.extend(res)
546+
547+
if pbar is not None:
548+
pbar.close()
549+
550+
benchmark_duration = time.perf_counter() - start_time
551+
metrics, actual_output_lens = calculate_metrics(
552+
input_requests=flat_requests,
553+
outputs=outputs,
554+
dur_s=benchmark_duration,
555+
tokenizer=tokenizer,
556+
selected_percentiles=[25.0, 50.0, 75.0, 99.0],
557+
goodput_config_dict={"ttft": 2000, "tpot": 50},
579558
)
580-
)
581559

582-
# Define the process_one_metric function, which can access the outer scope's selected_percentile_metrics
583-
def process_one_metric(
584-
metric_attribute_name: str,
585-
metric_name: str,
586-
metric_header: str,
587-
):
588-
selected_percentile_metrics = ["ttft", "tpot", "itl", "e2el"]
589-
if metric_attribute_name not in selected_percentile_metrics:
590-
return
591-
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
560+
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
561+
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
562+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
563+
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
564+
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
592565
print(
593566
"{:<40} {:<10.2f}".format(
594-
f"Mean {metric_name} (ms):",
595-
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
567+
"Request throughput (req/s):", metrics.request_throughput
596568
)
597569
)
598570
print(
599571
"{:<40} {:<10.2f}".format(
600-
f"Median {metric_name} (ms):",
601-
getattr(metrics, f"median_{metric_attribute_name}_ms"),
572+
"Output token throughput (tok/s):", metrics.output_throughput
602573
)
603574
)
604-
# standard deviation
605575
print(
606576
"{:<40} {:<10.2f}".format(
607-
f"Std {metric_name} (ms):",
608-
getattr(metrics, f"std_{metric_attribute_name}_ms"),
577+
"Total Token throughput (tok/s):", metrics.total_token_throughput
609578
)
610579
)
611-
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
612-
p_word = str(int(p)) if int(p) == p else str(p)
613-
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
614-
615-
process_one_metric("ttft", "TTFT", "Time to First Token")
616-
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
617-
process_one_metric("itl", "ITL", "Inter-token Latency")
618-
process_one_metric("e2el", "E2EL", "End-to-end Latency")
619-
print("=" * 50)
620-
621-
output_dir = args.result_dir if args.result_dir is not None else "./"
622-
if args.save_result:
623-
save_metrics_to_file(metrics=metrics, output_dir=output_dir)
624-
save_req_results_to_file(outputs=outputs, output_dir=output_dir)
580+
581+
# Define the process_one_metric function, which can access the outer scope's selected_percentile_metrics
582+
def process_one_metric(
583+
metric_attribute_name: str,
584+
metric_name: str,
585+
metric_header: str,
586+
):
587+
selected_percentile_metrics = ["ttft", "tpot", "itl", "e2el"]
588+
if metric_attribute_name not in selected_percentile_metrics:
589+
return
590+
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
591+
print(
592+
"{:<40} {:<10.2f}".format(
593+
f"Mean {metric_name} (ms):",
594+
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
595+
)
596+
)
597+
print(
598+
"{:<40} {:<10.2f}".format(
599+
f"Median {metric_name} (ms):",
600+
getattr(metrics, f"median_{metric_attribute_name}_ms"),
601+
)
602+
)
603+
# standard deviation
604+
print(
605+
"{:<40} {:<10.2f}".format(
606+
f"Std {metric_name} (ms):",
607+
getattr(metrics, f"std_{metric_attribute_name}_ms"),
608+
)
609+
)
610+
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
611+
p_word = str(int(p)) if int(p) == p else str(p)
612+
print(
613+
"{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)
614+
)
615+
616+
process_one_metric("ttft", "TTFT", "Time to First Token")
617+
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
618+
process_one_metric("itl", "ITL", "Inter-token Latency")
619+
process_one_metric("e2el", "E2EL", "End-to-end Latency")
620+
print("=" * 50)
621+
622+
output_dir = args.result_dir if args.result_dir is not None else "./"
623+
if args.save_result:
624+
save_metrics_to_file(metrics=metrics, output_dir=output_dir)
625+
save_req_results_to_file(outputs=outputs, output_dir=output_dir)
625626
return
626627

627628

@@ -678,6 +679,14 @@ def main(args: argparse.Namespace):
678679

679680

680681
if __name__ == "__main__":
682+
# Check openpyxl for Excel export
683+
try:
684+
import openpyxl
685+
except ImportError:
686+
print("\nMissing package: openpyxl")
687+
print("Please install openpyxl via pip install.\n")
688+
sys.exit(1)
689+
681690
parser = create_argument_trace()
682691
args = parser.parse_args()
683692
main(args)

0 commit comments

Comments
 (0)