diff --git a/django_async_extensions/utils/__init__.py b/django_async_extensions/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/django_async_extensions/utils/decorators.py b/django_async_extensions/utils/decorators.py new file mode 100644 index 0000000..f7414e2 --- /dev/null +++ b/django_async_extensions/utils/decorators.py @@ -0,0 +1,128 @@ +from functools import wraps + +from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async + + +def decorator_from_middleware_with_args(middleware_class): + """ + Like decorator_from_middleware, but return a function + that accepts the arguments to be passed to the middleware_class. + Use like:: + + cache_page = decorator_from_middleware_with_args(CacheMiddleware) + # ... + + @cache_page(3600) + def my_view(request): + # ... + """ + return make_middleware_decorator(middleware_class) + + +def decorator_from_middleware(middleware_class): + """ + Given a middleware class (not an instance), return a view decorator. This + lets you use middleware functionality on a per-view basis. The middleware + is created with no params passed. + """ + return make_middleware_decorator(middleware_class)() + + +def make_middleware_decorator(middleware_class): + def _make_decorator(*m_args, **m_kwargs): + def _decorator(view_func): + middleware = middleware_class(view_func, *m_args, **m_kwargs) + + async def _pre_process_request(request, *args, **kwargs): + if hasattr(middleware, "process_request"): + result = await middleware.process_request(request) + if result is not None: + return result + if hasattr(middleware, "process_view"): + if iscoroutinefunction(middleware.process_view): + result = await middleware.process_view( + request, view_func, args, kwargs + ) + else: + result = await sync_to_async(middleware.process_view)( + request, view_func, args, kwargs + ) + if result is not None: + return result + return None + + async def _process_exception(request, exception): + if hasattr(middleware, "process_exception"): + if iscoroutinefunction(middleware.process_exception): + result = await middleware.process_exception(request, exception) + else: + result = await sync_to_async(middleware.process_exception)( + request, exception + ) + if result is not None: + return result + raise + + async def _post_process_request(request, response): + if hasattr(response, "render") and callable(response.render): + if hasattr(middleware, "process_template_response"): + if iscoroutinefunction(middleware.process_template_response): + response = await middleware.process_template_response( + request, response + ) + else: + response = await sync_to_async( + middleware.process_template_response + )(request, response) + # Defer running of process_response until after the template + # has been rendered: + if hasattr(middleware, "process_response"): + + async def callback(response): + return await middleware.process_response(request, response) + + response.add_post_render_callback(async_to_sync(callback)) + else: + if hasattr(middleware, "process_response"): + return await middleware.process_response(request, response) + return response + + if iscoroutinefunction(view_func): + + async def _view_wrapper(request, *args, **kwargs): + result = await _pre_process_request(request, *args, **kwargs) + if result is not None: + return result + + try: + response = await view_func(request, *args, **kwargs) + except Exception as e: + result = await _process_exception(request, e) + if result is not None: + return result + + return await _post_process_request(request, response) + + else: + + def _view_wrapper(request, *args, **kwargs): + result = async_to_sync(_pre_process_request)( + request, *args, **kwargs + ) + if result is not None: + return result + + try: + response = view_func(request, *args, **kwargs) + except Exception as e: + result = async_to_sync(_process_exception)(request, e) + if result is not None: + return result + + return async_to_sync(_post_process_request)(request, response) + + return wraps(view_func)(_view_wrapper) + + return _decorator + + return _make_decorator diff --git a/docs/middleware/base.md b/docs/middleware/base.md index 267fc47..df48b8e 100644 --- a/docs/middleware/base.md +++ b/docs/middleware/base.md @@ -12,6 +12,9 @@ with the following specification: ``` where `get_response` is an **async function**, sync functions are not supported and **will raise** an error. +**Note:** you can use middlewares drove from this base class with normal django middlewares, you can even write sync views +`get_response` is usually provided by django, so you don't have to worry about it being async. + ---------------------------- other methods are as follows: diff --git a/docs/middleware/decorate_views.md b/docs/middleware/decorate_views.md new file mode 100644 index 0000000..491f3a1 --- /dev/null +++ b/docs/middleware/decorate_views.md @@ -0,0 +1,56 @@ +`django_async_extensions.utils.decorators.decorator_from_middleware` and +`django_async_extensions.utils.decorators.decorator_from_middleware_with_args` +are provided to decorate a view with an async middleware directly. + +they work almost exactly like django's [decorator_from_middleware](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware) +and [decorator_from_middleware_with_args](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware_with_args) +but it expects an async middleware as described in [AsyncMiddlewareMixin](base.md) + +**Important:** if you are using a middleware that inherits from [AsyncMiddlewareMixin](base.md) you can only decorate async views +if you need to decorate a sync view change middleware's `__init__()` method to accept async `get_response` argument. + +with an async view +```python +from django.http.response import HttpResponse + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.utils.decorators import decorator_from_middleware + +class MyAsyncMiddleware(AsyncMiddlewareMixin): + async def process_request(self, request): + return HttpResponse() + + +deco = decorator_from_middleware(MyAsyncMiddleware) + + +@deco +async def my_view(request): + return HttpResponse() +``` + + +if you need to use a sync view design your middleware like this +```python +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + +from asgiref.sync import iscoroutinefunction, markcoroutinefunction + + +class MyMiddleware(AsyncMiddlewareMixin): + sync_capable = True + + def __init__(self, get_response): + if get_response is None: + raise ValueError("get_response must be provided.") + self.get_response = get_response + + self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction( + getattr(self.get_response, "__call__", None) + ) + if self.async_mode: + # Mark the class as async-capable. + markcoroutinefunction(self) + + super().__init__() +``` diff --git a/tests/test_async_utils/__init__.py b/tests/test_async_utils/__init__.py new file mode 100644 index 0000000..08fc7f5 --- /dev/null +++ b/tests/test_async_utils/__init__.py @@ -0,0 +1 @@ +# named like this to not conflict with something from django :/ diff --git a/tests/test_async_utils/test_decorators.py b/tests/test_async_utils/test_decorators.py new file mode 100644 index 0000000..50f9b2b --- /dev/null +++ b/tests/test_async_utils/test_decorators.py @@ -0,0 +1,197 @@ +from asgiref.sync import sync_to_async + +import pytest + +from django.http import HttpResponse +from django.template import engines +from django.template.response import TemplateResponse +from django.test import RequestFactory + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.utils.decorators import decorator_from_middleware + + +class ProcessViewMiddleware(AsyncMiddlewareMixin): + def __init__(self, get_response): + self.get_response = get_response + + async def process_view(self, request, view_func, view_args, view_kwargs): + pass + + +process_view_dec = decorator_from_middleware(ProcessViewMiddleware) + + +@process_view_dec +async def async_process_view(request): + return HttpResponse() + + +@process_view_dec +def process_view(request): + return HttpResponse() + + +class ClassProcessView: + def __call__(self, request): + return HttpResponse() + + +class_process_view = process_view_dec(ClassProcessView()) + + +class AsyncClassProcessView: + async def __call__(self, request): + return HttpResponse() + + +async_class_process_view = process_view_dec(AsyncClassProcessView()) + + +class FullMiddleware(AsyncMiddlewareMixin): + def __init__(self, get_response): + self.get_response = get_response + + async def process_request(self, request): + request.process_request_reached = True + + async def process_view(self, request, view_func, view_args, view_kwargs): + request.process_view_reached = True + + async def process_template_response(self, request, response): + request.process_template_response_reached = True + return response + + async def process_response(self, request, response): + # This should never receive unrendered content. + request.process_response_content = response.content + request.process_response_reached = True + return response + + +full_dec = decorator_from_middleware(FullMiddleware) + + +class TestDecoratorFromMiddleware: + """ + Tests for view decorators created using + ``django.utils.decorators.decorator_from_middleware``. + """ + + rf = RequestFactory() + + def test_process_view_middleware(self): + """ + Test a middleware that implements process_view. + """ + process_view(self.rf.get("/")) + + async def test_process_view_middleware_async(self, async_rf): + await async_process_view(async_rf.get("/")) + + async def test_sync_process_view_raises_in_async_context(self): + msg = ( + "You cannot use AsyncToSync in the same thread as an async event loop" + " - just await the async function directly." + ) + with pytest.raises(RuntimeError, match=msg): + process_view(self.rf.get("/")) + + def test_callable_process_view_middleware(self): + """ + Test a middleware that implements process_view, operating on a callable class. + """ + class_process_view(self.rf.get("/")) + + async def test_callable_process_view_middleware_async(self, async_rf): + await async_process_view(async_rf.get("/")) + + def test_full_dec_normal(self): + """ + All methods of middleware are called for normal HttpResponses + """ + + @full_dec + def normal_view(request): + template = engines["django"].from_string("Hello world") + return HttpResponse(template.render()) + + request = self.rf.get("/") + normal_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + # process_template_response must not be called for HttpResponse + assert getattr(request, "process_template_response_reached", False) is False + assert getattr(request, "process_response_reached", False) + + async def test_full_dec_normal_async(self, async_rf): + """ + All methods of middleware are called for normal HttpResponses + """ + + @full_dec + async def normal_view(request): + template = engines["django"].from_string("Hello world") + return HttpResponse(template.render()) + + request = async_rf.get("/") + await normal_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + # process_template_response must not be called for HttpResponse + assert getattr(request, "process_template_response_reached", False) is False + assert getattr(request, "process_response_reached", False) + + def test_full_dec_templateresponse(self): + """ + All methods of middleware are called for TemplateResponses in + the right sequence. + """ + + @full_dec + def template_response_view(request): + template = engines["django"].from_string("Hello world") + return TemplateResponse(request, template) + + request = self.rf.get("/") + response = template_response_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + assert getattr(request, "process_template_response_reached", False) + # response must not be rendered yet. + assert response._is_rendered is False + # process_response must not be called until after response is rendered, + # otherwise some decorators like csrf_protect and gzip_page will not + # work correctly. See #16004 + assert getattr(request, "process_response_reached", False) is False + response.render() + assert getattr(request, "process_response_reached", False) + # process_response saw the rendered content + assert request.process_response_content == b"Hello world" + + async def test_full_dec_templateresponse_async(self, async_rf): + """ + All methods of middleware are called for TemplateResponses in + the right sequence. + """ + + @full_dec + async def template_response_view(request): + template = engines["django"].from_string("Hello world") + return TemplateResponse(request, template) + + request = async_rf.get("/") + response = await template_response_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + assert getattr(request, "process_template_response_reached", False) + # response must not be rendered yet. + assert response._is_rendered is False + # process_response must not be called until after response is rendered, + # otherwise some decorators like csrf_protect and gzip_page will not + # work correctly. See #16004 + assert getattr(request, "process_response_reached", False) is False + await sync_to_async(response.render)() + assert getattr(request, "process_response_reached", False) + # process_response saw the rendered content + assert request.process_response_content == b"Hello world"