Skip to content

Commit 57451e6

Browse files
committed
fix tests
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 577e9db commit 57451e6

File tree

8 files changed

+206
-101
lines changed

8 files changed

+206
-101
lines changed

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import time
66
from dataclasses import asdict, dataclass
7-
from typing import Any, Callable, Dict, List, Optional, Tuple
7+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
88

99
from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole
1010
from tensorrt_llm.logger import logger
@@ -97,7 +97,7 @@ async def watch_workers(
9797
self,
9898
get_existing_first: bool = True,
9999
on_event: Optional[Callable[[WorkerInfo, WatchEventType],
100-
None]] = None):
100+
Awaitable[Any]]] = None):
101101
if self._watch_handle:
102102
logger.error("Watch handle is already initialized")
103103
return []
@@ -121,25 +121,28 @@ async def watch_workers(
121121
self._watch_handle = await self._cluster_storage.watch(
122122
self.worker_key_prefix)
123123

124-
async def on_event_wrapper():
125-
logger.warning(
126-
f"Initializing watch task with {len(workers)} existing workers")
124+
async def worker_event_loop():
125+
logger.info(
126+
f"Start watching worker events with {len(workers)} existing workers"
127+
)
127128
for worker_info in workers:
128129
await on_event(worker_info, WatchEventType.SET)
129-
logger.warning("Start watching worker events")
130130
while True:
131131
try:
132132
worker_events = await self._watch_handle.drain()
133133
for event in worker_events:
134134
worker_info = self._parse_worker_info(event)
135135
await on_event(worker_info, event.event_type)
136+
except asyncio.CancelledError:
137+
break
136138
except Exception as e:
137139
logger.error(
138140
f"Error updating routers by worker events: {e}")
139141
await asyncio.sleep(1)
142+
logger.info("Stop watching worker events")
140143

141144
if on_event:
142-
self._watch_task = asyncio.create_task(on_event_wrapper())
145+
self._watch_task = asyncio.create_task(worker_event_loop())
143146
return workers
144147

145148
async def unwatch_workers(self) -> None:

tensorrt_llm/serve/openai_client.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
# yapf: disable
216
import asyncio
17+
import traceback
318
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
19+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Type
520

621
import aiohttp
722

@@ -16,8 +31,8 @@
1631
)
1732
from tensorrt_llm.serve.perf_metrics import ClientMetricsCollector, DisaggPerfMetricsCollector
1833
from tensorrt_llm.serve.responses_utils import (
19-
CompletionResponseGenerator,
2034
ResponseHooks,
35+
UCompletionResponseOrGenerator,
2136
get_steady_clock_now_in_seconds,
2237
)
2338
from tensorrt_llm.serve.router import Router
@@ -28,7 +43,7 @@
2843
class OpenAIClient(ABC):
2944
async def send_request(
3045
self, server: str, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
31-
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
46+
) -> UCompletionResponseOrGenerator:
3247
if isinstance(request, CompletionRequest):
3348
return await self._send_request(
3449
server, "v1/completions", request, CompletionResponse, hooks
@@ -48,8 +63,9 @@ async def _send_request(
4863
request: UCompletionRequest,
4964
response_type: Type[UCompletionResponse],
5065
hooks: Optional[ResponseHooks] = None,
51-
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
52-
"""Send a request to the server and return the response and the body iterator.
66+
) -> UCompletionResponseOrGenerator:
67+
"""Send a request to the server and return the response and the body generator.
68+
5369
The request is finished (in routers) when the generator is exhausted or there is an error.
5470
"""
5571
...
@@ -59,7 +75,7 @@ async def collect_metrics(self) -> Dict[str, Any]: ...
5975

6076
@abstractmethod
6177
async def check_ready(self) -> Tuple[List[str], List[str]]:
62-
"""Return the list of ready servers and the list of unready servers"""
78+
"""Return the list of ready servers and the list of unready servers."""
6379
...
6480

6581
async def shutdown(self) -> None: ...
@@ -97,28 +113,32 @@ async def _send_request(
97113
request: UCompletionRequest,
98114
response_type: Type[UCompletionResponse],
99115
hooks: Optional[ResponseHooks] = None,
100-
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
116+
) -> UCompletionResponseOrGenerator:
101117
if len(server) == 0:
102118
server, _ = await self._router.get_next_server(request)
103119
url = f"http://{server}/{endpoint}"
120+
logger.debug(
121+
f"Sending {self._client_type} request {request.disaggregated_params.ctx_request_id} to {url}"
122+
)
104123
try:
105124
self._metrics_collector.inc("total_requests")
106125
resp_generator = self._post_with_retry(server, url, request, hooks)
107126
if request.stream:
108127
return resp_generator
109128
else:
110129
# consume the generator to get the response and return it directly when it's not streaming
111-
resp_json = await anext(resp_generator)
112-
response = response_type(**resp_json)
113-
if hooks:
114-
if self._client_type == "ctx":
115-
hooks.on_ctx_resp(server, response)
116-
hooks.on_first_token(server, request)
117-
hooks.on_resp_done(server, request, response)
130+
response = None
131+
async for resp_json in resp_generator:
132+
response = response_type(**resp_json)
133+
if hooks:
134+
if self._client_type == "ctx":
135+
hooks.on_ctx_resp(server, response)
136+
else:
137+
hooks.on_first_token(server, request)
138+
hooks.on_resp_done(server, request, response)
118139
return response
119140
except Exception:
120141
self._metrics_collector.inc("error_requests")
121-
await self._finish_request(request)
122142
raise
123143

124144
async def _post_with_retry(
@@ -127,7 +147,7 @@ async def _post_with_retry(
127147
url: str,
128148
request: UCompletionRequest,
129149
hooks: Optional[ResponseHooks] = None,
130-
) -> Tuple[aiohttp.ClientResponse, Dict[str, Any]]:
150+
) -> AsyncGenerator[Any, None]:
131151
json_data = request.model_dump(exclude_unset=True)
132152
is_stream = request.stream
133153
for attempt in range(self._max_retries + 1):
@@ -149,22 +169,31 @@ async def _post_with_retry(
149169
else:
150170
http_response.raise_for_status()
151171
response_dict = await http_response.json()
152-
# do yield here until python allows return statements in async generators
172+
# yield here since python forbids return statements in async generators
153173
yield response_dict
174+
break # break and skip retries if the whole response is processed without exception
154175
except (aiohttp.ClientError, OSError) as e:
155176
if attempt == self._max_retries:
177+
logger.error(
178+
f"{self._client_type} client error to {url}: {e} - last retry {attempt} of {self._max_retries}"
179+
"failed",
180+
traceback.format_exc(),
181+
)
156182
raise
157-
import traceback
158183

159184
logger.error(
160-
f"Client error: {e} - retry {attempt} of {self._max_retries}",
185+
f"{self._client_type} client error to {url}: {e} - retry {attempt} of {self._max_retries}",
161186
traceback.format_exc(),
162187
)
163188
await asyncio.sleep(self._retry_interval)
164189
self._metrics_collector.inc("retry_requests")
165190
except Exception as e:
166-
logger.error(f"Error encountered while processing request to {url}: {e}")
191+
logger.error(
192+
f"Unexpected error while processing {self._client_type} request to {url}: {e}"
193+
)
167194
raise
195+
finally:
196+
await self._finish_request(request)
168197

169198
async def _response_generator(
170199
self,
@@ -173,7 +202,7 @@ async def _response_generator(
173202
start_time: float,
174203
hooks: Optional[ResponseHooks] = None,
175204
server: str = "",
176-
) -> CompletionResponseGenerator:
205+
) -> AsyncGenerator[Any, None]:
177206
assert request.stream, "Request is not streaming"
178207
assert "text/event-stream" in http_response.headers.get("Content-Type", ""), (
179208
"Response is not streaming"

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
#!/usr/bin/env python
15+
16+
# yapf: disable
217
import asyncio
318
import signal
419
import traceback
@@ -10,7 +25,7 @@
1025
from fastapi import FastAPI, HTTPException, Request
1126
from fastapi.exceptions import RequestValidationError
1227
from fastapi.responses import JSONResponse, Response, StreamingResponse
13-
from prometheus_client import CollectorRegistry, make_asgi_app
28+
from prometheus_client import make_asgi_app
1429

1530
# yapf: disable
1631
from tensorrt_llm.executor import CppExecutorError
@@ -49,14 +64,18 @@ def on_req_begin(self, request: UCompletionRequest):
4964

5065
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
5166
self.ctx_server = ctx_server
67+
logger.debug(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}")
5268

5369
def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
5470
self.gen_server = gen_server
5571
self.server_first_token_time = get_steady_clock_now_in_seconds()
72+
logger.debug(f"Received first token from {gen_server} for request {request.disaggregated_params.ctx_request_id}")
5673

5774
def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
58-
ctx_req_id = request.disaggregated_params.ctx_request_id
59-
asyncio.create_task(self.perf_metrics_collector.add_per_request_metrics(self.ctx_server, gen_server, ctx_req_id, self.raw_req.state.server_arrival_time, self.server_first_token_time))
75+
if request.disaggregated_params:
76+
ctx_req_id = request.disaggregated_params.ctx_request_id
77+
asyncio.create_task(self.perf_metrics_collector.add_per_request_metrics(self.ctx_server, gen_server, ctx_req_id, self.raw_req.state.server_arrival_time, self.server_first_token_time))
78+
logger.debug(f"Request {ctx_req_id} completed")
6079

6180

6281
class OpenAIDisaggServer:
@@ -81,7 +100,14 @@ def __init__(self,
81100

82101
self._disagg_cluster_storage = create_cluster_storage(config.disagg_cluster_config.cluster_uri, config.disagg_cluster_config.cluster_name) if config.disagg_cluster_config else None
83102

84-
self._service = OpenAIDisaggregatedService(self._config, self._ctx_router, self._gen_router, self._create_client, self._metadata_server, self._req_timeout_secs, self._server_start_timeout_secs, self._perf_metrics_collector, self._disagg_cluster_storage)
103+
self._service = OpenAIDisaggregatedService(
104+
self._config, self._ctx_router, self._gen_router, self._create_client,
105+
metadata_server=self._metadata_server,
106+
metadata_config=self._metadata_server_cfg,
107+
req_timeout_secs=self._req_timeout_secs,
108+
server_start_timeout_secs=self._server_start_timeout_secs,
109+
perf_metrics_collector=self._perf_metrics_collector,
110+
disagg_cluster_storage=self._disagg_cluster_storage)
85111

86112
try:
87113
otlp_cfg = config.otlp_config
@@ -123,9 +149,7 @@ def register_routes(self):
123149
self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"])
124150
self.app.add_api_route("/version", self.version, methods=["GET"])
125151
self.app.add_api_route("/perf_metrics", self._perf_metrics_collector.get_perf_metrics, methods=["GET"])
126-
registry = CollectorRegistry()
127-
metrics_app = make_asgi_app(registry=registry)
128-
self.app.mount("/prometheus/metrics", metrics_app)
152+
self.app.mount("/prometheus/metrics", make_asgi_app())
129153
if self._disagg_cluster_storage and isinstance(self._disagg_cluster_storage, HttpClusterStorageServer):
130154
self._disagg_cluster_storage.add_routes(self.app)
131155

0 commit comments

Comments
 (0)