Skip to content

Commit 40c1fb1

Browse files
sevdogShi Feng
authored andcommitted
Assured pools are closed on loop close in core (django#347)
1 parent 13ead0b commit 40c1fb1

File tree

3 files changed

+54
-44
lines changed

3 files changed

+54
-44
lines changed

channels_redis/core.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from channels.exceptions import ChannelFull
1616
from channels.layers import BaseChannelLayer
1717

18-
from .utils import _consistent_hash
18+
from .utils import _consistent_hash, _wrap_close
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -69,6 +69,26 @@ def put_nowait(self, item):
6969
return super(BoundedQueue, self).put_nowait(item)
7070

7171

72+
class RedisLoopLayer:
73+
def __init__(self, channel_layer):
74+
self._lock = asyncio.Lock()
75+
self.channel_layer = channel_layer
76+
self._connections = {}
77+
78+
def get_connection(self, index):
79+
if index not in self._connections:
80+
pool = self.channel_layer.create_pool(index)
81+
self._connections[index] = aioredis.Redis(connection_pool=pool)
82+
83+
return self._connections[index]
84+
85+
async def flush(self):
86+
async with self._lock:
87+
for index in list(self._connections):
88+
connection = self._connections.pop(index)
89+
await connection.close(close_connection_pool=True)
90+
91+
7292
class RedisChannelLayer(BaseChannelLayer):
7393
"""
7494
Redis channel layer.
@@ -103,8 +123,7 @@ def __init__(
103123
self.hosts = self.decode_hosts(hosts)
104124
self.ring_size = len(self.hosts)
105125
# Cached redis connection pools and the event loop they are from
106-
self.pools = {}
107-
self.pools_loop = None
126+
self._layers = {}
108127
# Normal channels choose a host index by cycling through the available hosts
109128
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
110129
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -140,7 +159,7 @@ def create_pool(self, index):
140159
return aioredis.sentinel.SentinelConnectionPool(
141160
master_name,
142161
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
143-
**host
162+
**host,
144163
)
145164
else:
146165
return aioredis.ConnectionPool(**host)
@@ -333,7 +352,7 @@ async def receive(self, channel):
333352

334353
raise
335354

336-
message, token, exception = None, None, None
355+
message = token = exception = None
337356
for task in done:
338357
try:
339358
result = task.result()
@@ -369,7 +388,7 @@ async def receive(self, channel):
369388
message_channel, message = await self.receive_single(
370389
real_channel
371390
)
372-
if type(message_channel) is list:
391+
if isinstance(message_channel, list):
373392
for chan in message_channel:
374393
self.receive_buffer[chan].put_nowait(message)
375394
else:
@@ -461,11 +480,7 @@ async def new_channel(self, prefix="specific"):
461480
Returns a new channel name that can be used by something in our
462481
process as a specific channel.
463482
"""
464-
return "%s.%s!%s" % (
465-
prefix,
466-
self.client_prefix,
467-
uuid.uuid4().hex,
468-
)
483+
return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"
469484

470485
### Flush extension ###
471486

@@ -498,9 +513,8 @@ async def close_pools(self):
498513
# Flush all cleaners, in case somebody just wanted to close the
499514
# pools without flushing first.
500515
await self.wait_received()
501-
502-
for index in self.pools:
503-
await self.pools[index].disconnect()
516+
for layer in self._layers.values():
517+
await layer.flush()
504518

505519
async def wait_received(self):
506520
"""
@@ -703,7 +717,7 @@ def _group_key(self, group):
703717
"""
704718
Common function to make the storage key for the group.
705719
"""
706-
return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
720+
return f"{self.prefix}:group:{group}".encode("utf8")
707721

708722
### Serialization ###
709723

@@ -747,7 +761,7 @@ def make_fernet(self, key):
747761
return Fernet(formatted_key)
748762

749763
def __str__(self):
750-
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
764+
return f"{self.__class__.__name__}(hosts={self.hosts})"
751765

752766
### Connection handling ###
753767

@@ -759,18 +773,14 @@ def connection(self, index):
759773
# Catch bad indexes
760774
if not 0 <= index < self.ring_size:
761775
raise ValueError(
762-
"There are only %s hosts - you asked for %s!" % (self.ring_size, index)
776+
f"There are only {self.ring_size} hosts - you asked for {index}!"
763777
)
764778

779+
loop = asyncio.get_running_loop()
765780
try:
766-
loop = asyncio.get_running_loop()
767-
if self.pools_loop != loop:
768-
self.pools = {}
769-
self.pools_loop = loop
770-
except RuntimeError:
771-
pass
772-
773-
if index not in self.pools:
774-
self.pools[index] = self.create_pool(index)
781+
layer = self._layers[loop]
782+
except KeyError:
783+
_wrap_close(self, loop)
784+
layer = self._layers[loop] = RedisLoopLayer(self)
775785

776-
return aioredis.Redis(connection_pool=self.pools[index])
786+
return layer.get_connection(index)

channels_redis/pubsub.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,16 @@
11
import asyncio
22
import functools
33
import logging
4-
import types
54
import uuid
65

76
import msgpack
87
from redis import asyncio as aioredis
98

10-
from .utils import _consistent_hash
9+
from .utils import _consistent_hash, _wrap_close
1110

1211
logger = logging.getLogger(__name__)
1312

1413

15-
def _wrap_close(proxy, loop):
16-
original_impl = loop.close
17-
18-
def _wrapper(self, *args, **kwargs):
19-
if loop in proxy._layers:
20-
layer = proxy._layers[loop]
21-
del proxy._layers[loop]
22-
loop.run_until_complete(layer.flush())
23-
24-
self.close = original_impl
25-
return self.close(*args, **kwargs)
26-
27-
loop.close = types.MethodType(_wrapper, loop)
28-
29-
3014
async def _async_proxy(obj, name, *args, **kwargs):
3115
# Must be defined as a function and not a method due to
3216
# https://bugs.python.org/issue38364

channels_redis/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import binascii
2+
import types
23

34

45
def _consistent_hash(value, ring_size):
@@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size):
1516
bigval = binascii.crc32(value) & 0xFFF
1617
ring_divisor = 4096 / float(ring_size)
1718
return int(bigval / ring_divisor)
19+
20+
21+
def _wrap_close(proxy, loop):
22+
original_impl = loop.close
23+
24+
def _wrapper(self, *args, **kwargs):
25+
if loop in proxy._layers:
26+
layer = proxy._layers[loop]
27+
del proxy._layers[loop]
28+
loop.run_until_complete(layer.flush())
29+
30+
self.close = original_impl
31+
return self.close(*args, **kwargs)
32+
33+
loop.close = types.MethodType(_wrapper, loop)

0 commit comments

Comments
 (0)