diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index d68f5fe9..602d8a4e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -197,6 +197,16 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. +Tracking open connections +......................... + +The new implementation of :class:`~asyncio.server.Server` provides a +:attr:`~asyncio.server.Server.connections` property, which is a set of all open +connections. This didn't exist in the original implementation. + +If you were keeping track of open connections, you may be able to simplify your +code by using this property. + .. _basic-auth: Performing HTTP Basic Authentication diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d940b1ea..1c87882f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,8 +35,11 @@ notice. New features ............ -* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for - enforcing HTTP Basic Auth on the server side. +* Made the set of active connections available in the :attr:`Server.connections + ` property. + +* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` + implementations of servers. .. _13.0: diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index d4d20aeb..2fcaeb41 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -17,6 +17,8 @@ Running a server .. autoclass:: Server + .. autoattribute:: connections + .. automethod:: close .. automethod:: wait_closed diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index f6cd9a1b..29860e56 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -28,7 +28,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, Event +from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout @@ -313,6 +313,18 @@ def __init__( # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + It can be useful in combination with :func:`~broadcast`. + + """ + return {connection for connection in self.handlers if connection.state is OPEN} + def wrap(self, server: asyncio.Server) -> None: """ Attach to a given :class:`asyncio.Server`. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index f05b9f1e..38f22690 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -350,6 +350,12 @@ async def test_disable_keepalive(self): latency = eval(await client.recv()) self.assertEqual(latency, 0) + async def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" @@ -362,6 +368,16 @@ def create_connection(*args, **kwargs): async with run_client(server) as client: await self.assertEval(client, "ws.create_connection_ran", "True") + async def test_connections(self): + """Server provides a connections property.""" + async with run_server() as server: + self.assertEqual(server.connections, set()) + async with run_client(server) as client: + self.assertEqual(len(server.connections), 1) + ws_id = str(next(iter(server.connections)).id) + await self.assertEval(client, "ws.id", ws_id) + self.assertEqual(server.connections, set()) + async def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" @@ -555,14 +571,6 @@ async def test_unsupported_compression(self): ) -class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): - async def test_logger(self): - """Server accepts a logger argument.""" - logger = logging.getLogger("test") - async with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) - - class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 39d7501b..a1763471 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -235,6 +235,12 @@ def test_disable_compression(self): with run_client(server) as client: self.assertEval(client, "ws.protocol.extensions", "[]") + def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" @@ -247,6 +253,19 @@ def create_connection(*args, **kwargs): with run_client(server) as client: self.assertEval(client, "ws.create_connection_ran", "True") + def test_fileno(self): + """Server provides a fileno attribute.""" + with run_server() as server: + self.assertIsInstance(server.fileno(), int) + + def test_shutdown(self): + """Server provides a shutdown method.""" + with run_server() as server: + server.shutdown() + # Check that the server socket is closed. + with self.assertRaises(OSError): + server.socket.accept() + def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" @@ -393,27 +412,6 @@ def test_unsupported_compression(self): ) -class WebSocketServerTests(unittest.TestCase): - def test_logger(self): - """Server accepts a logger argument.""" - logger = logging.getLogger("test") - with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) - - def test_fileno(self): - """Server provides a fileno attribute.""" - with run_server() as server: - self.assertIsInstance(server.fileno(), int) - - def test_shutdown(self): - """Server provides a shutdown method.""" - with run_server() as server: - server.shutdown() - # Check that the server socket is closed. - with self.assertRaises(OSError): - server.socket.accept() - - class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication."""