2626import aiohttp
2727import numpy as np
2828from transformers import AutoModelForCausalLM , PreTrainedTokenizerBase
29+ from tqdm .asyncio import tqdm
2930
3031from transformers import AutoTokenizer , PreTrainedTokenizer , PreTrainedTokenizerFast
3132
@@ -167,22 +168,32 @@ def to_openai_role(role_value: str) -> str:
167168async def get_request (
168169 input_requests : List [Tuple [List [dict ], str , int , int ]],
169170 request_rate : float ,
171+ concurrency : int = None ,
170172) -> AsyncGenerator [Tuple [List [dict ], str , int , int ], None ]:
171173 input_requests = iter (input_requests )
172- for request in input_requests :
173- yield request
174174
175- if request_rate == float ("inf" ):
176- # If the request rate is infinity, then we don't need to wait.
177- continue
178- # Sample the request interval from the exponential distribution.
179- interval = np .random .exponential (1.0 / request_rate )
180- # The next request will be sent after the interval.
181- await asyncio .sleep (interval )
175+ if concurrency is not None :
176+ # Concurrency-based request generation
177+ # This generator will be consumed by the benchmark function
178+ # which will manage the concurrency
179+ for request in input_requests :
180+ yield request
181+ else :
182+ # Rate-based request generation (original logic)
183+ for request in input_requests :
184+ yield request
185+
186+ if request_rate == float ("inf" ):
187+ # If the request rate is infinity, then we don't need to wait.
188+ continue
189+ # Sample the request interval from the exponential distribution.
190+ interval = np .random .exponential (1.0 / request_rate )
191+ # The next request will be sent after the interval.
192+ await asyncio .sleep (interval )
182193
183194
184195async def send_request (
185- messages : List [dict ], rendered_prompt : str , prompt_len : int , output_len : int , use_openai_api : bool
196+ messages : List [dict ], rendered_prompt : str , prompt_len : int , output_len : int , use_openai_api : bool , pbar = None
186197) -> None :
187198 if use_openai_api :
188199 # Use OpenAI API to send the request.
@@ -216,7 +227,7 @@ async def send_request(
216227 if is_first :
217228 is_first = False
218229 ttft = delta_time
219- text += json .loads (chunk .decode ("utf-8" )[6 :])["choices" ][0 ]["delta" ].get ("content" , "" )
230+ # text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "")
220231 if delta_time < 0.005 :
221232 receive_n += 1
222233 chunks .append (delta_time )
@@ -261,18 +272,50 @@ async def send_request(
261272 request_latency = request_end_time - request_start_time
262273 REQUEST_LATENCY .append ((prompt_len , output_len , request_latency , ttft ))
263274
275+ # Update progress bar if provided
276+ if pbar :
277+ pbar .update (1 )
278+
264279
265280async def benchmark (
266281 input_requests : List [Tuple [List [dict ], str , int , int ]],
267282 request_rate : float ,
268283 use_openai_api : bool = False ,
284+ concurrency : int = None ,
269285) -> None :
270- tasks : List [asyncio .Task ] = []
271- async for request in get_request (input_requests , request_rate ):
272- messages , rendered_prompt , prompt_len , output_len = request
273- task = asyncio .create_task (send_request (messages , rendered_prompt , prompt_len , output_len , use_openai_api ))
274- tasks .append (task )
275- await asyncio .gather (* tasks )
286+ total_requests = len (input_requests )
287+
288+ # Create progress bar
289+ pbar = tqdm (total = total_requests , desc = "Processing requests" , unit = "req" )
290+
291+ if concurrency is not None :
292+ # Concurrency-based processing
293+ semaphore = asyncio .Semaphore (concurrency )
294+ tasks : List [asyncio .Task ] = []
295+
296+ async def send_with_semaphore (messages , rendered_prompt , prompt_len , output_len ):
297+ async with semaphore :
298+ await send_request (messages , rendered_prompt , prompt_len , output_len , use_openai_api , pbar )
299+
300+ async for request in get_request (input_requests , request_rate , concurrency ):
301+ messages , rendered_prompt , prompt_len , output_len = request
302+ task = asyncio .create_task (send_with_semaphore (messages , rendered_prompt , prompt_len , output_len ))
303+ tasks .append (task )
304+
305+ await asyncio .gather (* tasks )
306+ else :
307+ # Rate-based processing (original logic)
308+ tasks : List [asyncio .Task ] = []
309+ async for request in get_request (input_requests , request_rate , concurrency ):
310+ messages , rendered_prompt , prompt_len , output_len = request
311+ task = asyncio .create_task (
312+ send_request (messages , rendered_prompt , prompt_len , output_len , use_openai_api , pbar )
313+ )
314+ tasks .append (task )
315+ await asyncio .gather (* tasks )
316+
317+ # Close progress bar
318+ pbar .close ()
276319
277320
278321def main (args : argparse .Namespace ):
@@ -285,7 +328,7 @@ def main(args: argparse.Namespace):
285328 )
286329
287330 benchmark_start_time = time .time ()
288- asyncio .run (benchmark (input_requests , args .request_rate , args .use_openai_api ))
331+ asyncio .run (benchmark (input_requests , args .request_rate , args .use_openai_api , args . concurrency ))
289332 benchmark_end_time = time .time ()
290333 benchmark_time = benchmark_end_time - benchmark_start_time
291334 print (f"Total time: { benchmark_time :.2f} s" )
@@ -325,11 +368,22 @@ def main(args: argparse.Namespace):
325368 "Otherwise, we use Poisson process to synthesize "
326369 "the request arrival times." ,
327370 )
371+ parser .add_argument (
372+ "--concurrency" ,
373+ type = int ,
374+ default = None ,
375+ help = "Number of concurrent requests to maintain. " "Cannot be used together with --request-rate." ,
376+ )
328377 parser .add_argument ("--num-prompts" , type = int , default = 1000 , help = "Number of prompts to process." )
329378 parser .add_argument (
330379 "--history-turns" , type = int , default = 6 , help = "Max number of context turns before the target assistant reply."
331380 )
332381 parser .add_argument ("--max-total-tokens" , type = int , default = 16384 , help = "Max total tokens (input + output)." )
333382 parser .add_argument ("--seed" , type = int , default = 0 )
334383 args = parser .parse_args ()
384+
385+ # Validate that only one of request_rate or concurrency is set
386+ if args .concurrency is not None and args .request_rate != float ("inf" ):
387+ raise ValueError ("Cannot set both --request-rate and --concurrency. Please use only one." )
388+
335389 main (args )
0 commit comments