diff --git a/channels/layers.py b/channels/layers.py index 12bbd2b8..1f418544 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -1,10 +1,23 @@ +from __future__ import annotations + import asyncio import fnmatch import random import re import string import time +from abc import ABC, abstractmethod from copy import deepcopy +from typing import ( + Dict, + Iterable, + List, + NoReturn, + Optional, + Protocol, + Tuple, + runtime_checkable, +) from django.conf import settings from django.core.signals import setting_changed @@ -20,6 +33,8 @@ class ChannelLayerManager: Takes a settings dictionary of backends and initialises them on request. """ + backends: Dict[str, BaseChannelLayer] + def __init__(self): self.backends = {} setting_changed.connect(self._reset_backends) @@ -36,14 +51,14 @@ def configs(self): # Lazy load settings so we can be imported return getattr(settings, "CHANNEL_LAYERS", {}) - def make_backend(self, name): + def make_backend(self, name) -> BaseChannelLayer: """ Instantiate channel layer. """ config = self.configs[name].get("CONFIG", {}) return self._make_backend(name, config) - def make_test_backend(self, name): + def make_test_backend(self, name) -> BaseChannelLayer: """ Instantiate channel layer using its test config. """ @@ -53,7 +68,7 @@ def make_test_backend(self, name): raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name) return self._make_backend(name, config) - def _make_backend(self, name, config): + def _make_backend(self, name, config) -> BaseChannelLayer: # Check for old format config if "ROUTING" in self.configs[name]: raise InvalidChannelLayerError( @@ -81,7 +96,7 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.configs - def set(self, key, layer): + def set(self, key: str, layer: BaseChannelLayer): """ Sets an alias to point to a new ChannelLayerWrapper instance, and returns the old one that it replaced. Useful for swapping out the @@ -92,20 +107,63 @@ def set(self, key, layer): return old -class BaseChannelLayer: +@runtime_checkable +class WithFlushExtension(Protocol): + async def flush(self) -> NoReturn: + """ + Clears messages and if available groups + """ + + async def close(self) -> NoReturn: + """ + Close connection to the layer. Called before stopping layer. + Unusable after. + """ + + +@runtime_checkable +class WithGroupsExtension(Protocol): + async def group_add(self, group: str, channel: str): + """ + Adds the channel name to a group. + """ + + async def group_discard(self, group: str, channel: str) -> NoReturn: + """ + Removes the channel name from a group when it exists. + """ + + async def group_send(self, group: str, message: dict) -> NoReturn: + """ + Sends message to group + """ + + +class BaseChannelLayer(ABC): """ Base channel layer class that others can inherit from, with useful common functionality. """ MAX_NAME_LENGTH = 100 + extensions: Iterable[str] = () + expiry: int + capacity: int + channel_capacity: Dict[str, int] - def __init__(self, expiry=60, capacity=100, channel_capacity=None): + def __init__( + self, + expiry: int = 60, + capacity: Optional[int] = 100, + channel_capacity: Optional[int] = None, + ): self.expiry = expiry self.capacity = capacity self.channel_capacity = channel_capacity or {} - def compile_capacities(self, channel_capacity): + def compile_capacities( + self, channel_capacity + ) -> List[Tuple[re.Pattern, Optional[int]]]: """ Takes an input channel_capacity dict and returns the compiled list of regexes that get_capacity will look for as self.channel_capacity @@ -120,7 +178,7 @@ def compile_capacities(self, channel_capacity): result.append((re.compile(fnmatch.translate(pattern)), value)) return result - def get_capacity(self, channel): + def get_capacity(self, channel: str) -> Optional[int]: """ Gets the correct capacity for the given channel; either the default, or a matching result from channel_capacity. Returns the first matching @@ -132,7 +190,7 @@ def get_capacity(self, channel): return capacity return self.capacity - def match_type_and_length(self, name): + def match_type_and_length(self, name) -> bool: if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH): return True return False @@ -148,7 +206,7 @@ def match_type_and_length(self, name): + "not {}" ) - def valid_channel_name(self, name, receive=False): + def valid_channel_name(self, name: str, receive=False) -> bool: if self.match_type_and_length(name): if bool(self.channel_name_regex.match(name)): # Check cases for special channels @@ -159,13 +217,13 @@ def valid_channel_name(self, name, receive=False): return True raise TypeError(self.invalid_name_error.format("Channel", name)) - def valid_group_name(self, name): + def valid_group_name(self, name: str) -> bool: if self.match_type_and_length(name): if bool(self.group_name_regex.match(name)): return True raise TypeError(self.invalid_name_error.format("Group", name)) - def valid_channel_names(self, names, receive=False): + def valid_channel_names(self, names: List[str], receive=False) -> bool: _non_empty_list = True if names else False _names_type = isinstance(names, list) assert _non_empty_list and _names_type, "names must be a non-empty list" @@ -175,7 +233,7 @@ def valid_channel_names(self, names, receive=False): ) return True - def non_local_name(self, name): + def non_local_name(self, name: str) -> str: """ Given a channel name, returns the "non-local" part. If the channel name is a process-specific channel (contains !) this means the part up to @@ -186,8 +244,34 @@ def non_local_name(self, name): else: return name + @abstractmethod + async def send(self, channel: str, message: dict): + """ + Send a message onto a (general or specific) channel. + """ -class InMemoryChannelLayer(BaseChannelLayer): + @abstractmethod + async def receive(self, channel: str) -> dict: + """ + Receive the first message that arrives on the channel. + If more than one coroutine waits on the same channel, a random one + of the waiting coroutines will get the result. + """ + + @abstractmethod + async def new_channel(self, prefix: str = "specific.") -> str: + """ + Returns a new channel name that can be used by something in our + process as a specific channel. + """ + + +# WARNING: Protocols must be last +class InMemoryChannelLayer( + BaseChannelLayer, + WithFlushExtension, + WithGroupsExtension, +): """ In-memory channel layer implementation """ @@ -198,13 +282,13 @@ def __init__( group_expiry=86400, capacity=100, channel_capacity=None, - **kwargs + **kwargs, ): super().__init__( expiry=expiry, capacity=capacity, channel_capacity=channel_capacity, - **kwargs + **kwargs, ) self.channels = {} self.groups = {} @@ -215,9 +299,6 @@ def __init__( extensions = ["groups", "flush"] async def send(self, channel, message): - """ - Send a message onto a (general or specific) channel. - """ # Typecheck assert isinstance(message, dict), "message is not a dict" assert self.valid_channel_name(channel), "Channel name not valid" @@ -234,11 +315,6 @@ async def send(self, channel, message): await queue.put((time.time() + self.expiry, deepcopy(message))) async def receive(self, channel): - """ - Receive the first message that arrives on the channel. - If more than one coroutine waits on the same channel, a random one - of the waiting coroutines will get the result. - """ assert self.valid_channel_name(channel) self._clean_expired() @@ -254,10 +330,6 @@ async def receive(self, channel): return message async def new_channel(self, prefix="specific."): - """ - Returns a new channel name that can be used by something in our - process as a specific channel. - """ return "%s.inmemory!%s" % ( prefix, "".join(random.choice(string.ascii_letters) for i in range(12)), @@ -314,9 +386,6 @@ def _remove_from_groups(self, channel): # Groups extension async def group_add(self, group, channel): - """ - Adds the channel name to a group. - """ # Check the inputs assert self.valid_group_name(group), "Group name not valid" assert self.valid_channel_name(channel), "Channel name not valid" @@ -349,7 +418,7 @@ async def group_send(self, group, message): pass -def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER): +def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER) -> Optional[BaseChannelLayer]: """ Returns a channel layer by alias, or None if it is not configured. """ diff --git a/tests/test_layers.py b/tests/test_layers.py index 543a9f19..3ebbc287 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -13,6 +13,18 @@ ) +# when starting with Test it would be tried to collect by pytest +class StubChannelLayer(BaseChannelLayer): + async def send(self, channel: str, message: dict): + raise NotImplementedError() + + async def receive(self, channel: str) -> dict: + raise NotImplementedError() + + async def new_channel(self, prefix: str = "specific.") -> str: + raise NotImplementedError() + + class TestChannelLayerManager(unittest.TestCase): @override_settings( CHANNEL_LAYERS={"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}} @@ -72,7 +84,7 @@ async def test_send_receive(): @pytest.mark.parametrize( "method", - [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], + [StubChannelLayer().valid_channel_name, StubChannelLayer().valid_group_name], ) @pytest.mark.parametrize( "channel_name,expected_valid",