Skip to content

Commit 579e250

Browse files
committed
add a UT for OpenAIHttpClient
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 850f245 commit 579e250

File tree

3 files changed

+275
-1
lines changed

3 files changed

+275
-1
lines changed

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
server_start_timeout_secs: int = 180,
6060
perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None,
6161
disagg_cluster_storage: Optional[ClusterStorage] = None,
62+
health_check_interval_secs: int = 3,
6263
):
6364
self._config = config
6465
self._ctx_router = ctx_router
@@ -70,6 +71,7 @@ def __init__(
7071
self._server_start_timeout_secs = server_start_timeout_secs
7172
self._perf_metrics_collector = perf_metrics_collector
7273
self._cluster_storage = disagg_cluster_storage
74+
self._health_check_interval_secs = health_check_interval_secs
7375

7476
self._ctx_client = None
7577
self._gen_client = None
@@ -250,7 +252,7 @@ async def teardown(self) -> None:
250252
async def _wait_for_all_servers_ready(self) -> None:
251253
async def check_servers_ready():
252254
elapsed_time = 0
253-
interval = 3
255+
interval = self._health_check_interval_secs
254256
while elapsed_time < self._server_start_timeout_secs:
255257
_, unready_ctx_servers = await self._ctx_client.check_ready()
256258
_, unready_gen_servers = await self._gen_client.check_ready()

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ l0_a10:
2323
# test list either).
2424
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
2525
- unittest/others/test_time_breakdown.py
26+
- unittest/disaggregated/test_disagg_openai_client.py
2627
- unittest/disaggregated/test_disagg_utils.py
2728
- unittest/disaggregated/test_router.py
2829
- unittest/disaggregated/test_remoteDictionary.py
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import AsyncMock, Mock, patch
16+
17+
import aiohttp
18+
import pytest
19+
20+
from tensorrt_llm.llmapi.disagg_utils import ServerRole
21+
from tensorrt_llm.serve.openai_client import OpenAIHttpClient
22+
from tensorrt_llm.serve.openai_protocol import (
23+
CompletionRequest,
24+
CompletionResponse,
25+
CompletionResponseChoice,
26+
DisaggregatedParams,
27+
UsageInfo,
28+
)
29+
from tensorrt_llm.serve.router import Router
30+
31+
32+
@pytest.fixture
33+
def mock_router():
34+
"""Create a mock router."""
35+
router = AsyncMock(spec=Router)
36+
router.servers = ["localhost:8000", "localhost:8001"]
37+
router.get_next_server = AsyncMock(return_value=("localhost:8000", None))
38+
router.finish_request = AsyncMock()
39+
return router
40+
41+
42+
@pytest.fixture
43+
def mock_session():
44+
"""Create a mock aiohttp session."""
45+
return AsyncMock(spec=aiohttp.ClientSession)
46+
47+
48+
@pytest.fixture
49+
def openai_client(mock_router, mock_session):
50+
"""Create an OpenAIHttpClient instance."""
51+
# uninitialize the prometheus metrics collector or it will raise a duplicate metric error
52+
from prometheus_client.registry import REGISTRY
53+
54+
REGISTRY._names_to_collectors = {}
55+
REGISTRY._collector_to_names = {}
56+
return OpenAIHttpClient(
57+
router=mock_router,
58+
role=ServerRole.CONTEXT,
59+
timeout_secs=180,
60+
max_retries=2,
61+
retry_interval_sec=1,
62+
session=mock_session,
63+
)
64+
65+
66+
@pytest.fixture
67+
def completion_request():
68+
"""Create a sample non-streaming CompletionRequest."""
69+
return CompletionRequest(
70+
model="test-model",
71+
prompt="Hello, world!",
72+
stream=False,
73+
disaggregated_params=DisaggregatedParams(
74+
request_type="generation_only", first_gen_tokens=[123], ctx_request_id=123
75+
),
76+
)
77+
78+
79+
@pytest.fixture
80+
def streaming_completion_request():
81+
"""Create a sample streaming CompletionRequest."""
82+
return CompletionRequest(
83+
model="test-model",
84+
prompt="Hello, world!",
85+
stream=True,
86+
disaggregated_params=DisaggregatedParams(
87+
request_type="generation_only", first_gen_tokens=[456], ctx_request_id=456
88+
),
89+
)
90+
91+
92+
class TestOpenAIHttpClient:
93+
"""Test OpenAIHttpClient main functionality."""
94+
95+
def dummy_response(self):
96+
return CompletionResponse(
97+
id="test-123",
98+
object="text_completion",
99+
created=1234567890,
100+
model="test-model",
101+
usage=UsageInfo(prompt_tokens=10, completion_tokens=10),
102+
choices=[CompletionResponseChoice(index=0, text="Hello!")],
103+
)
104+
105+
def test_initialization(self, mock_router, mock_session):
106+
"""Test client initialization."""
107+
client = OpenAIHttpClient(
108+
router=mock_router,
109+
role=ServerRole.GENERATION,
110+
timeout_secs=300,
111+
max_retries=5,
112+
session=mock_session,
113+
)
114+
assert client._router == mock_router
115+
assert client._role == ServerRole.GENERATION
116+
assert client._session == mock_session
117+
assert client._max_retries == 5
118+
119+
@pytest.mark.asyncio
120+
async def test_non_streaming_completion_request(
121+
self, openai_client, completion_request, mock_session, mock_router
122+
):
123+
"""Test non-streaming completion request end-to-end."""
124+
mock_response = self.dummy_response()
125+
126+
# Mock HTTP response
127+
mock_http_response = AsyncMock()
128+
mock_http_response.status = 200
129+
mock_http_response.headers = {"Content-Type": "application/json"}
130+
mock_http_response.json = AsyncMock(return_value=mock_response.model_dump())
131+
mock_http_response.raise_for_status = Mock()
132+
mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response)
133+
mock_http_response.__aexit__ = AsyncMock()
134+
135+
mock_session.post.return_value = mock_http_response
136+
137+
# Send request
138+
response = await openai_client.send_request(completion_request)
139+
140+
# Assertions
141+
assert isinstance(response, CompletionResponse)
142+
assert response.model == "test-model"
143+
mock_session.post.assert_called_once()
144+
mock_router.finish_request.assert_called_once_with(completion_request)
145+
146+
@pytest.mark.asyncio
147+
async def test_streaming_completion_request(
148+
self, openai_client, streaming_completion_request, mock_session, mock_router
149+
):
150+
"""Test streaming completion request end-to-end."""
151+
# Mock HTTP streaming response
152+
mock_http_response = AsyncMock()
153+
mock_http_response.status = 200
154+
mock_http_response.headers = {"Content-Type": "text/event-stream"}
155+
156+
dummy_data = [
157+
b'data: "Hello"\n\n',
158+
b'data: "world"\n\n',
159+
b'data: "!"\n\n',
160+
]
161+
162+
async def mock_iter_any():
163+
for data in dummy_data:
164+
yield data
165+
166+
mock_http_response.content = AsyncMock()
167+
mock_http_response.content.iter_any = mock_iter_any
168+
mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response)
169+
mock_http_response.__aexit__ = AsyncMock()
170+
171+
mock_session.post.return_value = mock_http_response
172+
173+
# Send streaming request
174+
response_generator = await openai_client.send_request(streaming_completion_request)
175+
176+
# Consume the generator
177+
chunks = []
178+
async for chunk in response_generator:
179+
chunks.append(chunk)
180+
181+
# Assertions
182+
assert len(chunks) == 3
183+
for i, chunk in enumerate(chunks):
184+
assert chunk == dummy_data[i]
185+
mock_session.post.assert_called_once()
186+
mock_router.finish_request.assert_called_once_with(streaming_completion_request)
187+
188+
@pytest.mark.asyncio
189+
async def test_request_with_custom_server(
190+
self, openai_client, completion_request, mock_session, mock_router
191+
):
192+
"""Test sending request to a specific server."""
193+
custom_server = "localhost:9000"
194+
mock_response = self.dummy_response()
195+
196+
mock_http_response = AsyncMock()
197+
mock_http_response.headers = {"Content-Type": "application/json"}
198+
mock_http_response.json = AsyncMock(return_value=mock_response.model_dump())
199+
mock_http_response.raise_for_status = Mock()
200+
mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response)
201+
mock_http_response.__aexit__ = AsyncMock()
202+
203+
mock_session.post.return_value = mock_http_response
204+
205+
await openai_client.send_request(completion_request, server=custom_server)
206+
207+
# Verify custom server was used in URL
208+
call_args = mock_session.post.call_args[0][0]
209+
assert custom_server in call_args
210+
# Router should not be called when server is specified
211+
mock_router.get_next_server.assert_not_called()
212+
213+
@pytest.mark.asyncio
214+
async def test_request_error_handling(
215+
self, openai_client, completion_request, mock_session, mock_router
216+
):
217+
"""Test error handling when request fails."""
218+
mock_session.post.side_effect = aiohttp.ClientError("Connection failed")
219+
220+
with pytest.raises(aiohttp.ClientError):
221+
await openai_client.send_request(completion_request)
222+
223+
# Should finish request on error
224+
mock_router.finish_request.assert_called_once_with(completion_request)
225+
226+
@pytest.mark.asyncio
227+
async def test_request_with_retry(
228+
self, openai_client, completion_request, mock_session, mock_router
229+
):
230+
"""Test retry mechanism on transient failures."""
231+
mock_response = self.dummy_response()
232+
233+
mock_http_response = AsyncMock()
234+
mock_http_response.headers = {"Content-Type": "application/json"}
235+
mock_http_response.json = AsyncMock(return_value=mock_response.model_dump())
236+
mock_http_response.raise_for_status = Mock()
237+
mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response)
238+
mock_http_response.__aexit__ = AsyncMock()
239+
240+
# First attempt fails, second succeeds
241+
mock_session.post.side_effect = [
242+
aiohttp.ClientError("Temporary failure"),
243+
mock_http_response,
244+
]
245+
246+
with patch("asyncio.sleep", new_callable=AsyncMock):
247+
response = await openai_client.send_request(completion_request)
248+
249+
assert isinstance(response, CompletionResponse)
250+
assert mock_session.post.call_count == 2 # Initial + 1 retry
251+
252+
@pytest.mark.asyncio
253+
async def test_max_retries_exceeded(
254+
self, openai_client, completion_request, mock_session, mock_router
255+
):
256+
"""Test that request fails after max retries."""
257+
mock_session.post.side_effect = aiohttp.ClientError("Connection failed")
258+
259+
with patch("asyncio.sleep", new_callable=AsyncMock):
260+
with pytest.raises(aiohttp.ClientError):
261+
await openai_client.send_request(completion_request)
262+
263+
# Should try max_retries + 1 times
264+
assert mock_session.post.call_count == openai_client._max_retries + 1
265+
mock_router.finish_request.assert_called_once()
266+
267+
@pytest.mark.asyncio
268+
async def test_invalid_request_type(self, openai_client):
269+
"""Test handling of invalid request type."""
270+
with pytest.raises(ValueError, match="Invalid request type"):
271+
await openai_client.send_request("invalid_request")

0 commit comments

Comments
 (0)