11# yapf: disable
22import asyncio
33from abc import ABC , abstractmethod
4- from typing import Any , AsyncGenerator , Dict , List , Optional , Tuple , Type
4+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
55
66import aiohttp
77
1212 CompletionResponse ,
1313 UCompletionRequest ,
1414 UCompletionResponse )
15- from tensorrt_llm .serve .perf_metrics import DisaggPerfMetricsCollector
16- from tensorrt_llm .serve .responses_utils import (ResponseHooks ,
15+ from tensorrt_llm .serve .perf_metrics import (ClientMetricsCollector ,
16+ DisaggPerfMetricsCollector )
17+ from tensorrt_llm .serve .responses_utils import (CompletionResponseGenerator ,
18+ ResponseHooks ,
1719 get_steady_clock_now_in_seconds )
1820from tensorrt_llm .serve .router import Router
1921
2022# yapf: enable
2123
22- CompletionResponseGenerator = AsyncGenerator [bytes , None ]
23-
2424
2525class OpenAIClient (ABC ):
2626
@@ -29,7 +29,7 @@ async def send_request(
2929 server : str ,
3030 request : UCompletionRequest ,
3131 hooks : Optional [ResponseHooks ] = None
32- ) -> Tuple [UCompletionResponse , AsyncGenerator [ bytes , None ] ]:
32+ ) -> Union [UCompletionResponse , CompletionResponseGenerator ]:
3333 if isinstance (request , CompletionRequest ):
3434 return await self ._send_request (server , "v1/completions" , request ,
3535 CompletionResponse , hooks )
@@ -48,7 +48,7 @@ async def _send_request(
4848 request : UCompletionRequest ,
4949 response_type : Type [UCompletionResponse ],
5050 hooks : Optional [ResponseHooks ] = None ,
51- ) -> Tuple [UCompletionResponse , AsyncGenerator [ bytes , None ] ]:
51+ ) -> Union [UCompletionResponse , CompletionResponseGenerator ]:
5252 """
5353 Send a request to the server and return the response and the body iterator.
5454 The request is finished (in routers) when the generator is exhausted or there is an error.
@@ -83,16 +83,19 @@ def __init__(self,
8383 router : Router ,
8484 client_type : str ,
8585 timeout_secs : int = 180 ,
86+ max_retries : int = 1 ,
8687 perf_metrics_collector : DisaggPerfMetricsCollector = None ):
8788 assert client_type in ["ctx" , "gen" ]
8889 self ._router = router
8990 self ._client_type = client_type
90- self ._perf_metrics_collector = perf_metrics_collector
91+ self ._metrics_collector = ClientMetricsCollector ( client_type )
9192 self ._session = aiohttp .ClientSession (
9293 connector = aiohttp .TCPConnector (limit = 0 ,
9394 limit_per_host = 0 ,
9495 force_close = False ),
9596 timeout = aiohttp .ClientTimeout (total = timeout_secs ))
97+ self ._max_retries = max_retries
98+ self ._retry_interval = 1
9699
97100 async def _send_request (
98101 self ,
@@ -101,79 +104,122 @@ async def _send_request(
101104 request : UCompletionRequest ,
102105 response_type : Type [UCompletionResponse ],
103106 hooks : Optional [ResponseHooks ] = None ,
104- ) -> Tuple [UCompletionResponse , CompletionResponseGenerator ]:
107+ ) -> Union [UCompletionResponse , CompletionResponseGenerator ]:
105108 if len (server ) == 0 :
106109 server , _ = await self ._router .get_next_server (request )
107110 url = f"http://{ server } /{ endpoint } "
108111 try :
109- start_time = get_steady_clock_now_in_seconds ()
110- self ._perf_metrics_collector .inc (
111- f"{ self ._client_type } _total_requests" )
112- async with self ._session .post (
113- url , json = request .model_dump (
114- exclude_unset = True )) as http_response :
115- content_type = http_response .headers .get ("Content-Type" , "" )
116- if not request .stream and "text/event-stream" in content_type :
117- raise ValueError (
118- "Received an event-stream although request stream was False"
119- )
120-
121- response_dict = await http_response .json ()
122- if not http_response .ok :
123- logger .error (f"Received failed response { response_dict } " )
124- http_response .raise_for_status ()
125- response = response_type (** response_dict )
126-
127- return response , self ._response_generator (
128- request , http_response , response , start_time , hooks )
112+ self ._metrics_collector .inc ("total_requests" )
113+ resp_generator = self ._post_with_retry (server , url , request , hooks )
114+ if request .stream :
115+ return resp_generator
116+ else :
117+ # consume the generator to get the response and return it directly when it's not streaming
118+ resp_json = await anext (resp_generator )
119+ response = response_type (** resp_json )
120+ if hooks :
121+ if self ._client_type == "ctx" :
122+ hooks .on_ctx_resp (server , response )
123+ hooks .on_first_token (server , request )
124+ hooks .on_resp_done (server , request , response )
125+ return response
129126 except Exception :
130- self ._perf_metrics_collector .inc (
131- f"{ self ._client_type } _error_requests" )
127+ self ._metrics_collector .inc ("error_requests" )
132128 await self ._finish_request (request )
133129 raise
134130
131+ async def _post_with_retry (
132+ self ,
133+ server : str ,
134+ url : str ,
135+ request : UCompletionRequest ,
136+ hooks : Optional [ResponseHooks ] = None
137+ ) -> Tuple [aiohttp .ClientResponse , Dict [str , Any ]]:
138+ json_data = request .model_dump (exclude_unset = True )
139+ is_stream = request .stream
140+ for attempt in range (self ._max_retries + 1 ):
141+ try :
142+ start_time = get_steady_clock_now_in_seconds ()
143+ async with self ._session .post (url ,
144+ json = json_data ) as http_response :
145+ content_type = http_response .headers .get ("Content-Type" , "" )
146+ if not is_stream and "text/event-stream" in content_type :
147+ raise ValueError (
148+ "Received an event-stream although request stream was False"
149+ )
150+ if is_stream :
151+ # do NOT return generator directly here or the response will go out of scope and get destroyed
152+ async for line in self ._response_generator (
153+ request , http_response , start_time , hooks ,
154+ server ):
155+ yield line
156+ else :
157+ http_response .raise_for_status ()
158+ response_dict = await http_response .json ()
159+ # do yield here until python allows return statements in async generators
160+ yield response_dict
161+ except (aiohttp .ClientError , OSError ) as e :
162+ if attempt == self ._max_retries :
163+ raise
164+ import traceback
165+ logger .error (
166+ f"Client error: { e } - retry { attempt } of { self ._max_retries } " ,
167+ traceback .format_exc ())
168+ await asyncio .sleep (self ._retry_interval )
169+ self ._metrics_collector .inc ("retry_requests" )
170+ except Exception as e :
171+ logger .error (
172+ f"Error encountered while processing request to { url } : { e } " )
173+ raise
174+
135175 async def _response_generator (
136176 self ,
137177 request : UCompletionRequest ,
138178 http_response : aiohttp .ClientResponse ,
139- response : UCompletionResponse ,
140179 start_time : float ,
141- hooks : Optional [ResponseHooks ] = None
142- ) -> CompletionResponseGenerator :
180+ hooks : Optional [ResponseHooks ] = None ,
181+ server : str = "" ) -> CompletionResponseGenerator :
182+ """
183+ If the request is streaming, yield the response line by line,
184+ otherwise, yield nothing because the generator won't be used and the response will be returned directly.
185+ """
186+ assert request .stream , "Request is not streaming"
187+ assert "text/event-stream" in http_response .headers .get (
188+ "Content-Type" , "" ), "Response is not streaming"
143189 try :
144- if request .stream and "text/event-stream" in http_response .headers .get (
145- "Content-Type" , "" ):
146- last_token_time = start_time
147- async for i , line in enumerate (
148- http_response .content .iter_any ()):
149- now_time = get_steady_clock_now_in_seconds ()
150- if i == 0 :
151- if hooks and hooks .on_first_token :
152- hooks .on_first_token (request , response )
153- self ._perf_metrics_collector .observe (
154- f"{ self ._client_type } _first_token_latency_seconds" ,
155- now_time - last_token_time ,
156- )
157- else :
158- self ._perf_metrics_collector .observe (
159- f"{ self ._client_type } _per_token_latency_seconds" ,
160- now_time - last_token_time ,
161- )
162- if line :
163- yield line
164- await asyncio .sleep (0 )
165- last_token_time = now_time
166- if hooks and hooks .on_resp_done :
167- hooks .on_resp_done (request , response )
168- self ._perf_metrics_collector .inc (
169- f"{ self ._client_type } _completed_requests" )
170- self ._perf_metrics_collector .observe (
171- f"{ self ._client_type } _complete_latency_seconds" ,
172- get_steady_clock_now_in_seconds () - start_time ,
173- )
190+ last_token_time = start_time
191+ i = 0
192+ async for line in http_response .content .iter_any ():
193+ now_time = get_steady_clock_now_in_seconds ()
194+ if i == 0 :
195+ if hooks :
196+ hooks .on_first_token (server , request )
197+ self ._metrics_collector .observe (
198+ "first_token_latency_seconds" ,
199+ now_time - last_token_time )
200+ else :
201+ self ._metrics_collector .observe ("per_token_latency_seconds" ,
202+ now_time - last_token_time )
203+ i += 1
204+ if line :
205+ yield line
206+ await asyncio .sleep (0 )
207+ last_token_time = now_time
208+
209+ if hooks :
210+ hooks .on_resp_done (server , request , None )
211+ self ._metrics_collector .inc ("completed_requests" )
212+ self ._metrics_collector .observe (
213+ "complete_latency_seconds" ,
214+ get_steady_clock_now_in_seconds () - start_time ,
215+ )
216+ except aiohttp .ClientError as e :
217+ # a client error is expected when the response stream is done if the connector has close=True
218+ logger .error (f"{ self ._client_type } Client error: { e } " )
219+ self ._metrics_collector .inc ("error_requests" )
220+ raise
174221 except Exception :
175- self ._perf_metrics_collector .inc (
176- f"{ self ._client_type } _error_requests" )
222+ self ._metrics_collector .inc ("error_requests" )
177223 raise
178224 finally :
179225 await self ._finish_request (request )
0 commit comments