4
4
"""
5
5
6
6
import asyncio
7
+ import signal
7
8
from typing import Any , Dict
8
9
9
- from ...http_client import ClientSession
10
+ from ...http_client import AsyncClientSession , ClientSession
10
11
from .rp_job import get_job , handle_job
11
12
from .rp_logger import RunPodLogger
12
13
from .worker_state import JobsQueue , JobsProgress
@@ -36,26 +37,91 @@ class JobScaler:
36
37
Job Scaler. This class is responsible for scaling the number of concurrent requests.
37
38
"""
38
39
39
- def __init__ (self , concurrency_modifier : Any ):
40
+ def __init__ (self , config : Dict [str , Any ]):
41
+ concurrency_modifier = config .get ("concurrency_modifier" )
40
42
if concurrency_modifier is None :
41
43
self .concurrency_modifier = _default_concurrency_modifier
42
44
else :
43
45
self .concurrency_modifier = concurrency_modifier
44
46
47
+ self ._shutdown_event = asyncio .Event ()
45
48
self .current_concurrency = 1
46
- self ._is_alive = True
49
+ self .config = config
50
+
51
+ def start (self ):
52
+ """
53
+ This is required for the worker to be able to shut down gracefully
54
+ when the user sends a SIGTERM or SIGINT signal. This is typically
55
+ the case when the worker is running in a container.
56
+ """
57
+ try :
58
+ # Register signal handlers for graceful shutdown
59
+ signal .signal (signal .SIGTERM , self .handle_shutdown )
60
+ signal .signal (signal .SIGINT , self .handle_shutdown )
61
+ except ValueError :
62
+ log .warning ("Signal handling is only supported in the main thread." )
63
+
64
+ # Start the main loop
65
+ # Run forever until the worker is signalled to shut down.
66
+ asyncio .run (self .run ())
67
+
68
+ def handle_shutdown (self , signum , frame ):
69
+ """
70
+ Called when the worker is signalled to shut down.
71
+
72
+ This function is called when the worker receives a signal to shut down, such as
73
+ SIGTERM or SIGINT. It sets the shutdown event, which will cause the worker to
74
+ exit its main loop and shut down gracefully.
75
+
76
+ Args:
77
+ signum: The signal number that was received.
78
+ frame: The current stack frame.
79
+ """
80
+ log .debug (f"Received shutdown signal: { signum } ." )
81
+ self .kill_worker ()
82
+
83
+ async def run (self ):
84
+ # Create an async session that will be closed when the worker is killed.
85
+
86
+ async with AsyncClientSession () as session :
87
+ # Create tasks for getting and running jobs.
88
+ jobtake_task = asyncio .create_task (self .get_jobs (session ))
89
+ jobrun_task = asyncio .create_task (self .run_jobs (session ))
90
+
91
+ tasks = [jobtake_task , jobrun_task ]
92
+
93
+ try :
94
+ # Concurrently run both tasks and wait for both to finish.
95
+ await asyncio .gather (* tasks )
96
+ except asyncio .CancelledError : # worker is killed
97
+ log .debug ("Worker tasks cancelled." )
98
+ self .kill_worker ()
99
+ finally :
100
+ # Handle the task cancellation gracefully
101
+ for task in tasks :
102
+ if not task .done ():
103
+ task .cancel ()
104
+ await asyncio .gather (* tasks , return_exceptions = True )
105
+ await self .cleanup () # Ensure resources are cleaned up
106
+
107
+ async def cleanup (self ):
108
+ # Perform any necessary cleanup here, such as closing connections
109
+ log .debug ("Cleaning up resources before shutdown." )
110
+ # TODO: stop heartbeat or close any open connections
111
+ await asyncio .sleep (0 ) # Give a chance for other tasks to run (optional)
112
+ log .debug ("Cleanup complete." )
47
113
48
114
def is_alive (self ):
49
115
"""
50
116
Return whether the worker is alive or not.
51
117
"""
52
- return self ._is_alive
118
+ return not self ._shutdown_event . is_set ()
53
119
54
120
def kill_worker (self ):
55
121
"""
56
122
Whether to kill the worker.
57
123
"""
58
- self ._is_alive = False
124
+ self ._shutdown_event . set ()
59
125
60
126
async def get_jobs (self , session : ClientSession ):
61
127
"""
@@ -66,38 +132,50 @@ async def get_jobs(self, session: ClientSession):
66
132
Adds jobs to the JobsQueue
67
133
"""
68
134
while self .is_alive ():
69
- log .debug (f"Jobs in progress: { job_progress .get_job_count ()} " )
70
-
71
- try :
72
- self .current_concurrency = self .concurrency_modifier (
73
- self .current_concurrency
74
- )
75
- log .debug (f"Concurrency set to: { self .current_concurrency } " )
76
-
77
- jobs_needed = self .current_concurrency - job_progress .get_job_count ()
78
- if not jobs_needed : # zero or less
79
- log .debug ("Queue is full. Retrying soon." )
80
- continue
135
+ log .debug (f"JobScaler.get_jobs | Jobs in progress: { job_progress .get_job_count ()} " )
81
136
82
- acquired_jobs = await get_job ( session , jobs_needed )
83
- if not acquired_jobs :
84
- log . debug ( "No jobs acquired." )
85
- continue
137
+ self . current_concurrency = self . concurrency_modifier (
138
+ self . current_concurrency
139
+ )
140
+ log . debug ( f"JobScaler.get_jobs | Concurrency set to: { self . current_concurrency } " )
86
141
87
- for job in acquired_jobs :
88
- await job_list .add_job (job )
89
-
90
- log .info (f"Jobs in queue: { job_list .get_job_count ()} " )
142
+ jobs_needed = self .current_concurrency - job_progress .get_job_count ()
143
+ if jobs_needed <= 0 :
144
+ log .debug ("JobScaler.get_jobs | Queue is full. Retrying soon." )
145
+ await asyncio .sleep (0.1 ) # don't go rapidly
146
+ continue
91
147
148
+ try :
149
+ # Keep the connection to the blocking call up to 30 seconds
150
+ acquired_jobs = await asyncio .wait_for (
151
+ get_job (session , jobs_needed ), timeout = 30
152
+ )
153
+ except asyncio .CancelledError :
154
+ log .debug ("JobScaler.get_jobs | Request was cancelled." )
155
+ continue
156
+ except TimeoutError :
157
+ log .debug ("JobScaler.get_jobs | Job acquisition timed out. Retrying." )
158
+ continue
159
+ except TypeError as error :
160
+ log .debug (f"JobScaler.get_jobs | Unexpected error: { error } ." )
161
+ continue
92
162
except Exception as error :
93
163
log .error (
94
164
f"Failed to get job. | Error Type: { type (error ).__name__ } | Error Message: { str (error )} "
95
165
)
166
+ continue
96
167
97
- finally :
98
- await asyncio .sleep (5 ) # yield control back to the event loop
168
+ if not acquired_jobs :
169
+ log .debug ("JobScaler.get_jobs | No jobs acquired." )
170
+ await asyncio .sleep (0 )
171
+ continue
99
172
100
- async def run_jobs (self , session : ClientSession , config : Dict [str , Any ]):
173
+ for job in acquired_jobs :
174
+ await job_list .add_job (job )
175
+
176
+ log .info (f"Jobs in queue: { job_list .get_job_count ()} " )
177
+
178
+ async def run_jobs (self , session : ClientSession ):
101
179
"""
102
180
Retrieve jobs from the jobs queue and process them concurrently.
103
181
@@ -111,7 +189,7 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
111
189
job = await job_list .get_job ()
112
190
113
191
# Create a new task for each job and add it to the task list
114
- task = asyncio .create_task (self .handle_job (session , config , job ))
192
+ task = asyncio .create_task (self .handle_job (session , job ))
115
193
tasks .append (task )
116
194
117
195
# Wait for any job to finish
@@ -131,19 +209,19 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
131
209
# Ensure all remaining tasks finish before stopping
132
210
await asyncio .gather (* tasks )
133
211
134
- async def handle_job (self , session : ClientSession , config : Dict [ str , Any ], job ):
212
+ async def handle_job (self , session : ClientSession , job : dict ):
135
213
"""
136
214
Process an individual job. This function is run concurrently for multiple jobs.
137
215
"""
138
- log .debug (f"Processing job: { job } " )
216
+ log .debug (f"JobScaler.handle_job | { job } " )
139
217
job_progress .add (job )
140
218
141
219
try :
142
- await handle_job (session , config , job )
220
+ await handle_job (session , self . config , job )
143
221
144
- if config .get ("refresh_worker" , False ):
222
+ if self . config .get ("refresh_worker" , False ):
145
223
self .kill_worker ()
146
-
224
+
147
225
except Exception as err :
148
226
log .error (f"Error handling job: { err } " , job ["id" ])
149
227
raise err
0 commit comments