Skip to content

Commit 26966c4

Browse files
sujaygarlankaPJSimon
authored andcommitted
Fix JWT retry logic (#479)
1 parent 345f848 commit 26966c4

File tree

5 files changed

+173
-18
lines changed

5 files changed

+173
-18
lines changed

HISTORY.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ Upcoming
88
- Fixed bug in get_admin_events function which caused errors when the optional event_types parameter was omitted.
99
- Added support for more attribute parameters when uploading new files and new versions of existing files.
1010
- Combined preflight check and lookup of accelerator URL into a single request for uploads.
11+
- Fixed JWT retry logic so the a new JTI claim is generated on each retry
12+
- Fixed retry logic so when a `Retry-After` header is passed back from the API, the SDK waits for the amount of time specified in the header before retrying
1113

1214
2.6.1 (2019-10-24)
1315
++++++++++++++++++

boxsdk/auth/jwt_auth.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import json
77
import random
88
import string
9+
import time
910

1011
from cryptography.hazmat.backends import default_backend
1112
from cryptography.hazmat.primitives import serialization
1213
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
1314
import jwt
1415
from six import binary_type, string_types, raise_from, text_type
1516

17+
from ..config import API
1618
from ..exception import BoxOAuthException
1719
from .oauth2 import OAuth2
1820
from ..object.user import User
@@ -239,14 +241,30 @@ def _auth_with_jwt(self, sub, sub_type):
239241
:rtype:
240242
`unicode`
241243
"""
242-
try:
243-
return self._construct_and_send_jwt_auth(sub, sub_type)
244-
except BoxOAuthException as ex:
245-
error_response = ex.network_response
246-
box_datetime = self._get_date_header(error_response)
247-
if box_datetime is not None and self._was_exp_claim_rejected_due_to_clock_skew(error_response):
248-
return self._construct_and_send_jwt_auth(sub, sub_type, box_datetime)
249-
raise
244+
attempt_number = 0
245+
jwt_time = None
246+
while True:
247+
try:
248+
return self._construct_and_send_jwt_auth(sub, sub_type, jwt_time)
249+
except BoxOAuthException as ex:
250+
network_response = ex.network_response
251+
code = network_response.status_code # pylint: disable=maybe-no-member
252+
box_datetime = self._get_date_header(network_response)
253+
254+
if attempt_number >= API.MAX_RETRY_ATTEMPTS:
255+
raise ex
256+
257+
if (code == 429 or code >= 500):
258+
jwt_time = None
259+
elif box_datetime is not None and self._was_exp_claim_rejected_due_to_clock_skew(network_response):
260+
jwt_time = box_datetime
261+
else:
262+
raise ex
263+
264+
time_delay = self._session.get_retry_after_time(attempt_number, network_response.headers.get('Retry-After', None)) # pylint: disable=maybe-no-member
265+
time.sleep(time_delay)
266+
attempt_number += 1
267+
self._logger.debug('Retrying JWT request')
250268

251269
@staticmethod
252270
def _get_date_header(network_response):

boxsdk/session/session.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Session(object):
2424

2525
_retry_randomization_factor = 0.5
2626
_retry_base_interval = 1
27+
_JWT_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:jwt-bearer'
2728

2829
"""
2930
Box API session. Provides automatic retry of failed requests.
@@ -268,7 +269,7 @@ def with_default_network_request_kwargs(self, extra_network_parameters):
268269
# We updated our retry strategy to use exponential backoff instead of the header returned from the API response.
269270
# This is something we can remove in latter major bumps.
270271
# pylint: disable=unused-argument
271-
def _get_retry_after_time(self, attempt_number, retry_after_header):
272+
def get_retry_after_time(self, attempt_number, retry_after_header):
272273
"""
273274
Get the amount of time to wait before retrying the API request, using the attempt number that failed to
274275
calculate the retry time for the next retry attempt.
@@ -285,6 +286,11 @@ def _get_retry_after_time(self, attempt_number, retry_after_header):
285286
:return: Number of seconds to wait before retrying.
286287
:rtype: `Number`
287288
"""
289+
if retry_after_header is not None:
290+
try:
291+
return int(retry_after_header)
292+
except (ValueError, TypeError):
293+
pass
288294
min_randomization = 1 - self._retry_randomization_factor
289295
max_randomization = 1 + self._retry_randomization_factor
290296
randomization = (random.uniform(0, 1) * (max_randomization - min_randomization)) + min_randomization
@@ -388,7 +394,7 @@ def _prepare_and_send_request(
388394
network_response = self._send_request(request, **kwargs)
389395

390396
while True:
391-
retry = self._get_retry_request_callable(network_response, attempt_number, request)
397+
retry = self._get_retry_request_callable(network_response, attempt_number, request, **kwargs)
392398

393399
if retry is None or attempt_number >= API.MAX_RETRY_ATTEMPTS:
394400
break
@@ -401,7 +407,7 @@ def _prepare_and_send_request(
401407

402408
return network_response
403409

404-
def _get_retry_request_callable(self, network_response, attempt_number, request):
410+
def _get_retry_request_callable(self, network_response, attempt_number, request, **kwargs):
405411
"""
406412
Get a callable that retries a request for certain types of failure.
407413
@@ -429,11 +435,15 @@ def _get_retry_request_callable(self, network_response, attempt_number, request)
429435
`callable`
430436
"""
431437
# pylint:disable=unused-argument
438+
data = kwargs.get('data', {})
439+
grant_type = None
440+
if 'grant_type' in data:
441+
grant_type = data['grant_type']
432442
code = network_response.status_code
433-
if code in (202, 429) or code >= 500:
443+
if (code in (202, 429) or code >= 500) and grant_type != self._JWT_GRANT_TYPE:
434444
return partial(
435445
self._network_layer.retry_after,
436-
self._get_retry_after_time(attempt_number, network_response.headers.get('Retry-After', None)),
446+
self.get_retry_after_time(attempt_number, network_response.headers.get('Retry-After', None)),
437447
self._send_request,
438448
)
439449
return None
@@ -547,7 +557,7 @@ def _renew_session(self, access_token_used):
547557
new_access_token, _ = self._oauth.refresh(access_token_used)
548558
return new_access_token
549559

550-
def _get_retry_request_callable(self, network_response, attempt_number, request):
560+
def _get_retry_request_callable(self, network_response, attempt_number, request, **kwargs):
551561
"""
552562
Get a callable that retries a request for certain types of failure.
553563
@@ -581,6 +591,7 @@ def _get_retry_request_callable(self, network_response, attempt_number, request)
581591
network_response,
582592
attempt_number,
583593
request,
594+
**kwargs
584595
)
585596

586597
def _send_request(self, request, **kwargs):

test/unit/auth/test_jwt_auth.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def unsuccessful_jwt_response(box_datetime, status_code, error_description, incl
491491
@pytest.mark.parametrize('jwt_algorithm', ('RS512',))
492492
@pytest.mark.parametrize('rsa_passphrase', (None,))
493493
@pytest.mark.parametrize('pass_private_key_by_path', (False,))
494-
@pytest.mark.parametrize('status_code', (400, 401, 429, 500))
494+
@pytest.mark.parametrize('status_code', (400, 401))
495495
@pytest.mark.parametrize('error_description', ('invalid box_sub_type claim', 'invalid kid', "check the 'exp' claim"))
496496
@pytest.mark.parametrize('error_code', ('invalid_grant', 'bad_request'))
497497
@pytest.mark.parametrize('include_date_header', (True, False))
@@ -512,8 +512,132 @@ def test_auth_retry_for_invalid_exp_claim(
512512
auth.authenticate_instance(enterprise_id)
513513
else:
514514
auth.authenticate_instance(enterprise_id)
515-
expected_calls = [call(enterprise_id, 'enterprise')]
515+
expected_calls = [call(enterprise_id, 'enterprise', None)]
516516
if expect_auth_retry:
517517
expected_calls.append(call(enterprise_id, 'enterprise', box_datetime.replace(microsecond=0, tzinfo=None)))
518518
assert len(mock_send_jwt.mock_calls) == len(expected_calls)
519519
mock_send_jwt.assert_has_calls(expected_calls)
520+
521+
522+
@pytest.mark.parametrize('jwt_algorithm', ('RS512',))
523+
@pytest.mark.parametrize('rsa_passphrase', (None,))
524+
@pytest.mark.parametrize('pass_private_key_by_path', (False,))
525+
@pytest.mark.parametrize('status_code', (429,))
526+
@pytest.mark.parametrize('error_description', ('Request rate limit exceeded',))
527+
@pytest.mark.parametrize('error_code', ('rate_limit_exceeded',))
528+
@pytest.mark.parametrize('include_date_header', (False,))
529+
def test_auth_retry_for_rate_limit_error(
530+
jwt_auth_init_mocks,
531+
unsuccessful_jwt_response,
532+
):
533+
# pylint:disable=redefined-outer-name
534+
enterprise_id = 'fake_enterprise_id'
535+
with jwt_auth_init_mocks(assert_authed=False) as params:
536+
auth = params[0]
537+
with patch.object(auth, '_construct_and_send_jwt_auth') as mock_send_jwt:
538+
side_effect = []
539+
expected_calls = []
540+
# Retries multiple times, but less than max retries. Then succeeds when it gets a token.
541+
for _ in range(API.MAX_RETRY_ATTEMPTS - 2):
542+
side_effect.append(BoxOAuthException(429, network_response=unsuccessful_jwt_response))
543+
expected_calls.append(call(enterprise_id, 'enterprise', None))
544+
side_effect.append('jwt_token')
545+
expected_calls.append(call(enterprise_id, 'enterprise', None))
546+
mock_send_jwt.side_effect = side_effect
547+
548+
auth.authenticate_instance(enterprise_id)
549+
assert len(mock_send_jwt.mock_calls) == len(expected_calls)
550+
mock_send_jwt.assert_has_calls(expected_calls)
551+
552+
553+
@pytest.mark.parametrize('jwt_algorithm', ('RS512',))
554+
@pytest.mark.parametrize('rsa_passphrase', (None,))
555+
@pytest.mark.parametrize('pass_private_key_by_path', (False,))
556+
@pytest.mark.parametrize('status_code', (429,))
557+
@pytest.mark.parametrize('error_description', ('Request rate limit exceeded',))
558+
@pytest.mark.parametrize('error_code', ('rate_limit_exceeded',))
559+
@pytest.mark.parametrize('include_date_header', (False,))
560+
def test_auth_max_retries_for_rate_limit_error(
561+
jwt_auth_init_mocks,
562+
unsuccessful_jwt_response,
563+
):
564+
# pylint:disable=redefined-outer-name
565+
enterprise_id = 'fake_enterprise_id'
566+
with jwt_auth_init_mocks(assert_authed=False) as params:
567+
auth = params[0]
568+
with patch.object(auth, '_construct_and_send_jwt_auth') as mock_send_jwt:
569+
side_effect = []
570+
expected_calls = []
571+
# Retries max number of times, then throws the error
572+
for _ in range(API.MAX_RETRY_ATTEMPTS + 1):
573+
side_effect.append(BoxOAuthException(429, network_response=unsuccessful_jwt_response))
574+
expected_calls.append(call(enterprise_id, 'enterprise', None))
575+
mock_send_jwt.side_effect = side_effect
576+
577+
with pytest.raises(BoxOAuthException) as error:
578+
auth.authenticate_instance(enterprise_id)
579+
assert error.value.status == 429
580+
assert len(mock_send_jwt.mock_calls) == len(expected_calls)
581+
mock_send_jwt.assert_has_calls(expected_calls)
582+
583+
584+
@pytest.mark.parametrize('jwt_algorithm', ('RS512',))
585+
@pytest.mark.parametrize('rsa_passphrase', (None,))
586+
@pytest.mark.parametrize('pass_private_key_by_path', (False,))
587+
@pytest.mark.parametrize('status_code', (500,))
588+
@pytest.mark.parametrize('error_description', ('Internal Server Error',))
589+
@pytest.mark.parametrize('error_code', ('internal_server_error',))
590+
@pytest.mark.parametrize('include_date_header', (False,))
591+
def test_auth_retry_for_internal_server_error(
592+
jwt_auth_init_mocks,
593+
unsuccessful_jwt_response,
594+
):
595+
# pylint:disable=redefined-outer-name
596+
enterprise_id = 'fake_enterprise_id'
597+
with jwt_auth_init_mocks(assert_authed=False) as params:
598+
auth = params[0]
599+
with patch.object(auth, '_construct_and_send_jwt_auth') as mock_send_jwt:
600+
side_effect = []
601+
expected_calls = []
602+
# Retries multiple times, but less than max retries. Then succeeds when it gets a token.
603+
for _ in range(API.MAX_RETRY_ATTEMPTS - 2):
604+
side_effect.append(BoxOAuthException(500, network_response=unsuccessful_jwt_response))
605+
expected_calls.append(call(enterprise_id, 'enterprise', None))
606+
side_effect.append('jwt_token')
607+
expected_calls.append(call(enterprise_id, 'enterprise', None))
608+
mock_send_jwt.side_effect = side_effect
609+
610+
auth.authenticate_instance(enterprise_id)
611+
assert len(mock_send_jwt.mock_calls) == len(expected_calls)
612+
mock_send_jwt.assert_has_calls(expected_calls)
613+
614+
615+
@pytest.mark.parametrize('jwt_algorithm', ('RS512',))
616+
@pytest.mark.parametrize('rsa_passphrase', (None,))
617+
@pytest.mark.parametrize('pass_private_key_by_path', (False,))
618+
@pytest.mark.parametrize('status_code', (500,))
619+
@pytest.mark.parametrize('error_description', ('Internal Server Error',))
620+
@pytest.mark.parametrize('error_code', ('internal_server_error',))
621+
@pytest.mark.parametrize('include_date_header', (False,))
622+
def test_auth_max_retries_for_internal_server_error(
623+
jwt_auth_init_mocks,
624+
unsuccessful_jwt_response,
625+
):
626+
# pylint:disable=redefined-outer-name
627+
enterprise_id = 'fake_enterprise_id'
628+
with jwt_auth_init_mocks(assert_authed=False) as params:
629+
auth = params[0]
630+
with patch.object(auth, '_construct_and_send_jwt_auth') as mock_send_jwt:
631+
side_effect = []
632+
expected_calls = []
633+
# Retries max number of times, then throws the error
634+
for _ in range(API.MAX_RETRY_ATTEMPTS + 1):
635+
side_effect.append(BoxOAuthException(500, network_response=unsuccessful_jwt_response))
636+
expected_calls.append(call(enterprise_id, 'enterprise', None))
637+
mock_send_jwt.side_effect = side_effect
638+
639+
with pytest.raises(BoxOAuthException) as error:
640+
auth.authenticate_instance(enterprise_id)
641+
assert error.value.status == 500
642+
assert len(mock_send_jwt.mock_calls) == len(expected_calls)
643+
mock_send_jwt.assert_has_calls(expected_calls)

test/unit/session/test_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_box_session_retries_response_after_retry_after(
150150
assert box_response.status_code == 200
151151
assert len(mock_network_layer.retry_after.call_args_list) == 1
152152
assert isinstance(mock_network_layer.retry_after.call_args[0][0], Number)
153-
assert round(mock_network_layer.retry_after.call_args[0][0], 4) == 1.18
153+
assert round(mock_network_layer.retry_after.call_args[0][0], 4) == 1
154154

155155

156156
@pytest.mark.parametrize('test_method', [
@@ -300,7 +300,7 @@ def test_session_uses_local_config(box_session, mock_network_layer, generic_succ
300300
)
301301
def test_get_retry_after_time(box_session, attempt_number, retry_after_header, expected_result):
302302
with patch('random.uniform', return_value=0.68):
303-
retry_time = box_session._get_retry_after_time(attempt_number, retry_after_header) # pylint: disable=protected-access
303+
retry_time = box_session.get_retry_after_time(attempt_number, retry_after_header) # pylint: disable=protected-access
304304
retry_time = round(retry_time, 4)
305305
assert retry_time == expected_result
306306

0 commit comments

Comments
 (0)