Skip to content

Commit ed5c7a7

Browse files
authored
Merge pull request #1023 from Yelp/u/kkasp/TRON-2342-exponential-backoff-dynamo-get
Add dynamodb retry config for throttling and other errors. Add exponential backoff and jitter for unprocessed keys. Fix edge case where we succesfully process keys on our last attempt but still fail
2 parents 98c1879 + e0f2cce commit ed5c7a7

File tree

2 files changed

+150
-81
lines changed

2 files changed

+150
-81
lines changed

tests/serialize/runstate/dynamodb_state_store_test.py

+65-42
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from testifycompat import assert_equal
1111
from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore
12+
from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES
1213

1314

1415
def mock_transact_write_items(self):
@@ -294,58 +295,80 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec
294295
vals = store.restore([key])
295296
assert key not in vals
296297

297-
def test_retry_saving(self, store, small_object, large_object):
298-
with mock.patch(
299-
"moto.dynamodb2.responses.DynamoHandler.transact_write_items",
300-
side_effect=KeyError("foo"),
301-
) as mock_failed_write:
302-
keys = [store.build_key("job_state", i) for i in range(1)]
303-
value = small_object
304-
pairs = zip(keys, (value for i in range(len(keys))))
305-
try:
306-
store.save(pairs)
307-
except Exception:
308-
assert_equal(mock_failed_write.call_count, 3)
309-
310-
def test_retry_reading(self, store, small_object, large_object):
298+
@pytest.mark.parametrize(
299+
"test_object, side_effects, expected_save_errors, expected_queue_length",
300+
[
301+
# All attempts fail
302+
("small_object", [KeyError("foo")] * 3, 3, 1),
303+
("large_object", [KeyError("foo")] * 3, 3, 1),
304+
# Failure followed by success
305+
("small_object", [KeyError("foo"), {}], 0, 0),
306+
("large_object", [KeyError("foo"), {}], 0, 0),
307+
],
308+
)
309+
def test_retry_saving(
310+
self, test_object, side_effects, expected_save_errors, expected_queue_length, store, small_object, large_object
311+
):
312+
object_mapping = {
313+
"small_object": small_object,
314+
"large_object": large_object,
315+
}
316+
value = object_mapping[test_object]
317+
318+
with mock.patch.object(
319+
store.client,
320+
"transact_write_items",
321+
side_effect=side_effects,
322+
) as mock_transact_write:
323+
keys = [store.build_key("job_state", 0)]
324+
pairs = zip(keys, [value])
325+
store.save(pairs)
326+
327+
for _ in side_effects:
328+
store._consume_save_queue()
329+
330+
assert mock_transact_write.call_count == len(side_effects)
331+
assert store.save_errors == expected_save_errors
332+
assert len(store.save_queue) == expected_queue_length
333+
334+
@pytest.mark.parametrize(
335+
"attempt, expected_delay",
336+
[
337+
(1, 1),
338+
(2, 2),
339+
(3, 4),
340+
(4, 8),
341+
(5, 10),
342+
(6, 10),
343+
(7, 10),
344+
],
345+
)
346+
def test_calculate_backoff_delay(self, store, attempt, expected_delay):
347+
delay = store._calculate_backoff_delay(attempt)
348+
assert_equal(delay, expected_delay)
349+
350+
def test_retry_reading(self, store):
311351
unprocessed_value = {
312-
"Responses": {
313-
store.name: [
314-
{
315-
"index": {"N": "0"},
316-
"key": {"S": "job_state 0"},
317-
},
318-
],
319-
},
352+
"Responses": {},
320353
"UnprocessedKeys": {
321354
store.name: {
355+
"Keys": [{"key": {"S": store.build_key("job_state", 0)}, "index": {"N": "0"}}],
322356
"ConsistentRead": True,
323-
"Keys": [
324-
{
325-
"index": {"N": "0"},
326-
"key": {"S": "job_state 0"},
327-
}
328-
],
329-
},
357+
}
330358
},
331-
"ResponseMetadata": {},
332359
}
333-
keys = [store.build_key("job_state", i) for i in range(1)]
334-
value = small_object
335-
pairs = zip(keys, (value for i in range(len(keys))))
336-
store.save(pairs)
360+
361+
keys = [store.build_key("job_state", 0)]
362+
337363
with mock.patch.object(
338364
store.client,
339365
"batch_get_item",
340366
return_value=unprocessed_value,
341-
) as mock_failed_read:
342-
try:
343-
with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch(
344-
"tron.config.static_config.build_configuration_watcher", autospec=True
345-
):
346-
store.restore(keys)
347-
except Exception:
348-
assert_equal(mock_failed_read.call_count, 11)
367+
) as mock_batch_get_item, mock.patch("time.sleep") as mock_sleep, pytest.raises(Exception) as exec_info:
368+
store.restore(keys)
369+
assert "failed to retrieve items with keys" in str(exec_info.value)
370+
assert mock_batch_get_item.call_count == MAX_UNPROCESSED_KEYS_RETRIES
371+
assert mock_sleep.call_count == MAX_UNPROCESSED_KEYS_RETRIES
349372

350373
def test_restore_exception_propagation(self, store, small_object):
351374
# This test is to ensure that restore propagates exceptions upwards: see DAR-2328

tron/serialize/runstate/dynamodb_state_store.py

+85-39
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typing import TypeVar
2121

2222
import boto3 # type: ignore
23+
import botocore # type: ignore
24+
from botocore.config import Config # type: ignore
2325

2426
import tron.prom_metrics as prom_metrics
2527
from tron.core.job import Job
@@ -35,16 +37,34 @@
3537
# to contain other attributes like object name and number of partitions.
3638
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
3739
MAX_SAVE_QUEUE = 500
38-
MAX_ATTEMPTS = 10
40+
# This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
41+
# keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
42+
# infinite loops in the case where a key is truly unprocessable. We allow for more retries than it should
43+
# ever take to avoid failing restores due to transient issues.
44+
MAX_UNPROCESSED_KEYS_RETRIES = 30
3945
MAX_TRANSACT_WRITE_ITEMS = 100
4046
log = logging.getLogger(__name__)
4147
T = TypeVar("T")
4248

4349

4450
class DynamoDBStateStore:
4551
def __init__(self, name, dynamodb_region, stopping=False) -> None:
46-
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region)
47-
self.client = boto3.client("dynamodb", region_name=dynamodb_region)
52+
# Standard mode includes an exponential backoff by a base factor of 2 for a
53+
# maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
54+
# random number between 0 and 1 and r is the base factor of 2). This might
55+
# look like:
56+
#
57+
# seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
58+
#
59+
# By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
60+
# on the random jitter.
61+
#
62+
# It handles transient errors like RequestTimeout and ConnectionError, as well
63+
# as Service-side errors like Throttling, SlowDown, and LimitExceeded.
64+
retry_config = Config(retries={"max_attempts": 5, "mode": "standard"})
65+
66+
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region, config=retry_config)
67+
self.client = boto3.client("dynamodb", region_name=dynamodb_region, config=retry_config)
4868
self.name = name
4969
self.dynamodb_region = dynamodb_region
5070
self.table = self.dynamodb.Table(name)
@@ -63,11 +83,11 @@ def build_key(self, type, iden) -> str:
6383

6484
def restore(self, keys, read_json: bool = False) -> dict:
6585
"""
66-
Fetch all under the same parition key(s).
86+
Fetch all under the same partition key(s).
6787
ret: <dict of key to states>
6888
"""
6989
# format of the keys always passed here is
70-
# job_state job_name --> high level info about the job: enabled, run_nums
90+
# job_state job_name --> high level info about the job: enabled, run_nums
7191
# job_run_state job_run_name --> high level info about the job run
7292
first_items = self._get_first_partitions(keys)
7393
remaining_items = self._get_remaining_partitions(first_items, read_json)
@@ -83,12 +103,22 @@ def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]:
83103
cand_keys_chunks.append(keys[i : min(len(keys), i + 100)])
84104
return cand_keys_chunks
85105

106+
def _calculate_backoff_delay(self, attempt: int) -> int:
107+
# Clamp attempt to 1 to avoid negative or zero exponent
108+
safe_attempt = max(attempt, 1)
109+
base_delay_seconds = 1
110+
max_delay_seconds = 10
111+
delay: int = min(base_delay_seconds * (2 ** (safe_attempt - 1)), max_delay_seconds)
112+
return delay
113+
86114
def _get_items(self, table_keys: list) -> object:
87115
items = []
88116
# let's avoid potentially mutating our input :)
89117
cand_keys_list = copy.copy(table_keys)
90-
attempts_to_retrieve_keys = 0
91-
while len(cand_keys_list) != 0:
118+
attempts = 0
119+
120+
# TODO: TRON-2363 - We should refactor this to not consume attempts when we are still making progress
121+
while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES:
92122
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
93123
responses = [
94124
executor.submit(
@@ -106,20 +136,35 @@ def _get_items(self, table_keys: list) -> object:
106136
cand_keys_list = []
107137
for resp in concurrent.futures.as_completed(responses):
108138
try:
109-
items.extend(resp.result()["Responses"][self.name])
110-
# add any potential unprocessed keys to the thread pool
111-
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
112-
cand_keys_list.extend(resp.result()["UnprocessedKeys"][self.name]["Keys"])
113-
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
114-
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
115-
error = Exception(
116-
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
117-
)
118-
raise error
119-
except Exception as e:
139+
result = resp.result()
140+
items.extend(result.get("Responses", {}).get(self.name, []))
141+
142+
# If DynamoDB returns unprocessed keys, we need to collect them and retry
143+
unprocessed_keys = result.get("UnprocessedKeys", {}).get(self.name, {}).get("Keys", [])
144+
if unprocessed_keys:
145+
cand_keys_list.extend(unprocessed_keys)
146+
except botocore.exceptions.ClientError as e:
147+
log.exception(f"ClientError during batch_get_item: {e.response}")
148+
raise
149+
except Exception:
120150
log.exception("Encountered issues retrieving data from DynamoDB")
121-
raise e
122-
attempts_to_retrieve_keys += 1
151+
raise
152+
if cand_keys_list:
153+
# We use _calculate_backoff_delay to get a delay that increases exponentially
154+
# with each retry. These retry attempts are distinct from the boto3 retry_config
155+
# and are used specifically to handle unprocessed keys.
156+
attempts += 1
157+
delay = self._calculate_backoff_delay(attempts)
158+
log.warning(
159+
f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - "
160+
f"Retrying {len(cand_keys_list)} unprocessed keys after {delay}s delay."
161+
)
162+
time.sleep(delay)
163+
if cand_keys_list:
164+
msg = f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys_list}\n from dynamodb after {MAX_UNPROCESSED_KEYS_RETRIES} retries."
165+
log.error(msg)
166+
167+
raise KeyError(msg)
123168
return items
124169

125170
def _get_first_partitions(self, keys: list):
@@ -291,12 +336,17 @@ def _save_loop(self):
291336
def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
292337
"""
293338
Partition the item and write up to MAX_TRANSACT_WRITE_ITEMS
294-
partitions atomically. Retry up to 3 times on failure.
339+
partitions atomically using TransactWriteItems.
340+
341+
The function examines the size of pickled_val and json_val,
342+
splitting them into multiple segments based on OBJECT_SIZE,
343+
storing each segment under the same partition key.
295344
296-
Examine the size of `pickled_val` and `json_val`, and
297-
splice them into different parts based on `OBJECT_SIZE`
298-
with different sort keys, and save them under the same
299-
partition key built.
345+
It relies on the boto3/botocore retry_config to handle
346+
certain errors (e.g. throttling). If an error is not
347+
addressed by boto3's internal logic, the transaction fails
348+
and raises an exception. It is the caller's responsibility
349+
to implement further retries.
300350
"""
301351
start = time.time()
302352

@@ -337,25 +387,21 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
337387
"N": str(num_json_val_partitions),
338388
}
339389

340-
count = 0
341390
items.append(item)
342391

343-
while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
392+
# We want to write the items when we've either reached the max number of items
393+
# for a transaction, or when we're done processing all partitions
394+
if len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
344395
try:
345396
self.client.transact_write_items(TransactItems=items)
346397
items = []
347-
break # exit the while loop on successful writing
348-
except Exception as e:
349-
count += 1
350-
if count > 3:
351-
timer(
352-
name="tron.dynamodb.setitem",
353-
delta=time.time() - start,
354-
)
355-
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
356-
raise e
357-
else:
358-
log.warning(f"Got error while saving {key}, trying again: {repr(e)}")
398+
except Exception:
399+
timer(
400+
name="tron.dynamodb.setitem",
401+
delta=time.time() - start,
402+
)
403+
log.exception(f"Failed to save partition for key: {key}")
404+
raise
359405
timer(
360406
name="tron.dynamodb.setitem",
361407
delta=time.time() - start,

0 commit comments

Comments
 (0)