Skip to content

Commit

Permalink
feature: improve layer typings, add method stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral committed Jun 27, 2023
1 parent 0933260 commit 858ce88
Showing 1 changed file with 71 additions and 30 deletions.
101 changes: 71 additions & 30 deletions channels/layers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import asyncio
import fnmatch
import random
import re
import string
import time
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Tuple

from django.conf import settings
from django.core.signals import setting_changed
Expand All @@ -20,6 +23,8 @@ class ChannelLayerManager:
Takes a settings dictionary of backends and initialises them on request.
"""

backends: Dict[str, BaseChannelLayer]

def __init__(self):
self.backends = {}
setting_changed.connect(self._reset_backends)
Expand All @@ -36,14 +41,14 @@ def configs(self):
# Lazy load settings so we can be imported
return getattr(settings, "CHANNEL_LAYERS", {})

def make_backend(self, name):
def make_backend(self, name) -> BaseChannelLayer:
"""
Instantiate channel layer.
"""
config = self.configs[name].get("CONFIG", {})
return self._make_backend(name, config)

def make_test_backend(self, name):
def make_test_backend(self, name) -> BaseChannelLayer:
"""
Instantiate channel layer using its test config.
"""
Expand All @@ -53,7 +58,7 @@ def make_test_backend(self, name):
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
return self._make_backend(name, config)

def _make_backend(self, name, config):
def _make_backend(self, name, config) -> BaseChannelLayer:
# Check for old format config
if "ROUTING" in self.configs[name]:
raise InvalidChannelLayerError(
Expand Down Expand Up @@ -81,7 +86,7 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.configs

def set(self, key, layer):
def set(self, key: str, layer: BaseChannelLayer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
Expand All @@ -99,13 +104,21 @@ class BaseChannelLayer:
"""

MAX_NAME_LENGTH = 100
extensions: Iterable[str] = ()

def __init__(self, expiry=60, capacity=100, channel_capacity=None):
def __init__(
self,
expiry=60,
capacity: Optional[int] = 100,
channel_capacity: Optional[int] = None,
):
self.expiry = expiry
self.capacity = capacity
self.channel_capacity = channel_capacity or {}

def compile_capacities(self, channel_capacity):
def compile_capacities(
self, channel_capacity
) -> List[Tuple[re.Pattern, Optional[int]]]:
"""
Takes an input channel_capacity dict and returns the compiled list
of regexes that get_capacity will look for as self.channel_capacity
Expand All @@ -120,7 +133,7 @@ def compile_capacities(self, channel_capacity):
result.append((re.compile(fnmatch.translate(pattern)), value))
return result

def get_capacity(self, channel):
def get_capacity(self, channel: str) -> Optional[int]:
"""
Gets the correct capacity for the given channel; either the default,
or a matching result from channel_capacity. Returns the first matching
Expand All @@ -132,7 +145,7 @@ def get_capacity(self, channel):
return capacity
return self.capacity

def match_type_and_length(self, name):
def match_type_and_length(self, name) -> bool:
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
return True
return False
Expand All @@ -148,7 +161,7 @@ def match_type_and_length(self, name):
+ "not {}"
)

def valid_channel_name(self, name, receive=False):
def valid_channel_name(self, name: str, receive=False) -> bool:
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
Expand All @@ -159,13 +172,13 @@ def valid_channel_name(self, name, receive=False):
return True
raise TypeError(self.invalid_name_error.format("Channel", name))

def valid_group_name(self, name):
def valid_group_name(self, name: str) -> bool:
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))

def valid_channel_names(self, names, receive=False):
def valid_channel_names(self, names, receive=False) -> bool:
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"
Expand All @@ -175,7 +188,7 @@ def valid_channel_names(self, names, receive=False):
)
return True

def non_local_name(self, name):
def non_local_name(self, name: str) -> str:
"""
Given a channel name, returns the "non-local" part. If the channel name
is a process-specific channel (contains !) this means the part up to
Expand All @@ -186,6 +199,49 @@ def non_local_name(self, name):
else:
return name

async def send(self, channel: str, message: dict):
"""
Send a message onto a (general or specific) channel.
"""
raise NotImplementedError()

async def receive(self, channel: str) -> dict:
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
raise NotImplementedError()

async def new_channel(self, prefix: str = "specific.") -> str:
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
raise NotImplementedError()

# Flush extension

async def flush(self):
raise NotImplementedError()

async def close(self):
raise NotImplementedError()

# Groups extension

async def group_add(self, group: str, channel: str):
"""
Adds the channel name to a group.
"""
raise NotImplementedError()

async def group_discard(self, group: str, channel: str):
raise NotImplementedError()

async def group_send(self, group: str, message: dict):
raise NotImplementedError()


class InMemoryChannelLayer(BaseChannelLayer):
"""
Expand All @@ -198,13 +254,13 @@ def __init__(
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
**kwargs,
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
**kwargs,
)
self.channels = {}
self.groups = {}
Expand All @@ -215,9 +271,6 @@ def __init__(
extensions = ["groups", "flush"]

async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
Expand All @@ -234,11 +287,6 @@ async def send(self, channel, message):
await queue.put((time.time() + self.expiry, deepcopy(message)))

async def receive(self, channel):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self._clean_expired()

Expand All @@ -254,10 +302,6 @@ async def receive(self, channel):
return message

async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.inmemory!%s" % (
prefix,
"".join(random.choice(string.ascii_letters) for i in range(12)),
Expand Down Expand Up @@ -314,9 +358,6 @@ def _remove_from_groups(self, channel):
# Groups extension

async def group_add(self, group, channel):
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
Expand Down Expand Up @@ -349,7 +390,7 @@ async def group_send(self, group, message):
pass


def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER) -> Optional[BaseChannelLayer]:
"""
Returns a channel layer by alias, or None if it is not configured.
"""
Expand Down

0 comments on commit 858ce88

Please sign in to comment.