diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 477a198..05eb897 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,7 +3,7 @@ name: Tests on: push: branches: - - master + - main pull_request: jobs: diff --git a/README.rst b/README.rst index e09af6f..a0fd130 100644 --- a/README.rst +++ b/README.rst @@ -31,6 +31,17 @@ Set up the channel layer in your Django settings file like so:: }, } +Or, you can use the alternate implementation which uses Redis Pub/Sub:: + + CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer", + "CONFIG": { + "hosts": [("localhost", 6379)], + }, + }, + } + Possible options for ``CONFIG`` are listed below. ``hosts`` diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py new file mode 100644 index 0000000..16c1d5f --- /dev/null +++ b/channels_redis/pubsub.py @@ -0,0 +1,347 @@ +import asyncio +import logging +import uuid + +import aioredis +import msgpack + +logger = logging.getLogger(__name__) + + +class RedisPubSubChannelLayer: + """ + Channel Layer that uses Redis's pub/sub functionality. + """ + + def __init__(self, hosts=None, prefix="asgi", **kwargs): + if hosts is None: + hosts = [("localhost", 6379)] + assert ( + isinstance(hosts, list) and len(hosts) > 0 + ), "`hosts` must be a list with at least one Redis server" + + self.prefix = prefix + + # Each consumer gets its own *specific* channel, created with the `new_channel()` method. + # This dict maps `channel_name` to a queue of messages for that channel. + self.channels = {} + + # A channel can subscribe to zero or more groups. + # This dict maps `group_name` to set of channel names who are subscribed to that group. + self.groups = {} + + # For each host, we create a `RedisSingleShardConnection` to manage the connection to that host. + self._shards = [RedisSingleShardConnection(host, self) for host in hosts] + + def _get_shard(self, channel_or_group_name): + """ + Return the shard that is used exclusively for this channel or group. + """ + if len(self._shards) == 1: + # Avoid the overhead of hashing and modulo when it is unnecessary. + return self._shards[0] + shard_index = abs(hash(channel_or_group_name)) % len(self._shards) + return self._shards[shard_index] + + def _get_group_channel_name(self, group): + """ + Return the channel name used by a group. + Includes '__group__' in the returned + string so that these names are distinguished + from those returned by `new_channel()`. + Technically collisions are possible, but it + takes what I believe is intentional abuse in + order to have colliding names. + """ + return f"{self.prefix}__group__{group}" + + extensions = ["groups", "flush"] + + ################################################################################ + # Channel layer API + ################################################################################ + + async def send(self, channel, message): + """ + Send a message onto a (general or specific) channel. + """ + shard = self._get_shard(channel) + await shard.publish(channel, message) + + async def new_channel(self, prefix="specific."): + """ + Returns a new channel name that can be used by a consumer in our + process as a specific channel. + """ + channel = f"{self.prefix}{prefix}{uuid.uuid4().hex}" + self.channels[channel] = asyncio.Queue() + shard = self._get_shard(channel) + await shard.subscribe(channel) + return channel + + async def receive(self, channel): + """ + Receive the first message that arrives on the channel. + If more than one coroutine waits on the same channel, a random one + of the waiting coroutines will get the result. + """ + if channel not in self.channels: + raise RuntimeError( + 'You should only call receive() on channels that you "own" and that were created with `new_channel()`.' + ) + + q = self.channels[channel] + + try: + message = await q.get() + except asyncio.CancelledError: + # We assume here that the reason we are cancelled is because the consumer + # is exiting, therefore we need to cleanup by unsubscribe below. Indeed, + # currently the way that Django Channels works, this is a safe assumption. + # In the future, Dajngo Channels could change to call a *new* method that + # would serve as the antithesis of `new_channel()`; this new method might + # be named `delete_channel()`. If that were the case, we would do the + # following cleanup from that new `delete_channel()` method, but, since + # that's not how Django Channels works (yet), we do the cleanup below: + if channel in self.channels: + del self.channels[channel] + try: + shard = self._get_shard(channel) + await shard.unsubscribe(channel) + except BaseException: + logger.exception("Unexpected exception while cleaning-up channel:") + # We don't re-raise here because we want the CancelledError to be the one re-raised. + raise + + return msgpack.unpackb(message) + + ################################################################################ + # Groups extension + ################################################################################ + + async def group_add(self, group, channel): + """ + Adds the channel name to a group. + """ + if channel not in self.channels: + raise RuntimeError( + "You can only call group_add() on channels that exist in-process.\n" + "Consumers are encouraged to use the common pattern:\n" + f" self.channel_layer.group_add({repr(group)}, self.channel_name)" + ) + group_channel = self._get_group_channel_name(group) + if group_channel not in self.groups: + self.groups[group_channel] = set() + group_channels = self.groups[group_channel] + if channel not in group_channels: + group_channels.add(channel) + shard = self._get_shard(group_channel) + await shard.subscribe(group_channel) + + async def group_discard(self, group, channel): + """ + Removes the channel from a group. + """ + group_channel = self._get_group_channel_name(group) + assert group_channel in self.groups + group_channels = self.groups[group_channel] + assert channel in group_channels + group_channels.remove(channel) + if len(group_channels) == 0: + del self.groups[group_channel] + shard = self._get_shard(group_channel) + await shard.unsubscribe(group_channel) + + async def group_send(self, group, message): + """ + Send the message to all subscribers of the group. + """ + group_channel = self._get_group_channel_name(group) + shard = self._get_shard(group_channel) + await shard.publish(group_channel, message) + + ################################################################################ + # Flush extension + ################################################################################ + + async def flush(self): + """ + Flush the layer, making it like new. It can continue to be used as if it + was just created. This also closes connections, serving as a clean-up + method; connections will be re-opened if you continue using this layer. + """ + self.channels = {} + self.groups = {} + for shard in self._shards: + await shard.flush() + + +def on_close_noop(sender, exc=None): + """ + If you don't pass an `on_close` function to the `Receiver`, then it + defaults to one that closes the Receiver whenever the last subscriber + unsubscribes. That is not what we want; instead, we want the Receiver + to continue even if no one is subscribed, because soon someone *will* + subscribe and we want things to continue from there. Passing this + empty function solves it. + """ + pass + + +class RedisSingleShardConnection: + def __init__(self, host, channel_layer): + self.host = host + self.channel_layer = channel_layer + self._subscribed_to = set() + self._lock = None + self._pub_conn = None + self._sub_conn = None + self._receiver = None + self._receive_task = None + self._keepalive_task = None + + async def publish(self, channel, message): + conn = await self._get_pub_conn() + await conn.publish(channel, msgpack.packb(message)) + + async def subscribe(self, channel): + if channel not in self._subscribed_to: + self._subscribed_to.add(channel) + conn = await self._get_sub_conn() + await conn.subscribe(self._receiver.channel(channel)) + + async def unsubscribe(self, channel): + if channel in self._subscribed_to: + self._subscribed_to.remove(channel) + conn = await self._get_sub_conn() + await conn.unsubscribe(channel) + + async def flush(self): + for task in [self._keepalive_task, self._receive_task]: + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._keepalive_task = None + self._receive_task = None + self._receiver = None + if self._sub_conn is not None: + self._sub_conn.close() + await self._sub_conn.wait_closed() + self._sub_conn = None + if self._pub_conn is not None: + self._pub_conn.close() + await self._pub_conn.wait_closed() + self._pub_conn = None + self._subscribed_to = set() + + async def _get_pub_conn(self): + """ + Return the connection to this shard that is used for *publishing* messages. + + If the connection is dead, automatically reconnect. + """ + if self._lock is None: + self._lock = asyncio.Lock() + async with self._lock: + if self._pub_conn is not None and self._pub_conn.closed: + self._pub_conn = None + while self._pub_conn is None: + try: + self._pub_conn = await aioredis.create_redis(self.host) + except BaseException: + logger.warning( + f"Failed to connect to Redis publish host: {self.host}; will try again in 1 second..." + ) + await asyncio.sleep(1) + return self._pub_conn + + async def _get_sub_conn(self): + """ + Return the connection to this shard that is used for *subscribing* to channels. + + If the connection is dead, automatically reconnect and resubscribe to all our channels! + """ + if self._keepalive_task is None: + self._keepalive_task = asyncio.ensure_future(self._do_keepalive()) + if self._lock is None: + self._lock = asyncio.Lock() + async with self._lock: + if self._sub_conn is not None and self._sub_conn.closed: + self._sub_conn = None + if self._sub_conn is None: + if self._receive_task is not None: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + # This is the normal case, that `asyncio.CancelledError` is throw. All good. + pass + except BaseException: + logger.exception( + "Unexpected exception while canceling the receiver task:" + ) + # Don't re-raise here. We don't actually care why `_receive_task` didn't exit cleanly. + self._receive_task = None + while self._sub_conn is None: + try: + self._sub_conn = await aioredis.create_redis(self.host) + except BaseException: + logger.warning( + f"Failed to connect to Redis subscribe host: {self.host}; will try again in 1 second..." + ) + await asyncio.sleep(1) + self._receiver = aioredis.pubsub.Receiver(on_close=on_close_noop) + self._receive_task = asyncio.ensure_future(self._do_receiving()) + if len(self._subscribed_to) > 0: + # Do our best to recover by resubscribing to the channels that we were previously subscribed to. + resubscribe_to = [ + self._receiver.channel(name) for name in self._subscribed_to + ] + await self._sub_conn.subscribe(*resubscribe_to) + return self._sub_conn + + async def _do_receiving(self): + async for ch, message in self._receiver.iter(): + name = ch.name + if isinstance(name, bytes): + # Reversing what happens here: + # https://github.com/aio-libs/aioredis-py/blob/8a207609b7f8a33e74c7c8130d97186e78cc0052/aioredis/util.py#L17 + name = name.decode() + if name in self.channel_layer.channels: + self.channel_layer.channels[name].put_nowait(message) + elif name in self.channel_layer.groups: + for channel_name in self.channel_layer.groups[name]: + if channel_name in self.channel_layer.channels: + self.channel_layer.channels[channel_name].put_nowait(message) + + async def _do_keepalive(self): + """ + This task's simple job is just to call `self._get_sub_conn()` periodically. + + Why? Well, calling `self._get_sub_conn()` has the nice side-effect that if + that connection has died (because Redis was restarted, or there was a networking + hiccup, for example), then calling `self._get_sub_conn()` will reconnect and + restore our old subscriptions. Thus, we want to do this on a predictable schedule. + This is kinda a sub-optimal way to achieve this, but I can't find a way in aioredis + to get a notification when the connection dies. I find this (sub-optimal) method + of checking the connection state works fine for my app; if Redis restarts, we reconnect + and resubscribe *quickly enough*; I mean, Redis restarting is already bad because it + will cause messages to get lost, and this periodic check at least minimizes the + damage *enough*. + + Note you wouldn't need this if you were *sure* that there would be a lot of subscribe/ + unsubscribe events on your site, because such events each call `self._get_sub_conn()`. + Thus, on a site with heavy traffic this task may not be necessary, but also maybe it is. + Why? Well, in a heavy traffic site you probably have more than one Django server replicas, + so it might be the case that one of your replicas is under-utilized and this periodic + connection check will be beneficial in the same way as it is for a low-traffic site. + """ + while True: + await asyncio.sleep(1) + try: + await self._get_sub_conn() + except Exception: + logger.exception("Unexpected exception in keepalive task:") diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py new file mode 100644 index 0000000..f79b9d5 --- /dev/null +++ b/tests/test_pubsub.py @@ -0,0 +1,103 @@ +import asyncio +import random + +import async_timeout +import pytest +from async_generator import async_generator, yield_ + +from channels_redis.pubsub import RedisPubSubChannelLayer + +TEST_HOSTS = [("localhost", 6379)] + + +@pytest.fixture() +@async_generator +async def channel_layer(): + """ + Channel layer fixture that flushes automatically. + """ + channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS) + await yield_(channel_layer) + await channel_layer.flush() + + +@pytest.mark.asyncio +async def test_send_receive(channel_layer): + """ + Makes sure we can send a message to a normal channel then receive it. + """ + channel = await channel_layer.new_channel() + await channel_layer.send(channel, {"type": "test.message", "text": "Ahoy-hoy!"}) + message = await channel_layer.receive(channel) + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" + + +@pytest.mark.asyncio +async def test_multi_send_receive(channel_layer): + """ + Tests overlapping sends and receives, and ordering. + """ + channel = await channel_layer.new_channel() + await channel_layer.send(channel, {"type": "message.1"}) + await channel_layer.send(channel, {"type": "message.2"}) + await channel_layer.send(channel, {"type": "message.3"}) + assert (await channel_layer.receive(channel))["type"] == "message.1" + assert (await channel_layer.receive(channel))["type"] == "message.2" + assert (await channel_layer.receive(channel))["type"] == "message.3" + + +@pytest.mark.asyncio +async def test_groups_basic(channel_layer): + """ + Tests basic group operation. + """ + channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1") + channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2") + channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3") + await channel_layer.group_add("test-group", channel_name1) + await channel_layer.group_add("test-group", channel_name2) + await channel_layer.group_add("test-group", channel_name3) + await channel_layer.group_discard("test-group", channel_name2) + await channel_layer.group_send("test-group", {"type": "message.1"}) + # Make sure we get the message on the two channels that were in + async with async_timeout.timeout(1): + assert (await channel_layer.receive(channel_name1))["type"] == "message.1" + assert (await channel_layer.receive(channel_name3))["type"] == "message.1" + # Make sure the removed channel did not get the message + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(1): + await channel_layer.receive(channel_name2) + + +@pytest.mark.asyncio +async def test_groups_same_prefix(channel_layer): + """ + Tests group_send with multiple channels with same channel prefix + """ + channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan") + channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan") + channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan") + await channel_layer.group_add("test-group", channel_name1) + await channel_layer.group_add("test-group", channel_name2) + await channel_layer.group_add("test-group", channel_name3) + await channel_layer.group_send("test-group", {"type": "message.1"}) + + # Make sure we get the message on the channels that were in + async with async_timeout.timeout(1): + assert (await channel_layer.receive(channel_name1))["type"] == "message.1" + assert (await channel_layer.receive(channel_name2))["type"] == "message.1" + assert (await channel_layer.receive(channel_name3))["type"] == "message.1" + + +@pytest.mark.asyncio +async def test_random_reset__channel_name(channel_layer): + """ + Makes sure resetting random seed does not make us reuse channel names. + """ + random.seed(1) + channel_name_1 = await channel_layer.new_channel() + random.seed(1) + channel_name_2 = await channel_layer.new_channel() + + assert channel_name_1 != channel_name_2