Skip to content

Commit 0273994

Browse files
Merge pull request #96 from runpod/api-improvements
API Improvements With Error Handling
2 parents d357ef5 + b8df719 commit 0273994

File tree

5 files changed

+88
-6
lines changed

5 files changed

+88
-6
lines changed

docs/api/handling_errors.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Handling Errors
2+
3+
```Python
4+
import runpod
5+
6+
try:
7+
# Use runpod to make a request
8+
except runpod.error.AuthenticationError as err:
9+
# Authentication with the API failed
10+
```

runpod/api_wrapper/ctl_commands.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# pylint: disable=too-many-arguments,too-many-locals
55

66
from typing import Optional
7+
78
from .queries import gpus
89
from .queries import pods as pod_queries
910
from .graphql import run_graphql_query
@@ -14,19 +15,19 @@ def get_gpus() -> dict:
1415
'''
1516
Get all GPU types
1617
'''
17-
raw_return = run_graphql_query(gpus.QUERY_GPU_TYPES)
18-
cleaned_return = raw_return["data"]["gpuTypes"]
18+
raw_response = run_graphql_query(gpus.QUERY_GPU_TYPES)
19+
cleaned_return = raw_response["data"]["gpuTypes"]
1920
return cleaned_return
2021

2122

2223
def get_gpu(gpu_id : str):
2324
'''
2425
Get a specific GPU type
25-
26+
2627
:param gpu_id: the id of the gpu
2728
'''
28-
raw_return = run_graphql_query(gpus.generate_gpu_query(gpu_id))
29-
cleaned_return = raw_return["data"]["gpuTypes"][0]
29+
raw_response = run_graphql_query(gpus.generate_gpu_query(gpu_id))
30+
cleaned_return = raw_response["data"]["gpuTypes"][0]
3031
return cleaned_return
3132

3233
def get_pods() -> dict:
@@ -56,7 +57,7 @@ def create_pod(name : str, image_name : str, gpu_type_id : str, cloud_type : str
5657
:param volume_in_gb: how big should the pod volume be
5758
:param ports: the ports to open in the pod, example format - "8888/http,666/tcp"
5859
:param volume_mount_path: where to mount the volume?
59-
:param env: the environment variables to inject into the pod,
60+
:param env: the environment variables to inject into the pod,
6061
for example {EXAMPLE_VAR:"example_value", EXAMPLE_VAR2:"example_value 2"}, will
6162
inject EXAMPLE_VAR and EXAMPLE_VAR2 into the pod with the mentioned values
6263

runpod/api_wrapper/graphql.py

+12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import requests
99

10+
from runpod import error
11+
12+
HTTP_STATUS_UNAUTHORIZED = 401
1013

1114
def run_graphql_query(query: str) -> Dict[str, Any]:
1215
'''
@@ -19,4 +22,13 @@ def run_graphql_query(query: str) -> Dict[str, Any]:
1922
}
2023
data = json.dumps({"query": query})
2124
response = requests.post(url, headers=headers, data=data, timeout=30)
25+
26+
print(response.json())
27+
28+
if response.status_code == HTTP_STATUS_UNAUTHORIZED:
29+
raise error.AuthenticationError("Unauthorized request, please check your API key.")
30+
31+
if "errors" in response.json():
32+
raise error.QueryError(response.json()["errors"][0]["message"])
33+
2234
return response.json()

runpod/error.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
'''
2+
runpd | error.py
3+
4+
This file contains the error classes for the runpod package.
5+
'''
6+
7+
from typing import Optional
8+
9+
10+
class RunPodError(Exception):
11+
'''
12+
Base class for all runpod errors
13+
'''
14+
def __init__(self, message: Optional[str] = None):
15+
super().__init__(message)
16+
17+
self.message = message
18+
19+
20+
21+
class AuthenticationError(RunPodError):
22+
'''
23+
Raised when authentication fails
24+
'''
25+
26+
27+
class QueryError(RunPodError):
28+
'''
29+
Raised when a GraphQL query fails
30+
'''

tests/test_api_wrapper/test_ctl_commands.py

+29
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,35 @@ def test_terminate_pod(self):
123123

124124
self.assertIsNone(ctl_commands.terminate_pod(pod_id="POD_ID"))
125125

126+
def test_raised_error(self):
127+
'''
128+
Test raised_error
129+
'''
130+
with patch("runpod.api_wrapper.graphql.requests.post") as patch_request:
131+
patch_request.return_value.json.return_value = {
132+
"errors": [
133+
{
134+
"message": "Error Message"
135+
}
136+
]
137+
}
138+
139+
with self.assertRaises(Exception) as context:
140+
ctl_commands.get_gpus()
141+
142+
self.assertEqual(str(context.exception), "Error Message")
143+
144+
145+
# Test Unauthorized with status code 401
146+
with patch("runpod.api_wrapper.graphql.requests.post") as patch_request:
147+
patch_request.return_value.status_code = 401
148+
149+
with self.assertRaises(Exception) as context:
150+
ctl_commands.get_gpus()
151+
152+
self.assertEqual(
153+
str(context.exception), "Unauthorized request, please check your API key.")
154+
126155
def test_get_pods(self):
127156
'''
128157
Tests get_pods

0 commit comments

Comments
 (0)