Skip to content

Commit 5a6b911

Browse files
authored
Blocking job take call means 5-sec debounce no longer needed (#366)
Fix: This was causing unnecessary delays in serverless workers. Refactored rp_job.get_job to work well under pause and unpause conditions. More debug lines too. Refactored rp_scale.JobScaler to handle shutdowns where it cleans up hanging tasks and connections gracefully. Better debug lines. Fixed rp_scale.JobScaler from unnecessary long asyncio.sleeps made before considering the blocking get_job calls. Improved worker_state's JobProgress and JobsQueue to timestamp when jobs are added or removed. Incorporated the lines of code in worker.run_worker into rp_scale.JobScaler where it belongs and simplified to job_scaler.start() Fixed non-error logged as errors in tracer Updated unit tests mandating these changes
1 parent 5d1cec6 commit 5a6b911

File tree

7 files changed

+316
-269
lines changed

7 files changed

+316
-269
lines changed

runpod/serverless/modules/rp_job.py

+39-38
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Job related helpers.
33
"""
44

5-
import asyncio
65
import inspect
76
import json
87
import os
@@ -60,43 +59,45 @@ async def get_job(
6059
session (ClientSession): The aiohttp ClientSession to use for the request.
6160
num_jobs (int): The number of jobs to get.
6261
"""
63-
try:
64-
async with session.get(_job_get_url(num_jobs)) as response:
65-
if response.status == 204:
66-
log.debug("No content, no job to process.")
67-
return
68-
69-
if response.status == 400:
70-
log.debug("Received 400 status, expected when FlashBoot is enabled.")
71-
return
72-
73-
if response.status != 200:
74-
log.error(f"Failed to get job, status code: {response.status}")
75-
return
76-
77-
jobs = await response.json()
78-
log.debug(f"Request Received | {jobs}")
79-
80-
# legacy job-take API
81-
if isinstance(jobs, dict):
82-
if "id" not in jobs or "input" not in jobs:
83-
raise Exception("Job has missing field(s): id or input.")
84-
return [jobs]
85-
86-
# batch job-take API
87-
if isinstance(jobs, list):
88-
return jobs
89-
90-
except asyncio.TimeoutError:
91-
log.debug("Timeout error, retrying.")
92-
93-
except Exception as error:
94-
log.error(
95-
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
96-
)
97-
98-
# empty
99-
return []
62+
async with session.get(_job_get_url(num_jobs)) as response:
63+
log.debug(f"- Response: {type(response).__name__} {response.status}")
64+
65+
if response.status == 204:
66+
log.debug("- No content, no job to process.")
67+
return
68+
69+
if response.status == 400:
70+
log.debug("- Received 400 status, expected when FlashBoot is enabled.")
71+
return
72+
73+
try:
74+
response.raise_for_status()
75+
except Exception:
76+
log.error(f"- Failed to get job, status code: {response.status}")
77+
return
78+
79+
# Verify if the content type is JSON
80+
if response.content_type != "application/json":
81+
log.error(f"- Unexpected content type: {response.content_type}")
82+
return
83+
84+
# Check if there is a non-empty content to parse
85+
if response.content_length == 0:
86+
log.debug("- No content to parse.")
87+
return
88+
89+
jobs = await response.json()
90+
log.debug(f"- Received Job(s)")
91+
92+
# legacy job-take API
93+
if isinstance(jobs, dict):
94+
if "id" not in jobs or "input" not in jobs:
95+
raise Exception("Job has missing field(s): id or input.")
96+
return [jobs]
97+
98+
# batch job-take API
99+
if isinstance(jobs, list):
100+
return jobs
100101

101102

102103
async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dict:

runpod/serverless/modules/rp_scale.py

+112-34
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
"""
55

66
import asyncio
7+
import signal
78
from typing import Any, Dict
89

9-
from ...http_client import ClientSession
10+
from ...http_client import AsyncClientSession, ClientSession
1011
from .rp_job import get_job, handle_job
1112
from .rp_logger import RunPodLogger
1213
from .worker_state import JobsQueue, JobsProgress
@@ -36,26 +37,91 @@ class JobScaler:
3637
Job Scaler. This class is responsible for scaling the number of concurrent requests.
3738
"""
3839

39-
def __init__(self, concurrency_modifier: Any):
40+
def __init__(self, config: Dict[str, Any]):
41+
concurrency_modifier = config.get("concurrency_modifier")
4042
if concurrency_modifier is None:
4143
self.concurrency_modifier = _default_concurrency_modifier
4244
else:
4345
self.concurrency_modifier = concurrency_modifier
4446

47+
self._shutdown_event = asyncio.Event()
4548
self.current_concurrency = 1
46-
self._is_alive = True
49+
self.config = config
50+
51+
def start(self):
52+
"""
53+
This is required for the worker to be able to shut down gracefully
54+
when the user sends a SIGTERM or SIGINT signal. This is typically
55+
the case when the worker is running in a container.
56+
"""
57+
try:
58+
# Register signal handlers for graceful shutdown
59+
signal.signal(signal.SIGTERM, self.handle_shutdown)
60+
signal.signal(signal.SIGINT, self.handle_shutdown)
61+
except ValueError:
62+
log.warning("Signal handling is only supported in the main thread.")
63+
64+
# Start the main loop
65+
# Run forever until the worker is signalled to shut down.
66+
asyncio.run(self.run())
67+
68+
def handle_shutdown(self, signum, frame):
69+
"""
70+
Called when the worker is signalled to shut down.
71+
72+
This function is called when the worker receives a signal to shut down, such as
73+
SIGTERM or SIGINT. It sets the shutdown event, which will cause the worker to
74+
exit its main loop and shut down gracefully.
75+
76+
Args:
77+
signum: The signal number that was received.
78+
frame: The current stack frame.
79+
"""
80+
log.debug(f"Received shutdown signal: {signum}.")
81+
self.kill_worker()
82+
83+
async def run(self):
84+
# Create an async session that will be closed when the worker is killed.
85+
86+
async with AsyncClientSession() as session:
87+
# Create tasks for getting and running jobs.
88+
jobtake_task = asyncio.create_task(self.get_jobs(session))
89+
jobrun_task = asyncio.create_task(self.run_jobs(session))
90+
91+
tasks = [jobtake_task, jobrun_task]
92+
93+
try:
94+
# Concurrently run both tasks and wait for both to finish.
95+
await asyncio.gather(*tasks)
96+
except asyncio.CancelledError: # worker is killed
97+
log.debug("Worker tasks cancelled.")
98+
self.kill_worker()
99+
finally:
100+
# Handle the task cancellation gracefully
101+
for task in tasks:
102+
if not task.done():
103+
task.cancel()
104+
await asyncio.gather(*tasks, return_exceptions=True)
105+
await self.cleanup() # Ensure resources are cleaned up
106+
107+
async def cleanup(self):
108+
# Perform any necessary cleanup here, such as closing connections
109+
log.debug("Cleaning up resources before shutdown.")
110+
# TODO: stop heartbeat or close any open connections
111+
await asyncio.sleep(0) # Give a chance for other tasks to run (optional)
112+
log.debug("Cleanup complete.")
47113

48114
def is_alive(self):
49115
"""
50116
Return whether the worker is alive or not.
51117
"""
52-
return self._is_alive
118+
return not self._shutdown_event.is_set()
53119

54120
def kill_worker(self):
55121
"""
56122
Whether to kill the worker.
57123
"""
58-
self._is_alive = False
124+
self._shutdown_event.set()
59125

60126
async def get_jobs(self, session: ClientSession):
61127
"""
@@ -66,38 +132,50 @@ async def get_jobs(self, session: ClientSession):
66132
Adds jobs to the JobsQueue
67133
"""
68134
while self.is_alive():
69-
log.debug(f"Jobs in progress: {job_progress.get_job_count()}")
70-
71-
try:
72-
self.current_concurrency = self.concurrency_modifier(
73-
self.current_concurrency
74-
)
75-
log.debug(f"Concurrency set to: {self.current_concurrency}")
76-
77-
jobs_needed = self.current_concurrency - job_progress.get_job_count()
78-
if not jobs_needed: # zero or less
79-
log.debug("Queue is full. Retrying soon.")
80-
continue
135+
log.debug(f"JobScaler.get_jobs | Jobs in progress: {job_progress.get_job_count()}")
81136

82-
acquired_jobs = await get_job(session, jobs_needed)
83-
if not acquired_jobs:
84-
log.debug("No jobs acquired.")
85-
continue
137+
self.current_concurrency = self.concurrency_modifier(
138+
self.current_concurrency
139+
)
140+
log.debug(f"JobScaler.get_jobs | Concurrency set to: {self.current_concurrency}")
86141

87-
for job in acquired_jobs:
88-
await job_list.add_job(job)
89-
90-
log.info(f"Jobs in queue: {job_list.get_job_count()}")
142+
jobs_needed = self.current_concurrency - job_progress.get_job_count()
143+
if jobs_needed <= 0:
144+
log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.")
145+
await asyncio.sleep(0.1) # don't go rapidly
146+
continue
91147

148+
try:
149+
# Keep the connection to the blocking call up to 30 seconds
150+
acquired_jobs = await asyncio.wait_for(
151+
get_job(session, jobs_needed), timeout=30
152+
)
153+
except asyncio.CancelledError:
154+
log.debug("JobScaler.get_jobs | Request was cancelled.")
155+
continue
156+
except TimeoutError:
157+
log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.")
158+
continue
159+
except TypeError as error:
160+
log.debug(f"JobScaler.get_jobs | Unexpected error: {error}.")
161+
continue
92162
except Exception as error:
93163
log.error(
94164
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
95165
)
166+
continue
96167

97-
finally:
98-
await asyncio.sleep(5) # yield control back to the event loop
168+
if not acquired_jobs:
169+
log.debug("JobScaler.get_jobs | No jobs acquired.")
170+
await asyncio.sleep(0)
171+
continue
99172

100-
async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
173+
for job in acquired_jobs:
174+
await job_list.add_job(job)
175+
176+
log.info(f"Jobs in queue: {job_list.get_job_count()}")
177+
178+
async def run_jobs(self, session: ClientSession):
101179
"""
102180
Retrieve jobs from the jobs queue and process them concurrently.
103181
@@ -111,7 +189,7 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
111189
job = await job_list.get_job()
112190

113191
# Create a new task for each job and add it to the task list
114-
task = asyncio.create_task(self.handle_job(session, config, job))
192+
task = asyncio.create_task(self.handle_job(session, job))
115193
tasks.append(task)
116194

117195
# Wait for any job to finish
@@ -131,19 +209,19 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
131209
# Ensure all remaining tasks finish before stopping
132210
await asyncio.gather(*tasks)
133211

134-
async def handle_job(self, session: ClientSession, config: Dict[str, Any], job):
212+
async def handle_job(self, session: ClientSession, job: dict):
135213
"""
136214
Process an individual job. This function is run concurrently for multiple jobs.
137215
"""
138-
log.debug(f"Processing job: {job}")
216+
log.debug(f"JobScaler.handle_job | {job}")
139217
job_progress.add(job)
140218

141219
try:
142-
await handle_job(session, config, job)
220+
await handle_job(session, self.config, job)
143221

144-
if config.get("refresh_worker", False):
222+
if self.config.get("refresh_worker", False):
145223
self.kill_worker()
146-
224+
147225
except Exception as err:
148226
log.error(f"Error handling job: {err}", job["id"])
149227
raise err

runpod/serverless/modules/worker_state.py

+8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from typing import Any, Dict, Optional
99
from asyncio import Queue
1010

11+
from .rp_logger import RunPodLogger
12+
13+
14+
log = RunPodLogger()
15+
1116
REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger.
1217

1318
WORKER_ID = os.environ.get("RUNPOD_POD_ID", str(uuid.uuid4()))
@@ -87,6 +92,7 @@ def add(self, element: Any):
8792
if not isinstance(element, Job):
8893
raise TypeError("Only Job objects can be added to JobsProgress.")
8994

95+
log.debug(f"JobsProgress.add | {element}")
9096
return super().add(element)
9197

9298
def remove(self, element: Any):
@@ -106,6 +112,7 @@ def remove(self, element: Any):
106112
if not isinstance(element, Job):
107113
raise TypeError("Only Job objects can be removed from JobsProgress.")
108114

115+
log.debug(f"JobsProgress.remove | {element}")
109116
return super().remove(element)
110117

111118
def get(self, element: Any) -> Job:
@@ -155,6 +162,7 @@ async def add_job(self, job: dict):
155162
If the queue is full, wait until a free
156163
slot is available before adding item.
157164
"""
165+
log.debug(f"JobsQueue.add_job | {job}")
158166
return await self.put(job)
159167

160168
async def get_job(self) -> dict:

runpod/serverless/worker.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
from typing import Any, Dict
99

10-
from runpod.http_client import AsyncClientSession
1110
from runpod.serverless.modules import rp_logger, rp_local, rp_ping, rp_scale
1211

1312
log = rp_logger.RunPodLogger()
@@ -26,7 +25,7 @@ def _is_local(config) -> bool:
2625

2726

2827
# ------------------------- Main Worker Running Loop ------------------------- #
29-
async def run_worker(config: Dict[str, Any]) -> None:
28+
def run_worker(config: Dict[str, Any]) -> None:
3029
"""
3130
Starts the worker loop for multi-processing.
3231
@@ -39,29 +38,9 @@ async def run_worker(config: Dict[str, Any]) -> None:
3938
# Start pinging RunPod to show that the worker is alive.
4039
heartbeat.start_ping()
4140

42-
# Create an async session that will be closed when the worker is killed.
43-
async with AsyncClientSession() as session:
44-
# Create a JobScaler responsible for adjusting the concurrency
45-
# of the worker based on the modifier callable.
46-
job_scaler = rp_scale.JobScaler(
47-
concurrency_modifier=config.get("concurrency_modifier", None)
48-
)
49-
50-
# Create tasks for getting and running jobs.
51-
jobtake_task = asyncio.create_task(job_scaler.get_jobs(session))
52-
jobrun_task = asyncio.create_task(job_scaler.run_jobs(session, config))
53-
54-
tasks = [jobtake_task, jobrun_task]
55-
56-
try:
57-
# Concurrently run both tasks and wait for both to finish.
58-
await asyncio.gather(*tasks)
59-
except asyncio.CancelledError: # worker is killed
60-
# Handle the task cancellation gracefully
61-
for task in tasks:
62-
task.cancel()
63-
await asyncio.gather(*tasks, return_exceptions=True)
64-
log.debug("Worker tasks cancelled.")
41+
# Create a JobScaler responsible for adjusting the concurrency
42+
job_scaler = rp_scale.JobScaler(config)
43+
job_scaler.start()
6544

6645

6746
def main(config: Dict[str, Any]) -> None:
@@ -74,4 +53,4 @@ def main(config: Dict[str, Any]) -> None:
7453
asyncio.run(rp_local.run_local(config))
7554

7655
else:
77-
asyncio.run(run_worker(config))
56+
run_worker(config)

0 commit comments

Comments
 (0)