diff --git a/aioquic/__init__.py b/aioquic/__init__.py index f3c316ad6..978074420 100644 --- a/aioquic/__init__.py +++ b/aioquic/__init__.py @@ -1,3 +1 @@ -from .client import connect # noqa from .connection import QuicConnection # noqa -from .server import serve # noqa diff --git a/aioquic/asyncio/__init__.py b/aioquic/asyncio/__init__.py new file mode 100644 index 000000000..0871f12af --- /dev/null +++ b/aioquic/asyncio/__init__.py @@ -0,0 +1,2 @@ +from .client import connect # noqa +from .server import serve # noqa diff --git a/aioquic/client.py b/aioquic/asyncio/client.py similarity index 80% rename from aioquic/client.py rename to aioquic/asyncio/client.py index 47caa0491..9da683984 100644 --- a/aioquic/client.py +++ b/aioquic/asyncio/client.py @@ -3,10 +3,11 @@ import socket from typing import AsyncGenerator, List, Optional, TextIO, cast +from ..configuration import QuicConfiguration +from ..connection import QuicConnection +from ..tls import SessionTicket, SessionTicketHandler from .compat import asynccontextmanager -from .configuration import QuicConfiguration -from .connection import QuicConnection, QuicStreamHandler -from .tls import SessionTicket, SessionTicketHandler +from .protocol import QuicConnectionProtocol, QuicStreamHandler __all__ = ["connect"] @@ -23,7 +24,7 @@ async def connect( session_ticket: Optional[SessionTicket] = None, session_ticket_handler: Optional[SessionTicketHandler] = None, stream_handler: Optional[QuicStreamHandler] = None, -) -> AsyncGenerator[QuicConnection, None]: +) -> AsyncGenerator[QuicConnectionProtocol, None]: """ Connect to a QUIC server at the given `host` and `port`. @@ -69,20 +70,21 @@ async def connect( if idle_timeout is not None: configuration.idle_timeout = idle_timeout + connection = QuicConnection( + configuration=configuration, session_ticket_handler=session_ticket_handler + ) + # connect _, protocol = await loop.create_datagram_endpoint( - lambda: QuicConnection( - configuration=configuration, - session_ticket_handler=session_ticket_handler, - stream_handler=stream_handler, - ), + lambda: QuicConnectionProtocol(connection, stream_handler=stream_handler), local_addr=("::", 0), ) - protocol = cast(QuicConnection, protocol) - protocol.connect(addr, protocol_version=protocol_version) + protocol = cast(QuicConnectionProtocol, protocol) + protocol.connect(addr, protocol_version) await protocol.wait_connected() try: yield protocol finally: protocol.close() + protocol._send_pending() await protocol.wait_closed() diff --git a/aioquic/compat.py b/aioquic/asyncio/compat.py similarity index 100% rename from aioquic/compat.py rename to aioquic/asyncio/compat.py diff --git a/aioquic/asyncio/protocol.py b/aioquic/asyncio/protocol.py new file mode 100644 index 000000000..22f933edd --- /dev/null +++ b/aioquic/asyncio/protocol.py @@ -0,0 +1,189 @@ +import asyncio +from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast + +from .. import events +from ..connection import NetworkAddress, QuicConnection + +QuicConnectionIdHandler = Callable[[bytes], None] +QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None] + + +class QuicConnectionProtocol(asyncio.DatagramProtocol): + def __init__( + self, + connection: QuicConnection, + stream_handler: Optional[QuicStreamHandler] = None, + ): + self._connection = connection + self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None + self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None + + if stream_handler is not None: + self._stream_handler = stream_handler + else: + self._stream_handler = lambda r, w: None + + def close(self) -> None: + self._connection.close() + self._send_pending() + + def connect(self, addr: NetworkAddress, protocol_version: int) -> None: + self._connection.connect( + addr, now=self._loop.time(), protocol_version=protocol_version + ) + self._send_pending() + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + loop = asyncio.get_event_loop() + + self._closed = asyncio.Event() + self._connected_waiter = loop.create_future() + self._loop = loop + self._ping_waiter: Optional[asyncio.Future[None]] = None + self._send_task: Optional[asyncio.Handle] = None + self._stream_readers: Dict[int, asyncio.StreamReader] = {} + self._timer: Optional[asyncio.TimerHandle] = None + self._timer_at: Optional[float] = None + self._transport = cast(asyncio.DatagramTransport, transport) + + def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: + self._connection.receive_datagram( + cast(bytes, data), addr, now=self._loop.time() + ) + self._send_pending() + + async def create_stream( + self, is_unidirectional: bool = False + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + """ + Create a QUIC stream and return a pair of (reader, writer) objects. + + The returned reader and writer objects are instances of :class:`asyncio.StreamReader` + and :class:`asyncio.StreamWriter` classes. + """ + stream = self._connection.create_stream(is_unidirectional=is_unidirectional) + return self._create_stream(stream.stream_id) + + def request_key_update(self) -> None: + """ + Request an update of the encryption keys. + """ + self._connection.request_key_update() + self._send_pending() + + async def ping(self) -> None: + """ + Pings the remote host and waits for the response. + """ + assert self._ping_waiter is None, "already await a ping" + self._ping_waiter = self._loop.create_future() + self._connection.send_ping(id(self._ping_waiter)) + self._send_soon() + await asyncio.shield(self._ping_waiter) + + async def wait_closed(self) -> None: + """ + Wait for the connection to be closed. + """ + await self._closed.wait() + + async def wait_connected(self) -> None: + """ + Wait for the TLS handshake to complete. + """ + await asyncio.shield(self._connected_waiter) + + def _create_stream( + self, stream_id: int + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + adapter = QuicStreamAdapter(self, stream_id) + reader = asyncio.StreamReader() + writer = asyncio.StreamWriter(adapter, None, reader, None) + self._stream_readers[stream_id] = reader + return reader, writer + + def _handle_timer(self) -> None: + now = max(self._timer_at, self._loop.time()) + + self._timer = None + self._timer_at = None + + self._connection.handle_timer(now=now) + + self._send_pending() + + def _send_pending(self) -> None: + self._send_task = None + + # process events + event = self._connection.next_event() + while event is not None: + if isinstance(event, events.ConnectionIdIssued): + self._connection_id_issued_handler(event.connection_id) + elif isinstance(event, events.ConnectionIdRetired): + self._connection_id_retired_handler(event.connection_id) + elif isinstance(event, events.ConnectionTerminated): + for reader in self._stream_readers.values(): + reader.feed_eof() + if not self._connected_waiter.done(): + self._connected_waiter.set_exception(ConnectionError) + self._closed.set() + elif isinstance(event, events.HandshakeCompleted): + self._connected_waiter.set_result(None) + elif isinstance(event, events.PongReceived): + waiter = self._ping_waiter + self._ping_waiter = None + waiter.set_result(None) + elif isinstance(event, events.StreamDataReceived): + reader = self._stream_readers.get(event.stream_id, None) + if reader is None: + reader, writer = self._create_stream(event.stream_id) + self._stream_handler(reader, writer) + reader.feed_data(event.data) + if event.end_stream: + reader.feed_eof() + + event = self._connection.next_event() + + # send datagrams + for data, addr in self._connection.datagrams_to_send(now=self._loop.time()): + self._transport.sendto(data, addr) + + # re-arm timer + timer_at = self._connection.get_timer() + if self._timer is not None and self._timer_at != timer_at: + self._timer.cancel() + self._timer = None + if self._timer is None and timer_at is not None: + self._timer = self._loop.call_at(timer_at, self._handle_timer) + self._timer_at = timer_at + + def _send_soon(self) -> None: + if self._send_task is None: + self._send_task = self._loop.call_soon(self._send_pending) + + +class QuicStreamAdapter(asyncio.Transport): + def __init__(self, protocol: QuicConnectionProtocol, stream_id: int): + self.protocol = protocol + self.stream_id = stream_id + + def can_write_eof(self) -> bool: + return True + + def get_extra_info(self, name: str, default: Any = None) -> Any: + """ + Get information about the underlying QUIC stream. + """ + if name == "connection": + return self.protocol._connection + elif name == "stream_id": + return self.stream_id + + def write(self, data): + self.protocol._connection.send_stream_data(self.stream_id, data) + self.protocol._send_soon() + + def write_eof(self): + self.protocol._connection.send_stream_data(self.stream_id, b"", end_stream=True) + self.protocol._send_soon() diff --git a/aioquic/server.py b/aioquic/asyncio/server.py similarity index 80% rename from aioquic/server.py rename to aioquic/asyncio/server.py index 40da7f953..e56ded542 100644 --- a/aioquic/server.py +++ b/aioquic/asyncio/server.py @@ -1,26 +1,28 @@ import asyncio import ipaddress import os +from functools import partial from typing import Any, Callable, Dict, List, Optional, Text, TextIO, Union, cast from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding, rsa -from .buffer import Buffer -from .configuration import QuicConfiguration -from .connection import NetworkAddress, QuicConnection, QuicStreamHandler -from .packet import ( +from ..buffer import Buffer +from ..configuration import QuicConfiguration +from ..connection import NetworkAddress, QuicConnection +from ..packet import ( PACKET_TYPE_INITIAL, encode_quic_retry, encode_quic_version_negotiation, pull_quic_header, ) -from .tls import SessionTicketFetcher, SessionTicketHandler +from ..tls import SessionTicketFetcher, SessionTicketHandler +from .protocol import QuicConnectionProtocol, QuicStreamHandler __all__ = ["serve"] -QuicConnectionHandler = Callable[[QuicConnection], None] +QuicConnectionHandler = Callable[[QuicConnectionProtocol], None] def encode_address(addr: NetworkAddress) -> bytes: @@ -39,7 +41,8 @@ def __init__( stream_handler: Optional[QuicStreamHandler] = None, ) -> None: self._configuration = configuration - self._connections: Dict[bytes, QuicConnection] = {} + self._protocols: Dict[bytes, QuicConnectionProtocol] = {} + self._loop = asyncio.get_event_loop() self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler self._transport: Optional[asyncio.DatagramTransport] = None @@ -81,9 +84,9 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N ) return - connection = self._connections.get(header.destination_cid, None) + protocol = self._protocols.get(header.destination_cid, None) original_connection_id: Optional[bytes] = None - if connection is None and header.packet_type == PACKET_TYPE_INITIAL: + if protocol is None and header.packet_type == PACKET_TYPE_INITIAL: # stateless retry if self._retry_key is not None: if not header.token: @@ -131,25 +134,34 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N original_connection_id=original_connection_id, session_ticket_fetcher=self._session_ticket_fetcher, session_ticket_handler=self._session_ticket_handler, - stream_handler=self._stream_handler, ) - self._connections[header.destination_cid] = connection + protocol = QuicConnectionProtocol( + connection, stream_handler=self._stream_handler + ) + protocol._connection_id_issued_handler = partial( + self._connection_id_issued, protocol=protocol + ) + protocol._connection_id_retired_handler = partial( + self._connection_id_retired, protocol=protocol + ) - def connection_id_issued(cid: bytes) -> None: - self._connections[cid] = connection + self._protocols[header.destination_cid] = protocol + protocol.connection_made(self._transport) - def connection_id_retired(cid: bytes) -> None: - del self._connections[cid] + self._protocols[connection.host_cid] = protocol + self._connection_handler(protocol) - connection._connection_id_issued_handler = connection_id_issued - connection._connection_id_retired_handler = connection_id_retired - connection.connection_made(self._transport) + if protocol is not None: + protocol.datagram_received(data, addr) - self._connections[connection.host_cid] = connection - self._connection_handler(connection) + def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol): + self._protocols[cid] = protocol - if connection is not None: - connection.datagram_received(data, addr) + def _connection_id_retired( + self, cid: bytes, protocol: QuicConnectionProtocol + ) -> None: + assert self._protocols[cid] == protocol + del self._protocols[cid] async def serve( @@ -181,7 +193,7 @@ async def serve( * ``connection_handler`` is a callback which is invoked whenever a connection is created. It must be a a function accepting a single - argument: a :class:`~aioquic.QuicConnection`. + argument: a :class:`~aioquic.asyncio.protocol.QuicConnectionProtocol`. * ``secrets_log_file`` is a file-like object in which to log traffic secrets. This is useful to analyze traffic captures with Wireshark. * ``stateless_retry`` specifies whether a stateless retry should be diff --git a/aioquic/connection.py b/aioquic/connection.py index 1cf2fcb3c..9896224d7 100644 --- a/aioquic/connection.py +++ b/aioquic/connection.py @@ -1,23 +1,12 @@ -import asyncio import binascii import logging import os +from collections import deque from dataclasses import dataclass from enum import Enum -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - List, - Optional, - Text, - Tuple, - Union, - cast, -) +from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Tuple, cast -from . import tls +from . import events, tls from .buffer import Buffer, BufferReadError, size_uint_var from .configuration import QuicConfiguration from .crypto import CryptoError, CryptoPair @@ -246,13 +235,10 @@ class QuicReceiveContext: time: float -QuicConnectionIdHandler = Callable[[bytes], None] -QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None] - END_STATES = frozenset([QuicConnectionState.CLOSING, QuicConnectionState.DRAINING]) -class QuicConnection(asyncio.DatagramProtocol): +class QuicConnection: """ A QUIC connection. """ @@ -264,7 +250,6 @@ def __init__( original_connection_id: Optional[bytes] = None, session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None, session_ticket_handler: Optional[tls.SessionTicketHandler] = None, - stream_handler: Optional[QuicStreamHandler] = None, ) -> None: if configuration.is_client: assert ( @@ -278,8 +263,6 @@ def __init__( configuration.private_key is not None ), "SSL private key is required for a server" - loop = asyncio.get_event_loop() - # counters for debugging self._stateless_retry_count = 0 self._version_negotiation_count = 0 @@ -288,14 +271,14 @@ def __init__( self._configuration = configuration self._is_client = configuration.is_client + self._ack_delay = K_GRANULARITY self._close_at: Optional[float] = None - self._close_exception: Optional[Exception] = None - self._closed = asyncio.Event() + self._close_event: Optional[events.ConnectionTerminated] = None self._connect_called = False - self._connected_waiter = loop.create_future() self._cryptos: Dict[tls.Epoch, CryptoPair] = {} self._crypto_buffers: Dict[tls.Epoch, Buffer] = {} self._crypto_streams: Dict[tls.Epoch, QuicStream] = {} + self._events: Deque[events.Event] = deque() self._handshake_complete = False self._handshake_confirmed = False self._host_cids = [ @@ -316,22 +299,19 @@ def __init__( self._local_max_stream_data_uni = MAX_DATA_WINDOW self._local_max_streams_bidi = 128 self._local_max_streams_uni = 128 - self._loop = loop self._logger = QuicConnectionAdapter( logger, {"host_cid": dump_cid(self.host_cid)} ) self._loss = QuicPacketRecovery(send_probe=self._send_probe) - self._loss_detection_at: Optional[float] = None + self._loss_at: Optional[float] = None self._network_paths: List[QuicNetworkPath] = [] self._original_connection_id = original_connection_id self._packet_number = 0 - self._parameters_available = asyncio.Event() self._parameters_received = False self._peer_cid = os.urandom(8) self._peer_cid_seq: Optional[int] = None self._peer_cid_available: List[QuicConnectionId] = [] self._peer_token = b"" - self._ping_waiter: Optional[asyncio.Future[None]] = None self._remote_idle_timeout = 0.0 # seconds self._remote_max_data = 0 self._remote_max_data_used = 0 @@ -346,27 +326,17 @@ def __init__( self._spin_highest_pn = 0 self._state = QuicConnectionState.FIRSTFLIGHT self._streams: Dict[int, QuicStream] = {} - self._timer: Optional[asyncio.TimerHandle] = None - self._timer_at: Optional[float] = None - self._transport: Optional[asyncio.DatagramTransport] = None self._version: Optional[int] = None # things to send self._close_pending: Optional[Dict] = None - self._ping_pending = False + self._ping_pending: List[int] = [] self._probe_pending = False self._retire_connection_ids: List[int] = [] - self._send_task: Optional[asyncio.Handle] = None # callbacks - self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None - self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler - if stream_handler is not None: - self._stream_handler = stream_handler - else: - self._stream_handler = lambda r, w: None # frame handlers self.__frame_handlers = [ @@ -419,20 +389,27 @@ def close( Close the connection. """ if self._state not in END_STATES: + self._close_event = events.ConnectionTerminated( + error_code=error_code, + frame_type=frame_type, + reason_phrase=reason_phrase, + ) self._close_pending = { "error_code": error_code, "frame_type": frame_type, "reason_phrase": reason_phrase, } - self._send_pending() def connect( - self, addr: NetworkAddress, protocol_version: Optional[int] = None + self, addr: NetworkAddress, now: float, protocol_version: Optional[int] = None ) -> None: """ Initiate the TLS handshake. This method can only be called for clients and a single time. + + After calling this method call :meth:`datagrams_to_send` to retrieve data + which needs to be sent. """ assert ( self._is_client and not self._connect_called @@ -444,22 +421,18 @@ def connect( self._version = protocol_version else: self._version = max(self._configuration.supported_versions) - self._connect() + self._connect(now=now) - async def create_stream( - self, is_unidirectional: bool = False - ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + def create_stream( + self, is_unidirectional: bool = False, stream_id: Optional[int] = None + ) -> QuicStream: """ - Create a QUIC stream and return a pair of (reader, writer) objects. - - The returned reader and writer objects are instances of :class:`asyncio.StreamReader` - and :class:`asyncio.StreamWriter` classes. + Create a QUIC stream and return it. """ - await self._parameters_available.wait() - - stream_id = (int(is_unidirectional) << 1) | int(not self._is_client) - while stream_id in self._streams: - stream_id += 4 + if stream_id is None: + stream_id = (int(is_unidirectional) << 1) | int(not self._is_client) + while stream_id in self._streams: + stream_id += 4 # determine limits if is_unidirectional: @@ -482,54 +455,140 @@ async def create_stream( max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, ) + return stream - return stream.reader, stream.writer - - async def ping(self) -> None: + def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: """ - Pings the remote host and waits for the response. + Return (data, addr) tuples of datagrams which need to be sent. """ - assert self._ping_waiter is None, "already await a ping" - self._ping_pending = True - self._ping_waiter = self._loop.create_future() - self._send_soon() - await asyncio.shield(self._ping_waiter) + network_path = self._network_paths[0] - def request_key_update(self) -> None: - """ - Request an update of the encryption keys. - """ - assert self._handshake_complete, "cannot change key before handshake completes" - self._cryptos[tls.Epoch.ONE_RTT].update_key() + if self._state in END_STATES: + return [] - async def wait_closed(self) -> None: + # build datagrams + builder = QuicPacketBuilder( + host_cid=self.host_cid, + packet_number=self._packet_number, + pad_first_datagram=( + self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT + ), + peer_cid=self._peer_cid, + peer_token=self._peer_token, + spin_bit=self._spin_bit, + version=self._version, + ) + if self._close_pending: + for epoch, packet_type in ( + (tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT), + (tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE), + (tls.Epoch.INITIAL, PACKET_TYPE_INITIAL), + ): + crypto = self._cryptos[epoch] + if crypto.send.is_valid(): + builder.start_packet(packet_type, crypto) + write_close_frame(builder, **self._close_pending) + builder.end_packet() + self._close_pending = None + break + self._close(is_initiator=True, now=now) + else: + # congestion control + builder.max_flight_bytes = ( + self._loss.congestion_window - self._loss.bytes_in_flight + ) + if not network_path.is_validated: + # limit data on un-validated network paths + builder.max_total_bytes = ( + network_path.bytes_received * 3 - network_path.bytes_sent + ) + + try: + if not self._handshake_confirmed: + for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]: + self._write_handshake(builder, epoch) + self._write_application(builder, network_path, now) + except QuicPacketBuilderStop: + pass + + datagrams, packets = builder.flush() + + if datagrams: + self._packet_number = builder.packet_number + + # register packets + sent_handshake = False + for packet in packets: + packet.sent_time = now + self._loss.on_packet_sent( + packet=packet, space=self._spaces[packet.epoch] + ) + if packet.epoch == tls.Epoch.HANDSHAKE: + sent_handshake = True + + # check if we can discard initial keys + if sent_handshake and self._is_client: + self._discard_epoch(tls.Epoch.INITIAL) + + # return datagrams to send and the destination network address + ret = [] + for datagram in datagrams: + network_path.bytes_sent += len(datagram) + ret.append((datagram, network_path.addr)) + return ret + + def get_timer(self) -> Optional[float]: """ - Wait for the connection to be closed. + Return the time at which the timer should fire or None if no timer is needed. """ - await self._closed.wait() + timer_at = self._close_at + if self._state not in END_STATES: + # ack timer + for space in self._loss.spaces: + if space.ack_at is not None and space.ack_at < timer_at: + timer_at = space.ack_at + + # loss detection timer + self._loss_at = self._loss.get_loss_detection_time() + if self._loss_at is not None and self._loss_at < timer_at: + timer_at = self._loss_at + return timer_at - async def wait_connected(self) -> None: + def handle_timer(self, now: float) -> None: """ - Wait for the TLS handshake to complete. + Handle the timer. + + After calling this method call :meth:`datagrams_to_send` to retrieve data + which needs to be sent. """ - await asyncio.shield(self._connected_waiter) + # idle timeout + if now >= self._close_at: + if self._close_event is None: + self._close_event = events.ConnectionTerminated( + error_code=QuicErrorCode.INTERNAL_ERROR, + frame_type=None, + reason_phrase="Idle timeout", + ) + self._close_complete() + return - # asyncio.DatagramProtocol + # loss detection timeout + if self._loss_at is not None and now >= self._loss_at: + self._logger.debug("Loss detection triggered") + self._loss.on_loss_detection_timeout(now=now) - def connection_lost(self, exc: Exception) -> None: - self._logger.info("Connection closed") - for epoch in self._spaces.keys(): - self._discard_epoch(epoch) - for stream in self._streams.values(): - stream.connection_lost(exc) - if not self._connected_waiter.done(): - self._connected_waiter.set_exception(exc or ConnectionError) - self._closed.set() + def next_event(self) -> Optional[events.Event]: + """ + Retrieve the next event from the event buffer. - def connection_made(self, transport: asyncio.BaseTransport) -> None: - self._transport = cast(asyncio.DatagramTransport, transport) + Returns `None` if there are no buffered events. + """ + try: + return self._events.popleft() + except IndexError: + return None - def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: + def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> None: """ Handle an incoming datagram. """ @@ -539,7 +598,6 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N data = cast(bytes, data) buf = Buffer(data=data) - now = self._loop.time() while not buf.eof(): start_off = buf.tell() header = pull_quic_header(buf, host_cid_length=len(self.host_cid)) @@ -564,18 +622,17 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N ) if not common: self._logger.error("Could not find a common protocol version") - self.connection_lost( - QuicConnectionError( - error_code=QuicErrorCode.INTERNAL_ERROR, - frame_type=None, - reason_phrase="Could not find a common protocol version", - ) + self._close_event = events.ConnectionTerminated( + error_code=QuicErrorCode.INTERNAL_ERROR, + frame_type=None, + reason_phrase="Could not find a common protocol version", ) + self._close_complete() return self._version = QuicProtocolVersion(max(common)) self._version_negotiation_count += 1 self._logger.info("Retrying with %s", self._version) - self._connect() + self._connect(now=now) return elif ( header.version is not None @@ -596,7 +653,7 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N self._peer_token = header.token self._stateless_retry_count += 1 self._logger.info("Performing stateless retry") - self._connect() + self._connect(now=now) return network_path = self._find_network_path(addr) @@ -667,13 +724,12 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N ) except QuicConnectionError as exc: self._logger.warning(exc) - self._close_exception = exc self.close( error_code=exc.error_code, frame_type=exc.frame_type, reason_phrase=exc.reason_phrase, ) - if self._state in END_STATES: + if self._state in END_STATES or self._close_pending: return # update idle timeout @@ -713,12 +769,37 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N space.largest_received_packet = packet_number space.ack_queue.add(packet_number) if is_ack_eliciting and space.ack_at is None: - space.ack_at = now + K_GRANULARITY + space.ack_at = now + self._ack_delay + + def request_key_update(self) -> None: + """ + Request an update of the encryption keys. + """ + assert self._handshake_complete, "cannot change key before handshake completes" + self._cryptos[tls.Epoch.ONE_RTT].update_key() - self._send_pending() + def send_ping(self, uid: int) -> None: + """ + Send a PING frame to the peer. + """ + self._ping_pending.append(uid) + + def send_stream_data( + self, stream_id: int, data: bytes, end_stream: bool = False + ) -> None: + """ + Send data on the specific stream. - def error_received(self, exc: Exception) -> None: - self._logger.warning(exc) + If `end_stream` is `True`, the FIN bit will be set. + """ + try: + stream = self._streams[stream_id] + except KeyError: + self.create_stream(stream_id=stream_id) + stream = self._streams[stream_id] + stream.write(data) + if end_stream: + stream.write_eof() # Private @@ -744,28 +825,37 @@ def _assert_stream_can_send(self, frame_type: int, stream_id: int) -> None: reason_phrase="Stream is receive-only", ) - def _close(self, is_initiator: bool) -> None: + def _close(self, is_initiator: bool, now: float) -> None: """ Start the close procedure. """ - self._close_at = self._loop.time() + 3 * self._loss.get_probe_timeout() + self._close_at = now + 3 * self._loss.get_probe_timeout() if is_initiator: self._set_state(QuicConnectionState.CLOSING) else: self._set_state(QuicConnectionState.DRAINING) - def _connect(self) -> None: + def _close_complete(self) -> None: + """ + Finish the close procedure. + """ + self._logger.info("Connection closed") + self._close_at = None + for epoch in self._spaces.keys(): + self._discard_epoch(epoch) + self._events.append(self._close_event) + + def _connect(self, now: float) -> None: """ Start the client handshake. """ assert self._is_client - self._close_at = self._loop.time() + self._configuration.idle_timeout + self._close_at = now + self._configuration.idle_timeout self._initialize(self._peer_cid) self.tls.handle_message(b"", self._crypto_buffers) self._push_crypto_data() - self._send_pending() def _consume_connection_id(self) -> None: """ @@ -840,7 +930,6 @@ def _get_or_create_stream(self, frame_type: int, stream_id: int) -> QuicStream: max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, ) - self._stream_handler(stream.reader, stream.writer) return stream def _initialize(self, peer_cid: bytes) -> None: @@ -955,15 +1044,10 @@ def _handle_connection_close_frame( self._logger.info( "Connection close code 0x%X, reason %s", error_code, reason_phrase ) - if error_code != QuicErrorCode.NO_ERROR: - self._close_exception = QuicConnectionError( - error_code=error_code, - frame_type=frame_type, - reason_phrase=reason_phrase, - ) - - self._close(is_initiator=False) - self._set_timer() + self._close_event = events.ConnectionTerminated( + error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase + ) + self._close(is_initiator=False, now=context.time) def _handle_crypto_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer @@ -1008,8 +1092,7 @@ def _handle_crypto_frame( ]: self._handshake_complete = True self._replenish_connection_ids() - # wakeup waiter - self._connected_waiter.set_result(None) + self._events.append(events.HandshakeCompleted()) def _handle_data_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer @@ -1161,8 +1244,8 @@ def _handle_reset_stream_frame( error_code, final_size, ) - stream = self._get_or_create_stream(frame_type, stream_id) - stream.connection_lost(None) + # stream = self._get_or_create_stream(frame_type, stream_id) + self._events.append(events.StreamReset(stream_id=stream_id)) def _handle_retire_connection_id_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer @@ -1182,7 +1265,9 @@ def _handle_retire_connection_id_frame( reason_phrase="Cannot retire current connection ID", ) del self._host_cids[index] - self._connection_id_retired_handler(connection_id.cid) + self._events.append( + events.ConnectionIdRetired(connection_id=connection_id.cid) + ) break # issue a new connection ID @@ -1290,43 +1375,28 @@ def _on_new_connection_id_delivery( if delivery != QuicDeliveryState.ACKED: connection_id.was_sent = False - def _on_ping_delivery(self, delivery: QuicDeliveryState) -> None: + def _on_ping_delivery( + self, delivery: QuicDeliveryState, uids: Sequence[int] + ) -> None: """ - Callback when a PING frame is is acknowledged or lost. + Callback when a PING frame is acknowledged or lost. """ if delivery == QuicDeliveryState.ACKED: self._logger.info("Received PING response") - waiter = self._ping_waiter - self._ping_waiter = None - waiter.set_result(None) + for uid in uids: + self._events.append(events.PongReceived(uid=uid)) else: - self._ping_pending = True + self._ping_pending.extend(uids) def _on_retire_connection_id_delivery( self, delivery: QuicDeliveryState, sequence_number: int ) -> None: """ - Callback when a RETIRE_CONNECTION_ID frame is is acknowledged or lost. + Callback when a RETIRE_CONNECTION_ID frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: self._retire_connection_ids.append(sequence_number) - def _on_timeout(self) -> None: - now = self._loop.time() + K_GRANULARITY - self._timer = None - self._timer_at = None - - # idle timeout - if now >= self._close_at: - self.connection_lost(self._close_exception) - return - - # loss detection timeout - if self._loss_at is not None and now >= self._loss_at: - self._logger.debug("Loss detection triggered") - self._loss.on_loss_detection_timeout(now=now) - self._send_pending() - def _payload_received( self, context: QuicReceiveContext, plain: bytes ) -> Tuple[bool, bool]: @@ -1399,93 +1469,9 @@ def _push_crypto_data(self) -> None: self._crypto_streams[epoch].write(buf.data) buf.seek(0) - def _send_pending(self) -> None: - network_path = self._network_paths[0] - - self._send_task = None - if self._state in END_STATES: - return - - # build datagrams - builder = QuicPacketBuilder( - host_cid=self.host_cid, - packet_number=self._packet_number, - pad_first_datagram=( - self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT - ), - peer_cid=self._peer_cid, - peer_token=self._peer_token, - spin_bit=self._spin_bit, - version=self._version, - ) - if self._close_pending: - for epoch, packet_type in ( - (tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT), - (tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE), - (tls.Epoch.INITIAL, PACKET_TYPE_INITIAL), - ): - crypto = self._cryptos[epoch] - if crypto.send.is_valid(): - builder.start_packet(packet_type, crypto) - write_close_frame(builder, **self._close_pending) - builder.end_packet() - self._close_pending = None - break - self._close(is_initiator=True) - else: - # congestion control - builder.max_flight_bytes = ( - self._loss.congestion_window - self._loss.bytes_in_flight - ) - if not network_path.is_validated: - # limit data on un-validated network paths - builder.max_total_bytes = ( - network_path.bytes_received * 3 - network_path.bytes_sent - ) - - try: - if not self._handshake_confirmed: - for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]: - self._write_handshake(builder, epoch) - self._write_application(builder, network_path, self._loop.time()) - except QuicPacketBuilderStop: - pass - - datagrams, packets = builder.flush() - - if datagrams: - self._packet_number = builder.packet_number - - # send datagrams - for datagram in datagrams: - self._transport.sendto(datagram, network_path.addr) - network_path.bytes_sent += len(datagram) - - # register packets - now = self._loop.time() - sent_handshake = False - for packet in packets: - packet.sent_time = now - self._loss.on_packet_sent( - packet=packet, space=self._spaces[packet.epoch] - ) - if packet.epoch == tls.Epoch.HANDSHAKE: - sent_handshake = True - - # check if we can discard initial keys - if sent_handshake and self._is_client: - self._discard_epoch(tls.Epoch.INITIAL) - - # arm timer - self._set_timer() - def _send_probe(self) -> None: self._probe_pending = True - def _send_soon(self) -> None: - if self._send_task is None: - self._send_task = self._loop.call_soon(self._send_pending) - def _parse_transport_parameters( self, data: bytes, from_session_ticket: bool = False ) -> None: @@ -1525,10 +1511,6 @@ def _parse_transport_parameters( if value is not None: setattr(self, "_remote_" + param, value) - # wakeup waiters - if not self._parameters_available.is_set(): - self._parameters_available.set() - def _serialize_transport_parameters(self) -> bytes: quic_transport_parameters = QuicTransportParameters( idle_timeout=int(self._configuration.idle_timeout * 1000), @@ -1553,28 +1535,6 @@ def _set_state(self, state: QuicConnectionState) -> None: self._logger.debug("%s -> %s", self._state, state) self._state = state - def _set_timer(self) -> None: - # determine earliest timeout - timer_at = self._close_at - if self._state not in END_STATES: - # ack timer - for space in self._loss.spaces: - if space.ack_at is not None and space.ack_at < timer_at: - timer_at = space.ack_at - - # loss detection timer - self._loss_at = self._loss.get_loss_detection_time() - if self._loss_at is not None and self._loss_at < timer_at: - timer_at = self._loss_at - - # re-arm timer - if self._timer is not None and self._timer_at != timer_at: - self._timer.cancel() - self._timer = None - if self._timer is None and timer_at is not None: - self._timer = self._loop.call_at(timer_at, self._on_timeout) - self._timer_at = timer_at - def _stream_can_receive(self, stream_id: int) -> bool: return stream_is_client_initiated( stream_id @@ -1676,7 +1636,9 @@ def _write_application( connection_id.stateless_reset_token, ) connection_id.was_sent = True - self._connection_id_issued_handler(connection_id.cid) + self._events.append( + events.ConnectionIdIssued(connection_id=connection_id.cid) + ) # RETIRE_CONNECTION_ID while self._retire_connection_ids: @@ -1698,8 +1660,12 @@ def _write_application( # PING (user-request) if self._ping_pending: self._logger.info("Sending PING in packet %d", builder.packet_number) - builder.start_frame(QuicFrameType.PING, self._on_ping_delivery) - self._ping_pending = False + builder.start_frame( + QuicFrameType.PING, + self._on_ping_delivery, + (tuple(self._ping_pending),), + ) + self._ping_pending.clear() # PING (probe) if self._probe_pending: diff --git a/aioquic/events.py b/aioquic/events.py new file mode 100644 index 000000000..b9c12f644 --- /dev/null +++ b/aioquic/events.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass + + +class Event: + pass + + +@dataclass +class ConnectionIdIssued(Event): + connection_id: bytes + + +@dataclass +class ConnectionIdRetired(Event): + connection_id: bytes + + +@dataclass +class ConnectionTerminated(Event): + error_code: int + frame_type: int + reason_phrase: str + + +class HandshakeCompleted(Event): + pass + + +@dataclass +class PongReceived(Event): + uid: int + + +@dataclass +class StreamDataReceived(Event): + data: bytes + end_stream: bool + stream_id: int + + +@dataclass +class StreamReset(Event): + stream_id: int diff --git a/aioquic/stream.py b/aioquic/stream.py index 1fce7e6a9..a4c60097e 100644 --- a/aioquic/stream.py +++ b/aioquic/stream.py @@ -1,12 +1,12 @@ -import asyncio from typing import Any, Optional +from . import events from .packet import QuicStreamFrame from .packet_builder import QuicDeliveryState from .rangeset import RangeSet -class QuicStream(asyncio.Transport): +class QuicStream: def __init__( self, stream_id: Optional[int] = None, @@ -20,13 +20,6 @@ def __init__( self.max_stream_data_remote = max_stream_data_remote self.send_buffer_is_empty = True - if stream_id is not None: - self.reader = asyncio.StreamReader() - self.writer = asyncio.StreamWriter(self, None, self.reader, None) - else: - self.reader = None - self.writer = None - self._recv_buffer = bytearray() self._recv_buffer_fin: Optional[int] = None self._recv_buffer_start = 0 # the offset for the start of the buffer @@ -48,13 +41,6 @@ def __init__( def stream_id(self) -> Optional[int]: return self.__stream_id - def connection_lost(self, exc: Exception) -> None: - if self.reader is not None: - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) - # reader def add_frame(self, frame: QuicStreamFrame) -> None: @@ -92,12 +78,15 @@ def add_frame(self, frame: QuicStreamFrame) -> None: if frame.fin: self._recv_buffer_fin = frame_end - if self.reader: + if self._connection: data = self.pull_data() - if data: - self.reader.feed_data(data) - if self._recv_buffer_start == self._recv_buffer_fin: - self.reader.feed_eof() + self._connection._events.append( + events.StreamDataReceived( + data=data, + end_stream=(self._recv_buffer_start == self._recv_buffer_fin), + stream_id=self.__stream_id, + ) + ) def pull_data(self) -> bytes: """ @@ -203,18 +192,6 @@ def on_data_delivery( # asyncio.Transport - def can_write_eof(self) -> bool: - return True - - def get_extra_info(self, name: str, default: Any = None) -> Any: - """ - Get information about the underlying QUIC stream. - """ - if name == "connection": - return self._connection - elif name == "stream_id": - return self.stream_id - def get_write_buffer_size(self) -> int: """ Return the current size of the write buffer. @@ -235,8 +212,6 @@ def write(self, data: bytes) -> None: ) self._send_buffer += data self._send_buffer_stop += size - if self._connection is not None: - self._connection._send_soon() def write_eof(self) -> None: assert self._send_buffer_fin is None, "cannot call write_eof() after FIN" @@ -244,5 +219,3 @@ def write_eof(self) -> None: self.send_buffer_is_empty = False self._send_buffer_fin = self._send_buffer_stop self._send_pending_eof = True - if self._connection is not None: - self._connection._send_soon() diff --git a/docs/api.rst b/docs/api.rst index 1e04f2315..84cb0dbb8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,7 +1,7 @@ API Reference ============= -.. automodule:: aioquic +.. automodule:: aioquic.asyncio Client ------ @@ -16,7 +16,7 @@ Server Common ------ - .. autoclass:: QuicConnection + .. autoclass:: QuicConnectionProtocol .. automethod:: close() .. automethod:: create_stream() diff --git a/docs/http_client.py b/docs/http_client.py index 7bb57f526..9c1c3c420 100644 --- a/docs/http_client.py +++ b/docs/http_client.py @@ -1,10 +1,10 @@ #!/usr/bin/env python import asyncio -import aioquic +from aioquic.asyncio import connect async def http_client(host, port): - async with aioquic.connect(host, port) as connection: + async with connect(host, port) as connection: reader, writer = await connection.create_stream() writer.write(b"GET /\r\n") writer.write_eof() diff --git a/examples/client.py b/examples/client.py index 9d1ca52e8..b0529af0c 100644 --- a/examples/client.py +++ b/examples/client.py @@ -4,7 +4,7 @@ import pickle import time -import aioquic +from aioquic.asyncio import connect try: import uvloop @@ -27,7 +27,7 @@ def save_session_ticket(ticket): async def run(host, port, path, **kwargs): - async with aioquic.connect(host, port, **kwargs) as connection: + async with connect(host, port, **kwargs) as connection: # perform HTTP/0.9 request reader, writer = await connection.create_stream() writer.write(("GET %s\r\n" % path).encode("utf8")) diff --git a/examples/interop.py b/examples/interop.py index cd6529ab7..7d0aa882d 100644 --- a/examples/interop.py +++ b/examples/interop.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from enum import Flag -import aioquic +from aioquic.asyncio import connect class Result(Flag): @@ -84,17 +84,17 @@ async def http_request(connection, path): async def test_version_negotiation(config, **kwargs): - async with aioquic.connect( + async with connect( config.host, config.port, protocol_version=0x1A2A3A4A, **kwargs ) as connection: - if connection._version_negotiation_count == 1: + if connection._connection._version_negotiation_count == 1: config.result |= Result.V async def test_handshake_and_close(config, **kwargs): - async with aioquic.connect(config.host, config.port, **kwargs) as connection: + async with connect(config.host, config.port, **kwargs) as connection: config.result |= Result.H - if connection._stateless_retry_count == 1: + if connection._connection._stateless_retry_count == 1: config.result |= Result.S config.result |= Result.C @@ -103,7 +103,7 @@ async def test_data_transfer(config, **kwargs): if config.path is None: return - async with aioquic.connect(config.host, config.port, **kwargs) as connection: + async with connect(config.host, config.port, **kwargs) as connection: response1 = await http_request(connection, config.path) response2 = await http_request(connection, config.path) @@ -119,7 +119,7 @@ def session_ticket_handler(ticket): saved_ticket = ticket # connect a first time, receive a ticket - async with aioquic.connect( + async with connect( config.host, config.port, session_ticket_handler=session_ticket_handler, @@ -129,22 +129,22 @@ def session_ticket_handler(ticket): # connect a second time, with the ticket if saved_ticket is not None: - async with aioquic.connect( + async with connect( config.host, config.port, session_ticket=saved_ticket, **kwargs ) as connection: await connection.ping() # check session was resumed - if connection.tls.session_resumed: + if connection._connection.tls.session_resumed: config.result |= Result.R # check early data was accepted - if connection.tls.early_data_accepted: + if connection._connection.tls.early_data_accepted: config.result |= Result.Z async def test_key_update(config, **kwargs): - async with aioquic.connect(config.host, config.port, **kwargs) as connection: + async with connect(config.host, config.port, **kwargs) as connection: # cause some traffic await connection.ping() @@ -158,11 +158,11 @@ async def test_key_update(config, **kwargs): async def test_spin_bit(config, **kwargs): - async with aioquic.connect(config.host, config.port, **kwargs) as connection: + async with connect(config.host, config.port, **kwargs) as connection: spin_bits = set() for i in range(5): await connection.ping() - spin_bits.add(connection._spin_bit_peer) + spin_bits.add(connection._connection._spin_bit_peer) if len(spin_bits) == 2: config.result |= Result.P @@ -181,8 +181,8 @@ async def run(only=None, **kwargs): print("\n=== %s %s ===\n" % (config.name, test_name)) try: await asyncio.wait_for(test_func(config, **kwargs), timeout=5) - except Exception: - pass + except Exception as exc: + print(exc) print("") print_result(config) diff --git a/examples/server.py b/examples/server.py index a5268877b..f66f52063 100644 --- a/examples/server.py +++ b/examples/server.py @@ -7,7 +7,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -import aioquic +from aioquic.asyncio import serve try: import uvloop @@ -148,7 +148,7 @@ def pop(self, label): uvloop.install() loop = asyncio.get_event_loop() protocol = loop.run_until_complete( - aioquic.serve( + serve( host=args.host, port=args.port, alpn_protocols=["hq-20"], diff --git a/tests/test_high_level.py b/tests/test_asyncio.py similarity index 84% rename from tests/test_high_level.py rename to tests/test_asyncio.py index 7615d2146..656584c5a 100644 --- a/tests/test_high_level.py +++ b/tests/test_asyncio.py @@ -1,8 +1,8 @@ import asyncio from unittest import TestCase -from aioquic.client import connect -from aioquic.server import serve +from aioquic.asyncio.client import connect +from aioquic.asyncio.server import serve from .utils import SERVER_CERTIFICATE, SERVER_PRIVATE_KEY, run @@ -21,6 +21,7 @@ def pop(self, label): async def run_client(host, port=4433, request=b"ping", **kwargs): async with connect(host, port, **kwargs) as client: reader, writer = await client.create_stream() + assert writer.can_write_eof() is True writer.write(request) writer.write_eof() @@ -63,6 +64,22 @@ def test_connect_and_serve_large(self): ) self.assertEqual(response, data) + def test_connect_and_serve_writelines(self): + async def run_client_writelines(host, port=4433, **kwargs): + async with connect(host, port, **kwargs) as client: + reader, writer = await client.create_stream() + assert writer.can_write_eof() is True + + writer.writelines([b"01234567", b"89012345"]) + writer.write_eof() + + return await reader.read() + + _, response = run( + asyncio.gather(run_server(), run_client_writelines("127.0.0.1")) + ) + self.assertEqual(response, b"5432109876543210") + def test_connect_and_serve_with_session_ticket(self): client_ticket = None store = SessionTicketStore() diff --git a/tests/test_connection.py b/tests/test_connection.py index ce3b070c0..5948fc1c9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,10 +2,10 @@ import binascii import contextlib import io -import random -from unittest import TestCase +import time +from unittest import TestCase, skip -from aioquic import tls +from aioquic import events, tls from aioquic.buffer import Buffer from aioquic.configuration import QuicConfiguration from aioquic.connection import ( @@ -23,12 +23,10 @@ encode_quic_retry, encode_quic_version_negotiation, ) -from aioquic.packet_builder import QuicPacketBuilder +from aioquic.packet_builder import QuicDeliveryState, QuicPacketBuilder from .utils import SERVER_CERTIFICATE, SERVER_PRIVATE_KEY, run -RTT = 0.005 - CLIENT_ADDR = ("1.2.3.4", 1234) SERVER_ADDR = ("2.3.4.5", 4433) @@ -40,21 +38,6 @@ def encode_uint_var(v): return buf.data -class FakeTransport: - sent = 0 - target = None - - def __init__(self, local_addr, loss=0): - self.local_addr = local_addr - self.loop = asyncio.get_event_loop() - self.loss = loss - - def sendto(self, data, addr): - self.sent += 1 - if self.target is not None and random.random() >= self.loss: - self.loop.call_soon(self.target.datagram_received, data, self.local_addr) - - def client_receive_context(client, epoch=tls.Epoch.ONE_RTT): return QuicReceiveContext( epoch=epoch, @@ -64,31 +47,21 @@ def client_receive_context(client, epoch=tls.Epoch.ONE_RTT): ) -def create_standalone_client(): - client = QuicConnection(configuration=QuicConfiguration(is_client=True)) - client_transport = FakeTransport(CLIENT_ADDR, loss=0) - - # kick-off handshake - client.connection_made(client_transport) - client.connect(SERVER_ADDR) - - return client, client_transport - +def consume_events(connection): + while True: + event = connection.next_event() + if event is None: + break -def create_transport(client, server, loss=0): - client_transport = FakeTransport(CLIENT_ADDR, loss=loss) - client_transport.target = server - server_transport = FakeTransport(SERVER_ADDR, loss=loss) - server_transport.target = client +def create_standalone_client(self): + client = QuicConnection(configuration=QuicConfiguration(is_client=True)) # kick-off handshake - server.connection_made(server_transport) - client.connection_made(client_transport) - client.connect(SERVER_ADDR) - run(asyncio.sleep(0)) + client.connect(SERVER_ADDR, now=time.time()) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 1) - return client_transport, server_transport + return client @contextlib.contextmanager @@ -97,12 +70,12 @@ def client_and_server( client_patch=lambda x: None, server_options={}, server_patch=lambda x: None, - server_stream_handler=None, transport_options={}, ): client = QuicConnection( configuration=QuicConfiguration(is_client=True, **client_options) ) + client._ack_delay = 0 client_patch(client) server = QuicConnection( @@ -111,13 +84,15 @@ def client_and_server( certificate=SERVER_CERTIFICATE, private_key=SERVER_PRIVATE_KEY, **server_options - ), - stream_handler=server_stream_handler, + ) ) + server._ack_delay = 0 server_patch(server) # perform handshake - create_transport(client, server, **transport_options) + client.connect(SERVER_ADDR, now=time.time()) + for i in range(4): + tick(client, server) yield client, server @@ -130,50 +105,60 @@ def sequence_numbers(connection_ids): return list(map(lambda x: x.sequence_number, connection_ids)) -class QuicConnectionTest(TestCase): - def _test_connect_with_version(self, client_versions, server_versions): - async def serve_request(reader, writer): - # check request - request = await reader.read(1024) - self.assertEqual(request, b"ping") - - # send response - writer.write(b"pong") +def tick(client, server): + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, CLIENT_ADDR, now=time.time()) - # receives EOF - request = await reader.read() - self.assertEqual(request, b"") + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) - # sends EOF - writer.write(b"done") - writer.write_eof() +class QuicConnectionTest(TestCase): + def _test_connect_with_version(self, client_versions, server_versions): with client_and_server( client_options={"supported_versions": client_versions}, server_options={"supported_versions": server_versions}, - server_stream_handler=lambda reader, writer: asyncio.ensure_future( - serve_request(reader, writer) - ), ) as (client, server): - run(asyncio.gather(client.wait_connected(), server.wait_connected())) + # check handshake completed + self.assertEqual(type(client.next_event()), events.HandshakeCompleted) + for i in range(7): + self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) + self.assertIsNone(client.next_event()) + + self.assertEqual(type(server.next_event()), events.HandshakeCompleted) + for i in range(7): + self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) + self.assertIsNone(server.next_event()) # check each endpoint has available connection IDs for the peer self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) + self.assertEqual( + sequence_numbers(server._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] + ) - # clients sends data over stream - client_reader, client_writer = run(client.create_stream()) - client_writer.write(b"ping") - - # client receives pong - self.assertEqual(run(client_reader.read(1024)), b"pong") - - # client writes EOF - client_writer.write_eof() - - # client receives EOF - self.assertEqual(run(client_reader.read()), b"done") + # client closes the connection + client.close() + tick(client, server) + + # check connection closes on the client side + client.handle_timer(client.get_timer()) + event = client.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) + self.assertEqual(event.frame_type, None) + self.assertEqual(event.reason_phrase, "") + self.assertIsNone(client.next_event()) + + # check connection closes on the server side + server.handle_timer(server.get_timer()) + event = server.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) + self.assertEqual(event.frame_type, None) + self.assertEqual(event.reason_phrase, "") + self.assertIsNone(server.next_event()) def test_connect_draft_19(self): self._test_connect_with_version( @@ -194,8 +179,6 @@ def test_connect_with_log(self): client_options={"secrets_log_file": client_log_file}, server_options={"secrets_log_file": server_log_file}, ) as (client, server): - run(client.wait_connected()) - # check secrets were logged client_log = client_log_file.getvalue() server_log = server_log_file.getvalue() @@ -213,53 +196,23 @@ def test_connect_with_log(self): ], ) - def test_connection_lost(self): - with client_and_server() as (client, server): - run(client.wait_connected()) - - # send data over stream - client_reader, client_writer = run(client.create_stream()) - client_writer.write(b"ping") - run(asyncio.sleep(0)) - - # break connection - client.connection_lost(None) - self.assertEqual(run(client_reader.read()), b"") - - def test_connection_lost_with_exception(self): - with client_and_server() as (client, server): - run(client.wait_connected()) - - # send data over stream - client_reader, client_writer = run(client.create_stream()) - client_writer.write(b"ping") - run(asyncio.sleep(0)) - - # break connection - exc = Exception("some error") - client.connection_lost(exc) - with self.assertRaises(Exception) as cm: - run(client_reader.read()) - self.assertEqual(cm.exception, exc) - def test_consume_connection_id(self): with client_and_server() as (client, server): - run(client.wait_connected()) self.assertEqual( sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] ) - # change connection ID + # the client changes connection ID client._consume_connection_id() - client._send_pending() + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, CLIENT_ADDR, now=time.time()) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] ) - # wait one RTT - run(asyncio.sleep(RTT)) - # the server provides a new connection ID + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) self.assertEqual( sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] ) @@ -267,57 +220,53 @@ def test_consume_connection_id(self): def test_create_stream(self): with client_and_server() as (client, server): # client - reader, writer = run(client.create_stream()) - self.assertEqual(writer.get_extra_info("stream_id"), 0) - self.assertIsNotNone(writer.get_extra_info("connection")) + stream = client.create_stream() + self.assertEqual(stream.stream_id, 0) - reader, writer = run(client.create_stream()) - self.assertEqual(writer.get_extra_info("stream_id"), 4) + stream = client.create_stream() + self.assertEqual(stream.stream_id, 4) - reader, writer = run(client.create_stream(is_unidirectional=True)) - self.assertEqual(writer.get_extra_info("stream_id"), 2) + stream = client.create_stream(is_unidirectional=True) + self.assertEqual(stream.stream_id, 2) - reader, writer = run(client.create_stream(is_unidirectional=True)) - self.assertEqual(writer.get_extra_info("stream_id"), 6) + stream = client.create_stream(is_unidirectional=True) + self.assertEqual(stream.stream_id, 6) # server - reader, writer = run(server.create_stream()) - self.assertEqual(writer.get_extra_info("stream_id"), 1) + stream = server.create_stream() + self.assertEqual(stream.stream_id, 1) - reader, writer = run(server.create_stream()) - self.assertEqual(writer.get_extra_info("stream_id"), 5) + stream = server.create_stream() + self.assertEqual(stream.stream_id, 5) - reader, writer = run(server.create_stream(is_unidirectional=True)) - self.assertEqual(writer.get_extra_info("stream_id"), 3) + stream = server.create_stream(is_unidirectional=True) + self.assertEqual(stream.stream_id, 3) - reader, writer = run(server.create_stream(is_unidirectional=True)) - self.assertEqual(writer.get_extra_info("stream_id"), 7) + stream = server.create_stream(is_unidirectional=True) + self.assertEqual(stream.stream_id, 7) def test_create_stream_over_max_streams(self): with client_and_server() as (client, server): - run(client.wait_connected()) - # create streams for i in range(128): - client_reader, client_writer = run(client.create_stream()) + client.send_stream_data(i * 4, b"") # create one too many with self.assertRaises(ValueError) as cm: - client_reader, client_writer = run(client.create_stream()) + client.send_stream_data(128 * 4, b"") self.assertEqual(str(cm.exception), "Too many streams open") def test_decryption_error(self): with client_and_server() as (client, server): - run(client.wait_connected()) - # mess with encryption key server._cryptos[tls.Epoch.ONE_RTT].send.setup( tls.CipherSuite.AES_128_GCM_SHA256, bytes(48) ) - # close + # server sends close server.close(error_code=QuicErrorCode.NO_ERROR) - run(server.wait_closed()) + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) def test_tls_error(self): def patch(client): @@ -331,17 +280,17 @@ def patched_initialize(peer_cid: bytes): # handshake fails with client_and_server(client_patch=patch) as (client, server): - with self.assertRaises(QuicConnectionError) as cm: - run(asyncio.gather(client.wait_connected(), server.wait_connected())) - self.assertEqual(cm.exception.error_code, 326) - self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) - self.assertEqual( - cm.exception.reason_phrase, "No supported protocol version" - ) + timer_at = server.get_timer() + server.handle_timer(timer_at) - def test_datagram_received_wrong_version(self): - client, client_transport = create_standalone_client() - self.assertEqual(client_transport.sent, 1) + event = server.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, 326) + self.assertEqual(event.frame_type, QuicFrameType.CRYPTO) + self.assertEqual(event.reason_phrase, "No supported protocol version") + + def test_receive_datagram_wrong_version(self): + client = create_standalone_client(self) builder = QuicPacketBuilder( host_cid=client._peer_cid, @@ -355,14 +304,13 @@ def test_datagram_received_wrong_version(self): builder.end_packet() for datagram in builder.flush()[0]: - client.datagram_received(datagram, SERVER_ADDR) - self.assertEqual(client_transport.sent, 1) + client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 0) - def test_datagram_received_retry(self): - client, client_transport = create_standalone_client() - self.assertEqual(client_transport.sent, 1) + def test_receive_datagram_retry(self): + client = create_standalone_client(self) - client.datagram_received( + client.receive_datagram( encode_quic_retry( version=QuicProtocolVersion.DRAFT_20, source_cid=binascii.unhexlify("85abb547bf28be97"), @@ -371,14 +319,14 @@ def test_datagram_received_retry(self): retry_token=bytes(16), ), SERVER_ADDR, + now=time.time(), ) - self.assertEqual(client_transport.sent, 2) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 1) - def test_datagram_received_retry_wrong_destination_cid(self): - client, client_transport = create_standalone_client() - self.assertEqual(client_transport.sent, 1) + def test_receive_datagram_retry_wrong_destination_cid(self): + client = create_standalone_client(self) - client.datagram_received( + client.receive_datagram( encode_quic_retry( version=QuicProtocolVersion.DRAFT_20, source_cid=binascii.unhexlify("85abb547bf28be97"), @@ -387,15 +335,12 @@ def test_datagram_received_retry_wrong_destination_cid(self): retry_token=bytes(16), ), SERVER_ADDR, + now=time.time(), ) - self.assertEqual(client_transport.sent, 1) - - def test_error_received(self): - client, _ = create_standalone_client() - client.error_received(OSError("foo")) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 0) def test_handle_ack_frame_ecn(self): - client, client_transport = create_standalone_client() + client = create_standalone_client(self) client._handle_ack_frame( client_receive_context(client), @@ -405,15 +350,15 @@ def test_handle_ack_frame_ecn(self): def test_handle_connection_close_frame(self): with client_and_server() as (client, server): - run(server.wait_connected()) server.close( error_code=QuicErrorCode.NO_ERROR, frame_type=QuicFrameType.PADDING ) + tick(server, client) def test_handle_connection_close_frame_app(self): with client_and_server() as (client, server): - run(server.wait_connected()) server.close(error_code=QuicErrorCode.NO_ERROR) + tick(server, client) def test_handle_data_blocked_frame(self): with client_and_server() as (client, server): @@ -439,8 +384,7 @@ def test_handle_max_data_frame(self): def test_handle_max_stream_data_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 - reader, writer = run(client.create_stream()) - stream = writer.transport + stream = client.create_stream() self.assertEqual(stream.max_stream_data_remote, 1048576) # client receives MAX_STREAM_DATA raising limit @@ -462,7 +406,7 @@ def test_handle_max_stream_data_frame(self): def test_handle_max_stream_data_frame_receive_only(self): with client_and_server() as (client, server): # server creates unidirectional stream 3 - run(server.create_stream(is_unidirectional=True)) + server.create_stream(is_unidirectional=True) # client receives MAX_STREAM_DATA: 3, 1 with self.assertRaises(QuicConnectionError) as cm: @@ -527,21 +471,26 @@ def test_handle_new_token_frame(self): def test_handle_path_challenge_frame(self): with client_and_server() as (client, server): # client changes address and sends some data - client._transport.local_addr = ("1.2.3.4", 2345) - reader, writer = run(client.create_stream()) - writer.write(b"01234567") + client.send_stream_data(0, b"01234567") + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) - # wait one RTT - run(asyncio.sleep(RTT)) + # check paths + self.assertEqual(len(server._network_paths), 2) + self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) + self.assertFalse(server._network_paths[0].is_validated) + self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) + self.assertTrue(server._network_paths[1].is_validated) # server sends PATH_CHALLENGE and receives PATH_RESPONSE - self.assertEqual(len(server._network_paths), 2) + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) - # check new path + # check paths self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) self.assertTrue(server._network_paths[0].is_validated) - - # check old path self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) self.assertTrue(server._network_paths[1].is_validated) @@ -560,7 +509,7 @@ def test_handle_path_response_frame_bad(self): def test_handle_reset_stream_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 - run(client.create_stream()) + client.create_stream() # client receives RESET_STREAM client._handle_reset_stream_frame( @@ -572,7 +521,7 @@ def test_handle_reset_stream_frame(self): def test_handle_reset_stream_frame_send_only(self): with client_and_server() as (client, server): # client creates unidirectional stream 2 - run(client.create_stream(is_unidirectional=True)) + client.create_stream(is_unidirectional=True) # client receives RESET_STREAM with self.assertRaises(QuicConnectionError) as cm: @@ -587,7 +536,6 @@ def test_handle_reset_stream_frame_send_only(self): def test_handle_retire_connection_id_frame(self): with client_and_server() as (client, server): - run(client.wait_connected()) self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) @@ -604,7 +552,6 @@ def test_handle_retire_connection_id_frame(self): def test_handle_retire_connection_id_frame_current_cid(self): with client_and_server() as (client, server): - run(client.wait_connected()) self.assertEqual( sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] ) @@ -630,7 +577,7 @@ def test_handle_retire_connection_id_frame_current_cid(self): def test_handle_stop_sending_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 - run(client.create_stream()) + client.create_stream() # client receives STOP_SENDING client._handle_stop_sending_frame( @@ -642,7 +589,7 @@ def test_handle_stop_sending_frame(self): def test_handle_stop_sending_frame_receive_only(self): with client_and_server() as (client, server): # server creates unidirectional stream 3 - run(server.create_stream(is_unidirectional=True)) + server.create_stream(is_unidirectional=True) # client receives STOP_SENDING with self.assertRaises(QuicConnectionError) as cm: @@ -709,7 +656,7 @@ def test_handle_stream_frame_over_max_streams(self): def test_handle_stream_frame_send_only(self): with client_and_server() as (client, server): # client creates unidirectional stream 2 - run(client.create_stream(is_unidirectional=True)) + client.create_stream(is_unidirectional=True) # client receives STREAM frame with self.assertRaises(QuicConnectionError) as cm: @@ -738,7 +685,7 @@ def test_handle_stream_frame_wrong_initiator(self): def test_handle_stream_data_blocked_frame(self): with client_and_server() as (client, server): # client creates bidirectional stream 0 - run(client.create_stream()) + client.create_stream() # client receives STREAM_DATA_BLOCKED client._handle_stream_data_blocked_frame( @@ -750,7 +697,7 @@ def test_handle_stream_data_blocked_frame(self): def test_handle_stream_data_blocked_frame_send_only(self): with client_and_server() as (client, server): # client creates unidirectional stream 2 - run(client.create_stream(is_unidirectional=True)) + client.create_stream(is_unidirectional=True) # client receives STREAM_DATA_BLOCKED with self.assertRaises(QuicConnectionError) as cm: @@ -814,6 +761,27 @@ def test_payload_received_malformed_frame(self): self.assertEqual(cm.exception.frame_type, 0x1C) self.assertEqual(cm.exception.reason_phrase, "Failed to parse frame") + def test_send_ping(self): + with client_and_server() as (client, server): + consume_events(client) + + # client sends ping, server ACKs it + client.send_ping(uid=12345) + tick(client, server) + + # check event + event = client.next_event() + self.assertEqual(type(event), events.PongReceived) + self.assertEqual(event.uid, 12345) + + # client sends another ping + client.send_ping(uid=23456) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 1) + + # ping is lost + client._on_ping_delivery(QuicDeliveryState.LOST, (23456,)) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 1) + def test_stream_direction(self): with client_and_server() as (client, server): for off in [0, 4, 8]: @@ -842,43 +810,44 @@ def test_stream_direction(self): self.assertTrue(server._stream_can_send(off + 3)) def test_version_negotiation_fail(self): - client, client_transport = create_standalone_client() - self.assertEqual(client_transport.sent, 1) + client = create_standalone_client(self) # no common version, no retry - client.datagram_received( + client.receive_datagram( encode_quic_version_negotiation( source_cid=client._peer_cid, destination_cid=client.host_cid, supported_versions=[0xFF000011], # DRAFT_16 ), SERVER_ADDR, + now=time.time(), ) - self.assertEqual(client_transport.sent, 1) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 0) - with self.assertRaises(QuicConnectionError) as cm: - run(client.wait_connected()) - self.assertEqual(cm.exception.error_code, QuicErrorCode.INTERNAL_ERROR) - self.assertEqual(cm.exception.frame_type, None) + event = client.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, QuicErrorCode.INTERNAL_ERROR) + self.assertEqual(event.frame_type, None) self.assertEqual( - cm.exception.reason_phrase, "Could not find a common protocol version" + event.reason_phrase, "Could not find a common protocol version" ) def test_version_negotiation_ok(self): - client, client_transport = create_standalone_client() - self.assertEqual(client_transport.sent, 1) + client = create_standalone_client(self) # found a common version, retry - client.datagram_received( + client.receive_datagram( encode_quic_version_negotiation( source_cid=client._peer_cid, destination_cid=client.host_cid, supported_versions=[QuicProtocolVersion.DRAFT_19], ), SERVER_ADDR, + now=time.time(), ) - self.assertEqual(client_transport.sent, 2) + self.assertEqual(len(client.datagrams_to_send(now=time.time())), 1) + @skip def test_with_packet_loss_during_app_data(self): """ This test ensures stream data is successfully sent and received @@ -908,13 +877,14 @@ async def serve_request(reader, writer): server._transport.loss = 0.25 # create stream and send data - client_reader, client_writer = run(client.create_stream()) + client_reader, client_writer = client.create_stream() client_writer.write(client_data) client_writer.write_eof() # check response self.assertEqual(run(client_reader.read()), server_data) + @skip def test_with_packet_loss_during_handshake(self): """ This test ensures handshake success and stream data is successfully sent @@ -938,7 +908,7 @@ async def serve_request(reader, writer): run(asyncio.gather(client.wait_connected(), server.wait_connected())) # create stream and send data - client_reader, client_writer = run(client.create_stream()) + client_reader, client_writer = client.create_stream() client_writer.write(client_data) client_writer.write_eof() diff --git a/tests/test_stream.py b/tests/test_stream.py index ab969ba62..0b550bc14 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,17 +1,9 @@ -import asyncio from unittest import TestCase from aioquic.packet import QuicStreamFrame from aioquic.packet_builder import QuicDeliveryState from aioquic.stream import QuicStream -from .utils import run - - -async def delay(coro): - await asyncio.sleep(0.1) - await coro() - class QuicStreamTest(TestCase): def test_recv_empty(self): @@ -74,15 +66,6 @@ def test_recv_ordered_2(self): self.assertEqual(list(stream._recv_ranges), []) self.assertEqual(stream._recv_buffer_start, 16) - def test_recv_ordered_3(self): - stream = QuicStream(stream_id=0) - - async def add_frame(): - stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")) - - data, _ = run(asyncio.gather(stream.reader.read(1024), delay(add_frame))) - self.assertEqual(data, b"01234567") - def test_recv_unordered(self): stream = QuicStream() @@ -156,14 +139,14 @@ def test_recv_fin(self): stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")) stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)) - self.assertEqual(run(stream.reader.read()), b"0123456789012345") + self.assertEqual(stream.pull_data(), b"0123456789012345") def test_recv_fin_out_of_order(self): stream = QuicStream(stream_id=0) stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)) stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")) - self.assertEqual(run(stream.reader.read()), b"0123456789012345") + self.assertEqual(stream.pull_data(), b"0123456789012345") def test_recv_fin_then_data(self): stream = QuicStream(stream_id=0) @@ -178,17 +161,16 @@ def test_recv_fin_twice(self): stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)) stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)) - self.assertEqual(run(stream.reader.read()), b"0123456789012345") + self.assertEqual(stream.pull_data(), b"0123456789012345") def test_recv_fin_without_data(self): stream = QuicStream(stream_id=0) stream.add_frame(QuicStreamFrame(offset=0, data=b"", fin=True)) - self.assertEqual(run(stream.reader.read()), b"") + self.assertEqual(stream.pull_data(), b"") def test_send_data(self): stream = QuicStream() - self.assertTrue(stream.can_write_eof()) # nothing to send yet frame = stream.get_frame(8) @@ -420,31 +402,3 @@ def test_send_fin_only_despite_blocked(self): # nothing more to send frame = stream.get_frame(8) self.assertIsNone(frame) - - def test_send_data_using_writelines(self): - stream = QuicStream() - - # nothing to send yet - frame = stream.get_frame(8) - self.assertIsNone(frame) - - # write data, send a chunk - stream.writelines([b"01234567", b"89012345"]) - self.assertEqual(list(stream._send_pending), [range(0, 16)]) - frame = stream.get_frame(8) - self.assertEqual(frame.data, b"01234567") - self.assertFalse(frame.fin) - self.assertEqual(frame.offset, 0) - self.assertEqual(list(stream._send_pending), [range(8, 16)]) - - # send another chunk - frame = stream.get_frame(8) - self.assertEqual(frame.data, b"89012345") - self.assertFalse(frame.fin) - self.assertEqual(frame.offset, 8) - self.assertEqual(list(stream._send_pending), []) - - # nothing more to send - frame = stream.get_frame(8) - self.assertIsNone(frame) - self.assertEqual(list(stream._send_pending), [])