Skip to content

Commit 38d74dd

Browse files
sevdogShi Feng
authored andcommitted
Refactored Redis connection utilities to share between layers. (django#352)
1 parent 40c1fb1 commit 38d74dd

File tree

3 files changed

+61
-65
lines changed

3 files changed

+61
-65
lines changed

channels_redis/core.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from channels.exceptions import ChannelFull
1616
from channels.layers import BaseChannelLayer
1717

18-
from .utils import _consistent_hash, _wrap_close
18+
from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -120,7 +120,7 @@ def __init__(
120120
self.should_auto_discard_full_channels = should_auto_discard_full_channels
121121
assert isinstance(self.prefix, str), "Prefix must be unicode"
122122
# Configure the host objects
123-
self.hosts = self.decode_hosts(hosts)
123+
self.hosts = decode_hosts(hosts)
124124
self.ring_size = len(self.hosts)
125125
# Cached redis connection pools and the event loop they are from
126126
self._layers = {}
@@ -148,46 +148,7 @@ def __init__(
148148
self.receive_clean_locks = ChannelLock()
149149

150150
def create_pool(self, index):
151-
host = self.hosts[index]
152-
153-
if "address" in host:
154-
return aioredis.ConnectionPool.from_url(host["address"])
155-
elif "master_name" in host:
156-
sentinels = host.pop("sentinels")
157-
master_name = host.pop("master_name")
158-
sentinel_kwargs = host.pop("sentinel_kwargs", None)
159-
return aioredis.sentinel.SentinelConnectionPool(
160-
master_name,
161-
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
162-
**host,
163-
)
164-
else:
165-
return aioredis.ConnectionPool(**host)
166-
167-
def decode_hosts(self, hosts):
168-
"""
169-
Takes the value of the "hosts" argument passed to the class and returns
170-
a list of kwargs to use for the Redis connection constructor.
171-
"""
172-
# If no hosts were provided, return a default value
173-
if not hosts:
174-
return [{"address": "redis://localhost:6379"}]
175-
# If they provided just a string, scold them.
176-
if isinstance(hosts, (str, bytes)):
177-
raise ValueError(
178-
"You must pass a list of Redis hosts, even if there is only one."
179-
)
180-
181-
# Decode each hosts entry into a kwargs dict
182-
result = []
183-
for entry in hosts:
184-
if isinstance(entry, dict):
185-
result.append(entry)
186-
elif isinstance(entry, tuple):
187-
result.append({"host": entry[0], "port": entry[1]})
188-
else:
189-
result.append({"address": entry})
190-
return result
151+
return create_pool(self.hosts[index])
191152

192153
def _setup_encryption(self, symmetric_encryption_keys):
193154
# See if we can do encryption if they asked

channels_redis/pubsub.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import msgpack
77
from redis import asyncio as aioredis
88

9-
from .utils import _consistent_hash, _wrap_close
9+
from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -81,12 +81,6 @@ def __init__(
8181
channel_layer=None,
8282
**kwargs,
8383
):
84-
if hosts is None:
85-
hosts = ["redis://localhost:6379"]
86-
assert (
87-
isinstance(hosts, list) and len(hosts) > 0
88-
), "`hosts` must be a list with at least one Redis server"
89-
9084
self.prefix = prefix
9185

9286
self.on_disconnect = on_disconnect
@@ -102,7 +96,9 @@ def __init__(
10296
self.groups = {}
10397

10498
# For each host, we create a `RedisSingleShardConnection` to manage the connection to that host.
105-
self._shards = [RedisSingleShardConnection(host, self) for host in hosts]
99+
self._shards = [
100+
RedisSingleShardConnection(host, self) for host in decode_hosts(hosts)
101+
]
106102

107103
def _get_shard(self, channel_or_group_name):
108104
"""
@@ -247,9 +243,7 @@ async def flush(self):
247243

248244
class RedisSingleShardConnection:
249245
def __init__(self, host, channel_layer):
250-
self.host = host.copy() if type(host) is dict else {"address": host}
251-
self.master_name = self.host.pop("master_name", None)
252-
self.sentinel_kwargs = self.host.pop("sentinel_kwargs", None)
246+
self.host = host
253247
self.channel_layer = channel_layer
254248
self._subscribed_to = set()
255249
self._lock = asyncio.Lock()
@@ -331,18 +325,7 @@ def _receive_message(self, message):
331325

332326
def _ensure_redis(self):
333327
if self._redis is None:
334-
if self.master_name is None:
335-
pool = aioredis.ConnectionPool.from_url(self.host["address"])
336-
else:
337-
# aioredis default timeout is way too low
338-
pool = aioredis.sentinel.SentinelConnectionPool(
339-
self.master_name,
340-
aioredis.sentinel.Sentinel(
341-
self.host["sentinels"],
342-
socket_timeout=2,
343-
sentinel_kwargs=self.sentinel_kwargs,
344-
),
345-
)
328+
pool = create_pool(self.host)
346329
self._redis = aioredis.Redis(connection_pool=pool)
347330
self._pubsub = self._redis.pubsub()
348331

channels_redis/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import binascii
22
import types
33

4+
from redis import asyncio as aioredis
5+
46

57
def _consistent_hash(value, ring_size):
68
"""
@@ -31,3 +33,53 @@ def _wrapper(self, *args, **kwargs):
3133
return self.close(*args, **kwargs)
3234

3335
loop.close = types.MethodType(_wrapper, loop)
36+
37+
38+
def decode_hosts(hosts):
39+
"""
40+
Takes the value of the "hosts" argument and returns
41+
a list of kwargs to use for the Redis connection constructor.
42+
"""
43+
# If no hosts were provided, return a default value
44+
if not hosts:
45+
return [{"address": "redis://localhost:6379"}]
46+
# If they provided just a string, scold them.
47+
if isinstance(hosts, (str, bytes)):
48+
raise ValueError(
49+
"You must pass a list of Redis hosts, even if there is only one."
50+
)
51+
52+
# Decode each hosts entry into a kwargs dict
53+
result = []
54+
for entry in hosts:
55+
if isinstance(entry, dict):
56+
result.append(entry)
57+
elif isinstance(entry, (tuple, list)):
58+
result.append({"host": entry[0], "port": entry[1]})
59+
else:
60+
result.append({"address": entry})
61+
return result
62+
63+
64+
def create_pool(host):
65+
"""
66+
Takes the value of the "host" argument and returns a suited connection pool to
67+
the corresponding redis instance.
68+
"""
69+
# avoid side-effects from modifying host
70+
host = host.copy()
71+
if "address" in host:
72+
address = host.pop("address")
73+
return aioredis.ConnectionPool.from_url(address, **host)
74+
75+
master_name = host.pop("master_name", None)
76+
if master_name is not None:
77+
sentinels = host.pop("sentinels")
78+
sentinel_kwargs = host.pop("sentinel_kwargs", None)
79+
return aioredis.sentinel.SentinelConnectionPool(
80+
master_name,
81+
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
82+
**host
83+
)
84+
85+
return aioredis.ConnectionPool(**host)

0 commit comments

Comments
 (0)