diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 8d089563..d68f5fe9 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -71,15 +71,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -HTTP Basic Authentication -......................... - -On the server side, :func:`~asyncio.server.serve` doesn't provide HTTP Basic -Authentication yet. - -For the avoidance of doubt, on the client side, :func:`~asyncio.client.connect` -performs HTTP Basic Authentication. - Following redirects ................... @@ -165,12 +156,12 @@ Server APIs | ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | | :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | -| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | +| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | +| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | -| ``websockets.auth.basic_auth_protocol_factory()`` |br| | | +| ``websockets.basic_auth_protocol_factory()`` |br| | See below :ref:`how to migrate ` to | +| ``websockets.auth.basic_auth_protocol_factory()`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ @@ -206,6 +197,75 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. +.. _basic-auth: + +Performing HTTP Basic Authentication +.................................... + +.. admonition:: This section applies only to servers. + :class: tip + + On the client side, :func:`~asyncio.client.connect` performs HTTP Basic + Authentication automatically when the URI contains credentials. + +In the original implementation, the recommended way to add HTTP Basic +Authentication to a server was to set the ``create_protocol`` argument of +:func:`~legacy.server.serve` to a factory function generated by +:func:`~legacy.auth.basic_auth_protocol_factory`:: + + from websockets.legacy.auth import basic_auth_protocol_factory + from websockets.legacy.server import serve + + async with serve(..., create_protocol=basic_auth_protocol_factory(...)): + ... + +In the new implementation, the :func:`~asyncio.server.basic_auth` function +generates a ``process_request`` coroutine that performs HTTP Basic +Authentication:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve(..., process_request=basic_auth(...)): + ... + +:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or +a ``check_credentials`` coroutine as well as an optional ``realm`` just like +:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, +``check_credentials`` may be a function instead of a coroutine. + +This new API has more obvious semantics. That makes it easier to understand and +also easier to extend. + +In the original implementation, overriding ``create_protocol`` changed the type +of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, +a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP +Basic Authentication in its ``process_request`` method. If you wanted to +customize ``process_request`` further, you had: + +* an ill-defined option: add a ``process_request`` argument to + :func:`~legacy.server.serve`; to tell which one would run first, you had to + experiment or read the code; +* a cumbersome option: subclass + :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that + subclass in the ``create_protocol`` argument of + :func:`~legacy.auth.basic_auth_protocol_factory`. + +In the new implementation, you just write a ``process_request`` coroutine:: + + from websockets.asyncio.server import basic_auth, serve + + process_basic_auth = basic_auth(...) + + async def process_request(connection, request): + ... # some logic here + response = await process_basic_auth(connection, request) + if response is not None: + return response + ... # more logic here + + async with serve(..., process_request=process_request): + ... + Customizing the opening handshake ................................. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 955136ac..d940b1ea 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,12 @@ notice. *In development* +New features +............ + +* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for + enforcing HTTP Basic Auth on the server side. + .. _13.0: 13.0 diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index bd5a34b1..d4d20aeb 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -81,3 +81,11 @@ Broadcast --------- .. autofunction:: websockets.asyncio.server.broadcast + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: websockets.asyncio.server.basic_auth diff --git a/docs/reference/features.rst b/docs/reference/features.rst index a380f455..d9941e40 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -118,7 +118,7 @@ Server +------------------------------------+--------+--------+--------+--------+ | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ❌ | ❌ | ❌ | ✅ | + | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 23cb0409..80e9c17b 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -60,3 +60,11 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: websockets.sync.server.basic_auth diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index be263e56..5f09bbf7 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -45,7 +45,6 @@ explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], - ..., ) serve( @@ -57,7 +56,6 @@ explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], - ..., ) The Window Bits and Memory Level values in these examples reduce memory usage diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 632d3ac2..033887e8 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -131,7 +131,9 @@ class connect: :func:`connect` may be used as an asynchronous context manager:: - async with websockets.asyncio.client.connect(...) as websocket: + from websockets.asyncio.client import connect + + async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 8f04ec31..f6cd9a1b 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import hmac import http import logging import socket @@ -13,13 +14,19 @@ Generator, Iterable, Sequence, + Tuple, + cast, ) -from websockets.frames import CloseCode - +from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate -from ..headers import validate_subprotocols +from ..frames import CloseCode +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol @@ -28,7 +35,14 @@ from .connection import Connection, broadcast -__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"] +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "ServerConnection", + "Server", + "basic_auth", +] class ServerConnection(Connection): @@ -79,6 +93,7 @@ def __init__( ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + self.username: str # see basic_auth() def respond(self, status: StatusLike, text: str) -> Response: """ @@ -548,19 +563,21 @@ class serve: :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: + from websockets.asyncio.server import serve + def handler(websocket): ... # set this future to exit the server stop = asyncio.get_running_loop().create_future() - async with websockets.asyncio.server.serve(handler, host, port): + async with serve(handler, host, port): await stop Alternatively, call :meth:`~Server.serve_forever` to serve requests and cancel it to stop the server:: - server = await websockets.asyncio.server.serve(handler, host, port) + server = await serve(handler, host, port) await server.serve_forever() Args: @@ -822,3 +839,123 @@ def unix_serve( """ return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None, +) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function or coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + async def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + valid_credentials = check_credentials(username, password) + if isinstance(valid_credentials, Awaitable): + valid_credentials = await valid_credentials + + if not valid_credentials: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e33d53f6..3c700a37 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -156,7 +156,9 @@ def connect( :func:`connect` may be used as a context manager:: - with websockets.sync.client.connect(...) as websocket: + from websockets.sync.client import connect + + with connect(...) as websocket: ... The connection is closed automatically when exiting the context. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 86c162af..5e22e112 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hmac import http import logging import os @@ -10,12 +11,17 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Sequence +from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode -from ..headers import validate_subprotocols +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol @@ -24,7 +30,7 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "Server"] +__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] class ServerConnection(Connection): @@ -65,6 +71,7 @@ def __init__( protocol, close_timeout=close_timeout, ) + self.username: str # see basic_auth() def respond(self, status: StatusLike, text: str) -> Response: """ @@ -368,10 +375,12 @@ def serve( that it will be closed and call :meth:`~Server.serve_forever` to serve requests:: + from websockets.sync.server import serve + def handler(websocket): ... - with websockets.sync.server.serve(handler, ...) as server: + with serve(handler, ...) as server: server.serve_forever() Args: @@ -587,3 +596,119 @@ def unix_serve( """ return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerConnection, Request], Response | None]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.sync.server import basic_auth, serve + + with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + if not check_credentials(username, password): + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index b899998f..f05b9f1e 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import hmac import http import logging import socket @@ -560,3 +561,162 @@ async def test_logger(self): logger = logging.getLogger("test") async with run_server(logger=logger) as server: self.assertIs(server.logger, logger) + + +class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + async with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + await self.assertEval(client, "ws.username", "bye") + + async def test_check_credentials_function(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_check_credentials_coroutine(self): + """basic_auth accepts a check_credentials coroutine.""" + + async def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + async def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index e3dfeb27..39d7501b 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,4 +1,5 @@ import dataclasses +import hmac import http import logging import socket @@ -413,6 +414,150 @@ def test_shutdown(self): server.socket.accept() +class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + self.assertEval(client, "ws.username", "hello") + + def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + self.assertEval(client, "ws.username", "bye") + + def test_check_credentials(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + self.assertEval(client, "ws.username", "hello") + + def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) + + class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument."""