Skip to content

Commit 50348fd

Browse files
committed
fix tests
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 577e9db commit 50348fd

File tree

8 files changed

+189
-76
lines changed

8 files changed

+189
-76
lines changed

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import time
66
from dataclasses import asdict, dataclass
7-
from typing import Any, Callable, Dict, List, Optional, Tuple
7+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
88

99
from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole
1010
from tensorrt_llm.logger import logger
@@ -97,7 +97,7 @@ async def watch_workers(
9797
self,
9898
get_existing_first: bool = True,
9999
on_event: Optional[Callable[[WorkerInfo, WatchEventType],
100-
None]] = None):
100+
Awaitable[Any]]] = None):
101101
if self._watch_handle:
102102
logger.error("Watch handle is already initialized")
103103
return []
@@ -121,25 +121,28 @@ async def watch_workers(
121121
self._watch_handle = await self._cluster_storage.watch(
122122
self.worker_key_prefix)
123123

124-
async def on_event_wrapper():
125-
logger.warning(
126-
f"Initializing watch task with {len(workers)} existing workers")
124+
async def worker_event_loop():
125+
logger.info(
126+
f"Start watching worker events with {len(workers)} existing workers"
127+
)
127128
for worker_info in workers:
128129
await on_event(worker_info, WatchEventType.SET)
129-
logger.warning("Start watching worker events")
130130
while True:
131131
try:
132132
worker_events = await self._watch_handle.drain()
133133
for event in worker_events:
134134
worker_info = self._parse_worker_info(event)
135135
await on_event(worker_info, event.event_type)
136+
except asyncio.CancelledError:
137+
break
136138
except Exception as e:
137139
logger.error(
138140
f"Error updating routers by worker events: {e}")
139141
await asyncio.sleep(1)
142+
logger.info("Stop watching worker events")
140143

141144
if on_event:
142-
self._watch_task = asyncio.create_task(on_event_wrapper())
145+
self._watch_task = asyncio.create_task(worker_event_loop())
143146
return workers
144147

145148
async def unwatch_workers(self) -> None:

tensorrt_llm/serve/openai_client.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
# yapf: disable
216
import asyncio
17+
import traceback
318
from abc import ABC, abstractmethod
419
from typing import Any, Dict, List, Optional, Tuple, Type, Union
520

@@ -18,6 +33,7 @@
1833
from tensorrt_llm.serve.responses_utils import (
1934
CompletionResponseGenerator,
2035
ResponseHooks,
36+
UCompletionResponseOrGenerator,
2137
get_steady_clock_now_in_seconds,
2238
)
2339
from tensorrt_llm.serve.router import Router
@@ -48,8 +64,9 @@ async def _send_request(
4864
request: UCompletionRequest,
4965
response_type: Type[UCompletionResponse],
5066
hooks: Optional[ResponseHooks] = None,
51-
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
52-
"""Send a request to the server and return the response and the body iterator.
67+
) -> UCompletionResponseOrGenerator:
68+
"""Send a request to the server and return the response and the body generator.
69+
5370
The request is finished (in routers) when the generator is exhausted or there is an error.
5471
"""
5572
...
@@ -59,7 +76,7 @@ async def collect_metrics(self) -> Dict[str, Any]: ...
5976

6077
@abstractmethod
6178
async def check_ready(self) -> Tuple[List[str], List[str]]:
62-
"""Return the list of ready servers and the list of unready servers"""
79+
"""Return the list of ready servers and the list of unready servers."""
6380
...
6481

6582
async def shutdown(self) -> None: ...
@@ -101,24 +118,28 @@ async def _send_request(
101118
if len(server) == 0:
102119
server, _ = await self._router.get_next_server(request)
103120
url = f"http://{server}/{endpoint}"
121+
logger.debug(
122+
f"Sending {self._client_type} request {request.disaggregated_params.ctx_request_id} to {url}"
123+
)
104124
try:
105125
self._metrics_collector.inc("total_requests")
106126
resp_generator = self._post_with_retry(server, url, request, hooks)
107127
if request.stream:
108128
return resp_generator
109129
else:
110130
# consume the generator to get the response and return it directly when it's not streaming
111-
resp_json = await anext(resp_generator)
112-
response = response_type(**resp_json)
113-
if hooks:
114-
if self._client_type == "ctx":
115-
hooks.on_ctx_resp(server, response)
116-
hooks.on_first_token(server, request)
117-
hooks.on_resp_done(server, request, response)
131+
response = None
132+
async for resp_json in resp_generator:
133+
response = response_type(**resp_json)
134+
if hooks:
135+
if self._client_type == "ctx":
136+
hooks.on_ctx_resp(server, response)
137+
else:
138+
hooks.on_first_token(server, request)
139+
hooks.on_resp_done(server, request, response)
118140
return response
119141
except Exception:
120142
self._metrics_collector.inc("error_requests")
121-
await self._finish_request(request)
122143
raise
123144

124145
async def _post_with_retry(
@@ -149,22 +170,31 @@ async def _post_with_retry(
149170
else:
150171
http_response.raise_for_status()
151172
response_dict = await http_response.json()
152-
# do yield here until python allows return statements in async generators
173+
# yield here since python forbids return statements in async generators
153174
yield response_dict
175+
break # break and skip retries if the whole response is processed without exception
154176
except (aiohttp.ClientError, OSError) as e:
155177
if attempt == self._max_retries:
178+
logger.error(
179+
f"{self._client_type} client error to {url}: {e} - last retry {attempt} of {self._max_retries}"
180+
"failed",
181+
traceback.format_exc(),
182+
)
156183
raise
157-
import traceback
158184

159185
logger.error(
160-
f"Client error: {e} - retry {attempt} of {self._max_retries}",
186+
f"{self._client_type} client error to {url}: {e} - retry {attempt} of {self._max_retries}",
161187
traceback.format_exc(),
162188
)
163189
await asyncio.sleep(self._retry_interval)
164190
self._metrics_collector.inc("retry_requests")
165191
except Exception as e:
166-
logger.error(f"Error encountered while processing request to {url}: {e}")
192+
logger.error(
193+
f"Unexpected error while processing {self._client_type} request to {url}: {e}"
194+
)
167195
raise
196+
finally:
197+
await self._finish_request(request)
168198

169199
async def _response_generator(
170200
self,

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
#!/usr/bin/env python
15+
16+
# yapf: disable
217
import asyncio
318
import signal
419
import traceback
@@ -10,7 +25,7 @@
1025
from fastapi import FastAPI, HTTPException, Request
1126
from fastapi.exceptions import RequestValidationError
1227
from fastapi.responses import JSONResponse, Response, StreamingResponse
13-
from prometheus_client import CollectorRegistry, make_asgi_app
28+
from prometheus_client import make_asgi_app
1429

1530
# yapf: disable
1631
from tensorrt_llm.executor import CppExecutorError
@@ -49,14 +64,16 @@ def on_req_begin(self, request: UCompletionRequest):
4964

5065
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
5166
self.ctx_server = ctx_server
67+
logger.info(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}")
5268

5369
def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
5470
self.gen_server = gen_server
5571
self.server_first_token_time = get_steady_clock_now_in_seconds()
5672

5773
def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
58-
ctx_req_id = request.disaggregated_params.ctx_request_id
59-
asyncio.create_task(self.perf_metrics_collector.add_per_request_metrics(self.ctx_server, gen_server, ctx_req_id, self.raw_req.state.server_arrival_time, self.server_first_token_time))
74+
if request.disaggregated_params:
75+
ctx_req_id = request.disaggregated_params.ctx_request_id
76+
asyncio.create_task(self.perf_metrics_collector.add_per_request_metrics(self.ctx_server, gen_server, ctx_req_id, self.raw_req.state.server_arrival_time, self.server_first_token_time))
6077

6178

6279
class OpenAIDisaggServer:
@@ -81,7 +98,14 @@ def __init__(self,
8198

8299
self._disagg_cluster_storage = create_cluster_storage(config.disagg_cluster_config.cluster_uri, config.disagg_cluster_config.cluster_name) if config.disagg_cluster_config else None
83100

84-
self._service = OpenAIDisaggregatedService(self._config, self._ctx_router, self._gen_router, self._create_client, self._metadata_server, self._req_timeout_secs, self._server_start_timeout_secs, self._perf_metrics_collector, self._disagg_cluster_storage)
101+
self._service = OpenAIDisaggregatedService(
102+
self._config, self._ctx_router, self._gen_router, self._create_client,
103+
metadata_server=self._metadata_server,
104+
metadata_config=self._metadata_server_cfg,
105+
req_timeout_secs=self._req_timeout_secs,
106+
server_start_timeout_secs=self._server_start_timeout_secs,
107+
perf_metrics_collector=self._perf_metrics_collector,
108+
disagg_cluster_storage=self._disagg_cluster_storage)
85109

86110
try:
87111
otlp_cfg = config.otlp_config
@@ -123,9 +147,7 @@ def register_routes(self):
123147
self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"])
124148
self.app.add_api_route("/version", self.version, methods=["GET"])
125149
self.app.add_api_route("/perf_metrics", self._perf_metrics_collector.get_perf_metrics, methods=["GET"])
126-
registry = CollectorRegistry()
127-
metrics_app = make_asgi_app(registry=registry)
128-
self.app.mount("/prometheus/metrics", metrics_app)
150+
self.app.mount("/prometheus/metrics", make_asgi_app())
129151
if self._disagg_cluster_storage and isinstance(self._disagg_cluster_storage, HttpClusterStorageServer):
130152
self._disagg_cluster_storage.add_routes(self.app)
131153

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import asyncio
216
import copy
317
import os
@@ -7,6 +21,7 @@
721
ConditionalDisaggConfig,
822
DisaggClusterConfig,
923
DisaggServerConfig,
24+
MetadataServerConfig,
1025
ServerRole,
1126
)
1227
from tensorrt_llm.logger import logger
@@ -37,8 +52,9 @@ def __init__(
3752
config: DisaggServerConfig,
3853
ctx_router: Router,
3954
gen_router: Router,
40-
client_factory: Callable[[Router, str], OpenAIClient] = None,
55+
client_factory: Callable[[Router, str], OpenAIClient],
4156
metadata_server: Optional[JsonDictionary] = None,
57+
metadata_config: Optional[MetadataServerConfig] = None,
4258
req_timeout_secs: int = 180,
4359
server_start_timeout_secs: int = 180,
4460
perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None,
@@ -49,6 +65,7 @@ def __init__(
4965
self._gen_router = gen_router
5066
self._client_factory = client_factory
5167
self._metadata_server = metadata_server
68+
self._metadata_config = metadata_config
5269
self._req_timeout_secs = req_timeout_secs
5370
self._server_start_timeout_secs = server_start_timeout_secs
5471
self._perf_metrics_collector = perf_metrics_collector
@@ -86,11 +103,6 @@ async def openai_chat_completion(
86103
async def _send_disagg_request(
87104
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
88105
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
89-
"""This is the main disaggregated serving logic:
90-
1. send context request to the context server if ctx is needed, return the context response if gen is not needed
91-
2. build a generation request based on the context response and send it to the generation server if gen is needed,
92-
return the generation response
93-
"""
94106
if hooks:
95107
hooks.on_req_begin(request)
96108
# empty server means client decides which server to use
@@ -104,7 +116,7 @@ async def _send_disagg_request(
104116
if need_ctx:
105117
ctx_req = self._get_ctx_request(request)
106118
# ctx generator is empty
107-
ctx_response = await self._ctx_client.send_request(ctx_server, ctx_req)
119+
ctx_response = await self._ctx_client.send_request(ctx_server, ctx_req, hooks)
108120
await self._verify_ctx_response(ctx_response)
109121
gen_req = self._get_gen_request(request, ctx_response)
110122
if ctx_response is None or self._need_gen(ctx_response):
@@ -206,13 +218,13 @@ async def setup(self) -> None:
206218
await self._disagg_cluster_manager.watch_workers(on_event=self._on_worker_event)
207219
logger.info("Disagg cluster manager started")
208220
else:
209-
if self._metadata_server:
221+
if self._metadata_server and self._metadata_config:
210222
logger.info("Starting server monitoring via metadata service")
211223
await self._ctx_router.start_server_monitoring(
212-
self.metadata_server.refresh_interval
224+
self._metadata_config.refresh_interval
213225
)
214226
await self._gen_router.start_server_monitoring(
215-
self.metadata_server.refresh_interval
227+
self._metadata_config.refresh_interval
216228
)
217229
await self._wait_for_servers_ready()
218230

@@ -274,7 +286,11 @@ async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
274286
# FIXME: This is a demo to show the basic idea of disagg-service with pre-allocating generation
275287
class OpenAIDisaggregatedPreAllocService(OpenAIDisaggregatedService):
276288
def _need_gen(self, request: UCompletionRequest) -> bool:
277-
return request.max_tokens > 1
289+
if isinstance(request, CompletionRequest) and request.max_tokens is not None:
290+
return request.max_tokens > 1
291+
if isinstance(request, ChatCompletionRequest) and request.max_completion_tokens is not None:
292+
return request.max_completion_tokens > 1
293+
return False
278294

279295
async def _send_disagg_request(
280296
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
@@ -290,23 +306,24 @@ async def _send_disagg_request(
290306
need_gen = self._need_gen(request)
291307
# send ctx and gen requests in parallel
292308
assert need_gen or need_ctx, "Neither generation nor context is required"
293-
with asyncio.TaskGroup() as tg:
294-
if need_ctx:
295-
296-
async def _run_ctx_task():
297-
# send ctx request and gen request in parallel
298-
ctx_req = self._get_ctx_request(request)
299-
ctx_response = await self._ctx_client.send_completion_request(
300-
ctx_server, ctx_req
301-
)
302-
return ctx_response
303-
304-
ctx_task = tg.create_task(_run_ctx_task())
305-
if need_gen:
306-
gen_task = tg.create_task(
307-
self._gen_client.send_completion_request(gen_server, request, hooks)
308-
)
309+
gen_task = None
310+
ctx_task = None
311+
tasks = []
312+
313+
async def _run_ctx_task():
314+
# send ctx request and gen request in parallel
315+
ctx_req = self._get_ctx_request(request)
316+
ctx_response = await self._ctx_client.send_request(ctx_server, ctx_req, hooks)
317+
return ctx_response
318+
319+
if need_ctx:
320+
ctx_task = asyncio.create_task(_run_ctx_task())
321+
if need_gen:
322+
gen_task = asyncio.create_task(
323+
self._gen_client.send_request(gen_server, request, hooks)
324+
)
325+
tasks.append(gen_task)
326+
await asyncio.gather(*tasks)
309327
if need_gen:
310328
return gen_task.result()
311-
else:
312-
return ctx_task.result()
329+
return ctx_task.result()

0 commit comments

Comments
 (0)