Skip to content

Commit 51f3bdc

Browse files
committed
Share aiohttp.ClientSessions per worker
Slightly refactor `openAIModelServerClient` to add a new method, `process_request_with_session`, that accepts a custom `ReusableHTTPClientSession` per request, which allows the caller to reuse an HTTP client session per worker. The previous method, `process_request`, is made to create a fresh HTTP client session then call `process_request_with_session`, preserving the previous behavior. Prior to this commit, a new `aiohttp.ClientSession` is created for each request. Not only is this inefficient and lowers throughput, on certain environments, it also leads to inotify watch issues: aiodns - WARNING - Failed to create DNS resolver channel with automatic monitoring of resolver configuration changes. This usually means the system ran out of inotify watches. Falling back to socket state callback. Consider increasing the system inotify watch limit: Failed to initialize c-ares channel Indeed, because each DNS resolver is created for a new `ClientSession`, creating tons of new `ClientSession`s causes eventual inotify watch exhaustion. Sharing `ClientSession`s solves this issue. Relevant links: - https://docs.aiohttp.org/en/stable/http_request_lifecycle.html - https://stackoverflow.com/questions/62707369/one-aiohttp-clientsession-per-thread - home-assistant/core#144457 (comment) Relevant PR: kubernetes-sigs#247 (doesn't address the issue of worker sharing).
1 parent dab80ce commit 51f3bdc

File tree

4 files changed

+77
-14
lines changed

4 files changed

+77
-14
lines changed

inference_perf/client/modelserver/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .base import ModelServerClient
14+
from .base import ModelServerClient, ReusableHTTPClientSession
1515
from .mock_client import MockModelServerClient
1616
from .vllm_client import vLLMModelServerClient
1717
from .sglang_client import SGlangModelServerClient
1818

1919

20-
__all__ = ["ModelServerClient", "MockModelServerClient", "vLLMModelServerClient", "SGlangModelServerClient"]
20+
__all__ = [
21+
"ModelServerClient",
22+
"ReusableHTTPClientSession",
23+
"MockModelServerClient",
24+
"vLLMModelServerClient",
25+
"SGlangModelServerClient",
26+
]

inference_perf/client/modelserver/base.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import List, Optional, Tuple
15+
from typing import List, Optional, Tuple, Any
1616
from inference_perf.client.metricsclient.base import MetricsMetadata
1717
from inference_perf.config import APIConfig, APIType
18-
1918
from inference_perf.apis import InferenceAPIData
19+
import aiohttp
2020

2121

2222
class ModelServerPrometheusMetric:
@@ -87,10 +87,38 @@ def get_supported_apis(self) -> List[APIType]:
8787
raise NotImplementedError
8888

8989
@abstractmethod
90-
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
90+
async def process_request(
91+
self, data: InferenceAPIData, stage_id: int, scheduled_time: float, *args: Any, **kwargs: Any
92+
) -> None:
9193
raise NotImplementedError
9294

9395
@abstractmethod
9496
def get_prometheus_metric_metadata(self) -> PrometheusMetricMetadata:
9597
# assumption: all metrics clients have metrics exported in Prometheus format
9698
raise NotImplementedError
99+
100+
101+
class ReusableHTTPClientSession:
102+
"""
103+
A wrapper for aiohttp.ClientSession to allow for reusable sessions.
104+
This is useful for sharing among many HTTP clients.
105+
"""
106+
107+
def __init__(self, session: aiohttp.ClientSession, dont_close: bool = False) -> None:
108+
self.session = session
109+
self.dont_close = dont_close
110+
111+
def make_dont_close(self) -> "ReusableHTTPClientSession":
112+
return ReusableHTTPClientSession(session=self.session, dont_close=True)
113+
114+
async def close(self) -> None:
115+
if self.dont_close:
116+
self.dont_close = False
117+
return
118+
await self.session.close()
119+
120+
async def __aenter__(self) -> None:
121+
pass
122+
123+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
124+
await self.close()

inference_perf/client/modelserver/openai_client.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from inference_perf.config import APIConfig, APIType, CustomTokenizerConfig
1818
from inference_perf.apis import InferenceAPIData, InferenceInfo, RequestLifecycleMetric, ErrorResponseInfo
1919
from inference_perf.utils import CustomTokenizer
20-
from .base import ModelServerClient, PrometheusMetricMetadata
20+
from .base import ModelServerClient, PrometheusMetricMetadata, ReusableHTTPClientSession
2121
from typing import List, Optional
2222
import aiohttp
2323
import asyncio
@@ -30,6 +30,8 @@
3030

3131

3232
class openAIModelServerClient(ModelServerClient):
33+
_session: aiohttp.ClientSession | None = None
34+
3335
def __init__(
3436
self,
3537
metrics_collector: RequestDataCollector,
@@ -71,6 +73,17 @@ def __init__(
7173
self.tokenizer = CustomTokenizer(tokenizer_config)
7274

7375
async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled_time: float) -> None:
76+
session = self.new_reusable_session()
77+
async with session:
78+
await self.process_request_with_session(data, stage_id, scheduled_time, session)
79+
80+
async def process_request_with_session(
81+
self,
82+
data: InferenceAPIData,
83+
stage_id: int,
84+
scheduled_time: float,
85+
session: ReusableHTTPClientSession,
86+
) -> None:
7487
payload = data.to_payload(
7588
model_name=self.model_name,
7689
max_tokens=self.max_completion_tokens,
@@ -87,14 +100,10 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
87100

88101
request_data = json.dumps(payload)
89102

90-
timeout = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel
91-
92-
async with aiohttp.ClientSession(
93-
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections), timeout=timeout
94-
) as session:
103+
async with session.make_dont_close():
95104
start = time.perf_counter()
96105
try:
97-
async with session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
106+
async with session.session.post(self.uri + data.get_route(), headers=headers, data=request_data) as response:
98107
response_info = await data.process_response(
99108
response=response, config=self.api_config, tokenizer=self.tokenizer
100109
)
@@ -138,6 +147,14 @@ async def process_request(self, data: InferenceAPIData, stage_id: int, scheduled
138147
)
139148
)
140149

150+
def new_reusable_session(self) -> ReusableHTTPClientSession:
151+
return ReusableHTTPClientSession(
152+
aiohttp.ClientSession(
153+
timeout=aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.helpers.sentinel,
154+
connector=aiohttp.TCPConnector(limit=self.max_tcp_connections),
155+
)
156+
)
157+
141158
def get_supported_apis(self) -> List[APIType]:
142159
return []
143160

inference_perf/loadgen/load_generator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from .load_timer import LoadTimer, ConstantLoadTimer, PoissonLoadTimer, TraceReplayLoadTimer
1818
from inference_perf.datagen import DataGenerator
1919
from inference_perf.apis import InferenceAPIData
20-
from inference_perf.client.modelserver import ModelServerClient
20+
from inference_perf.client.modelserver import ModelServerClient, ReusableHTTPClientSession
21+
from inference_perf.client.modelserver.openai_client import openAIModelServerClient
2122
from inference_perf.circuit_breaker import get_circuit_breaker
2223
from inference_perf.config import LoadConfig, LoadStage, LoadType, StageGenType, TraceFormat
2324
from asyncio import (
@@ -83,6 +84,10 @@ async def loop(self) -> None:
8384
item = None
8485
timeout = 0.5
8586

87+
session: ReusableHTTPClientSession | None = None
88+
if isinstance(self.client, openAIModelServerClient):
89+
session = self.client.new_reusable_session()
90+
8691
while not self.stop_signal.is_set():
8792
while self.request_phase.is_set() and not self.cancel_signal.is_set():
8893
await semaphore.acquire()
@@ -120,7 +125,12 @@ async def schedule_client(
120125
with self.active_requests_counter.get_lock():
121126
self.active_requests_counter.value += 1
122127
inflight = True
123-
await self.client.process_request(request_data, stage_id, request_time)
128+
129+
if isinstance(self.client, openAIModelServerClient):
130+
assert session
131+
await self.client.process_request_with_session(request_data, stage_id, request_time, session)
132+
else:
133+
await self.client.process_request(request_data, stage_id, request_time)
124134
except CancelledError:
125135
pass
126136
finally:
@@ -151,6 +161,8 @@ async def schedule_client(
151161
logger.debug(f"[Worker {self.id}] waiting for next phase")
152162
self.request_phase.wait()
153163

164+
if session:
165+
await session.close()
154166
logger.debug(f"[Worker {self.id}] stopped")
155167

156168
def run(self) -> None:

0 commit comments

Comments
 (0)