-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
111 lines (92 loc) · 3.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import multiprocessing
import os
from tracestorm.logger import init_logger
from tracestorm.request_generator import generate_request
from tracestorm.result_analyzer import ResultAnalyzer
from tracestorm.trace_generator import generate_trace
from tracestorm.trace_player import play
from tracestorm.utils import round_robin_shard
logger = init_logger(__name__)
def get_args():
parser = argparse.ArgumentParser(
description="Run a replay of OpenAI requests."
)
parser.add_argument("--model", required=True, help="Model name")
parser.add_argument(
"--rps", type=int, default=1, help="Requests per second"
)
parser.add_argument(
"--pattern", default="uniform", help="Pattern for generating trace"
)
parser.add_argument(
"--duration", type=int, default=10, help="Duration in seconds"
)
parser.add_argument(
"--subprocesses", type=int, default=1, help="Number of subprocesses"
)
parser.add_argument(
"--base-url",
default=os.environ.get("OPENAI_BASE_URL", "http://localhost:8000/v1"),
help="OpenAI Base URL",
)
parser.add_argument(
"--api-key",
default=os.environ.get("OPENAI_API_KEY", "none"),
help="OpenAI API Key",
)
return parser.parse_args()
def main():
args = get_args()
raw_trace = generate_trace(args.rps, args.pattern, args.duration)
total_requests = len(raw_trace)
logger.debug(f"Raw trace: {raw_trace}")
requests = generate_request(args.model, total_requests)
logger.debug(f"Requests: {requests}")
ipc_queue = multiprocessing.Queue()
processes = []
if total_requests == 0:
logger.warning("No requests to process. Trace is empty.")
return
# Launch subprocesses
for i, (partial_trace, partial_requests) in enumerate(
round_robin_shard(raw_trace, requests, args.subprocesses), start=1
):
p = multiprocessing.Process(
target=play,
args=(
f"TracePlayer-{i}",
partial_trace,
partial_requests,
args.base_url,
args.api_key,
ipc_queue,
),
)
p.start()
processes.append(p)
results_collected = 0
aggregated_results = []
while results_collected < total_requests:
try:
name, timestamp, resp = ipc_queue.get(timeout=30)
results_collected += 1
logger.info(
f"Received result from {name} for timestamp {timestamp}: {resp['token_count']} tokens"
)
aggregated_results.append((name, timestamp, resp))
except Exception as e:
logger.error(
f"Timeout or error reading from IPC queue: {e}", exc_info=True
)
break
for p in processes:
p.join()
logger.info("All subprocesses have finished.")
logger.debug(f"Aggregated results: {aggregated_results}")
result_analyzer = ResultAnalyzer()
result_analyzer.store_raw_results(aggregated_results)
print(result_analyzer)
result_analyzer.plot_cdf()
if __name__ == "__main__":
main()