1- # yapf disagrees with isort in pre-commit hooks
21# yapf: disable
32import asyncio
43from abc import ABC , abstractmethod
5- from typing import Any , AsyncGenerator , Dict , List , Tuple , Type , Union
4+ from typing import Any , AsyncGenerator , Dict , List , Optional , Tuple , Type
65
76import aiohttp
87
1413 UCompletionRequest ,
1514 UCompletionResponse )
1615from tensorrt_llm .serve .perf_metrics import DisaggPerfMetricsCollector
17- from tensorrt_llm .serve .responses_utils import (CompletionResponseIterator ,
16+ from tensorrt_llm .serve .responses_utils import (ResponseHooks ,
1817 get_steady_clock_now_in_seconds )
1918from tensorrt_llm .serve .router import Router
2019
2120# yapf: enable
2221
22+ CompletionResponseGenerator = AsyncGenerator [bytes , None ]
23+
2324
2425class OpenAIClient (ABC ):
2526
2627 async def send_request (
27- self , server : str , request : UCompletionRequest
28+ self ,
29+ server : str ,
30+ request : UCompletionRequest ,
31+ hooks : Optional [ResponseHooks ] = None
2832 ) -> Tuple [UCompletionResponse , AsyncGenerator [bytes , None ]]:
2933 if isinstance (request , CompletionRequest ):
30- return await self .send_completion_request (server , request )
34+ return await self ._send_request (server , "v1/completions" , request ,
35+ CompletionResponse , hooks )
3136 elif isinstance (request , ChatCompletionRequest ):
32- return await self .send_chat_request (server , request )
37+ return await self ._send_request (server , "v1/chat/completions" ,
38+ request , ChatCompletionResponse ,
39+ hooks )
3340 else :
3441 raise ValueError (f"Invalid request type: { type (request )} " )
3542
36- async def send_completion_request (
37- self , server : str , request : CompletionRequest
38- ) -> Tuple [CompletionResponse , AsyncGenerator [bytes , None ]]:
39- return await self ._send_request (server , "v1/completions" , request ,
40- CompletionResponse )
41-
42- async def send_chat_request (
43- self , server : str , request : ChatCompletionRequest
44- ) -> Tuple [ChatCompletionResponse , AsyncGenerator [bytes , None ]]:
45- return await self ._send_request (server , "v1/chat/completions" , request ,
46- ChatCompletionResponse )
47-
4843 @abstractmethod
4944 async def _send_request (
5045 self ,
5146 server : str ,
5247 endpoint : str ,
53- request : Union [ CompletionRequest , ChatCompletionRequest ] ,
48+ request : UCompletionRequest ,
5449 response_type : Type [UCompletionResponse ],
50+ hooks : Optional [ResponseHooks ] = None ,
5551 ) -> Tuple [UCompletionResponse , AsyncGenerator [bytes , None ]]:
5652 """
5753 Send a request to the server and return the response and the body iterator.
@@ -95,7 +91,7 @@ def __init__(self,
9591 self ._session = aiohttp .ClientSession (
9692 connector = aiohttp .TCPConnector (limit = 0 ,
9793 limit_per_host = 0 ,
98- force_close = True ),
94+ force_close = False ),
9995 timeout = aiohttp .ClientTimeout (total = timeout_secs ))
10096
10197 async def _send_request (
@@ -104,7 +100,8 @@ async def _send_request(
104100 endpoint : str ,
105101 request : UCompletionRequest ,
106102 response_type : Type [UCompletionResponse ],
107- ) -> Tuple [UCompletionResponse , CompletionResponseIterator ]:
103+ hooks : Optional [ResponseHooks ] = None ,
104+ ) -> Tuple [UCompletionResponse , CompletionResponseGenerator ]:
108105 if len (server ) == 0 :
109106 server , _ = await self ._router .get_next_server (request )
110107 url = f"http://{ server } /{ endpoint } "
@@ -113,36 +110,46 @@ async def _send_request(
113110 self ._perf_metrics_collector .inc (
114111 f"{ self ._client_type } _total_requests" )
115112 async with self ._session .post (
116- url ,
117- json = request . model_dump ( exclude_unset = True )) as response :
118- content_type = response .headers .get ("Content-Type" , "" )
113+ url , json = request . model_dump (
114+ exclude_unset = True )) as http_response :
115+ content_type = http_response .headers .get ("Content-Type" , "" )
119116 if not request .stream and "text/event-stream" in content_type :
120117 raise ValueError (
121118 "Received an event-stream although request stream was False"
122119 )
123120
124- response_dict = await response .json ()
125- if not response .ok :
121+ response_dict = await http_response .json ()
122+ if not http_response .ok :
126123 logger .error (f"Received failed response { response_dict } " )
127- response .raise_for_status ()
128- return response_type (** response_dict ), self ._response_generator (
129- response , start_time )
124+ http_response .raise_for_status ()
125+ response = response_type (** response_dict )
126+
127+ return response , self ._response_generator (
128+ request , http_response , response , start_time , hooks )
130129 except Exception :
131130 self ._perf_metrics_collector .inc (
132131 f"{ self ._client_type } _error_requests" )
133- self .finish_request (request )
132+ await self ._finish_request (request )
134133 raise
135134
136135 async def _response_generator (
137- self , request : UCompletionRequest , response : aiohttp .ClientResponse ,
138- start_time : float ) -> AsyncGenerator [bytes , None ]:
136+ self ,
137+ request : UCompletionRequest ,
138+ http_response : aiohttp .ClientResponse ,
139+ response : UCompletionResponse ,
140+ start_time : float ,
141+ hooks : Optional [ResponseHooks ] = None
142+ ) -> CompletionResponseGenerator :
139143 try :
140- if request .stream and "text/event-stream" in response .headers .get (
144+ if request .stream and "text/event-stream" in http_response .headers .get (
141145 "Content-Type" , "" ):
142146 last_token_time = start_time
143- async for i , line in enumerate (response .content .iter_any ()):
147+ async for i , line in enumerate (
148+ http_response .content .iter_any ()):
144149 now_time = get_steady_clock_now_in_seconds ()
145150 if i == 0 :
151+ if hooks and hooks .on_first_token :
152+ hooks .on_first_token (request , response )
146153 self ._perf_metrics_collector .observe (
147154 f"{ self ._client_type } _first_token_latency_seconds" ,
148155 now_time - last_token_time ,
@@ -152,10 +159,12 @@ async def _response_generator(
152159 f"{ self ._client_type } _per_token_latency_seconds" ,
153160 now_time - last_token_time ,
154161 )
155- last_token_time = now_time
156162 if line :
157163 yield line
158164 await asyncio .sleep (0 )
165+ last_token_time = now_time
166+ if hooks and hooks .on_resp_done :
167+ hooks .on_resp_done (request , response )
159168 self ._perf_metrics_collector .inc (
160169 f"{ self ._client_type } _completed_requests" )
161170 self ._perf_metrics_collector .observe (
@@ -167,10 +176,10 @@ async def _response_generator(
167176 f"{ self ._client_type } _error_requests" )
168177 raise
169178 finally :
170- self .finish_request (request )
179+ await self ._finish_request (request )
171180
172- async def finish_request (self , request : UCompletionRequest ) -> None :
173- self ._router .finish_request (request )
181+ async def _finish_request (self , request : UCompletionRequest ) -> None :
182+ await self ._router .finish_request (request )
174183
175184 async def collect_metrics (self ) -> Dict [str , Any ]:
176185 metrics = {}
0 commit comments