Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ jobs:
tests:
name: Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
services:
postgres:
env:
POSTGRES_USER: channels
POSTGRES_PASSWORD: channels
POSTGRES_DB: channels
image: postgres:14-alpine
ports: ["5432:5432"]
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
strategy:
fail-fast: false
matrix:
Expand Down
7 changes: 4 additions & 3 deletions channels/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from asgiref.sync import async_to_sync

from . import DEFAULT_CHANNEL_LAYER
from .db import aclose_old_connections, database_sync_to_async
from .db import database_sync_to_async
from .exceptions import StopConsumer
from .layers import get_channel_layer
from .signals import consumer_started, consumer_terminated
from .utils import await_many_dispatch


Expand Down Expand Up @@ -62,15 +63,15 @@ async def __call__(self, scope, receive, send):
await await_many_dispatch([receive], self.dispatch)
except StopConsumer:
# Exit cleanly
pass
await consumer_terminated.asend(sender=self.__class__)

async def dispatch(self, message):
"""
Works out what to do with a message.
"""
handler = getattr(self, get_handler_name(message), None)
if handler:
await aclose_old_connections()
await consumer_started.asend(sender=self.__class__)
await handler(message)
else:
raise ValueError("No handler for message type %s" % message["type"])
Expand Down
12 changes: 9 additions & 3 deletions channels/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from asgiref.sync import SyncToAsync, sync_to_async
from django.db import close_old_connections
from .signals import consumer_started, consumer_terminated, db_sync_to_async


class DatabaseSyncToAsync(SyncToAsync):
Expand All @@ -8,16 +9,21 @@ class DatabaseSyncToAsync(SyncToAsync):
"""

def thread_handler(self, loop, *args, **kwargs):
close_old_connections()
db_sync_to_async.send(sender=self.__class__, start=True)
try:
return super().thread_handler(loop, *args, **kwargs)
finally:
close_old_connections()
db_sync_to_async.send(sender=self.__class__, start=False)


# The class is TitleCased, but we want to encourage use as a callable/decorator
database_sync_to_async = DatabaseSyncToAsync


async def aclose_old_connections():
async def aclose_old_connections(**kwargs):
return await sync_to_async(close_old_connections)()


consumer_started.connect(aclose_old_connections)
consumer_terminated.connect(aclose_old_connections)
db_sync_to_async.connect(close_old_connections)
2 changes: 0 additions & 2 deletions channels/generic/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from channels.consumer import AsyncConsumer

from ..db import aclose_old_connections
from ..exceptions import StopConsumer


Expand Down Expand Up @@ -89,5 +88,4 @@ async def http_disconnect(self, message):
Let the user do their cleanup and close the consumer.
"""
await self.disconnect()
await aclose_old_connections()
raise StopConsumer()
2 changes: 0 additions & 2 deletions channels/generic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from asgiref.sync import async_to_sync

from ..consumer import AsyncConsumer, SyncConsumer
from ..db import aclose_old_connections
from ..exceptions import (
AcceptConnection,
DenyConnection,
Expand Down Expand Up @@ -248,7 +247,6 @@ async def websocket_disconnect(self, message):
"BACKEND is unconfigured or doesn't support groups"
)
await self.disconnect(message["code"])
await aclose_old_connections()
raise StopConsumer()

async def disconnect(self, code):
Expand Down
5 changes: 5 additions & 0 deletions channels/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.dispatch import Signal

consumer_started = Signal()
consumer_terminated = Signal()
db_sync_to_async = Signal()
28 changes: 27 additions & 1 deletion channels/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .application import ApplicationCommunicator # noqa
from channels.db import aclose_old_connections
from channels.signals import consumer_started, consumer_terminated, db_sync_to_async
from django.db import close_old_connections
from asgiref.testing import ApplicationCommunicator # noqa
from .http import HttpCommunicator # noqa
from .live import ChannelsLiveServerTestCase # noqa
from .websocket import WebsocketCommunicator # noqa
Expand All @@ -8,4 +11,27 @@
"HttpCommunicator",
"ChannelsLiveServerTestCase",
"WebsocketCommunicator",
"ConsumerTestMixin",
]


class ConsumerTestMixin:
"""
Mixin to be applied to Django `TestCase` or `TransactionTestCase` to ensure
that database connections are not closed by consumers during test execution.
"""

@classmethod
def setUpClass(cls):
super().setUpClass()
consumer_started.disconnect(aclose_old_connections)
consumer_terminated.disconnect(aclose_old_connections)
db_sync_to_async.disconnect(close_old_connections)

@classmethod
def tearDownClass(cls):
super().tearDownClass()
consumer_started.connect(aclose_old_connections)
consumer_terminated.connect(aclose_old_connections)
db_sync_to_async.connect(close_old_connections)

17 changes: 0 additions & 17 deletions channels/testing/application.py

This file was deleted.

2 changes: 1 addition & 1 deletion channels/testing/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from urllib.parse import unquote, urlparse

from channels.testing.application import ApplicationCommunicator
from asgiref.testing import ApplicationCommunicator


class HttpCommunicator(ApplicationCommunicator):
Expand Down
2 changes: 1 addition & 1 deletion channels/testing/websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from urllib.parse import unquote, urlparse

from channels.testing.application import ApplicationCommunicator
from asgiref.testing import ApplicationCommunicator


class WebsocketCommunicator(ApplicationCommunicator):
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ tests =
pytest-django
pytest-asyncio
selenium
psycopg
daphne =
daphne>=4.0.0
types =
Expand Down
8 changes: 8 additions & 0 deletions tests/sample_project/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
# Override Django’s default behaviour of using an in-memory database
# in tests for SQLite, since that avoids connection.close() working.
"TEST": {"NAME": "test_db.sqlite3"},
},
"other": {
"ENGINE": "django.db.backends.postgresql",
"NAME": "channels",
"USER": "channels",
"PASSWORD": "channels",
"HOST": "localhost",
"PORT": 5432,
}
}

Expand Down
14 changes: 10 additions & 4 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from channels.db import database_sync_to_async
from channels.generic.http import AsyncHttpConsumer
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.testing import HttpCommunicator, WebsocketCommunicator
from channels.testing import ConsumerTestMixin, HttpCommunicator, WebsocketCommunicator


@database_sync_to_async
def basic_query():
with db.connections["default"].cursor() as cursor:
cursor.execute("SELECT 1234")
return cursor.fetchone()[0]
cursor.execute("SELECT 1234;")
cursor.fetchone()[0]

with db.connections["other"].cursor() as cursor:
cursor.execute("SELECT 1234;")
cursor.fetchone()[0]


class WebsocketConsumer(AsyncWebsocketConsumer):
Expand All @@ -30,7 +34,9 @@ async def handle(self, body):
)


class ConnectionClosingTests(TestCase):
class ConnectionClosingTests(ConsumerTestMixin, TestCase):
databases = {'default', 'other'}

async def test_websocket(self):
self.assertNotRegex(
db.connections["default"].settings_dict.get("NAME"),
Expand Down
Loading