2020
2121import aiohttp
2222
23+ from tensorrt_llm .llmapi .disagg_utils import ServerRole
2324from tensorrt_llm .logger import logger
2425from tensorrt_llm .serve .openai_protocol import (
2526 ChatCompletionRequest ,
4243
4344class OpenAIClient (ABC ):
4445 async def send_request (
45- self , server : str , request : UCompletionRequest , hooks : Optional [ResponseHooks ] = None
46+ self ,
47+ request : UCompletionRequest ,
48+ server : Optional [str ] = None ,
49+ hooks : Optional [ResponseHooks ] = None ,
4650 ) -> UCompletionResponseOrGenerator :
4751 if isinstance (request , CompletionRequest ):
4852 return await self ._send_request (
49- server , "v1/completions" , request , CompletionResponse , hooks
53+ "v1/completions" , request , CompletionResponse , server , hooks
5054 )
5155 elif isinstance (request , ChatCompletionRequest ):
5256 return await self ._send_request (
53- server , "v1/chat/completions" , request , ChatCompletionResponse , hooks
57+ "v1/chat/completions" , request , ChatCompletionResponse , server , hooks
5458 )
5559 else :
5660 raise ValueError (f"Invalid request type: { type (request )} " )
5761
5862 @abstractmethod
5963 async def _send_request (
6064 self ,
61- server : str ,
6265 endpoint : str ,
6366 request : UCompletionRequest ,
6467 response_type : Type [UCompletionResponse ],
68+ server : Optional [str ] = None ,
6569 hooks : Optional [ResponseHooks ] = None ,
6670 ) -> UCompletionResponseOrGenerator :
6771 """Send a request to the server and return the response and the body generator.
@@ -90,55 +94,58 @@ class OpenAIHttpClient(OpenAIClient):
9094 def __init__ (
9195 self ,
9296 router : Router ,
93- client_type : str ,
97+ role : ServerRole ,
9498 timeout_secs : int = 180 ,
9599 max_retries : int = 1 ,
100+ retry_interval_sec : int = 1 ,
96101 session : Optional [aiohttp .ClientSession ] = None ,
97102 ):
98- assert client_type in ["ctx" , "gen" ]
99103 self ._router = router
100- self ._client_type = client_type
101- self ._metrics_collector = ClientMetricsCollector (client_type )
104+ self ._role = role
105+ self ._metrics_collector = ClientMetricsCollector (role )
102106 self ._session = session or aiohttp .ClientSession (
103107 connector = aiohttp .TCPConnector (limit = 0 , limit_per_host = 0 , force_close = False ),
104108 timeout = aiohttp .ClientTimeout (total = timeout_secs ),
105109 )
106110 self ._max_retries = max_retries
107- self ._retry_interval = 1
111+ self ._retry_interval_sec = retry_interval_sec
108112
109113 async def _send_request (
110114 self ,
111- server : str ,
112115 endpoint : str ,
113116 request : UCompletionRequest ,
114117 response_type : Type [UCompletionResponse ],
118+ server : Optional [str ] = None ,
115119 hooks : Optional [ResponseHooks ] = None ,
116120 ) -> UCompletionResponseOrGenerator :
117- if len ( server ) == 0 :
121+ if server is None :
118122 server , _ = await self ._router .get_next_server (request )
119123 url = f"http://{ server } /{ endpoint } "
120124 logger .debug (
121- f"Sending { self ._client_type } request { request .disaggregated_params .ctx_request_id } to { url } "
125+ f"Sending { self ._role } request { request .disaggregated_params .ctx_request_id } to { url } "
122126 )
123127 try :
124- self ._metrics_collector .inc ("total_requests" )
128+ self ._metrics_collector .total_requests . inc ()
125129 resp_generator = self ._post_with_retry (server , url , request , hooks )
126130 if request .stream :
131+ # return the response generator, the request is not done yet
127132 return resp_generator
128133 else :
129134 # consume the generator to get the response and return it directly when it's not streaming
130135 response = None
131136 async for resp_json in resp_generator :
132137 response = response_type (** resp_json )
133138 if hooks :
134- if self ._client_type == "ctx" :
139+ if self ._role == ServerRole . CONTEXT :
135140 hooks .on_ctx_resp (server , response )
136141 else :
137142 hooks .on_first_token (server , request )
138143 hooks .on_resp_done (server , request , response )
139144 return response
140145 except Exception :
141- self ._metrics_collector .inc ("error_requests" )
146+ self ._metrics_collector .error_requests .inc ()
147+ # finish the request upon error
148+ await self ._finish_request (request )
142149 raise
143150
144151 async def _post_with_retry (
@@ -163,45 +170,45 @@ async def _post_with_retry(
163170 # do NOT return generator directly here or the response will go
164171 # out of scope and get destroyed
165172 async for line in self ._response_generator (
166- request , http_response , start_time , hooks , server
173+ request , http_response , start_time , server , hooks
167174 ):
168175 yield line
176+ # don't finish the request here since the response generator is not done yet
169177 else :
170178 http_response .raise_for_status ()
171179 response_dict = await http_response .json ()
172180 # yield here since python forbids return statements in async generators
173181 yield response_dict
182+ # finish the request after the successful response
183+ await self ._finish_request (request )
174184 break # break and skip retries if the whole response is processed without exception
175185 except (aiohttp .ClientError , OSError ) as e :
176186 if attempt == self ._max_retries :
177187 logger .error (
178- f"{ self . _client_type } client error to { url } : { e } - last retry { attempt } of { self ._max_retries } "
188+ f"Client error to { url } : { e } - last retry { attempt } of { self ._max_retries } "
179189 "failed" ,
180190 traceback .format_exc (),
181191 )
182192 raise
183-
184193 logger .error (
185- f"{ self ._client_type } client error to { url } : { e } - retry { attempt } of { self ._max_retries } " ,
194+ f"{ self ._role } client error to { url } : { e } - retry { attempt } of { self ._max_retries } " ,
186195 traceback .format_exc (),
187196 )
188- await asyncio .sleep (self ._retry_interval )
189- self ._metrics_collector .inc ("retry_requests" )
197+ await asyncio .sleep (self ._retry_interval_sec )
198+ self ._metrics_collector .retry_requests . inc ()
190199 except Exception as e :
191200 logger .error (
192- f"Unexpected error while processing { self ._client_type } request to { url } : { e } "
201+ f"Unexpected error while processing { self ._role } request to { url } : { e } "
193202 )
194203 raise
195- finally :
196- await self ._finish_request (request )
197204
198205 async def _response_generator (
199206 self ,
200207 request : UCompletionRequest ,
201208 http_response : aiohttp .ClientResponse ,
202209 start_time : float ,
210+ server : str ,
203211 hooks : Optional [ResponseHooks ] = None ,
204- server : str = "" ,
205212 ) -> AsyncGenerator [Any , None ]:
206213 assert request .stream , "Request is not streaming"
207214 assert "text/event-stream" in http_response .headers .get ("Content-Type" , "" ), (
@@ -215,12 +222,12 @@ async def _response_generator(
215222 if i == 0 :
216223 if hooks :
217224 hooks .on_first_token (server , request )
218- self ._metrics_collector .observe (
219- "first_token_latency_seconds" , now_time - last_token_time
225+ self ._metrics_collector .first_token_latency_seconds . observe (
226+ now_time - last_token_time
220227 )
221228 else :
222- self ._metrics_collector .observe (
223- "per_token_latency_seconds" , now_time - last_token_time
229+ self ._metrics_collector .per_token_latency_seconds . observe (
230+ now_time - last_token_time
224231 )
225232 i += 1
226233 if line :
@@ -230,20 +237,20 @@ async def _response_generator(
230237
231238 if hooks :
232239 hooks .on_resp_done (server , request , None )
233- self ._metrics_collector .inc ("completed_requests" )
234- self ._metrics_collector .observe (
235- "complete_latency_seconds" ,
236- get_steady_clock_now_in_seconds () - start_time ,
240+ self ._metrics_collector .completed_requests .inc ()
241+ self ._metrics_collector .complete_latency_seconds .observe (
242+ get_steady_clock_now_in_seconds () - start_time
237243 )
238244 except aiohttp .ClientError as e :
239245 # a client error is expected when the response stream is done if the connector has close=True
240- logger .error (f"{ self ._client_type } Client error: { e } " )
241- self ._metrics_collector .inc ("error_requests" )
246+ logger .error (f"{ self ._role } client { server } error: { e } " )
247+ self ._metrics_collector .error_requests . inc ()
242248 raise
243249 except Exception :
244- self ._metrics_collector .inc ("error_requests" )
250+ self ._metrics_collector .error_requests . inc ()
245251 raise
246252 finally :
253+ # finish the request after streaming response is done or error is raised
247254 await self ._finish_request (request )
248255
249256 async def _finish_request (self , request : UCompletionRequest ) -> None :
0 commit comments