Skip to content

Commit 6a38a5a

Browse files
authored
fix: pings were missing requestIds since the last big refactor (#362)
* Jobs-->JobsProgress to track jobs in flight + tests * Tests for rp_ping.Heartbeat * refactor: moved JobScaler.process_job to rp_job.handle_job * refactor: JobScaler distinguishes JobsProgress from JobsQueue * fix: Job should accept arbitrary properties + tests * log and raise the JobScaler.handle_job exceptions * refactor: graceful shutdown of tasks when worker is killed * tests: complete tests for rp_ping
1 parent 5830b50 commit 6a38a5a

File tree

9 files changed

+403
-144
lines changed

9 files changed

+403
-144
lines changed

runpod/serverless/modules/rp_fastapi.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .rp_handler import is_generator
1818
from .rp_job import run_job, run_job_generator
1919
from .rp_ping import Heartbeat
20-
from .worker_state import Jobs
20+
from .worker_state import Job, JobsProgress
2121

2222
RUNPOD_ENDPOINT_ID = os.environ.get("RUNPOD_ENDPOINT_ID", None)
2323

@@ -96,7 +96,7 @@
9696

9797

9898
# ------------------------------ Initializations ----------------------------- #
99-
job_list = Jobs()
99+
job_list = JobsProgress()
100100
heartbeat = Heartbeat()
101101

102102

@@ -286,12 +286,12 @@ async def _realtime(self, job: Job):
286286
Performs model inference on the input data using the provided handler.
287287
If handler is not provided, returns an error message.
288288
"""
289-
job_list.add_job(job.id)
289+
job_list.add(job.id)
290290

291291
# Process the job using the provided handler, passing in the job input.
292292
job_results = await run_job(self.config["handler"], job.__dict__)
293293

294-
job_list.remove_job(job.id)
294+
job_list.remove(job.id)
295295

296296
# Return the results of the job processing.
297297
return jsonable_encoder(job_results)
@@ -304,7 +304,11 @@ async def _realtime(self, job: Job):
304304
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
305305
"""Development endpoint to simulate run behavior."""
306306
assigned_job_id = f"test-{uuid.uuid4()}"
307-
job_list.add_job(assigned_job_id, job_request.input, job_request.webhook)
307+
job_list.add({
308+
"id": assigned_job_id,
309+
"input": job_request.input,
310+
"webhook": job_request.webhook
311+
})
308312
return jsonable_encoder({"id": assigned_job_id, "status": "IN_PROGRESS"})
309313

310314
# ---------------------------------- runsync --------------------------------- #
@@ -341,7 +345,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
341345
# ---------------------------------- stream ---------------------------------- #
342346
async def _sim_stream(self, job_id: str) -> StreamOutput:
343347
"""Development endpoint to simulate stream behavior."""
344-
stashed_job = job_list.get_job(job_id)
348+
stashed_job = job_list.get(job_id)
345349
if stashed_job is None:
346350
return jsonable_encoder(
347351
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
@@ -363,7 +367,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
363367
}
364368
)
365369

366-
job_list.remove_job(job.id)
370+
job_list.remove(job.id)
367371

368372
if stashed_job.webhook:
369373
thread = threading.Thread(
@@ -380,7 +384,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
380384
# ---------------------------------- status ---------------------------------- #
381385
async def _sim_status(self, job_id: str) -> JobOutput:
382386
"""Development endpoint to simulate status behavior."""
383-
stashed_job = job_list.get_job(job_id)
387+
stashed_job = job_list.get(job_id)
384388
if stashed_job is None:
385389
return jsonable_encoder(
386390
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
@@ -396,7 +400,7 @@ async def _sim_status(self, job_id: str) -> JobOutput:
396400
else:
397401
job_output = await run_job(self.config["handler"], job.__dict__)
398402

399-
job_list.remove_job(job.id)
403+
job_list.remove(job.id)
400404

401405
if job_output.get("error", None):
402406
return jsonable_encoder(

runpod/serverless/modules/rp_job.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
from runpod.serverless.modules.rp_logger import RunPodLogger
1414

1515
from ...version import __version__ as runpod_version
16+
from ..utils import rp_debugger
17+
from .rp_handler import is_generator
18+
from .rp_http import send_result, stream_result
1619
from .rp_tips import check_return_size
17-
from .worker_state import WORKER_ID, JobsQueue
20+
from .worker_state import WORKER_ID, REF_COUNT_ZERO, JobsProgress
1821

1922
JOB_GET_URL = str(os.environ.get("RUNPOD_WEBHOOK_GET_JOB")).replace("$ID", WORKER_ID)
2023

2124
log = RunPodLogger()
22-
job_list = JobsQueue()
25+
job_progress = JobsProgress()
2326

2427

2528
def _job_get_url(batch_size: int = 1):
@@ -32,7 +35,7 @@ def _job_get_url(batch_size: int = 1):
3235
Returns:
3336
str: The prepared URL for the 'get' request to the serverless API.
3437
"""
35-
job_in_progress = "1" if job_list.get_job_count() else "0"
38+
job_in_progress = "1" if job_progress.get_job_count() else "0"
3639

3740
if batch_size > 1:
3841
job_take_url = JOB_GET_URL.replace("/job-take/", "/job-take-batch/")
@@ -96,6 +99,47 @@ async def get_job(
9699
return []
97100

98101

102+
async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dict:
103+
if is_generator(config["handler"]):
104+
is_stream = True
105+
generator_output = run_job_generator(config["handler"], job)
106+
log.debug("Handler is a generator, streaming results.", job["id"])
107+
108+
job_result = {"output": []}
109+
async for stream_output in generator_output:
110+
log.debug(f"Stream output: {stream_output}", job["id"])
111+
if "error" in stream_output:
112+
job_result = stream_output
113+
break
114+
if config.get("return_aggregate_stream", False):
115+
job_result["output"].append(stream_output["output"])
116+
117+
await stream_result(session, stream_output, job)
118+
else:
119+
is_stream = False
120+
job_result = await run_job(config["handler"], job)
121+
122+
# If refresh_worker is set, pod will be reset after job is complete.
123+
if config.get("refresh_worker", False):
124+
log.info("refresh_worker flag set, stopping pod after job.", job["id"])
125+
job_result["stopPod"] = True
126+
127+
# If rp_debugger is set, debugger output will be returned.
128+
if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict):
129+
job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output()
130+
log.debug("rp_debugger | Flag set, returning debugger output.", job["id"])
131+
132+
# Calculate ready delay for the debugger output.
133+
ready_delay = (config["reference_counter_start"] - REF_COUNT_ZERO) * 1000
134+
job_result["output"]["rp_debugger"]["ready_delay_ms"] = ready_delay
135+
else:
136+
log.debug("rp_debugger | Flag not set, skipping debugger output.", job["id"])
137+
rp_debugger.clear_debugger_output()
138+
139+
# Send the job result back to JOB_DONE_URL
140+
await send_result(session, job_result, job, is_stream=is_stream)
141+
142+
99143
async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
100144
"""
101145
Run the job using the handler.

runpod/serverless/modules/rp_ping.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,26 @@
1212

1313
from runpod.http_client import SyncClientSession
1414
from runpod.serverless.modules.rp_logger import RunPodLogger
15-
from runpod.serverless.modules.worker_state import WORKER_ID, JobsQueue
15+
from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress
1616
from runpod.version import __version__ as runpod_version
1717

1818
log = RunPodLogger()
19-
jobs = JobsQueue() # Contains the list of jobs that are currently running.
19+
jobs = JobsProgress() # Contains the list of jobs that are currently running.
2020

2121

2222
class Heartbeat:
2323
"""Sends heartbeats to the Runpod server."""
2424

25-
PING_URL = os.environ.get("RUNPOD_WEBHOOK_PING", "PING_NOT_SET")
26-
PING_URL = PING_URL.replace("$RUNPOD_POD_ID", WORKER_ID)
27-
PING_INTERVAL = int(os.environ.get("RUNPOD_PING_INTERVAL", 10000)) // 1000
28-
2925
_thread_started = False
3026

3127
def __init__(self, pool_connections=10, retries=3) -> None:
3228
"""
3329
Initializes the Heartbeat class.
3430
"""
31+
self.PING_URL = os.environ.get("RUNPOD_WEBHOOK_PING", "PING_NOT_SET")
32+
self.PING_URL = self.PING_URL.replace("$RUNPOD_POD_ID", WORKER_ID)
33+
self.PING_INTERVAL = int(os.environ.get("RUNPOD_PING_INTERVAL", 10000)) // 1000
34+
3535
self._session = SyncClientSession()
3636
self._session.headers.update(
3737
{"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"}
@@ -56,15 +56,15 @@ def start_ping(self, test=False):
5656
"""
5757
Sends heartbeat pings to the Runpod server.
5858
"""
59-
if os.environ.get("RUNPOD_AI_API_KEY") is None:
59+
if not os.environ.get("RUNPOD_AI_API_KEY"):
6060
log.debug("Not deployed on RunPod serverless, pings will not be sent.")
6161
return
6262

63-
if os.environ.get("RUNPOD_POD_ID") is None:
63+
if not os.environ.get("RUNPOD_POD_ID"):
6464
log.info("Not running on RunPod, pings will not be sent.")
6565
return
6666

67-
if self.PING_URL in ["PING_NOT_SET", None]:
67+
if (not self.PING_URL) or self.PING_URL == "PING_NOT_SET":
6868
log.error("Ping URL not set, cannot start ping.")
6969
return
7070

runpod/serverless/modules/rp_scale.py

+22-49
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77
from typing import Any, Dict
88

99
from ...http_client import ClientSession
10-
from ..utils import rp_debugger
11-
from .rp_handler import is_generator
12-
from .rp_http import send_result, stream_result
13-
from .rp_job import get_job, run_job, run_job_generator
10+
from .rp_job import get_job, handle_job
1411
from .rp_logger import RunPodLogger
15-
from .worker_state import JobsQueue, REF_COUNT_ZERO
12+
from .worker_state import JobsQueue, JobsProgress
1613

1714
log = RunPodLogger()
1815
job_list = JobsQueue()
16+
job_progress = JobsProgress()
1917

2018

2119
def _default_concurrency_modifier(current_concurrency: int) -> int:
@@ -68,15 +66,15 @@ async def get_jobs(self, session: ClientSession):
6866
Adds jobs to the JobsQueue
6967
"""
7068
while self.is_alive():
71-
log.debug(f"Jobs in progress: {job_list.get_job_count()}")
69+
log.debug(f"Jobs in progress: {job_progress.get_job_count()}")
7270

7371
try:
7472
self.current_concurrency = self.concurrency_modifier(
7573
self.current_concurrency
7674
)
7775
log.debug(f"Concurrency set to: {self.current_concurrency}")
7876

79-
jobs_needed = self.current_concurrency - job_list.get_job_count()
77+
jobs_needed = self.current_concurrency - job_progress.get_job_count()
8078
if not jobs_needed: # zero or less
8179
log.debug("Queue is full. Retrying soon.")
8280
continue
@@ -113,7 +111,7 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
113111
job = await job_list.get_job()
114112

115113
# Create a new task for each job and add it to the task list
116-
task = asyncio.create_task(self.process_job(session, config, job))
114+
task = asyncio.create_task(self.handle_job(session, config, job))
117115
tasks.append(task)
118116

119117
# Wait for any job to finish
@@ -133,51 +131,26 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
133131
# Ensure all remaining tasks finish before stopping
134132
await asyncio.gather(*tasks)
135133

136-
async def process_job(self, session: ClientSession, config: Dict[str, Any], job):
134+
async def handle_job(self, session: ClientSession, config: Dict[str, Any], job):
137135
"""
138136
Process an individual job. This function is run concurrently for multiple jobs.
139137
"""
140138
log.debug(f"Processing job: {job}")
139+
job_progress.add(job)
141140

142-
if is_generator(config["handler"]):
143-
is_stream = True
144-
generator_output = run_job_generator(config["handler"], job)
145-
log.debug("Handler is a generator, streaming results.", job["id"])
146-
147-
job_result = {"output": []}
148-
async for stream_output in generator_output:
149-
log.debug(f"Stream output: {stream_output}", job["id"])
150-
if "error" in stream_output:
151-
job_result = stream_output
152-
break
153-
if config.get("return_aggregate_stream", False):
154-
job_result["output"].append(stream_output["output"])
155-
156-
await stream_result(session, stream_output, job)
157-
else:
158-
is_stream = False
159-
job_result = await run_job(config["handler"], job)
160-
161-
# If refresh_worker is set, pod will be reset after job is complete.
162-
if config.get("refresh_worker", False):
163-
log.info("refresh_worker flag set, stopping pod after job.", job["id"])
164-
job_result["stopPod"] = True
165-
self.kill_worker()
166-
167-
# If rp_debugger is set, debugger output will be returned.
168-
if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict):
169-
job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output()
170-
log.debug("rp_debugger | Flag set, returning debugger output.", job["id"])
171-
172-
# Calculate ready delay for the debugger output.
173-
ready_delay = (config["reference_counter_start"] - REF_COUNT_ZERO) * 1000
174-
job_result["output"]["rp_debugger"]["ready_delay_ms"] = ready_delay
175-
else:
176-
log.debug("rp_debugger | Flag not set, skipping debugger output.", job["id"])
177-
rp_debugger.clear_debugger_output()
141+
try:
142+
await handle_job(session, config, job)
143+
144+
if config.get("refresh_worker", False):
145+
self.kill_worker()
146+
147+
except Exception as err:
148+
log.error(f"Error handling job: {err}", job["id"])
149+
raise err
178150

179-
# Send the job result back to JOB_DONE_URL
180-
await send_result(session, job_result, job, is_stream=is_stream)
151+
finally:
152+
# Inform JobsQueue of a task completion
153+
job_list.task_done()
181154

182-
# Inform JobsQueue of a task completion
183-
job_list.task_done()
155+
# Job is no longer in progress
156+
job_progress.remove(job["id"])

0 commit comments

Comments
 (0)