Skip to content

Commit

Permalink
Add serialization registry
Browse files Browse the repository at this point in the history
  • Loading branch information
sevdog committed Sep 3, 2024
1 parent 13cef45 commit 162cc74
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 85 deletions.
46 changes: 46 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
If you want to enforce a matching order, use an ``OrderedDict`` as the
argument; channels will then be matched in the order the dict provides them.

.. _encryption
``symmetric_encryption_keys``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -237,6 +238,51 @@ And then in your channels consumer, you can implement the handler:
async def redis_disconnect(self, *args):
# Handle disconnect
``serializer_format``
~~~~~~~~~~~~~~~~~~~~~~
By default every message sent to redis is encoded using `msgpack <https://msgpack.org/>`_ (_currently ``msgpack`` is a mandatory dependency of this package, it may become optional in a future release_).
It is also possible to switch to `JSON <http://www.json.org/>`_:

.. code-block:: python
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": ["redis://:[email protected]:6379/0"],
"serializer_format": "json",
},
},
}
Custom serializer can be defined by:

- extending ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps`` and ``loads`` methods
- using any class which accepts generic keyword arguments and provides ``serialize``/``deserialize`` methods

Then it may be registerd (or can be overriden) by using ``channels_redis.serializers.registry``:

.. code-block:: python
from channels_redis.serializers import registry
class MyFormatSerializer:
def serialize(self, message):
...
def deserialize(self, message):
...
registry.register_serializer('myformat', MyFormatSerializer)
**NOTE**: the registry allows to override the serializer class used for a specific format without any particular check nor constraint, thus it is recommended to pay attention with order-of-imports when using third-party serializers which may override a built-in format.


Serializers are also responsible for encryption *symmetric_encryption_keys*. When extending ``channels_redis.serializers.BaseMessageSerializer`` encryption is already configured in the base class, unless you override ``serialize``/``deserialize`` methods: in this case you should call ``self.crypter.encrypt`` in serialization and ``self.crypter.decrypt`` in deserialization process. When using full custom serializer expect an optional sequence of keys to be passed via ``symmetric_encryption_keys``.


Dependencies
------------

Expand Down
61 changes: 13 additions & 48 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import asyncio
import base64
import collections
import functools
import hashlib
import itertools
import logging
import random
import time
import uuid

import msgpack
from redis import asyncio as aioredis

from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
Expand Down Expand Up @@ -115,6 +112,8 @@ def __init__(
capacity=100,
channel_capacity=None,
symmetric_encryption_keys=None,
random_prefix_length=12,
serializer_format="msgpack",
):
# Store basic information
self.expiry = expiry
Expand All @@ -126,15 +125,21 @@ def __init__(
# Configure the host objects
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
# serialization
self._serializer = registry.get_serializer(
serializer_format,
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix_length=random_prefix_length,
expiry=self.expiry,
symmetric_encryption_keys=symmetric_encryption_keys,
)
# Cached redis connection pools and the event loop they are from
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
# Decide on a unique client prefix to use in ! sections
self.client_prefix = uuid.uuid4().hex
# Set up any encryption objects
self._setup_encryption(symmetric_encryption_keys)
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
Expand All @@ -154,24 +159,6 @@ def __init__(
def create_pool(self, index):
return create_pool(self.hosts[index])

def _setup_encryption(self, symmetric_encryption_keys):
# See if we can do encryption if they asked
if symmetric_encryption_keys:
if isinstance(symmetric_encryption_keys, (str, bytes)):
raise ValueError(
"symmetric_encryption_keys must be a list of possible keys"
)
try:
from cryptography.fernet import MultiFernet
except ImportError:
raise ValueError(
"Cannot run with encryption without 'cryptography' installed."
)
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
self.crypter = MultiFernet(sub_fernets)
else:
self.crypter = None

### Channel layer API ###

extensions = ["groups", "flush"]
Expand Down Expand Up @@ -656,41 +643,19 @@ def serialize(self, message):
"""
Serializes message to a byte string.
"""
value = msgpack.packb(message, use_bin_type=True)
if self.crypter:
value = self.crypter.encrypt(value)

# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
return random_prefix + value
return self._serializer.serialize(message)

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
# Removes the random prefix
message = message[12:]

if self.crypter:
message = self.crypter.decrypt(message, self.expiry + 10)
return msgpack.unpackb(message, raw=False)
return self._serializer.deserialize(message)

### Internal functions ###

def consistent_hash(self, value):
return _consistent_hash(value, self.ring_size)

def make_fernet(self, key):
"""
Given a single encryption key, returns a Fernet instance using it.
"""
from cryptography.fernet import Fernet

if isinstance(key, str):
key = key.encode("utf8")
formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest())
return Fernet(formatted_key)

def __str__(self):
return f"{self.__class__.__name__}(hosts={self.hosts})"

Expand Down
39 changes: 25 additions & 14 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import logging
import uuid

import msgpack
from redis import asyncio as aioredis

from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
Expand All @@ -25,10 +25,33 @@ async def _async_proxy(obj, name, *args, **kwargs):


class RedisPubSubChannelLayer:
def __init__(self, *args, **kwargs) -> None:
def __init__(
self,
*args,
symmetric_encryption_keys=None,
serializer_format="msgpack",
**kwargs,
) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}
# serialization
self._serializer = registry.get_serializer(
serializer_format,
symmetric_encryption_keys=symmetric_encryption_keys,
)

def serialize(self, message):
"""
Serializes message to a byte string.
"""
return self._serializer.serialize(message)

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
return self._serializer.deserialize(message)

def __getattr__(self, name):
if name in (
Expand All @@ -44,18 +67,6 @@ def __getattr__(self, name):
else:
return getattr(self._get_layer(), name)

def serialize(self, message):
"""
Serializes message to a byte string.
"""
return msgpack.packb(message)

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
return msgpack.unpackb(message)

def _get_layer(self):
loop = asyncio.get_running_loop()

Expand Down
Loading

0 comments on commit 162cc74

Please sign in to comment.