Skip to content

Commit 7779030

Browse files
David Perld-perl
authored andcommitted
refactor: separate message serialization and encoding
1 parent ce76ee2 commit 7779030

File tree

7 files changed

+62
-8
lines changed

7 files changed

+62
-8
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pydantic import BaseModel, ConfigDict, computed_field
2+
3+
4+
class BecCodecInfo(BaseModel):
5+
type_name: str
6+
7+
8+
class BECSerializable(BaseModel):
9+
"""A base class for serializable BEC objects, especially BEC messages.
10+
Fields in subclasses which use non-primitive types must be in structured,
11+
type-hinted objects, and their encoders and JSON schema should be defined in
12+
this class."""
13+
14+
model_config = ConfigDict(
15+
json_schema_serialization_defaults_required=True,
16+
arbitrary_types_allowed=True,
17+
extra="forbid",
18+
)
19+
20+
@computed_field()
21+
@property
22+
def bec_codec(self) -> BecCodecInfo:
23+
return BecCodecInfo(type_name=self.__class__.__name__)

bec_lib/bec_lib/devicemanager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,9 +667,18 @@ def _get_redis_device_config(self) -> list:
667667

668668
def _add_multiple_devices_with_log(self, devices: Iterable[tuple[dict, DeviceInfoMessage]]):
669669
logs = (self._add_device(*conf_msg) for conf_msg in devices if conf_msg is not None)
670-
logger.info(f"Adding new devices:\n" + ", ".join(f"{name}: {t}" for name, t in logs)) # type: ignore # filtered
670+
if set(logs) == {None}:
671+
logger.warning("No devices added!")
672+
return
673+
logger.info(
674+
f"Adding new devices:\n"
675+
+ ", ".join(f"{log[0]}: {log[1]}" for log in logs if log is not None)
676+
)
671677

672678
def _add_device(self, dev: dict, msg: DeviceInfoMessage) -> tuple[str, str] | None:
679+
if msg is None:
680+
logger.error(f"No device info in Redis for: {dev}")
681+
return None
673682
name = msg.content["device"]
674683
info = msg.content["info"]
675684

bec_lib/bec_lib/messages.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from importlib.metadata import PackageNotFoundError
1010
from importlib.metadata import version as importlib_version
1111
from types import NoneType
12-
from typing import Annotated, Any, ClassVar, Literal, Self, Union
12+
from typing import Annotated, Any, ClassVar, Literal, Self, TypeVar, Union
1313
from uuid import uuid4
1414

1515
import msgpack
@@ -26,6 +26,7 @@
2626
)
2727
from typing_extensions import TypeAliasType
2828

29+
from bec_lib.bec_serializable import BECSerializable
2930
from bec_lib.metadata_schema import get_metadata_schema_for_scan
3031
from bec_lib.one_way_registry import OneWaySerializationRegistry
3132

@@ -113,7 +114,7 @@ class BECStatus(Enum):
113114
ERROR = -1
114115

115116

116-
class BECMessage(BaseModel):
117+
class BECMessage(BECSerializable):
117118
"""Base Model class for BEC Messages
118119
119120
Args:
@@ -122,7 +123,6 @@ class BECMessage(BaseModel):
122123
123124
"""
124125

125-
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
126126
msg_type: ClassVar[str]
127127
metadata: JsonableDict = Field(default_factory=dict)
128128

@@ -1307,16 +1307,19 @@ class DAPResponseMessage(BECMessage):
13071307
dap_request: BECMessage | None = Field(default=None)
13081308

13091309

1310+
MessageType = TypeVar("MessageType", bound=BECMessage)
1311+
1312+
13101313
class AvailableResourceMessage(BECMessage):
13111314
"""Message for available resources such as scans, data processing plugins etc
13121315
13131316
Args:
1314-
resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description
1317+
resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description - may contain only one type of BECMessage
13151318
metadata (dict, optional): Metadata. Defaults to None.
13161319
"""
13171320

13181321
msg_type: ClassVar[str] = "available_resource_message"
1319-
resource: JsonableDict | list[JsonableDict] | BECMessage | list[BECMessage]
1322+
resource: JsonableDict | list[JsonableDict] | MessageType | list[MessageType]
13201323

13211324

13221325
class ProgressMessage(BECMessage):

bec_lib/bec_lib/serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class BECMessagePack(SerializationRegistry):
3636

3737
def dumps(self, obj):
3838
"""Pack object `obj` and return packed bytes."""
39+
if isinstance(obj, BECMessage):
40+
obj = obj.model_dump(mode="python", fallback=self.encode)
3941
return msgpack_module.packb(obj, default=self.encode)
4042

4143
def loads(self, raw_bytes):

bec_lib/bec_lib/serialization_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Callable, Type
55

66
from bec_lib import codecs as bec_codecs
7+
from bec_lib import messages
78
from bec_lib.logger import bec_logger
89

910
logger = bec_logger.logger
@@ -18,7 +19,6 @@ def __init__(self):
1819
self._registry: dict[str, tuple[Type, Callable, Callable]] = {}
1920
self._legacy_codecs = [] # can be removed in future versions, see issue #516
2021

21-
self.register_codec(bec_codecs.BECMessageEncoder)
2222
self.register_codec(bec_codecs.EndpointInfoEncoder)
2323
self.register_codec(bec_codecs.SetEncoder)
2424
self.register_codec(bec_codecs.BECTypeEncoder)
@@ -97,6 +97,11 @@ def encode(self, obj):
9797

9898
def decode(self, data):
9999
"""Decode an object using the registered codec."""
100+
if isinstance(data, dict) and "bec_codec" in data:
101+
codec_info = data.pop("bec_codec")
102+
msg_cls = messages.__dict__.get(codec_info.get("type_name"))
103+
if msg_cls is not None:
104+
return msg_cls.model_validate(data)
100105
if not isinstance(data, dict) or "__bec_codec__" not in data:
101106
return data
102107
codec_info = data["__bec_codec__"]

bec_lib/tests/test_bec_messages.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_bec_message_msgpack_serialization_version(version):
1717
assert "Unsupported BECMessage version" in str(exception.value)
1818
else:
1919
res = MsgpackSerialization.dumps(msg)
20-
res_expected = b"\x81\xad__bec_codec__\x83\xacencoder_name\xaaBECMessage\xa9type_name\xb8DeviceInstructionMessage\xa4data\x84\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00"
20+
res_expected = b"\x85\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00\xa9bec_codec\x81\xa9type_name\xb8DeviceInstructionMessage"
2121
assert res == res_expected
2222
res_loaded = MsgpackSerialization.loads(res)
2323
assert res_loaded == msg
@@ -682,3 +682,14 @@ def test_message_with_np_array_in_dict():
682682
arr = np.zeros(5)
683683
msg = messages.ScanMessage(point_id=0, scan_id="", data={"device": {"value": arr}}, metadata={})
684684
assert isinstance(msg.data["device"]["value"], np.ndarray)
685+
686+
687+
def test_message_service_config():
688+
msg = messages.MessagingServiceConfig(
689+
metadata={}, service_name="signal", scopes=["*"], enabled=True
690+
)
691+
dump = msg.model_dump(mode="python")
692+
assert dump["service_name"] == "signal"
693+
resource_msg = messages.AvailableResourceMessage(resource=[msg])
694+
resource_msg_dump = resource_msg.model_dump(mode="python")
695+
assert resource_msg_dump["resource"][0]["service_name"] == "signal"

bec_server/bec_server/procedures/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _log_on_end(future: Future):
5050

5151
def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T:
5252
if isinstance(msg, dict):
53+
msg.pop("bec_codec", None)
5354
return MsgType.model_validate(msg)
5455
return msg
5556

0 commit comments

Comments
 (0)