Skip to content

Commit 483975c

Browse files
committed
tested
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 3b5d901 commit 483975c

File tree

9 files changed

+227
-315
lines changed

9 files changed

+227
-315
lines changed

docker/Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ CCACHE_DIR ?= $(CODE_DIR)/cpp/.ccache
139139
CONAN_DIR ?= $(CODE_DIR)/cpp/.conan
140140
USER_CACHE_DIR ?= $(shell readlink -f "${HOME_DIR}/.cache")
141141
RUN_CMD ?=
142-
CONTAINER_NAME ?= tensorrt_llm
142+
CONTAINER_NAME ?= tensorrt_llm_bug_wksp
143143
WORK_DIR ?= $(CODE_DIR)
144144
DOCKER_PULL ?= 0
145145

@@ -157,7 +157,6 @@ endif
157157
$(GPU_OPTS) \
158158
--volume $(SOURCE_DIR):$(CODE_DIR) \
159159
$(EXTRA_VOLUMES) \
160-
$(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \
161160
--env "CCACHE_DIR=$(CCACHE_DIR)" \
162161
--env "CCACHE_BASEDIR=$(CODE_DIR)" \
163162
--env "CONAN_HOME=$(CONAN_DIR)" \

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,16 @@ class MetadataServerConfig():
8080
refresh_interval: float = 10.0
8181

8282

83-
def get_ctx_gen_server_urls(
83+
def get_ctx_gen_server_addrs(
8484
server_configs: list[CtxGenServerConfig]
8585
) -> tuple[list[str], list[str]]:
8686
ctx_server_urls = []
8787
gen_server_urls = []
8888
for cfg in server_configs:
8989
if cfg.type == "ctx":
90-
ctx_server_urls.append(f"http://{cfg.hostname}:{cfg.port}")
90+
ctx_server_urls.append(f"{cfg.hostname}:{cfg.port}")
9191
else:
92-
gen_server_urls.append(f"http://{cfg.hostname}:{cfg.port}")
92+
gen_server_urls.append(f"{cfg.hostname}:{cfg.port}")
9393

9494
return ctx_server_urls, gen_server_urls
9595

tensorrt_llm/serve/openai_client.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
# yapf disagrees with isort in pre-commit hooks
21
# yapf: disable
32
import asyncio
43
from abc import ABC, abstractmethod
5-
from typing import Any, AsyncGenerator, Dict, List, Tuple, Type, Union
4+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Type
65

76
import aiohttp
87

@@ -14,44 +13,41 @@
1413
UCompletionRequest,
1514
UCompletionResponse)
1615
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
17-
from tensorrt_llm.serve.responses_utils import (CompletionResponseIterator,
16+
from tensorrt_llm.serve.responses_utils import (ResponseHooks,
1817
get_steady_clock_now_in_seconds)
1918
from tensorrt_llm.serve.router import Router
2019

2120
# yapf: enable
2221

22+
CompletionResponseGenerator = AsyncGenerator[bytes, None]
23+
2324

2425
class OpenAIClient(ABC):
2526

2627
async def send_request(
27-
self, server: str, request: UCompletionRequest
28+
self,
29+
server: str,
30+
request: UCompletionRequest,
31+
hooks: Optional[ResponseHooks] = None
2832
) -> Tuple[UCompletionResponse, AsyncGenerator[bytes, None]]:
2933
if isinstance(request, CompletionRequest):
30-
return await self.send_completion_request(server, request)
34+
return await self._send_request(server, "v1/completions", request,
35+
CompletionResponse, hooks)
3136
elif isinstance(request, ChatCompletionRequest):
32-
return await self.send_chat_request(server, request)
37+
return await self._send_request(server, "v1/chat/completions",
38+
request, ChatCompletionResponse,
39+
hooks)
3340
else:
3441
raise ValueError(f"Invalid request type: {type(request)}")
3542

36-
async def send_completion_request(
37-
self, server: str, request: CompletionRequest
38-
) -> Tuple[CompletionResponse, AsyncGenerator[bytes, None]]:
39-
return await self._send_request(server, "v1/completions", request,
40-
CompletionResponse)
41-
42-
async def send_chat_request(
43-
self, server: str, request: ChatCompletionRequest
44-
) -> Tuple[ChatCompletionResponse, AsyncGenerator[bytes, None]]:
45-
return await self._send_request(server, "v1/chat/completions", request,
46-
ChatCompletionResponse)
47-
4843
@abstractmethod
4944
async def _send_request(
5045
self,
5146
server: str,
5247
endpoint: str,
53-
request: Union[CompletionRequest, ChatCompletionRequest],
48+
request: UCompletionRequest,
5449
response_type: Type[UCompletionResponse],
50+
hooks: Optional[ResponseHooks] = None,
5551
) -> Tuple[UCompletionResponse, AsyncGenerator[bytes, None]]:
5652
"""
5753
Send a request to the server and return the response and the body iterator.
@@ -95,7 +91,7 @@ def __init__(self,
9591
self._session = aiohttp.ClientSession(
9692
connector=aiohttp.TCPConnector(limit=0,
9793
limit_per_host=0,
98-
force_close=True),
94+
force_close=False),
9995
timeout=aiohttp.ClientTimeout(total=timeout_secs))
10096

10197
async def _send_request(
@@ -104,7 +100,8 @@ async def _send_request(
104100
endpoint: str,
105101
request: UCompletionRequest,
106102
response_type: Type[UCompletionResponse],
107-
) -> Tuple[UCompletionResponse, CompletionResponseIterator]:
103+
hooks: Optional[ResponseHooks] = None,
104+
) -> Tuple[UCompletionResponse, CompletionResponseGenerator]:
108105
if len(server) == 0:
109106
server, _ = await self._router.get_next_server(request)
110107
url = f"http://{server}/{endpoint}"
@@ -113,36 +110,46 @@ async def _send_request(
113110
self._perf_metrics_collector.inc(
114111
f"{self._client_type}_total_requests")
115112
async with self._session.post(
116-
url,
117-
json=request.model_dump(exclude_unset=True)) as response:
118-
content_type = response.headers.get("Content-Type", "")
113+
url, json=request.model_dump(
114+
exclude_unset=True)) as http_response:
115+
content_type = http_response.headers.get("Content-Type", "")
119116
if not request.stream and "text/event-stream" in content_type:
120117
raise ValueError(
121118
"Received an event-stream although request stream was False"
122119
)
123120

124-
response_dict = await response.json()
125-
if not response.ok:
121+
response_dict = await http_response.json()
122+
if not http_response.ok:
126123
logger.error(f"Received failed response {response_dict}")
127-
response.raise_for_status()
128-
return response_type(**response_dict), self._response_generator(
129-
response, start_time)
124+
http_response.raise_for_status()
125+
response = response_type(**response_dict)
126+
127+
return response, self._response_generator(
128+
request, http_response, response, start_time, hooks)
130129
except Exception:
131130
self._perf_metrics_collector.inc(
132131
f"{self._client_type}_error_requests")
133-
self.finish_request(request)
132+
await self._finish_request(request)
134133
raise
135134

136135
async def _response_generator(
137-
self, request: UCompletionRequest, response: aiohttp.ClientResponse,
138-
start_time: float) -> AsyncGenerator[bytes, None]:
136+
self,
137+
request: UCompletionRequest,
138+
http_response: aiohttp.ClientResponse,
139+
response: UCompletionResponse,
140+
start_time: float,
141+
hooks: Optional[ResponseHooks] = None
142+
) -> CompletionResponseGenerator:
139143
try:
140-
if request.stream and "text/event-stream" in response.headers.get(
144+
if request.stream and "text/event-stream" in http_response.headers.get(
141145
"Content-Type", ""):
142146
last_token_time = start_time
143-
async for i, line in enumerate(response.content.iter_any()):
147+
async for i, line in enumerate(
148+
http_response.content.iter_any()):
144149
now_time = get_steady_clock_now_in_seconds()
145150
if i == 0:
151+
if hooks and hooks.on_first_token:
152+
hooks.on_first_token(request, response)
146153
self._perf_metrics_collector.observe(
147154
f"{self._client_type}_first_token_latency_seconds",
148155
now_time - last_token_time,
@@ -152,10 +159,12 @@ async def _response_generator(
152159
f"{self._client_type}_per_token_latency_seconds",
153160
now_time - last_token_time,
154161
)
155-
last_token_time = now_time
156162
if line:
157163
yield line
158164
await asyncio.sleep(0)
165+
last_token_time = now_time
166+
if hooks and hooks.on_resp_done:
167+
hooks.on_resp_done(request, response)
159168
self._perf_metrics_collector.inc(
160169
f"{self._client_type}_completed_requests")
161170
self._perf_metrics_collector.observe(
@@ -167,10 +176,10 @@ async def _response_generator(
167176
f"{self._client_type}_error_requests")
168177
raise
169178
finally:
170-
self.finish_request(request)
179+
await self._finish_request(request)
171180

172-
async def finish_request(self, request: UCompletionRequest) -> None:
173-
self._router.finish_request(request)
181+
async def _finish_request(self, request: UCompletionRequest) -> None:
182+
await self._router.finish_request(request)
174183

175184
async def collect_metrics(self) -> Dict[str, Any]:
176185
metrics = {}

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
#!/usr/bin/env python
2-
import asyncio
3-
import copy
4-
import itertools
5-
import os
6-
import signal
72
import traceback
8-
from collections import deque
93
from contextlib import asynccontextmanager
104
from typing import Callable, Optional, Union
115

@@ -18,19 +12,19 @@
1812

1913
# yapf: disable
2014
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
21-
MetadataServerConfig)
15+
MetadataServerConfig,
16+
get_ctx_gen_server_addrs)
2217
from tensorrt_llm.logger import logger
2318
from tensorrt_llm.serve.metadata_server import create_metadata_server
2419
from tensorrt_llm.serve.openai_client import OpenAIClient, OpenAIHttpClient
2520
from tensorrt_llm.serve.openai_disagg_service import (
2621
OpenAIDisaggregatedService, ResponseHooks)
2722
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
2823
CompletionRequest)
24+
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
2925
from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware,
3026
get_steady_clock_now_in_seconds)
3127
from tensorrt_llm.serve.router import Router, create_router
32-
from tensorrt_llm.tensorrt_llm.serve.perf_metrics import \
33-
DisaggPerfMetricsCollector
3428
from tensorrt_llm.version import __version__ as VERSION
3529

3630
# yapf: enale
@@ -44,21 +38,22 @@ def __init__(self,
4438
server_start_timeout_secs: int = 180,
4539
metadata_server_cfg: Optional[MetadataServerConfig] = None,
4640
metrics_interval_secs: int = 0):
47-
self.config = config
48-
self.req_timeout_secs = req_timeout_secs
49-
self.server_start_timeout_secs = server_start_timeout_secs
50-
self.metadata_server_cfg = metadata_server_cfg
51-
self.metrics_interval_secs = metrics_interval_secs
52-
53-
self._ctx_router = create_router(config.ctx_router_config, config.ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
54-
self._gen_router = create_router(config.gen_router_config, config.gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
41+
self._config = config
42+
self._req_timeout_secs = req_timeout_secs
43+
self._server_start_timeout_secs = server_start_timeout_secs
44+
self._metadata_server_cfg = metadata_server_cfg
45+
self._metrics_interval_secs = metrics_interval_secs
46+
47+
self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs)
48+
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
49+
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
5550
self._metadata_server = create_metadata_server(metadata_server_cfg)
56-
self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests, [])
51+
self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests)
5752

58-
self._service = OpenAIDisaggregatedService(config, self._ctx_router, self._gen_router, self._create_client, self.metadata_server, req_timeout_secs, server_start_timeout_secs, self._perf_metrics_collector)
53+
self._service = OpenAIDisaggregatedService(self._config, self._ctx_router, self._gen_router, self._create_client, self._metadata_server, self._req_timeout_secs, self._server_start_timeout_secs, self._perf_metrics_collector)
5954

6055
@asynccontextmanager
61-
async def lifespan() -> None:
56+
async def lifespan(app) -> None:
6257
await self._service.setup()
6358
yield
6459
await self._service.teardown()
@@ -72,10 +67,9 @@ async def validation_exception_handler(_, exc):
7267
return JSONResponse(status_code=400, content={"error": str(exc)})
7368

7469
self.register_routes()
75-
self.mount_metrics()
7670

7771
def _create_client(self, router: Router, client_type: str, perf_metrics_collector: DisaggPerfMetricsCollector) -> OpenAIClient:
78-
return OpenAIHttpClient(router, client_type, self.req_timeout_secs, perf_metrics_collector)
72+
return OpenAIHttpClient(router, client_type, self._req_timeout_secs, perf_metrics_collector)
7973

8074

8175
def register_routes(self):
@@ -90,19 +84,27 @@ def register_routes(self):
9084
self.app.mount("/prometheus/metrics", metrics_app)
9185

9286
def _wrap_entry_point(self, entry_point: Callable) -> Callable:
93-
async def wrapper(req: Union[CompletionRequest, ChatCompletionRequest], raw_request: Request) -> Response:
87+
async def wrapper(req: Union[CompletionRequest, ChatCompletionRequest], raw_req: Request) -> Response:
88+
def update_arrival_time(req: Union[CompletionRequest, ChatCompletionRequest]):
89+
raw_req.state.server_arrival_time = get_steady_clock_now_in_seconds()
90+
def update_first_token_time(req: Union[CompletionRequest, ChatCompletionRequest], response: Response):
91+
raw_req.state.server_first_token_time = get_steady_clock_now_in_seconds()
9492
try:
95-
hooks = ResponseHooks();
96-
hooks.on_req_begin = lambda req: raw_request.state.server_arrival_time = get_steady_clock_now_in_seconds()
97-
hooks.on_first_token = lambda req, response: raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
98-
response, iterator = await entry_point(req, hooks)
93+
hooks = ResponseHooks()
94+
hooks.on_req_begin = update_arrival_time
95+
hooks.on_first_token = update_first_token_time
96+
response, generator = await entry_point(req, hooks)
9997
if req.stream:
100-
return StreamingResponse(content=iterator, media_type="text/event-stream")
98+
return StreamingResponse(content=generator, media_type="text/event-stream")
10199
else:
102100
return response
103101
except Exception as e:
104102
logger.error(f"Error in entry point: {e}")
105-
return Response(status_code=500, content=f"Internal server error: {e}")
103+
print(traceback.format_exc())
104+
import sys
105+
sys.exit(1)
106+
#return Response(status_code=500, content=f"Internal server error: {e}")
107+
raise e
106108
return wrapper
107109

108110

@@ -117,11 +119,11 @@ async def cluster_info(self) -> JSONResponse:
117119
async def version(self) -> JSONResponse:
118120
return JSONResponse(content={"version": VERSION})
119121

120-
async def __call__(self, host, port):
122+
async def __call__(self, host: str, port: int):
121123
config = uvicorn.Config(self.app,
122124
host=host,
123125
port=port,
124-
log_level="info",
126+
log_level=logger.level,
125127
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
126128
await uvicorn.Server(config).serve()
127129

@@ -150,7 +152,7 @@ async def set_steady_clock_offset(server_url: str, offset: float) -> None:
150152
async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response:
151153
if response.status != 200:
152154
logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned")
153-
for server_url in self.ctx_servers + self.gen_servers:
155+
for server_url in self._ctx_servers + self._gen_servers:
154156
delay, offset = await query_steady_clock_offset(server_url)
155157
if delay is None or offset is None:
156158
logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment")

0 commit comments

Comments
 (0)