From d3874dc9e920dd655c6d486187aeff700d83d1f7 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 13 Oct 2025 09:10:15 -0400 Subject: [PATCH 1/8] Fix locking logic for NodesManager --- redis/cluster.py | 496 ++++++++++++++++-------------- tests/test_cluster_transaction.py | 11 +- 2 files changed, 262 insertions(+), 245 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 1d4a3e0d0c..a8f2537756 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -8,7 +8,18 @@ from copy import copy from enum import Enum from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Union, + final, +) from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan @@ -1275,7 +1286,7 @@ def _execute_command(self, target_node, *args, **kwargs): # Reset the counter self.reinitialize_counter = 0 else: - self.nodes_manager.update_moved_exception(e) + self.nodes_manager.move_slot(e) moved = True except TryAgainError: if ttl < self.RedisClusterRequestTTL / 2: @@ -1414,8 +1425,9 @@ class LoadBalancer: """ def __init__(self, start_index: int = 0) -> None: - self.primary_to_idx = {} - self.start_index = start_index + self.primary_to_idx: dict[str, int] = {} + self.start_index: int = start_index + self._lock: threading.Lock = threading.Lock() def get_server_index( self, @@ -1433,7 +1445,8 @@ def get_server_index( ) def reset(self) -> None: - self.primary_to_idx.clear() + with self._lock: + self.primary_to_idx.clear() def _get_random_replica_index(self, list_size: int) -> int: return random.randint(1, list_size - 1) @@ -1441,22 +1454,22 @@ def _get_random_replica_index(self, list_size: int) -> int: def _get_round_robin_index( self, primary: str, list_size: int, replicas_only: bool ) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - if replicas_only and server_index == 0: - # skip the primary node index - server_index = 1 - # Update the index for the next round - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index + with self._lock: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + if replicas_only and server_index == 0: + # skip the primary node index + server_index = 1 + # Update the index for the next round + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index class NodesManager: def __init__( self, - startup_nodes, + startup_nodes: list[ClusterNode], from_url=False, require_full_coverage=False, - lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, @@ -1466,25 +1479,33 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, **kwargs, ): - self.nodes_cache: Dict[str, Redis] = {} - self.slots_cache = {} - self.startup_nodes = {} - self.default_node = None - self.populate_startup_nodes(startup_nodes) + self.nodes_cache: dict[str, ClusterNode] = {} + self.slots_cache: dict[int, list[ClusterNode]] = {} + self.startup_nodes: dict[str, ClusterNode] = {n.name: n for n in startup_nodes} + self.default_node: ClusterNode | None = None + self._epoch: int = 0 self.from_url = from_url self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap - self._cache = cache - self._cache_config = cache_config - self._cache_factory = cache_factory - self._moved_exception = None + self._cache: CacheInterface | None = None + if cache: + self._cache = cache + elif cache_factory is not None: + self._cache = cache_factory.get_cache() + elif cache_config is not None: + self._cache = CacheFactory(cache_config).get_cache() self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() - if lock is None: - lock = threading.RLock() - self._lock = lock + + # nodes_cache / slots_cache / startup_nodes / default_node are protected by _lock + self._lock: threading.RLock = threading.RLock() + # initialize holds _initialization_lock to dedup multiple calls to reinitialize; + # note that if we hold both _lock and _initialization_lock, we _must_ acquire + # _initialization_lock first (ie: to have a consistent order) to avoid deadlock. + self._initialization_lock: threading.Lock = threading.Lock() + if event_dispatcher is None: self._event_dispatcher = EventDispatcher() else: @@ -1494,7 +1515,12 @@ def __init__( ) self.initialize() - def get_node(self, host=None, port=None, node_name=None): + def get_node( + self, + host: str | None = None, + port: int | None = None, + node_name: str | None = None, + ) -> ClusterNode | None: """ Get the requested node from the cluster's nodes. nodes. @@ -1504,53 +1530,50 @@ def get_node(self, host=None, port=None, node_name=None): # the user passed host and port if host == "localhost": host = socket.gethostbyname(host) - return self.nodes_cache.get(get_node_name(host=host, port=port)) + with self._lock: + return self.nodes_cache.get(get_node_name(host=host, port=port)) elif node_name: - return self.nodes_cache.get(node_name) + with self._lock: + return self.nodes_cache.get(node_name) else: return None - def update_moved_exception(self, exception): - self._moved_exception = exception - - def _update_moved_slots(self): + def move_slot(self, e: AskError | MovedError): """ Update the slot's node with the redirected one """ - e = self._moved_exception - redirected_node = self.get_node(host=e.host, port=e.port) - if redirected_node is not None: - # The node already exists - if redirected_node.server_type is not PRIMARY: - # Update the node's server type - redirected_node.server_type = PRIMARY - else: - # This is a new node, we will add it to the nodes cache - redirected_node = ClusterNode(e.host, e.port, PRIMARY) - self.nodes_cache[redirected_node.name] = redirected_node - if redirected_node in self.slots_cache[e.slot_id]: - # The MOVED error resulted from a failover, and the new slot owner - # had previously been a replica. - old_primary = self.slots_cache[e.slot_id][0] - # Update the old primary to be a replica and add it to the end of - # the slot's node list - old_primary.server_type = REPLICA - self.slots_cache[e.slot_id].append(old_primary) - # Remove the old replica, which is now a primary, from the slot's - # node list - self.slots_cache[e.slot_id].remove(redirected_node) - # Override the old primary with the new one - self.slots_cache[e.slot_id][0] = redirected_node - if self.default_node == old_primary: - # Update the default node with the new primary - self.default_node = redirected_node - else: - # The new slot owner is a new server, or a server from a different - # shard. We need to remove all current nodes from the slot's list - # (including replications) and add just the new node. - self.slots_cache[e.slot_id] = [redirected_node] - # Reset moved_exception - self._moved_exception = None + with self._lock: + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node is not None: + # The node already exists + if redirected_node.server_type is not PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode(e.host, e.port, PRIMARY) + self.nodes_cache[redirected_node.name] = redirected_node + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] @deprecated_args( args_to_warn=["server_type"], @@ -1562,66 +1585,56 @@ def _update_moved_slots(self): ) def get_node_from_slot( self, - slot, - read_from_replicas=False, - load_balancing_strategy=None, - server_type=None, + slot: int, + read_from_replicas: bool = False, + load_balancing_strategy: LoadBalancingStrategy | None = None, + server_type: Literal["primary", "replica"] | None = None, ) -> ClusterNode: """ Gets a node that servers this hash slot """ - if self._moved_exception: - with self._lock: - if self._moved_exception: - self._update_moved_slots() - - if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' - ) if read_from_replicas is True and load_balancing_strategy is None: load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN - if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: - # get the server index using the strategy defined in load_balancing_strategy - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]), load_balancing_strategy - ) - elif ( - server_type is None - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 - ): - # return a primary - node_idx = 0 - else: - # return a replica - # randomly choose one of the replicas - node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) + with self._lock: + if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: + raise SlotNotCoveredError( + f'Slot "{slot}" not covered by the cluster. ' + + f'"require_full_coverage={self._require_full_coverage}"' + ) + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]), load_balancing_strategy + ) + elif ( + server_type is None + or server_type == PRIMARY + or len(self.slots_cache[slot]) == 1 + ): + # return a primary + node_idx = 0 + else: + # return a replica + # randomly choose one of the replicas + node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) - return self.slots_cache[slot][node_idx] + return self.slots_cache[slot][node_idx] - def get_nodes_by_server_type(self, server_type): + def get_nodes_by_server_type(self, server_type: Literal["primary", "replica"]): """ Get all nodes with the specified server type :param server_type: 'primary' or 'replica' :return: list of ClusterNode """ - return [ - node - for node in self.nodes_cache.values() - if node.server_type == server_type - ] - - def populate_startup_nodes(self, nodes): - """ - Populate all startup nodes and filters out any duplicates - """ - for n in nodes: - self.startup_nodes[n.name] = n + with self._lock: + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] def check_slots_coverage(self, slots_cache): # Validate if all slots are covered or if we should try next @@ -1688,7 +1701,8 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): # before creating a new cluster node, check if the cluster node already # exists in the current nodes cache and has a valid connection so we can # reuse it - target_node = self.nodes_cache.get(node_name) + with self._lock: + target_node = self.nodes_cache.get(node_name) if target_node is None or target_node.redis_connection is None: # create new cluster node for this cluster target_node = ClusterNode(host, port, role) @@ -1713,135 +1727,143 @@ def initialize(self): fully_covered = False kwargs = self.connection_kwargs exception = None - # Convert to tuple to prevent RuntimeError if self.startup_nodes - # is modified during iteration - for startup_node in tuple(self.startup_nodes.values()): - try: - if startup_node.redis_connection: - r = startup_node.redis_connection - else: - # Create a new Redis connection - r = self.create_redis_node( - startup_node.host, startup_node.port, **kwargs - ) - self.startup_nodes[startup_node.name].redis_connection = r - # Make sure cluster mode is enabled on this node - try: - cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - r.connection_pool.disconnect() - except ResponseError: - raise RedisClusterException( - "Cluster mode is not enabled on this node" - ) - startup_nodes_reachable = True - except Exception as e: - # Try the next startup node. - # The exception is saved and raised only if we have no more nodes. - exception = e - continue - - # CLUSTER SLOTS command results in the following output: - # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] - # where each node contains the following list: [IP, port, node_id] - # Therefore, cluster_slots[0][2][0] will be the IP address of the - # primary node of the first slot section. - # If there's only one server in the cluster, its ``host`` is '' - # Fix it to the host in startup_nodes - if ( - len(cluster_slots) == 1 - and len(cluster_slots[0][2][0]) == 0 - and len(self.startup_nodes) == 1 - ): - cluster_slots[0][2][0] = startup_node.host - - for slot in cluster_slots: - primary_node = slot[2] - host = str_if_bytes(primary_node[0]) - if host == "": - host = startup_node.host - port = int(primary_node[1]) - host, port = self.remap_host_port(host, port) + with self._lock: + epoch = self._epoch - nodes_for_slot = [] + with self._initialization_lock: + # randomly order the startup nodes to ensure multiple clients evenly + # distribute topology discovery requests across the cluster. + with self._lock: + if epoch != self._epoch: + # another thread has already re-initialized the nodes; don't + # bother running again + return - target_node = self._get_or_create_cluster_node( - host, port, PRIMARY, tmp_nodes_cache + startup_nodes = random.sample( + list(self.startup_nodes.values()), k=len(self.startup_nodes) ) - nodes_for_slot.append(target_node) - replica_nodes = slot[3:] - for replica_node in replica_nodes: - host = str_if_bytes(replica_node[0]) - port = int(replica_node[1]) + for startup_node in startup_nodes: + try: + if startup_node.redis_connection: + r = startup_node.redis_connection + else: + # Create a new Redis connection + r = self.create_redis_node( + startup_node.host, startup_node.port, **kwargs + ) + self.startup_nodes[startup_node.name].redis_connection = r + # Make sure cluster mode is enabled on this node + try: + cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + except ResponseError: + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + startup_nodes_reachable = True + except Exception as e: + # Try the next startup node. + # The exception is saved and raised only if we have no more nodes. + exception = e + continue + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and len(cluster_slots[0][2][0]) == 0 + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + primary_node = slot[2] + host = str_if_bytes(primary_node[0]) + if host == "": + host = startup_node.host + port = int(primary_node[1]) host, port = self.remap_host_port(host, port) - target_replica_node = self._get_or_create_cluster_node( - host, port, REPLICA, tmp_nodes_cache - ) - nodes_for_slot.append(target_replica_node) - for i in range(int(slot[0]), int(slot[1]) + 1): - if i not in tmp_slots: - tmp_slots[i] = nodes_for_slot - else: - # Validate that 2 nodes want to use the same slot cache - # setup - tmp_slot = tmp_slots[i][0] - if tmp_slot.name != target_node.name: - disagreements.append( - f"{tmp_slot.name} vs {target_node.name} on slot: {i}" - ) - - if len(disagreements) > 5: - raise RedisClusterException( - f"startup_nodes could not agree on a valid " - f"slots cache: {', '.join(disagreements)}" + nodes_for_slot = [] + + target_node = self._get_or_create_cluster_node( + host, port, PRIMARY, tmp_nodes_cache + ) + nodes_for_slot.append(target_node) + + replica_nodes = slot[3:] + for replica_node in replica_nodes: + host = str_if_bytes(replica_node[0]) + port = int(replica_node[1]) + host, port = self.remap_host_port(host, port) + target_replica_node = self._get_or_create_cluster_node( + host, port, REPLICA, tmp_nodes_cache + ) + nodes_for_slot.append(target_replica_node) + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = nodes_for_slot + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" ) - fully_covered = self.check_slots_coverage(tmp_slots) - if fully_covered: - # Don't need to continue to the next startup node if all - # slots are covered - break + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f"slots cache: {', '.join(disagreements)}" + ) - if not startup_nodes_reachable: - raise RedisClusterException( - f"Redis Cluster cannot be connected. Please provide at least " - f"one reachable node: {str(exception)}" - ) from exception + fully_covered = self.check_slots_coverage(tmp_slots) + if fully_covered: + # Don't need to continue to the next startup node if all + # slots are covered + break - if self._cache is None and self._cache_config is not None: - if self._cache_factory is None: - self._cache = CacheFactory(self._cache_config).get_cache() - else: - self._cache = self._cache_factory.get_cache() + if not startup_nodes_reachable: + raise RedisClusterException( + f"Redis Cluster cannot be connected. Please provide at least " + f"one reachable node: {str(exception)}" + ) from exception - # Create Redis connections to all nodes - self.create_redis_connections(list(tmp_nodes_cache.values())) + # Create Redis connections to all nodes + self.create_redis_connections(list(tmp_nodes_cache.values())) - # Check if the slots are not fully covered - if not fully_covered and self._require_full_coverage: - # Despite the requirement that the slots be covered, there - # isn't a full coverage - raise RedisClusterException( - f"All slots are not covered after query all startup_nodes. " - f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " - f"covered..." - ) + # Check if the slots are not fully covered + if not fully_covered and self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) - # Set the tmp variables to the real variables - self.nodes_cache = tmp_nodes_cache - self.slots_cache = tmp_slots - # Set the default node - self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] - if self._dynamic_startup_nodes: - # Populate the startup nodes with all discovered nodes - self.startup_nodes = tmp_nodes_cache - # If initialize was called after a MovedError, clear it - self._moved_exception = None + # Set the tmp variables to the real variables + with self._lock: + self.nodes_cache = tmp_nodes_cache + self.slots_cache = tmp_slots + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + if self._dynamic_startup_nodes: + # Populate the startup nodes with all discovered nodes + self.startup_nodes = tmp_nodes_cache def close(self) -> None: - self.default_node = None - for node in self.nodes_cache.values(): + with self._lock: + self.default_node = None + nodes = tuple(self.nodes_cache.values()) + for node in nodes: if node.redis_connection: node.redis_connection.close() @@ -1862,15 +1884,17 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def find_connection_owner(self, connection: Connection) -> Optional[Redis]: + def find_connection_owner(self, connection: Connection) -> ClusterNode | None: node_name = get_node_name(connection.host, connection.port) - for node in tuple(self.nodes_cache.values()): - if node.redis_connection: - conn_args = node.redis_connection.connection_pool.connection_kwargs - if node_name == get_node_name( - conn_args.get("host"), conn_args.get("port") - ): - return node + with self._lock: + for node in tuple(self.nodes_cache.values()): + if node.redis_connection: + conn_args = node.redis_connection.connection_pool.connection_kwargs + if node_name == get_node_name( + conn_args.get("host"), conn_args.get("port") + ): + return node + return None class ClusterPubSub(PubSub): @@ -3166,7 +3190,7 @@ def _reinitialize_on_error(self, error): self.reinitialize_counter = 0 else: if isinstance(error, AskError): - self._nodes_manager.update_moved_exception(error) + self._nodes_manager.move_slot(error) self._executing = False diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index 6ebd6df566..08bdc250fe 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -126,9 +126,6 @@ def test_retry_transaction_during_unfinished_slot_migration(self, r): with ( patch.object(Redis, "parse_response") as parse_response, - patch.object( - NodesManager, "_update_moved_slots" - ) as manager_update_moved_slots, ): def ask_redirect_effect(connection, *args, **options): @@ -151,8 +148,6 @@ def ask_redirect_effect(connection, *args, **options): f" {slot} {node_importing.name}" ) - manager_update_moved_slots.assert_called() - @pytest.mark.onlycluster def test_retry_transaction_during_slot_migration_successful(self, r): """ @@ -166,9 +161,6 @@ def test_retry_transaction_during_slot_migration_successful(self, r): with ( patch.object(Redis, "parse_response") as parse_response, - patch.object( - NodesManager, "_update_moved_slots" - ) as manager_update_moved_slots, ): def ask_redirect_effect(conn, *args, **options): @@ -198,7 +190,7 @@ def update_moved_slot(): # simulate slot table update r.nodes_manager.slots_cache[slot] = [node_importing] parse_response.side_effect = ask_redirect_effect - manager_update_moved_slots.side_effect = update_moved_slot + # manager_update_moved_slots.side_effect = update_moved_slot result = None with r.pipeline(transaction=True) as pipe: @@ -311,6 +303,7 @@ def test_retry_transaction_on_connection_error_with_watched_keys( mock_pool.get_connection.return_value = mock_connection mock_pool._available_connections = [mock_connection] mock_pool._lock = threading.RLock() + mock_pool.connection_kwargs = {} _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) node_importing.redis_connection.connection_pool = mock_pool From 1c0d6d4249637e6be76b5f6292d89b1f0842fb4f Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 13 Oct 2025 15:04:41 -0400 Subject: [PATCH 2/8] Apply similar changes to async cluster code --- redis/asyncio/cluster.py | 263 +++++++++--------- .../test_asyncio/test_cluster_transaction.py | 17 +- tests/test_cluster_transaction.py | 8 - 3 files changed, 135 insertions(+), 153 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 225fd3b79f..51a3a73854 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -912,7 +912,7 @@ async def _execute_command( # Reset the counter self.reinitialize_counter = 0 else: - self.nodes_manager._moved_exception = e + self.nodes_manager.move_slot(e) moved = True except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) @@ -1266,12 +1266,13 @@ async def _mock(self, error: RedisError): class NodesManager: __slots__ = ( "_dynamic_startup_nodes", - "_moved_exception", "_event_dispatcher", "connection_kwargs", "default_node", "nodes_cache", + "_epoch", "read_load_balancer", + "_initialize_lock", "require_full_coverage", "slots_cache", "startup_nodes", @@ -1295,10 +1296,11 @@ def __init__( self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} self.slots_cache: Dict[int, List["ClusterNode"]] = {} + self._epoch: int = 0 self.read_load_balancer = LoadBalancer() + self._initialize_lock: asyncio.Lock = asyncio.Lock() self._dynamic_startup_nodes: bool = dynamic_startup_nodes - self._moved_exception: MovedError = None if event_dispatcher is None: self._event_dispatcher = EventDispatcher() else: @@ -1340,11 +1342,7 @@ def set_nodes( task = asyncio.create_task(old[name].disconnect()) # noqa old[name] = node - def update_moved_exception(self, exception): - self._moved_exception = exception - - def _update_moved_slots(self) -> None: - e = self._moved_exception + def move_slot(self, e: AskError | MovedError): redirected_node = self.get_node(host=e.host, port=e.port) if redirected_node: # The node already exists @@ -1378,8 +1376,6 @@ def _update_moved_slots(self) -> None: # shard. We need to remove all current nodes from the slot's list # (including replications) and add just the new node. self.slots_cache[e.slot_id] = [redirected_node] - # Reset moved_exception - self._moved_exception = None def get_node_from_slot( self, @@ -1387,9 +1383,6 @@ def get_node_from_slot( read_from_replicas: bool = False, load_balancing_strategy=None, ) -> "ClusterNode": - if self._moved_exception: - self._update_moved_slots() - if read_from_replicas is True and load_balancing_strategy is None: load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN @@ -1423,135 +1416,147 @@ async def initialize(self) -> None: startup_nodes_reachable = False fully_covered = False exception = None - # Convert to tuple to prevent RuntimeError if self.startup_nodes - # is modified during iteration - for startup_node in tuple(self.startup_nodes.values()): - try: - # Make sure cluster mode is enabled on this node + epoch = self._epoch + + async with self._initialize_lock: + if self._epoch != epoch: + # another initialize call has already reinitialized the + # nodes since we started waiting for the lock; + # we don't need to do it again. + return + + # Convert to tuple to prevent RuntimeError if self.startup_nodes + # is modified during iteration + for startup_node in tuple(self.startup_nodes.values()): try: - self._event_dispatcher.dispatch( - AfterAsyncClusterInstantiationEvent( - self.nodes_cache, - self.connection_kwargs.get("credential_provider", None), + # Make sure cluster mode is enabled on this node + try: + self._event_dispatcher.dispatch( + AfterAsyncClusterInstantiationEvent( + self.nodes_cache, + self.connection_kwargs.get("credential_provider", None), + ) ) - ) - cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") - except ResponseError: - raise RedisClusterException( - "Cluster mode is not enabled on this node" - ) - startup_nodes_reachable = True - except Exception as e: - # Try the next startup node. - # The exception is saved and raised only if we have no more nodes. - exception = e - continue - - # CLUSTER SLOTS command results in the following output: - # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] - # where each node contains the following list: [IP, port, node_id] - # Therefore, cluster_slots[0][2][0] will be the IP address of the - # primary node of the first slot section. - # If there's only one server in the cluster, its ``host`` is '' - # Fix it to the host in startup_nodes - if ( - len(cluster_slots) == 1 - and not cluster_slots[0][2][0] - and len(self.startup_nodes) == 1 - ): - cluster_slots[0][2][0] = startup_node.host - - for slot in cluster_slots: - for i in range(2, len(slot)): - slot[i] = [str_if_bytes(val) for val in slot[i]] - primary_node = slot[2] - host = primary_node[0] - if host == "": - host = startup_node.host - port = int(primary_node[1]) - host, port = self.remap_host_port(host, port) - - nodes_for_slot = [] - - target_node = tmp_nodes_cache.get(get_node_name(host, port)) - if not target_node: - target_node = ClusterNode( - host, port, PRIMARY, **self.connection_kwargs - ) - # add this node to the nodes cache - tmp_nodes_cache[target_node.name] = target_node - nodes_for_slot.append(target_node) - - replica_nodes = slot[3:] - for replica_node in replica_nodes: - host = replica_node[0] - port = replica_node[1] + cluster_slots = await startup_node.execute_command( + "CLUSTER SLOTS" + ) + except ResponseError: + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + startup_nodes_reachable = True + except Exception as e: + # Try the next startup node. + # The exception is saved and raised only if we have no more nodes. + exception = e + continue + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and not cluster_slots[0][2][0] + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + for i in range(2, len(slot)): + slot[i] = [str_if_bytes(val) for val in slot[i]] + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) host, port = self.remap_host_port(host, port) - target_replica_node = tmp_nodes_cache.get(get_node_name(host, port)) - if not target_replica_node: - target_replica_node = ClusterNode( - host, port, REPLICA, **self.connection_kwargs + nodes_for_slot = [] + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if not target_node: + target_node = ClusterNode( + host, port, PRIMARY, **self.connection_kwargs ) # add this node to the nodes cache - tmp_nodes_cache[target_replica_node.name] = target_replica_node - nodes_for_slot.append(target_replica_node) + tmp_nodes_cache[target_node.name] = target_node + nodes_for_slot.append(target_node) - for i in range(int(slot[0]), int(slot[1]) + 1): - if i not in tmp_slots: - tmp_slots[i] = nodes_for_slot - else: - # Validate that 2 nodes want to use the same slot cache - # setup - tmp_slot = tmp_slots[i][0] - if tmp_slot.name != target_node.name: - disagreements.append( - f"{tmp_slot.name} vs {target_node.name} on slot: {i}" - ) + replica_nodes = slot[3:] + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + host, port = self.remap_host_port(host, port) - if len(disagreements) > 5: - raise RedisClusterException( - f"startup_nodes could not agree on a valid " - f"slots cache: {', '.join(disagreements)}" + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port) + ) + if not target_replica_node: + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) + # add this node to the nodes cache + tmp_nodes_cache[target_replica_node.name] = target_replica_node + nodes_for_slot.append(target_replica_node) + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = nodes_for_slot + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" ) - # Validate if all slots are covered or if we should try next startup node - fully_covered = True - for i in range(REDIS_CLUSTER_HASH_SLOTS): - if i not in tmp_slots: - fully_covered = False + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f"slots cache: {', '.join(disagreements)}" + ) + + # Validate if all slots are covered or if we should try next startup node + fully_covered = True + for i in range(REDIS_CLUSTER_HASH_SLOTS): + if i not in tmp_slots: + fully_covered = False + break + if fully_covered: break - if fully_covered: - break - if not startup_nodes_reachable: - raise RedisClusterException( - f"Redis Cluster cannot be connected. Please provide at least " - f"one reachable node: {str(exception)}" - ) from exception - - # Check if the slots are not fully covered - if not fully_covered and self.require_full_coverage: - # Despite the requirement that the slots be covered, there - # isn't a full coverage - raise RedisClusterException( - f"All slots are not covered after query all startup_nodes. " - f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " - f"covered..." - ) + if not startup_nodes_reachable: + raise RedisClusterException( + f"Redis Cluster cannot be connected. Please provide at least " + f"one reachable node: {str(exception)}" + ) from exception + + # Check if the slots are not fully covered + if not fully_covered and self.require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) - # Set the tmp variables to the real variables - self.slots_cache = tmp_slots - self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) + # Set the tmp variables to the real variables + self.slots_cache = tmp_slots + self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) - if self._dynamic_startup_nodes: - # Populate the startup nodes with all discovered nodes - self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) + if self._dynamic_startup_nodes: + # Populate the startup nodes with all discovered nodes + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) - # Set the default node - self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] - # If initialize was called after a MovedError, clear it - self._moved_exception = None + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + self._epoch += 1 async def aclose(self, attr: str = "nodes_cache") -> None: self.default_node = None @@ -2255,9 +2260,7 @@ async def _reinitialize_on_error(self, error): self.reinitialize_counter = 0 else: if isinstance(error, AskError): - self._pipe.cluster_client.nodes_manager.update_moved_exception( - error - ) + self._pipe.cluster_client.nodes_manager.move_slot(error) self._executing = False diff --git a/tests/test_asyncio/test_cluster_transaction.py b/tests/test_asyncio/test_cluster_transaction.py index 5e540eae5e..e39d4aaab9 100644 --- a/tests/test_asyncio/test_cluster_transaction.py +++ b/tests/test_asyncio/test_cluster_transaction.py @@ -127,9 +127,7 @@ async def test_retry_transaction_during_unfinished_slot_migration(self, r): with ( patch.object(ClusterNode, "parse_response") as parse_response, - patch.object( - NodesManager, "_update_moved_slots" - ) as manager_update_moved_slots, + patch.object(NodesManager, "move_slot") as manager_move_slot, ): def ask_redirect_effect(connection, *args, **options): @@ -152,7 +150,7 @@ def ask_redirect_effect(connection, *args, **options): f" {slot} {node_importing.name}" ) - manager_update_moved_slots.assert_called() + manager_move_slot.assert_called() @pytest.mark.onlycluster async def test_retry_transaction_during_slot_migration_successful( @@ -170,9 +168,6 @@ async def test_retry_transaction_during_slot_migration_successful( with ( patch.object(ClusterNode, "parse_response") as parse_response, - patch.object( - NodesManager, "_update_moved_slots" - ) as manager_update_moved_slots, ): def ask_redirect_effect(conn, *args, **options): @@ -192,15 +187,7 @@ def ask_redirect_effect(conn, *args, **options): else: assert False, f"unexpected node {conn.host}:{conn.port} was called" - def update_moved_slot(): # simulate slot table update - ask_error = r.nodes_manager._moved_exception - assert ask_error is not None, "No AskError was previously triggered" - assert f"{ask_error.host}:{ask_error.port}" == node_importing.name - r.nodes_manager._moved_exception = None - r.nodes_manager.slots_cache[slot] = [node_importing] - parse_response.side_effect = ask_redirect_effect - manager_update_moved_slots.side_effect = update_moved_slot result = None async with r.pipeline(transaction=True) as pipe: diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index 08bdc250fe..b7efed8815 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -182,15 +182,7 @@ def ask_redirect_effect(conn, *args, **options): else: assert False, f"unexpected node {conn.host}:{conn.port} was called" - def update_moved_slot(): # simulate slot table update - ask_error = r.nodes_manager._moved_exception - assert ask_error is not None, "No AskError was previously triggered" - assert f"{ask_error.host}:{ask_error.port}" == node_importing.name - r.nodes_manager._moved_exception = None - r.nodes_manager.slots_cache[slot] = [node_importing] - parse_response.side_effect = ask_redirect_effect - # manager_update_moved_slots.side_effect = update_moved_slot result = None with r.pipeline(transaction=True) as pipe: From 8187504fede011ccbbf759948cd6e79e68964c2e Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 13 Oct 2025 23:26:34 -0400 Subject: [PATCH 3/8] Fix tests --- dev_requirements.txt | 2 +- redis/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index e61f37f101..3fe5ac25b1 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,7 +4,7 @@ invoke==2.2.0 mock packaging>=20.4 pytest -pytest-asyncio>=0.23.0 +pytest-asyncio>=0.24.0 pytest-cov pytest-profiling==1.8.1 pytest-timeout diff --git a/redis/__init__.py b/redis/__init__.py index 67f165d9fe..795662d2e2 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -46,7 +46,7 @@ def int_or_str(value): return value -__version__ = "6.2.0" +__version__ = "6.4.0" VERSION = tuple(map(int_or_str, __version__.split("."))) From 77b40410e5b7af850f1fcc00ef661621fce76680 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 17 Oct 2025 06:44:49 -0700 Subject: [PATCH 4/8] Add tests --- redis/cluster.py | 13 +- tests/test_cluster.py | 363 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 374 insertions(+), 2 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index a8f2537756..26afad4fed 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1713,6 +1713,14 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): return target_node + def _get_epoch(self) -> int: + """ + Get the current epoch value. This method exists primarily to allow + tests to mock the epoch fetch and control race condition timing. + """ + with self._lock: + return self._epoch + def initialize(self): """ Initializes the nodes cache, slots cache and redis connections. @@ -1727,8 +1735,7 @@ def initialize(self): fully_covered = False kwargs = self.connection_kwargs exception = None - with self._lock: - epoch = self._epoch + epoch = self._get_epoch() with self._initialization_lock: # randomly order the startup nodes to ensure multiple clients evenly @@ -1858,6 +1865,8 @@ def initialize(self): if self._dynamic_startup_nodes: # Populate the startup nodes with all discovered nodes self.startup_nodes = tmp_nodes_cache + # Increment the epoch to signal that initialization has completed + self._epoch += 1 def close(self) -> None: with self._lock: diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 2936bb0024..7ccfc63669 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2947,6 +2947,369 @@ def test_allow_custom_queue_class(self, queue_class): for node in rc.nodes_manager.nodes_cache.values(): assert node.redis_connection.connection_pool.queue_class == queue_class + def test_concurrent_get_node(self): + """ + Test that concurrent get_node calls are thread-safe + """ + rc = get_mocked_redis_client(host=default_host, port=default_port) + n_manager = rc.nodes_manager + + results = [] + errors = [] + + def get_node_repeatedly(host, port, iterations): + """Get a node repeatedly""" + try: + for _ in range(iterations): + node = n_manager.get_node(host=host, port=port) + results.append(node) + except Exception as e: + errors.append(e) + + # Create threads that will concurrently read nodes + threads = [] + for _ in range(10): + threads.append( + threading.Thread( + target=get_node_repeatedly, args=("127.0.0.1", 7000, 100) + ) + ) + + # Start all threads + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check that no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Check that we got results + assert len(results) == 1000 + + def test_concurrent_initialize(self): + """ + Test that concurrent initialize calls are properly deduplicated + by the _initialization_lock + """ + initialization_count = {"count": 0} + + with ( + patch.object(Redis, "execute_command") as execute_command_mock, + patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize, + ): + + def execute_command(*_args, **_kwargs): + if _args[0] == "CLUSTER SLOTS": + # Track how many times we actually fetch cluster slots + initialization_count["count"] += 1 + # Add a small delay to make race conditions more likely + # sleep(0.01) + return default_cluster_slots + elif _args[0] == "COMMAND": + return {"get": [], "set": []} + elif _args[0] == "INFO": + return {"cluster_enabled": True} + elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "no"} + else: + return execute_command_mock(*_args, **_kwargs) + + execute_command_mock.side_effect = execute_command + + def cmd_init_mock(self, r): + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + nm = NodesManager( + startup_nodes=[ClusterNode(host=default_host, port=default_port)], + from_url=False, + require_full_coverage=False, + dynamic_startup_nodes=True, + ) + + # Reset the counter after initial setup + initialization_count["count"] = 0 + + errors: list[Exception] = [] + + def initialize_repeatedly(iterations: int): + """Call initialize repeatedly""" + try: + for _ in range(iterations): + nm.initialize() + except Exception as e: + errors.append(e) + + # Create multiple threads that will try to initialize concurrently + threads: list[threading.Thread] = [] + for _ in range(10): + threads.append( + threading.Thread(target=initialize_repeatedly, args=(5,)) + ) + + # Start all threads + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check that no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Due to the _initialization_lock, we should see far fewer + # actual initializations than the 50 calls (10 threads * 5 calls) + # In practice, we should see around 50 or fewer depending on timing + assert initialization_count["count"] <= 50 + + # Verify that the nodes_cache is still consistent + assert len(nm.nodes_cache) > 0 + assert len(nm.slots_cache) > 0 + + def test_concurrent_initialize_exact_timing(self): + """ + Test that exactly two concurrent initialize calls result in only + one actual cluster slots fetch by forcing them to start simultaneously + """ + initialization_count = {"count": 0} + epoch_barrier = threading.Barrier(2) + + with ( + patch.object(Redis, "execute_command") as execute_command_mock, + patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize, + ): + + def execute_command(*_args, **_kwargs): + if _args[0] == "CLUSTER SLOTS": + # Track how many times we actually fetch cluster slots + initialization_count["count"] += 1 + return default_cluster_slots + elif _args[0] == "COMMAND": + return {"get": [], "set": []} + elif _args[0] == "INFO": + return {"cluster_enabled": True} + elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "no"} + else: + return execute_command_mock(*_args, **_kwargs) + + execute_command_mock.side_effect = execute_command + + def cmd_init_mock(self, r): + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + nm = NodesManager( + startup_nodes=[ClusterNode(host=default_host, port=default_port)], + from_url=False, + require_full_coverage=False, + dynamic_startup_nodes=True, + ) + + # Reset the counter after initial setup + initialization_count["count"] = 0 + + # Store the original method + original_get_epoch = nm._get_epoch + + def mocked_get_epoch(): + """ + Mock _get_epoch to control race timing: + 1. First thread fetches epoch + 2. Both threads sync at epoch_barrier (ensures 2nd thread also fetches epoch) + 3. Both threads sync at proceed_barrier (ensures both have same epoch before lock) + 4. Both threads proceed to try to acquire _initialization_lock + """ + epoch = original_get_epoch() + # Wait for both threads to have fetched the epoch before proceeding + epoch_barrier.wait() + return epoch + + # Patch the instance method directly + nm._get_epoch = mocked_get_epoch + + errors: list[Exception] = [] + + def initialize_thread(): + """Call initialize to test concurrent access""" + try: + nm.initialize() + except Exception as e: + errors.append(e) + + # Create exactly 2 threads that will initialize at the same time + threads: list[threading.Thread] = [] + for _ in range(2): + threads.append(threading.Thread(target=initialize_thread)) + + # Start both threads + for t in threads: + t.start() + + # Wait for both threads to complete + for t in threads: + t.join() + + # Check that no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Due to the _initialization_lock, only one thread should have + # actually fetched cluster slots + assert initialization_count["count"] == 1 + + # Verify that the nodes_cache is still consistent + assert len(nm.nodes_cache) > 0 + assert len(nm.slots_cache) > 0 + + def test_concurrent_move_slot_and_get_node_from_slot(self): + """ + Test that concurrent move_slot and get_node_from_slot calls + don't cause race conditions or return inconsistent data + """ + rc = get_mocked_redis_client(host=default_host, port=default_port) + n_manager = rc.nodes_manager + + slot = 5000 + + errors: list[Exception] = [] + results: list[ClusterNode] = [] + + def move_slot_repeatedly(iterations: int): + """Simulate concurrent MOVED errors""" + try: + for i in range(iterations): + # Alternate between two nodes + if i % 2 == 0: + moved_error = MovedError(f"{slot} 127.0.0.1:7001") + else: + moved_error = MovedError(f"{slot} 127.0.0.1:7000") + n_manager.move_slot(moved_error) + except Exception as e: + errors.append(e) + + def get_node_from_slot_repeatedly(iterations: int): + """Get node from slot repeatedly""" + try: + for _ in range(iterations): + node = n_manager.get_node_from_slot(slot) + # The node should always be a valid primary + assert node is not None + assert node.server_type == PRIMARY + results.append(node) + except Exception as e: + errors.append(e) + + # Create threads: some moving slots, others reading + threads: list[threading.Thread] = [] + for _ in range(3): + threads.append(threading.Thread(target=move_slot_repeatedly, args=(100,))) + for _ in range(7): + threads.append( + threading.Thread(target=get_node_from_slot_repeatedly, args=(100,)) + ) + + # Start all threads + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check that no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Verify we got results from readers + assert len(results) == 700 + + # Verify the final state is consistent + assert len(n_manager.slots_cache[slot]) >= 1 + assert n_manager.slots_cache[slot][0].server_type == PRIMARY + + def test_concurrent_nodes_cache_access(self): + """ + Test that concurrent access to nodes_cache doesn't cause + dictionary modification errors + """ + rc = get_mocked_redis_client(host=default_host, port=default_port) + n_manager = rc.nodes_manager + + errors = [] + results = [] + + def iterate_nodes_repeatedly(iterations): + """Iterate over nodes repeatedly""" + try: + for _ in range(iterations): + # Simulate iterating over nodes (like get_nodes() does) + with n_manager._lock: + nodes = list(n_manager.nodes_cache.values()) + results.append(len(nodes)) + except Exception as e: + errors.append(e) + + def modify_nodes_cache(iterations): + """Simulate node additions via move_slot""" + try: + for i in range(iterations): + # Create a MOVED error to a new node + new_port = 8000 + i + moved_error = MovedError(f"{i} 127.0.0.1:{new_port}") + n_manager.move_slot(moved_error) + except Exception as e: + errors.append(e) + + # Create threads that read and modify concurrently + threads = [] + for _ in range(5): + threads.append( + threading.Thread(target=iterate_nodes_repeatedly, args=(50,)) + ) + for _ in range(2): + threads.append(threading.Thread(target=modify_nodes_cache, args=(25,))) + + # Start all threads + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check that no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Verify we got results + assert len(results) == 250 + @pytest.mark.onlycluster class TestClusterPubSubObject: From 3cb600d4b0977d36c8e32ebb596b98e98f0d7515 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 17 Oct 2025 08:24:46 -0700 Subject: [PATCH 5/8] Fix tests --- tests/test_cluster.py | 358 +++++++++++++----------------------------- 1 file changed, 105 insertions(+), 253 deletions(-) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 7ccfc63669..b29352f0ea 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2947,141 +2947,6 @@ def test_allow_custom_queue_class(self, queue_class): for node in rc.nodes_manager.nodes_cache.values(): assert node.redis_connection.connection_pool.queue_class == queue_class - def test_concurrent_get_node(self): - """ - Test that concurrent get_node calls are thread-safe - """ - rc = get_mocked_redis_client(host=default_host, port=default_port) - n_manager = rc.nodes_manager - - results = [] - errors = [] - - def get_node_repeatedly(host, port, iterations): - """Get a node repeatedly""" - try: - for _ in range(iterations): - node = n_manager.get_node(host=host, port=port) - results.append(node) - except Exception as e: - errors.append(e) - - # Create threads that will concurrently read nodes - threads = [] - for _ in range(10): - threads.append( - threading.Thread( - target=get_node_repeatedly, args=("127.0.0.1", 7000, 100) - ) - ) - - # Start all threads - for t in threads: - t.start() - - # Wait for all threads to complete - for t in threads: - t.join() - - # Check that no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" - - # Check that we got results - assert len(results) == 1000 - - def test_concurrent_initialize(self): - """ - Test that concurrent initialize calls are properly deduplicated - by the _initialization_lock - """ - initialization_count = {"count": 0} - - with ( - patch.object(Redis, "execute_command") as execute_command_mock, - patch.object( - CommandsParser, "initialize", autospec=True - ) as cmd_parser_initialize, - ): - - def execute_command(*_args, **_kwargs): - if _args[0] == "CLUSTER SLOTS": - # Track how many times we actually fetch cluster slots - initialization_count["count"] += 1 - # Add a small delay to make race conditions more likely - # sleep(0.01) - return default_cluster_slots - elif _args[0] == "COMMAND": - return {"get": [], "set": []} - elif _args[0] == "INFO": - return {"cluster_enabled": True} - elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": "no"} - else: - return execute_command_mock(*_args, **_kwargs) - - execute_command_mock.side_effect = execute_command - - def cmd_init_mock(self, r): - self.commands = { - "get": { - "name": "get", - "arity": 2, - "flags": ["readonly", "fast"], - "first_key_pos": 1, - "last_key_pos": 1, - "step_count": 1, - } - } - - cmd_parser_initialize.side_effect = cmd_init_mock - - nm = NodesManager( - startup_nodes=[ClusterNode(host=default_host, port=default_port)], - from_url=False, - require_full_coverage=False, - dynamic_startup_nodes=True, - ) - - # Reset the counter after initial setup - initialization_count["count"] = 0 - - errors: list[Exception] = [] - - def initialize_repeatedly(iterations: int): - """Call initialize repeatedly""" - try: - for _ in range(iterations): - nm.initialize() - except Exception as e: - errors.append(e) - - # Create multiple threads that will try to initialize concurrently - threads: list[threading.Thread] = [] - for _ in range(10): - threads.append( - threading.Thread(target=initialize_repeatedly, args=(5,)) - ) - - # Start all threads - for t in threads: - t.start() - - # Wait for all threads to complete - for t in threads: - t.join() - - # Check that no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" - - # Due to the _initialization_lock, we should see far fewer - # actual initializations than the 50 calls (10 threads * 5 calls) - # In practice, we should see around 50 or fewer depending on timing - assert initialization_count["count"] <= 50 - - # Verify that the nodes_cache is still consistent - assert len(nm.nodes_cache) > 0 - assert len(nm.slots_cache) > 0 - def test_concurrent_initialize_exact_timing(self): """ Test that exactly two concurrent initialize calls result in only @@ -3092,9 +2957,6 @@ def test_concurrent_initialize_exact_timing(self): with ( patch.object(Redis, "execute_command") as execute_command_mock, - patch.object( - CommandsParser, "initialize", autospec=True - ) as cmd_parser_initialize, ): def execute_command(*_args, **_kwargs): @@ -3102,30 +2964,17 @@ def execute_command(*_args, **_kwargs): # Track how many times we actually fetch cluster slots initialization_count["count"] += 1 return default_cluster_slots - elif _args[0] == "COMMAND": - return {"get": [], "set": []} - elif _args[0] == "INFO": - return {"cluster_enabled": True} - elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": "no"} else: return execute_command_mock(*_args, **_kwargs) execute_command_mock.side_effect = execute_command - def cmd_init_mock(self, r): - self.commands = { - "get": { - "name": "get", - "arity": 2, - "flags": ["readonly", "fast"], - "first_key_pos": 1, - "last_key_pos": 1, - "step_count": 1, - } - } - - cmd_parser_initialize.side_effect = cmd_init_mock + r = get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_enabled=True, + ) + nm = r.nodes_manager nm = NodesManager( startup_nodes=[ClusterNode(host=default_host, port=default_port)], @@ -3149,7 +2998,7 @@ def mocked_get_epoch(): 4. Both threads proceed to try to acquire _initialization_lock """ epoch = original_get_epoch() - # Wait for both threads to have fetched the epoch before proceeding + # Signal that this thread has fetched the epoch epoch_barrier.wait() return epoch @@ -3189,54 +3038,36 @@ def initialize_thread(): assert len(nm.nodes_cache) > 0 assert len(nm.slots_cache) > 0 - def test_concurrent_move_slot_and_get_node_from_slot(self): - """ - Test that concurrent move_slot and get_node_from_slot calls - don't cause race conditions or return inconsistent data - """ - rc = get_mocked_redis_client(host=default_host, port=default_port) - n_manager = rc.nodes_manager - - slot = 5000 - + def test_concurrent_slot_moves(self): + # ensure multiple concurrently moved slots are processed correctly, + # eg: not dropping updates + r = get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_enabled=True, + ) + nm = r.nodes_manager + # Move slots 0-1000 to 127.0.0.1 in concurrent threads + num_threads = 20 + slots_per_thread = 50 # 1000 slots / 20 threads = 50 slots per thread errors: list[Exception] = [] - results: list[ClusterNode] = [] - - def move_slot_repeatedly(iterations: int): - """Simulate concurrent MOVED errors""" - try: - for i in range(iterations): - # Alternate between two nodes - if i % 2 == 0: - moved_error = MovedError(f"{slot} 127.0.0.1:7001") - else: - moved_error = MovedError(f"{slot} 127.0.0.1:7000") - n_manager.move_slot(moved_error) - except Exception as e: - errors.append(e) - def get_node_from_slot_repeatedly(iterations: int): - """Get node from slot repeatedly""" + def move_slots_worker(thread_id: int): + """Each thread moves a subset of slots to 127.0.0.1""" try: - for _ in range(iterations): - node = n_manager.get_node_from_slot(slot) - # The node should always be a valid primary - assert node is not None - assert node.server_type == PRIMARY - results.append(node) + for i in range(slots_per_thread): + moved_error = MovedError( + f"{thread_id * slots_per_thread + i} 127.0.0.1:7001" + ) + nm.move_slot(moved_error) except Exception as e: errors.append(e) - # Create threads: some moving slots, others reading + # Start all threads threads: list[threading.Thread] = [] - for _ in range(3): - threads.append(threading.Thread(target=move_slot_repeatedly, args=(100,))) - for _ in range(7): - threads.append( - threading.Thread(target=get_node_from_slot_repeatedly, args=(100,)) - ) + for i in range(num_threads): + threads.append(threading.Thread(target=move_slots_worker, args=(i,))) - # Start all threads for t in threads: t.start() @@ -3247,68 +3078,89 @@ def get_node_from_slot_repeatedly(iterations: int): # Check that no errors occurred assert len(errors) == 0, f"Errors occurred: {errors}" - # Verify we got results from readers - assert len(results) == 700 - - # Verify the final state is consistent - assert len(n_manager.slots_cache[slot]) >= 1 - assert n_manager.slots_cache[slot][0].server_type == PRIMARY - - def test_concurrent_nodes_cache_access(self): - """ - Test that concurrent access to nodes_cache doesn't cause - dictionary modification errors - """ - rc = get_mocked_redis_client(host=default_host, port=default_port) - n_manager = rc.nodes_manager - - errors = [] - results = [] - - def iterate_nodes_repeatedly(iterations): - """Iterate over nodes repeatedly""" - try: - for _ in range(iterations): - # Simulate iterating over nodes (like get_nodes() does) - with n_manager._lock: - nodes = list(n_manager.nodes_cache.values()) - results.append(len(nodes)) - except Exception as e: - errors.append(e) + # Verify that all slots 0-1000 are moved to 127.0.0.1:7000 + for slot_id in range(num_threads * slots_per_thread): + assert slot_id in nm.slots_cache, f"Slot {slot_id} missing" + slot_nodes = nm.slots_cache[slot_id] + assert len(slot_nodes) >= 1, f"Slot {slot_id} has no nodes" + primary_node = slot_nodes[0] + assert primary_node.host == "127.0.0.1", ( + f"Slot {slot_id} not moved to 127.0.0.1, " + f"current host: {primary_node.host}" + ) + assert primary_node.port == 7001, ( + f"Slot {slot_id} not moved to port 7001, " + f"current port: {primary_node.port}" + ) + assert primary_node.server_type == PRIMARY - def modify_nodes_cache(iterations): - """Simulate node additions via move_slot""" - try: - for i in range(iterations): - # Create a MOVED error to a new node - new_port = 8000 + i - moved_error = MovedError(f"{i} 127.0.0.1:{new_port}") - n_manager.move_slot(moved_error) - except Exception as e: - errors.append(e) + def test_concurrent_initialize_and_move_slot(self): + # race initialize & move slot to ensure that the two operations + # don't conflict with each other. - # Create threads that read and modify concurrently - threads = [] - for _ in range(5): - threads.append( - threading.Thread(target=iterate_nodes_repeatedly, args=(50,)) + with ( + patch.object(Redis, "execute_command") as execute_command_mock, + ): + r = get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_enabled=True, ) - for _ in range(2): - threads.append(threading.Thread(target=modify_nodes_cache, args=(25,))) + nm = r.nodes_manager - # Start all threads - for t in threads: - t.start() + def execute_command(*_args, **_kwargs): + if _args[0] == "CLUSTER SLOTS": + return default_cluster_slots + else: + return execute_command_mock(*_args, **_kwargs) - # Wait for all threads to complete - for t in threads: - t.join() + execute_command_mock.side_effect = execute_command - # Check that no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" + errors: list[Exception] = [] + + def initialize_worker(): + """Reinitialize the cluster""" + try: + nm.initialize() + except Exception as e: + errors.append(e) - # Verify we got results - assert len(results) == 250 + def move_slots_worker(): + """Move slots while initialize is running""" + for slot_id in range(10): + try: + # move slot to a mix of :7001 & :7003, which simulates what we'd see in + # both failovers & slot migrations. + new_slot = 7001 if slot_id % 2 == 0 else 7003 + moved_error = MovedError(f"{slot_id} 127.0.0.1:{new_slot}") + nm.move_slot(moved_error) + except Exception as e: + errors.append(e) + + for _ in range(100): + t1 = threading.Thread(target=initialize_worker) + t2 = threading.Thread(target=move_slots_worker) + + t1.start() + t2.start() + + t1.join() + t2.join() + + # check that no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # verify data consistency + for slot_id in range(REDIS_CLUSTER_HASH_SLOTS): + assert slot_id in nm.slots_cache, f"Slot {slot_id} missing" + slot_nodes = nm.slots_cache[slot_id] + assert len(slot_nodes) == 2 + + for node in slot_nodes: + assert node.name in nm.nodes_cache + + # primary should be first + assert slot_nodes[0].server_type == PRIMARY @pytest.mark.onlycluster From b2676d71043592713e5a6f72d4b77248d77dc48d Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 17 Oct 2025 08:24:56 -0700 Subject: [PATCH 6/8] Fix impl --- redis/cluster.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 26afad4fed..088ddfe106 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1404,14 +1404,6 @@ def __repr__(self): def __eq__(self, obj): return isinstance(obj, ClusterNode) and obj.name == self.name - def __del__(self): - try: - if self.redis_connection is not None: - self.redis_connection.close() - except Exception: - # Ignore errors when closing the connection - pass - class LoadBalancingStrategy(Enum): ROUND_ROBIN = "round_robin" @@ -1701,13 +1693,14 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): # before creating a new cluster node, check if the cluster node already # exists in the current nodes cache and has a valid connection so we can # reuse it + redis_connection: Redis | None = None with self._lock: - target_node = self.nodes_cache.get(node_name) - if target_node is None or target_node.redis_connection is None: - # create new cluster node for this cluster - target_node = ClusterNode(host, port, role) - if target_node.server_type != role: - target_node.server_type = role + previous_node = self.nodes_cache.get(node_name) + if previous_node: + redis_connection = previous_node.redis_connection + # don't update the old ClusterNode, so we don't update its role + # outside of the lock + target_node = ClusterNode(host, port, role, redis_connection) # add this node to the nodes cache tmp_nodes_cache[target_node.name] = target_node From 1d54c113e71ebc878393ec4d65c660a04cd92001 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 17 Oct 2025 09:01:48 -0700 Subject: [PATCH 7/8] Move comment to make more sense --- redis/cluster.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 088ddfe106..5ccc306584 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1731,14 +1731,14 @@ def initialize(self): epoch = self._get_epoch() with self._initialization_lock: - # randomly order the startup nodes to ensure multiple clients evenly - # distribute topology discovery requests across the cluster. with self._lock: if epoch != self._epoch: # another thread has already re-initialized the nodes; don't # bother running again return + # randomly order the startup nodes to ensure multiple clients evenly + # distribute topology discovery requests across the cluster. startup_nodes = random.sample( list(self.startup_nodes.values()), k=len(self.startup_nodes) ) From a20733d7c14f62758d1769e18e3fd091b5ed513d Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 17 Oct 2025 09:02:34 -0700 Subject: [PATCH 8/8] Lint --- redis/cluster.py | 1 - tests/test_cluster_transaction.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 5ccc306584..bd061af4f5 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -18,7 +18,6 @@ Set, Tuple, Union, - final, ) from redis._parsers import CommandsParser, Encoder diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index b7efed8815..d5b21abc9d 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -8,7 +8,7 @@ from redis import CrossSlotTransactionError, ConnectionPool, RedisClusterException from redis.backoff import NoBackoff from redis.client import Redis -from redis.cluster import PRIMARY, ClusterNode, NodesManager, RedisCluster +from redis.cluster import PRIMARY, ClusterNode, RedisCluster from redis.retry import Retry from .conftest import skip_if_server_version_lt