Skip to content

Commit

Permalink
Test websocket connection and auth
Browse files Browse the repository at this point in the history
  • Loading branch information
jennydaman committed Sep 9, 2024
1 parent b68145b commit 596f422
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 19 deletions.
2 changes: 1 addition & 1 deletion chris_backend/core/websockets/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from pacsfiles import consumers

websocket_urlpatterns = [
re_path(r'v1/pacs/progress/',
re_path(r'v1/pacs/ws/',
consumers.PACSFileProgress.as_asgi()),
]
41 changes: 23 additions & 18 deletions chris_backend/pacsfiles/consumers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
import json

import channels.exceptions
from channels.generic.websocket import WebsocketConsumer
import rest_framework.permissions
from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from rest_framework import permissions

from pacsfiles.permissions import IsChrisOrIsPACSUserReadOnly


class PACSFileProgress(WebsocketConsumer):
class PACSFileProgress(AsyncJsonWebsocketConsumer):
"""
A WebSockets endpoint which relays progress messages from NATS sent by *oxidicom* to a client.
"""

permission_classes = (permissions.IsAuthenticated, IsChrisOrIsPACSUserReadOnly,)

def connect(self):
if not self._has_permission():
raise channels.exceptions.DenyConnection()
self.accept()
async def connect(self):
if not await self._has_permission():
await self.close()
else:
await self.accept()

async def receive_json(self, content, **kwargs):
...

async def disconnect(self, code):
...

@database_sync_to_async
def _has_permission(self) -> bool:
"""
Manual permissions check.
Expand All @@ -26,16 +36,11 @@ def _has_permission(self) -> bool:
self.user = self.scope.get('user', None)
if self.user is None:
return False
if getattr(self, 'method', None) is None:
# make it work with ``IsChrisOrIsPACSUserReadOnly``
self.method = rest_framework.permissions.SAFE_METHODS[0]

return all(
permission().has_permission(self, self.__class__)
for permission in self.permission_classes
)

def disconnect(self, close_code):
pass

def receive(self, text_data):
text_data_json = json.loads(text_data)
message = text_data_json["message"]

self.send(text_data=json.dumps({"message": message}))
54 changes: 54 additions & 0 deletions chris_backend/pacsfiles/tests/test_consumers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jwt
from channels.db import database_sync_to_async
from django.conf import settings
from django.contrib.auth.models import User, Group

# note: use TransactionTestCase instead of TestCase for async tests that speak to DB.
# See https://stackoverflow.com/a/71763849
from django.test import TransactionTestCase

from channels.testing import WebsocketCommunicator
from django.utils import timezone

from core.models import FileDownloadToken
from core.websockets.auth import TokenQsAuthMiddleware

from pacsfiles.consumers import PACSFileProgress


class PACSFileProgressTests(TransactionTestCase):

def setUp(self):
self.username = 'PintoGideon'
self.password = 'gideon1234'
self.email = '[email protected]'
self.user = User.objects.create_user(username=self.username,
email=self.email,
password=self.password)
pacs_grp, _ = Group.objects.get_or_create(name='pacs_users')
self.user.groups.set([pacs_grp])
self.user.save()

async def test_my_consumer(self):
token = await self._get_download_token()
app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi())
communicator = WebsocketCommunicator(app, f'v1/pacs/ws/?token={token.token}')
connected, subprotocol = await communicator.connect()
assert connected

async def test_unauthenticated_not_connected(self):
app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi())
communicator = WebsocketCommunicator(app, 'v1/pacs/ws/') # no token
connected, subprotocol = await communicator.connect()
assert not connected

@database_sync_to_async
def _get_download_token(self) -> FileDownloadToken:
"""
Copy-pasted from
https://github.com/FNNDSC/ChRIS_ultron_backEnd/blob/7bcccc2031386955875ef4e9758025577f5ee067/chris_backend/userfiles/tests/test_views.py#L210-L213
"""
dt = timezone.now() + timezone.timedelta(minutes=10)
token = jwt.encode({'user': self.user.username, 'exp': dt}, settings.SECRET_KEY,
algorithm='HS256')
return FileDownloadToken.objects.create(token=token, owner=self.user)

0 comments on commit 596f422

Please sign in to comment.