1515from channels .exceptions import ChannelFull
1616from channels .layers import BaseChannelLayer
1717
18- from .utils import _consistent_hash
18+ from .utils import _consistent_hash , _wrap_close
1919
2020logger = logging .getLogger (__name__ )
2121
@@ -69,6 +69,26 @@ def put_nowait(self, item):
6969 return super (BoundedQueue , self ).put_nowait (item )
7070
7171
72+ class RedisLoopLayer :
73+ def __init__ (self , channel_layer ):
74+ self ._lock = asyncio .Lock ()
75+ self .channel_layer = channel_layer
76+ self ._connections = {}
77+
78+ def get_connection (self , index ):
79+ if index not in self ._connections :
80+ pool = self .channel_layer .create_pool (index )
81+ self ._connections [index ] = aioredis .Redis (connection_pool = pool )
82+
83+ return self ._connections [index ]
84+
85+ async def flush (self ):
86+ async with self ._lock :
87+ for index in list (self ._connections ):
88+ connection = self ._connections .pop (index )
89+ await connection .close (close_connection_pool = True )
90+
91+
7292class RedisChannelLayer (BaseChannelLayer ):
7393 """
7494 Redis channel layer.
@@ -103,8 +123,7 @@ def __init__(
103123 self .hosts = self .decode_hosts (hosts )
104124 self .ring_size = len (self .hosts )
105125 # Cached redis connection pools and the event loop they are from
106- self .pools = {}
107- self .pools_loop = None
126+ self ._layers = {}
108127 # Normal channels choose a host index by cycling through the available hosts
109128 self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
110129 self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -140,7 +159,7 @@ def create_pool(self, index):
140159 return aioredis .sentinel .SentinelConnectionPool (
141160 master_name ,
142161 aioredis .sentinel .Sentinel (sentinels , sentinel_kwargs = sentinel_kwargs ),
143- ** host
162+ ** host ,
144163 )
145164 else :
146165 return aioredis .ConnectionPool (** host )
@@ -333,7 +352,7 @@ async def receive(self, channel):
333352
334353 raise
335354
336- message , token , exception = None , None , None
355+ message = token = exception = None
337356 for task in done :
338357 try :
339358 result = task .result ()
@@ -369,7 +388,7 @@ async def receive(self, channel):
369388 message_channel , message = await self .receive_single (
370389 real_channel
371390 )
372- if type (message_channel ) is list :
391+ if isinstance (message_channel , list ) :
373392 for chan in message_channel :
374393 self .receive_buffer [chan ].put_nowait (message )
375394 else :
@@ -461,11 +480,7 @@ async def new_channel(self, prefix="specific"):
461480 Returns a new channel name that can be used by something in our
462481 process as a specific channel.
463482 """
464- return "%s.%s!%s" % (
465- prefix ,
466- self .client_prefix ,
467- uuid .uuid4 ().hex ,
468- )
483+ return f"{ prefix } .{ self .client_prefix } !{ uuid .uuid4 ().hex } "
469484
470485 ### Flush extension ###
471486
@@ -498,9 +513,8 @@ async def close_pools(self):
498513 # Flush all cleaners, in case somebody just wanted to close the
499514 # pools without flushing first.
500515 await self .wait_received ()
501-
502- for index in self .pools :
503- await self .pools [index ].disconnect ()
516+ for layer in self ._layers .values ():
517+ await layer .flush ()
504518
505519 async def wait_received (self ):
506520 """
@@ -703,7 +717,7 @@ def _group_key(self, group):
703717 """
704718 Common function to make the storage key for the group.
705719 """
706- return ( "%s:group:%s" % ( self .prefix , group )) .encode ("utf8" )
720+ return f" { self .prefix } : group: { group } " .encode ("utf8" )
707721
708722 ### Serialization ###
709723
@@ -747,7 +761,7 @@ def make_fernet(self, key):
747761 return Fernet (formatted_key )
748762
749763 def __str__ (self ):
750- return "%s(hosts=%s)" % ( self .__class__ .__name__ , self .hosts )
764+ return f" { self .__class__ .__name__ } (hosts= { self .hosts } )"
751765
752766 ### Connection handling ###
753767
@@ -759,18 +773,14 @@ def connection(self, index):
759773 # Catch bad indexes
760774 if not 0 <= index < self .ring_size :
761775 raise ValueError (
762- "There are only %s hosts - you asked for %s!" % ( self . ring_size , index )
776+ f "There are only { self . ring_size } hosts - you asked for { index } !"
763777 )
764778
779+ loop = asyncio .get_running_loop ()
765780 try :
766- loop = asyncio .get_running_loop ()
767- if self .pools_loop != loop :
768- self .pools = {}
769- self .pools_loop = loop
770- except RuntimeError :
771- pass
772-
773- if index not in self .pools :
774- self .pools [index ] = self .create_pool (index )
781+ layer = self ._layers [loop ]
782+ except KeyError :
783+ _wrap_close (self , loop )
784+ layer = self ._layers [loop ] = RedisLoopLayer (self )
775785
776- return aioredis . Redis ( connection_pool = self . pools [ index ] )
786+ return layer . get_connection ( index )
0 commit comments