Skip to content

Commit 95a1950

Browse files
sevdogcarltongibson
authored andcommitted
Assure pools are closed on loop close in core (#332)
1 parent 89b29ad commit 95a1950

File tree

3 files changed

+54
-44
lines changed

3 files changed

+54
-44
lines changed

channels_redis/core.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from channels.exceptions import ChannelFull
1616
from channels.layers import BaseChannelLayer
1717

18-
from .utils import _consistent_hash
18+
from .utils import _consistent_hash, _wrap_close
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -69,6 +69,26 @@ def put_nowait(self, item):
6969
return super(BoundedQueue, self).put_nowait(item)
7070

7171

72+
class RedisLoopLayer:
73+
def __init__(self, channel_layer):
74+
self._lock = asyncio.Lock()
75+
self.channel_layer = channel_layer
76+
self._connections = {}
77+
78+
def get_connection(self, index):
79+
if index not in self._connections:
80+
pool = self.channel_layer.create_pool(index)
81+
self._connections[index] = aioredis.Redis(connection_pool=pool)
82+
83+
return self._connections[index]
84+
85+
async def flush(self):
86+
async with self._lock:
87+
for index in list(self._connections):
88+
connection = self._connections.pop(index)
89+
await connection.close(close_connection_pool=True)
90+
91+
7292
class RedisChannelLayer(BaseChannelLayer):
7393
"""
7494
Redis channel layer.
@@ -101,8 +121,7 @@ def __init__(
101121
self.hosts = self.decode_hosts(hosts)
102122
self.ring_size = len(self.hosts)
103123
# Cached redis connection pools and the event loop they are from
104-
self.pools = {}
105-
self.pools_loop = None
124+
self._layers = {}
106125
# Normal channels choose a host index by cycling through the available hosts
107126
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
108127
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -138,7 +157,7 @@ def create_pool(self, index):
138157
return aioredis.sentinel.SentinelConnectionPool(
139158
master_name,
140159
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
141-
**host
160+
**host,
142161
)
143162
else:
144163
return aioredis.ConnectionPool(**host)
@@ -331,7 +350,7 @@ async def receive(self, channel):
331350

332351
raise
333352

334-
message, token, exception = None, None, None
353+
message = token = exception = None
335354
for task in done:
336355
try:
337356
result = task.result()
@@ -367,7 +386,7 @@ async def receive(self, channel):
367386
message_channel, message = await self.receive_single(
368387
real_channel
369388
)
370-
if type(message_channel) is list:
389+
if isinstance(message_channel, list):
371390
for chan in message_channel:
372391
self.receive_buffer[chan].put_nowait(message)
373392
else:
@@ -459,11 +478,7 @@ async def new_channel(self, prefix="specific"):
459478
Returns a new channel name that can be used by something in our
460479
process as a specific channel.
461480
"""
462-
return "%s.%s!%s" % (
463-
prefix,
464-
self.client_prefix,
465-
uuid.uuid4().hex,
466-
)
481+
return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"
467482

468483
### Flush extension ###
469484

@@ -496,9 +511,8 @@ async def close_pools(self):
496511
# Flush all cleaners, in case somebody just wanted to close the
497512
# pools without flushing first.
498513
await self.wait_received()
499-
500-
for index in self.pools:
501-
await self.pools[index].disconnect()
514+
for layer in self._layers.values():
515+
await layer.flush()
502516

503517
async def wait_received(self):
504518
"""
@@ -667,7 +681,7 @@ def _group_key(self, group):
667681
"""
668682
Common function to make the storage key for the group.
669683
"""
670-
return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
684+
return f"{self.prefix}:group:{group}".encode("utf8")
671685

672686
### Serialization ###
673687

@@ -711,7 +725,7 @@ def make_fernet(self, key):
711725
return Fernet(formatted_key)
712726

713727
def __str__(self):
714-
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
728+
return f"{self.__class__.__name__}(hosts={self.hosts})"
715729

716730
### Connection handling ###
717731

@@ -723,18 +737,14 @@ def connection(self, index):
723737
# Catch bad indexes
724738
if not 0 <= index < self.ring_size:
725739
raise ValueError(
726-
"There are only %s hosts - you asked for %s!" % (self.ring_size, index)
740+
f"There are only {self.ring_size} hosts - you asked for {index}!"
727741
)
728742

743+
loop = asyncio.get_running_loop()
729744
try:
730-
loop = asyncio.get_running_loop()
731-
if self.pools_loop != loop:
732-
self.pools = {}
733-
self.pools_loop = loop
734-
except RuntimeError:
735-
pass
736-
737-
if index not in self.pools:
738-
self.pools[index] = self.create_pool(index)
745+
layer = self._layers[loop]
746+
except KeyError:
747+
_wrap_close(self, loop)
748+
layer = self._layers[loop] = RedisLoopLayer(self)
739749

740-
return aioredis.Redis(connection_pool=self.pools[index])
750+
return layer.get_connection(index)

channels_redis/pubsub.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,16 @@
11
import asyncio
22
import functools
33
import logging
4-
import types
54
import uuid
65

76
import msgpack
87
from redis import asyncio as aioredis
98

10-
from .utils import _consistent_hash
9+
from .utils import _consistent_hash, _wrap_close
1110

1211
logger = logging.getLogger(__name__)
1312

1413

15-
def _wrap_close(proxy, loop):
16-
original_impl = loop.close
17-
18-
def _wrapper(self, *args, **kwargs):
19-
if loop in proxy._layers:
20-
layer = proxy._layers[loop]
21-
del proxy._layers[loop]
22-
loop.run_until_complete(layer.flush())
23-
24-
self.close = original_impl
25-
return self.close(*args, **kwargs)
26-
27-
loop.close = types.MethodType(_wrapper, loop)
28-
29-
3014
async def _async_proxy(obj, name, *args, **kwargs):
3115
# Must be defined as a function and not a method due to
3216
# https://bugs.python.org/issue38364

channels_redis/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import binascii
2+
import types
23

34

45
def _consistent_hash(value, ring_size):
@@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size):
1516
bigval = binascii.crc32(value) & 0xFFF
1617
ring_divisor = 4096 / float(ring_size)
1718
return int(bigval / ring_divisor)
19+
20+
21+
def _wrap_close(proxy, loop):
22+
original_impl = loop.close
23+
24+
def _wrapper(self, *args, **kwargs):
25+
if loop in proxy._layers:
26+
layer = proxy._layers[loop]
27+
del proxy._layers[loop]
28+
loop.run_until_complete(layer.flush())
29+
30+
self.close = original_impl
31+
return self.close(*args, **kwargs)
32+
33+
loop.close = types.MethodType(_wrapper, loop)

0 commit comments

Comments
 (0)