From 8244d70411fd2be573446b328fcc5fafb92002b7 Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 27 Jun 2023 13:04:59 +0200 Subject: [PATCH 1/3] feature: improve layer typings, add method stubs --- channels/layers.py | 101 +++++++++++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 30 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 12bbd2b8..e64520da 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import fnmatch import random @@ -5,6 +7,7 @@ import string import time from copy import deepcopy +from typing import Dict, Iterable, List, Optional, Tuple from django.conf import settings from django.core.signals import setting_changed @@ -20,6 +23,8 @@ class ChannelLayerManager: Takes a settings dictionary of backends and initialises them on request. """ + backends: Dict[str, BaseChannelLayer] + def __init__(self): self.backends = {} setting_changed.connect(self._reset_backends) @@ -36,14 +41,14 @@ def configs(self): # Lazy load settings so we can be imported return getattr(settings, "CHANNEL_LAYERS", {}) - def make_backend(self, name): + def make_backend(self, name) -> BaseChannelLayer: """ Instantiate channel layer. """ config = self.configs[name].get("CONFIG", {}) return self._make_backend(name, config) - def make_test_backend(self, name): + def make_test_backend(self, name) -> BaseChannelLayer: """ Instantiate channel layer using its test config. """ @@ -53,7 +58,7 @@ def make_test_backend(self, name): raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name) return self._make_backend(name, config) - def _make_backend(self, name, config): + def _make_backend(self, name, config) -> BaseChannelLayer: # Check for old format config if "ROUTING" in self.configs[name]: raise InvalidChannelLayerError( @@ -81,7 +86,7 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.configs - def set(self, key, layer): + def set(self, key: str, layer: BaseChannelLayer): """ Sets an alias to point to a new ChannelLayerWrapper instance, and returns the old one that it replaced. Useful for swapping out the @@ -99,13 +104,21 @@ class BaseChannelLayer: """ MAX_NAME_LENGTH = 100 + extensions: Iterable[str] = () - def __init__(self, expiry=60, capacity=100, channel_capacity=None): + def __init__( + self, + expiry=60, + capacity: Optional[int] = 100, + channel_capacity: Optional[int] = None, + ): self.expiry = expiry self.capacity = capacity self.channel_capacity = channel_capacity or {} - def compile_capacities(self, channel_capacity): + def compile_capacities( + self, channel_capacity + ) -> List[Tuple[re.Pattern, Optional[int]]]: """ Takes an input channel_capacity dict and returns the compiled list of regexes that get_capacity will look for as self.channel_capacity @@ -120,7 +133,7 @@ def compile_capacities(self, channel_capacity): result.append((re.compile(fnmatch.translate(pattern)), value)) return result - def get_capacity(self, channel): + def get_capacity(self, channel: str) -> Optional[int]: """ Gets the correct capacity for the given channel; either the default, or a matching result from channel_capacity. Returns the first matching @@ -132,7 +145,7 @@ def get_capacity(self, channel): return capacity return self.capacity - def match_type_and_length(self, name): + def match_type_and_length(self, name) -> bool: if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH): return True return False @@ -148,7 +161,7 @@ def match_type_and_length(self, name): + "not {}" ) - def valid_channel_name(self, name, receive=False): + def valid_channel_name(self, name: str, receive=False) -> bool: if self.match_type_and_length(name): if bool(self.channel_name_regex.match(name)): # Check cases for special channels @@ -159,13 +172,13 @@ def valid_channel_name(self, name, receive=False): return True raise TypeError(self.invalid_name_error.format("Channel", name)) - def valid_group_name(self, name): + def valid_group_name(self, name: str) -> bool: if self.match_type_and_length(name): if bool(self.group_name_regex.match(name)): return True raise TypeError(self.invalid_name_error.format("Group", name)) - def valid_channel_names(self, names, receive=False): + def valid_channel_names(self, names: List[str], receive=False) -> bool: _non_empty_list = True if names else False _names_type = isinstance(names, list) assert _non_empty_list and _names_type, "names must be a non-empty list" @@ -175,7 +188,7 @@ def valid_channel_names(self, names, receive=False): ) return True - def non_local_name(self, name): + def non_local_name(self, name: str) -> str: """ Given a channel name, returns the "non-local" part. If the channel name is a process-specific channel (contains !) this means the part up to @@ -186,6 +199,49 @@ def non_local_name(self, name): else: return name + async def send(self, channel: str, message: dict): + """ + Send a message onto a (general or specific) channel. + """ + raise NotImplementedError() + + async def receive(self, channel: str) -> dict: + """ + Receive the first message that arrives on the channel. + If more than one coroutine waits on the same channel, a random one + of the waiting coroutines will get the result. + """ + raise NotImplementedError() + + async def new_channel(self, prefix: str = "specific.") -> str: + """ + Returns a new channel name that can be used by something in our + process as a specific channel. + """ + raise NotImplementedError() + + # Flush extension + + async def flush(self): + raise NotImplementedError() + + async def close(self): + raise NotImplementedError() + + # Groups extension + + async def group_add(self, group: str, channel: str): + """ + Adds the channel name to a group. + """ + raise NotImplementedError() + + async def group_discard(self, group: str, channel: str): + raise NotImplementedError() + + async def group_send(self, group: str, message: dict): + raise NotImplementedError() + class InMemoryChannelLayer(BaseChannelLayer): """ @@ -198,13 +254,13 @@ def __init__( group_expiry=86400, capacity=100, channel_capacity=None, - **kwargs + **kwargs, ): super().__init__( expiry=expiry, capacity=capacity, channel_capacity=channel_capacity, - **kwargs + **kwargs, ) self.channels = {} self.groups = {} @@ -215,9 +271,6 @@ def __init__( extensions = ["groups", "flush"] async def send(self, channel, message): - """ - Send a message onto a (general or specific) channel. - """ # Typecheck assert isinstance(message, dict), "message is not a dict" assert self.valid_channel_name(channel), "Channel name not valid" @@ -234,11 +287,6 @@ async def send(self, channel, message): await queue.put((time.time() + self.expiry, deepcopy(message))) async def receive(self, channel): - """ - Receive the first message that arrives on the channel. - If more than one coroutine waits on the same channel, a random one - of the waiting coroutines will get the result. - """ assert self.valid_channel_name(channel) self._clean_expired() @@ -254,10 +302,6 @@ async def receive(self, channel): return message async def new_channel(self, prefix="specific."): - """ - Returns a new channel name that can be used by something in our - process as a specific channel. - """ return "%s.inmemory!%s" % ( prefix, "".join(random.choice(string.ascii_letters) for i in range(12)), @@ -314,9 +358,6 @@ def _remove_from_groups(self, channel): # Groups extension async def group_add(self, group, channel): - """ - Adds the channel name to a group. - """ # Check the inputs assert self.valid_group_name(group), "Group name not valid" assert self.valid_channel_name(channel), "Channel name not valid" @@ -349,7 +390,7 @@ async def group_send(self, group, message): pass -def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER): +def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER) -> Optional[BaseChannelLayer]: """ Returns a channel layer by alias, or None if it is not configured. """ From 39835c2852f0289d172ca61e37c024ea53509524 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 29 Jan 2024 05:14:35 +0100 Subject: [PATCH 2/3] add extensions as runtime-checkable protocols, make BaseChannelLayer abstract --- channels/layers.py | 76 +++++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index e64520da..20b91c43 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -6,8 +6,18 @@ import re import string import time +from abc import ABC, abstractmethod from copy import deepcopy -from typing import Dict, Iterable, List, Optional, Tuple +from typing import ( + Dict, + Iterable, + List, + NoReturn, + Optional, + Protocol, + Tuple, + runtime_checkable, +) from django.conf import settings from django.core.signals import setting_changed @@ -97,7 +107,39 @@ def set(self, key: str, layer: BaseChannelLayer): return old -class BaseChannelLayer: +@runtime_checkable +class WithFlushExtension(Protocol): + async def flush(self) -> NoReturn: + """ + Clears messages and if available groups + """ + + async def close(self) -> NoReturn: + """ + Close connection to the layer. Called before stopping layer. + Unusable after. + """ + + +@runtime_checkable +class WithGroupsExtension(Protocol): + async def group_add(self, group: str, channel: str): + """ + Adds the channel name to a group. + """ + + async def group_discard(self, group: str, channel: str) -> NoReturn: + """ + Removes the channel name from a group when it exists. + """ + + async def group_send(self, group: str, message: dict) -> NoReturn: + """ + Sends message to group + """ + + +class BaseChannelLayer(ABC): """ Base channel layer class that others can inherit from, with useful common functionality. @@ -199,51 +241,29 @@ def non_local_name(self, name: str) -> str: else: return name + @abstractmethod async def send(self, channel: str, message: dict): """ Send a message onto a (general or specific) channel. """ - raise NotImplementedError() + @abstractmethod async def receive(self, channel: str) -> dict: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, a random one of the waiting coroutines will get the result. """ - raise NotImplementedError() + @abstractmethod async def new_channel(self, prefix: str = "specific.") -> str: """ Returns a new channel name that can be used by something in our process as a specific channel. """ - raise NotImplementedError() - - # Flush extension - - async def flush(self): - raise NotImplementedError() - - async def close(self): - raise NotImplementedError() - - # Groups extension - - async def group_add(self, group: str, channel: str): - """ - Adds the channel name to a group. - """ - raise NotImplementedError() - - async def group_discard(self, group: str, channel: str): - raise NotImplementedError() - - async def group_send(self, group: str, message: dict): - raise NotImplementedError() -class InMemoryChannelLayer(BaseChannelLayer): +class InMemoryChannelLayer(WithFlushExtension, WithGroupsExtension, BaseChannelLayer): """ In-memory channel layer implementation """ From e4bd0cf625dbf25c5adcb7a35f7391ed61d9f6ce Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 29 Jan 2024 05:21:30 +0100 Subject: [PATCH 3/3] fix tests and add typings for attributes --- channels/layers.py | 12 ++++++++++-- tests/test_layers.py | 14 +++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 20b91c43..1f418544 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -147,10 +147,13 @@ class BaseChannelLayer(ABC): MAX_NAME_LENGTH = 100 extensions: Iterable[str] = () + expiry: int + capacity: int + channel_capacity: Dict[str, int] def __init__( self, - expiry=60, + expiry: int = 60, capacity: Optional[int] = 100, channel_capacity: Optional[int] = None, ): @@ -263,7 +266,12 @@ async def new_channel(self, prefix: str = "specific.") -> str: """ -class InMemoryChannelLayer(WithFlushExtension, WithGroupsExtension, BaseChannelLayer): +# WARNING: Protocols must be last +class InMemoryChannelLayer( + BaseChannelLayer, + WithFlushExtension, + WithGroupsExtension, +): """ In-memory channel layer implementation """ diff --git a/tests/test_layers.py b/tests/test_layers.py index 543a9f19..3ebbc287 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -13,6 +13,18 @@ ) +# when starting with Test it would be tried to collect by pytest +class StubChannelLayer(BaseChannelLayer): + async def send(self, channel: str, message: dict): + raise NotImplementedError() + + async def receive(self, channel: str) -> dict: + raise NotImplementedError() + + async def new_channel(self, prefix: str = "specific.") -> str: + raise NotImplementedError() + + class TestChannelLayerManager(unittest.TestCase): @override_settings( CHANNEL_LAYERS={"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}} @@ -72,7 +84,7 @@ async def test_send_receive(): @pytest.mark.parametrize( "method", - [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], + [StubChannelLayer().valid_channel_name, StubChannelLayer().valid_group_name], ) @pytest.mark.parametrize( "channel_name,expected_valid",