Skip to content

Commit 104984f

Browse files
committed
working on a stream request hang issue
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 483975c commit 104984f

File tree

9 files changed

+517
-279
lines changed

9 files changed

+517
-279
lines changed

docker/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ CCACHE_DIR ?= $(CODE_DIR)/cpp/.ccache
139139
CONAN_DIR ?= $(CODE_DIR)/cpp/.conan
140140
USER_CACHE_DIR ?= $(shell readlink -f "${HOME_DIR}/.cache")
141141
RUN_CMD ?=
142-
CONTAINER_NAME ?= tensorrt_llm_bug_wksp
142+
CONTAINER_NAME ?= tensorrt_llm
143143
WORK_DIR ?= $(CODE_DIR)
144144
DOCKER_PULL ?= 0
145145

@@ -157,6 +157,7 @@ endif
157157
$(GPU_OPTS) \
158158
--volume $(SOURCE_DIR):$(CODE_DIR) \
159159
$(EXTRA_VOLUMES) \
160+
$(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \
160161
--env "CCACHE_DIR=$(CCACHE_DIR)" \
161162
--env "CCACHE_BASEDIR=$(CODE_DIR)" \
162163
--env "CONAN_HOME=$(CONAN_DIR)" \

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
import time
66
from dataclasses import asdict, dataclass
7-
from typing import Any, Dict, List, Tuple
7+
from typing import Any, Callable, Dict, List, Optional, Tuple
88

99
from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole
1010
from tensorrt_llm.logger import logger
@@ -44,6 +44,7 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):
4444
self._current_ctx_workers = {} # worker_id -> WorkerInfo
4545
self._current_gen_workers = {} # worker_id -> WorkerInfo
4646
self._watch_handle = None
47+
self._watch_task = None
4748

4849
def __del__(self):
4950
try:
@@ -92,7 +93,14 @@ def current_gen_worker_num(self) -> int:
9293
def worker_key_prefix(self) -> str:
9394
return get_worker_key_prefix(self._config.cluster_name)
9495

95-
async def watch_workers(self, get_existing_first: bool = True):
96+
async def watch_workers(
97+
self,
98+
get_existing_first: bool = True,
99+
on_event: Optional[Callable[[WorkerInfo, WatchEventType],
100+
None]] = None):
101+
if self._watch_handle:
102+
logger.error("Watch handle is already initialized")
103+
return []
96104
workers = []
97105
if get_existing_first:
98106
# There is a tiny gap between getting existing workers and watching the key,
@@ -106,12 +114,35 @@ async def watch_workers(self, get_existing_first: bool = True):
106114
workers.append(self._parse_worker_info(event))
107115
self._watch_handle = await self._cluster_storage.watch(
108116
self.worker_key_prefix)
117+
118+
async def on_event_wrapper():
119+
logger.warning(
120+
f"Initializing watch task with {len(workers)} existing workers")
121+
for worker_info in workers:
122+
await on_event(worker_info, WatchEventType.SET)
123+
logger.warning("Start watching worker events")
124+
while True:
125+
try:
126+
worker_events = await self._watch_handle.drain()
127+
for event in worker_events:
128+
worker_info = self._parse_worker_info(event)
129+
await on_event(worker_info, event.event_type)
130+
except Exception as e:
131+
logger.error(
132+
f"Error updating routers by worker events: {e}")
133+
await asyncio.sleep(1)
134+
135+
if on_event:
136+
self._watch_task = asyncio.create_task(on_event_wrapper())
109137
return workers
110138

111139
async def unwatch_workers(self) -> None:
112140
if self._watch_handle:
113141
await self._cluster_storage.unwatch(self.worker_key_prefix)
114142
self._watch_handle = None
143+
if self._watch_task:
144+
self._watch_task.cancel()
145+
self._watch_task = None
115146

116147
async def get_worker_events(
117148
self) -> List[Tuple[WorkerInfo, WatchEventType]]:

tensorrt_llm/serve/openai_client.py

Lines changed: 112 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# yapf: disable
22
import asyncio
33
from abc import ABC, abstractmethod
4-
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Type
4+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
55

66
import aiohttp
77

@@ -12,15 +12,15 @@
1212
CompletionResponse,
1313
UCompletionRequest,
1414
UCompletionResponse)
15-
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
16-
from tensorrt_llm.serve.responses_utils import (ResponseHooks,
15+
from tensorrt_llm.serve.perf_metrics import (ClientMetricsCollector,
16+
DisaggPerfMetricsCollector)
17+
from tensorrt_llm.serve.responses_utils import (CompletionResponseGenerator,
18+
ResponseHooks,
1719
get_steady_clock_now_in_seconds)
1820
from tensorrt_llm.serve.router import Router
1921

2022
# yapf: enable
2123

22-
CompletionResponseGenerator = AsyncGenerator[bytes, None]
23-
2424

2525
class OpenAIClient(ABC):
2626

@@ -29,7 +29,7 @@ async def send_request(
2929
server: str,
3030
request: UCompletionRequest,
3131
hooks: Optional[ResponseHooks] = None
32-
) -> Tuple[UCompletionResponse, AsyncGenerator[bytes, None]]:
32+
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
3333
if isinstance(request, CompletionRequest):
3434
return await self._send_request(server, "v1/completions", request,
3535
CompletionResponse, hooks)
@@ -48,7 +48,7 @@ async def _send_request(
4848
request: UCompletionRequest,
4949
response_type: Type[UCompletionResponse],
5050
hooks: Optional[ResponseHooks] = None,
51-
) -> Tuple[UCompletionResponse, AsyncGenerator[bytes, None]]:
51+
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
5252
"""
5353
Send a request to the server and return the response and the body iterator.
5454
The request is finished (in routers) when the generator is exhausted or there is an error.
@@ -83,16 +83,19 @@ def __init__(self,
8383
router: Router,
8484
client_type: str,
8585
timeout_secs: int = 180,
86+
max_retries: int = 1,
8687
perf_metrics_collector: DisaggPerfMetricsCollector = None):
8788
assert client_type in ["ctx", "gen"]
8889
self._router = router
8990
self._client_type = client_type
90-
self._perf_metrics_collector = perf_metrics_collector
91+
self._metrics_collector = ClientMetricsCollector(client_type)
9192
self._session = aiohttp.ClientSession(
9293
connector=aiohttp.TCPConnector(limit=0,
9394
limit_per_host=0,
9495
force_close=False),
9596
timeout=aiohttp.ClientTimeout(total=timeout_secs))
97+
self._max_retries = max_retries
98+
self._retry_interval = 1
9699

97100
async def _send_request(
98101
self,
@@ -101,79 +104,122 @@ async def _send_request(
101104
request: UCompletionRequest,
102105
response_type: Type[UCompletionResponse],
103106
hooks: Optional[ResponseHooks] = None,
104-
) -> Tuple[UCompletionResponse, CompletionResponseGenerator]:
107+
) -> Union[UCompletionResponse, CompletionResponseGenerator]:
105108
if len(server) == 0:
106109
server, _ = await self._router.get_next_server(request)
107110
url = f"http://{server}/{endpoint}"
108111
try:
109-
start_time = get_steady_clock_now_in_seconds()
110-
self._perf_metrics_collector.inc(
111-
f"{self._client_type}_total_requests")
112-
async with self._session.post(
113-
url, json=request.model_dump(
114-
exclude_unset=True)) as http_response:
115-
content_type = http_response.headers.get("Content-Type", "")
116-
if not request.stream and "text/event-stream" in content_type:
117-
raise ValueError(
118-
"Received an event-stream although request stream was False"
119-
)
120-
121-
response_dict = await http_response.json()
122-
if not http_response.ok:
123-
logger.error(f"Received failed response {response_dict}")
124-
http_response.raise_for_status()
125-
response = response_type(**response_dict)
126-
127-
return response, self._response_generator(
128-
request, http_response, response, start_time, hooks)
112+
self._metrics_collector.inc("total_requests")
113+
resp_generator = self._post_with_retry(server, url, request, hooks)
114+
if request.stream:
115+
return resp_generator
116+
else:
117+
# consume the generator to get the response and return it directly when it's not streaming
118+
resp_json = await anext(resp_generator)
119+
response = response_type(**resp_json)
120+
if hooks:
121+
if self._client_type == "ctx":
122+
hooks.on_ctx_resp(server, response)
123+
hooks.on_first_token(server, request)
124+
hooks.on_resp_done(server, request, response)
125+
return response
129126
except Exception:
130-
self._perf_metrics_collector.inc(
131-
f"{self._client_type}_error_requests")
127+
self._metrics_collector.inc("error_requests")
132128
await self._finish_request(request)
133129
raise
134130

131+
async def _post_with_retry(
132+
self,
133+
server: str,
134+
url: str,
135+
request: UCompletionRequest,
136+
hooks: Optional[ResponseHooks] = None
137+
) -> Tuple[aiohttp.ClientResponse, Dict[str, Any]]:
138+
json_data = request.model_dump(exclude_unset=True)
139+
is_stream = request.stream
140+
for attempt in range(self._max_retries + 1):
141+
try:
142+
start_time = get_steady_clock_now_in_seconds()
143+
async with self._session.post(url,
144+
json=json_data) as http_response:
145+
content_type = http_response.headers.get("Content-Type", "")
146+
if not is_stream and "text/event-stream" in content_type:
147+
raise ValueError(
148+
"Received an event-stream although request stream was False"
149+
)
150+
if is_stream:
151+
# do NOT return generator directly here or the response will go out of scope and get destroyed
152+
async for line in self._response_generator(
153+
request, http_response, start_time, hooks,
154+
server):
155+
yield line
156+
else:
157+
http_response.raise_for_status()
158+
response_dict = await http_response.json()
159+
# do yield here until python allows return statements in async generators
160+
yield response_dict
161+
except (aiohttp.ClientError, OSError) as e:
162+
if attempt == self._max_retries:
163+
raise
164+
import traceback
165+
logger.error(
166+
f"Client error: {e} - retry {attempt} of {self._max_retries}",
167+
traceback.format_exc())
168+
await asyncio.sleep(self._retry_interval)
169+
self._metrics_collector.inc("retry_requests")
170+
except Exception as e:
171+
logger.error(
172+
f"Error encountered while processing request to {url}: {e}")
173+
raise
174+
135175
async def _response_generator(
136176
self,
137177
request: UCompletionRequest,
138178
http_response: aiohttp.ClientResponse,
139-
response: UCompletionResponse,
140179
start_time: float,
141-
hooks: Optional[ResponseHooks] = None
142-
) -> CompletionResponseGenerator:
180+
hooks: Optional[ResponseHooks] = None,
181+
server: str = "") -> CompletionResponseGenerator:
182+
"""
183+
If the request is streaming, yield the response line by line,
184+
otherwise, yield nothing because the generator won't be used and the response will be returned directly.
185+
"""
186+
assert request.stream, "Request is not streaming"
187+
assert "text/event-stream" in http_response.headers.get(
188+
"Content-Type", ""), "Response is not streaming"
143189
try:
144-
if request.stream and "text/event-stream" in http_response.headers.get(
145-
"Content-Type", ""):
146-
last_token_time = start_time
147-
async for i, line in enumerate(
148-
http_response.content.iter_any()):
149-
now_time = get_steady_clock_now_in_seconds()
150-
if i == 0:
151-
if hooks and hooks.on_first_token:
152-
hooks.on_first_token(request, response)
153-
self._perf_metrics_collector.observe(
154-
f"{self._client_type}_first_token_latency_seconds",
155-
now_time - last_token_time,
156-
)
157-
else:
158-
self._perf_metrics_collector.observe(
159-
f"{self._client_type}_per_token_latency_seconds",
160-
now_time - last_token_time,
161-
)
162-
if line:
163-
yield line
164-
await asyncio.sleep(0)
165-
last_token_time = now_time
166-
if hooks and hooks.on_resp_done:
167-
hooks.on_resp_done(request, response)
168-
self._perf_metrics_collector.inc(
169-
f"{self._client_type}_completed_requests")
170-
self._perf_metrics_collector.observe(
171-
f"{self._client_type}_complete_latency_seconds",
172-
get_steady_clock_now_in_seconds() - start_time,
173-
)
190+
last_token_time = start_time
191+
i = 0
192+
async for line in http_response.content.iter_any():
193+
now_time = get_steady_clock_now_in_seconds()
194+
if i == 0:
195+
if hooks:
196+
hooks.on_first_token(server, request)
197+
self._metrics_collector.observe(
198+
"first_token_latency_seconds",
199+
now_time - last_token_time)
200+
else:
201+
self._metrics_collector.observe("per_token_latency_seconds",
202+
now_time - last_token_time)
203+
i += 1
204+
if line:
205+
yield line
206+
await asyncio.sleep(0)
207+
last_token_time = now_time
208+
209+
if hooks:
210+
hooks.on_resp_done(server, request, None)
211+
self._metrics_collector.inc("completed_requests")
212+
self._metrics_collector.observe(
213+
"complete_latency_seconds",
214+
get_steady_clock_now_in_seconds() - start_time,
215+
)
216+
except aiohttp.ClientError as e:
217+
# a client error is expected when the response stream is done if the connector has close=True
218+
logger.error(f"{self._client_type} Client error: {e}")
219+
self._metrics_collector.inc("error_requests")
220+
raise
174221
except Exception:
175-
self._perf_metrics_collector.inc(
176-
f"{self._client_type}_error_requests")
222+
self._metrics_collector.inc("error_requests")
177223
raise
178224
finally:
179225
await self._finish_request(request)

0 commit comments

Comments
 (0)