diff --git a/channels_redis/core.py b/channels_redis/core.py index 23d9213..911376d 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -15,7 +15,7 @@ from channels.exceptions import ChannelFull from channels.layers import BaseChannelLayer -from .utils import _consistent_hash, _wrap_close +from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts logger = logging.getLogger(__name__) @@ -120,7 +120,7 @@ def __init__( self.should_auto_discard_full_channels = should_auto_discard_full_channels assert isinstance(self.prefix, str), "Prefix must be unicode" # Configure the host objects - self.hosts = self.decode_hosts(hosts) + self.hosts = decode_hosts(hosts) self.ring_size = len(self.hosts) # Cached redis connection pools and the event loop they are from self._layers = {} @@ -148,46 +148,7 @@ def __init__( self.receive_clean_locks = ChannelLock() def create_pool(self, index): - host = self.hosts[index] - - if "address" in host: - return aioredis.ConnectionPool.from_url(host["address"]) - elif "master_name" in host: - sentinels = host.pop("sentinels") - master_name = host.pop("master_name") - sentinel_kwargs = host.pop("sentinel_kwargs", None) - return aioredis.sentinel.SentinelConnectionPool( - master_name, - aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), - **host, - ) - else: - return aioredis.ConnectionPool(**host) - - def decode_hosts(self, hosts): - """ - Takes the value of the "hosts" argument passed to the class and returns - a list of kwargs to use for the Redis connection constructor. - """ - # If no hosts were provided, return a default value - if not hosts: - return [{"address": "redis://localhost:6379"}] - # If they provided just a string, scold them. - if isinstance(hosts, (str, bytes)): - raise ValueError( - "You must pass a list of Redis hosts, even if there is only one." - ) - - # Decode each hosts entry into a kwargs dict - result = [] - for entry in hosts: - if isinstance(entry, dict): - result.append(entry) - elif isinstance(entry, tuple): - result.append({"host": entry[0], "port": entry[1]}) - else: - result.append({"address": entry}) - return result + return create_pool(self.hosts[index]) def _setup_encryption(self, symmetric_encryption_keys): # See if we can do encryption if they asked diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 2ac8a08..78db68e 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -6,7 +6,7 @@ import msgpack from redis import asyncio as aioredis -from .utils import _consistent_hash, _wrap_close +from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts logger = logging.getLogger(__name__) @@ -81,12 +81,6 @@ def __init__( channel_layer=None, **kwargs, ): - if hosts is None: - hosts = ["redis://localhost:6379"] - assert ( - isinstance(hosts, list) and len(hosts) > 0 - ), "`hosts` must be a list with at least one Redis server" - self.prefix = prefix self.on_disconnect = on_disconnect @@ -102,7 +96,9 @@ def __init__( self.groups = {} # For each host, we create a `RedisSingleShardConnection` to manage the connection to that host. - self._shards = [RedisSingleShardConnection(host, self) for host in hosts] + self._shards = [ + RedisSingleShardConnection(host, self) for host in decode_hosts(hosts) + ] def _get_shard(self, channel_or_group_name): """ @@ -247,9 +243,7 @@ async def flush(self): class RedisSingleShardConnection: def __init__(self, host, channel_layer): - self.host = host.copy() if type(host) is dict else {"address": host} - self.master_name = self.host.pop("master_name", None) - self.sentinel_kwargs = self.host.pop("sentinel_kwargs", None) + self.host = host self.channel_layer = channel_layer self._subscribed_to = set() self._lock = asyncio.Lock() @@ -331,18 +325,7 @@ def _receive_message(self, message): def _ensure_redis(self): if self._redis is None: - if self.master_name is None: - pool = aioredis.ConnectionPool.from_url(self.host["address"]) - else: - # aioredis default timeout is way too low - pool = aioredis.sentinel.SentinelConnectionPool( - self.master_name, - aioredis.sentinel.Sentinel( - self.host["sentinels"], - socket_timeout=2, - sentinel_kwargs=self.sentinel_kwargs, - ), - ) + pool = create_pool(self.host) self._redis = aioredis.Redis(connection_pool=pool) self._pubsub = self._redis.pubsub() diff --git a/channels_redis/utils.py b/channels_redis/utils.py index d2405bb..98e06ca 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -1,6 +1,8 @@ import binascii import types +from redis import asyncio as aioredis + def _consistent_hash(value, ring_size): """ @@ -31,3 +33,53 @@ def _wrapper(self, *args, **kwargs): return self.close(*args, **kwargs) loop.close = types.MethodType(_wrapper, loop) + + +def decode_hosts(hosts): + """ + Takes the value of the "hosts" argument and returns + a list of kwargs to use for the Redis connection constructor. + """ + # If no hosts were provided, return a default value + if not hosts: + return [{"address": "redis://localhost:6379"}] + # If they provided just a string, scold them. + if isinstance(hosts, (str, bytes)): + raise ValueError( + "You must pass a list of Redis hosts, even if there is only one." + ) + + # Decode each hosts entry into a kwargs dict + result = [] + for entry in hosts: + if isinstance(entry, dict): + result.append(entry) + elif isinstance(entry, (tuple, list)): + result.append({"host": entry[0], "port": entry[1]}) + else: + result.append({"address": entry}) + return result + + +def create_pool(host): + """ + Takes the value of the "host" argument and returns a suited connection pool to + the corresponding redis instance. + """ + # avoid side-effects from modifying host + host = host.copy() + if "address" in host: + address = host.pop("address") + return aioredis.ConnectionPool.from_url(address, **host) + + master_name = host.pop("master_name", None) + if master_name is not None: + sentinels = host.pop("sentinels") + sentinel_kwargs = host.pop("sentinel_kwargs", None) + return aioredis.sentinel.SentinelConnectionPool( + master_name, + aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), + **host + ) + + return aioredis.ConnectionPool(**host)