Skip to content

Commit d11fd52

Browse files
authored
Merge pull request #46 from fjarri/todo-party
Todo party
2 parents fb7f597 + e183a0d commit d11fd52

44 files changed

Lines changed: 981 additions & 842 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/examples_common.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99

1010
from nucypher_async._drivers.http_client import HTTPClient
1111
from nucypher_async._drivers.ssl import fetch_certificate
12-
from nucypher_async._drivers.time import SystemClock
12+
from nucypher_async._drivers.time import BaseClock, SystemClock
1313
from nucypher_async._mocks import MockCBDClient, MockClock, MockIdentityClient, MockPREClient
1414
from nucypher_async._p2p import Contact, NodeClient, Operator
15-
from nucypher_async.base.time import BaseClock
1615
from nucypher_async.blockchain.cbd import CBDClient
1716
from nucypher_async.blockchain.identity import AmountT, IdentityAccount, IdentityClient
18-
from nucypher_async.blockchain.pre import PREAccount, PREClient
17+
from nucypher_async.blockchain.pre import PREAccount, PREAmount, PREClient
1918
from nucypher_async.characters import MasterKey
2019
from nucypher_async.characters.cbd import Decryptor
2120
from nucypher_async.characters.pre import Reencryptor
@@ -54,6 +53,9 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
5453
cbd_client = MockCBDClient()
5554
clock = MockClock()
5655

56+
pre_account = PREAccount.random()
57+
pre_client.mock_set_balance(pre_account.address, PREAmount.ether(1))
58+
5759
logger.info("Mocked mode - starting nodes")
5860

5961
for i in range(3):
@@ -74,7 +76,7 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
7476
AmountT.ether(40000),
7577
)
7678

77-
logger = logger.get_child(f"Node{i + 1}")
79+
node_logger = logger.get_child(f"Node{i + 1}")
7880

7981
http_server_config = HTTPServerConfig.from_typed_values(
8082
bind_to_address=LOCALHOST,
@@ -88,7 +90,7 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
8890
pre_client=pre_client,
8991
cbd_client=cbd_client,
9092
node_client=NodeClient(HTTPClient()),
91-
logger=logger,
93+
logger=node_logger,
9294
seed_contacts=seed_contacts,
9395
clock=clock,
9496
)
@@ -118,7 +120,7 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
118120
clock=clock,
119121
seed_contact=Contact(LOCALHOST, PORT_BASE),
120122
server_handles=handles,
121-
pre_account=PREAccount.random(),
123+
pre_account=pre_account,
122124
)
123125

124126
@classmethod

examples/grant_and_retrieve.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from nucypher_async.blockchain.pre import PREAccountSigner
1010
from nucypher_async.characters import MasterKey
11-
from nucypher_async.characters.pre import Delegator, Publisher, Recipient
11+
from nucypher_async.characters.pre import Delegator, EncryptedMessage, Publisher, Recipient
1212
from nucypher_async.client.network import NetworkClient
13-
from nucypher_async.client.pre import LocalPREClient, pre_encrypt
13+
from nucypher_async.client.pre import LocalPREClient
1414

1515

1616
async def main(*, mocked: bool = True) -> None:
@@ -62,13 +62,13 @@ async def main(*, mocked: bool = True) -> None:
6262
)
6363

6464
message = b"a secret message"
65-
message_kit = pre_encrypt(enacted_policy, message)
65+
encrypted_message = EncryptedMessage(enacted_policy.policy, message)
6666

6767
context.logger.info("Bob retrieves and decrypts")
6868
decrypted = await bob_client.decrypt(
6969
bob,
7070
enacted_policy,
71-
message_kit,
71+
encrypted_message,
7272
alice.card(),
7373
publisher.card(),
7474
)

nucypher_async/_drivers/asgi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ async def body_json(self) -> JSON:
8383

8484
@property
8585
def remote_host(self) -> str | None:
86-
# TODO: we can get the port here too
86+
# We could get a port here too, but it won't be useful,
87+
# because it won't correspond to the port the server (if any) is bound to.
8788
return self._request.client.host if self._request.client else None
8889

8990
@property

nucypher_async/_drivers/http_client.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,23 @@
77

88
import httpx
99

10-
from .._utils import temp_file
1110
from ..base.types import JSON
12-
from .ssl import SSLCertificate, fetch_certificate
11+
from .ssl import SSLCertificate, fetch_certificate, make_client_ssl_context
1312

1413

1514
class HTTPClientError(Exception):
1615
pass
1716

1817

19-
# TODO: move logic to _drivers/ssl
20-
def make_ssl_context(certificate: SSLCertificate) -> ssl.SSLContext:
21-
# Cannot create a context with a certificate directly, see https://bugs.python.org/issue16487.
22-
# So instead we do it via a temporary file, and negate the performance penalty by caching.
23-
with temp_file(certificate.to_pem_bytes()) as certificate_filename:
24-
return ssl.create_default_context(cafile=str(certificate_filename))
25-
26-
2718
class HTTPClient:
2819
# The default certificate cache size is chosen to cover the possible size of the network,
2920
# which at its best only had a few hundred nodes.
3021
def __init__(self, certificate_cache_size: int = 1024):
31-
self._cached_ssl_context = lru_cache(maxsize=certificate_cache_size)(make_ssl_context)
22+
@lru_cache(maxsize=certificate_cache_size)
23+
def cached_ssl_context(certificate: SSLCertificate) -> ssl.SSLContext:
24+
return make_client_ssl_context(certificate)
25+
26+
self._cached_ssl_context = cached_ssl_context
3227

3328
async def fetch_certificate(self, host: str, port: int) -> SSLCertificate:
3429
return await fetch_certificate(host, port)
@@ -60,9 +55,11 @@ def json(self) -> JSON:
6055
return cast("JSON", self._response.json())
6156

6257
@property
63-
def status_code(self) -> http.HTTPStatus:
64-
# TODO: should we add a handling of an unrecognized status code?
65-
return http.HTTPStatus(self._response.status_code)
58+
def status_code(self) -> http.HTTPStatus | None:
59+
try:
60+
return http.HTTPStatus(self._response.status_code)
61+
except ValueError:
62+
return None
6663

6764

6865
class HTTPClientSession:

nucypher_async/_drivers/http_server.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Encapsulates a specific HTTP server running our ASGI app (currently ``hypercorn``)."""
22

3-
import os
43
from abc import ABC, abstractmethod
54
from ipaddress import IPv4Address
65
from ssl import SSLContext
@@ -12,9 +11,8 @@
1211
from hypercorn.trio import serve
1312
from hypercorn.typing import ASGIFramework
1413

15-
from .._utils import temp_file
1614
from ..logging import Logger
17-
from .ssl import SSLCertificate, SSLPrivateKey
15+
from .ssl import SSLCertificate, SSLPrivateKey, fill_ssl_context
1816

1917
if TYPE_CHECKING: # pragma: no cover
2018
import logging
@@ -42,35 +40,18 @@ def __init__(
4240

4341
# Have to keep the unencrypted private key in memory,
4442
# but at least we're not leaking it in the filesystem.
45-
# TODO: Can we do better? Zeroize it on cleanup?
4643
self.__ssl_private_key = ssl_private_key
4744

4845
def create_ssl_context(self) -> SSLContext | None:
49-
# sanity check
50-
if self.certfile or self.keyfile or self.ca_certs:
51-
raise RuntimeError(
52-
"Certificate/keyfile must be passed to the constructor in the serialized form"
53-
)
54-
5546
context = super().create_ssl_context()
5647

5748
# Since ssl_enabled() returns True, the context will be created,
5849
# but with no certificates loaded.
5950
assert context is not None, "SSL context was not created"
6051

61-
# Encrypt the temporary file we create with an emphemeral password.
62-
keyfile_password = os.urandom(32)
63-
64-
# TODO: move logic to _drivers/ssl
65-
if self.__ssl_ca_chain:
66-
chain_data = b"\n".join(cert.to_pem_bytes() for cert in self.__ssl_ca_chain).decode()
67-
context.load_verify_locations(cadata=chain_data)
68-
69-
with (
70-
temp_file(self.__ssl_certificate.to_pem_bytes()) as certfile,
71-
temp_file(self.__ssl_private_key.to_pem_bytes(keyfile_password)) as keyfile,
72-
):
73-
context.load_cert_chain(certfile=certfile, keyfile=keyfile, password=keyfile_password)
52+
fill_ssl_context(
53+
context, self.__ssl_private_key, self.__ssl_certificate, self.__ssl_ca_chain
54+
)
7455

7556
return context
7657

nucypher_async/_drivers/ssl.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import secrets
12
import ssl
3+
from collections.abc import Iterable
24
from ipaddress import ip_address
35
from typing import cast, get_args
46

@@ -14,6 +16,8 @@
1416
from cryptography.hazmat.primitives.serialization import load_pem_private_key
1517
from cryptography.x509.oid import NameOID
1618

19+
from .._utils import temp_file
20+
1721

1822
class SSLPrivateKey:
1923
@classmethod
@@ -85,8 +89,6 @@ def self_signed(
8589
host: str,
8690
days_valid: int = 365,
8791
) -> "SSLCertificate":
88-
# TODO: assert that the start date is in UTC?
89-
9092
public_key = private_key._public_key() # noqa: SLF001
9193

9294
end_date = start_date.shift(days=days_valid)
@@ -140,6 +142,10 @@ def list_from_pem_bytes(cls, data: bytes) -> list["SSLCertificate"]:
140142
certs_bytes = data.split(start_line)
141143
return [cls.from_pem_bytes(start_line + cert_bytes) for cert_bytes in certs_bytes[1:]]
142144

145+
@staticmethod
146+
def list_to_pem_str(certificates: Iterable["SSLCertificate"]) -> str:
147+
return b"\n".join(certificate.to_pem_bytes() for certificate in certificates).decode()
148+
143149
@classmethod
144150
def from_der_bytes(cls, data: bytes) -> "SSLCertificate":
145151
return cls(x509.load_der_x509_certificate(data))
@@ -175,6 +181,34 @@ def not_valid_after(self) -> arrow.Arrow:
175181
return arrow.get(self._certificate.not_valid_after_utc)
176182

177183

184+
def make_client_ssl_context(certificate: SSLCertificate) -> ssl.SSLContext:
185+
# Cannot create a context with a certificate directly, see https://bugs.python.org/issue16487.
186+
# So instead we do it via a temporary file, and negate the performance penalty by caching.
187+
with temp_file(certificate.to_pem_bytes()) as certificate_filename:
188+
return ssl.create_default_context(cafile=str(certificate_filename))
189+
190+
191+
def fill_ssl_context(
192+
context: ssl.SSLContext,
193+
private_key: SSLPrivateKey,
194+
certificate: SSLCertificate,
195+
ca_chain: Iterable[SSLCertificate] | None,
196+
) -> None:
197+
if ca_chain:
198+
context.load_verify_locations(cadata=SSLCertificate.list_to_pem_str(ca_chain))
199+
200+
# Encrypt the temporary file we create with an emphemeral password.
201+
# Would be nice to zeroize it in the end, but we cannot zeroize `bytes`,
202+
# and `cryptography` does not accept `bytearray` as a password.
203+
keyfile_password = secrets.token_bytes(32)
204+
205+
with (
206+
temp_file(certificate.to_pem_bytes()) as certfile,
207+
temp_file(private_key.to_pem_bytes(keyfile_password)) as keyfile,
208+
):
209+
context.load_cert_chain(certfile=certfile, keyfile=keyfile, password=keyfile_password)
210+
211+
178212
async def fetch_certificate(host: str, port: int) -> SSLCertificate:
179213
# Do not verify the certificate, it is self-signed
180214
context = ssl.create_default_context()

nucypher_async/_drivers/time.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
13
import arrow
24

3-
from ..base.time import BaseClock
5+
6+
class BaseClock(ABC):
7+
"""
8+
An abstract class for getting the current time.
9+
A behavior different from just returning the system time may be needed for tests.
10+
"""
11+
12+
# not a staticmethod since some implementations may need to maintain an internal state
13+
@abstractmethod
14+
def utcnow(self) -> arrow.Arrow:
15+
pass
416

517

618
class SystemClock(BaseClock):

nucypher_async/_mocks/asgi.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
LifespanStartupEvent,
1515
)
1616

17-
from .._drivers.http_client import HTTPClient, HTTPClientSession, HTTPResponse
17+
from .._drivers.http_client import HTTPClient, HTTPClientError, HTTPClientSession, HTTPResponse
1818
from .._drivers.http_server import HTTPServable, HTTPServableApp
1919
from .._drivers.ssl import SSLCertificate
2020
from ..proxy import ProxyServer
@@ -93,26 +93,33 @@ def get_server(self, host: str, port: int) -> tuple[SSLCertificate, LifespanMana
9393

9494

9595
class MockHTTPClient(HTTPClient):
96-
def __init__(self, mock_network: MockHTTPNetwork, host: str | None = None):
97-
# TODO: do we actually need to be able to specify the client's host?
96+
def __init__(self, mock_network: MockHTTPNetwork, client_host: str | None = None):
9897
self._mock_network = mock_network
99-
self._host = host or "passive client"
98+
# Since the nodes use HTTP for P2P messaging,
99+
# we need to be able to report the client's hostname (used for DDoS protection).
100+
self._client_host = client_host or "passive client"
100101

101102
async def fetch_certificate(self, host: str, port: int) -> SSLCertificate:
102103
certificate, _manager = self._mock_network.get_server(host, port)
103104
return certificate
104105

105106
@asynccontextmanager
106107
async def session(
107-
self, _certificate: SSLCertificate | None = None
108+
self, certificate: SSLCertificate | None = None
108109
) -> AsyncIterator["MockHTTPClientSession"]:
109-
yield MockHTTPClientSession(self._mock_network, self._host)
110+
yield MockHTTPClientSession(self._mock_network, self._client_host, certificate)
110111

111112

112113
class MockHTTPClientSession(HTTPClientSession):
113-
def __init__(self, mock_network: MockHTTPNetwork, host: str = "mock_hostname"):
114+
def __init__(
115+
self,
116+
mock_network: MockHTTPNetwork,
117+
client_host: str = "mock_hostname",
118+
certificate: SSLCertificate | None = None,
119+
):
114120
self._mock_network = mock_network
115-
self._host = host
121+
self._client_host = client_host
122+
self._certificate = certificate
116123

117124
async def get(self, url: str, params: Mapping[str, str] = {}) -> HTTPResponse:
118125
response = await self._request("get", url, params=params)
@@ -126,11 +133,14 @@ async def _request(self, method: str, url: str, *args: Any, **kwargs: Any) -> ht
126133
url_parts = urlparse(url)
127134
assert url_parts.hostname is not None, "Hostname is missing from the url"
128135
assert url_parts.port is not None, "Port is missing from the url"
129-
_certificate, manager = self._mock_network.get_server(url_parts.hostname, url_parts.port)
130-
# TODO: check the cerificate's validity here
136+
certificate, manager = self._mock_network.get_server(url_parts.hostname, url_parts.port)
137+
138+
if self._certificate is not None and certificate != self._certificate:
139+
raise HTTPClientError("Certificate mismatch")
140+
131141
# Unfortunately there are no unified types for hypercorn and httpx,
132142
# so we have to cast manually.
133143
app = cast("httpx._transports.asgi._ASGIApp", manager.app) # noqa: SLF001
134-
transport = httpx.ASGITransport(app=app, client=(str(self._host), 9999))
144+
transport = httpx.ASGITransport(app=app, client=(self._client_host, 9999))
135145
async with httpx.AsyncClient(transport=transport) as client:
136146
return await client.request(method, url, *args, **kwargs)

nucypher_async/_mocks/eth.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,26 @@ async def call(self, call: BoundMethodCall) -> Any:
5959
async def transact(
6060
self, signer: Signer, call: BoundMethodCall, amount: Amount | None = None
6161
) -> None:
62-
# TODO: change the caller's balance appropriately
63-
# TODO: check that the call is payable if amount is not 0
64-
6562
# Lower the type from specific currency
6663
amount = Amount.wei(0 if amount is None else amount.as_wei())
6764

65+
signer_balance = self._balances[signer.address]
66+
if signer_balance < amount:
67+
raise ValueError(
68+
f"Not enough funds for {signer.address.checksum}: "
69+
f"need {amount}, got {signer_balance}"
70+
)
71+
72+
if amount.as_wei() != 0 and not call.payable:
73+
raise ValueError(
74+
"The call is not payable, but the transaction has funds attached to it"
75+
)
76+
6877
# Lower the signer address type
6978
address = Address(bytes(signer.address))
7079

7180
self._contracts[call.contract_address].transact(address, amount, call.data_bytes)
81+
self._balances[signer.address] -= amount
7282

7383
def set_balance(self, address: Address, amount: Amount) -> None:
7484
# Lower the type from specific currency

nucypher_async/_mocks/identity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def isOperatorConfirmed(self, operator_address: Address) -> bool: # noqa: N802
3939
def getActiveStakingProviders( # noqa: N802
4040
self, _start_index: int, _max_staking_providers: int
4141
) -> tuple[int, list[bytes]]:
42-
# TODO: support pagination
42+
# TODO (#47): implement pagination
4343
total = sum(amount.as_wei() for amount in self._stakes.values())
4444
return total, [
4545
bytes(address) + amount.as_wei().to_bytes(12, byteorder="big")

0 commit comments

Comments
 (0)