From 0f7686b2e90190a6545d4c9cf4ea847a49f7e8d9 Mon Sep 17 00:00:00 2001 From: Davide Date: Thu, 7 Aug 2025 18:28:14 +0200 Subject: [PATCH 1/6] Use signals to handle DB-close --- channels/consumer.py | 7 ++++--- channels/db.py | 13 ++++++++++--- channels/signals.py | 5 +++++ channels/testing/application.py | 26 +++++++++++++++----------- tests/test_database.py | 6 ++++-- 5 files changed, 38 insertions(+), 19 deletions(-) create mode 100644 channels/signals.py diff --git a/channels/consumer.py b/channels/consumer.py index fc065432b..fbc9c43d7 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -3,9 +3,10 @@ from asgiref.sync import async_to_sync from . import DEFAULT_CHANNEL_LAYER -from .db import aclose_old_connections, database_sync_to_async +from .db import database_sync_to_async from .exceptions import StopConsumer from .layers import get_channel_layer +from .signals import consumer_started, consumer_terminated from .utils import await_many_dispatch @@ -62,7 +63,7 @@ async def __call__(self, scope, receive, send): await await_many_dispatch([receive], self.dispatch) except StopConsumer: # Exit cleanly - pass + await consumer_terminated.asend(sender=self.__class__) async def dispatch(self, message): """ @@ -70,7 +71,7 @@ async def dispatch(self, message): """ handler = getattr(self, get_handler_name(message), None) if handler: - await aclose_old_connections() + await consumer_started.asend(sender=self.__class__) await handler(message) else: raise ValueError("No handler for message type %s" % message["type"]) diff --git a/channels/db.py b/channels/db.py index 2961b5cdb..ec2342edc 100644 --- a/channels/db.py +++ b/channels/db.py @@ -1,5 +1,6 @@ from asgiref.sync import SyncToAsync, sync_to_async from django.db import close_old_connections +from .signals import consumer_started, consumer_terminated, db_sync_to_async class DatabaseSyncToAsync(SyncToAsync): @@ -8,16 +9,22 @@ class DatabaseSyncToAsync(SyncToAsync): """ def thread_handler(self, loop, *args, **kwargs): - close_old_connections() + db_sync_to_async.send(sender=self.__class__, start=True) try: return super().thread_handler(loop, *args, **kwargs) finally: - close_old_connections() + db_sync_to_async.send(sender=self.__class__, start=False) # The class is TitleCased, but we want to encourage use as a callable/decorator database_sync_to_async = DatabaseSyncToAsync -async def aclose_old_connections(): +async def aclose_old_connections(**kwargs): return await sync_to_async(close_old_connections)() + + +consumer_started.connect(aclose_old_connections) +consumer_terminated.connect(aclose_old_connections) +db_sync_to_async.connect(close_old_connections) + diff --git a/channels/signals.py b/channels/signals.py new file mode 100644 index 000000000..96c613778 --- /dev/null +++ b/channels/signals.py @@ -0,0 +1,5 @@ +from django.dispatch import Signal + +consumer_started = Signal() +consumer_terminated = Signal() +db_sync_to_async = Signal() diff --git a/channels/testing/application.py b/channels/testing/application.py index 2003178c1..c06be628c 100644 --- a/channels/testing/application.py +++ b/channels/testing/application.py @@ -1,17 +1,21 @@ -from unittest import mock +from contextlib import asynccontextmanager from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator - - -def no_op(): - pass +from channels.db import aclose_old_connections +from channels.signals import consumer_started, consumer_terminated, db_sync_to_async +from django.db import close_old_connections class ApplicationCommunicator(BaseApplicationCommunicator): - async def send_input(self, message): - with mock.patch("channels.db.close_old_connections", no_op): - return await super().send_input(message) - async def receive_output(self, timeout=1): - with mock.patch("channels.db.close_old_connections", no_op): - return await super().receive_output(timeout) + @asynccontextmanager + async def handle_db(self): + consumer_started.disconnect(aclose_old_connections) + consumer_terminated.disconnect(aclose_old_connections) + db_sync_to_async.disconnect(close_old_connections) + try: + yield + finally: + consumer_started.connect(aclose_old_connections) + consumer_terminated.connect(aclose_old_connections) + db_sync_to_async.connect(close_old_connections) diff --git a/tests/test_database.py b/tests/test_database.py index 3faf05b5b..52978ada9 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -38,7 +38,8 @@ async def test_websocket(self): "This bug only occurs when the database is materialized on disk", ) communicator = WebsocketCommunicator(WebsocketConsumer.as_asgi(), "/") - connected, subprotocol = await communicator.connect() + async with communicator.handle_db(): + connected, subprotocol = await communicator.connect() self.assertTrue(connected) self.assertEqual(subprotocol, "fun") @@ -51,5 +52,6 @@ async def test_http(self): communicator = HttpCommunicator( HttpConsumer.as_asgi(), method="GET", path="/test/" ) - connected = await communicator.get_response() + async with communicator.handle_db(): + connected = await communicator.get_response() self.assertTrue(connected) From b606976ca89a27675483593f6e4658518528b44d Mon Sep 17 00:00:00 2001 From: Davide Date: Mon, 11 Aug 2025 16:09:26 +0200 Subject: [PATCH 2/6] Use a decorator/ctx-manager to handle DB keep-open --- channels/db.py | 1 - channels/testing/__init__.py | 70 ++++++++++++++++++++++++++++++++- channels/testing/application.py | 21 ---------- channels/testing/http.py | 2 +- channels/testing/websocket.py | 2 +- tests/test_database.py | 11 +++--- 6 files changed, 77 insertions(+), 30 deletions(-) delete mode 100644 channels/testing/application.py diff --git a/channels/db.py b/channels/db.py index ec2342edc..c07233b10 100644 --- a/channels/db.py +++ b/channels/db.py @@ -27,4 +27,3 @@ async def aclose_old_connections(**kwargs): consumer_started.connect(aclose_old_connections) consumer_terminated.connect(aclose_old_connections) db_sync_to_async.connect(close_old_connections) - diff --git a/channels/testing/__init__.py b/channels/testing/__init__.py index d7dee3ef7..05b1e0989 100644 --- a/channels/testing/__init__.py +++ b/channels/testing/__init__.py @@ -1,4 +1,11 @@ -from .application import ApplicationCommunicator # noqa +import asyncio +from contextlib import AbstractAsyncContextManager, AbstractContextManager +from functools import wraps + +from channels.db import aclose_old_connections +from channels.signals import consumer_started, consumer_terminated, db_sync_to_async +from django.db import close_old_connections +from asgiref.testing import ApplicationCommunicator # noqa from .http import HttpCommunicator # noqa from .live import ChannelsLiveServerTestCase # noqa from .websocket import WebsocketCommunicator # noqa @@ -8,4 +15,65 @@ "HttpCommunicator", "ChannelsLiveServerTestCase", "WebsocketCommunicator", + "keep_db_open", ] + + +class DatabaseWrapper(AbstractAsyncContextManager, AbstractContextManager): + """ + Wrapper which can be used as both context-manager or decorator to ensure + that database connections are not closed during test execution. + """ + + def __init__(self): + self._lock = asyncio.Lock() + self._counter = 0 + + def _disconnect(self): + if self._counter == 0: + consumer_started.disconnect(aclose_old_connections) + consumer_terminated.disconnect(aclose_old_connections) + db_sync_to_async.disconnect(close_old_connections) + + self._counter += 1 + + def _connect(self): + self._counter -= 1 + if self._counter <= 0: + consumer_started.connect(aclose_old_connections) + consumer_terminated.connect(aclose_old_connections) + db_sync_to_async.connect(close_old_connections) + + def __enter__(self): + self._disconnect() + + def __exit__(self, exc_type, exc_value, traceback): + self._disconnect() + + # in async mode also use a lock to reduce concurrency issue + # with the inner counter value + async def __aenter__(self): + async with self._lock: + self._disconnect() + + async def __aexit__(self, exc_type, exc_value, traceback): + async with self._lock: + self._connect() + + def __call__(self, func): + if asyncio.iscoroutinefunction(func): + + async def wrapper(*args, **kwargs): + async with self: + return await func(*args, **kwargs) + + else: + + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wraps(func)(wrapper) + + +keep_db_open = DatabaseWrapper() diff --git a/channels/testing/application.py b/channels/testing/application.py deleted file mode 100644 index c06be628c..000000000 --- a/channels/testing/application.py +++ /dev/null @@ -1,21 +0,0 @@ -from contextlib import asynccontextmanager - -from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator -from channels.db import aclose_old_connections -from channels.signals import consumer_started, consumer_terminated, db_sync_to_async -from django.db import close_old_connections - - -class ApplicationCommunicator(BaseApplicationCommunicator): - - @asynccontextmanager - async def handle_db(self): - consumer_started.disconnect(aclose_old_connections) - consumer_terminated.disconnect(aclose_old_connections) - db_sync_to_async.disconnect(close_old_connections) - try: - yield - finally: - consumer_started.connect(aclose_old_connections) - consumer_terminated.connect(aclose_old_connections) - db_sync_to_async.connect(close_old_connections) diff --git a/channels/testing/http.py b/channels/testing/http.py index 8130265a0..6b1514ca7 100644 --- a/channels/testing/http.py +++ b/channels/testing/http.py @@ -1,6 +1,6 @@ from urllib.parse import unquote, urlparse -from channels.testing.application import ApplicationCommunicator +from asgiref.testing import ApplicationCommunicator class HttpCommunicator(ApplicationCommunicator): diff --git a/channels/testing/websocket.py b/channels/testing/websocket.py index 24e58d369..57ea4a653 100644 --- a/channels/testing/websocket.py +++ b/channels/testing/websocket.py @@ -1,7 +1,7 @@ import json from urllib.parse import unquote, urlparse -from channels.testing.application import ApplicationCommunicator +from asgiref.testing import ApplicationCommunicator class WebsocketCommunicator(ApplicationCommunicator): diff --git a/tests/test_database.py b/tests/test_database.py index 52978ada9..961b906a4 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,7 +4,7 @@ from channels.db import database_sync_to_async from channels.generic.http import AsyncHttpConsumer from channels.generic.websocket import AsyncWebsocketConsumer -from channels.testing import HttpCommunicator, WebsocketCommunicator +from channels.testing import HttpCommunicator, WebsocketCommunicator, keep_db_open @database_sync_to_async @@ -31,6 +31,8 @@ async def handle(self, body): class ConnectionClosingTests(TestCase): + + @keep_db_open async def test_websocket(self): self.assertNotRegex( db.connections["default"].settings_dict.get("NAME"), @@ -38,11 +40,11 @@ async def test_websocket(self): "This bug only occurs when the database is materialized on disk", ) communicator = WebsocketCommunicator(WebsocketConsumer.as_asgi(), "/") - async with communicator.handle_db(): - connected, subprotocol = await communicator.connect() + connected, subprotocol = await communicator.connect() self.assertTrue(connected) self.assertEqual(subprotocol, "fun") + @keep_db_open async def test_http(self): self.assertNotRegex( db.connections["default"].settings_dict.get("NAME"), @@ -52,6 +54,5 @@ async def test_http(self): communicator = HttpCommunicator( HttpConsumer.as_asgi(), method="GET", path="/test/" ) - async with communicator.handle_db(): - connected = await communicator.get_response() + connected = await communicator.get_response() self.assertTrue(connected) From 662eb3b0d366f3968a7d4b535bc90d44f2d6612c Mon Sep 17 00:00:00 2001 From: Davide Date: Mon, 11 Aug 2025 17:06:34 +0200 Subject: [PATCH 3/6] Remove explicit call to aclose_db_connections since it is handled in base consumer class --- channels/generic/http.py | 2 -- channels/generic/websocket.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/channels/generic/http.py b/channels/generic/http.py index 0d043cc3a..909e85704 100644 --- a/channels/generic/http.py +++ b/channels/generic/http.py @@ -1,6 +1,5 @@ from channels.consumer import AsyncConsumer -from ..db import aclose_old_connections from ..exceptions import StopConsumer @@ -89,5 +88,4 @@ async def http_disconnect(self, message): Let the user do their cleanup and close the consumer. """ await self.disconnect() - await aclose_old_connections() raise StopConsumer() diff --git a/channels/generic/websocket.py b/channels/generic/websocket.py index b4d99119c..899ac8915 100644 --- a/channels/generic/websocket.py +++ b/channels/generic/websocket.py @@ -3,7 +3,6 @@ from asgiref.sync import async_to_sync from ..consumer import AsyncConsumer, SyncConsumer -from ..db import aclose_old_connections from ..exceptions import ( AcceptConnection, DenyConnection, @@ -248,7 +247,6 @@ async def websocket_disconnect(self, message): "BACKEND is unconfigured or doesn't support groups" ) await self.disconnect(message["code"]) - await aclose_old_connections() raise StopConsumer() async def disconnect(self, code): From 2b0b272ebd9e8906265d83582b740788fd38d0de Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 12 Aug 2025 09:43:55 +0200 Subject: [PATCH 4/6] Use a mixin for TestCase to handle db-connection signals --- channels/testing/__init__.py | 74 ++++++++---------------------------- tests/test_database.py | 6 +-- 2 files changed, 18 insertions(+), 62 deletions(-) diff --git a/channels/testing/__init__.py b/channels/testing/__init__.py index 05b1e0989..991cf12c6 100644 --- a/channels/testing/__init__.py +++ b/channels/testing/__init__.py @@ -1,7 +1,3 @@ -import asyncio -from contextlib import AbstractAsyncContextManager, AbstractContextManager -from functools import wraps - from channels.db import aclose_old_connections from channels.signals import consumer_started, consumer_terminated, db_sync_to_async from django.db import close_old_connections @@ -15,65 +11,27 @@ "HttpCommunicator", "ChannelsLiveServerTestCase", "WebsocketCommunicator", - "keep_db_open", + "ConsumerTestMixin", ] -class DatabaseWrapper(AbstractAsyncContextManager, AbstractContextManager): +class ConsumerTestMixin: """ - Wrapper which can be used as both context-manager or decorator to ensure - that database connections are not closed during test execution. + Mixin to be applied to Django `TestCase` or `TransactionTestCase` to ensure + that database connections are not closed by consumers during test execution. """ - def __init__(self): - self._lock = asyncio.Lock() - self._counter = 0 - - def _disconnect(self): - if self._counter == 0: - consumer_started.disconnect(aclose_old_connections) - consumer_terminated.disconnect(aclose_old_connections) - db_sync_to_async.disconnect(close_old_connections) - - self._counter += 1 - - def _connect(self): - self._counter -= 1 - if self._counter <= 0: - consumer_started.connect(aclose_old_connections) - consumer_terminated.connect(aclose_old_connections) - db_sync_to_async.connect(close_old_connections) - - def __enter__(self): - self._disconnect() - - def __exit__(self, exc_type, exc_value, traceback): - self._disconnect() - - # in async mode also use a lock to reduce concurrency issue - # with the inner counter value - async def __aenter__(self): - async with self._lock: - self._disconnect() - - async def __aexit__(self, exc_type, exc_value, traceback): - async with self._lock: - self._connect() - - def __call__(self, func): - if asyncio.iscoroutinefunction(func): - - async def wrapper(*args, **kwargs): - async with self: - return await func(*args, **kwargs) - - else: - - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wraps(func)(wrapper) + @classmethod + def setUpClass(cls): + super().setUpClass() + consumer_started.disconnect(aclose_old_connections) + consumer_terminated.disconnect(aclose_old_connections) + db_sync_to_async.disconnect(close_old_connections) + @classmethod + def tearDownClass(cls): + super().tearDownClass() + consumer_started.connect(aclose_old_connections) + consumer_terminated.connect(aclose_old_connections) + db_sync_to_async.connect(close_old_connections) -keep_db_open = DatabaseWrapper() diff --git a/tests/test_database.py b/tests/test_database.py index 961b906a4..7d210c3a3 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,7 +4,7 @@ from channels.db import database_sync_to_async from channels.generic.http import AsyncHttpConsumer from channels.generic.websocket import AsyncWebsocketConsumer -from channels.testing import HttpCommunicator, WebsocketCommunicator, keep_db_open +from channels.testing import ConsumerTestMixin, HttpCommunicator, WebsocketCommunicator @database_sync_to_async @@ -30,9 +30,8 @@ async def handle(self, body): ) -class ConnectionClosingTests(TestCase): +class ConnectionClosingTests(ConsumerTestMixin, TestCase): - @keep_db_open async def test_websocket(self): self.assertNotRegex( db.connections["default"].settings_dict.get("NAME"), @@ -44,7 +43,6 @@ async def test_websocket(self): self.assertTrue(connected) self.assertEqual(subprotocol, "fun") - @keep_db_open async def test_http(self): self.assertNotRegex( db.connections["default"].settings_dict.get("NAME"), From 42cc1edcc4c0bfc2e50ddd959246907375a568c8 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 12 Aug 2025 09:44:11 +0200 Subject: [PATCH 5/6] Add also postgres to database tests --- .github/workflows/tests.yml | 9 +++++++++ setup.cfg | 1 + tests/sample_project/config/settings.py | 6 ++++++ tests/test_database.py | 9 +++++++-- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2750cf916..a0559b2dd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,6 +10,15 @@ jobs: tests: name: Python ${{ matrix.python-version }} runs-on: ubuntu-latest + services: + postgres: + env: + POSTGRES_USER: channels + POSTGRES_PASSWORD: channels + POSTGRES_DB: channels + image: postgres:14-alpine + ports: ["5432:5432"] + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 strategy: fail-fast: false matrix: diff --git a/setup.cfg b/setup.cfg index 45fa26294..1a15c66e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ tests = pytest-django pytest-asyncio selenium + psycopg daphne = daphne>=4.0.0 types = diff --git a/tests/sample_project/config/settings.py b/tests/sample_project/config/settings.py index 610572173..52e06d476 100644 --- a/tests/sample_project/config/settings.py +++ b/tests/sample_project/config/settings.py @@ -64,6 +64,12 @@ # Override Django’s default behaviour of using an in-memory database # in tests for SQLite, since that avoids connection.close() working. "TEST": {"NAME": "test_db.sqlite3"}, + }, + "other": { + "ENGINE": "django.db.backends.postgresql", + "NAME": "channels", + "USER": "channels", + "PASSWORD": "channels", } } diff --git a/tests/test_database.py b/tests/test_database.py index 7d210c3a3..4edd60311 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -10,8 +10,12 @@ @database_sync_to_async def basic_query(): with db.connections["default"].cursor() as cursor: - cursor.execute("SELECT 1234") - return cursor.fetchone()[0] + cursor.execute("SELECT 1234;") + cursor.fetchone()[0] + + with db.connections["other"].cursor() as cursor: + cursor.execute("SELECT 1234;") + cursor.fetchone()[0] class WebsocketConsumer(AsyncWebsocketConsumer): @@ -31,6 +35,7 @@ async def handle(self, body): class ConnectionClosingTests(ConsumerTestMixin, TestCase): + databases = {'default', 'other'} async def test_websocket(self): self.assertNotRegex( From bb1e5c97051b5f1ccd9a695d55cce03797fbfca4 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 12 Aug 2025 09:56:52 +0200 Subject: [PATCH 6/6] Adjust confings for postgres in tests --- tests/sample_project/config/settings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/sample_project/config/settings.py b/tests/sample_project/config/settings.py index 52e06d476..a472fbd23 100644 --- a/tests/sample_project/config/settings.py +++ b/tests/sample_project/config/settings.py @@ -70,6 +70,8 @@ "NAME": "channels", "USER": "channels", "PASSWORD": "channels", + "HOST": "localhost", + "PORT": 5432, } }