Skip to content

Commit

Permalink
Merge pull request #150 from NilCoalescing/fix/#148_drf-permisions-cl…
Browse files Browse the repository at this point in the history
…asses

Wrap DRF permissions classes.
  • Loading branch information
hishnash authored Jul 10, 2022
2 parents 3e62d75 + 746fe3d commit 43d6697
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 43 deletions.
56 changes: 30 additions & 26 deletions djangochannelsrestframework/consumers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json
import typing
from collections import defaultdict
Expand All @@ -10,12 +9,13 @@
from django.http import HttpRequest, HttpResponse
from django.http.response import Http404
from django.template.response import SimpleTemplateResponse
from django.contrib.auth.models import AnonymousUser
from rest_framework.exceptions import PermissionDenied, MethodNotAllowed, APIException
from rest_framework.permissions import BasePermission as DRFBasePermission
from rest_framework.response import Response

from djangochannelsrestframework.permissions import BasePermission
from djangochannelsrestframework.settings import api_settings
from djangochannelsrestframework.permissions import BasePermission, WrappedDRFPermission
from djangochannelsrestframework.scope_utils import request_from_scope, ensure_async


class APIConsumerMetaclass(type):
Expand All @@ -38,12 +38,6 @@ def __new__(mcs, name, bases, body):
return cls


def ensure_async(method: typing.Callable):
if asyncio.iscoroutinefunction(method):
return method
return database_sync_to_async(method)


class AsyncAPIConsumer(AsyncJsonWebsocketConsumer, metaclass=APIConsumerMetaclass):
"""
This provides an async API consumer that is very inspired by DjangoRestFrameworks ViewSets.
Expand All @@ -57,9 +51,8 @@ class AsyncAPIConsumer(AsyncJsonWebsocketConsumer, metaclass=APIConsumerMetaclas
# The following policies may be set at either globally, or per-view.
# take the default values set for django rest framework!

permission_classes = (
api_settings.DEFAULT_PERMISSION_CLASSES
) # type: List[Type[BasePermission]]
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
# type: List[Type[BasePermission]]

groups = {}

Expand All @@ -68,6 +61,20 @@ class AsyncAPIConsumer(AsyncJsonWebsocketConsumer, metaclass=APIConsumerMetaclas
lambda: defaultdict(set)
)

async def websocket_connect(self, message):
"""
Called when a WebSocket connection is opened.
"""
try:
for permission in await self.get_permissions(action="connect"):
if not await ensure_async(permission.can_connect)(
scope=self.scope, consumer=self, message=message
):
raise PermissionDenied()
await super().websocket_connect(message)
except PermissionDenied:
await self.close()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.groups = set(self.groups or [])
Expand Down Expand Up @@ -100,7 +107,16 @@ async def get_permissions(self, action: str, **kwargs):
"""
Instantiates and returns the list of permissions that this view requires.
"""
return [permission() for permission in self.permission_classes]
permission_instances = []
for permission_class in self.permission_classes:
instance = permission_class()

# If the permission is an DRF permission instance
if isinstance(instance, DRFBasePermission):
instance = WrappedDRFPermission(instance)
permission_instances.append(instance)

return permission_instances

async def check_permissions(self, action: str, **kwargs):
"""
Expand Down Expand Up @@ -245,24 +261,12 @@ async def handle_action(self, action: str, request_id: str, **kwargs):

@database_sync_to_async
def call_view(self, action: str, **kwargs):

request = HttpRequest()
request.path = self.scope.get("path")
request.session = self.scope.get("session", None)
request.user = self.scope.get("user", AnonymousUser)

request.META["HTTP_CONTENT_TYPE"] = "application/json"
request.META["HTTP_ACCEPT"] = "application/json"

for (header_name, value) in self.scope.get("headers", []):
request.META[header_name.decode("utf-8")] = value.decode("utf-8")
request = request_from_scope(self.scope)

args, view_kwargs = self.get_view_args(action=action, **kwargs)

request.method = self.actions[action]
request.POST = json.dumps(kwargs.get("data", {}))
if self.scope.get("cookies"):
request.COOKIES = self.scope.get("cookies")

for key, value in kwargs.get("query", {}).items():
if isinstance(value, list):
Expand Down
4 changes: 1 addition & 3 deletions djangochannelsrestframework/generics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Any, Dict, Type, Optional, List, OrderedDict, Union
from typing import Any, Dict, Type, Optional

from django.db.models import QuerySet, Model
from rest_framework.generics import get_object_or_404
from rest_framework.serializers import Serializer
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList

from djangochannelsrestframework.consumers import AsyncAPIConsumer
from djangochannelsrestframework.settings import api_settings


class GenericAsyncAPIConsumer(AsyncAPIConsumer):
Expand Down
7 changes: 3 additions & 4 deletions djangochannelsrestframework/mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Tuple, Dict, Optional, OrderedDict, Union
from djangochannelsrestframework.observer.model_observer import Action
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList

from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList

from .decorators import action
from djangochannelsrestframework.settings import api_settings
from djangochannelsrestframework.decorators import action


class CreateModelMixin:
Expand Down Expand Up @@ -419,7 +418,7 @@ def perform_delete(self, instance, **kwargs):


class PaginatedModelListMixin(ListModelMixin):

permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS

@action()
Expand Down
4 changes: 2 additions & 2 deletions djangochannelsrestframework/pagination.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union

from rest_framework.pagination import LimitOffsetPagination
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList

from djangochannelsrestframework.settings import api_settings
from rest_framework.response import Response
from rest_framework.pagination import LimitOffsetPagination


def _positive_int(integer_string, strict=False, cutoff=None):
Expand Down
53 changes: 52 additions & 1 deletion djangochannelsrestframework/permissions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Dict, Any


from channels.consumer import AsyncConsumer
from rest_framework.permissions import BasePermission as DRFBasePermission

from djangochannelsrestframework.scope_utils import ensure_async
from djangochannelsrestframework.scope_utils import request_from_scope


class OperationHolderMixin:
Expand Down Expand Up @@ -95,8 +100,21 @@ class BasePermission(metaclass=BasePermissionMetaclass):
async def has_permission(
self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs
) -> bool:
"""
Called on every websocket message sent before the corresponding action handler is called.
"""
pass

async def can_connect(
self, scope: Dict[str, Any], consumer: AsyncConsumer, message=None
) -> bool:
"""
Called during connection to validate if a given client can establish a websocket connection.
By default, this returns True and permits all connections to be made.
"""
return True


class AllowAny(BasePermission):
"""Allow any permission class"""
Expand All @@ -108,7 +126,7 @@ async def has_permission(


class IsAuthenticated(BasePermission):
"""Allow authenticated only class"""
"""Allow authenticated users"""

async def has_permission(
self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs
Expand All @@ -117,3 +135,36 @@ async def has_permission(
if not user:
return False
return user.pk and user.is_authenticated


class WrappedDRFPermission(BasePermission):
"""
Used to wrap an instance of DRF permissions class.
"""

permission: DRFBasePermission

mapped_actions = {
"create": "PUT",
"update": "PATCH",
"list": "GET",
"retrieve": "GET",
"connect": "HEAD",
}

def __init__(self, permission: DRFBasePermission):
self.permission = permission

async def has_permission(
self, scope: Dict[str, Any], consumer: AsyncConsumer, action: str, **kwargs
) -> bool:
request = request_from_scope(scope)
request.method = self.mapped_actions.get(action, action.upper())
return await ensure_async(self.permission.has_permission)(request, consumer)

async def can_connect(
self, scope: Dict[str, Any], consumer: AsyncConsumer, message=None
) -> bool:
request = request_from_scope(scope)
request.method = self.mapped_actions.get("connect", "CONNECT")
return await ensure_async(self.permission.has_permission)(request, consumer)
33 changes: 33 additions & 0 deletions djangochannelsrestframework/scope_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
from typing import Any, Dict, Callable

from channels.db import database_sync_to_async
from django.http import HttpRequest


def ensure_async(method: Callable):
"""
Ensure method is async if not wrap it in database_sync_to_async.
"""
if asyncio.iscoroutinefunction(method):
return method
return database_sync_to_async(method)


def request_from_scope(scope: Dict[str, Any]) -> HttpRequest:
from django.contrib.auth.models import AnonymousUser

request = HttpRequest()
request.path = scope.get("path")
request.session = scope.get("session", None)
request.user = scope.get("user", AnonymousUser)

request.META["HTTP_CONTENT_TYPE"] = "application/json"
request.META["HTTP_ACCEPT"] = "application/json"

for (header_name, value) in scope.get("headers", []):
request.META[header_name.decode("utf-8")] = value.decode("utf-8")

if scope.get("cookies"):
request.COOKIES = scope.get("cookies")
return request
62 changes: 56 additions & 6 deletions djangochannelsrestframework/settings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,65 @@
from django.conf import settings

from rest_framework.settings import APISettings
from django.core.signals import setting_changed
from rest_framework.settings import perform_import

DEFAULTS = {
"DEFAULT_PAGE_SIZE": 25,
"DEFAULT_PERMISSION_CLASSES": ("djangochannelsrestframework.permissions.AllowAny",),
"DEFAULT_PAGINATION_CLASS": None,
"PAGE_SIZE": None,
}
IMPORT_STRINGS = ("DEFAULT_PERMISSION_CLASSES",)
IMPORT_STRINGS = ("DEFAULT_PERMISSION_CLASSES", "DEFAULT_PAGINATION_CLASS")


class APISettings:
def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = user_settings
self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS
self._cached_attrs = set()

@property
def user_settings(self):
if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, "DJANGO_CHANNELS_REST_API", {})
return self._user_settings

def __getattr__(self, attr):
if attr not in self.defaults:
raise AttributeError("Invalid API setting: '%s'" % attr)

try:
# Check if present in user settings
val = self.user_settings[attr]
except KeyError:
# Fall back to defaults
val = self.defaults[attr]

# Coerce import strings into classes
if attr in self.import_strings:
val = perform_import(val, attr)

# Cache the result
self._cached_attrs.add(attr)
setattr(self, attr, val)
return val

def reload(self):
for attr in self._cached_attrs:
delattr(self, attr)
self._cached_attrs.clear()
if hasattr(self, "_user_settings"):
delattr(self, "_user_settings")


api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)


def reload_api_settings(*args, **kwargs):
setting = kwargs["setting"]
if setting == "DJANGO_CHANNELS_REST_API":
api_settings.reload()


api_settings = APISettings(
getattr(settings, "DJANGO_CHANNELS_REST_API", None), DEFAULTS, IMPORT_STRINGS
)
setting_changed.connect(reload_api_settings)
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
[tool:pytest]
addopts = tests/
Loading

0 comments on commit 43d6697

Please sign in to comment.