From 078c7f4d1ca3a6bdc492c766e50acc506bab9d57 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Sat, 27 May 2023 23:18:04 +0200 Subject: [PATCH] Refatoring for v2 --- .github/workflows/ci.yml | 4 +- .../workflows/linter-requirements.txt | 0 README.md | 18 +- examples/request_msearch.py | 68 ------- examples/respond_msearch.py | 78 -------- pyproject.toml | 10 + ssdp/__init__.py | 177 +----------------- ssdp/__main__.py | 138 ++++++++++++++ ssdp/asyncio.py | 56 ++++++ ssdp/lexers.py | 54 ++++++ ssdp/messages.py | 120 ++++++++++++ ssdp/network.py | 16 ++ ssdp/server.py | 55 ++++++ tests/test_ssdp.py | 13 +- 14 files changed, 469 insertions(+), 338 deletions(-) rename requirements.txt => .github/workflows/linter-requirements.txt (100%) delete mode 100644 examples/request_msearch.py delete mode 100644 examples/respond_msearch.py create mode 100644 ssdp/__main__.py create mode 100644 ssdp/asyncio.py create mode 100644 ssdp/lexers.py create mode 100644 ssdp/messages.py create mode 100644 ssdp/network.py create mode 100644 ssdp/server.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b438f0d..37c1829 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,8 +24,8 @@ jobs: with: python-version: "3.x" cache: "pip" - cache-dependency-path: "requirements.txt" - - run: python -m pip install -r requirements.txt + cache-dependency-path: ".github/workflows/linter-requirements.txt" + - run: python -m pip install -r .github/workflows/linter-requirements.txt - run: ${{ matrix.lint-command }} dist: diff --git a/requirements.txt b/.github/workflows/linter-requirements.txt similarity index 100% rename from requirements.txt rename to .github/workflows/linter-requirements.txt diff --git a/README.md b/README.md index 84c31fc..68eac30 100644 --- a/README.md +++ b/README.md @@ -14,32 +14,34 @@ python3 -m pip install ssdp ## Usage ```python +import ssdp.asyncio +import ssdp.messages import asyncio import socket import ssdp -class MyProtocol(ssdp.SimpleServiceDiscoveryProtocol): +class MyProtocol(ssdp.asyncio.SimpleServiceDiscoveryProtocol): - def response_received(self, response, addr): - print(response, addr) + def response_received(self, response, addr): + print(response, addr) - def request_received(self, request, addr): - print(request, addr) + def request_received(self, request, addr): + print(request, addr) loop = asyncio.get_event_loop() connect = loop.create_datagram_endpoint(MyProtocol, family=socket.AF_INET) transport, protocol = loop.run_until_complete(connect) -notify = ssdp.SSDPRequest('NOTIFY') +notify = ssdp.message.SSDPRequest('NOTIFY') notify.sendto(transport, (MyProtocol.MULTICAST_ADDRESS, 1982)) try: - loop.run_forever() + loop.run_forever() except KeyboardInterrupt: - pass + pass transport.close() loop.close() diff --git a/examples/request_msearch.py b/examples/request_msearch.py deleted file mode 100644 index a5918e1..0000000 --- a/examples/request_msearch.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3 -"""Send out a M-SEARCH request and listening for responses.""" -import asyncio -import socket - -import ssdp - - -class MyProtocol(ssdp.SimpleServiceDiscoveryProtocol): - """Protocol to handle responses and requests.""" - - def response_received(self, response: ssdp.SSDPResponse, addr: tuple): - """Handle an incoming response.""" - print( - "received response: {} {} {}".format( - response.status_code, response.reason, response.version - ) - ) - - for header in response.headers: - print("header: {}".format(header)) - - print() - - def request_received(self, request: ssdp.SSDPRequest, addr: tuple): - """Handle an incoming request.""" - print( - "received request: {} {} {}".format( - request.method, request.uri, request.version - ) - ) - - for header in request.headers: - print("header: {}".format(header)) - - print() - - -def main(): - # Start the asyncio loop. - loop = asyncio.get_event_loop() - connect = loop.create_datagram_endpoint(MyProtocol, family=socket.AF_INET) - transport, protocol = loop.run_until_complete(connect) - - # Send out an M-SEARCH request, requesting all service types. - search_request = ssdp.SSDPRequest( - "M-SEARCH", - headers={ - "HOST": "239.255.255.250:1900", - "MAN": '"ssdp:discover"', - "MX": "4", - "ST": "ssdp:all", - }, - ) - search_request.sendto(transport, (MyProtocol.MULTICAST_ADDRESS, 1900)) - - # Keep on running for 4 seconds. - try: - loop.run_until_complete(asyncio.sleep(4)) - except KeyboardInterrupt: - pass - - transport.close() - loop.close() - - -if __name__ == "__main__": - main() diff --git a/examples/respond_msearch.py b/examples/respond_msearch.py deleted file mode 100644 index 20d6325..0000000 --- a/examples/respond_msearch.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -"""Waiting for a M-SEARCH request and respond to it.""" -import asyncio -import socket - -import ssdp - - -class MyProtocol(ssdp.SimpleServiceDiscoveryProtocol): - """Protocol to handle responses and requests.""" - - def response_received(self, response: ssdp.SSDPResponse, addr: tuple): - """Handle an incoming response.""" - print( - "received response: {} {} {}".format( - response.status_code, response.reason, response.version - ) - ) - - for header in response.headers: - print("header: {}".format(header)) - - print() - - def request_received(self, request: ssdp.SSDPRequest, addr: tuple): - """Handle an incoming request and respond to it.""" - print( - "received request: {} {} {}".format( - request.method, request.uri, request.version - ) - ) - - for header in request.headers: - print("header: {}".format(header)) - - print() - - # Build response and send it. - print("Sending a response back to {}:{}".format(*addr)) - ssdp_response = ssdp.SSDPResponse( - 200, - "OK", - headers={ - "Cache-Control": "max-age=30", - "Location": "http://127.0.0.1:80/Device.xml", - "Server": "Python UPnP/1.0 SSDP", - "ST": "urn:schemas-upnp-org:service:ExampleService:1", - "USN": "uuid:2fac1234-31f8-11b4-a222-08002b34c003::urn:schemas-upnp-org:service:Example:1", - "EXT": "", - }, - ) - ssdp_response.sendto(self.transport, addr) - - -def main(): - # Start the asyncio loop. - loop = asyncio.get_event_loop() - connect = loop.create_datagram_endpoint( - MyProtocol, - family=socket.AF_INET, - local_addr=(MyProtocol.MULTICAST_ADDRESS, 1900), - ) - transport, protocol = loop.run_until_complete(connect) - - # Ensure MyProtocol has something send to. - MyProtocol.transport = transport - - try: - loop.run_forever() - except KeyboardInterrupt: - pass - - transport.close() - loop.close() - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 006782a..bc427e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,16 @@ test = [ "pytest", "pytest-cov", ] +cli = [ + "click", + "Pygments", +] + +[project.scripts] +ssdp = "ssdp:__main__.cli" + +[project.entry-points."pygments.lexers"] +ssdp = "ssdp.lexer:SSDPLexer" [project.urls] Project-URL = "https://github.com/codingjoe/ssdp" diff --git a/ssdp/__init__.py b/ssdp/__init__.py index 10aa863..3340956 100644 --- a/ssdp/__init__.py +++ b/ssdp/__init__.py @@ -1,182 +1,7 @@ """Python asyncio library for Simple Service Discovery Protocol (SSDP).""" -import asyncio -import email.parser -import errno -import logging - from . import _version __version__ = _version.version -__all__ = ["SSDPRequest", "SSDPResponse", "SimpleServiceDiscoveryProtocol"] +__all__ = [] VERSION = _version.version_tuple -logger = logging.getLogger("ssdp") - - -class SSDPMessage: - """Simplified HTTP message to serve as a SSDP message.""" - - def __init__(self, version="HTTP/1.1", headers=None): - if headers is None: - headers = [] - elif isinstance(headers, dict): - headers = headers.items() - - self.version = version - self.headers = list(headers) - - @classmethod - def parse(cls, msg): - """ - Parse message a string into a :class:`SSDPMessage` instance. - - Args: - msg (str): Message string. - - Returns: - SSDPMessage: Message parsed from string. - - """ - raise NotImplementedError() - - @classmethod - def parse_headers(cls, msg): - """ - Parse HTTP headers. - - Args: - msg (str): HTTP message. - - Returns: - (List[Tuple[str, str]]): List of header tuples. - - """ - return list(email.parser.Parser().parsestr(msg).items()) - - def __str__(self): - """Return full HTTP message.""" - raise NotImplementedError() - - def __bytes__(self): - """Return full HTTP message as bytes.""" - return self.__str__().encode().replace(b"\n", b"\r\n") - - -class SSDPResponse(SSDPMessage): - """Simple Service Discovery Protocol (SSDP) response.""" - - def __init__(self, status_code, reason, **kwargs): - self.status_code = int(status_code) - self.reason = reason - super().__init__(**kwargs) - - @classmethod - def parse(cls, msg): - """Parse message string to response object.""" - lines = msg.splitlines() - version, status_code, reason = lines[0].split() - headers = cls.parse_headers("\r\n".join(lines[1:])) - return cls( - version=version, status_code=status_code, reason=reason, headers=headers - ) - - def __str__(self): - """Return complete SSDP response.""" - lines = list() - lines.append(" ".join([self.version, str(self.status_code), self.reason])) - for header in self.headers: - lines.append("%s: %s" % header) - return "\n".join(lines) - - -class SSDPRequest(SSDPMessage): - """Simple Service Discovery Protocol (SSDP) request.""" - - def __init__(self, method, uri="*", version="HTTP/1.1", headers=None): - self.method = method - self.uri = uri - super().__init__(version=version, headers=headers) - - @classmethod - def parse(cls, msg): - """Parse message string to request object.""" - lines = msg.splitlines() - method, uri, version = lines[0].split() - headers = cls.parse_headers("\r\n".join(lines[1:])) - return cls(version=version, uri=uri, method=method, headers=headers) - - def sendto(self, transport, addr): - """ - Send request to a given address via given transport. - - Args: - transport (asyncio.DatagramTransport): - Write transport to send the message on. - addr (Tuple[str, int]): - IP address and port pair to send the message to. - - """ - msg = bytes(self) + b"\r\n" + b"\r\n" - logger.debug("%s:%s < %s", *(addr + (self,))) - transport.sendto(msg, addr) - - def __str__(self): - """Return complete SSDP request.""" - lines = list() - lines.append(" ".join([self.method, self.uri, self.version])) - for header in self.headers: - lines.append("%s: %s" % header) - return "\n".join(lines) - - -class SimpleServiceDiscoveryProtocol(asyncio.DatagramProtocol): - """ - Simple Service Discovery Protocol (SSDP). - - SSDP is part of UPnP protocol stack. For more information see: - https://en.wikipedia.org/wiki/Simple_Service_Discovery_Protocol - """ - - MULTICAST_ADDRESS = "239.255.255.250" - - def datagram_received(self, data, addr): - data = data.decode() - logger.debug("%s:%s > %s", *(addr + (data,))) - - if data.startswith("HTTP/"): - self.response_received(SSDPResponse.parse(data), addr) - else: - self.request_received(SSDPRequest.parse(data), addr) - - def response_received(self, response, addr): - """ - Being called when some response is received. - - Args: - response (SSDPResponse): Received response. - addr (Tuple[str, int]): Tuple containing IP address and port number. - - """ - raise NotImplementedError() - - def request_received(self, request, addr): - """ - Being called when some request is received. - - Args: - request (SSDPRequest): Received request. - addr (Tuple[str, int]): Tuple containing IP address and port number. - - """ - raise NotImplementedError() - - def error_received(self, exc): - if exc == errno.EAGAIN or exc == errno.EWOULDBLOCK: - logger.error("Error received: %s", exc) - else: - raise IOError("Unexpected connection error") from exc - - def connection_lost(self, exc): - logger.exception("Socket closed, stop the event loop", exc_info=exc) - loop = asyncio.get_event_loop() - loop.stop() diff --git a/ssdp/__main__.py b/ssdp/__main__.py new file mode 100644 index 0000000..cc18f91 --- /dev/null +++ b/ssdp/__main__.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +import asyncio +import logging +import platform +import socket +import sys +import time +from http.server import _get_best_family + +import ssdp.asyncio +import ssdp.messages +from ssdp import network +from ssdp.server import SSDPMessageHandler, SSDPServer + +try: + import click + + from .lexers import prettify_msg +except ImportError: + print("The SSDP CLI requires needs to be installed via `pip install ssdp[cli]`.") + exit(1) + + +import ssdp + + +class PrintProcessor: + """Print SSDP messages to stdout.""" + + def process_request(self, request: ssdp.messages.SSDPRequest, addr: tuple): + """Handle an incoming request.""" + self.pprint(request, addr) + + def process_response(self, response: ssdp.messages.SSDPResponse, addr: tuple): + """Handle an incoming response.""" + self.pprint(response, addr) + + @staticmethod + def pprint(msg, addr): + """Pretty print the message.""" + host = f"[{addr[0]}]" if ":" in addr[0] else addr[0] + host = click.style(host, fg="green", bold=True) + port = click.style(str(addr[1]), fg="yellow", bold=True) + click.echo( + "%s:%s - - [%s] %s" % (host, port, time.asctime(), prettify_msg(msg)) + ) + + +class PrintSSDMessageProtocol( + PrintProcessor, ssdp.asyncio.SimpleServiceDiscoveryProtocol +): + pass + + +@click.group() +@click.option("-v", "--verbose", count=True, help="Increase verbosity.") +def cli(verbose): + """SSDP command line interface.""" + logging.basicConfig( + level=max(10, 10 * (2 - verbose)), + format="%(levelname)s: [%(asctime)s] %(message)s", + handlers=[logging.StreamHandler()], + ) + + +@cli.command() +@click.option( + "--bind", + "-b", + help="Specify alternate bind address [default: all interfaces]", +) +def discover(bind): + """Send out an M-SEARCH request and listening for responses.""" + family, addr = _get_best_family(bind, network.PORT) + loop = asyncio.get_event_loop() + + connect = loop.create_datagram_endpoint(PrintSSDMessageProtocol, family=family) + transport, protocol = loop.run_until_complete(connect) + + target = network.MULTICAST_ADDRESS_IPV4, network.PORT + + search_request = ssdp.messages.SSDPRequest( + "M-SEARCH", + headers={ + "HOST": "%s:%d" % target, + "MAN": '"ssdp:discover"', + "MX": "4", + "ST": "ssdp:all", + }, + ) + + target = network.MULTICAST_ADDRESS_IPV4, network.PORT + + search_request.sendto(transport, target) + + PrintSSDMessageProtocol.pprint(search_request, addr[:2]) + try: + loop.run_forever() + except KeyboardInterrupt: + pass + finally: + transport.close() + + +class PrintSSDPMessageHandler(PrintProcessor, SSDPMessageHandler): + pass + + +@cli.command(name="server") +@click.option( + "--bind", + "-b", + help="Specify alternate bind address [default: all interfaces]", +) +def serve(bind, ServerClass=SSDPServer): + if platform.system() == "Darwin": + # macOS doesn't support IPv6 multicast + ServerClass.address_family, addr = _get_best_family( + bind, network.PORT, socket.AF_INET + ) + else: + ServerClass.address_family, addr = _get_best_family(bind, network.PORT) + + with ServerClass(addr, PrintSSDPMessageHandler) as ssdpd: + host, port = ssdpd.socket.getsockname()[:2] + url_host = f"[{host}]" if ":" in host else host + click.echo( + f"Serving SSDP on {host} port {port} " f"(http://{url_host}:{port}/) ..." + ) + try: + ssdpd.serve_forever() + except KeyboardInterrupt: + click.echo("\nKeyboard interrupt received, exiting.") + sys.exit(0) + + +if __name__ == "__main__": + cli() diff --git a/ssdp/asyncio.py b/ssdp/asyncio.py new file mode 100644 index 0000000..6df78af --- /dev/null +++ b/ssdp/asyncio.py @@ -0,0 +1,56 @@ +import asyncio +import errno +import logging + +from . import messages + +logger = logging.getLogger(__name__) + + +class SimpleServiceDiscoveryProtocol(asyncio.DatagramProtocol): + """ + Simple Service Discovery Protocol (SSDP). + + SSDP is part of UPnP protocol stack. For more information see: + https://en.wikipedia.org/wiki/Simple_Service_Discovery_Protocol + """ + + def datagram_received(self, data, addr): + data = data.decode() + logger.debug("%s:%s – – %s", *addr, data) + + if data.startswith("HTTP/"): + self.process_response(messages.SSDPResponse.parse(data), addr) + else: + self.process_request(messages.SSDPRequest.parse(data), addr) + + def process_response(self, response, addr): + """ + Being called when some response is received. + + Args: + response (ssdp.messages.SSDPResponse): Received response. + addr (Tuple[str, int]): Tuple containing IP address and port number. + + """ + raise NotImplementedError() + + def process_request(self, request, addr): + """ + Being called when some request is received. + + Args: + request (ssdp.messages.SSDPRequest): Received request. + addr (Tuple[str, int]): Tuple containing IP address and port number. + + """ + raise NotImplementedError() + + def error_received(self, exc): + if exc == errno.EAGAIN or exc == errno.EWOULDBLOCK: + logger.exception("Blocking IO error", exc_info=exc) + else: + raise exc + + def connection_lost(self, exc): + logger.exception("Connection lost", exc_info=exc) diff --git a/ssdp/lexers.py b/ssdp/lexers.py new file mode 100644 index 0000000..be999ec --- /dev/null +++ b/ssdp/lexers.py @@ -0,0 +1,54 @@ +from pygments import formatters, highlight, lexer, lexers, token + + +class SSDPLexer(lexers.HttpLexer): + tokens = { + "root": [ + ( + r"(M-SEARCH|NOTIFY)( +)([^ ]+)( +)" + r"(HTTP)(/)(1\.[01]|2(?:\.0)?|3)(\r?\n|\Z)", + lexer.bygroups( + token.Name.Function, + token.Text, + token.Name.Namespace, + token.Text, + token.Keyword.Reserved, + token.Operator, + token.Number, + token.Text, + ), + "headers", + ), + ( + r"(HTTP)(/)(1\.[01]|2(?:\.0)?|3)( +)(\d{3})(?:( +)([^\r\n]*))?(\r?\n|\Z)", + lexer.bygroups( + token.Keyword.Reserved, + token.Operator, + token.Number, + token.Text, + token.Number, + token.Text, + token.Name.Exception, + token.Text, + ), + "headers", + ), + ], + "headers": [ + ( + r"([^\s:]+)( *)(:)( *)([^\r\n]+)(\r?\n|\Z)", + lexers.HttpLexer.header_callback, + ), + ( + r"([\t ]+)([^\r\n]+)(\r?\n|\Z)", + lexers.HttpLexer.continuous_header_callback, + ), + (r"\r?\n", token.Text, "content"), + ], + "content": [(r".+", lexers.HttpLexer.content_callback)], + } + + +def prettify_msg(msg): + """Return a pretty-printed version of a SSDP message.""" + return highlight(str(msg), SSDPLexer(), formatters.TerminalFormatter()) diff --git a/ssdp/messages.py b/ssdp/messages.py new file mode 100644 index 0000000..56a6530 --- /dev/null +++ b/ssdp/messages.py @@ -0,0 +1,120 @@ +import email.parser +import logging + +logger = logging.getLogger("ssdp") + + +class SSDPMessage: + """Simplified HTTP message to serve as a SSDP message.""" + + def __init__(self, version="HTTP/1.1", headers=None): + if headers is None: + headers = [] + elif isinstance(headers, dict): + headers = headers.items() + + self.version = version + self.headers = list(headers) + + @classmethod + def parse(cls, msg): + """ + Parse message a string into a :class:`SSDPMessage` instance. + + Args: + msg (str): Message string. + + Returns: + SSDPMessage: Message parsed from string. + + """ + raise NotImplementedError() + + @classmethod + def parse_headers(cls, msg): + """ + Parse HTTP headers. + + Args: + msg (str): HTTP message. + + Returns: + (List[Tuple[str, str]]): List of header tuples. + + """ + return list(email.parser.Parser().parsestr(msg).items()) + + def __str__(self): + """Return full HTTP message.""" + raise NotImplementedError() + + def __bytes__(self): + """Return full HTTP message as bytes.""" + return self.__str__().encode().replace(b"\n", b"\r\n") + + +class SSDPResponse(SSDPMessage): + """Simple Service Discovery Protocol (SSDP) response.""" + + def __init__(self, status_code, reason, **kwargs): + self.status_code = int(status_code) + self.reason = reason + super().__init__(**kwargs) + + @classmethod + def parse(cls, msg): + """Parse message string to response object.""" + lines = msg.splitlines() + version, status_code, reason = lines[0].split() + headers = cls.parse_headers("\r\n".join(lines[1:])) + return cls( + version=version, status_code=status_code, reason=reason, headers=headers + ) + + def __str__(self): + """Return complete SSDP response.""" + lines = list() + lines.append(" ".join([self.version, str(self.status_code), self.reason])) + for header in self.headers: + lines.append("%s: %s" % header) + return "\r\n".join(lines) + + +class SSDPRequest(SSDPMessage): + """Simple Service Discovery Protocol (SSDP) request.""" + + def __init__(self, method, uri="*", version="HTTP/1.1", headers=None): + self.method = method + self.uri = uri + super().__init__(version=version, headers=headers) + + @classmethod + def parse(cls, msg): + """Parse message string to request object.""" + lines = msg.splitlines() + method, uri, version = lines[0].split() + headers = cls.parse_headers("\r\n".join(lines[1:])) + return cls(version=version, uri=uri, method=method, headers=headers) + + def sendto(self, transport, addr): + """ + Send request to a given address via given transport. + + Args: + transport (asyncio.DatagramTransport): + Write transport to send the message on. + addr (Tuple[str, int]): + IP address and port pair to send the message to. + + """ + msg = bytes(self) + b"\r\n" + b"\r\n" + logger.debug("%s:%s - - %s", *(addr + (self,))) + transport.sendto(msg, addr) + + def __str__(self): + """Return complete SSDP request.""" + lines = list() + lines.append(" ".join([self.method, self.uri, self.version])) + for header in self.headers: + lines.append("%s: %s" % header) + return "\r\n".join(lines) diff --git a/ssdp/network.py b/ssdp/network.py new file mode 100644 index 0000000..0326ac1 --- /dev/null +++ b/ssdp/network.py @@ -0,0 +1,16 @@ +__all__ = [ + "MULTICAST_ADDRESS_IPV4", + "MULTICAST_ADDRESS_IPV6_LINK_LOCAL", + "MULTICAST_ADDRESS_IPV6_SITE_LOCAL", + "MULTICAST_ADDRESS_IPV6_ORG_LOCAL", + "MULTICAST_ADDRESS_IPV6_GLOBAL", + "PORT", +] + +MULTICAST_ADDRESS_IPV4 = "239.255.255.250" +MULTICAST_ADDRESS_IPV6_LINK_LOCAL = "ff02::c" +MULTICAST_ADDRESS_IPV6_SITE_LOCAL = "ff05::c" +MULTICAST_ADDRESS_IPV6_ORG_LOCAL = "ff08::c" +MULTICAST_ADDRESS_IPV6_GLOBAL = "ff0e::c" + +PORT = 1900 diff --git a/ssdp/server.py b/ssdp/server.py new file mode 100644 index 0000000..751816c --- /dev/null +++ b/ssdp/server.py @@ -0,0 +1,55 @@ +import logging +import socket +import socketserver +import struct + +from . import messages, network + +logger = logging.getLogger(__name__) + + +class SSDPMessageHandler(socketserver.BaseRequestHandler): + def handle(self): + data = self.request[0] + data = data.decode() + logger.debug("%s:%d – – %s", *self.client_address, data) + + if data.startswith("HTTP/"): + self.process_response( + messages.SSDPResponse.parse(data), self.client_address + ) + else: + self.process_request(messages.SSDPRequest.parse(data), self.client_address) + + def process_request( + self, request: messages.SSDPRequest, client_address: tuple[str, int] + ): + raise NotImplementedError() + + def process_response( + self, response: messages.SSDPResponse, client_address: tuple[str, int] + ): + raise NotImplementedError() + + +class SSDPServer(socketserver.UDPServer): + allow_reuse_address = True + + def server_bind(self): + if self.address_family == socket.AF_INET: + self.socket.setsockopt( + socket.IPPROTO_IP, + socket.IP_ADD_MEMBERSHIP, + socket.inet_aton(network.MULTICAST_ADDRESS_IPV4) + struct.pack("@I", 0), + ) + elif self.address_family == socket.AF_INET6: + ifis = struct.pack("@I", 0) + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, ifis) + group = ( + socket.inet_pton( + self.address_family, network.MULTICAST_ADDRESS_IPV6_SITE_LOCAL + ) + + ifis + ) + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, group) + self.socket.bind(self.server_address[:2]) diff --git a/tests/test_ssdp.py b/tests/test_ssdp.py index 8c1b7ed..62ea81e 100644 --- a/tests/test_ssdp.py +++ b/tests/test_ssdp.py @@ -3,7 +3,8 @@ import pytest -from ssdp import SimpleServiceDiscoveryProtocol, SSDPMessage, SSDPRequest, SSDPResponse +from ssdp.asyncio import SimpleServiceDiscoveryProtocol +from ssdp.messages import SSDPMessage, SSDPRequest, SSDPResponse from . import fixtures @@ -53,10 +54,10 @@ def test_parse(self): def test_str(self): response = SSDPResponse( - 200, "OK", headers=[("Location", "Location: http://192.168.1.239:55443")] + 200, "OK", headers=[("Location", "http://192.168.1.239:55443")] ) assert str(response) == ( - "HTTP/1.1 200 OK\n" "Location: Location: http://192.168.1.239:55443" + "HTTP/1.1 200 OK\r\nLocation: http://192.168.1.239:55443" ) @@ -70,16 +71,16 @@ def test_str(self): request = SSDPRequest( "NOTIFY", "*", headers=[("Cache-Control", "max-age=3600")] ) - assert str(request) == ("NOTIFY * HTTP/1.1\n" "Cache-Control: max-age=3600") + assert str(request) == ("NOTIFY * HTTP/1.1\r\nCache-Control: max-age=3600") def test_sendto(self): requests = [] class MyProtocol(SimpleServiceDiscoveryProtocol): - def response_received(self, response, addr): + def process_response(self, response, addr): print(response, addr) - def request_received(self, request, addr): + def process_request(self, request, addr): requests.append(request) print(request, addr)