20
20
from typing import TypeVar
21
21
22
22
import boto3 # type: ignore
23
+ import botocore # type: ignore
24
+ from botocore .config import Config # type: ignore
23
25
24
26
import tron .prom_metrics as prom_metrics
25
27
from tron .core .job import Job
35
37
# to contain other attributes like object name and number of partitions.
36
38
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
37
39
MAX_SAVE_QUEUE = 500
38
- MAX_ATTEMPTS = 10
40
+ # This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
41
+ # keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
42
+ # infinite loops in the case where a key is truly unprocessable. We allow for more retries than it should
43
+ # ever take to avoid failing restores due to transient issues.
44
+ MAX_UNPROCESSED_KEYS_RETRIES = 30
39
45
MAX_TRANSACT_WRITE_ITEMS = 100
40
46
log = logging .getLogger (__name__ )
41
47
T = TypeVar ("T" )
42
48
43
49
44
50
class DynamoDBStateStore :
45
51
def __init__ (self , name , dynamodb_region , stopping = False ) -> None :
46
- self .dynamodb = boto3 .resource ("dynamodb" , region_name = dynamodb_region )
47
- self .client = boto3 .client ("dynamodb" , region_name = dynamodb_region )
52
+ # Standard mode includes an exponential backoff by a base factor of 2 for a
53
+ # maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
54
+ # random number between 0 and 1 and r is the base factor of 2). This might
55
+ # look like:
56
+ #
57
+ # seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
58
+ #
59
+ # By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
60
+ # on the random jitter.
61
+ #
62
+ # It handles transient errors like RequestTimeout and ConnectionError, as well
63
+ # as Service-side errors like Throttling, SlowDown, and LimitExceeded.
64
+ retry_config = Config (retries = {"max_attempts" : 5 , "mode" : "standard" })
65
+
66
+ self .dynamodb = boto3 .resource ("dynamodb" , region_name = dynamodb_region , config = retry_config )
67
+ self .client = boto3 .client ("dynamodb" , region_name = dynamodb_region , config = retry_config )
48
68
self .name = name
49
69
self .dynamodb_region = dynamodb_region
50
70
self .table = self .dynamodb .Table (name )
@@ -63,11 +83,11 @@ def build_key(self, type, iden) -> str:
63
83
64
84
def restore (self , keys , read_json : bool = False ) -> dict :
65
85
"""
66
- Fetch all under the same parition key(s).
86
+ Fetch all under the same partition key(s).
67
87
ret: <dict of key to states>
68
88
"""
69
89
# format of the keys always passed here is
70
- # job_state job_name --> high level info about the job: enabled, run_nums
90
+ # job_state job_name --> high level info about the job: enabled, run_nums
71
91
# job_run_state job_run_name --> high level info about the job run
72
92
first_items = self ._get_first_partitions (keys )
73
93
remaining_items = self ._get_remaining_partitions (first_items , read_json )
@@ -83,12 +103,22 @@ def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]:
83
103
cand_keys_chunks .append (keys [i : min (len (keys ), i + 100 )])
84
104
return cand_keys_chunks
85
105
106
+ def _calculate_backoff_delay (self , attempt : int ) -> int :
107
+ # Clamp attempt to 1 to avoid negative or zero exponent
108
+ safe_attempt = max (attempt , 1 )
109
+ base_delay_seconds = 1
110
+ max_delay_seconds = 10
111
+ delay : int = min (base_delay_seconds * (2 ** (safe_attempt - 1 )), max_delay_seconds )
112
+ return delay
113
+
86
114
def _get_items (self , table_keys : list ) -> object :
87
115
items = []
88
116
# let's avoid potentially mutating our input :)
89
117
cand_keys_list = copy .copy (table_keys )
90
- attempts_to_retrieve_keys = 0
91
- while len (cand_keys_list ) != 0 :
118
+ attempts = 0
119
+
120
+ # TODO: TRON-2363 - We should refactor this to not consume attempts when we are still making progress
121
+ while len (cand_keys_list ) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES :
92
122
with concurrent .futures .ThreadPoolExecutor (max_workers = 5 ) as executor :
93
123
responses = [
94
124
executor .submit (
@@ -106,20 +136,35 @@ def _get_items(self, table_keys: list) -> object:
106
136
cand_keys_list = []
107
137
for resp in concurrent .futures .as_completed (responses ):
108
138
try :
109
- items . extend ( resp .result ()[ "Responses" ][ self . name ] )
110
- # add any potential unprocessed keys to the thread pool
111
- if resp . result ()[ "UnprocessedKeys" ]. get ( self . name ) and attempts_to_retrieve_keys < MAX_ATTEMPTS :
112
- cand_keys_list . extend ( resp . result ()[ "UnprocessedKeys" ][ self . name ][ "Keys" ])
113
- elif attempts_to_retrieve_keys >= MAX_ATTEMPTS :
114
- failed_keys = resp . result ()[ "UnprocessedKeys" ][ self . name ][ "Keys" ]
115
- error = Exception (
116
- f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n { failed_keys } \n from dynamodb \n { resp . result () } "
117
- )
118
- raise error
119
- except Exception as e :
139
+ result = resp .result ()
140
+ items . extend ( result . get ( "Responses" , {}). get ( self . name , []))
141
+
142
+ # If DynamoDB returns unprocessed keys, we need to collect them and retry
143
+ unprocessed_keys = result . get ( "UnprocessedKeys" , {}). get ( self . name , {}). get ( "Keys" , [])
144
+ if unprocessed_keys :
145
+ cand_keys_list . extend ( unprocessed_keys )
146
+ except botocore . exceptions . ClientError as e :
147
+ log . exception ( f"ClientError during batch_get_item: { e . response } " )
148
+ raise
149
+ except Exception :
120
150
log .exception ("Encountered issues retrieving data from DynamoDB" )
121
- raise e
122
- attempts_to_retrieve_keys += 1
151
+ raise
152
+ if cand_keys_list :
153
+ # We use _calculate_backoff_delay to get a delay that increases exponentially
154
+ # with each retry. These retry attempts are distinct from the boto3 retry_config
155
+ # and are used specifically to handle unprocessed keys.
156
+ attempts += 1
157
+ delay = self ._calculate_backoff_delay (attempts )
158
+ log .warning (
159
+ f"Attempt { attempts } /{ MAX_UNPROCESSED_KEYS_RETRIES } - "
160
+ f"Retrying { len (cand_keys_list )} unprocessed keys after { delay } s delay."
161
+ )
162
+ time .sleep (delay )
163
+ if cand_keys_list :
164
+ msg = f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n { cand_keys_list } \n from dynamodb after { MAX_UNPROCESSED_KEYS_RETRIES } retries."
165
+ log .error (msg )
166
+
167
+ raise KeyError (msg )
123
168
return items
124
169
125
170
def _get_first_partitions (self , keys : list ):
@@ -291,12 +336,17 @@ def _save_loop(self):
291
336
def __setitem__ (self , key : str , value : Tuple [bytes , str ]) -> None :
292
337
"""
293
338
Partition the item and write up to MAX_TRANSACT_WRITE_ITEMS
294
- partitions atomically. Retry up to 3 times on failure.
339
+ partitions atomically using TransactWriteItems.
340
+
341
+ The function examines the size of pickled_val and json_val,
342
+ splitting them into multiple segments based on OBJECT_SIZE,
343
+ storing each segment under the same partition key.
295
344
296
- Examine the size of `pickled_val` and `json_val`, and
297
- splice them into different parts based on `OBJECT_SIZE`
298
- with different sort keys, and save them under the same
299
- partition key built.
345
+ It relies on the boto3/botocore retry_config to handle
346
+ certain errors (e.g. throttling). If an error is not
347
+ addressed by boto3's internal logic, the transaction fails
348
+ and raises an exception. It is the caller's responsibility
349
+ to implement further retries.
300
350
"""
301
351
start = time .time ()
302
352
@@ -337,25 +387,21 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
337
387
"N" : str (num_json_val_partitions ),
338
388
}
339
389
340
- count = 0
341
390
items .append (item )
342
391
343
- while len (items ) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1 :
392
+ # We want to write the items when we've either reached the max number of items
393
+ # for a transaction, or when we're done processing all partitions
394
+ if len (items ) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1 :
344
395
try :
345
396
self .client .transact_write_items (TransactItems = items )
346
397
items = []
347
- break # exit the while loop on successful writing
348
- except Exception as e :
349
- count += 1
350
- if count > 3 :
351
- timer (
352
- name = "tron.dynamodb.setitem" ,
353
- delta = time .time () - start ,
354
- )
355
- log .error (f"Failed to save partition for key: { key } , error: { repr (e )} " )
356
- raise e
357
- else :
358
- log .warning (f"Got error while saving { key } , trying again: { repr (e )} " )
398
+ except Exception :
399
+ timer (
400
+ name = "tron.dynamodb.setitem" ,
401
+ delta = time .time () - start ,
402
+ )
403
+ log .exception (f"Failed to save partition for key: { key } " )
404
+ raise
359
405
timer (
360
406
name = "tron.dynamodb.setitem" ,
361
407
delta = time .time () - start ,
0 commit comments