Skip to content

feat: retrying HTTP calls on errors #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions solnlib/splunk_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _request_handler(context):
'cert_file': string
'pool_connections', int,
'pool_maxsize', int,
'max_retries': int,
'retry_status_codes': list,
}
:type content: dict
"""
Expand Down Expand Up @@ -102,25 +104,32 @@ def _request_handler(context):
else:
cert = None

retries = Retry(
total=MAX_REQUEST_RETRIES,
backoff_factor=0.3,
status_forcelist=[500, 502, 503, 504],
allowed_methods=["GET", "POST", "PUT", "DELETE"],
raise_on_status=False,
)
if context.get("pool_connections", 0):
logging.info("Use HTTP connection pooling")
session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
max_retries=retries,
pool_connections=context.get("pool_connections", 10),
pool_maxsize=context.get("pool_maxsize", 10),
def adapter():
retries = Retry(
total=context.get("max_retries", MAX_REQUEST_RETRIES),
backoff_factor=0.3,
status_forcelist=context.get("retry_status_codes", [500, 502, 503, 504]),
allowed_methods=["GET", "POST", "PUT", "DELETE"],
raise_on_status=False,
)
session.mount("https://", adapter)
req_func = session.request
else:
req_func = requests.request

adapter_args = {
"max_retries": retries,
}

# By default, pool_connections and pool_maxsize are set to 10 in urllib3
if "pool_connections" in context:
adapter_args["pool_connections"] = context["pool_connections"]
if "pool_maxsize" in context:
adapter_args["pool_maxsize"] = context["pool_maxsize"]

return requests.adapters.HTTPAdapter(**adapter_args)

session = requests.Session()
session.mount("http://", adapter())
session.mount("https://", adapter())

req_func = session.request

def request(url, message, **kwargs):
"""
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
import socket
from contextlib import closing
from http.server import BaseHTTPRequestHandler, HTTPServer
from threading import Thread

import pytest


@pytest.fixture(scope="session")
def http_mock_server():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]

class Mock:
def __init__(self, host, port):
self.host = host
self.port = port
self.get_func = None

def get(self, func):
self.get_func = func
return func

mock = Mock("localhost", port)

class RequestArg:
def __init__(self):
self.headers = {
"Content-Type": "application/json",
}
self.response_code = 200

def send_header(self, key, value):
self.headers[key] = value

def send_response(self, code):
self.response_code = code

class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if mock.get_func is None:
self.send_response(404)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps({"error": "Not Found"}).encode("utf-8"))
return

request = RequestArg()
response = mock.get_func(request)

self.send_response(request.response_code)

for key, value in request.headers.items():
self.send_header(key, value)

self.end_headers()

if isinstance(response, dict):
response = json.dumps(response)

self.wfile.write(response.encode("utf-8"))

server_address = ("", mock.port)
httpd = HTTPServer(server_address, Handler)

thread = Thread(target=httpd.serve_forever)
thread.setDaemon(True)
thread.start()

yield mock

httpd.shutdown()
httpd.server_close()
thread.join()
66 changes: 66 additions & 0 deletions tests/unit/test_splunk_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from unittest import mock

import pytest
from splunklib.binding import HTTPError

from solnlib.splunk_rest_client import MAX_REQUEST_RETRIES

from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -109,3 +111,67 @@ def test_request_retry(http_conn_pool, http_resp, mock_get_splunkd_access_info):
http_conn_pool.side_effect = side_effects
with pytest.raises(ConnectionError):
rest_client.get("test")


@pytest.mark.parametrize("error_code", [429, 500, 503])
def test_request_throttling(http_mock_server, error_code):
@http_mock_server.get
def throttling(request):
"""Mock endpoint to simulate request throttling.

The endpoint will return an error status code for the first 5
requests, and a 200 status code for subsequent requests.
"""
number = getattr(throttling, "call_count", 0)
throttling.call_count = number + 1

if number < 2:
request.send_response(error_code)
request.send_header("Retry-After", "1")
return {"error": f"Error {number}"}

return {"content": "Success"}

rest_client = SplunkRestClient(
"msg_name_1",
"session_key",
"_",
scheme="http",
host="localhost",
port=http_mock_server.port,
)

resp = rest_client.get("test")
assert resp.status == 200
assert resp.body.read().decode("utf-8") == '{"content": "Success"}'


@pytest.mark.parametrize("error_code", [429, 500, 503])
def test_request_throttling_exceeded(http_mock_server, error_code):
@http_mock_server.get
def throttling(request):
"""Mock endpoint to simulate request throttling.

The endpoint will always return an error status code.
"""
number = getattr(throttling, "call_count", 0)
throttling.call_count = number + 1

request.send_response(error_code)
request.send_header("Retry-After", "1")
return {"error": f"Error {number}"}

rest_client = SplunkRestClient(
"msg_name_1",
"session_key",
"_",
scheme="http",
host="localhost",
port=http_mock_server.port,
)

with pytest.raises(HTTPError) as ex:
rest_client.get("test")

assert ex.value.status == error_code
assert ex.value.body.decode("utf-8") == '{"error": "Error 5"}'
Loading