Skip to content

implement make_middleware-decorator and related utils #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
128 changes: 128 additions & 0 deletions django_async_extensions/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from functools import wraps

from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async


def decorator_from_middleware_with_args(middleware_class):
"""
Like decorator_from_middleware, but return a function
that accepts the arguments to be passed to the middleware_class.
Use like::

cache_page = decorator_from_middleware_with_args(CacheMiddleware)
# ...

@cache_page(3600)
def my_view(request):
# ...
"""
return make_middleware_decorator(middleware_class)


def decorator_from_middleware(middleware_class):
"""
Given a middleware class (not an instance), return a view decorator. This
lets you use middleware functionality on a per-view basis. The middleware
is created with no params passed.
"""
return make_middleware_decorator(middleware_class)()


def make_middleware_decorator(middleware_class):
def _make_decorator(*m_args, **m_kwargs):
def _decorator(view_func):
middleware = middleware_class(view_func, *m_args, **m_kwargs)

async def _pre_process_request(request, *args, **kwargs):
if hasattr(middleware, "process_request"):
result = await middleware.process_request(request)
if result is not None:
return result
if hasattr(middleware, "process_view"):
if iscoroutinefunction(middleware.process_view):
result = await middleware.process_view(
request, view_func, args, kwargs
)
else:
result = await sync_to_async(middleware.process_view)(
request, view_func, args, kwargs
)
if result is not None:
return result
return None

async def _process_exception(request, exception):
if hasattr(middleware, "process_exception"):
if iscoroutinefunction(middleware.process_exception):
result = await middleware.process_exception(request, exception)
else:
result = await sync_to_async(middleware.process_exception)(
request, exception
)
if result is not None:
return result
raise

async def _post_process_request(request, response):
if hasattr(response, "render") and callable(response.render):
if hasattr(middleware, "process_template_response"):
if iscoroutinefunction(middleware.process_template_response):
response = await middleware.process_template_response(
request, response
)
else:
response = await sync_to_async(
middleware.process_template_response
)(request, response)
# Defer running of process_response until after the template
# has been rendered:
if hasattr(middleware, "process_response"):

async def callback(response):
return await middleware.process_response(request, response)

response.add_post_render_callback(async_to_sync(callback))
else:
if hasattr(middleware, "process_response"):
return await middleware.process_response(request, response)
return response

if iscoroutinefunction(view_func):

async def _view_wrapper(request, *args, **kwargs):
result = await _pre_process_request(request, *args, **kwargs)
if result is not None:
return result

try:
response = await view_func(request, *args, **kwargs)
except Exception as e:
result = await _process_exception(request, e)
if result is not None:
return result

return await _post_process_request(request, response)

else:

def _view_wrapper(request, *args, **kwargs):
result = async_to_sync(_pre_process_request)(
request, *args, **kwargs
)
if result is not None:
return result

try:
response = view_func(request, *args, **kwargs)
except Exception as e:
result = async_to_sync(_process_exception)(request, e)
if result is not None:
return result

return async_to_sync(_post_process_request)(request, response)

return wraps(view_func)(_view_wrapper)

return _decorator

return _make_decorator
3 changes: 3 additions & 0 deletions docs/middleware/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ with the following specification:
```
where `get_response` is an **async function**, sync functions are not supported and **will raise** an error.

**Note:** you can use middlewares drove from this base class with normal django middlewares, you can even write sync views
`get_response` is usually provided by django, so you don't have to worry about it being async.

----------------------------

other methods are as follows:
Expand Down
56 changes: 56 additions & 0 deletions docs/middleware/decorate_views.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
`django_async_extensions.utils.decorators.decorator_from_middleware` and
`django_async_extensions.utils.decorators.decorator_from_middleware_with_args`
are provided to decorate a view with an async middleware directly.

they work almost exactly like django's [decorator_from_middleware](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware)
and [decorator_from_middleware_with_args](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware_with_args)
but it expects an async middleware as described in [AsyncMiddlewareMixin](base.md)

**Important:** if you are using a middleware that inherits from [AsyncMiddlewareMixin](base.md) you can only decorate async views
if you need to decorate a sync view change middleware's `__init__()` method to accept async `get_response` argument.

with an async view
```python
from django.http.response import HttpResponse

from django_async_extensions.middleware.base import AsyncMiddlewareMixin
from django_async_extensions.utils.decorators import decorator_from_middleware

class MyAsyncMiddleware(AsyncMiddlewareMixin):
async def process_request(self, request):
return HttpResponse()


deco = decorator_from_middleware(MyAsyncMiddleware)


@deco
async def my_view(request):
return HttpResponse()
```


if you need to use a sync view design your middleware like this
```python
from django_async_extensions.middleware.base import AsyncMiddlewareMixin

from asgiref.sync import iscoroutinefunction, markcoroutinefunction


class MyMiddleware(AsyncMiddlewareMixin):
sync_capable = True

def __init__(self, get_response):
if get_response is None:
raise ValueError("get_response must be provided.")
self.get_response = get_response

self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction(
getattr(self.get_response, "__call__", None)
)
if self.async_mode:
# Mark the class as async-capable.
markcoroutinefunction(self)

super().__init__()
```
1 change: 1 addition & 0 deletions tests/test_async_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# named like this to not conflict with something from django :/
197 changes: 197 additions & 0 deletions tests/test_async_utils/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from asgiref.sync import sync_to_async

import pytest

from django.http import HttpResponse
from django.template import engines
from django.template.response import TemplateResponse
from django.test import RequestFactory

from django_async_extensions.middleware.base import AsyncMiddlewareMixin
from django_async_extensions.utils.decorators import decorator_from_middleware


class ProcessViewMiddleware(AsyncMiddlewareMixin):
def __init__(self, get_response):
self.get_response = get_response

async def process_view(self, request, view_func, view_args, view_kwargs):
pass


process_view_dec = decorator_from_middleware(ProcessViewMiddleware)


@process_view_dec
async def async_process_view(request):
return HttpResponse()


@process_view_dec
def process_view(request):
return HttpResponse()


class ClassProcessView:
def __call__(self, request):
return HttpResponse()


class_process_view = process_view_dec(ClassProcessView())


class AsyncClassProcessView:
async def __call__(self, request):
return HttpResponse()


async_class_process_view = process_view_dec(AsyncClassProcessView())


class FullMiddleware(AsyncMiddlewareMixin):
def __init__(self, get_response):
self.get_response = get_response

async def process_request(self, request):
request.process_request_reached = True

async def process_view(self, request, view_func, view_args, view_kwargs):
request.process_view_reached = True

async def process_template_response(self, request, response):
request.process_template_response_reached = True
return response

async def process_response(self, request, response):
# This should never receive unrendered content.
request.process_response_content = response.content
request.process_response_reached = True
return response


full_dec = decorator_from_middleware(FullMiddleware)


class TestDecoratorFromMiddleware:
"""
Tests for view decorators created using
``django.utils.decorators.decorator_from_middleware``.
"""

rf = RequestFactory()

def test_process_view_middleware(self):
"""
Test a middleware that implements process_view.
"""
process_view(self.rf.get("/"))

async def test_process_view_middleware_async(self, async_rf):
await async_process_view(async_rf.get("/"))

async def test_sync_process_view_raises_in_async_context(self):
msg = (
"You cannot use AsyncToSync in the same thread as an async event loop"
" - just await the async function directly."
)
with pytest.raises(RuntimeError, match=msg):
process_view(self.rf.get("/"))

def test_callable_process_view_middleware(self):
"""
Test a middleware that implements process_view, operating on a callable class.
"""
class_process_view(self.rf.get("/"))

async def test_callable_process_view_middleware_async(self, async_rf):
await async_process_view(async_rf.get("/"))

def test_full_dec_normal(self):
"""
All methods of middleware are called for normal HttpResponses
"""

@full_dec
def normal_view(request):
template = engines["django"].from_string("Hello world")
return HttpResponse(template.render())

request = self.rf.get("/")
normal_view(request)
assert getattr(request, "process_request_reached", False)
assert getattr(request, "process_view_reached", False)
# process_template_response must not be called for HttpResponse
assert getattr(request, "process_template_response_reached", False) is False
assert getattr(request, "process_response_reached", False)

async def test_full_dec_normal_async(self, async_rf):
"""
All methods of middleware are called for normal HttpResponses
"""

@full_dec
async def normal_view(request):
template = engines["django"].from_string("Hello world")
return HttpResponse(template.render())

request = async_rf.get("/")
await normal_view(request)
assert getattr(request, "process_request_reached", False)
assert getattr(request, "process_view_reached", False)
# process_template_response must not be called for HttpResponse
assert getattr(request, "process_template_response_reached", False) is False
assert getattr(request, "process_response_reached", False)

def test_full_dec_templateresponse(self):
"""
All methods of middleware are called for TemplateResponses in
the right sequence.
"""

@full_dec
def template_response_view(request):
template = engines["django"].from_string("Hello world")
return TemplateResponse(request, template)

request = self.rf.get("/")
response = template_response_view(request)
assert getattr(request, "process_request_reached", False)
assert getattr(request, "process_view_reached", False)
assert getattr(request, "process_template_response_reached", False)
# response must not be rendered yet.
assert response._is_rendered is False
# process_response must not be called until after response is rendered,
# otherwise some decorators like csrf_protect and gzip_page will not
# work correctly. See #16004
assert getattr(request, "process_response_reached", False) is False
response.render()
assert getattr(request, "process_response_reached", False)
# process_response saw the rendered content
assert request.process_response_content == b"Hello world"

async def test_full_dec_templateresponse_async(self, async_rf):
"""
All methods of middleware are called for TemplateResponses in
the right sequence.
"""

@full_dec
async def template_response_view(request):
template = engines["django"].from_string("Hello world")
return TemplateResponse(request, template)

request = async_rf.get("/")
response = await template_response_view(request)
assert getattr(request, "process_request_reached", False)
assert getattr(request, "process_view_reached", False)
assert getattr(request, "process_template_response_reached", False)
# response must not be rendered yet.
assert response._is_rendered is False
# process_response must not be called until after response is rendered,
# otherwise some decorators like csrf_protect and gzip_page will not
# work correctly. See #16004
assert getattr(request, "process_response_reached", False) is False
await sync_to_async(response.render)()
assert getattr(request, "process_response_reached", False)
# process_response saw the rendered content
assert request.process_response_content == b"Hello world"
Loading