44import logging
55import os
66import time
7+ from datetime import datetime
78from functools import cmp_to_key
89
910import requests
@@ -156,6 +157,7 @@ def __init__(
156157 keys = None ,
157158 source = "" ,
158159 cache_time = 300 ,
160+ ignore_errors_period = 0 ,
159161 fileformat = "jwks" ,
160162 keytype = "RSA" ,
161163 keyusage = None ,
@@ -188,6 +190,8 @@ def __init__(
188190 self .remote = False
189191 self .local = False
190192 self .cache_time = cache_time
193+ self .ignore_errors_period = ignore_errors_period
194+ self .ignore_errors_until = None # UNIX timestamp of last error
191195 self .time_out = 0
192196 self .etag = ""
193197 self .source = None
@@ -314,7 +318,11 @@ def do_local_jwk(self, filename):
314318 Load a JWKS from a local file
315319
316320 :param filename: Name of the file from which the JWKS should be loaded
321+ :return: True if load was successful or False if file hasn't been modified
317322 """
323+ if not self ._local_update_required ():
324+ return False
325+
318326 LOGGER .info ("Reading local JWKS from %s" , filename )
319327 with open (filename ) as input_file :
320328 _info = json .load (input_file )
@@ -324,6 +332,7 @@ def do_local_jwk(self, filename):
324332 self .do_keys ([_info ])
325333 self .last_local = time .time ()
326334 self .time_out = self .last_local + self .cache_time
335+ return True
327336
328337 def do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
329338 """
@@ -332,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
332341 :param filename: Name of the file
333342 :param keytype: Presently 'rsa' and 'ec' supported
334343 :param keyusage: encryption ('enc') or signing ('sig') or both
344+ :return: True if load was successful or False if file hasn't been modified
335345 """
346+ if not self ._local_update_required ():
347+ return False
348+
336349 LOGGER .info ("Reading local DER from %s" , filename )
337350 key_args = {}
338351 _kty = keytype .lower ()
@@ -355,16 +368,25 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
355368 self .do_keys ([key_args ])
356369 self .last_local = time .time ()
357370 self .time_out = self .last_local + self .cache_time
371+ return True
358372
359373 def do_remote (self ):
360374 """
361375 Load a JWKS from a webpage.
362376
363- :return: True or False if load was successful
377+ :return: True if load was successful or False if remote hasn't been modified
364378 """
365379 # if self.verify_ssl is not None:
366380 # self.httpc_params["verify"] = self.verify_ssl
367381
382+ if self .ignore_errors_until and time .time () < self .ignore_errors_until :
383+ LOGGER .warning (
384+ "Not reading remote JWKS from %s (in error holddown until %s)" ,
385+ self .source ,
386+ datetime .fromtimestamp (self .ignore_errors_until ),
387+ )
388+ return False
389+
368390 LOGGER .info ("Reading remote JWKS from %s" , self .source )
369391 try :
370392 LOGGER .debug ("KeyBundle fetch keys from: %s" , self .source )
@@ -378,7 +400,10 @@ def do_remote(self):
378400 LOGGER .error (err )
379401 raise UpdateFailed (REMOTE_FAILED .format (self .source , str (err )))
380402
381- if _http_resp .status_code == 200 : # New content
403+ load_successful = _http_resp .status_code == 200
404+ not_modified = _http_resp .status_code == 304
405+
406+ if load_successful :
382407 self .time_out = time .time () + self .cache_time
383408
384409 self .imp_jwks = self ._parse_remote_response (_http_resp )
@@ -390,25 +415,27 @@ def do_remote(self):
390415 self .do_keys (self .imp_jwks ["keys" ])
391416 except KeyError :
392417 LOGGER .error ("No 'keys' keyword in JWKS" )
418+ self .ignore_errors_until = time .time () + self .ignore_errors_period
393419 raise UpdateFailed (MALFORMED .format (self .source ))
394420
395421 if hasattr (_http_resp , "headers" ):
396422 headers = getattr (_http_resp , "headers" )
397423 self .last_remote = headers .get ("last-modified" ) or headers .get ("date" )
398-
399- elif _http_resp .status_code == 304 : # Not modified
424+ elif not_modified :
400425 LOGGER .debug ("%s not modified since %s" , self .source , self .last_remote )
401426 self .time_out = time .time () + self .cache_time
402-
403427 else :
404428 LOGGER .warning (
405429 "HTTP status %d reading remote JWKS from %s" ,
406430 _http_resp .status_code ,
407431 self .source ,
408432 )
433+ self .ignore_errors_until = time .time () + self .ignore_errors_period
409434 raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
435+
410436 self .last_updated = time .time ()
411- return True
437+ self .ignore_errors_until = None
438+ return load_successful
412439
413440 def _parse_remote_response (self , response ):
414441 """
@@ -433,23 +460,20 @@ def _parse_remote_response(self, response):
433460 return None
434461
435462 def _uptodate (self ):
436- res = False
437463 if self .remote or self .local :
438464 if time .time () > self .time_out :
439- if self .local and not self ._local_update_required ():
440- res = True
441- elif self .update ():
442- res = True
443- return res
465+ return self .update ()
466+ return False
444467
445468 def update (self ):
446469 """
447470 Reload the keys if necessary.
448471
449472 This is a forced update, will happen even if cache time has not elapsed.
450473 Replaced keys will be marked as inactive and not removed.
474+
475+ :return: True if update was ok or False if we encountered an error during update.
451476 """
452- res = True # An update was successful
453477 if self .source :
454478 _old_keys = self ._keys # just in case
455479
@@ -459,24 +483,27 @@ def update(self):
459483 try :
460484 if self .local :
461485 if self .fileformat in ["jwks" , "jwk" ]:
462- self .do_local_jwk (self .source )
486+ updated = self .do_local_jwk (self .source )
463487 elif self .fileformat == "der" :
464- self .do_local_der (self .source , self .keytype , self .keyusage )
488+ updated = self .do_local_der (self .source , self .keytype , self .keyusage )
465489 elif self .remote :
466- res = self .do_remote ()
490+ updated = self .do_remote ()
467491 except Exception as err :
468492 LOGGER .error ("Key bundle update failed: %s" , err )
469493 self ._keys = _old_keys # restore
470494 return False
471495
472- now = time .time ()
473- for _key in _old_keys :
474- if _key not in self ._keys :
475- if not _key .inactive_since : # If already marked don't mess
476- _key .inactive_since = now
477- self ._keys .append (_key )
496+ if updated :
497+ now = time .time ()
498+ for _key in _old_keys :
499+ if _key not in self ._keys :
500+ if not _key .inactive_since : # If already marked don't mess
501+ _key .inactive_since = now
502+ self ._keys .append (_key )
503+ else :
504+ self ._keys = _old_keys
478505
479- return res
506+ return True
480507
481508 def get (self , typ = "" , only_active = True ):
482509 """
0 commit comments