4
4
import math
5
5
import os
6
6
import pickle
7
+ import random
7
8
import threading
8
9
import time
9
10
from collections import defaultdict
20
21
from typing import TypeVar
21
22
22
23
import boto3 # type: ignore
24
+ from botocore .config import Config
23
25
24
26
import tron .prom_metrics as prom_metrics
25
27
from tron .core .job import Job
35
37
# to contain other attributes like object name and number of partitions.
36
38
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
37
39
MAX_SAVE_QUEUE = 500
38
- MAX_ATTEMPTS = 10
40
+ # This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
41
+ # keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
42
+ # infinite loops in the case where a key is truly unprocessable.
43
+ MAX_UNPROCESSED_KEYS_RETRIES = 10
39
44
MAX_TRANSACT_WRITE_ITEMS = 100
40
45
log = logging .getLogger (__name__ )
41
46
T = TypeVar ("T" )
42
47
43
48
44
49
class DynamoDBStateStore :
45
50
def __init__ (self , name , dynamodb_region , stopping = False ) -> None :
46
- self .dynamodb = boto3 .resource ("dynamodb" , region_name = dynamodb_region )
47
- self .client = boto3 .client ("dynamodb" , region_name = dynamodb_region )
51
+ # Standard mode includes an exponential backoff by a base factor of 2 for a
52
+ # maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
53
+ # random number between 0 and 1 and r is the base factor of 2). This might
54
+ # look like:
55
+ #
56
+ # seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
57
+ #
58
+ # By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
59
+ # on the random jitter.
60
+ #
61
+ # It handles transient errors like RequestTimeout and ConnectionError, as well
62
+ # as Service-side errors like Throttling, SlowDown, and LimitExceeded.
63
+ retry_config = Config (retries = {"max_attempts" : 5 , "mode" : "standard" })
64
+
65
+ self .dynamodb = boto3 .resource ("dynamodb" , region_name = dynamodb_region , config = retry_config )
66
+ self .client = boto3 .client ("dynamodb" , region_name = dynamodb_region , config = retry_config )
48
67
self .name = name
49
68
self .dynamodb_region = dynamodb_region
50
69
self .table = self .dynamodb .Table (name )
@@ -63,11 +82,11 @@ def build_key(self, type, iden) -> str:
63
82
64
83
def restore (self , keys , read_json : bool = False ) -> dict :
65
84
"""
66
- Fetch all under the same parition key(s).
85
+ Fetch all under the same partition key(s).
67
86
ret: <dict of key to states>
68
87
"""
69
88
# format of the keys always passed here is
70
- # job_state job_name --> high level info about the job: enabled, run_nums
89
+ # job_state job_name --> high level info about the job: enabled, run_nums
71
90
# job_run_state job_run_name --> high level info about the job run
72
91
first_items = self ._get_first_partitions (keys )
73
92
remaining_items = self ._get_remaining_partitions (first_items , read_json )
@@ -87,8 +106,11 @@ def _get_items(self, table_keys: list) -> object:
87
106
items = []
88
107
# let's avoid potentially mutating our input :)
89
108
cand_keys_list = copy .copy (table_keys )
90
- attempts_to_retrieve_keys = 0
91
- while len (cand_keys_list ) != 0 :
109
+ attempts = 0
110
+ base_delay = 0.5
111
+ max_delay = 10
112
+
113
+ while len (cand_keys_list ) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES :
92
114
with concurrent .futures .ThreadPoolExecutor (max_workers = 5 ) as executor :
93
115
responses = [
94
116
executor .submit (
@@ -106,20 +128,33 @@ def _get_items(self, table_keys: list) -> object:
106
128
cand_keys_list = []
107
129
for resp in concurrent .futures .as_completed (responses ):
108
130
try :
109
- items .extend (resp .result ()["Responses" ][self .name ])
110
- # add any potential unprocessed keys to the thread pool
111
- if resp .result ()["UnprocessedKeys" ].get (self .name ) and attempts_to_retrieve_keys < MAX_ATTEMPTS :
112
- cand_keys_list .extend (resp .result ()["UnprocessedKeys" ][self .name ]["Keys" ])
113
- elif attempts_to_retrieve_keys >= MAX_ATTEMPTS :
114
- failed_keys = resp .result ()["UnprocessedKeys" ][self .name ]["Keys" ]
115
- error = Exception (
116
- f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n { failed_keys } \n from dynamodb\n { resp .result ()} "
117
- )
118
- raise error
131
+ result = resp .result ()
132
+ items .extend (result .get ("Responses" , {}).get (self .name , []))
133
+
134
+ # If DynamoDB returns unprocessed keys, we need to collect them and retry
135
+ unprocessed_keys = result .get ("UnprocessedKeys" , {}).get (self .name , {}).get ("Keys" , [])
136
+ if unprocessed_keys :
137
+ cand_keys_list .extend (unprocessed_keys )
119
138
except Exception as e :
120
139
log .exception ("Encountered issues retrieving data from DynamoDB" )
121
140
raise e
122
- attempts_to_retrieve_keys += 1
141
+ if cand_keys_list :
142
+ attempts += 1
143
+ # Exponential backoff for retrying unprocessed keys
144
+ exponential_delay = min (base_delay * (2 ** (attempts - 1 )), max_delay )
145
+ # Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls
146
+ jitter = random .uniform (0 , exponential_delay )
147
+ delay = jitter
148
+ log .warning (
149
+ f"Attempt { attempts } /{ MAX_UNPROCESSED_KEYS_RETRIES } - Retrying { len (cand_keys_list )} unprocessed keys after { delay :.2f} s delay."
150
+ )
151
+ time .sleep (delay )
152
+ if cand_keys_list :
153
+ error = Exception (
154
+ f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n { cand_keys_list } \n from dynamodb after { MAX_UNPROCESSED_KEYS_RETRIES } retries."
155
+ )
156
+ log .error (repr (error ))
157
+ raise error
123
158
return items
124
159
125
160
def _get_first_partitions (self , keys : list ):
@@ -336,25 +371,15 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
336
371
"N" : str (num_json_val_partitions ),
337
372
}
338
373
339
- count = 0
340
374
items .append (item )
341
375
342
376
while len (items ) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1 :
343
377
try :
344
378
self .client .transact_write_items (TransactItems = items )
345
379
items = []
346
- break # exit the while loop on successful writing
347
380
except Exception as e :
348
- count += 1
349
- if count > 3 :
350
- timer (
351
- name = "tron.dynamodb.setitem" ,
352
- delta = time .time () - start ,
353
- )
354
- log .error (f"Failed to save partition for key: { key } , error: { repr (e )} " )
355
- raise e
356
- else :
357
- log .warning (f"Got error while saving { key } , trying again: { repr (e )} " )
381
+ log .error (f"Failed to save partition for key: { key } , error: { repr (e )} " )
382
+ raise e
358
383
timer (
359
384
name = "tron.dynamodb.setitem" ,
360
385
delta = time .time () - start ,
0 commit comments