diff --git a/channels_redis/core.py b/channels_redis/core.py index 7c04ecd1..c3eb3b38 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -15,7 +15,7 @@ from channels.exceptions import ChannelFull from channels.layers import BaseChannelLayer -from .utils import _consistent_hash +from .utils import _consistent_hash, _wrap_close logger = logging.getLogger(__name__) @@ -69,6 +69,26 @@ def put_nowait(self, item): return super(BoundedQueue, self).put_nowait(item) +class RedisLoopLayer: + def __init__(self, channel_layer): + self._lock = asyncio.Lock() + self.channel_layer = channel_layer + self._connections = {} + + def get_connection(self, index): + if index not in self._connections: + pool = self.channel_layer.create_pool(index) + self._connections[index] = aioredis.Redis(connection_pool=pool) + + return self._connections[index] + + async def flush(self): + async with self._lock: + for index in list(self._connections): + connection = self._connections.pop(index) + await connection.close(close_connection_pool=True) + + class RedisChannelLayer(BaseChannelLayer): """ Redis channel layer. @@ -101,8 +121,7 @@ def __init__( self.hosts = self.decode_hosts(hosts) self.ring_size = len(self.hosts) # Cached redis connection pools and the event loop they are from - self.pools = {} - self.pools_loop = None + self._layers = {} # Normal channels choose a host index by cycling through the available hosts self._receive_index_generator = itertools.cycle(range(len(self.hosts))) self._send_index_generator = itertools.cycle(range(len(self.hosts))) @@ -138,7 +157,7 @@ def create_pool(self, index): return aioredis.sentinel.SentinelConnectionPool( master_name, aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), - **host + **host, ) else: return aioredis.ConnectionPool(**host) @@ -331,7 +350,7 @@ async def receive(self, channel): raise - message, token, exception = None, None, None + message = token = exception = None for task in done: try: result = task.result() @@ -367,7 +386,7 @@ async def receive(self, channel): message_channel, message = await self.receive_single( real_channel ) - if type(message_channel) is list: + if isinstance(message_channel, list): for chan in message_channel: self.receive_buffer[chan].put_nowait(message) else: @@ -459,11 +478,7 @@ async def new_channel(self, prefix="specific"): Returns a new channel name that can be used by something in our process as a specific channel. """ - return "%s.%s!%s" % ( - prefix, - self.client_prefix, - uuid.uuid4().hex, - ) + return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}" ### Flush extension ### @@ -496,9 +511,8 @@ async def close_pools(self): # Flush all cleaners, in case somebody just wanted to close the # pools without flushing first. await self.wait_received() - - for index in self.pools: - await self.pools[index].disconnect() + for layer in self._layers.values(): + await layer.flush() async def wait_received(self): """ @@ -667,7 +681,7 @@ def _group_key(self, group): """ Common function to make the storage key for the group. """ - return ("%s:group:%s" % (self.prefix, group)).encode("utf8") + return f"{self.prefix}:group:{group}".encode("utf8") ### Serialization ### @@ -711,7 +725,7 @@ def make_fernet(self, key): return Fernet(formatted_key) def __str__(self): - return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts) + return f"{self.__class__.__name__}(hosts={self.hosts})" ### Connection handling ### @@ -723,18 +737,14 @@ def connection(self, index): # Catch bad indexes if not 0 <= index < self.ring_size: raise ValueError( - "There are only %s hosts - you asked for %s!" % (self.ring_size, index) + f"There are only {self.ring_size} hosts - you asked for {index}!" ) + loop = asyncio.get_running_loop() try: - loop = asyncio.get_running_loop() - if self.pools_loop != loop: - self.pools = {} - self.pools_loop = loop - except RuntimeError: - pass - - if index not in self.pools: - self.pools[index] = self.create_pool(index) + layer = self._layers[loop] + except KeyError: + _wrap_close(self, loop) + layer = self._layers[loop] = RedisLoopLayer(self) - return aioredis.Redis(connection_pool=self.pools[index]) + return layer.get_connection(index) diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index ccaef0fb..2ac8a08a 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -1,32 +1,16 @@ import asyncio import functools import logging -import types import uuid import msgpack from redis import asyncio as aioredis -from .utils import _consistent_hash +from .utils import _consistent_hash, _wrap_close logger = logging.getLogger(__name__) -def _wrap_close(proxy, loop): - original_impl = loop.close - - def _wrapper(self, *args, **kwargs): - if loop in proxy._layers: - layer = proxy._layers[loop] - del proxy._layers[loop] - loop.run_until_complete(layer.flush()) - - self.close = original_impl - return self.close(*args, **kwargs) - - loop.close = types.MethodType(_wrapper, loop) - - async def _async_proxy(obj, name, *args, **kwargs): # Must be defined as a function and not a method due to # https://bugs.python.org/issue38364 diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 7b30fdcf..d2405bb0 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -1,4 +1,5 @@ import binascii +import types def _consistent_hash(value, ring_size): @@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size): bigval = binascii.crc32(value) & 0xFFF ring_divisor = 4096 / float(ring_size) return int(bigval / ring_divisor) + + +def _wrap_close(proxy, loop): + original_impl = loop.close + + def _wrapper(self, *args, **kwargs): + if loop in proxy._layers: + layer = proxy._layers[loop] + del proxy._layers[loop] + loop.run_until_complete(layer.flush()) + + self.close = original_impl + return self.close(*args, **kwargs) + + loop.close = types.MethodType(_wrapper, loop)