Skip to content

Commit 8ea9d89

Browse files
committed
implement prototype
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 8303cfa commit 8ea9d89

File tree

9 files changed

+958
-586
lines changed

9 files changed

+958
-586
lines changed

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ class MetadataServerConfig():
8787
refresh_interval: float = 10.0
8888

8989

90-
def get_ctx_gen_server_urls(
90+
def get_ctx_gen_server_addrs(
9191
server_configs: list[CtxGenServerConfig]
9292
) -> tuple[list[str], list[str]]:
9393
ctx_server_urls = []
9494
gen_server_urls = []
9595
for cfg in server_configs:
9696
if cfg.type == "ctx":
97-
ctx_server_urls.append(f"http://{cfg.hostname}:{cfg.port}")
97+
ctx_server_urls.append(f"{cfg.hostname}:{cfg.port}")
9898
else:
99-
gen_server_urls.append(f"http://{cfg.hostname}:{cfg.port}")
99+
gen_server_urls.append(f"{cfg.hostname}:{cfg.port}")
100100

101101
return ctx_server_urls, gen_server_urls
102102

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import time
66
from dataclasses import asdict, dataclass
7-
from typing import Any, Dict, List, Tuple
7+
from typing import Any, Callable, Dict, List, Optional, Tuple
88

99
from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole
1010
from tensorrt_llm.logger import logger
@@ -44,6 +44,7 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):
4444
self._current_ctx_workers = {} # worker_id -> WorkerInfo
4545
self._current_gen_workers = {} # worker_id -> WorkerInfo
4646
self._watch_handle = None
47+
self._watch_task = None
4748

4849
def __del__(self):
4950
try:
@@ -92,7 +93,14 @@ def current_gen_worker_num(self) -> int:
9293
def worker_key_prefix(self) -> str:
9394
return get_worker_key_prefix(self._config.cluster_name)
9495

95-
async def watch_workers(self, get_existing_first: bool = True):
96+
async def watch_workers(
97+
self,
98+
get_existing_first: bool = True,
99+
on_event: Optional[Callable[[WorkerInfo, WatchEventType],
100+
None]] = None):
101+
if self._watch_handle:
102+
logger.error("Watch handle is already initialized")
103+
return []
96104
workers = []
97105
self._watch_handle = await self._cluster_storage.watch(
98106
self.worker_key_prefix)
@@ -109,12 +117,38 @@ async def watch_workers(self, get_existing_first: bool = True):
109117
workers.append(self._parse_worker_info(event))
110118
events.append(event)
111119
await self._watch_handle.add_events(events)
120+
121+
self._watch_handle = await self._cluster_storage.watch(
122+
self.worker_key_prefix)
123+
124+
async def on_event_wrapper():
125+
logger.warning(
126+
f"Initializing watch task with {len(workers)} existing workers")
127+
for worker_info in workers:
128+
await on_event(worker_info, WatchEventType.SET)
129+
logger.warning("Start watching worker events")
130+
while True:
131+
try:
132+
worker_events = await self._watch_handle.drain()
133+
for event in worker_events:
134+
worker_info = self._parse_worker_info(event)
135+
await on_event(worker_info, event.event_type)
136+
except Exception as e:
137+
logger.error(
138+
f"Error updating routers by worker events: {e}")
139+
await asyncio.sleep(1)
140+
141+
if on_event:
142+
self._watch_task = asyncio.create_task(on_event_wrapper())
112143
return workers
113144

114145
async def unwatch_workers(self) -> None:
115146
if self._watch_handle:
116147
await self._cluster_storage.unwatch(self.worker_key_prefix)
117148
self._watch_handle = None
149+
if self._watch_task:
150+
self._watch_task.cancel()
151+
self._watch_task = None
118152

119153
async def get_worker_events(
120154
self) -> List[Tuple[WorkerInfo, WatchEventType]]:
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# yapf: disable
2+
import asyncio
3+
from abc import ABC, abstractmethod
4+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
5+
6+
import aiohttp
7+
8+
from tensorrt_llm.logger import logger
9+
from tensorrt_llm.serve.openai_protocol import (
10+
ChatCompletionRequest,
11+
ChatCompletionResponse,
12+
CompletionRequest,
13+
CompletionResponse,
14+
UCompletionRequest,
15+
UCompletionResponse,
16+
)
17+
from tensorrt_llm.serve.perf_metrics import ClientMetricsCollector, DisaggPerfMetricsCollector
18+
from tensorrt_llm.serve.responses_utils import (
19+
CompletionResponseGenerator,
20+
ResponseHooks,
21+
get_steady_clock_now_in_seconds,
22+
)
23+
from tensorrt_llm.serve.router import Router
24+
25+
# yapf: enable
26+
27+
28+
class OpenAIClient(ABC):
29+
async def send_request(
30+
self, server: str, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
31+
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
32+
if isinstance(request, CompletionRequest):
33+
return await self._send_request(
34+
server, "v1/completions", request, CompletionResponse, hooks
35+
)
36+
elif isinstance(request, ChatCompletionRequest):
37+
return await self._send_request(
38+
server, "v1/chat/completions", request, ChatCompletionResponse, hooks
39+
)
40+
else:
41+
raise ValueError(f"Invalid request type: {type(request)}")
42+
43+
@abstractmethod
44+
async def _send_request(
45+
self,
46+
server: str,
47+
endpoint: str,
48+
request: UCompletionRequest,
49+
response_type: Type[UCompletionResponse],
50+
hooks: Optional[ResponseHooks] = None,
51+
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
52+
"""Send a request to the server and return the response and the body iterator.
53+
The request is finished (in routers) when the generator is exhausted or there is an error.
54+
"""
55+
...
56+
57+
@abstractmethod
58+
async def collect_metrics(self) -> Dict[str, Any]: ...
59+
60+
@abstractmethod
61+
async def check_ready(self) -> Tuple[List[str], List[str]]:
62+
"""Return the list of ready servers and the list of unready servers"""
63+
...
64+
65+
async def shutdown(self) -> None: ...
66+
67+
@abstractmethod
68+
async def _finish_request(self, request: UCompletionRequest) -> None:
69+
"""Finish the request in the router."""
70+
...
71+
72+
73+
class OpenAIHttpClient(OpenAIClient):
74+
def __init__(
75+
self,
76+
router: Router,
77+
client_type: str,
78+
timeout_secs: int = 180,
79+
max_retries: int = 1,
80+
perf_metrics_collector: DisaggPerfMetricsCollector = None,
81+
):
82+
assert client_type in ["ctx", "gen"]
83+
self._router = router
84+
self._client_type = client_type
85+
self._metrics_collector = ClientMetricsCollector(client_type)
86+
self._session = aiohttp.ClientSession(
87+
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=False),
88+
timeout=aiohttp.ClientTimeout(total=timeout_secs),
89+
)
90+
self._max_retries = max_retries
91+
self._retry_interval = 1
92+
93+
async def _send_request(
94+
self,
95+
server: str,
96+
endpoint: str,
97+
request: UCompletionRequest,
98+
response_type: Type[UCompletionResponse],
99+
hooks: Optional[ResponseHooks] = None,
100+
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
101+
if len(server) == 0:
102+
server, _ = await self._router.get_next_server(request)
103+
url = f"http://{server}/{endpoint}"
104+
try:
105+
self._metrics_collector.inc("total_requests")
106+
resp_generator = self._post_with_retry(server, url, request, hooks)
107+
if request.stream:
108+
return resp_generator
109+
else:
110+
# consume the generator to get the response and return it directly when it's not streaming
111+
resp_json = await anext(resp_generator)
112+
response = response_type(**resp_json)
113+
if hooks:
114+
if self._client_type == "ctx":
115+
hooks.on_ctx_resp(server, response)
116+
hooks.on_first_token(server, request)
117+
hooks.on_resp_done(server, request, response)
118+
return response
119+
except Exception:
120+
self._metrics_collector.inc("error_requests")
121+
await self._finish_request(request)
122+
raise
123+
124+
async def _post_with_retry(
125+
self,
126+
server: str,
127+
url: str,
128+
request: UCompletionRequest,
129+
hooks: Optional[ResponseHooks] = None,
130+
) -> Tuple[aiohttp.ClientResponse, Dict[str, Any]]:
131+
json_data = request.model_dump(exclude_unset=True)
132+
is_stream = request.stream
133+
for attempt in range(self._max_retries + 1):
134+
try:
135+
start_time = get_steady_clock_now_in_seconds()
136+
async with self._session.post(url, json=json_data) as http_response:
137+
content_type = http_response.headers.get("Content-Type", "")
138+
if not is_stream and "text/event-stream" in content_type:
139+
raise ValueError(
140+
"Received an event-stream although request stream was False"
141+
)
142+
if is_stream:
143+
# do NOT return generator directly here or the response will go
144+
# out of scope and get destroyed
145+
async for line in self._response_generator(
146+
request, http_response, start_time, hooks, server
147+
):
148+
yield line
149+
else:
150+
http_response.raise_for_status()
151+
response_dict = await http_response.json()
152+
# do yield here until python allows return statements in async generators
153+
yield response_dict
154+
except (aiohttp.ClientError, OSError) as e:
155+
if attempt == self._max_retries:
156+
raise
157+
import traceback
158+
159+
logger.error(
160+
f"Client error: {e} - retry {attempt} of {self._max_retries}",
161+
traceback.format_exc(),
162+
)
163+
await asyncio.sleep(self._retry_interval)
164+
self._metrics_collector.inc("retry_requests")
165+
except Exception as e:
166+
logger.error(f"Error encountered while processing request to {url}: {e}")
167+
raise
168+
169+
async def _response_generator(
170+
self,
171+
request: UCompletionRequest,
172+
http_response: aiohttp.ClientResponse,
173+
start_time: float,
174+
hooks: Optional[ResponseHooks] = None,
175+
server: str = "",
176+
) -> CompletionResponseGenerator:
177+
assert request.stream, "Request is not streaming"
178+
assert "text/event-stream" in http_response.headers.get("Content-Type", ""), (
179+
"Response is not streaming"
180+
)
181+
try:
182+
last_token_time = start_time
183+
i = 0
184+
async for line in http_response.content.iter_any():
185+
now_time = get_steady_clock_now_in_seconds()
186+
if i == 0:
187+
if hooks:
188+
hooks.on_first_token(server, request)
189+
self._metrics_collector.observe(
190+
"first_token_latency_seconds", now_time - last_token_time
191+
)
192+
else:
193+
self._metrics_collector.observe(
194+
"per_token_latency_seconds", now_time - last_token_time
195+
)
196+
i += 1
197+
if line:
198+
yield line
199+
await asyncio.sleep(0)
200+
last_token_time = now_time
201+
202+
if hooks:
203+
hooks.on_resp_done(server, request, None)
204+
self._metrics_collector.inc("completed_requests")
205+
self._metrics_collector.observe(
206+
"complete_latency_seconds",
207+
get_steady_clock_now_in_seconds() - start_time,
208+
)
209+
except aiohttp.ClientError as e:
210+
# a client error is expected when the response stream is done if the connector has close=True
211+
logger.error(f"{self._client_type} Client error: {e}")
212+
self._metrics_collector.inc("error_requests")
213+
raise
214+
except Exception:
215+
self._metrics_collector.inc("error_requests")
216+
raise
217+
finally:
218+
await self._finish_request(request)
219+
220+
async def _finish_request(self, request: UCompletionRequest) -> None:
221+
await self._router.finish_request(request)
222+
223+
async def collect_metrics(self) -> Dict[str, Any]:
224+
metrics = {}
225+
for server in self._router.servers:
226+
try:
227+
async with self._session.get(f"http://{server}/metrics") as response:
228+
metrics[server] = await response.json()
229+
except Exception:
230+
continue
231+
return metrics
232+
233+
async def shutdown(self) -> None:
234+
await self._session.close()
235+
236+
async def check_ready(self) -> Tuple[List[str], List[str]]:
237+
async def check_server_ready(server: str) -> bool:
238+
try:
239+
async with self._session.get(f"http://{server}/health") as response:
240+
return response.status == 200
241+
except Exception:
242+
return False
243+
244+
servers_ready = await asyncio.gather(
245+
*[check_server_ready(server) for server in self._router.servers]
246+
)
247+
return [server for server, ready in zip(self._router.servers, servers_ready) if ready], [
248+
server for server, ready in zip(self._router.servers, servers_ready) if not ready
249+
]

0 commit comments

Comments
 (0)