Skip to content

Commit 2543f34

Browse files
authored
Fix: JobScaler issues that cause request failures (#383)
* Integrated asyncio.Queue within JobScaler (removes JobsQueue) and fully take advantage of its blocking .get .put functions * Using asyncio.Queue(maxsize) to dictate concurrency (via concurrency_modifier) * JobScaler.set_scale() adjusts concurrency when needed and safe in runtime * JobScaler.current_occupancy() uses asyncio.Queue size and JobsProgress(set) size to gate capacity * Simpler/cleaner job acquisition steps * Removed legacy tracers for http clients
1 parent a477453 commit 2543f34

13 files changed

+188
-780
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ runpod = "runpod.cli.entry:runpod_cli"
5454
test = [
5555
"asynctest",
5656
"nest_asyncio",
57+
"faker",
5758
"pytest-asyncio",
5859
"pytest-cov",
5960
"pytest-timeout",

runpod/http_client.py

+1-42
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError
99

1010
from .cli.groups.config.functions import get_credentials
11-
from .tracer import create_aiohttp_tracer, create_request_tracer
1211
from .user_agent import USER_AGENT
1312

1413

@@ -37,13 +36,11 @@ def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name
3736
"""
3837
Deprecation from aiohttp.ClientSession forbids inheritance.
3938
This is now a factory method
40-
TODO: use httpx
4139
"""
4240
return ClientSession(
4341
connector=TCPConnector(limit=0),
4442
headers=get_auth_header(),
4543
timeout=ClientTimeout(600, ceil_threshold=400),
46-
trace_configs=[create_aiohttp_tracer()],
4744
*args,
4845
**kwargs,
4946
)
@@ -52,43 +49,5 @@ def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name
5249
class SyncClientSession(requests.Session):
5350
"""
5451
Inherits requests.Session to override `request()` method for tracing
55-
TODO: use httpx
5652
"""
57-
58-
def request(self, method, url, **kwargs): # pylint: disable=arguments-differ
59-
"""
60-
Override for tracing. Not using super().request()
61-
to capture metrics for connection and transfer times
62-
"""
63-
with create_request_tracer() as tracer:
64-
# Separate out the kwargs that are not applicable to `requests.Request`
65-
request_kwargs = {
66-
k: v
67-
for k, v in kwargs.items()
68-
# contains the names of the arguments
69-
if k in requests.Request.__init__.__code__.co_varnames
70-
}
71-
72-
# Separate out the kwargs that are applicable to `requests.Request`
73-
send_kwargs = {k: v for k, v in kwargs.items() if k not in request_kwargs}
74-
75-
# Create a PreparedRequest object to hold the request details
76-
req = requests.Request(method, url, **request_kwargs)
77-
prepped = self.prepare_request(req)
78-
tracer.request = prepped # Assign the request to the tracer
79-
80-
# Merge environment settings
81-
settings = self.merge_environment_settings(
82-
prepped.url,
83-
send_kwargs.get("proxies"),
84-
send_kwargs.get("stream"),
85-
send_kwargs.get("verify"),
86-
send_kwargs.get("cert"),
87-
)
88-
send_kwargs.update(settings)
89-
90-
# Send the request
91-
response = self.send(prepped, **send_kwargs)
92-
tracer.response = response # Assign the response to the tracer
93-
94-
return response
53+
pass

runpod/serverless/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323

2424
log = RunPodLogger()
2525

26+
27+
def handle_uncaught_exception(exc_type, exc_value, exc_traceback):
28+
log.error(f"Uncaught exception | {exc_type}; {exc_value}; {exc_traceback};")
29+
30+
sys.excepthook = handle_uncaught_exception
31+
32+
2633
# ---------------------------------------------------------------------------- #
2734
# Run Time Arguments #
2835
# ---------------------------------------------------------------------------- #

runpod/serverless/modules/rp_fastapi.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,12 @@ async def _realtime(self, job: Job):
286286
Performs model inference on the input data using the provided handler.
287287
If handler is not provided, returns an error message.
288288
"""
289-
await job_list.add(job.id)
289+
job_list.add(job.id)
290290

291291
# Process the job using the provided handler, passing in the job input.
292292
job_results = await run_job(self.config["handler"], job.__dict__)
293293

294-
await job_list.remove(job.id)
294+
job_list.remove(job.id)
295295

296296
# Return the results of the job processing.
297297
return jsonable_encoder(job_results)
@@ -304,7 +304,7 @@ async def _realtime(self, job: Job):
304304
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
305305
"""Development endpoint to simulate run behavior."""
306306
assigned_job_id = f"test-{uuid.uuid4()}"
307-
await job_list.add({
307+
job_list.add({
308308
"id": assigned_job_id,
309309
"input": job_request.input,
310310
"webhook": job_request.webhook
@@ -345,7 +345,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
345345
# ---------------------------------- stream ---------------------------------- #
346346
async def _sim_stream(self, job_id: str) -> StreamOutput:
347347
"""Development endpoint to simulate stream behavior."""
348-
stashed_job = await job_list.get(job_id)
348+
stashed_job = job_list.get(job_id)
349349
if stashed_job is None:
350350
return jsonable_encoder(
351351
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
@@ -367,7 +367,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
367367
}
368368
)
369369

370-
await job_list.remove(job.id)
370+
job_list.remove(job.id)
371371

372372
if stashed_job.webhook:
373373
thread = threading.Thread(
@@ -384,7 +384,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
384384
# ---------------------------------- status ---------------------------------- #
385385
async def _sim_status(self, job_id: str) -> JobOutput:
386386
"""Development endpoint to simulate status behavior."""
387-
stashed_job = await job_list.get(job_id)
387+
stashed_job = job_list.get(job_id)
388388
if stashed_job is None:
389389
return jsonable_encoder(
390390
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
@@ -400,7 +400,7 @@ async def _sim_status(self, job_id: str) -> JobOutput:
400400
else:
401401
job_output = await run_job(self.config["handler"], job.__dict__)
402402

403-
await job_list.remove(job.id)
403+
job_list.remove(job.id)
404404

405405
if job_output.get("error", None):
406406
return jsonable_encoder(

runpod/serverless/modules/rp_scale.py

+75-34
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
1111
from .rp_job import get_job, handle_job
1212
from .rp_logger import RunPodLogger
13-
from .worker_state import JobsQueue, JobsProgress
13+
from .worker_state import JobsProgress, IS_LOCAL_TEST
1414

1515
log = RunPodLogger()
16-
job_list = JobsQueue()
1716
job_progress = JobsProgress()
1817

1918

@@ -38,16 +37,50 @@ class JobScaler:
3837
"""
3938

4039
def __init__(self, config: Dict[str, Any]):
41-
concurrency_modifier = config.get("concurrency_modifier")
42-
if concurrency_modifier is None:
43-
self.concurrency_modifier = _default_concurrency_modifier
44-
else:
45-
self.concurrency_modifier = concurrency_modifier
46-
4740
self._shutdown_event = asyncio.Event()
4841
self.current_concurrency = 1
4942
self.config = config
5043

44+
self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)
45+
46+
self.concurrency_modifier = _default_concurrency_modifier
47+
self.jobs_fetcher = get_job
48+
self.jobs_fetcher_timeout = 90
49+
self.jobs_handler = handle_job
50+
51+
if concurrency_modifier := config.get("concurrency_modifier"):
52+
self.concurrency_modifier = concurrency_modifier
53+
54+
if not IS_LOCAL_TEST:
55+
# below cannot be changed unless local
56+
return
57+
58+
if jobs_fetcher := self.config.get("jobs_fetcher"):
59+
self.jobs_fetcher = jobs_fetcher
60+
61+
if jobs_fetcher_timeout := self.config.get("jobs_fetcher_timeout"):
62+
self.jobs_fetcher_timeout = jobs_fetcher_timeout
63+
64+
if jobs_handler := self.config.get("jobs_handler"):
65+
self.jobs_handler = jobs_handler
66+
67+
async def set_scale(self):
68+
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)
69+
70+
if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize):
71+
# no need to resize
72+
return
73+
74+
while self.current_occupancy() > 0:
75+
# not safe to scale when jobs are in flight
76+
await asyncio.sleep(1)
77+
continue
78+
79+
self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)
80+
log.debug(
81+
f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}"
82+
)
83+
5184
def start(self):
5285
"""
5386
This is required for the worker to be able to shut down gracefully
@@ -105,6 +138,15 @@ def kill_worker(self):
105138
log.info("Kill worker.")
106139
self._shutdown_event.set()
107140

141+
def current_occupancy(self) -> int:
142+
current_queue_count = self.jobs_queue.qsize()
143+
current_progress_count = job_progress.get_job_count()
144+
145+
log.debug(
146+
f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}"
147+
)
148+
return current_progress_count + current_queue_count
149+
108150
async def get_jobs(self, session: ClientSession):
109151
"""
110152
Retrieve multiple jobs from the server in batches using blocking requests.
@@ -114,45 +156,42 @@ async def get_jobs(self, session: ClientSession):
114156
Adds jobs to the JobsQueue
115157
"""
116158
while self.is_alive():
117-
log.debug("JobScaler.get_jobs | Starting job acquisition.")
118-
119-
self.current_concurrency = self.concurrency_modifier(
120-
self.current_concurrency
121-
)
122-
log.debug(f"JobScaler.get_jobs | current Concurrency set to: {self.current_concurrency}")
159+
await self.set_scale()
123160

124-
current_progress_count = await job_progress.get_job_count()
125-
log.debug(f"JobScaler.get_jobs | current Jobs in progress: {current_progress_count}")
126-
127-
current_queue_count = job_list.get_job_count()
128-
log.debug(f"JobScaler.get_jobs | current Jobs in queue: {current_queue_count}")
129-
130-
jobs_needed = self.current_concurrency - current_progress_count - current_queue_count
161+
jobs_needed = self.current_concurrency - self.current_occupancy()
131162
if jobs_needed <= 0:
132163
log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.")
133164
await asyncio.sleep(1) # don't go rapidly
134165
continue
135166

136167
try:
137-
# Keep the connection to the blocking call up to 30 seconds
168+
log.debug("JobScaler.get_jobs | Starting job acquisition.")
169+
170+
# Keep the connection to the blocking call with timeout
138171
acquired_jobs = await asyncio.wait_for(
139-
get_job(session, jobs_needed), timeout=30
172+
self.jobs_fetcher(session, jobs_needed),
173+
timeout=self.jobs_fetcher_timeout,
140174
)
141175

142176
if not acquired_jobs:
143177
log.debug("JobScaler.get_jobs | No jobs acquired.")
144178
continue
145179

146180
for job in acquired_jobs:
147-
await job_list.add_job(job)
181+
await self.jobs_queue.put(job)
182+
job_progress.add(job)
183+
log.debug("Job Queued", job["id"])
148184

149-
log.info(f"Jobs in queue: {job_list.get_job_count()}")
185+
log.info(f"Jobs in queue: {self.jobs_queue.qsize()}")
150186

151187
except TooManyRequests:
152-
log.debug(f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds.")
188+
log.debug(
189+
f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds."
190+
)
153191
await asyncio.sleep(5) # debounce for 5 seconds
154192
except asyncio.CancelledError:
155193
log.debug("JobScaler.get_jobs | Request was cancelled.")
194+
raise # CancelledError is a BaseException
156195
except TimeoutError:
157196
log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.")
158197
except TypeError as error:
@@ -173,10 +212,10 @@ async def run_jobs(self, session: ClientSession):
173212
"""
174213
tasks = [] # Store the tasks for concurrent job processing
175214

176-
while self.is_alive() or not job_list.empty():
215+
while self.is_alive() or not self.jobs_queue.empty():
177216
# Fetch as many jobs as the concurrency allows
178-
while len(tasks) < self.current_concurrency and not job_list.empty():
179-
job = await job_list.get_job()
217+
while len(tasks) < self.current_concurrency and not self.jobs_queue.empty():
218+
job = await self.jobs_queue.get()
180219

181220
# Create a new task for each job and add it to the task list
182221
task = asyncio.create_task(self.handle_job(session, job))
@@ -204,9 +243,9 @@ async def handle_job(self, session: ClientSession, job: dict):
204243
Process an individual job. This function is run concurrently for multiple jobs.
205244
"""
206245
try:
207-
await job_progress.add(job)
246+
log.debug("Handling Job", job["id"])
208247

209-
await handle_job(session, self.config, job)
248+
await self.jobs_handler(session, self.config, job)
210249

211250
if self.config.get("refresh_worker", False):
212251
self.kill_worker()
@@ -216,8 +255,10 @@ async def handle_job(self, session: ClientSession, job: dict):
216255
raise err
217256

218257
finally:
219-
# Inform JobsQueue of a task completion
220-
job_list.task_done()
258+
# Inform Queue of a task completion
259+
self.jobs_queue.task_done()
221260

222261
# Job is no longer in progress
223-
await job_progress.remove(job["id"])
262+
job_progress.remove(job)
263+
264+
log.debug("Finished Job", job["id"])

0 commit comments

Comments
 (0)