Skip to content

Commit

Permalink
Add regex support in ServerProtocol(origins=...).
Browse files Browse the repository at this point in the history
  • Loading branch information
daH005 committed Jan 13, 2025
1 parent e7a098e commit 1a43fe2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
10 changes: 6 additions & 4 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hmac
import http
import logging
import re
import socket
import sys
from collections.abc import Awaitable, Generator, Iterable, Sequence
Expand Down Expand Up @@ -599,9 +600,10 @@ def handler(websocket):
See :meth:`~asyncio.loop.create_server` for details.
port: TCP port the server listens on.
See :meth:`~asyncio.loop.create_server` for details.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
in the list if the lack of an origin is acceptable.
origins: Acceptable values of the ``Origin`` header, including regular
expressions, for defending against Cross-Site WebSocket Hijacking
attacks. Include :obj:`None` in the list if the lack of an origin
is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
Expand Down Expand Up @@ -681,7 +683,7 @@ def __init__(
port: int | None = None,
*,
# WebSocket
origins: Sequence[Origin | None] | None = None,
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Expand Down
21 changes: 16 additions & 5 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import binascii
import email.utils
import http
import re
import warnings
from collections.abc import Generator, Sequence
from typing import Any, Callable, cast
Expand Down Expand Up @@ -49,9 +50,9 @@ class ServerProtocol(Protocol):
Sans-I/O implementation of a WebSocket server connection.
Args:
origins: Acceptable values of the ``Origin`` header; include
:obj:`None` in the list if the lack of an origin is acceptable.
This is useful for defending against Cross-Site WebSocket
origins: Acceptable values of the ``Origin`` header, including regular
expressions; include :obj:`None` in the list if the lack of an origin
is acceptable. This is useful for defending against Cross-Site WebSocket
Hijacking attacks.
extensions: List of supported extensions, in order in which they
should be tried.
Expand All @@ -73,7 +74,7 @@ class ServerProtocol(Protocol):
def __init__(
self,
*,
origins: Sequence[Origin | None] | None = None,
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Expand Down Expand Up @@ -309,7 +310,17 @@ def process_origin(self, headers: Headers) -> Origin | None:
if origin is not None:
origin = cast(Origin, origin)
if self.origins is not None:
if origin not in self.origins:
valid = False
for acceptable_origin_or_regex in self.origins:
if isinstance(acceptable_origin_or_regex, re.Pattern):
# `str(origin)` is needed for compatibility
# between `Pattern.match(string=...)` and `origin`.
valid = acceptable_origin_or_regex.match(str(origin)) is not None
else:
valid = acceptable_origin_or_regex == origin
if valid:
break
if not valid:
raise InvalidOrigin(origin)
return origin

Expand Down
10 changes: 6 additions & 4 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import http
import logging
import os
import re
import selectors
import socket
import ssl as ssl_module
Expand Down Expand Up @@ -325,7 +326,7 @@ def serve(
sock: socket.socket | None = None,
ssl: ssl_module.SSLContext | None = None,
# WebSocket
origins: Sequence[Origin | None] | None = None,
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Expand Down Expand Up @@ -399,9 +400,10 @@ def handler(websocket):
You may call :func:`socket.create_server` to create a suitable TCP
socket.
ssl: Configuration for enabling TLS on the connection.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
in the list if the lack of an origin is acceptable.
origins: Acceptable values of the ``Origin`` header, including regular
expressions, for defending against Cross-Site WebSocket Hijacking
attacks. Include :obj:`None` in the list if the lack of an origin
is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
Expand Down
37 changes: 37 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import http
import logging
import re
import sys
import unittest
from unittest.mock import patch
Expand Down Expand Up @@ -623,6 +624,42 @@ def test_unsupported_origin(self):
"invalid Origin header: https://original.example.com",
)

def test_supported_origin_by_regex(self):
"""
Handshake succeeds when checking origins and the origin is supported
by a regular expression.
"""
server = ServerProtocol(
origins=["https://example.com", re.compile(r"https://other.*")]
)
request = make_request()
request.headers["Origin"] = "https://other.example.com"
response = server.accept(request)
server.send_response(response)

self.assertHandshakeSuccess(server)
self.assertEqual(server.origin, "https://other.example.com")

def test_unsupported_origin_by_regex(self):
"""
Handshake succeeds when checking origins and the origin is unsupported
by a regular expression.
"""
server = ServerProtocol(
origins=["https://example.com", re.compile(r"https://other.*")]
)
request = make_request()
request.headers["Origin"] = "https://original.example.com"
response = server.accept(request)
server.send_response(response)

self.assertEqual(response.status_code, 403)
self.assertHandshakeError(
server,
InvalidOrigin,
"invalid Origin header: https://original.example.com",
)

def test_no_origin_accepted(self):
"""Handshake succeeds when the lack of an origin is accepted."""
server = ServerProtocol(origins=[None])
Expand Down

0 comments on commit 1a43fe2

Please sign in to comment.