diff --git a/channels/layers.py b/channels/layers.py index 12bbd2b8..e64520da 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import fnmatch import random @@ -5,6 +7,7 @@ import string import time from copy import deepcopy +from typing import Dict, Iterable, List, Optional, Tuple from django.conf import settings from django.core.signals import setting_changed @@ -20,6 +23,8 @@ class ChannelLayerManager: Takes a settings dictionary of backends and initialises them on request. """ + backends: Dict[str, BaseChannelLayer] + def __init__(self): self.backends = {} setting_changed.connect(self._reset_backends) @@ -36,14 +41,14 @@ def configs(self): # Lazy load settings so we can be imported return getattr(settings, "CHANNEL_LAYERS", {}) - def make_backend(self, name): + def make_backend(self, name) -> BaseChannelLayer: """ Instantiate channel layer. """ config = self.configs[name].get("CONFIG", {}) return self._make_backend(name, config) - def make_test_backend(self, name): + def make_test_backend(self, name) -> BaseChannelLayer: """ Instantiate channel layer using its test config. """ @@ -53,7 +58,7 @@ def make_test_backend(self, name): raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name) return self._make_backend(name, config) - def _make_backend(self, name, config): + def _make_backend(self, name, config) -> BaseChannelLayer: # Check for old format config if "ROUTING" in self.configs[name]: raise InvalidChannelLayerError( @@ -81,7 +86,7 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.configs - def set(self, key, layer): + def set(self, key: str, layer: BaseChannelLayer): """ Sets an alias to point to a new ChannelLayerWrapper instance, and returns the old one that it replaced. Useful for swapping out the @@ -99,13 +104,21 @@ class BaseChannelLayer: """ MAX_NAME_LENGTH = 100 + extensions: Iterable[str] = () - def __init__(self, expiry=60, capacity=100, channel_capacity=None): + def __init__( + self, + expiry=60, + capacity: Optional[int] = 100, + channel_capacity: Optional[int] = None, + ): self.expiry = expiry self.capacity = capacity self.channel_capacity = channel_capacity or {} - def compile_capacities(self, channel_capacity): + def compile_capacities( + self, channel_capacity + ) -> List[Tuple[re.Pattern, Optional[int]]]: """ Takes an input channel_capacity dict and returns the compiled list of regexes that get_capacity will look for as self.channel_capacity @@ -120,7 +133,7 @@ def compile_capacities(self, channel_capacity): result.append((re.compile(fnmatch.translate(pattern)), value)) return result - def get_capacity(self, channel): + def get_capacity(self, channel: str) -> Optional[int]: """ Gets the correct capacity for the given channel; either the default, or a matching result from channel_capacity. Returns the first matching @@ -132,7 +145,7 @@ def get_capacity(self, channel): return capacity return self.capacity - def match_type_and_length(self, name): + def match_type_and_length(self, name) -> bool: if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH): return True return False @@ -148,7 +161,7 @@ def match_type_and_length(self, name): + "not {}" ) - def valid_channel_name(self, name, receive=False): + def valid_channel_name(self, name: str, receive=False) -> bool: if self.match_type_and_length(name): if bool(self.channel_name_regex.match(name)): # Check cases for special channels @@ -159,13 +172,13 @@ def valid_channel_name(self, name, receive=False): return True raise TypeError(self.invalid_name_error.format("Channel", name)) - def valid_group_name(self, name): + def valid_group_name(self, name: str) -> bool: if self.match_type_and_length(name): if bool(self.group_name_regex.match(name)): return True raise TypeError(self.invalid_name_error.format("Group", name)) - def valid_channel_names(self, names, receive=False): + def valid_channel_names(self, names: List[str], receive=False) -> bool: _non_empty_list = True if names else False _names_type = isinstance(names, list) assert _non_empty_list and _names_type, "names must be a non-empty list" @@ -175,7 +188,7 @@ def valid_channel_names(self, names, receive=False): ) return True - def non_local_name(self, name): + def non_local_name(self, name: str) -> str: """ Given a channel name, returns the "non-local" part. If the channel name is a process-specific channel (contains !) this means the part up to @@ -186,6 +199,49 @@ def non_local_name(self, name): else: return name + async def send(self, channel: str, message: dict): + """ + Send a message onto a (general or specific) channel. + """ + raise NotImplementedError() + + async def receive(self, channel: str) -> dict: + """ + Receive the first message that arrives on the channel. + If more than one coroutine waits on the same channel, a random one + of the waiting coroutines will get the result. + """ + raise NotImplementedError() + + async def new_channel(self, prefix: str = "specific.") -> str: + """ + Returns a new channel name that can be used by something in our + process as a specific channel. + """ + raise NotImplementedError() + + # Flush extension + + async def flush(self): + raise NotImplementedError() + + async def close(self): + raise NotImplementedError() + + # Groups extension + + async def group_add(self, group: str, channel: str): + """ + Adds the channel name to a group. + """ + raise NotImplementedError() + + async def group_discard(self, group: str, channel: str): + raise NotImplementedError() + + async def group_send(self, group: str, message: dict): + raise NotImplementedError() + class InMemoryChannelLayer(BaseChannelLayer): """ @@ -198,13 +254,13 @@ def __init__( group_expiry=86400, capacity=100, channel_capacity=None, - **kwargs + **kwargs, ): super().__init__( expiry=expiry, capacity=capacity, channel_capacity=channel_capacity, - **kwargs + **kwargs, ) self.channels = {} self.groups = {} @@ -215,9 +271,6 @@ def __init__( extensions = ["groups", "flush"] async def send(self, channel, message): - """ - Send a message onto a (general or specific) channel. - """ # Typecheck assert isinstance(message, dict), "message is not a dict" assert self.valid_channel_name(channel), "Channel name not valid" @@ -234,11 +287,6 @@ async def send(self, channel, message): await queue.put((time.time() + self.expiry, deepcopy(message))) async def receive(self, channel): - """ - Receive the first message that arrives on the channel. - If more than one coroutine waits on the same channel, a random one - of the waiting coroutines will get the result. - """ assert self.valid_channel_name(channel) self._clean_expired() @@ -254,10 +302,6 @@ async def receive(self, channel): return message async def new_channel(self, prefix="specific."): - """ - Returns a new channel name that can be used by something in our - process as a specific channel. - """ return "%s.inmemory!%s" % ( prefix, "".join(random.choice(string.ascii_letters) for i in range(12)), @@ -314,9 +358,6 @@ def _remove_from_groups(self, channel): # Groups extension async def group_add(self, group, channel): - """ - Adds the channel name to a group. - """ # Check the inputs assert self.valid_group_name(group), "Group name not valid" assert self.valid_channel_name(channel), "Channel name not valid" @@ -349,7 +390,7 @@ async def group_send(self, group, message): pass -def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER): +def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER) -> Optional[BaseChannelLayer]: """ Returns a channel layer by alias, or None if it is not configured. """