1+ # Copyright (c) 2025, NVIDIA CORPORATION.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115# yapf: disable
216import asyncio
17+ import traceback
318from abc import ABC , abstractmethod
4- from typing import Any , Dict , List , Optional , Tuple , Type , Union
19+ from typing import Any , AsyncGenerator , Dict , List , Optional , Tuple , Type
520
621import aiohttp
722
1631)
1732from tensorrt_llm .serve .perf_metrics import ClientMetricsCollector , DisaggPerfMetricsCollector
1833from tensorrt_llm .serve .responses_utils import (
19- CompletionResponseGenerator ,
2034 ResponseHooks ,
35+ UCompletionResponseOrGenerator ,
2136 get_steady_clock_now_in_seconds ,
2237)
2338from tensorrt_llm .serve .router import Router
2843class OpenAIClient (ABC ):
2944 async def send_request (
3045 self , server : str , request : UCompletionRequest , hooks : Optional [ResponseHooks ] = None
31- ) -> Union [ UCompletionResponse , CompletionResponseGenerator ] :
46+ ) -> UCompletionResponseOrGenerator :
3247 if isinstance (request , CompletionRequest ):
3348 return await self ._send_request (
3449 server , "v1/completions" , request , CompletionResponse , hooks
@@ -48,8 +63,9 @@ async def _send_request(
4863 request : UCompletionRequest ,
4964 response_type : Type [UCompletionResponse ],
5065 hooks : Optional [ResponseHooks ] = None ,
51- ) -> Union [UCompletionResponse , CompletionResponseGenerator ]:
52- """Send a request to the server and return the response and the body iterator.
66+ ) -> UCompletionResponseOrGenerator :
67+ """Send a request to the server and return the response and the body generator.
68+
5369 The request is finished (in routers) when the generator is exhausted or there is an error.
5470 """
5571 ...
@@ -59,7 +75,7 @@ async def collect_metrics(self) -> Dict[str, Any]: ...
5975
6076 @abstractmethod
6177 async def check_ready (self ) -> Tuple [List [str ], List [str ]]:
62- """Return the list of ready servers and the list of unready servers"""
78+ """Return the list of ready servers and the list of unready servers. """
6379 ...
6480
6581 async def shutdown (self ) -> None : ...
@@ -97,28 +113,32 @@ async def _send_request(
97113 request : UCompletionRequest ,
98114 response_type : Type [UCompletionResponse ],
99115 hooks : Optional [ResponseHooks ] = None ,
100- ) -> Union [ UCompletionResponse , CompletionResponseGenerator ] :
116+ ) -> UCompletionResponseOrGenerator :
101117 if len (server ) == 0 :
102118 server , _ = await self ._router .get_next_server (request )
103119 url = f"http://{ server } /{ endpoint } "
120+ logger .debug (
121+ f"Sending { self ._client_type } request { request .disaggregated_params .ctx_request_id } to { url } "
122+ )
104123 try :
105124 self ._metrics_collector .inc ("total_requests" )
106125 resp_generator = self ._post_with_retry (server , url , request , hooks )
107126 if request .stream :
108127 return resp_generator
109128 else :
110129 # consume the generator to get the response and return it directly when it's not streaming
111- resp_json = await anext (resp_generator )
112- response = response_type (** resp_json )
113- if hooks :
114- if self ._client_type == "ctx" :
115- hooks .on_ctx_resp (server , response )
116- hooks .on_first_token (server , request )
117- hooks .on_resp_done (server , request , response )
130+ response = None
131+ async for resp_json in resp_generator :
132+ response = response_type (** resp_json )
133+ if hooks :
134+ if self ._client_type == "ctx" :
135+ hooks .on_ctx_resp (server , response )
136+ else :
137+ hooks .on_first_token (server , request )
138+ hooks .on_resp_done (server , request , response )
118139 return response
119140 except Exception :
120141 self ._metrics_collector .inc ("error_requests" )
121- await self ._finish_request (request )
122142 raise
123143
124144 async def _post_with_retry (
@@ -127,7 +147,7 @@ async def _post_with_retry(
127147 url : str ,
128148 request : UCompletionRequest ,
129149 hooks : Optional [ResponseHooks ] = None ,
130- ) -> Tuple [ aiohttp . ClientResponse , Dict [ str , Any ] ]:
150+ ) -> AsyncGenerator [ Any , None ]:
131151 json_data = request .model_dump (exclude_unset = True )
132152 is_stream = request .stream
133153 for attempt in range (self ._max_retries + 1 ):
@@ -149,22 +169,31 @@ async def _post_with_retry(
149169 else :
150170 http_response .raise_for_status ()
151171 response_dict = await http_response .json ()
152- # do yield here until python allows return statements in async generators
172+ # yield here since python forbids return statements in async generators
153173 yield response_dict
174+ break # break and skip retries if the whole response is processed without exception
154175 except (aiohttp .ClientError , OSError ) as e :
155176 if attempt == self ._max_retries :
177+ logger .error (
178+ f"{ self ._client_type } client error to { url } : { e } - last retry { attempt } of { self ._max_retries } "
179+ "failed" ,
180+ traceback .format_exc (),
181+ )
156182 raise
157- import traceback
158183
159184 logger .error (
160- f"Client error: { e } - retry { attempt } of { self ._max_retries } " ,
185+ f"{ self . _client_type } client error to { url } : { e } - retry { attempt } of { self ._max_retries } " ,
161186 traceback .format_exc (),
162187 )
163188 await asyncio .sleep (self ._retry_interval )
164189 self ._metrics_collector .inc ("retry_requests" )
165190 except Exception as e :
166- logger .error (f"Error encountered while processing request to { url } : { e } " )
191+ logger .error (
192+ f"Unexpected error while processing { self ._client_type } request to { url } : { e } "
193+ )
167194 raise
195+ finally :
196+ await self ._finish_request (request )
168197
169198 async def _response_generator (
170199 self ,
@@ -173,7 +202,7 @@ async def _response_generator(
173202 start_time : float ,
174203 hooks : Optional [ResponseHooks ] = None ,
175204 server : str = "" ,
176- ) -> CompletionResponseGenerator :
205+ ) -> AsyncGenerator [ Any , None ] :
177206 assert request .stream , "Request is not streaming"
178207 assert "text/event-stream" in http_response .headers .get ("Content-Type" , "" ), (
179208 "Response is not streaming"
0 commit comments