10
10
from ...http_client import AsyncClientSession , ClientSession , TooManyRequests
11
11
from .rp_job import get_job , handle_job
12
12
from .rp_logger import RunPodLogger
13
- from .worker_state import JobsQueue , JobsProgress
13
+ from .worker_state import JobsProgress , IS_LOCAL_TEST
14
14
15
15
log = RunPodLogger ()
16
- job_list = JobsQueue ()
17
16
job_progress = JobsProgress ()
18
17
19
18
@@ -38,16 +37,50 @@ class JobScaler:
38
37
"""
39
38
40
39
def __init__ (self , config : Dict [str , Any ]):
41
- concurrency_modifier = config .get ("concurrency_modifier" )
42
- if concurrency_modifier is None :
43
- self .concurrency_modifier = _default_concurrency_modifier
44
- else :
45
- self .concurrency_modifier = concurrency_modifier
46
-
47
40
self ._shutdown_event = asyncio .Event ()
48
41
self .current_concurrency = 1
49
42
self .config = config
50
43
44
+ self .jobs_queue = asyncio .Queue (maxsize = self .current_concurrency )
45
+
46
+ self .concurrency_modifier = _default_concurrency_modifier
47
+ self .jobs_fetcher = get_job
48
+ self .jobs_fetcher_timeout = 90
49
+ self .jobs_handler = handle_job
50
+
51
+ if concurrency_modifier := config .get ("concurrency_modifier" ):
52
+ self .concurrency_modifier = concurrency_modifier
53
+
54
+ if not IS_LOCAL_TEST :
55
+ # below cannot be changed unless local
56
+ return
57
+
58
+ if jobs_fetcher := self .config .get ("jobs_fetcher" ):
59
+ self .jobs_fetcher = jobs_fetcher
60
+
61
+ if jobs_fetcher_timeout := self .config .get ("jobs_fetcher_timeout" ):
62
+ self .jobs_fetcher_timeout = jobs_fetcher_timeout
63
+
64
+ if jobs_handler := self .config .get ("jobs_handler" ):
65
+ self .jobs_handler = jobs_handler
66
+
67
+ async def set_scale (self ):
68
+ self .current_concurrency = self .concurrency_modifier (self .current_concurrency )
69
+
70
+ if self .jobs_queue and (self .current_concurrency == self .jobs_queue .maxsize ):
71
+ # no need to resize
72
+ return
73
+
74
+ while self .current_occupancy () > 0 :
75
+ # not safe to scale when jobs are in flight
76
+ await asyncio .sleep (1 )
77
+ continue
78
+
79
+ self .jobs_queue = asyncio .Queue (maxsize = self .current_concurrency )
80
+ log .debug (
81
+ f"JobScaler.set_scale | New concurrency set to: { self .current_concurrency } "
82
+ )
83
+
51
84
def start (self ):
52
85
"""
53
86
This is required for the worker to be able to shut down gracefully
@@ -105,6 +138,15 @@ def kill_worker(self):
105
138
log .info ("Kill worker." )
106
139
self ._shutdown_event .set ()
107
140
141
+ def current_occupancy (self ) -> int :
142
+ current_queue_count = self .jobs_queue .qsize ()
143
+ current_progress_count = job_progress .get_job_count ()
144
+
145
+ log .debug (
146
+ f"JobScaler.status | concurrency: { self .current_concurrency } ; queue: { current_queue_count } ; progress: { current_progress_count } "
147
+ )
148
+ return current_progress_count + current_queue_count
149
+
108
150
async def get_jobs (self , session : ClientSession ):
109
151
"""
110
152
Retrieve multiple jobs from the server in batches using blocking requests.
@@ -114,45 +156,42 @@ async def get_jobs(self, session: ClientSession):
114
156
Adds jobs to the JobsQueue
115
157
"""
116
158
while self .is_alive ():
117
- log .debug ("JobScaler.get_jobs | Starting job acquisition." )
118
-
119
- self .current_concurrency = self .concurrency_modifier (
120
- self .current_concurrency
121
- )
122
- log .debug (f"JobScaler.get_jobs | current Concurrency set to: { self .current_concurrency } " )
159
+ await self .set_scale ()
123
160
124
- current_progress_count = await job_progress .get_job_count ()
125
- log .debug (f"JobScaler.get_jobs | current Jobs in progress: { current_progress_count } " )
126
-
127
- current_queue_count = job_list .get_job_count ()
128
- log .debug (f"JobScaler.get_jobs | current Jobs in queue: { current_queue_count } " )
129
-
130
- jobs_needed = self .current_concurrency - current_progress_count - current_queue_count
161
+ jobs_needed = self .current_concurrency - self .current_occupancy ()
131
162
if jobs_needed <= 0 :
132
163
log .debug ("JobScaler.get_jobs | Queue is full. Retrying soon." )
133
164
await asyncio .sleep (1 ) # don't go rapidly
134
165
continue
135
166
136
167
try :
137
- # Keep the connection to the blocking call up to 30 seconds
168
+ log .debug ("JobScaler.get_jobs | Starting job acquisition." )
169
+
170
+ # Keep the connection to the blocking call with timeout
138
171
acquired_jobs = await asyncio .wait_for (
139
- get_job (session , jobs_needed ), timeout = 30
172
+ self .jobs_fetcher (session , jobs_needed ),
173
+ timeout = self .jobs_fetcher_timeout ,
140
174
)
141
175
142
176
if not acquired_jobs :
143
177
log .debug ("JobScaler.get_jobs | No jobs acquired." )
144
178
continue
145
179
146
180
for job in acquired_jobs :
147
- await job_list .add_job (job )
181
+ await self .jobs_queue .put (job )
182
+ job_progress .add (job )
183
+ log .debug ("Job Queued" , job ["id" ])
148
184
149
- log .info (f"Jobs in queue: { job_list . get_job_count ()} " )
185
+ log .info (f"Jobs in queue: { self . jobs_queue . qsize ()} " )
150
186
151
187
except TooManyRequests :
152
- log .debug (f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds." )
188
+ log .debug (
189
+ f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds."
190
+ )
153
191
await asyncio .sleep (5 ) # debounce for 5 seconds
154
192
except asyncio .CancelledError :
155
193
log .debug ("JobScaler.get_jobs | Request was cancelled." )
194
+ raise # CancelledError is a BaseException
156
195
except TimeoutError :
157
196
log .debug ("JobScaler.get_jobs | Job acquisition timed out. Retrying." )
158
197
except TypeError as error :
@@ -173,10 +212,10 @@ async def run_jobs(self, session: ClientSession):
173
212
"""
174
213
tasks = [] # Store the tasks for concurrent job processing
175
214
176
- while self .is_alive () or not job_list .empty ():
215
+ while self .is_alive () or not self . jobs_queue .empty ():
177
216
# Fetch as many jobs as the concurrency allows
178
- while len (tasks ) < self .current_concurrency and not job_list .empty ():
179
- job = await job_list . get_job ()
217
+ while len (tasks ) < self .current_concurrency and not self . jobs_queue .empty ():
218
+ job = await self . jobs_queue . get ()
180
219
181
220
# Create a new task for each job and add it to the task list
182
221
task = asyncio .create_task (self .handle_job (session , job ))
@@ -204,9 +243,9 @@ async def handle_job(self, session: ClientSession, job: dict):
204
243
Process an individual job. This function is run concurrently for multiple jobs.
205
244
"""
206
245
try :
207
- await job_progress . add ( job )
246
+ log . debug ( "Handling Job" , job [ "id" ] )
208
247
209
- await handle_job (session , self .config , job )
248
+ await self . jobs_handler (session , self .config , job )
210
249
211
250
if self .config .get ("refresh_worker" , False ):
212
251
self .kill_worker ()
@@ -216,8 +255,10 @@ async def handle_job(self, session: ClientSession, job: dict):
216
255
raise err
217
256
218
257
finally :
219
- # Inform JobsQueue of a task completion
220
- job_list .task_done ()
258
+ # Inform Queue of a task completion
259
+ self . jobs_queue .task_done ()
221
260
222
261
# Job is no longer in progress
223
- await job_progress .remove (job ["id" ])
262
+ job_progress .remove (job )
263
+
264
+ log .debug ("Finished Job" , job ["id" ])
0 commit comments