1515from channels .exceptions import ChannelFull
1616from channels .layers import BaseChannelLayer
1717
18- from .utils import _consistent_hash
18+ from .utils import _consistent_hash , _wrap_close
1919
2020logger = logging .getLogger (__name__ )
2121
@@ -69,6 +69,26 @@ def put_nowait(self, item):
6969 return super (BoundedQueue , self ).put_nowait (item )
7070
7171
72+ class RedisLoopLayer :
73+ def __init__ (self , channel_layer ):
74+ self ._lock = asyncio .Lock ()
75+ self .channel_layer = channel_layer
76+ self ._connections = {}
77+
78+ def get_connection (self , index ):
79+ if index not in self ._connections :
80+ pool = self .channel_layer .create_pool (index )
81+ self ._connections [index ] = aioredis .Redis (connection_pool = pool )
82+
83+ return self ._connections [index ]
84+
85+ async def flush (self ):
86+ async with self ._lock :
87+ for index in list (self ._connections ):
88+ connection = self ._connections .pop (index )
89+ await connection .close (close_connection_pool = True )
90+
91+
7292class RedisChannelLayer (BaseChannelLayer ):
7393 """
7494 Redis channel layer.
@@ -101,8 +121,7 @@ def __init__(
101121 self .hosts = self .decode_hosts (hosts )
102122 self .ring_size = len (self .hosts )
103123 # Cached redis connection pools and the event loop they are from
104- self .pools = {}
105- self .pools_loop = None
124+ self ._layers = {}
106125 # Normal channels choose a host index by cycling through the available hosts
107126 self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
108127 self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -138,7 +157,7 @@ def create_pool(self, index):
138157 return aioredis .sentinel .SentinelConnectionPool (
139158 master_name ,
140159 aioredis .sentinel .Sentinel (sentinels , sentinel_kwargs = sentinel_kwargs ),
141- ** host
160+ ** host ,
142161 )
143162 else :
144163 return aioredis .ConnectionPool (** host )
@@ -331,7 +350,7 @@ async def receive(self, channel):
331350
332351 raise
333352
334- message , token , exception = None , None , None
353+ message = token = exception = None
335354 for task in done :
336355 try :
337356 result = task .result ()
@@ -367,7 +386,7 @@ async def receive(self, channel):
367386 message_channel , message = await self .receive_single (
368387 real_channel
369388 )
370- if type (message_channel ) is list :
389+ if isinstance (message_channel , list ) :
371390 for chan in message_channel :
372391 self .receive_buffer [chan ].put_nowait (message )
373392 else :
@@ -459,11 +478,7 @@ async def new_channel(self, prefix="specific"):
459478 Returns a new channel name that can be used by something in our
460479 process as a specific channel.
461480 """
462- return "%s.%s!%s" % (
463- prefix ,
464- self .client_prefix ,
465- uuid .uuid4 ().hex ,
466- )
481+ return f"{ prefix } .{ self .client_prefix } !{ uuid .uuid4 ().hex } "
467482
468483 ### Flush extension ###
469484
@@ -496,9 +511,8 @@ async def close_pools(self):
496511 # Flush all cleaners, in case somebody just wanted to close the
497512 # pools without flushing first.
498513 await self .wait_received ()
499-
500- for index in self .pools :
501- await self .pools [index ].disconnect ()
514+ for layer in self ._layers .values ():
515+ await layer .flush ()
502516
503517 async def wait_received (self ):
504518 """
@@ -667,7 +681,7 @@ def _group_key(self, group):
667681 """
668682 Common function to make the storage key for the group.
669683 """
670- return ( "%s:group:%s" % ( self .prefix , group )) .encode ("utf8" )
684+ return f" { self .prefix } : group: { group } " .encode ("utf8" )
671685
672686 ### Serialization ###
673687
@@ -711,7 +725,7 @@ def make_fernet(self, key):
711725 return Fernet (formatted_key )
712726
713727 def __str__ (self ):
714- return "%s(hosts=%s)" % ( self .__class__ .__name__ , self .hosts )
728+ return f" { self .__class__ .__name__ } (hosts= { self .hosts } )"
715729
716730 ### Connection handling ###
717731
@@ -723,18 +737,14 @@ def connection(self, index):
723737 # Catch bad indexes
724738 if not 0 <= index < self .ring_size :
725739 raise ValueError (
726- "There are only %s hosts - you asked for %s!" % ( self . ring_size , index )
740+ f "There are only { self . ring_size } hosts - you asked for { index } !"
727741 )
728742
743+ loop = asyncio .get_running_loop ()
729744 try :
730- loop = asyncio .get_running_loop ()
731- if self .pools_loop != loop :
732- self .pools = {}
733- self .pools_loop = loop
734- except RuntimeError :
735- pass
736-
737- if index not in self .pools :
738- self .pools [index ] = self .create_pool (index )
745+ layer = self ._layers [loop ]
746+ except KeyError :
747+ _wrap_close (self , loop )
748+ layer = self ._layers [loop ] = RedisLoopLayer (self )
739749
740- return aioredis . Redis ( connection_pool = self . pools [ index ] )
750+ return layer . get_connection ( index )
0 commit comments