Skip to content

Commit b0173a2

Browse files
committed
Add serialization registry
1 parent 13cef45 commit b0173a2

File tree

5 files changed

+255
-61
lines changed

5 files changed

+255
-61
lines changed

README.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
171171
If you want to enforce a matching order, use an ``OrderedDict`` as the
172172
argument; channels will then be matched in the order the dict provides them.
173173

174+
.. _encryption
174175
``symmetric_encryption_keys``
175176
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176177

@@ -237,6 +238,44 @@ And then in your channels consumer, you can implement the handler:
237238
async def redis_disconnect(self, *args):
238239
# Handle disconnect
239240
241+
242+
243+
``serializer_format``
244+
~~~~~~~~~~~~~~~~~~~~~~
245+
By default every message which reach redis is encoded using `msgpack <https://msgpack.org/>`_.
246+
It is also possible to switch to `JSON <http://www.json.org/>`_:
247+
248+
.. code-block:: python
249+
250+
CHANNEL_LAYERS = {
251+
"default": {
252+
"BACKEND": "channels_redis.core.RedisChannelLayer",
253+
"CONFIG": {
254+
"hosts": ["redis://:[email protected]:6379/0"],
255+
"serializer_format": "json",
256+
},
257+
},
258+
}
259+
260+
A new serializer may be registered (or can be overriden) by using ``channels_redis.serializers.registry``,
261+
providing a class which extends ``channels_redis.serializers.BaseMessageSerializer``, implementing ``dumps``
262+
and ``loads`` methods, or which provides ``serialize``/``deserialize`` methods and calling the registration method on registry:
263+
264+
.. code-block:: python
265+
266+
from channels_redis.serializers import registry
267+
268+
class MyFormatSerializer:
269+
def serialize(self, message):
270+
...
271+
def deserialize(self, message):
272+
...
273+
274+
registry.register_serializer('myformat', MyFormatSerializer)
275+
276+
**NOTE**: Serializers also perform the encryption job see *symmetric_encryption_keys*.
277+
278+
240279
Dependencies
241280
------------
242281

channels_redis/core.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55
import hashlib
66
import itertools
77
import logging
8-
import random
98
import time
109
import uuid
1110

12-
import msgpack
1311
from redis import asyncio as aioredis
1412

1513
from channels.exceptions import ChannelFull
1614
from channels.layers import BaseChannelLayer
1715

16+
from .serializers import registry
1817
from .utils import (
1918
_close_redis,
2019
_consistent_hash,
@@ -115,6 +114,7 @@ def __init__(
115114
capacity=100,
116115
channel_capacity=None,
117116
symmetric_encryption_keys=None,
117+
serializer_format="msgpack",
118118
):
119119
# Store basic information
120120
self.expiry = expiry
@@ -126,15 +126,23 @@ def __init__(
126126
# Configure the host objects
127127
self.hosts = decode_hosts(hosts)
128128
self.ring_size = len(self.hosts)
129+
# serialization
130+
self._serializer = registry.get_serializer(
131+
serializer_format,
132+
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
133+
random_prefix_length=12,
134+
expiry=self.expiry,
135+
symmetric_encryption_keys=symmetric_encryption_keys,
136+
)
137+
self.serialize = self._serializer.serialize
138+
self.deserialize = self._serializer.deserialize
129139
# Cached redis connection pools and the event loop they are from
130140
self._layers = {}
131141
# Normal channels choose a host index by cycling through the available hosts
132142
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
133143
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
134144
# Decide on a unique client prefix to use in ! sections
135145
self.client_prefix = uuid.uuid4().hex
136-
# Set up any encryption objects
137-
self._setup_encryption(symmetric_encryption_keys)
138146
# Number of coroutines trying to receive right now
139147
self.receive_count = 0
140148
# The receive lock
@@ -154,24 +162,6 @@ def __init__(
154162
def create_pool(self, index):
155163
return create_pool(self.hosts[index])
156164

157-
def _setup_encryption(self, symmetric_encryption_keys):
158-
# See if we can do encryption if they asked
159-
if symmetric_encryption_keys:
160-
if isinstance(symmetric_encryption_keys, (str, bytes)):
161-
raise ValueError(
162-
"symmetric_encryption_keys must be a list of possible keys"
163-
)
164-
try:
165-
from cryptography.fernet import MultiFernet
166-
except ImportError:
167-
raise ValueError(
168-
"Cannot run with encryption without 'cryptography' installed."
169-
)
170-
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
171-
self.crypter = MultiFernet(sub_fernets)
172-
else:
173-
self.crypter = None
174-
175165
### Channel layer API ###
176166

177167
extensions = ["groups", "flush"]
@@ -650,31 +640,6 @@ def _group_key(self, group):
650640
"""
651641
return f"{self.prefix}:group:{group}".encode("utf8")
652642

653-
### Serialization ###
654-
655-
def serialize(self, message):
656-
"""
657-
Serializes message to a byte string.
658-
"""
659-
value = msgpack.packb(message, use_bin_type=True)
660-
if self.crypter:
661-
value = self.crypter.encrypt(value)
662-
663-
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
664-
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
665-
return random_prefix + value
666-
667-
def deserialize(self, message):
668-
"""
669-
Deserializes from a byte string.
670-
"""
671-
# Removes the random prefix
672-
message = message[12:]
673-
674-
if self.crypter:
675-
message = self.crypter.decrypt(message, self.expiry + 10)
676-
return msgpack.unpackb(message, raw=False)
677-
678643
### Internal functions ###
679644

680645
def consistent_hash(self, value):

channels_redis/pubsub.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import logging
44
import uuid
55

6-
import msgpack
76
from redis import asyncio as aioredis
87

8+
from .serializers import registry
99
from .utils import (
1010
_close_redis,
1111
_consistent_hash,
@@ -25,10 +25,23 @@ async def _async_proxy(obj, name, *args, **kwargs):
2525

2626

2727
class RedisPubSubChannelLayer:
28-
def __init__(self, *args, **kwargs) -> None:
28+
def __init__(
29+
self,
30+
*args,
31+
symmetric_encryption_keys=None,
32+
serializer_format="msgpack",
33+
**kwargs,
34+
) -> None:
2935
self._args = args
3036
self._kwargs = kwargs
3137
self._layers = {}
38+
# serialization
39+
self._serializer = registry.get_serializer(
40+
serializer_format,
41+
symmetric_encryption_keys=symmetric_encryption_keys,
42+
)
43+
self.serialize = self._serializer.serialize
44+
self.deserialize = self._serializer.deserialize
3245

3346
def __getattr__(self, name):
3447
if name in (
@@ -44,18 +57,6 @@ def __getattr__(self, name):
4457
else:
4558
return getattr(self._get_layer(), name)
4659

47-
def serialize(self, message):
48-
"""
49-
Serializes message to a byte string.
50-
"""
51-
return msgpack.packb(message)
52-
53-
def deserialize(self, message):
54-
"""
55-
Deserializes from a byte string.
56-
"""
57-
return msgpack.unpackb(message)
58-
5960
def _get_layer(self):
6061
loop = asyncio.get_running_loop()
6162

channels_redis/serializers.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import json
2+
import random
3+
import abc
4+
5+
6+
class SerializerDoesNotExist(KeyError):
7+
"""The requested serializer was not found."""
8+
9+
10+
class BaseMessageSerializer(abc.ABC):
11+
12+
def __init__(
13+
self,
14+
symmetric_encryption_keys=None,
15+
random_prefix_length=0,
16+
expiry=None,
17+
):
18+
self.random_prefix_length = random_prefix_length
19+
self.expiry = expiry
20+
# Set up any encryption objects
21+
self._setup_encryption(symmetric_encryption_keys)
22+
23+
def _setup_encryption(self, symmetric_encryption_keys):
24+
# See if we can do encryption if they asked
25+
if symmetric_encryption_keys:
26+
if isinstance(symmetric_encryption_keys, (str, bytes)):
27+
raise ValueError(
28+
"symmetric_encryption_keys must be a list of possible keys"
29+
)
30+
try:
31+
from cryptography.fernet import MultiFernet
32+
except ImportError:
33+
raise ValueError(
34+
"Cannot run with encryption without 'cryptography' installed."
35+
)
36+
sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
37+
self.crypter = MultiFernet(sub_fernets)
38+
else:
39+
self.crypter = None
40+
41+
@abc.abstractmethod
42+
def dumps(self, message):
43+
raise NotImplementedError
44+
45+
@abc.abstractmethod
46+
def loads(self, message):
47+
raise NotImplementedError
48+
49+
def serialize(self, message):
50+
"""
51+
Serializes message to a byte string.
52+
"""
53+
message = self.dumps(message)
54+
# ensure message is bytes
55+
if isinstance(message, str):
56+
message = message.encode("utf-8")
57+
if self.crypter:
58+
message = self.crypter.encrypt(message)
59+
60+
if self.random_prefix_length > 0:
61+
# provide random prefix
62+
message = (
63+
random.getrandbits(8 * self.random_prefix_length).to_bytes(
64+
self.random_prefix_length, "big"
65+
)
66+
+ message
67+
)
68+
return message
69+
70+
def deserialize(self, message):
71+
"""
72+
Deserializes from a byte string.
73+
"""
74+
if self.random_prefix_length > 0:
75+
# Removes the random prefix
76+
message = message[self.random_prefix_length :] # noqa: E203
77+
78+
if self.crypter:
79+
ttl = self.expiry if self.expiry is None else self.expiry + 10
80+
message = self.crypter.decrypt(message, ttl)
81+
return self.loads(message)
82+
83+
84+
class MissingSerializer(BaseMessageSerializer):
85+
exception = None
86+
87+
def __init__(self, *args, **kwargs):
88+
raise self.exception
89+
90+
91+
class JSONSerializer(BaseMessageSerializer):
92+
dumps = staticmethod(json.dumps)
93+
loads = staticmethod(json.loads)
94+
95+
96+
# code ready for a future in which msgpack may become an optional dependency
97+
try:
98+
import msgpack
99+
except ImportError as exc:
100+
101+
class MsgPackSerializer(MissingSerializer):
102+
exception = exc
103+
104+
else:
105+
106+
class MsgPackSerializer(BaseMessageSerializer):
107+
dumps = staticmethod(msgpack.packb)
108+
loads = staticmethod(msgpack.unpackb)
109+
110+
111+
class SerializersRegistry:
112+
def __init__(self):
113+
self._registry = {}
114+
115+
def register_serializer(self, format, serializer_class):
116+
"""
117+
Register a new serializer for given format
118+
"""
119+
assert isinstance(serializer_class, type) and (
120+
issubclass(serializer_class, BaseMessageSerializer)
121+
or hasattr(serializer_class, "serialize")
122+
and hasattr(serializer_class, "deserialize")
123+
), """
124+
`serializer_class` should be a class which implements `serialize` and `deserialize` method
125+
or a subclass of `channels_redis.serializers.BaseMessageSerializer`
126+
"""
127+
128+
self._registry[format] = serializer_class
129+
130+
def get_serializer(self, format, *args, **kwargs):
131+
try:
132+
serializer_class = self._registry[format]
133+
except KeyError:
134+
raise SerializerDoesNotExist(format)
135+
136+
return serializer_class(*args, **kwargs)
137+
138+
139+
registry = SerializersRegistry()
140+
registry.register_serializer("json", JSONSerializer)
141+
registry.register_serializer("msgpack", MsgPackSerializer)

0 commit comments

Comments
 (0)