Skip to content

Commit 3f78233

Browse files
authored
Fix: handle uncaught exception only for Serverless workers (#388)
* refactor: moved handle_uncaught_exception to rp_scale * refactor: bind handle_uncaught_exception on JobScaler init * fix: python <3.11 compatibility
1 parent d7a2131 commit 3f78233

File tree

4 files changed

+89
-7
lines changed

4 files changed

+89
-7
lines changed

runpod/serverless/__init__.py

-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import signal
1111
import sys
1212
import time
13-
import typing
1413
from typing import Any, Dict
1514

1615
from runpod.serverless import core
@@ -24,12 +23,6 @@
2423
log = RunPodLogger()
2524

2625

27-
def handle_uncaught_exception(exc_type, exc_value, exc_traceback):
28-
log.error(f"Uncaught exception | {exc_type}; {exc_value}; {exc_traceback};")
29-
30-
sys.excepthook = handle_uncaught_exception
31-
32-
3326
# ---------------------------------------------------------------------------- #
3427
# Run Time Arguments #
3528
# ---------------------------------------------------------------------------- #

runpod/serverless/modules/rp_scale.py

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import asyncio
77
import signal
8+
import sys
9+
import traceback
810
from typing import Any, Dict
911

1012
from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
@@ -16,6 +18,11 @@
1618
job_progress = JobsProgress()
1719

1820

21+
def _handle_uncaught_exception(exc_type, exc_value, exc_traceback):
22+
exc = traceback.format_exception(exc_type, exc_value, exc_traceback)
23+
log.error(f"Uncaught exception | {exc}")
24+
25+
1926
def _default_concurrency_modifier(current_concurrency: int) -> int:
2027
"""
2128
Default concurrency modifier.
@@ -87,6 +94,8 @@ def start(self):
8794
when the user sends a SIGTERM or SIGINT signal. This is typically
8895
the case when the worker is running in a container.
8996
"""
97+
sys.excepthook = _handle_uncaught_exception
98+
9099
try:
91100
# Register signal handlers for graceful shutdown
92101
signal.signal(signal.SIGTERM, self.handle_shutdown)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import sys
2+
import traceback
3+
from unittest import TestCase
4+
from unittest.mock import patch
5+
6+
from runpod.serverless.modules.rp_scale import _handle_uncaught_exception
7+
8+
9+
class TestHandleUncaughtException(TestCase):
10+
def setUp(self):
11+
sys.excepthook = sys.__excepthook__
12+
13+
@patch("runpod.serverless.modules.rp_scale.log")
14+
def test_handle_uncaught_exception(self, mock_logger):
15+
exc_type = ValueError
16+
exc_value = ValueError("This is a test error")
17+
exc_traceback = None # No traceback for simplicity
18+
19+
_handle_uncaught_exception(exc_type, exc_value, exc_traceback)
20+
21+
formatted_exception = traceback.format_exception(exc_type, exc_value, exc_traceback)
22+
23+
mock_logger.error.assert_called_once()
24+
log_message = mock_logger.error.call_args[0][0]
25+
assert "Uncaught exception | " in log_message
26+
assert str(formatted_exception) in log_message
27+
28+
@patch("runpod.serverless.modules.rp_scale.log")
29+
def test_handle_uncaught_exception_with_traceback(self, mock_logger):
30+
try:
31+
raise RuntimeError("This is a runtime error")
32+
except RuntimeError:
33+
exc_type, exc_value, exc_traceback = sys.exc_info()
34+
35+
_handle_uncaught_exception(exc_type, exc_value, exc_traceback)
36+
37+
formatted_exception = traceback.format_exception(exc_type, exc_value, exc_traceback)
38+
39+
mock_logger.error.assert_called_once()
40+
log_message = mock_logger.error.call_args[0][0]
41+
assert "Uncaught exception | " in log_message
42+
assert str(formatted_exception) in log_message
43+
44+
@patch("runpod.serverless.modules.rp_scale.log")
45+
def test_handle_uncaught_exception_with_no_exception(self, mock_logger):
46+
_handle_uncaught_exception(None, None, None)
47+
48+
mock_logger.error.assert_called_once()
49+
log_message = mock_logger.error.call_args[0][0]
50+
assert "Uncaught exception | " in log_message
51+
52+
def test_excepthook_not_set_when_start_not_invoked(self):
53+
assert sys.excepthook == sys.__excepthook__
54+
assert sys.excepthook != _handle_uncaught_exception

tests/test_serverless/test_worker.py

+26
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import argparse
66
import os
7+
import sys
78
from unittest import mock
89
from unittest.mock import patch, mock_open, Mock, MagicMock
910

@@ -12,6 +13,7 @@
1213

1314
import runpod
1415
from runpod.serverless.modules.rp_logger import RunPodLogger
16+
from runpod.serverless.modules.rp_scale import _handle_uncaught_exception
1517
from runpod.serverless import _signal_handler
1618

1719
nest_asyncio.apply()
@@ -187,6 +189,9 @@ async def asyncSetUp(self):
187189
"rp_args": {"rp_debugger": True, "rp_log_level": "DEBUG"},
188190
}
189191

192+
async def asyncTearDown(self):
193+
sys.excepthook = sys.__excepthook__
194+
190195
@patch("runpod.serverless.modules.rp_scale.AsyncClientSession")
191196
@patch("runpod.serverless.modules.rp_scale.get_job")
192197
@patch("runpod.serverless.modules.rp_job.run_job")
@@ -543,3 +548,24 @@ async def test_run_worker_with_sls_core(self):
543548
os.environ.pop("RUNPOD_USE_CORE")
544549

545550
assert mock_main.called
551+
552+
@patch("runpod.serverless.signal.signal")
553+
@patch("runpod.serverless.worker.rp_scale.JobScaler.run")
554+
def test_start_sets_excepthook(self, _, __):
555+
runpod.serverless.start({})
556+
assert sys.excepthook == _handle_uncaught_exception
557+
558+
@patch("runpod.serverless.signal.signal")
559+
@patch("runpod.serverless.rp_fastapi.WorkerAPI.start_uvicorn")
560+
@patch("runpod.serverless._set_config_args")
561+
def test_start_does_not_set_excepthook(self, mock_set_config_args, _, __):
562+
mock_set_config_args.return_value = self.config
563+
self.config.update({"rp_args": {
564+
"rp_serve_api": True,
565+
"rp_api_host": "localhost",
566+
"rp_api_port": 8000,
567+
"rp_api_concurrency": 1,
568+
}})
569+
570+
runpod.serverless.start(self.config)
571+
assert sys.excepthook != _handle_uncaught_exception

0 commit comments

Comments
 (0)