Skip to content

Commit 850f245

Browse files
committed
address review comments
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent ade914e commit 850f245

File tree

4 files changed

+138
-118
lines changed

4 files changed

+138
-118
lines changed

tensorrt_llm/serve/openai_client.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import aiohttp
2222

23+
from tensorrt_llm.llmapi.disagg_utils import ServerRole
2324
from tensorrt_llm.logger import logger
2425
from tensorrt_llm.serve.openai_protocol import (
2526
ChatCompletionRequest,
@@ -42,26 +43,29 @@
4243

4344
class OpenAIClient(ABC):
4445
async def send_request(
45-
self, server: str, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
46+
self,
47+
request: UCompletionRequest,
48+
server: Optional[str] = None,
49+
hooks: Optional[ResponseHooks] = None,
4650
) -> UCompletionResponseOrGenerator:
4751
if isinstance(request, CompletionRequest):
4852
return await self._send_request(
49-
server, "v1/completions", request, CompletionResponse, hooks
53+
"v1/completions", request, CompletionResponse, server, hooks
5054
)
5155
elif isinstance(request, ChatCompletionRequest):
5256
return await self._send_request(
53-
server, "v1/chat/completions", request, ChatCompletionResponse, hooks
57+
"v1/chat/completions", request, ChatCompletionResponse, server, hooks
5458
)
5559
else:
5660
raise ValueError(f"Invalid request type: {type(request)}")
5761

5862
@abstractmethod
5963
async def _send_request(
6064
self,
61-
server: str,
6265
endpoint: str,
6366
request: UCompletionRequest,
6467
response_type: Type[UCompletionResponse],
68+
server: Optional[str] = None,
6569
hooks: Optional[ResponseHooks] = None,
6670
) -> UCompletionResponseOrGenerator:
6771
"""Send a request to the server and return the response and the body generator.
@@ -90,55 +94,58 @@ class OpenAIHttpClient(OpenAIClient):
9094
def __init__(
9195
self,
9296
router: Router,
93-
client_type: str,
97+
role: ServerRole,
9498
timeout_secs: int = 180,
9599
max_retries: int = 1,
100+
retry_interval_sec: int = 1,
96101
session: Optional[aiohttp.ClientSession] = None,
97102
):
98-
assert client_type in ["ctx", "gen"]
99103
self._router = router
100-
self._client_type = client_type
101-
self._metrics_collector = ClientMetricsCollector(client_type)
104+
self._role = role
105+
self._metrics_collector = ClientMetricsCollector(role)
102106
self._session = session or aiohttp.ClientSession(
103107
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=False),
104108
timeout=aiohttp.ClientTimeout(total=timeout_secs),
105109
)
106110
self._max_retries = max_retries
107-
self._retry_interval = 1
111+
self._retry_interval_sec = retry_interval_sec
108112

109113
async def _send_request(
110114
self,
111-
server: str,
112115
endpoint: str,
113116
request: UCompletionRequest,
114117
response_type: Type[UCompletionResponse],
118+
server: Optional[str] = None,
115119
hooks: Optional[ResponseHooks] = None,
116120
) -> UCompletionResponseOrGenerator:
117-
if len(server) == 0:
121+
if server is None:
118122
server, _ = await self._router.get_next_server(request)
119123
url = f"http://{server}/{endpoint}"
120124
logger.debug(
121-
f"Sending {self._client_type} request {request.disaggregated_params.ctx_request_id} to {url}"
125+
f"Sending {self._role} request {request.disaggregated_params.ctx_request_id} to {url}"
122126
)
123127
try:
124-
self._metrics_collector.inc("total_requests")
128+
self._metrics_collector.total_requests.inc()
125129
resp_generator = self._post_with_retry(server, url, request, hooks)
126130
if request.stream:
131+
# return the response generator, the request is not done yet
127132
return resp_generator
128133
else:
129134
# consume the generator to get the response and return it directly when it's not streaming
130135
response = None
131136
async for resp_json in resp_generator:
132137
response = response_type(**resp_json)
133138
if hooks:
134-
if self._client_type == "ctx":
139+
if self._role == ServerRole.CONTEXT:
135140
hooks.on_ctx_resp(server, response)
136141
else:
137142
hooks.on_first_token(server, request)
138143
hooks.on_resp_done(server, request, response)
139144
return response
140145
except Exception:
141-
self._metrics_collector.inc("error_requests")
146+
self._metrics_collector.error_requests.inc()
147+
# finish the request upon error
148+
await self._finish_request(request)
142149
raise
143150

144151
async def _post_with_retry(
@@ -163,45 +170,45 @@ async def _post_with_retry(
163170
# do NOT return generator directly here or the response will go
164171
# out of scope and get destroyed
165172
async for line in self._response_generator(
166-
request, http_response, start_time, hooks, server
173+
request, http_response, start_time, server, hooks
167174
):
168175
yield line
176+
# don't finish the request here since the response generator is not done yet
169177
else:
170178
http_response.raise_for_status()
171179
response_dict = await http_response.json()
172180
# yield here since python forbids return statements in async generators
173181
yield response_dict
182+
# finish the request after the successful response
183+
await self._finish_request(request)
174184
break # break and skip retries if the whole response is processed without exception
175185
except (aiohttp.ClientError, OSError) as e:
176186
if attempt == self._max_retries:
177187
logger.error(
178-
f"{self._client_type} client error to {url}: {e} - last retry {attempt} of {self._max_retries}"
188+
f"Client error to {url}: {e} - last retry {attempt} of {self._max_retries}"
179189
"failed",
180190
traceback.format_exc(),
181191
)
182192
raise
183-
184193
logger.error(
185-
f"{self._client_type} client error to {url}: {e} - retry {attempt} of {self._max_retries}",
194+
f"{self._role} client error to {url}: {e} - retry {attempt} of {self._max_retries}",
186195
traceback.format_exc(),
187196
)
188-
await asyncio.sleep(self._retry_interval)
189-
self._metrics_collector.inc("retry_requests")
197+
await asyncio.sleep(self._retry_interval_sec)
198+
self._metrics_collector.retry_requests.inc()
190199
except Exception as e:
191200
logger.error(
192-
f"Unexpected error while processing {self._client_type} request to {url}: {e}"
201+
f"Unexpected error while processing {self._role} request to {url}: {e}"
193202
)
194203
raise
195-
finally:
196-
await self._finish_request(request)
197204

198205
async def _response_generator(
199206
self,
200207
request: UCompletionRequest,
201208
http_response: aiohttp.ClientResponse,
202209
start_time: float,
210+
server: str,
203211
hooks: Optional[ResponseHooks] = None,
204-
server: str = "",
205212
) -> AsyncGenerator[Any, None]:
206213
assert request.stream, "Request is not streaming"
207214
assert "text/event-stream" in http_response.headers.get("Content-Type", ""), (
@@ -215,12 +222,12 @@ async def _response_generator(
215222
if i == 0:
216223
if hooks:
217224
hooks.on_first_token(server, request)
218-
self._metrics_collector.observe(
219-
"first_token_latency_seconds", now_time - last_token_time
225+
self._metrics_collector.first_token_latency_seconds.observe(
226+
now_time - last_token_time
220227
)
221228
else:
222-
self._metrics_collector.observe(
223-
"per_token_latency_seconds", now_time - last_token_time
229+
self._metrics_collector.per_token_latency_seconds.observe(
230+
now_time - last_token_time
224231
)
225232
i += 1
226233
if line:
@@ -230,20 +237,20 @@ async def _response_generator(
230237

231238
if hooks:
232239
hooks.on_resp_done(server, request, None)
233-
self._metrics_collector.inc("completed_requests")
234-
self._metrics_collector.observe(
235-
"complete_latency_seconds",
236-
get_steady_clock_now_in_seconds() - start_time,
240+
self._metrics_collector.completed_requests.inc()
241+
self._metrics_collector.complete_latency_seconds.observe(
242+
get_steady_clock_now_in_seconds() - start_time
237243
)
238244
except aiohttp.ClientError as e:
239245
# a client error is expected when the response stream is done if the connector has close=True
240-
logger.error(f"{self._client_type} Client error: {e}")
241-
self._metrics_collector.inc("error_requests")
246+
logger.error(f"{self._role} client {server} error: {e}")
247+
self._metrics_collector.error_requests.inc()
242248
raise
243249
except Exception:
244-
self._metrics_collector.inc("error_requests")
250+
self._metrics_collector.error_requests.inc()
245251
raise
246252
finally:
253+
# finish the request after streaming response is done or error is raised
247254
await self._finish_request(request)
248255

249256
async def _finish_request(self, request: UCompletionRequest) -> None:

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tensorrt_llm.executor.executor import CppExecutorError
3232
from tensorrt_llm.llmapi import tracing
3333
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
34-
MetadataServerConfig,
34+
MetadataServerConfig, ServerRole,
3535
get_ctx_gen_server_addrs)
3636
from tensorrt_llm.logger import logger
3737
from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer,
@@ -136,8 +136,8 @@ async def validation_exception_handler(_, exc):
136136

137137
self.register_routes()
138138

139-
def _create_client(self, router: Router, client_type: str, max_retries) -> OpenAIClient:
140-
client = OpenAIHttpClient(router, client_type, self._req_timeout_secs, max_retries)
139+
def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) -> OpenAIClient:
140+
client = OpenAIHttpClient(router, role, self._req_timeout_secs, max_retries)
141141
self._perf_metrics_collector.add_client(client)
142142
return client
143143

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
config: DisaggServerConfig,
5353
ctx_router: Router,
5454
gen_router: Router,
55-
client_factory: Callable[[Router, str], OpenAIClient],
55+
client_factory: Callable[[Router, ServerRole], OpenAIClient],
5656
metadata_server: Optional[JsonDictionary] = None,
5757
metadata_config: Optional[MetadataServerConfig] = None,
5858
req_timeout_secs: int = 180,
@@ -106,21 +106,25 @@ async def _send_disagg_request(
106106
if hooks:
107107
hooks.on_req_begin(request)
108108
# empty server means client decides which server to use
109-
gen_server = ""
110-
ctx_server = ""
109+
reserved_gen_server = None
110+
reserved_ctx_server = None
111111
# reserve a gen_server if conditional disagg is needed
112-
gen_server, need_ctx = await self._check_conditional_disagg(request)
112+
reserved_gen_server, need_ctx = await self._check_conditional_disagg(request)
113113
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
114114
ctx_response = None
115115
gen_req = request
116116
if need_ctx:
117117
ctx_req = self._get_ctx_request(request)
118118
# ctx generator is empty
119-
ctx_response = await self._ctx_client.send_request(ctx_server, ctx_req, hooks)
119+
ctx_response = await self._ctx_client.send_request(
120+
ctx_req, server=reserved_ctx_server, hooks=hooks
121+
)
120122
await self._verify_ctx_response(ctx_response)
121123
gen_req = self._get_gen_request(request, ctx_response)
122124
if ctx_response is None or self._need_gen(ctx_response):
123-
return await self._gen_client.send_request(gen_server, gen_req, hooks)
125+
return await self._gen_client.send_request(
126+
gen_req, server=reserved_gen_server, hooks=hooks
127+
)
124128
else:
125129
if request.stream:
126130
# ctx client will never return a generator when streaming is requested
@@ -170,7 +174,7 @@ async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
170174
):
171175
return gen_server, True
172176
return gen_server, False
173-
return "", True
177+
return None, True
174178

175179
async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool:
176180
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1":
@@ -206,8 +210,12 @@ def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]:
206210
return self._config.conditional_disagg_config
207211

208212
async def setup(self) -> None:
209-
self._ctx_client = self._client_factory(self._ctx_router, "ctx", self._config.max_retries)
210-
self._gen_client = self._client_factory(self._gen_router, "gen", self._config.max_retries)
213+
self._ctx_client = self._client_factory(
214+
self._ctx_router, ServerRole.CONTEXT, self._config.max_retries
215+
)
216+
self._gen_client = self._client_factory(
217+
self._gen_router, ServerRole.GENERATION, self._config.max_retries
218+
)
211219

212220
if self.disagg_cluster_config and self._cluster_storage:
213221
logger.info("Starting disagg cluster manager")
@@ -263,12 +271,17 @@ async def check_servers_ready():
263271
async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEventType):
264272
router_map = {ServerRole.CONTEXT: self._ctx_router, ServerRole.GENERATION: self._gen_router}
265273
worker_addr = f"{worker_info.host}:{worker_info.port}"
266-
router = router_map[worker_info.role]
267-
if event_type == WatchEventType.SET:
268-
await router.add_server(worker_addr)
269-
elif event_type == WatchEventType.DELETE:
270-
await router.remove_server(worker_addr)
271-
logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}")
274+
try:
275+
router = router_map[worker_info.role]
276+
if event_type == WatchEventType.SET:
277+
await router.add_server(worker_addr)
278+
elif event_type == WatchEventType.DELETE:
279+
await router.remove_server(worker_addr)
280+
logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}")
281+
except KeyError:
282+
logger.error(
283+
f"Unknown worker role: {worker_info.role}, Worker {worker_info.worker_id} event: {event_type.name}"
284+
)
272285

273286
async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
274287
if ctx_response:
@@ -281,49 +294,3 @@ async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
281294
if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
282295
raise ValueError("Invalid disaggregated params in context phase response.")
283296
return ctx_response
284-
285-
286-
# FIXME: This is a demo to show the basic idea of disagg-service with pre-allocating generation
287-
class OpenAIDisaggregatedPreAllocService(OpenAIDisaggregatedService):
288-
def _need_gen(self, request: UCompletionRequest) -> bool:
289-
if isinstance(request, CompletionRequest) and request.max_tokens is not None:
290-
return request.max_tokens > 1
291-
if isinstance(request, ChatCompletionRequest) and request.max_completion_tokens is not None:
292-
return request.max_completion_tokens > 1
293-
return False
294-
295-
async def _send_disagg_request(
296-
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
297-
) -> UCompletionResponseOrGenerator:
298-
if hooks:
299-
hooks.on_req_begin(request)
300-
# empty server means client decides which server to use
301-
gen_server = ""
302-
ctx_server = ""
303-
# reserve a gen_server if conditional disagg is needed
304-
gen_server, need_ctx = await self._check_conditional_disagg(request)
305-
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
306-
need_gen = self._need_gen(request)
307-
# send ctx and gen requests in parallel
308-
assert need_gen or need_ctx, "Neither generation nor context is required"
309-
gen_task = None
310-
ctx_task = None
311-
tasks = []
312-
313-
async def _run_ctx_task():
314-
# send ctx request and gen request in parallel
315-
ctx_req = self._get_ctx_request(request)
316-
ctx_response = await self._ctx_client.send_request(ctx_server, ctx_req, hooks)
317-
return ctx_response
318-
319-
if need_ctx:
320-
ctx_task = asyncio.create_task(_run_ctx_task())
321-
if need_gen:
322-
gen_task = asyncio.create_task(
323-
self._gen_client.send_request(gen_server, request, hooks)
324-
)
325-
tasks.append(gen_task)
326-
await asyncio.gather(*tasks)
327-
if need_gen:
328-
return gen_task.result()
329-
return ctx_task.result()

0 commit comments

Comments
 (0)