Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamodb retry config for throttling and other errors. Add exponential backoff and jitter for unprocessed keys. Fix edge case where we succesfully process keys on our last attempt but still fail #1023

Merged
merged 11 commits into from
Feb 12, 2025
Merged
58 changes: 36 additions & 22 deletions tests/serialize/runstate/dynamodb_state_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from moto.dynamodb2.responses import dynamo_json_dump

from testifycompat import assert_equal
from testifycompat.assertions import assert_in
from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore
from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES


def mock_transact_write_items(self):
Expand Down Expand Up @@ -294,7 +296,8 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec
vals = store.restore([key])
assert key not in vals

def test_retry_saving(self, store, small_object, large_object):
@mock.patch("time.sleep", return_value=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my personal preference is usually to use the context manager way of mocking since that gives a little more control over where a mock is active, but not a blocker :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, not what you commented on at all, but upon closer inspection this test isn't really testing much. I'll rewrite this

def test_retry_saving(self, mock_sleep, store, small_object, large_object):
with mock.patch(
"moto.dynamodb2.responses.DynamoHandler.transact_write_items",
side_effect=KeyError("foo"),
Expand All @@ -307,45 +310,56 @@ def test_retry_saving(self, store, small_object, large_object):
except Exception:
assert_equal(mock_failed_write.call_count, 3)

def test_retry_reading(self, store, small_object, large_object):
@mock.patch("time.sleep")
@mock.patch("random.uniform")
def test_retry_reading(self, mock_random_uniform, mock_sleep, store, small_object, large_object):
unprocessed_value = {
"Responses": {
store.name: [
{
"index": {"N": "0"},
"key": {"S": "job_state 0"},
},
],
},
"Responses": {},
"UnprocessedKeys": {
store.name: {
"ConsistentRead": True,
"Keys": [
{
"index": {"N": "0"},
"key": {"S": "job_state 0"},
"index": {"N": "0"},
}
],
},
"ConsistentRead": True,
}
},
"ResponseMetadata": {},
}
keys = [store.build_key("job_state", i) for i in range(1)]
value = small_object
pairs = zip(keys, (value for i in range(len(keys))))
pairs = zip(keys, [value] * len(keys))
store.save(pairs)
store._consume_save_queue()

# Mock random.uniform to return the upper limit of the range so that we are simulating max jitter
def side_effect_random_uniform(a, b):
return b

mock_random_uniform.side_effect = side_effect_random_uniform

with mock.patch.object(
store.client,
"batch_get_item",
return_value=unprocessed_value,
) as mock_failed_read:
try:
with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch(
"tron.config.static_config.build_configuration_watcher", autospec=True
):
store.restore(keys)
except Exception:
assert_equal(mock_failed_read.call_count, 11)
with pytest.raises(Exception) as exec_info, mock.patch(
"tron.config.static_config.load_yaml_file", autospec=True
), mock.patch("tron.config.static_config.build_configuration_watcher", autospec=True):
store.restore(keys)
assert_in("failed to retrieve items with keys", str(exec_info.value))
assert_equal(mock_failed_read.call_count, MAX_UNPROCESSED_KEYS_RETRIES)

# We also need to verify that sleep was called with expected delays
expected_delays = []
base_delay_seconds = 0.5
max_delay_seconds = 10
for attempt in range(1, MAX_UNPROCESSED_KEYS_RETRIES + 1):
expected_delay = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds)
expected_delays.append(expected_delay)
actual_delays = [call.args[0] for call in mock_sleep.call_args_list]
assert_equal(actual_delays, expected_delays)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd maybe extract the exponential backoff logic in tron/serialize/runstate/dynamodb_state_store.py to a function so that we can write a more targeted test for that and simplify this to checking if we called that function the right amount of times

(mostly 'cause I generally try to avoid for loops/calculations inside tests :p)


def test_restore_exception_propagation(self, store, small_object):
# This test is to ensure that restore propagates exceptions upwards: see DAR-2328
Expand Down
93 changes: 62 additions & 31 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import pickle
import random
import threading
import time
from collections import defaultdict
Expand All @@ -20,6 +21,7 @@
from typing import TypeVar

import boto3 # type: ignore
from botocore.config import Config # type: ignore

import tron.prom_metrics as prom_metrics
from tron.core.job import Job
Expand All @@ -35,16 +37,33 @@
# to contain other attributes like object name and number of partitions.
OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 10
# This is distinct from the number of retries in the retry_config as this is used for handling unprocessed
# keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid
# infinite loops in the case where a key is truly unprocessable.
MAX_UNPROCESSED_KEYS_RETRIES = 10
MAX_TRANSACT_WRITE_ITEMS = 100
log = logging.getLogger(__name__)
T = TypeVar("T")


class DynamoDBStateStore:
def __init__(self, name, dynamodb_region, stopping=False) -> None:
self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region)
self.client = boto3.client("dynamodb", region_name=dynamodb_region)
# Standard mode includes an exponential backoff by a base factor of 2 for a
# maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a
# random number between 0 and 1 and r is the base factor of 2). This might
# look like:
#
# seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds
#
# By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending
# on the random jitter.
#
# It handles transient errors like RequestTimeout and ConnectionError, as well
# as Service-side errors like Throttling, SlowDown, and LimitExceeded.
retry_config = Config(retries={"max_attempts": 5, "mode": "standard"})

self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region, config=retry_config)
self.client = boto3.client("dynamodb", region_name=dynamodb_region, config=retry_config)
self.name = name
self.dynamodb_region = dynamodb_region
self.table = self.dynamodb.Table(name)
Expand All @@ -63,11 +82,11 @@ def build_key(self, type, iden) -> str:

def restore(self, keys, read_json: bool = False) -> dict:
"""
Fetch all under the same parition key(s).
Fetch all under the same partition key(s).
ret: <dict of key to states>
"""
# format of the keys always passed here is
# job_state job_name --> high level info about the job: enabled, run_nums
# job_state job_name --> high level info about the job: enabled, run_nums
# job_run_state job_run_name --> high level info about the job run
first_items = self._get_first_partitions(keys)
remaining_items = self._get_remaining_partitions(first_items, read_json)
Expand All @@ -87,8 +106,11 @@ def _get_items(self, table_keys: list) -> object:
items = []
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts_to_retrieve_keys = 0
while len(cand_keys_list) != 0:
attempts = 0
base_delay_seconds = 0.5
max_delay_seconds = 10

while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just posting this here for future us: we'll probably want to refactor this at some point to not consume an attempt if we're making progress (i.e., we got at least one key back) and we're simply seeing dynamodb send us partial responses (unless we wanna take a hard line with what our data sizes are such that we can always get a full chunk back at any time) and only consume an attempt if we're doing an error-caused retry

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah shit, meant to ticket that and link it in a TODO. Thanks for calling that out

with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
Expand All @@ -106,20 +128,33 @@ def _get_items(self, table_keys: list) -> object:
cand_keys_list = []
for resp in concurrent.futures.as_completed(responses):
try:
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
cand_keys_list.extend(resp.result()["UnprocessedKeys"][self.name]["Keys"])
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
)
raise error
result = resp.result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should also print the response when we get into the exception block to also have an idea on why we got unprocessed keys and why we exceeded the attempts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so maybe we add it here

                except Exception as e:
                    log.exception("Encountered issues retrieving data from DynamoDB")
                    raise e

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hesitant to dump the response because it can get pretty large. After a lot of reading I've landed on logging ResponseMetadata on ClientError. This should capture what we care about

See https://fluffy.yelpcorp.com/i/qWG1tRPrFt40M6pPr3lLkXnCSbJJBFhd.html

items.extend(result.get("Responses", {}).get(self.name, []))

# If DynamoDB returns unprocessed keys, we need to collect them and retry
unprocessed_keys = result.get("UnprocessedKeys", {}).get(self.name, {}).get("Keys", [])
if unprocessed_keys:
cand_keys_list.extend(unprocessed_keys)
except Exception as e:
log.exception("Encountered issues retrieving data from DynamoDB")
raise e
attempts_to_retrieve_keys += 1
if cand_keys_list:
attempts += 1
# Exponential backoff for retrying unprocessed keys
exponential_delay = min(base_delay_seconds * (2 ** (attempts - 1)), max_delay_seconds)
# Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls
jitter = random.uniform(0, exponential_delay)
delay = jitter
log.warning(
f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - Retrying {len(cand_keys_list)} unprocessed keys after {delay:.2f}s delay."
)
time.sleep(delay)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What to do about this lil guy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!8ball we should use a restore thread

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we should probably try to figure out a non-blocking way to do this or have this run in a separate thread - if we get to the worst case of 5 attempts and this is running on the reactor thread, we'll essentially block all of tron from doing anything for 20s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although, actually - this is probably fine since we do all sorts of blocking stuff in restore and aren't expecting tron to be usable/do anything until we've restored everything

...so maybe this is fine?

if cand_keys_list:
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys_list}\n from dynamodb after {MAX_UNPROCESSED_KEYS_RETRIES} retries."
)
log.error(repr(error))
raise error
return items

def _get_first_partitions(self, keys: list):
Expand Down Expand Up @@ -337,25 +372,21 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None:
"N": str(num_json_val_partitions),
}

count = 0
items.append(item)

while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
# We want to write the items when we've either reached the max number of items
# for a transaction, or when we're done processing all partitions
if len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1:
try:
self.client.transact_write_items(TransactItems=items)
items = []
break # exit the while loop on successful writing
except Exception as e:
count += 1
if count > 3:
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
)
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
raise e
else:
log.warning(f"Got error while saving {key}, trying again: {repr(e)}")
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
)
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i know this is old code being moved around, so you can leave this as-is but

Suggested change
log.error(f"Failed to save partition for key: {key}, error: {repr(e)}")
log.exception(f"Failed to save partition for key: {key}")

would include the full traceback for us automatically :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although: it looks like there's a behavior change here?

might be worth adding a comment here (or in the docstring) that this function will not retry on its own and that it's the callers responsibility to do so)

raise e
timer(
name="tron.dynamodb.setitem",
delta=time.time() - start,
Expand Down