Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InMemoryChannelLayer improvements, test fixes #1976

Merged
merged 8 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def __init__(
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
**kwargs,
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
**kwargs,
)
self.channels = {}
self.groups = {}
Expand All @@ -225,13 +225,14 @@ async def send(self, channel, message):
# name in message
assert "__asgi_channel__" not in message

queue = self.channels.setdefault(channel, asyncio.Queue())
# Are we full
if queue.qsize() >= self.capacity:
raise ChannelFull(channel)

queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)
bigfootjon marked this conversation as resolved.
Show resolved Hide resolved
# Add message
await queue.put((time.time() + self.expiry, deepcopy(message)))
try:
queue.put_nowait((time.time() + self.expiry, deepcopy(message)))
except asyncio.queues.QueueFull:
raise ChannelFull(channel)

async def receive(self, channel):
"""
Expand All @@ -242,14 +243,16 @@ async def receive(self, channel):
assert self.valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(channel, asyncio.Queue())
queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)

# Do a plain direct receive
try:
_, message = await queue.get()
finally:
if queue.empty():
del self.channels[channel]
self.channels.pop(channel, None)
devkral marked this conversation as resolved.
Show resolved Hide resolved

return message

Expand Down Expand Up @@ -279,19 +282,17 @@ def _clean_expired(self):
self._remove_from_groups(channel)
# Is the channel now empty and needs deleting?
if queue.empty():
del self.channels[channel]
self.channels.pop(channel, None)

# Group Expiration
timeout = int(time.time()) - self.group_expiry
for group in self.groups:
for channel in list(self.groups.get(group, set())):
# If join time is older than group_expiry end the group membership
if (
self.groups[group][channel]
and int(self.groups[group][channel]) < timeout
):
for channels in self.groups.values():
for name, timestamp in list(channels.items()):
devkral marked this conversation as resolved.
Show resolved Hide resolved
# If join time is older than group_expiry
# end the group membership
if timestamp and timestamp < timeout:
# Delete from group
del self.groups[group][channel]
channels.pop(name, None)

# Flush extension

Expand All @@ -308,8 +309,7 @@ def _remove_from_groups(self, channel):
Removes a channel from all groups. Used when a message on it expires.
"""
for channels in self.groups.values():
if channel in channels:
del channels[channel]
channels.pop(channel, None)

# Groups extension

Expand All @@ -329,22 +329,29 @@ async def group_discard(self, group, channel):
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
# Remove from group set
if group in self.groups:
if channel in self.groups[group]:
del self.groups[group][channel]
if not self.groups[group]:
del self.groups[group]
group_channels = self.groups.get(group, None)
if group_channels:
# remove channel if in group
group_channels.pop(channel, None)
# is group now empty? If yes remove it
if not group_channels:
self.groups.pop(group, None)

async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
# Run clean
self._clean_expired()

# Send to each channel
for channel in self.groups.get(group, set()):
ops = []
if group in self.groups:
for channel in self.groups[group].keys():
ops.append(asyncio.create_task(self.send(channel, message)))
for send_result in asyncio.as_completed(ops):
try:
await self.send(channel, message)
await send_result
except ChannelFull:
pass

Expand Down
30 changes: 27 additions & 3 deletions tests/test_inmemorychannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,36 @@ async def test_send_receive(channel_layer):
await channel_layer.send(
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
)
await channel_layer.send(
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
)
message = await channel_layer.receive("test-channel-1")
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"
# not removed because not empty
assert "test-channel-1" in channel_layer.channels
message = await channel_layer.receive("test-channel-1")
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"
# removed because empty
assert "test-channel-1" not in channel_layer.channels


@pytest.mark.asyncio
async def test_race_empty(channel_layer):
"""
Makes sure the race is handled gracefully.
"""
receive_task = asyncio.create_task(channel_layer.receive("test-channel-1"))
await asyncio.sleep(0.1)
await channel_layer.send(
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
)
del channel_layer.channels["test-channel-1"]
await asyncio.sleep(0.1)
message = await receive_task
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"


@pytest.mark.asyncio
Expand Down Expand Up @@ -62,7 +89,6 @@ async def test_multi_send_receive(channel_layer):
"""
Tests overlapping sends and receives, and ordering.
"""
devkral marked this conversation as resolved.
Show resolved Hide resolved
channel_layer = InMemoryChannelLayer()
bigfootjon marked this conversation as resolved.
Show resolved Hide resolved
await channel_layer.send("test-channel-3", {"type": "message.1"})
await channel_layer.send("test-channel-3", {"type": "message.2"})
await channel_layer.send("test-channel-3", {"type": "message.3"})
Expand All @@ -76,7 +102,6 @@ async def test_groups_basic(channel_layer):
"""
Tests basic group operation.
"""
channel_layer = InMemoryChannelLayer()
await channel_layer.group_add("test-group", "test-gr-chan-1")
await channel_layer.group_add("test-group", "test-gr-chan-2")
await channel_layer.group_add("test-group", "test-gr-chan-3")
Expand All @@ -97,7 +122,6 @@ async def test_groups_channel_full(channel_layer):
"""
Tests that group_send ignores ChannelFull
"""
channel_layer = InMemoryChannelLayer()
await channel_layer.group_add("test-group", "test-gr-chan-1")
await channel_layer.group_send("test-group", {"type": "message.1"})
await channel_layer.group_send("test-group", {"type": "message.1"})
Expand Down