Skip to content

Commit b361235

Browse files
committed
Add dynamodb retry config for throttling and other errors. Add exponential backoff and jitter for unprocessed keys. Fix edge case where we succesfully process keys on our last attempt but still fail
1 parent 04418af commit b361235

File tree

1 file changed

+55
-30
lines changed

1 file changed

+55
-30
lines changed

tron/serialize/runstate/dynamodb_state_store.py

+55-30
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import os
66
import pickle
7+
import random
78
import threading
89
import time
910
from collections import defaultdict
@@ -20,6 +21,7 @@
2021
from typing import TypeVar
2122

2223
import boto3 # type: ignore
24+
from botocore.config import Config
2325

2426
import tron.prom_metrics as prom_metrics
2527
from tron.core.job import Job
@@ -35,16 +37,33 @@
3537
# to contain other attributes like object name and number of partitions.
3638
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
3739
MAX_SAVE_QUEUE = 500
38-
MAX_ATTEMPTS = 10
40+
# This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
41+
# keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
42+
# infinite loops in the case where a key is truly unprocessable.
43+
MAX_UNPROCESSED_KEYS_RETRIES = 10
3944
MAX_TRANSACT_WRITE_ITEMS = 100
4045
log = logging.getLogger(__name__)
4146
T = TypeVar("T")
4247

4348

4449
class DynamoDBStateStore:
4550
def __init__(self, name, dynamodb_region, stopping=False) -> None:
46-
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region)
47-
self.client = boto3.client("dynamodb", region_name=dynamodb_region)
51+
# Standard mode includes an exponential backoff by a base factor of 2 for a
52+
# maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
53+
# random number between 0 and 1 and r is the base factor of 2). This might
54+
# look like:
55+
#
56+
# seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
57+
#
58+
# By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
59+
# on the random jitter.
60+
#
61+
# It handles transient errors like RequestTimeout and ConnectionError, as well
62+
# as Service-side errors like Throttling, SlowDown, and LimitExceeded.
63+
retry_config = Config(retries={"max_attempts": 5, "mode": "standard"})
64+
65+
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region, config=retry_config)
66+
self.client = boto3.client("dynamodb", region_name=dynamodb_region, config=retry_config)
4867
self.name = name
4968
self.dynamodb_region = dynamodb_region
5069
self.table = self.dynamodb.Table(name)
@@ -63,11 +82,11 @@ def build_key(self, type, iden) -> str:
6382

6483
def restore(self, keys, read_json: bool = False) -> dict:
6584
"""
66-
Fetch all under the same parition key(s).
85+
Fetch all under the same partition key(s).
6786
ret: <dict of key to states>
6887
"""
6988
# format of the keys always passed here is
70-
# job_state job_name --> high level info about the job: enabled, run_nums
89+
# job_state job_name --> high level info about the job: enabled, run_nums
7190
# job_run_state job_run_name --> high level info about the job run
7291
first_items = self._get_first_partitions(keys)
7392
remaining_items = self._get_remaining_partitions(first_items, read_json)
@@ -87,8 +106,11 @@ def _get_items(self, table_keys: list) -> object:
87106
items = []
88107
# let's avoid potentially mutating our input :)
89108
cand_keys_list = copy.copy(table_keys)
90-
attempts_to_retrieve_keys = 0
91-
while len(cand_keys_list) != 0:
109+
attempts = 0
110+
base_delay = 0.5
111+
max_delay = 10
112+
113+
while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES:
92114
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
93115
responses = [
94116
executor.submit(
@@ -106,20 +128,33 @@ def _get_items(self, table_keys: list) -> object:
106128
cand_keys_list = []
107129
for resp in concurrent.futures.as_completed(responses):
108130
try:
109-
items.extend(resp.result()["Responses"][self.name])
110-
# add any potential unprocessed keys to the thread pool
111-
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
112-
cand_keys_list.extend(resp.result()["UnprocessedKeys"][self.name]["Keys"])
113-
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
114-
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
115-
error = Exception(
116-
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
117-
)
118-
raise error
131+
result = resp.result()
132+
items.extend(result.get("Responses", {}).get(self.name, []))
133+
134+
# If DynamoDB returns unprocessed keys, we need to collect them and retry
135+
unprocessed_keys = result.get("UnprocessedKeys", {}).get(self.name, {}).get("Keys", [])
136+
if unprocessed_keys:
137+
cand_keys_list.extend(unprocessed_keys)
119138
except Exception as e:
120139
log.exception("Encountered issues retrieving data from DynamoDB")
121140
raise e
122-
attempts_to_retrieve_keys += 1
141+
if cand_keys_list:
142+
attempts += 1
143+
# Exponential backoff for retrying unprocessed keys
144+
exponential_delay = min(base_delay * (2 ** (attempts - 1)), max_delay)
145+
# Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls
146+
jitter = random.uniform(0, exponential_delay)
147+
delay = jitter
148+
log.warning(
149+
f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - Retrying {len(cand_keys_list)} unprocessed keys after {delay:.2f}s delay."
150+
)
151+
time.sleep(delay)
152+
if cand_keys_list:
153+
error = Exception(
154+
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys_list}\n from dynamodb after {MAX_UNPROCESSED_KEYS_RETRIES} retries."
155+
)
156+
log.error(repr(error))
157+
raise error
123158
return items
124159

125160
def _get_first_partitions(self, keys: list):
@@ -336,25 +371,15 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
336371
"N": str(num_json_val_partitions),
337372
}
338373

339-
count = 0
340374
items.append(item)
341375

342376
while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
343377
try:
344378
self.client.transact_write_items(TransactItems=items)
345379
items = []
346-
break # exit the while loop on successful writing
347380
except Exception as e:
348-
count += 1
349-
if count > 3:
350-
timer(
351-
name="tron.dynamodb.setitem",
352-
delta=time.time() - start,
353-
)
354-
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
355-
raise e
356-
else:
357-
log.warning(f"Got error while saving {key}, trying again: {repr(e)}")
381+
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
382+
raise e
358383
timer(
359384
name="tron.dynamodb.setitem",
360385
delta=time.time() - start,

0 commit comments

Comments
 (0)