1414 LifespanStartupEvent ,
1515)
1616
17- from .._drivers .http_client import HTTPClient , HTTPClientSession , HTTPResponse
17+ from .._drivers .http_client import HTTPClient , HTTPClientError , HTTPClientSession , HTTPResponse
1818from .._drivers .http_server import HTTPServable , HTTPServableApp
1919from .._drivers .ssl import SSLCertificate
2020from ..proxy import ProxyServer
@@ -93,26 +93,33 @@ def get_server(self, host: str, port: int) -> tuple[SSLCertificate, LifespanMana
9393
9494
9595class MockHTTPClient (HTTPClient ):
96- def __init__ (self , mock_network : MockHTTPNetwork , host : str | None = None ):
97- # TODO: do we actually need to be able to specify the client's host?
96+ def __init__ (self , mock_network : MockHTTPNetwork , client_host : str | None = None ):
9897 self ._mock_network = mock_network
99- self ._host = host or "passive client"
98+ # Since the nodes use HTTP for P2P messaging,
99+ # we need to be able to report the client's hostname (used for DDoS protection).
100+ self ._client_host = client_host or "passive client"
100101
101102 async def fetch_certificate (self , host : str , port : int ) -> SSLCertificate :
102103 certificate , _manager = self ._mock_network .get_server (host , port )
103104 return certificate
104105
105106 @asynccontextmanager
106107 async def session (
107- self , _certificate : SSLCertificate | None = None
108+ self , certificate : SSLCertificate | None = None
108109 ) -> AsyncIterator ["MockHTTPClientSession" ]:
109- yield MockHTTPClientSession (self ._mock_network , self ._host )
110+ yield MockHTTPClientSession (self ._mock_network , self ._client_host , certificate )
110111
111112
112113class MockHTTPClientSession (HTTPClientSession ):
113- def __init__ (self , mock_network : MockHTTPNetwork , host : str = "mock_hostname" ):
114+ def __init__ (
115+ self ,
116+ mock_network : MockHTTPNetwork ,
117+ client_host : str = "mock_hostname" ,
118+ certificate : SSLCertificate | None = None ,
119+ ):
114120 self ._mock_network = mock_network
115- self ._host = host
121+ self ._client_host = client_host
122+ self ._certificate = certificate
116123
117124 async def get (self , url : str , params : Mapping [str , str ] = {}) -> HTTPResponse :
118125 response = await self ._request ("get" , url , params = params )
@@ -126,11 +133,14 @@ async def _request(self, method: str, url: str, *args: Any, **kwargs: Any) -> ht
126133 url_parts = urlparse (url )
127134 assert url_parts .hostname is not None , "Hostname is missing from the url"
128135 assert url_parts .port is not None , "Port is missing from the url"
129- _certificate , manager = self ._mock_network .get_server (url_parts .hostname , url_parts .port )
130- # TODO: check the cerificate's validity here
136+ certificate , manager = self ._mock_network .get_server (url_parts .hostname , url_parts .port )
137+
138+ if self ._certificate is not None and certificate != self ._certificate :
139+ raise HTTPClientError ("Certificate mismatch" )
140+
131141 # Unfortunately there are no unified types for hypercorn and httpx,
132142 # so we have to cast manually.
133143 app = cast ("httpx._transports.asgi._ASGIApp" , manager .app ) # noqa: SLF001
134- transport = httpx .ASGITransport (app = app , client = (str ( self ._host ) , 9999 ))
144+ transport = httpx .ASGITransport (app = app , client = (self ._client_host , 9999 ))
135145 async with httpx .AsyncClient (transport = transport ) as client :
136146 return await client .request (method , url , * args , ** kwargs )
0 commit comments