diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index aebfef230b..92681d9059 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2,16 +2,23 @@ import collections import random import socket +import threading +import time import warnings +from abc import ABC, abstractmethod +from copy import copy +from itertools import chain from typing import ( Any, Callable, + Coroutine, Deque, Dict, Generator, List, Mapping, Optional, + Set, Tuple, Type, TypeVar, @@ -53,7 +60,10 @@ ClusterDownError, ClusterError, ConnectionError, + CrossSlotTransactionError, DataError, + ExecAbortError, + InvalidPipelineStack, MaxConnectionsError, MovedError, RedisClusterException, @@ -62,6 +72,7 @@ SlotNotCoveredError, TimeoutError, TryAgainError, + WatchError, ) from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( @@ -870,10 +881,7 @@ def pipeline( if shard_hint: raise RedisClusterException("shard_hint is deprecated in cluster mode") - if transaction: - raise RedisClusterException("transaction is deprecated in cluster mode") - - return ClusterPipeline(self) + return ClusterPipeline(self, transaction) def lock( self, @@ -956,6 +964,30 @@ def lock( raise_on_release_error=raise_on_release_error, ) + async def transaction( + self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + shard_hint = kwargs.pop("shard_hint", None) + value_from_callable = kwargs.pop("value_from_callable", False) + watch_delay = kwargs.pop("watch_delay", None) + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = await func(pipe) + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + time.sleep(watch_delay) + continue + class ClusterNode: """ @@ -1077,6 +1109,12 @@ def acquire_connection(self) -> Connection: raise MaxConnectionsError() + def release(self, connection: Connection) -> None: + """ + Release connection back to free queue. + """ + self._free.append(connection) + async def parse_response( self, connection: Connection, command: str, **kwargs: Any ) -> Any: @@ -1247,6 +1285,9 @@ def set_nodes( task = asyncio.create_task(old[name].disconnect()) # noqa old[name] = node + def update_moved_exception(self, exception): + self._moved_exception = exception + def _update_moved_slots(self) -> None: e = self._moved_exception redirected_node = self.get_node(host=e.host, port=e.port) @@ -1514,41 +1555,47 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm | Existing :class:`~.RedisCluster` client """ - __slots__ = ("_command_stack", "_client") - - def __init__(self, client: RedisCluster) -> None: - self._client = client + __slots__ = ("cluster_client",) - self._command_stack: List["PipelineCommand"] = [] + def __init__( + self, client: RedisCluster, transaction: Optional[bool] = None + ) -> None: + self.cluster_client = client + self._transaction = transaction + self._execution_strategy: ExecutionStrategy = ( + PipelineStrategy(self) + if not self._transaction + else TransactionStrategy(self) + ) async def initialize(self) -> "ClusterPipeline": - if self._client._initialize: - await self._client.initialize() - self._command_stack = [] + if self.cluster_client._initialize: + await self.cluster_client.initialize() + self._execution_strategy._command_queue = [] return self async def __aenter__(self) -> "ClusterPipeline": return await self.initialize() async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - self._command_stack = [] + self._execution_strategy._command_queue = [] def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: return self.initialize().__await__() def __enter__(self) -> "ClusterPipeline": - self._command_stack = [] + self._execution_strategy._command_queue = [] return self def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - self._command_stack = [] + self._execution_strategy._command_queue = [] def __bool__(self) -> bool: "Pipeline instances should always evaluate to True on Python 3+" return True def __len__(self) -> int: - return len(self._command_stack) + return len(self._execution_strategy._command_queue) def execute_command( self, *args: Union[KeyT, EncodableT], **kwargs: Any @@ -1564,10 +1611,7 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ - self._command_stack.append( - PipelineCommand(len(self._command_stack), *args, **kwargs) - ) - return self + return self._execution_strategy.execute_command(*args, **kwargs) async def execute( self, raise_on_error: bool = True, allow_redirections: bool = True @@ -1587,34 +1631,307 @@ async def execute( :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ - if not self._command_stack: + try: + return await self._execution_strategy.execute( + raise_on_error, allow_redirections + ) + finally: + await self.reset() + + def _split_command_across_slots( + self, command: str, *keys: KeyT + ) -> "ClusterPipeline": + for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): + self.execute_command(command, *slot_keys) + + return self + + async def reset(self): + """ + Reset back to empty pipeline. + """ + await self._execution_strategy.reset() + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + self._execution_strategy.multi() + + async def discard(self): + """ """ + await self._execution_strategy.discard() + + async def watch(self, *names): + """Watches the values at keys ``names``""" + await self._execution_strategy.watch(*names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + await self._execution_strategy.unwatch() + + async def unlink(self, *names): + await self._execution_strategy.unlink(*names) + + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + return self._execution_strategy.mset_nonatomic(mapping) + + +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + if command == "mset_nonatomic": + continue + + setattr(ClusterPipeline, command, block_pipeline_command(command)) + + +class PipelineCommand: + def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.position = position + self.result: Union[Any, Exception] = None + + def __repr__(self) -> str: + return f"[{self.position}] {self.args} ({self.kwargs})" + + +class ExecutionStrategy(ABC): + @abstractmethod + async def initialize(self) -> "ClusterPipeline": + """ + Initialize the execution strategy. + + See ClusterPipeline.initialize() + """ + pass + + @abstractmethod + def execute_command( + self, *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> "ClusterPipeline": + """ + Append a raw command to the pipeline. + + See ClusterPipeline.execute_command() + """ + pass + + @abstractmethod + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + """ + Execute the pipeline. + + It will retry the commands as specified by retries specified in :attr:`retry` + & then raise an exception. + + See ClusterPipeline.execute() + """ + pass + + @abstractmethod + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + """ + Executes multiple MSET commands according to the provided slot/pairs mapping. + + See ClusterPipeline.mset_nonatomic() + """ + pass + + @abstractmethod + async def reset(self): + """ + Resets current execution strategy. + + See: ClusterPipeline.reset() + """ + pass + + @abstractmethod + def multi(self): + """ + Starts transactional context. + + See: ClusterPipeline.multi() + """ + pass + + @abstractmethod + async def watch(self, *names): + """ + Watch given keys. + + See: ClusterPipeline.watch() + """ + pass + + @abstractmethod + async def unwatch(self): + """ + Unwatches all previously specified keys + + See: ClusterPipeline.unwatch() + """ + pass + + @abstractmethod + async def discard(self): + pass + + @abstractmethod + async def unlink(self, *names): + """ + "Unlink a key specified by ``names``" + + See: ClusterPipeline.unlink() + """ + pass + + +class AbstractStrategy(ExecutionStrategy): + def __init__(self, pipe: ClusterPipeline) -> None: + self._pipe: ClusterPipeline = pipe + self._command_queue: List["PipelineCommand"] = [] + + async def __aenter__(self) -> "ClusterPipeline": + return await self._pipe.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_queue = [] + + def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: + return self._pipe.initialize().__await__() + + def __enter__(self) -> "ClusterPipeline": + self._command_queue = [] + return self._pipe + + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_queue = [] + + def __bool__(self) -> bool: + "Pipeline instances should always evaluate to True on Python 3+" + return True + + async def initialize(self) -> "ClusterPipeline": + if self._pipe.cluster_client._initialize: + await self._pipe.cluster_client.initialize() + self._command_queue = [] + return self._pipe + + def execute_command( + self, *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> "ClusterPipeline": + self._command_queue.append( + PipelineCommand(len(self._command_queue), *args, **kwargs) + ) + return self._pipe + + def _annotate_exception(self, exception, number, command): + """ + Provides extra context to the exception prior to it being handled + """ + cmd = " ".join(map(safe_str, command)) + msg = ( + f"Command # {number} ({truncate_text(cmd)}) of pipeline " + f"caused error: {exception.args[0]}" + ) + exception.args = (msg,) + exception.args[1:] + + @abstractmethod + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + pass + + @abstractmethod + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + pass + + @abstractmethod + async def reset(self): + pass + + @abstractmethod + def multi(self): + pass + + @abstractmethod + async def watch(self, *names): + pass + + @abstractmethod + async def unwatch(self): + pass + + @abstractmethod + async def discard(self): + pass + + @abstractmethod + async def unlink(self, *names): + pass + + +class PipelineStrategy(AbstractStrategy): + def __init__(self, pipe: ClusterPipeline) -> None: + super().__init__(pipe) + + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + encoder = self._pipe.cluster_client.encoder + + slots_pairs = {} + for pair in mapping.items(): + slot = key_slot(encoder.encode(pair[0])) + slots_pairs.setdefault(slot, []).extend(pair) + + for pairs in slots_pairs.values(): + self.execute_command("MSET", *pairs) + + return self._pipe + + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + if not self._command_queue: return [] try: - retry_attempts = self._client.retry.get_retries() + retry_attempts = self._pipe.cluster_client.retry.get_retries() while True: try: - if self._client._initialize: - await self._client.initialize() + if self._pipe.cluster_client._initialize: + await self._pipe.cluster_client.initialize() return await self._execute( - self._client, - self._command_stack, + self._pipe.cluster_client, + self._command_queue, raise_on_error=raise_on_error, allow_redirections=allow_redirections, ) - except self.__class__.ERRORS_ALLOW_RETRY as e: + except RedisCluster.ERRORS_ALLOW_RETRY as e: if retry_attempts > 0: # Try again with the new cluster setup. All other errors # should be raised. retry_attempts -= 1 - await self._client.aclose() + await self._pipe.cluster_client.aclose() await asyncio.sleep(0.25) else: # All other errors should be raised. raise e finally: - self._command_stack = [] + self._command_queue = [] async def _execute( self, @@ -1694,50 +2011,401 @@ async def _execute( for cmd in default_node[1]: # Check if it has a command that failed with a relevant # exception - if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: + if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: client.replace_default_node() break return [cmd.result for cmd in stack] - def _split_command_across_slots( - self, command: str, *keys: KeyT - ) -> "ClusterPipeline": - for slot_keys in self._client._partition_keys_by_slot(keys).values(): - self.execute_command(command, *slot_keys) + async def reset(self): + """ + Reset back to empty pipeline. + """ + self._command_queue = [] - return self + def multi(self): + raise RedisClusterException( + "method multi() is not supported outside of transactional context" + ) + + async def watch(self, *names): + raise RedisClusterException( + "method watch() is not supported outside of transactional context" + ) + + async def unwatch(self): + raise RedisClusterException( + "method unwatch() is not supported outside of transactional context" + ) + + async def discard(self): + raise RedisClusterException( + "method discard() is not supported outside of transactional context" + ) + + async def unlink(self, *names): + if len(names) != 1: + raise RedisClusterException( + "unlinking multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("UNLINK", names[0]) + + +class TransactionStrategy(AbstractStrategy): + NO_SLOTS_COMMANDS = {"UNWATCH"} + IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + SLOT_REDIRECT_ERRORS = (AskError, MovedError) + CONNECTION_ERRORS = ( + ConnectionError, + OSError, + ClusterDownError, + SlotNotCoveredError, + ) + + def __init__(self, pipe: ClusterPipeline) -> None: + super().__init__(pipe) + self._explicit_transaction = False + self._watching = False + self._pipeline_slots: Set[int] = set() + self._transaction_node: Optional[ClusterNode] = None + self._transaction_connection: Optional[Connection] = None + self._executing = False + self._retry = copy(self._pipe.cluster_client.retry) + self._retry.update_supported_errors( + RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS + ) + + def _get_client_and_connection_for_transaction( + self, + ) -> Tuple[ClusterNode, Connection]: + """ + Find a connection for a pipeline transaction. + + For running an atomic transaction, watch keys ensure that contents have not been + altered as long as the watch commands for those keys were sent over the same + connection. So once we start watching a key, we fetch a connection to the + node that owns that slot and reuse it. + """ + if not self._pipeline_slots: + raise RedisClusterException( + "At least a command with a key is needed to identify a node" + ) + + node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( + list(self._pipeline_slots)[0], False + ) + self._transaction_node = node + + if not self._transaction_connection: + connection: Connection = self._transaction_node.acquire_connection() + self._transaction_connection = connection + + return self._transaction_node, self._transaction_connection + + def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": + # Given the limitation of ClusterPipeline sync API, we have to run it in thread. + response = None + error = None + + def runner(): + nonlocal response + nonlocal error + try: + response = asyncio.run(self._execute_command(*args, **kwargs)) + except Exception as e: + error = e + + thread = threading.Thread(target=runner) + thread.start() + thread.join() + + if error: + raise error + + return response + + async def _execute_command( + self, *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> Any: + if self._pipe.cluster_client._initialize: + await self._pipe.cluster_client.initialize() + + slot_number: Optional[int] = None + if args[0] not in self.NO_SLOTS_COMMANDS: + slot_number = await self._pipe.cluster_client._determine_slot(*args) + + if ( + self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS + ) and not self._explicit_transaction: + if args[0] == "WATCH": + self._validate_watch() + + if slot_number is not None: + if self._pipeline_slots and slot_number not in self._pipeline_slots: + raise CrossSlotTransactionError( + "Cannot watch or send commands on different slots" + ) + + self._pipeline_slots.add(slot_number) + elif args[0] not in self.NO_SLOTS_COMMANDS: + raise RedisClusterException( + f"Cannot identify slot number for command: {args[0]}," + "it cannot be triggered in a transaction" + ) + + return self._immediate_execute_command(*args, **kwargs) + else: + if slot_number is not None: + self._pipeline_slots.add(slot_number) + + return super().execute_command(*args, **kwargs) + + def _validate_watch(self): + if self._explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + self._watching = True + + async def _immediate_execute_command(self, *args, **options): + return await self._retry.call_with_retry( + lambda: self._get_connection_and_send_command(*args, **options), + self._reinitialize_on_error, + ) + + async def _get_connection_and_send_command(self, *args, **options): + redis_node, connection = self._get_client_and_connection_for_transaction() + return await self._send_command_parse_response( + connection, redis_node, args[0], *args, **options + ) + + async def _send_command_parse_response( + self, + connection: Connection, + redis_node: ClusterNode, + command_name, + *args, + **options, + ): + """ + Send a command and parse the response + """ + + await connection.send_command(*args) + output = await redis_node.parse_response(connection, command_name, **options) + + if command_name in self.UNWATCH_COMMANDS: + self._watching = False + return output + + async def _reinitialize_on_error(self, error): + if self._watching: + if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: + raise WatchError("Slot rebalancing occurred while watching keys") + + if ( + type(error) in self.SLOT_REDIRECT_ERRORS + or type(error) in self.CONNECTION_ERRORS + ): + if self._transaction_connection: + self._transaction_connection = None + + self._pipe.cluster_client.reinitialize_counter += 1 + if ( + self._pipe.cluster_client.reinitialize_steps + and self._pipe.cluster_client.reinitialize_counter + % self._pipe.cluster_client.reinitialize_steps + == 0 + ): + await self._pipe.cluster_client.nodes_manager.initialize() + self.reinitialize_counter = 0 + else: + self._pipe.cluster_client.nodes_manager.update_moved_exception(error) + + self._executing = False + + def _raise_first_error(self, responses, stack): + """ + Raise the first exception on the stack + """ + for r, cmd in zip(responses, stack): + if isinstance(r, Exception): + self._annotate_exception(r, cmd.position + 1, cmd.args) + raise r def mset_nonatomic( self, mapping: Mapping[AnyKeyT, EncodableT] ) -> "ClusterPipeline": - encoder = self._client.encoder + raise NotImplementedError("Method is not supported in transactional context.") - slots_pairs = {} - for pair in mapping.items(): - slot = key_slot(encoder.encode(pair[0])) - slots_pairs.setdefault(slot, []).extend(pair) + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + stack = self._command_queue + if not stack and (not self._watching or not self._pipeline_slots): + return [] - for pairs in slots_pairs.values(): - self.execute_command("MSET", *pairs) + return await self._execute_transaction_with_retries(stack, raise_on_error) - return self + async def _execute_transaction_with_retries( + self, stack: List["PipelineCommand"], raise_on_error: bool + ): + return await self._retry.call_with_retry( + lambda: self._execute_transaction(stack, raise_on_error), + self._reinitialize_on_error, + ) + async def _execute_transaction( + self, stack: List["PipelineCommand"], raise_on_error: bool + ): + if len(self._pipeline_slots) > 1: + raise CrossSlotTransactionError( + "All keys involved in a cluster transaction must map to the same slot" + ) -for command in PIPELINE_BLOCKED_COMMANDS: - command = command.replace(" ", "_").lower() - if command == "mset_nonatomic": - continue + self._executing = True - setattr(ClusterPipeline, command, block_pipeline_command(command)) + redis_node, connection = self._get_client_and_connection_for_transaction() + stack = chain( + [PipelineCommand(0, "MULTI")], + stack, + [PipelineCommand(0, "EXEC")], + ) + commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] + packed_commands = connection.pack_commands(commands) + await connection.send_packed_command(packed_commands) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await redis_node.parse_response(connection, "MULTI") + except ResponseError as e: + self._annotate_exception(e, 0, "MULTI") + errors.append(e) + except self.CONNECTION_ERRORS as cluster_error: + self._annotate_exception(cluster_error, 0, "MULTI") + raise -class PipelineCommand: - def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: - self.args = args - self.kwargs = kwargs - self.position = position - self.result: Union[Any, Exception] = None + # and all the other commands + for i, command in enumerate(self._command_queue): + if EMPTY_RESPONSE in command.kwargs: + errors.append((i, command.kwargs[EMPTY_RESPONSE])) + else: + try: + _ = await redis_node.parse_response(connection, "_") + except self.SLOT_REDIRECT_ERRORS as slot_error: + self._annotate_exception(slot_error, i + 1, command.args) + errors.append(slot_error) + except self.CONNECTION_ERRORS as cluster_error: + self._annotate_exception(cluster_error, i + 1, command.args) + raise + except ResponseError as e: + self._annotate_exception(e, i + 1, command.args) + errors.append(e) + + response = None + # parse the EXEC. + try: + response = await redis_node.parse_response(connection, "EXEC") + except ExecAbortError: + if errors: + raise errors[0] + raise - def __repr__(self) -> str: - return f"[{self.position}] {self.args} ({self.kwargs})" + self._executing = False + + # EXEC clears any watched keys + self._watching = False + + if response is None: + raise WatchError("Watched variable changed.") + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(self._command_queue): + raise InvalidPipelineStack( + "Unexpected response length for cluster pipeline EXEC." + " Command stack was {} but response had length {}".format( + [c.args[0] for c in self._command_queue], len(response) + ) + ) + + # find any errors in the response and raise if necessary + if raise_on_error or len(errors) > 0: + self._raise_first_error( + response, + self._command_queue, + ) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, self._command_queue): + if not isinstance(r, Exception): + command_name = cmd.args[0] + if command_name in self._pipe.cluster_client.response_callbacks: + r = self._pipe.cluster_client.response_callbacks[command_name]( + r, **cmd.kwargs + ) + data.append(r) + return data + + async def reset(self): + self._command_queue = [] + + # make sure to reset the connection state in the event that we were + # watching something + if self._transaction_connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self._transaction_connection.send_command("UNWATCH") + await self._transaction_connection.read_response() + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + self._transaction_node.release(self._transaction_connection) + self._transaction_connection = None + except self.CONNECTION_ERRORS: + # disconnect will also remove any previous WATCHes + if self._transaction_connection: + await self._transaction_connection.disconnect() + + # clean up the other instance attributes + self._transaction_node = None + self._watching = False + self._explicit_transaction = False + self._pipeline_slots = set() + self._executing = False + + def multi(self): + if self._explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self._command_queue: + raise RedisError( + "Commands without an initial WATCH have already been issued" + ) + self._explicit_transaction = True + + async def watch(self, *names): + if self._explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + if self._watching: + return await self.execute_command("UNWATCH") + + return True + + async def discard(self): + await self.reset() + + async def unlink(self, *names): + return self.execute_command("UNLINK", *names) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 1b3fbd5526..7f87131c7a 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2750,10 +2750,6 @@ class TestClusterPipeline: async def test_blocked_arguments(self, r: RedisCluster) -> None: """Test handling for blocked pipeline arguments.""" - with pytest.raises(RedisClusterException) as ex: - r.pipeline(transaction=True) - - assert str(ex.value) == "transaction is deprecated in cluster mode" with pytest.raises(RedisClusterException) as ex: r.pipeline(shard_hint=True) diff --git a/tests/test_asyncio/test_cluster_transaction.py b/tests/test_asyncio/test_cluster_transaction.py new file mode 100644 index 0000000000..8b070cb289 --- /dev/null +++ b/tests/test_asyncio/test_cluster_transaction.py @@ -0,0 +1,397 @@ +from typing import Tuple +from unittest.mock import patch, Mock + +import pytest + +import redis +from redis import CrossSlotTransactionError, RedisClusterException +from redis.asyncio import RedisCluster, Connection +from redis.asyncio.cluster import ClusterNode, NodesManager +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.cluster import PRIMARY +from tests.conftest import skip_if_server_version_lt + + +def _find_source_and_target_node_for_slot( + r: RedisCluster, slot: int +) -> Tuple[ClusterNode, ClusterNode]: + """Returns a pair of ClusterNodes, where the first node is the + one that owns the slot and the second is a possible target + for that slot, i.e. a primary node different from the first + one. + """ + node_migrating = r.nodes_manager.get_node_from_slot(slot) + assert node_migrating, f"No node could be found that owns slot #{slot}" + + available_targets = [ + n + for n in r.nodes_manager.startup_nodes.values() + if node_migrating.name != n.name and n.server_type == PRIMARY + ] + + assert available_targets, f"No possible target nodes for slot #{slot}" + return node_migrating, available_targets[0] + + +@pytest.mark.onlycluster +class TestClusterTransaction: + @pytest.mark.onlycluster + async def test_pipeline_is_true(self, r) -> None: + "Ensure pipeline instances are not false-y" + async with r.pipeline(transaction=True) as pipe: + assert pipe + + @pytest.mark.onlycluster + async def test_pipeline_empty_transaction(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=True) as pipe: + assert await pipe.execute() == [] + + @pytest.mark.onlycluster + async def test_executes_transaction_against_cluster(self, r) -> None: + async with r.pipeline(transaction=True) as tx: + tx.set("{foo}bar", "value1") + tx.set("{foo}baz", "value2") + tx.set("{foo}bad", "value3") + tx.get("{foo}bar") + tx.get("{foo}baz") + tx.get("{foo}bad") + assert await tx.execute() == [ + True, + True, + True, + b"value1", + b"value2", + b"value3", + ] + + await r.flushall() + + tx = r.pipeline(transaction=True) + tx.set("{foo}bar", "value1") + tx.set("{foo}baz", "value2") + tx.set("{foo}bad", "value3") + tx.get("{foo}bar") + tx.get("{foo}baz") + tx.get("{foo}bad") + assert await tx.execute() == [True, True, True, b"value1", b"value2", b"value3"] + + @pytest.mark.onlycluster + async def test_throws_exception_on_different_hash_slots(self, r): + async with r.pipeline(transaction=True) as tx: + tx.set("{foo}bar", "value1") + tx.set("{foobar}baz", "value2") + + with pytest.raises( + CrossSlotTransactionError, + match="All keys involved in a cluster transaction must map to the same slot", + ): + await tx.execute() + + @pytest.mark.onlycluster + async def test_throws_exception_with_watch_on_different_hash_slots(self, r): + async with r.pipeline(transaction=True) as tx: + with pytest.raises( + RedisClusterException, + match="WATCH - all keys must map to the same key slot", + ): + await tx.watch("key1", "key2") + + @pytest.mark.onlycluster + async def test_transaction_with_watched_keys(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=True) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + pipe.multi() + pipe.set("a", int(a) + 1) + assert await pipe.execute() == [True] + + @pytest.mark.onlycluster + async def test_retry_transaction_during_unfinished_slot_migration(self, r): + """ + When a transaction is triggered during a migration, MovedError + or AskError may appear (depends on the key being already migrated + or the key not existing already). The patch on parse_response + simulates such an error, but the slot cache is not updated + (meaning the migration is still ongoing) so the pipeline eventually + fails as if it was retried but the migration is not yet complete. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object( + ClusterNode, "parse_response" + ) as parse_response, patch.object( + NodesManager, "_update_moved_slots" + ) as manager_update_moved_slots: + + def ask_redirect_effect(connection, *args, **options): + if "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + + parse_response.side_effect = ask_redirect_effect + + async with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + with pytest.raises(redis.exceptions.AskError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (SET book val) of pipeline caused error:" + f" {slot} {node_importing.name}" + ) + + manager_update_moved_slots.assert_called() + + @pytest.mark.onlycluster + async def test_retry_transaction_during_slot_migration_successful( + self, create_redis + ): + """ + If a MovedError or AskError appears when calling EXEC and no key is watched, + the pipeline is retried after updating the node manager slot table. If the + migration was completed, the transaction may then complete successfully. + """ + r = await create_redis(flushdb=False) + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object( + ClusterNode, "parse_response" + ) as parse_response, patch.object( + NodesManager, "_update_moved_slots" + ) as manager_update_moved_slots: + + def ask_redirect_effect(conn, *args, **options): + # first call should go here, we trigger an AskError + if f"{conn.host}:{conn.port}" == node_migrating.name: + if "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # if the slot table is updated, the next call will go here + elif f"{conn.host}:{conn.port}" == node_importing.name: + if "EXEC" in args: + return ["OK"] # mock value to validate this section was called + return + else: + assert False, f"unexpected node {conn.host}:{conn.port} was called" + + def update_moved_slot(): # simulate slot table update + ask_error = r.nodes_manager._moved_exception + assert ask_error is not None, "No AskError was previously triggered" + assert f"{ask_error.host}:{ask_error.port}" == node_importing.name + r.nodes_manager._moved_exception = None + r.nodes_manager.slots_cache[slot] = [node_importing] + + parse_response.side_effect = ask_redirect_effect + manager_update_moved_slots.side_effect = update_moved_slot + + result = None + async with r.pipeline(transaction=True) as pipe: + pipe.multi() + pipe.set(key, "val") + result = await pipe.execute() + + assert result and True in result, "Target node was not called" + + @pytest.mark.onlycluster + async def test_retry_transaction_with_watch_after_slot_migration(self, r): + """ + If a MovedError or AskError appears when calling WATCH, the client + must attempt to recover itself before proceeding and no WatchError + should appear. + """ + key = "book" + slot = r.keyslot(key) + r.reinitialize_steps = 1 + + # force a MovedError on the first call to pipe.watch() + # by switching the node that owns the slot to another one + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + r.nodes_manager.slots_cache[slot] = [node_importing] + + async with r.pipeline(transaction=True) as pipe: + await pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + assert await pipe.execute() == [True] + + @pytest.mark.onlycluster + async def test_retry_transaction_with_watch_during_slot_migration(self, r): + """ + If a MovedError or AskError appears when calling EXEC and keys were + being watched before the migration started, a WatchError should appear. + These errors imply resetting the connection and connecting to a new node, + so watches are lost anyway and the client code must be notified. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(ClusterNode, "parse_response") as parse_response: + + def ask_redirect_effect(conn, *args, **options): + if f"{conn.host}:{conn.port}" == node_migrating.name: + # we simulate the watch was sent before the migration started + if "WATCH" in args: + return b"OK" + # but the pipeline was triggered after the migration started + elif "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # we should not try to connect to any other node + else: + assert False, f"unexpected node {conn.host}:{conn.port} was called" + + parse_response.side_effect = ask_redirect_effect + + async with r.pipeline(transaction=True) as pipe: + await pipe.watch(key) + + pipe.multi() + pipe.set(key, "val") + with pytest.raises(redis.exceptions.WatchError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Slot rebalancing occurred while watching keys" + ) + + @pytest.mark.onlycluster + async def test_retry_transaction_on_connection_error(self, r): + key = "book" + slot = r.keyslot(key) + + mock_connection = Mock(spec=Connection) + mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( + "Conn error" + ) + mock_connection.retry = Retry(NoBackoff(), 0) + + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + node_importing._free.append(mock_connection) + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 + + async with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + assert await pipe.execute() == [True] + + assert mock_connection.read_response.call_count == 1 + + @pytest.mark.onlycluster + async def test_retry_transaction_on_connection_error_with_watched_keys(self, r): + key = "book" + slot = r.keyslot(key) + + mock_connection = Mock(spec=Connection) + mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( + "Conn error" + ) + mock_connection.retry = Retry(NoBackoff(), 0) + + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + node_importing._free.append(mock_connection) + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 + + async with r.pipeline(transaction=True) as pipe: + await pipe.watch(key) + + pipe.multi() + pipe.set(key, "val") + assert await pipe.execute() == [True] + + assert mock_connection.read_response.call_count == 1 + + @pytest.mark.onlycluster + async def test_exec_error_raised(self, r): + hashkey = "{key}" + await r.set(f"{hashkey}:c", "a") + + async with r.pipeline(transaction=True) as pipe: + pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH {key}:c 3) of pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + assert await r.get(f"{hashkey}:z") == b"zzz" + + @pytest.mark.onlycluster + async def test_parse_error_raised(self, r): + hashkey = "{key}" + async with r.pipeline(transaction=True) as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM {key}:b) of pipeline caused error: wrong number" + ) + + # make sure the pipe was restored to a working state + assert await pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + assert await r.get(f"{hashkey}:z") == b"zzz" + + @pytest.mark.onlycluster + async def test_transaction_callable(self, r): + hashkey = "{key}" + await r.set(f"{hashkey}:a", 1) + await r.set(f"{hashkey}:b", 2) + has_run = [] + + async def my_transaction(pipe): + a_value = await pipe.get(f"{hashkey}:a") + assert a_value in (b"1", b"2") + b_value = await pipe.get(f"{hashkey}:b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + await r.incr(f"{hashkey}:a") + has_run.append("it has") + + pipe.multi() + pipe.set(f"{hashkey}:c", int(a_value) + int(b_value)) + + result = await r.transaction(my_transaction, f"{hashkey}:a", f"{hashkey}:b") + assert result == [True] + assert await r.get(f"{hashkey}:c") == b"4" + + @pytest.mark.onlycluster + @skip_if_server_version_lt("2.0.0") + async def test_transaction_discard(self, r): + hashkey = "{key}" + + # pipelines enabled as transactions can be discarded at any point + async with r.pipeline(transaction=True) as pipe: + await pipe.watch(f"{hashkey}:key") + await pipe.set(f"{hashkey}:key", "someval") + await pipe.discard() + + assert not pipe._execution_strategy._watching + assert not pipe._execution_strategy._command_queue