From 4c12c3bc557c597c5b467c128d6113e3aa86fb99 Mon Sep 17 00:00:00 2001 From: Jennings Zhang Date: Tue, 10 Sep 2024 17:02:17 -0400 Subject: [PATCH] Implement LONK-WS and test --- chris_backend/pacsfiles/consumers.py | 54 ++++- chris_backend/pacsfiles/lonk.py | 209 ++++++++++++++++++ chris_backend/pacsfiles/tests/mocks.py | 47 ++++ .../pacsfiles/tests/test_consumers.py | 89 +++++++- 4 files changed, 383 insertions(+), 16 deletions(-) create mode 100644 chris_backend/pacsfiles/lonk.py create mode 100644 chris_backend/pacsfiles/tests/mocks.py diff --git a/chris_backend/pacsfiles/consumers.py b/chris_backend/pacsfiles/consumers.py index 7acfec40..7ce50662 100755 --- a/chris_backend/pacsfiles/consumers.py +++ b/chris_backend/pacsfiles/consumers.py @@ -1,8 +1,15 @@ import rest_framework.permissions from channels.db import database_sync_to_async from channels.generic.websocket import AsyncJsonWebsocketConsumer +from django.conf import settings from rest_framework import permissions +from pacsfiles.lonk import ( + LonkClient, + validate_subscription, + LonkWsSubscription, + Lonk, +) from pacsfiles.permissions import IsChrisOrIsPACSUserReadOnly @@ -11,19 +18,54 @@ class PACSFileProgress(AsyncJsonWebsocketConsumer): A WebSockets endpoint which relays progress messages from NATS sent by *oxidicom* to a client. """ - permission_classes = (permissions.IsAuthenticated, IsChrisOrIsPACSUserReadOnly,) + permission_classes = ( + permissions.IsAuthenticated, + IsChrisOrIsPACSUserReadOnly, + ) async def connect(self): if not await self._has_permission(): - await self.close() - else: - await self.accept() + return await self.close() + self.client: LonkClient = await LonkClient.connect( + settings.NATS_ADDRESS + ) + await self.accept() async def receive_json(self, content, **kwargs): - ... + if validate_subscription(content): + await self._subscribe( + content['pacs_name'], content['SeriesInstanceUID'] + ) + return + await self.close(code=400, reason='Invalid subscription') + + async def _subscribe(self, pacs_name: str, series_instance_uid: str): + """ + Subscribe to progress notifications about the reception of a DICOM series. + """ + try: + await self.client.subscribe( + pacs_name, series_instance_uid, lambda msg: self.send_json(msg) + ) + response = Lonk( + pacs_name=pacs_name, + SeriesInstanceUID=series_instance_uid, + message=LonkWsSubscription(subscription='subscribed'), + ) + await self.send_json(response) + except Exception as e: + response = Lonk( + pacs_name=pacs_name, + SeriesInstanceUID=series_instance_uid, + message=LonkWsSubscription(subscription='error'), + ) + await self.send_json(response) + await self.close(code=500) + raise e async def disconnect(self, code): - ... + await super().disconnect(code) + await self.client.close() @database_sync_to_async def _has_permission(self) -> bool: diff --git a/chris_backend/pacsfiles/lonk.py b/chris_backend/pacsfiles/lonk.py new file mode 100644 index 00000000..b4dc51bd --- /dev/null +++ b/chris_backend/pacsfiles/lonk.py @@ -0,0 +1,209 @@ +""" +Implementation of the "Light Oxidicom NotifiKations Encoding" + +See https://chrisproject.org/docs/oxidicom/lonk +""" +import asyncio +import enum +from sys import byteorder +from typing import ( + Self, + Callable, + TypedDict, + Literal, + TypeGuard, + Any, + Awaitable, +) + +import nats +from nats import NATS +from nats.aio.subscription import Subscription +from nats.aio.msg import Msg + + +class SubscriptionRequest(TypedDict): + """ + A request to subscribe to LONK notifications about a DICOM series. + """ + + pacs_name: str + SeriesInstanceUID: str + action: Literal['subscribe'] + + +def validate_subscription(data: Any) -> TypeGuard[SubscriptionRequest]: + if not isinstance(data, dict): + return False + return ( + data.get('action', None) == 'subscribe' + and isinstance(data.get('SeriesInstanceUID', None), str) + and isinstance(data.get('pacs_name', None), str) + ) + + +class LonkProgress(TypedDict): + """ + LONK "done" message. + + https://chrisproject.org/docs/oxidicom/lonk#lonk-message-encoding + """ + + ndicom: int + + +class LonkError(TypedDict): + """ + LONK "error" message. + + https://chrisproject.org/docs/oxidicom/lonk#lonk-message-encoding + """ + + error: str + + +class LonkDone(TypedDict): + """ + LONK "done" message. + + https://chrisproject.org/docs/oxidicom/lonk#lonk-message-encoding + """ + + done: bool + + +class LonkWsSubscription(TypedDict): + """ + LONK-WS "subscribed" message. + + https://chrisproject.org/docs/oxidicom/lonk-ws#lonk-ws-subscription + """ + + subscription: Literal['subscribed', 'error'] + + +LonkMessageData = LonkProgress | LonkError | LonkDone | LonkWsSubscription +""" +Lonk message data. + +https://chrisproject.org/docs/oxidicom/lonk-ws#messages +""" + + +class Lonk(TypedDict): + """ + Serialized LONK message about a DICOM series. + + https://chrisproject.org/docs/oxidicom#lonk-message-encoding + """ + + SeriesInstanceUID: str + pacs_name: str + message: LonkMessageData + + +class LonkClient: + """ + "Light Oxidicom NotifiKations Encoding" client: + A client for the messages sent by *oxidicom* over NATS. + + https://chrisproject.org/docs/oxidicom/lonk + """ + + def __init__(self, nc: NATS): + self._nc = nc + self._subscriptions: list[Subscription] = [] + + @classmethod + async def connect(cls, servers: str | list[str]) -> Self: + return cls(await nats.connect(servers)) + + async def subscribe( + self, + pacs_name: str, + series_instance_uid: str, + cb: Callable[[Lonk], Awaitable[None]], + ): + subject = subject_of(pacs_name, series_instance_uid) + cb = _curry_message2json(pacs_name, series_instance_uid, cb) + subscription = await self._nc.subscribe(subject, cb=cb) + self._subscriptions.append(subscription) + return subscription + + async def close(self): + await asyncio.gather(*(s.unsubscribe() for s in self._subscriptions)) + await self._nc.close() + + +def subject_of(pacs_name: str, series_instance_uid: str) -> str: + """ + Get the NATS subject for a series. + + Equivalent to https://github.com/FNNDSC/oxidicom/blob/33838f22a5431a349b3b83a313035b8e22d16bb1/src/lonk.rs#L36-L48 + """ + return f'oxidicom.{_sanitize_topic_part(pacs_name)}.{_sanitize_topic_part(series_instance_uid)}' + + +def _sanitize_topic_part(s: str) -> str: + return ( + s.replace('\0', '') + .replace(' ', '_') + .replace('.', '_') + .replace('*', '_') + .replace('>', '_') + ) + + +def _message2json( + pacs_name: str, series_instance_uid: str, message: Msg +) -> Lonk: + return Lonk( + pacs_name=pacs_name, + SeriesInstanceUID=series_instance_uid, + message=_serialize_to_lonkws(message.data), + ) + + +def _curry_message2json( + pacs_name: str, + series_instance_uid: str, + cb: Callable[[Lonk], Awaitable[None]], +): + async def nats_callback(message: Msg): + lonk = _message2json(pacs_name, series_instance_uid, message) + await cb(lonk) + + return nats_callback + + +@enum.unique +class LonkMagicByte(enum.IntEnum): + """ + LONK message first magic byte. + """ + + DONE = 0x00 + PROGRESS = 0x01 + ERROR = 0x02 + + +def _serialize_to_lonkws(payload: bytes) -> LonkMessageData: + """ + Translate LONK binary encoding to LONK-WS JSON. + """ + if len(payload) == 0: + raise ValueError('Empty message') + data = payload[1:] + + match payload[0]: + case LonkMagicByte.DONE.value: + return LonkDone(done=True) + case LonkMagicByte.PROGRESS.value: + ndicom = int.from_bytes(data, 'little', signed=False) + return LonkProgress(ndicom=ndicom) + case LonkMagicByte.ERROR.value: + error = data.decode(encoding='utf-8') + return LonkError(error=error) + case _: + hexstr = ' '.join(hex(b) for b in payload) + raise ValueError(f'Unrecognized message: {hexstr}') diff --git a/chris_backend/pacsfiles/tests/mocks.py b/chris_backend/pacsfiles/tests/mocks.py new file mode 100644 index 00000000..4f4825bf --- /dev/null +++ b/chris_backend/pacsfiles/tests/mocks.py @@ -0,0 +1,47 @@ +from typing import Self + +import nats +from nats import NATS + +from pacsfiles import lonk +from pacsfiles.lonk import LonkMagicByte + + +class Mockidicom: + """ + A mock *oxidicom* which sends LONK messages to NATS. + + Somewhat similar to https://github.com/FNNDSC/oxidicom/blob/e6bb83d1ea2fbaf5bb4af7dbf518a4b1a2957f2d/src/lonk.rs + """ + + def __init__(self, nc: NATS): + self._nc = nc + + @classmethod + async def connect(cls, servers: str | list[str]) -> Self: + nc = await nats.connect(servers) + return cls(nc) + + async def send_progress( + self, pacs_name: str, SeriesInstanceUID: str, ndicom: int + ): + subject = lonk.subject_of(pacs_name, SeriesInstanceUID) + u32 = ndicom.to_bytes(length=4, byteorder='little', signed=False) + data = LonkMagicByte.PROGRESS.value.to_bytes() + u32 + await self._nc.publish(subject, data) + + async def send_done(self, pacs_name: str, SeriesInstanceUID: str): + subject = lonk.subject_of(pacs_name, SeriesInstanceUID) + await self._nc.publish(subject, LonkMagicByte.DONE.value.to_bytes()) + + async def send_error( + self, pacs_name: str, SeriesInstanceUID: str, error: str + ): + subject = lonk.subject_of(pacs_name, SeriesInstanceUID) + data = LonkMagicByte.ERROR.value.to_bytes() + error.encode( + encoding='utf-8' + ) + await self._nc.publish(subject, data) + + async def close(self): + self._nc.close() diff --git a/chris_backend/pacsfiles/tests/test_consumers.py b/chris_backend/pacsfiles/tests/test_consumers.py index 58baa41c..458992a7 100644 --- a/chris_backend/pacsfiles/tests/test_consumers.py +++ b/chris_backend/pacsfiles/tests/test_consumers.py @@ -5,7 +5,7 @@ # note: use TransactionTestCase instead of TestCase for async tests that speak to DB. # See https://stackoverflow.com/a/71763849 -from django.test import TransactionTestCase +from django.test import TransactionTestCase, tag from channels.testing import WebsocketCommunicator from django.utils import timezone @@ -13,29 +13,95 @@ from core.models import FileDownloadToken from core.websockets.auth import TokenQsAuthMiddleware +from pacsfiles.lonk import ( + SubscriptionRequest, + Lonk, + LonkWsSubscription, + LonkProgress, + LonkDone, + LonkError, +) from pacsfiles.consumers import PACSFileProgress +from pacsfiles.tests.mocks import Mockidicom class PACSFileProgressTests(TransactionTestCase): - def setUp(self): self.username = 'PintoGideon' self.password = 'gideon1234' self.email = 'gideon@example.org' - self.user = User.objects.create_user(username=self.username, - email=self.email, - password=self.password) + self.user = User.objects.create_user( + username=self.username, email=self.email, password=self.password + ) pacs_grp, _ = Group.objects.get_or_create(name='pacs_users') self.user.groups.set([pacs_grp]) self.user.save() - async def test_my_consumer(self): + @tag('integration') + async def test_lonk_ws(self): token = await self._get_download_token() app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi()) - communicator = WebsocketCommunicator(app, f'v1/pacs/ws/?token={token.token}') + communicator = WebsocketCommunicator( + app, f'v1/pacs/ws/?token={token.token}' + ) connected, subprotocol = await communicator.connect() assert connected + oxidicom: Mockidicom = await Mockidicom.connect(settings.NATS_ADDRESS) + + series1 = {'pacs_name': 'MyPACS', 'SeriesInstanceUID': '1.234.567890'} + subscription_request = SubscriptionRequest( + action='subscribe', **series1 + ) + await communicator.send_json_to(subscription_request) + self.assertEqual( + await communicator.receive_json_from(), + Lonk( + message=LonkWsSubscription(subscription='subscribed'), + **series1, + ), + ) + series2 = {'pacs_name': 'MyPACS', 'SeriesInstanceUID': '5.678.90123'} + subscription_request = SubscriptionRequest( + action='subscribe', **series2 + ) + await communicator.send_json_to(subscription_request) + self.assertEqual( + await communicator.receive_json_from(), + Lonk( + message=LonkWsSubscription(subscription='subscribed'), + **series2, + ), + ) + + await oxidicom.send_progress(ndicom=1, **series1) + self.assertEqual( + await communicator.receive_json_from(), + Lonk(message=LonkProgress(ndicom=1), **series1), + ) + await oxidicom.send_progress(ndicom=115, **series1) + self.assertEqual( + await communicator.receive_json_from(), + Lonk(message=LonkProgress(ndicom=115), **series1), + ) + + await oxidicom.send_error(error='stuck in chimney', **series2) + self.assertEqual( + await communicator.receive_json_from(), + Lonk(message=LonkError(error='stuck in chimney'), **series2), + ) + + await oxidicom.send_progress(ndicom=192, **series1) + self.assertEqual( + await communicator.receive_json_from(), + Lonk(message=LonkProgress(ndicom=192), **series1), + ) + await oxidicom.send_done(**series1) + self.assertEqual( + await communicator.receive_json_from(), + Lonk(message=LonkDone(done=True), **series1), + ) + async def test_unauthenticated_not_connected(self): app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi()) communicator = WebsocketCommunicator(app, 'v1/pacs/ws/') # no token @@ -49,6 +115,9 @@ def _get_download_token(self) -> FileDownloadToken: https://github.com/FNNDSC/ChRIS_ultron_backEnd/blob/7bcccc2031386955875ef4e9758025577f5ee067/chris_backend/userfiles/tests/test_views.py#L210-L213 """ dt = timezone.now() + timezone.timedelta(minutes=10) - token = jwt.encode({'user': self.user.username, 'exp': dt}, settings.SECRET_KEY, - algorithm='HS256') - return FileDownloadToken.objects.create(token=token, owner=self.user) \ No newline at end of file + token = jwt.encode( + {'user': self.user.username, 'exp': dt}, + settings.SECRET_KEY, + algorithm='HS256', + ) + return FileDownloadToken.objects.create(token=token, owner=self.user)