Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made core layer close connection pools on loop close. #347

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .utils import _consistent_hash
from .utils import _consistent_hash, _wrap_close

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,6 +69,26 @@ def put_nowait(self, item):
return super(BoundedQueue, self).put_nowait(item)


class RedisLoopLayer:
def __init__(self, channel_layer):
self._lock = asyncio.Lock()
self.channel_layer = channel_layer
self._connections = {}

def get_connection(self, index):
if index not in self._connections:
pool = self.channel_layer.create_pool(index)
self._connections[index] = aioredis.Redis(connection_pool=pool)

return self._connections[index]

async def flush(self):
async with self._lock:
for index in list(self._connections):
connection = self._connections.pop(index)
await connection.close(close_connection_pool=True)


class RedisChannelLayer(BaseChannelLayer):
"""
Redis channel layer.
Expand Down Expand Up @@ -101,8 +121,7 @@ def __init__(
self.hosts = self.decode_hosts(hosts)
self.ring_size = len(self.hosts)
# Cached redis connection pools and the event loop they are from
self.pools = {}
self.pools_loop = None
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
Expand Down Expand Up @@ -138,7 +157,7 @@ def create_pool(self, index):
return aioredis.sentinel.SentinelConnectionPool(
master_name,
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
**host
**host,
)
else:
return aioredis.ConnectionPool(**host)
Expand Down Expand Up @@ -331,7 +350,7 @@ async def receive(self, channel):

raise

message, token, exception = None, None, None
message = token = exception = None
for task in done:
try:
result = task.result()
Expand Down Expand Up @@ -367,7 +386,7 @@ async def receive(self, channel):
message_channel, message = await self.receive_single(
real_channel
)
if type(message_channel) is list:
if isinstance(message_channel, list):
for chan in message_channel:
self.receive_buffer[chan].put_nowait(message)
else:
Expand Down Expand Up @@ -459,11 +478,7 @@ async def new_channel(self, prefix="specific"):
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.%s!%s" % (
prefix,
self.client_prefix,
uuid.uuid4().hex,
)
return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"

### Flush extension ###

Expand Down Expand Up @@ -496,9 +511,8 @@ async def close_pools(self):
# Flush all cleaners, in case somebody just wanted to close the
# pools without flushing first.
await self.wait_received()

for index in self.pools:
await self.pools[index].disconnect()
for layer in self._layers.values():
await layer.flush()

async def wait_received(self):
"""
Expand Down Expand Up @@ -667,7 +681,7 @@ def _group_key(self, group):
"""
Common function to make the storage key for the group.
"""
return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
return f"{self.prefix}:group:{group}".encode("utf8")

### Serialization ###

Expand Down Expand Up @@ -711,7 +725,7 @@ def make_fernet(self, key):
return Fernet(formatted_key)

def __str__(self):
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
return f"{self.__class__.__name__}(hosts={self.hosts})"

### Connection handling ###

Expand All @@ -723,18 +737,14 @@ def connection(self, index):
# Catch bad indexes
if not 0 <= index < self.ring_size:
raise ValueError(
"There are only %s hosts - you asked for %s!" % (self.ring_size, index)
f"There are only {self.ring_size} hosts - you asked for {index}!"
)

loop = asyncio.get_running_loop()
try:
loop = asyncio.get_running_loop()
if self.pools_loop != loop:
self.pools = {}
self.pools_loop = loop
except RuntimeError:
pass

if index not in self.pools:
self.pools[index] = self.create_pool(index)
layer = self._layers[loop]
except KeyError:
_wrap_close(self, loop)
layer = self._layers[loop] = RedisLoopLayer(self)

return aioredis.Redis(connection_pool=self.pools[index])
return layer.get_connection(index)
18 changes: 1 addition & 17 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,16 @@
import asyncio
import functools
import logging
import types
import uuid

import msgpack
from redis import asyncio as aioredis

from .utils import _consistent_hash
from .utils import _consistent_hash, _wrap_close

logger = logging.getLogger(__name__)


def _wrap_close(proxy, loop):
original_impl = loop.close

def _wrapper(self, *args, **kwargs):
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())

self.close = original_impl
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)


async def _async_proxy(obj, name, *args, **kwargs):
# Must be defined as a function and not a method due to
# https://bugs.python.org/issue38364
Expand Down
16 changes: 16 additions & 0 deletions channels_redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import binascii
import types


def _consistent_hash(value, ring_size):
Expand All @@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size):
bigval = binascii.crc32(value) & 0xFFF
ring_divisor = 4096 / float(ring_size)
return int(bigval / ring_divisor)


def _wrap_close(proxy, loop):
original_impl = loop.close

def _wrapper(self, *args, **kwargs):
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())

self.close = original_impl
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)