diff --git a/channels/testing/__init__.py b/channels/testing/__init__.py index f96625cd..d7dee3ef 100644 --- a/channels/testing/__init__.py +++ b/channels/testing/__init__.py @@ -1,5 +1,4 @@ -from asgiref.testing import ApplicationCommunicator # noqa - +from .application import ApplicationCommunicator # noqa from .http import HttpCommunicator # noqa from .live import ChannelsLiveServerTestCase # noqa from .websocket import WebsocketCommunicator # noqa diff --git a/channels/testing/application.py b/channels/testing/application.py new file mode 100644 index 00000000..2003178c --- /dev/null +++ b/channels/testing/application.py @@ -0,0 +1,17 @@ +from unittest import mock + +from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator + + +def no_op(): + pass + + +class ApplicationCommunicator(BaseApplicationCommunicator): + async def send_input(self, message): + with mock.patch("channels.db.close_old_connections", no_op): + return await super().send_input(message) + + async def receive_output(self, timeout=1): + with mock.patch("channels.db.close_old_connections", no_op): + return await super().receive_output(timeout) diff --git a/channels/testing/http.py b/channels/testing/http.py index 6b1514ca..8130265a 100644 --- a/channels/testing/http.py +++ b/channels/testing/http.py @@ -1,6 +1,6 @@ from urllib.parse import unquote, urlparse -from asgiref.testing import ApplicationCommunicator +from channels.testing.application import ApplicationCommunicator class HttpCommunicator(ApplicationCommunicator): diff --git a/channels/testing/websocket.py b/channels/testing/websocket.py index 57ea4a65..24e58d36 100644 --- a/channels/testing/websocket.py +++ b/channels/testing/websocket.py @@ -1,7 +1,7 @@ import json from urllib.parse import unquote, urlparse -from asgiref.testing import ApplicationCommunicator +from channels.testing.application import ApplicationCommunicator class WebsocketCommunicator(ApplicationCommunicator): diff --git a/docs/topics/testing.rst b/docs/topics/testing.rst index a3c14a00..c3547fd8 100644 --- a/docs/topics/testing.rst +++ b/docs/topics/testing.rst @@ -73,8 +73,8 @@ you might need to fall back to it if you are testing things like HTTP chunked responses or long-polling, which aren't supported in ``HttpCommunicator`` yet. .. note:: - ``ApplicationCommunicator`` is actually provided by the base ``asgiref`` - package, but we let you import it from ``channels.testing`` for convenience. + ``ApplicationCommunicator`` extends the class provided by the base ``asgiref`` + package. Channels adds support for running unit tests with async consumers. To construct it, pass it an application and a scope: diff --git a/tests/conftest.py b/tests/conftest.py index 8e7b3155..94c9803a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,14 @@ def pytest_configure(): settings.configure( - DATABASES={"default": {"ENGINE": "django.db.backends.sqlite3"}}, + DATABASES={ + "default": { + "ENGINE": "django.db.backends.sqlite3", + # Override Django’s default behaviour of using an in-memory database + # in tests for SQLite, since that avoids connection.close() working. + "TEST": {"NAME": "test_db.sqlite3"}, + } + }, INSTALLED_APPS=[ "django.contrib.auth", "django.contrib.contenttypes", diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 00000000..3faf05b5 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,55 @@ +from django import db +from django.test import TestCase + +from channels.db import database_sync_to_async +from channels.generic.http import AsyncHttpConsumer +from channels.generic.websocket import AsyncWebsocketConsumer +from channels.testing import HttpCommunicator, WebsocketCommunicator + + +@database_sync_to_async +def basic_query(): + with db.connections["default"].cursor() as cursor: + cursor.execute("SELECT 1234") + return cursor.fetchone()[0] + + +class WebsocketConsumer(AsyncWebsocketConsumer): + async def connect(self): + await basic_query() + await self.accept("fun") + + +class HttpConsumer(AsyncHttpConsumer): + async def handle(self, body): + await basic_query() + await self.send_response( + 200, + b"", + headers={b"Content-Type": b"text/plain"}, + ) + + +class ConnectionClosingTests(TestCase): + async def test_websocket(self): + self.assertNotRegex( + db.connections["default"].settings_dict.get("NAME"), + "memorydb", + "This bug only occurs when the database is materialized on disk", + ) + communicator = WebsocketCommunicator(WebsocketConsumer.as_asgi(), "/") + connected, subprotocol = await communicator.connect() + self.assertTrue(connected) + self.assertEqual(subprotocol, "fun") + + async def test_http(self): + self.assertNotRegex( + db.connections["default"].settings_dict.get("NAME"), + "memorydb", + "This bug only occurs when the database is materialized on disk", + ) + communicator = HttpCommunicator( + HttpConsumer.as_asgi(), method="GET", path="/test/" + ) + connected = await communicator.get_response() + self.assertTrue(connected)