Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/examples_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

from nucypher_async._drivers.http_client import HTTPClient
from nucypher_async._drivers.ssl import fetch_certificate
from nucypher_async._drivers.time import SystemClock
from nucypher_async._drivers.time import BaseClock, SystemClock
from nucypher_async._mocks import MockCBDClient, MockClock, MockIdentityClient, MockPREClient
from nucypher_async._p2p import Contact, NodeClient, Operator
from nucypher_async.base.time import BaseClock
from nucypher_async.blockchain.cbd import CBDClient
from nucypher_async.blockchain.identity import AmountT, IdentityAccount, IdentityClient
from nucypher_async.blockchain.pre import PREAccount, PREClient
from nucypher_async.blockchain.pre import PREAccount, PREAmount, PREClient
from nucypher_async.characters import MasterKey
from nucypher_async.characters.cbd import Decryptor
from nucypher_async.characters.pre import Reencryptor
Expand Down Expand Up @@ -54,6 +53,9 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
cbd_client = MockCBDClient()
clock = MockClock()

pre_account = PREAccount.random()
pre_client.mock_set_balance(pre_account.address, PREAmount.ether(1))

logger.info("Mocked mode - starting nodes")

for i in range(3):
Expand All @@ -74,7 +76,7 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
AmountT.ether(40000),
)

logger = logger.get_child(f"Node{i + 1}")
node_logger = logger.get_child(f"Node{i + 1}")

http_server_config = HTTPServerConfig.from_typed_values(
bind_to_address=LOCALHOST,
Expand All @@ -88,7 +90,7 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
pre_client=pre_client,
cbd_client=cbd_client,
node_client=NodeClient(HTTPClient()),
logger=logger,
logger=node_logger,
seed_contacts=seed_contacts,
clock=clock,
)
Expand Down Expand Up @@ -118,7 +120,7 @@ async def local(cls, nursery: trio.Nursery) -> "Context":
clock=clock,
seed_contact=Contact(LOCALHOST, PORT_BASE),
server_handles=handles,
pre_account=PREAccount.random(),
pre_account=pre_account,
)

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions examples/grant_and_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from nucypher_async.blockchain.pre import PREAccountSigner
from nucypher_async.characters import MasterKey
from nucypher_async.characters.pre import Delegator, Publisher, Recipient
from nucypher_async.characters.pre import Delegator, EncryptedMessage, Publisher, Recipient
from nucypher_async.client.network import NetworkClient
from nucypher_async.client.pre import LocalPREClient, pre_encrypt
from nucypher_async.client.pre import LocalPREClient


async def main(*, mocked: bool = True) -> None:
Expand Down Expand Up @@ -62,13 +62,13 @@ async def main(*, mocked: bool = True) -> None:
)

message = b"a secret message"
message_kit = pre_encrypt(enacted_policy, message)
encrypted_message = EncryptedMessage(enacted_policy.policy, message)

context.logger.info("Bob retrieves and decrypts")
decrypted = await bob_client.decrypt(
bob,
enacted_policy,
message_kit,
encrypted_message,
alice.card(),
publisher.card(),
)
Expand Down
3 changes: 2 additions & 1 deletion nucypher_async/_drivers/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ async def body_json(self) -> JSON:

@property
def remote_host(self) -> str | None:
# TODO: we can get the port here too
# We could get a port here too, but it won't be useful,
# because it won't correspond to the port the server (if any) is bound to.
return self._request.client.host if self._request.client else None

@property
Expand Down
25 changes: 11 additions & 14 deletions nucypher_async/_drivers/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,23 @@

import httpx

from .._utils import temp_file
from ..base.types import JSON
from .ssl import SSLCertificate, fetch_certificate
from .ssl import SSLCertificate, fetch_certificate, make_client_ssl_context


class HTTPClientError(Exception):
pass


# TODO: move logic to _drivers/ssl
def make_ssl_context(certificate: SSLCertificate) -> ssl.SSLContext:
# Cannot create a context with a certificate directly, see https://bugs.python.org/issue16487.
# So instead we do it via a temporary file, and negate the performance penalty by caching.
with temp_file(certificate.to_pem_bytes()) as certificate_filename:
return ssl.create_default_context(cafile=str(certificate_filename))


class HTTPClient:
# The default certificate cache size is chosen to cover the possible size of the network,
# which at its best only had a few hundred nodes.
def __init__(self, certificate_cache_size: int = 1024):
self._cached_ssl_context = lru_cache(maxsize=certificate_cache_size)(make_ssl_context)
@lru_cache(maxsize=certificate_cache_size)
def cached_ssl_context(certificate: SSLCertificate) -> ssl.SSLContext:
return make_client_ssl_context(certificate)

self._cached_ssl_context = cached_ssl_context

async def fetch_certificate(self, host: str, port: int) -> SSLCertificate:
return await fetch_certificate(host, port)
Expand Down Expand Up @@ -60,9 +55,11 @@ def json(self) -> JSON:
return cast("JSON", self._response.json())

@property
def status_code(self) -> http.HTTPStatus:
# TODO: should we add a handling of an unrecognized status code?
return http.HTTPStatus(self._response.status_code)
def status_code(self) -> http.HTTPStatus | None:
try:
return http.HTTPStatus(self._response.status_code)
except ValueError:
return None


class HTTPClientSession:
Expand Down
27 changes: 4 additions & 23 deletions nucypher_async/_drivers/http_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Encapsulates a specific HTTP server running our ASGI app (currently ``hypercorn``)."""

import os
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from ssl import SSLContext
Expand All @@ -12,9 +11,8 @@
from hypercorn.trio import serve
from hypercorn.typing import ASGIFramework

from .._utils import temp_file
from ..logging import Logger
from .ssl import SSLCertificate, SSLPrivateKey
from .ssl import SSLCertificate, SSLPrivateKey, fill_ssl_context

if TYPE_CHECKING: # pragma: no cover
import logging
Expand Down Expand Up @@ -42,35 +40,18 @@ def __init__(

# Have to keep the unencrypted private key in memory,
# but at least we're not leaking it in the filesystem.
# TODO: Can we do better? Zeroize it on cleanup?
self.__ssl_private_key = ssl_private_key

def create_ssl_context(self) -> SSLContext | None:
# sanity check
if self.certfile or self.keyfile or self.ca_certs:
raise RuntimeError(
"Certificate/keyfile must be passed to the constructor in the serialized form"
)

context = super().create_ssl_context()

# Since ssl_enabled() returns True, the context will be created,
# but with no certificates loaded.
assert context is not None, "SSL context was not created"

# Encrypt the temporary file we create with an emphemeral password.
keyfile_password = os.urandom(32)

# TODO: move logic to _drivers/ssl
if self.__ssl_ca_chain:
chain_data = b"\n".join(cert.to_pem_bytes() for cert in self.__ssl_ca_chain).decode()
context.load_verify_locations(cadata=chain_data)

with (
temp_file(self.__ssl_certificate.to_pem_bytes()) as certfile,
temp_file(self.__ssl_private_key.to_pem_bytes(keyfile_password)) as keyfile,
):
context.load_cert_chain(certfile=certfile, keyfile=keyfile, password=keyfile_password)
fill_ssl_context(
context, self.__ssl_private_key, self.__ssl_certificate, self.__ssl_ca_chain
)

return context

Expand Down
38 changes: 36 additions & 2 deletions nucypher_async/_drivers/ssl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import secrets
import ssl
from collections.abc import Iterable
from ipaddress import ip_address
from typing import cast, get_args

Expand All @@ -14,6 +16,8 @@
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509.oid import NameOID

from .._utils import temp_file


class SSLPrivateKey:
@classmethod
Expand Down Expand Up @@ -85,8 +89,6 @@ def self_signed(
host: str,
days_valid: int = 365,
) -> "SSLCertificate":
# TODO: assert that the start date is in UTC?

public_key = private_key._public_key() # noqa: SLF001

end_date = start_date.shift(days=days_valid)
Expand Down Expand Up @@ -140,6 +142,10 @@ def list_from_pem_bytes(cls, data: bytes) -> list["SSLCertificate"]:
certs_bytes = data.split(start_line)
return [cls.from_pem_bytes(start_line + cert_bytes) for cert_bytes in certs_bytes[1:]]

@staticmethod
def list_to_pem_str(certificates: Iterable["SSLCertificate"]) -> str:
return b"\n".join(certificate.to_pem_bytes() for certificate in certificates).decode()

@classmethod
def from_der_bytes(cls, data: bytes) -> "SSLCertificate":
return cls(x509.load_der_x509_certificate(data))
Expand Down Expand Up @@ -175,6 +181,34 @@ def not_valid_after(self) -> arrow.Arrow:
return arrow.get(self._certificate.not_valid_after_utc)


def make_client_ssl_context(certificate: SSLCertificate) -> ssl.SSLContext:
# Cannot create a context with a certificate directly, see https://bugs.python.org/issue16487.
# So instead we do it via a temporary file, and negate the performance penalty by caching.
with temp_file(certificate.to_pem_bytes()) as certificate_filename:
return ssl.create_default_context(cafile=str(certificate_filename))


def fill_ssl_context(
context: ssl.SSLContext,
private_key: SSLPrivateKey,
certificate: SSLCertificate,
ca_chain: Iterable[SSLCertificate] | None,
) -> None:
if ca_chain:
context.load_verify_locations(cadata=SSLCertificate.list_to_pem_str(ca_chain))

# Encrypt the temporary file we create with an emphemeral password.
# Would be nice to zeroize it in the end, but we cannot zeroize `bytes`,
# and `cryptography` does not accept `bytearray` as a password.
keyfile_password = secrets.token_bytes(32)

with (
temp_file(certificate.to_pem_bytes()) as certfile,
temp_file(private_key.to_pem_bytes(keyfile_password)) as keyfile,
):
context.load_cert_chain(certfile=certfile, keyfile=keyfile, password=keyfile_password)


async def fetch_certificate(host: str, port: int) -> SSLCertificate:
# Do not verify the certificate, it is self-signed
context = ssl.create_default_context()
Expand Down
14 changes: 13 additions & 1 deletion nucypher_async/_drivers/time.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from abc import ABC, abstractmethod

import arrow

from ..base.time import BaseClock

class BaseClock(ABC):
"""
An abstract class for getting the current time.
A behavior different from just returning the system time may be needed for tests.
"""

# not a staticmethod since some implementations may need to maintain an internal state
@abstractmethod
def utcnow(self) -> arrow.Arrow:
pass


class SystemClock(BaseClock):
Expand Down
32 changes: 21 additions & 11 deletions nucypher_async/_mocks/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
LifespanStartupEvent,
)

from .._drivers.http_client import HTTPClient, HTTPClientSession, HTTPResponse
from .._drivers.http_client import HTTPClient, HTTPClientError, HTTPClientSession, HTTPResponse
from .._drivers.http_server import HTTPServable, HTTPServableApp
from .._drivers.ssl import SSLCertificate
from ..proxy import ProxyServer
Expand Down Expand Up @@ -93,26 +93,33 @@ def get_server(self, host: str, port: int) -> tuple[SSLCertificate, LifespanMana


class MockHTTPClient(HTTPClient):
def __init__(self, mock_network: MockHTTPNetwork, host: str | None = None):
# TODO: do we actually need to be able to specify the client's host?
def __init__(self, mock_network: MockHTTPNetwork, client_host: str | None = None):
self._mock_network = mock_network
self._host = host or "passive client"
# Since the nodes use HTTP for P2P messaging,
# we need to be able to report the client's hostname (used for DDoS protection).
self._client_host = client_host or "passive client"

async def fetch_certificate(self, host: str, port: int) -> SSLCertificate:
certificate, _manager = self._mock_network.get_server(host, port)
return certificate

@asynccontextmanager
async def session(
self, _certificate: SSLCertificate | None = None
self, certificate: SSLCertificate | None = None
) -> AsyncIterator["MockHTTPClientSession"]:
yield MockHTTPClientSession(self._mock_network, self._host)
yield MockHTTPClientSession(self._mock_network, self._client_host, certificate)


class MockHTTPClientSession(HTTPClientSession):
def __init__(self, mock_network: MockHTTPNetwork, host: str = "mock_hostname"):
def __init__(
self,
mock_network: MockHTTPNetwork,
client_host: str = "mock_hostname",
certificate: SSLCertificate | None = None,
):
self._mock_network = mock_network
self._host = host
self._client_host = client_host
self._certificate = certificate

async def get(self, url: str, params: Mapping[str, str] = {}) -> HTTPResponse:
response = await self._request("get", url, params=params)
Expand All @@ -126,11 +133,14 @@ async def _request(self, method: str, url: str, *args: Any, **kwargs: Any) -> ht
url_parts = urlparse(url)
assert url_parts.hostname is not None, "Hostname is missing from the url"
assert url_parts.port is not None, "Port is missing from the url"
_certificate, manager = self._mock_network.get_server(url_parts.hostname, url_parts.port)
# TODO: check the cerificate's validity here
certificate, manager = self._mock_network.get_server(url_parts.hostname, url_parts.port)

if self._certificate is not None and certificate != self._certificate:
raise HTTPClientError("Certificate mismatch")

# Unfortunately there are no unified types for hypercorn and httpx,
# so we have to cast manually.
app = cast("httpx._transports.asgi._ASGIApp", manager.app) # noqa: SLF001
transport = httpx.ASGITransport(app=app, client=(str(self._host), 9999))
transport = httpx.ASGITransport(app=app, client=(self._client_host, 9999))
async with httpx.AsyncClient(transport=transport) as client:
return await client.request(method, url, *args, **kwargs)
16 changes: 13 additions & 3 deletions nucypher_async/_mocks/eth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,26 @@ async def call(self, call: BoundMethodCall) -> Any:
async def transact(
self, signer: Signer, call: BoundMethodCall, amount: Amount | None = None
) -> None:
# TODO: change the caller's balance appropriately
# TODO: check that the call is payable if amount is not 0

# Lower the type from specific currency
amount = Amount.wei(0 if amount is None else amount.as_wei())

signer_balance = self._balances[signer.address]
if signer_balance < amount:
raise ValueError(
f"Not enough funds for {signer.address.checksum}: "
f"need {amount}, got {signer_balance}"
)

if amount.as_wei() != 0 and not call.payable:
raise ValueError(
"The call is not payable, but the transaction has funds attached to it"
)

# Lower the signer address type
address = Address(bytes(signer.address))

self._contracts[call.contract_address].transact(address, amount, call.data_bytes)
self._balances[signer.address] -= amount

def set_balance(self, address: Address, amount: Amount) -> None:
# Lower the type from specific currency
Expand Down
2 changes: 1 addition & 1 deletion nucypher_async/_mocks/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def isOperatorConfirmed(self, operator_address: Address) -> bool: # noqa: N802
def getActiveStakingProviders( # noqa: N802
self, _start_index: int, _max_staking_providers: int
) -> tuple[int, list[bytes]]:
# TODO: support pagination
# TODO (#47): implement pagination
total = sum(amount.as_wei() for amount in self._stakes.values())
return total, [
bytes(address) + amount.as_wei().to_bytes(12, byteorder="big")
Expand Down
Loading