15
15
from channels .exceptions import ChannelFull
16
16
from channels .layers import BaseChannelLayer
17
17
18
- from .utils import _consistent_hash
18
+ from .utils import _consistent_hash , _wrap_close
19
19
20
20
logger = logging .getLogger (__name__ )
21
21
@@ -69,6 +69,27 @@ def put_nowait(self, item):
69
69
return super (BoundedQueue , self ).put_nowait (item )
70
70
71
71
72
+ class RedisLoopLayer :
73
+
74
+ def __init__ (self , channel_layer ):
75
+ self ._lock = asyncio .Lock ()
76
+ self .channel_layer = channel_layer
77
+ self ._connections = {}
78
+
79
+ def get_connection (self , index ):
80
+ if index not in self ._connections :
81
+ pool = self .channel_layer .create_pool (index )
82
+ self ._connections [index ] = aioredis .Redis (connection_pool = pool )
83
+
84
+ return self ._connections [index ]
85
+
86
+ async def flush (self ):
87
+ async with self ._lock :
88
+ for index in list (self ._connections ):
89
+ connection = self ._connections .pop (index )
90
+ await connection .close (close_connection_pool = True )
91
+
92
+
72
93
class RedisChannelLayer (BaseChannelLayer ):
73
94
"""
74
95
Redis channel layer.
@@ -101,8 +122,7 @@ def __init__(
101
122
self .hosts = self .decode_hosts (hosts )
102
123
self .ring_size = len (self .hosts )
103
124
# Cached redis connection pools and the event loop they are from
104
- self .pools = {}
105
- self .pools_loop = None
125
+ self ._layers = {}
106
126
# Normal channels choose a host index by cycling through the available hosts
107
127
self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
108
128
self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -331,7 +351,7 @@ async def receive(self, channel):
331
351
332
352
raise
333
353
334
- message , token , exception = None , None , None
354
+ message = token = exception = None
335
355
for task in done :
336
356
try :
337
357
result = task .result ()
@@ -367,7 +387,7 @@ async def receive(self, channel):
367
387
message_channel , message = await self .receive_single (
368
388
real_channel
369
389
)
370
- if type (message_channel ) is list :
390
+ if isinstance (message_channel , list ) :
371
391
for chan in message_channel :
372
392
self .receive_buffer [chan ].put_nowait (message )
373
393
else :
@@ -459,11 +479,7 @@ async def new_channel(self, prefix="specific"):
459
479
Returns a new channel name that can be used by something in our
460
480
process as a specific channel.
461
481
"""
462
- return "%s.%s!%s" % (
463
- prefix ,
464
- self .client_prefix ,
465
- uuid .uuid4 ().hex ,
466
- )
482
+ return f"{ prefix } .{ self .client_prefix } !{ uuid .uuid4 ().hex } "
467
483
468
484
### Flush extension ###
469
485
@@ -496,9 +512,8 @@ async def close_pools(self):
496
512
# Flush all cleaners, in case somebody just wanted to close the
497
513
# pools without flushing first.
498
514
await self .wait_received ()
499
-
500
- for index in self .pools :
501
- await self .pools [index ].disconnect ()
515
+ for layer in self ._layers .values ():
516
+ await layer .flush ()
502
517
503
518
async def wait_received (self ):
504
519
"""
@@ -667,7 +682,7 @@ def _group_key(self, group):
667
682
"""
668
683
Common function to make the storage key for the group.
669
684
"""
670
- return ( "%s:group:%s" % ( self .prefix , group )) .encode ("utf8" )
685
+ return f" { self .prefix } : group: { group } " .encode ("utf8" )
671
686
672
687
### Serialization ###
673
688
@@ -711,7 +726,7 @@ def make_fernet(self, key):
711
726
return Fernet (formatted_key )
712
727
713
728
def __str__ (self ):
714
- return "%s(hosts=%s)" % ( self .__class__ .__name__ , self .hosts )
729
+ return f" { self .__class__ .__name__ } (hosts= { self .hosts } )"
715
730
716
731
### Connection handling ###
717
732
@@ -723,18 +738,14 @@ def connection(self, index):
723
738
# Catch bad indexes
724
739
if not 0 <= index < self .ring_size :
725
740
raise ValueError (
726
- "There are only %s hosts - you asked for %s!" % ( self . ring_size , index )
741
+ f "There are only { self . ring_size } hosts - you asked for { index } !"
727
742
)
728
743
744
+ loop = asyncio .get_running_loop ()
729
745
try :
730
- loop = asyncio .get_running_loop ()
731
- if self .pools_loop != loop :
732
- self .pools = {}
733
- self .pools_loop = loop
734
- except RuntimeError :
735
- pass
736
-
737
- if index not in self .pools :
738
- self .pools [index ] = self .create_pool (index )
746
+ layer = self ._layers [loop ]
747
+ except KeyError :
748
+ _wrap_close (self , loop )
749
+ layer = self ._layers [loop ] = RedisLoopLayer (self )
739
750
740
- return aioredis . Redis ( connection_pool = self . pools [ index ] )
751
+ return layer . get_connection ( index )
0 commit comments