Skip to content

Commit 2fca31c

Browse files
committed
Assure pools are closed on loop close in core (#332)
1 parent a7094c5 commit 2fca31c

File tree

3 files changed

+54
-43
lines changed

3 files changed

+54
-43
lines changed

channels_redis/core.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from channels.exceptions import ChannelFull
1616
from channels.layers import BaseChannelLayer
1717

18-
from .utils import _consistent_hash
18+
from .utils import _consistent_hash, _wrap_close
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -69,6 +69,27 @@ def put_nowait(self, item):
6969
return super(BoundedQueue, self).put_nowait(item)
7070

7171

72+
class RedisLoopLayer:
73+
74+
def __init__(self, channel_layer):
75+
self._lock = asyncio.Lock()
76+
self.channel_layer = channel_layer
77+
self._connections = {}
78+
79+
def get_connection(self, index):
80+
if index not in self._connections:
81+
pool = self.channel_layer.create_pool(index)
82+
self._connections[index] = aioredis.Redis(connection_pool=pool)
83+
84+
return self._connections[index]
85+
86+
async def flush(self):
87+
async with self._lock:
88+
for index in list(self._connections):
89+
connection = self._connections.pop(index)
90+
await connection.close(close_connection_pool=True)
91+
92+
7293
class RedisChannelLayer(BaseChannelLayer):
7394
"""
7495
Redis channel layer.
@@ -101,8 +122,7 @@ def __init__(
101122
self.hosts = self.decode_hosts(hosts)
102123
self.ring_size = len(self.hosts)
103124
# Cached redis connection pools and the event loop they are from
104-
self.pools = {}
105-
self.pools_loop = None
125+
self._layers = {}
106126
# Normal channels choose a host index by cycling through the available hosts
107127
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
108128
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -331,7 +351,7 @@ async def receive(self, channel):
331351

332352
raise
333353

334-
message, token, exception = None, None, None
354+
message = token = exception = None
335355
for task in done:
336356
try:
337357
result = task.result()
@@ -367,7 +387,7 @@ async def receive(self, channel):
367387
message_channel, message = await self.receive_single(
368388
real_channel
369389
)
370-
if type(message_channel) is list:
390+
if isinstance(message_channel, list):
371391
for chan in message_channel:
372392
self.receive_buffer[chan].put_nowait(message)
373393
else:
@@ -459,11 +479,7 @@ async def new_channel(self, prefix="specific"):
459479
Returns a new channel name that can be used by something in our
460480
process as a specific channel.
461481
"""
462-
return "%s.%s!%s" % (
463-
prefix,
464-
self.client_prefix,
465-
uuid.uuid4().hex,
466-
)
482+
return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"
467483

468484
### Flush extension ###
469485

@@ -496,9 +512,8 @@ async def close_pools(self):
496512
# Flush all cleaners, in case somebody just wanted to close the
497513
# pools without flushing first.
498514
await self.wait_received()
499-
500-
for index in self.pools:
501-
await self.pools[index].disconnect()
515+
for layer in self._layers.values():
516+
await layer.flush()
502517

503518
async def wait_received(self):
504519
"""
@@ -667,7 +682,7 @@ def _group_key(self, group):
667682
"""
668683
Common function to make the storage key for the group.
669684
"""
670-
return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
685+
return f"{self.prefix}:group:{group}".encode("utf8")
671686

672687
### Serialization ###
673688

@@ -711,7 +726,7 @@ def make_fernet(self, key):
711726
return Fernet(formatted_key)
712727

713728
def __str__(self):
714-
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
729+
return f"{self.__class__.__name__}(hosts={self.hosts})"
715730

716731
### Connection handling ###
717732

@@ -723,18 +738,14 @@ def connection(self, index):
723738
# Catch bad indexes
724739
if not 0 <= index < self.ring_size:
725740
raise ValueError(
726-
"There are only %s hosts - you asked for %s!" % (self.ring_size, index)
741+
f"There are only {self.ring_size} hosts - you asked for {index}!"
727742
)
728743

744+
loop = asyncio.get_running_loop()
729745
try:
730-
loop = asyncio.get_running_loop()
731-
if self.pools_loop != loop:
732-
self.pools = {}
733-
self.pools_loop = loop
734-
except RuntimeError:
735-
pass
736-
737-
if index not in self.pools:
738-
self.pools[index] = self.create_pool(index)
746+
layer = self._layers[loop]
747+
except KeyError:
748+
_wrap_close(self, loop)
749+
layer = self._layers[loop] = RedisLoopLayer(self)
739750

740-
return aioredis.Redis(connection_pool=self.pools[index])
751+
return layer.get_connection(index)

channels_redis/pubsub.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,16 @@
11
import asyncio
22
import functools
33
import logging
4-
import types
54
import uuid
65

76
import msgpack
87
from redis import asyncio as aioredis
98

10-
from .utils import _consistent_hash
9+
from .utils import _consistent_hash, _wrap_close
1110

1211
logger = logging.getLogger(__name__)
1312

1413

15-
def _wrap_close(proxy, loop):
16-
original_impl = loop.close
17-
18-
def _wrapper(self, *args, **kwargs):
19-
if loop in proxy._layers:
20-
layer = proxy._layers[loop]
21-
del proxy._layers[loop]
22-
loop.run_until_complete(layer.flush())
23-
24-
self.close = original_impl
25-
return self.close(*args, **kwargs)
26-
27-
loop.close = types.MethodType(_wrapper, loop)
28-
29-
3014
async def _async_proxy(obj, name, *args, **kwargs):
3115
# Must be defined as a function and not a method due to
3216
# https://bugs.python.org/issue38364

channels_redis/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import binascii
2+
import types
23

34

45
def _consistent_hash(value, ring_size):
@@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size):
1516
bigval = binascii.crc32(value) & 0xFFF
1617
ring_divisor = 4096 / float(ring_size)
1718
return int(bigval / ring_divisor)
19+
20+
21+
def _wrap_close(proxy, loop):
22+
original_impl = loop.close
23+
24+
def _wrapper(self, *args, **kwargs):
25+
if loop in proxy._layers:
26+
layer = proxy._layers[loop]
27+
del proxy._layers[loop]
28+
loop.run_until_complete(layer.flush())
29+
30+
self.close = original_impl
31+
return self.close(*args, **kwargs)
32+
33+
loop.close = types.MethodType(_wrapper, loop)

0 commit comments

Comments
 (0)