15
15
from channels .exceptions import ChannelFull
16
16
from channels .layers import BaseChannelLayer
17
17
18
- from .utils import _consistent_hash
18
+ from .utils import _consistent_hash , _wrap_close
19
19
20
20
logger = logging .getLogger (__name__ )
21
21
@@ -69,6 +69,26 @@ def put_nowait(self, item):
69
69
return super (BoundedQueue , self ).put_nowait (item )
70
70
71
71
72
+ class RedisLoopLayer :
73
+ def __init__ (self , channel_layer ):
74
+ self ._lock = asyncio .Lock ()
75
+ self .channel_layer = channel_layer
76
+ self ._connections = {}
77
+
78
+ def get_connection (self , index ):
79
+ if index not in self ._connections :
80
+ pool = self .channel_layer .create_pool (index )
81
+ self ._connections [index ] = aioredis .Redis (connection_pool = pool )
82
+
83
+ return self ._connections [index ]
84
+
85
+ async def flush (self ):
86
+ async with self ._lock :
87
+ for index in list (self ._connections ):
88
+ connection = self ._connections .pop (index )
89
+ await connection .close (close_connection_pool = True )
90
+
91
+
72
92
class RedisChannelLayer (BaseChannelLayer ):
73
93
"""
74
94
Redis channel layer.
@@ -101,8 +121,7 @@ def __init__(
101
121
self .hosts = self .decode_hosts (hosts )
102
122
self .ring_size = len (self .hosts )
103
123
# Cached redis connection pools and the event loop they are from
104
- self .pools = {}
105
- self .pools_loop = None
124
+ self ._layers = {}
106
125
# Normal channels choose a host index by cycling through the available hosts
107
126
self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
108
127
self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -138,7 +157,7 @@ def create_pool(self, index):
138
157
return aioredis .sentinel .SentinelConnectionPool (
139
158
master_name ,
140
159
aioredis .sentinel .Sentinel (sentinels , sentinel_kwargs = sentinel_kwargs ),
141
- ** host
160
+ ** host ,
142
161
)
143
162
else :
144
163
return aioredis .ConnectionPool (** host )
@@ -331,7 +350,7 @@ async def receive(self, channel):
331
350
332
351
raise
333
352
334
- message , token , exception = None , None , None
353
+ message = token = exception = None
335
354
for task in done :
336
355
try :
337
356
result = task .result ()
@@ -367,7 +386,7 @@ async def receive(self, channel):
367
386
message_channel , message = await self .receive_single (
368
387
real_channel
369
388
)
370
- if type (message_channel ) is list :
389
+ if isinstance (message_channel , list ) :
371
390
for chan in message_channel :
372
391
self .receive_buffer [chan ].put_nowait (message )
373
392
else :
@@ -459,11 +478,7 @@ async def new_channel(self, prefix="specific"):
459
478
Returns a new channel name that can be used by something in our
460
479
process as a specific channel.
461
480
"""
462
- return "%s.%s!%s" % (
463
- prefix ,
464
- self .client_prefix ,
465
- uuid .uuid4 ().hex ,
466
- )
481
+ return f"{ prefix } .{ self .client_prefix } !{ uuid .uuid4 ().hex } "
467
482
468
483
### Flush extension ###
469
484
@@ -496,9 +511,8 @@ async def close_pools(self):
496
511
# Flush all cleaners, in case somebody just wanted to close the
497
512
# pools without flushing first.
498
513
await self .wait_received ()
499
-
500
- for index in self .pools :
501
- await self .pools [index ].disconnect ()
514
+ for layer in self ._layers .values ():
515
+ await layer .flush ()
502
516
503
517
async def wait_received (self ):
504
518
"""
@@ -667,7 +681,7 @@ def _group_key(self, group):
667
681
"""
668
682
Common function to make the storage key for the group.
669
683
"""
670
- return ( "%s:group:%s" % ( self .prefix , group )) .encode ("utf8" )
684
+ return f" { self .prefix } : group: { group } " .encode ("utf8" )
671
685
672
686
### Serialization ###
673
687
@@ -711,7 +725,7 @@ def make_fernet(self, key):
711
725
return Fernet (formatted_key )
712
726
713
727
def __str__ (self ):
714
- return "%s(hosts=%s)" % ( self .__class__ .__name__ , self .hosts )
728
+ return f" { self .__class__ .__name__ } (hosts= { self .hosts } )"
715
729
716
730
### Connection handling ###
717
731
@@ -723,18 +737,14 @@ def connection(self, index):
723
737
# Catch bad indexes
724
738
if not 0 <= index < self .ring_size :
725
739
raise ValueError (
726
- "There are only %s hosts - you asked for %s!" % ( self . ring_size , index )
740
+ f "There are only { self . ring_size } hosts - you asked for { index } !"
727
741
)
728
742
743
+ loop = asyncio .get_running_loop ()
729
744
try :
730
- loop = asyncio .get_running_loop ()
731
- if self .pools_loop != loop :
732
- self .pools = {}
733
- self .pools_loop = loop
734
- except RuntimeError :
735
- pass
736
-
737
- if index not in self .pools :
738
- self .pools [index ] = self .create_pool (index )
745
+ layer = self ._layers [loop ]
746
+ except KeyError :
747
+ _wrap_close (self , loop )
748
+ layer = self ._layers [loop ] = RedisLoopLayer (self )
739
749
740
- return aioredis . Redis ( connection_pool = self . pools [ index ] )
750
+ return layer . get_connection ( index )
0 commit comments