Skip to content

Commit d0fead9

Browse files
committed
feat: abstract worker
1 parent 3de0b46 commit d0fead9

File tree

3 files changed

+49
-27
lines changed

3 files changed

+49
-27
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [Installation](#installation)
1111
- [API Language Library](#api-language-library)
1212
- [SDK - Serverless Worker](#sdk---serverless-worker)
13+
- [Quick Start](#quick-start)
1314

1415
## Installation
1516

@@ -30,3 +31,15 @@ runpod.api_key = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
3031
## SDK - Serverless Worker
3132

3233
This python package can also be used to create a serverless worker that can be deployed to RunPod.
34+
35+
### Quick Start
36+
37+
Create an executable file called 'worker' in the root of your project that contains the following:
38+
39+
```python
40+
#!/usr/bin/env python
41+
42+
import runpod
43+
44+
runpod.serverless.pod_worker()
45+
```

src/runpod/serverless/pod_worker.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,44 @@
1111
from .modules.logging import log
1212

1313

14-
worker_life = lifecycle.LifecycleManager()
14+
def start_worker():
15+
'''
16+
Starts the worker.
17+
'''
18+
worker_life = lifecycle.LifecycleManager()
19+
20+
if not worker_life.is_worker_zero:
21+
log("Not worker zero, starting TTL timer thread.")
22+
threading.Thread(target=worker_life.check_worker_ttl_thread).start()
23+
else:
24+
log("Worker zero, not starting TTL timer thread.")
1525

16-
if not worker_life.is_worker_zero:
17-
log("Not worker zero, starting TTL timer thread.")
18-
threading.Thread(target=worker_life.check_worker_ttl_thread).start()
19-
else:
20-
log("Worker zero, not starting TTL timer thread.")
26+
while True:
27+
if os.environ.get('TEST_LOCAL', 'false') != 'true':
28+
next_job = job.get(worker_life.worker_id)
29+
else:
30+
next_job = job.get_local()
2131

32+
if next_job is not None:
33+
worker_life.work_in_progress = True # Rests when "reset_worker_ttl" is called
2234

23-
while True:
24-
if os.environ.get('TEST_LOCAL', 'false') != 'true':
25-
next_job = job.get(worker_life.worker_id)
26-
else:
27-
next_job = job.get_local()
35+
try:
36+
output_urls, job_duration_ms = job.run(next_job['id'], next_job['input'])
37+
job.post(worker_life.worker_id, next_job['id'], output_urls, job_duration_ms)
38+
except ValueError as err:
39+
job.error(worker_life.worker_id, next_job['id'], str(err))
2840

29-
if next_job is not None:
30-
worker_life.work_in_progress = True # Rests when "reset_worker_ttl" is called
41+
# -------------------------------- Job Cleanup ------------------------------- #
42+
shutil.rmtree("input_objects", ignore_errors=True)
43+
shutil.rmtree("output_objects", ignore_errors=True)
44+
os.remove('output.zip')
3145

32-
try:
33-
output_urls, job_duration_ms = job.run(next_job['id'], next_job['input'])
34-
job.post(worker_life.worker_id, next_job['id'], output_urls, job_duration_ms)
35-
except ValueError as err:
36-
job.error(worker_life.worker_id, next_job['id'], str(err))
46+
worker_life.reset_worker_ttl()
3747

38-
# -------------------------------- Job Cleanup ------------------------------- #
39-
shutil.rmtree("input_objects", ignore_errors=True)
40-
shutil.rmtree("output_objects", ignore_errors=True)
41-
os.remove('output.zip')
48+
if os.environ.get('TEST_LOCAL', 'false') == 'true':
49+
log("Local testing complete, exiting.")
50+
break
4251

43-
worker_life.reset_worker_ttl()
4452

45-
if os.environ.get('TEST_LOCAL', 'false') == 'true':
46-
log("Local testing complete, exiting.")
47-
break
53+
if __name__ == '__main__':
54+
start_worker()

tests/test_serverless_module_download.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ def __init__(self, content, status_code):
3030
class TestDownloadInputObjects(unittest.TestCase):
3131
''' Tests for download_input_objects '''
3232

33+
@patch('os.makedirs', return_value=None)
3334
@patch('requests.get', side_effect=mock_requests_get)
3435
@patch('builtins.open', new_callable=mock_open)
35-
def test_download_input_objects(self, mock_open_file, mock_get):
36+
def test_download_input_objects(self, mock_open_file, mock_get, mock_makedirs):
3637
'''
3738
Tests download_input_objects
3839
'''
@@ -43,3 +44,4 @@ def test_download_input_objects(self, mock_open_file, mock_get):
4344
self.assertEqual(len(objects), 1)
4445
self.assertIn('https://example.com/picture.jpg', mock_get.call_args_list[0][0])
4546
mock_open_file.assert_called_once_with(objects[0], 'wb')
47+
mock_makedirs.assert_called_once_with('input_objects', exist_ok=True)

0 commit comments

Comments
 (0)