Skip to content

Commit

Permalink
Implement AsyncBufRead (#100)
Browse files Browse the repository at this point in the history
* Implement AsyncBufRead for TlsStream types & Stream

* Implement AsyncRead using AsyncBufRead

* Reimplement AsyncRead for {server,client}::TlsStream in terms of AsyncBufRead
  • Loading branch information
goffrie authored Feb 11, 2025
1 parent 276625b commit 710cf25
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 99 deletions.
26 changes: 10 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rust-version = "1.71"
exclude = ["/.github", "/examples", "/scripts"]

[dependencies]
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
rustls = { version = "0.23.22", default-features = false, features = ["std"] }
tokio = "1.0"

[features]
Expand Down
77 changes: 50 additions & 27 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io;
use std::io::{self, BufRead as _};
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
Expand All @@ -9,7 +9,7 @@ use std::task::Waker;
use std::task::{Context, Poll};

use rustls::ClientConnection;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

use crate::common::{IoSession, Stream, TlsState};

Expand Down Expand Up @@ -82,50 +82,69 @@ impl<IO> IoSession for TlsStream<IO> {
}
}

#[cfg(feature = "early-data")]
impl<IO> TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
fn poll_early_data(&mut self, cx: &mut Context<'_>) {
// In the EarlyData state, we have not really established a Tls connection.
// Before writing data through `AsyncWrite` and completing the tls handshake,
// we ignore read readiness and return to pending.
//
// In order to avoid event loss,
// we need to register a waker and wake it up after tls is connected.
if self
.early_waker
.as_ref()
.filter(|waker| cx.waker().will_wake(waker))
.is_none()
{
self.early_waker = Some(cx.waker().clone());
}
}
}

impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let data = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = data.len().min(buf.remaining());
buf.put_slice(&data[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}

impl<IO> AsyncBufRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(..) => {
let this = self.get_mut();

// In the EarlyData state, we have not really established a Tls connection.
// Before writing data through `AsyncWrite` and completing the tls handshake,
// we ignore read readiness and return to pending.
//
// In order to avoid event loss,
// we need to register a waker and wake it up after tls is connected.
if this
.early_waker
.as_ref()
.filter(|waker| cx.waker().will_wake(waker))
.is_none()
{
this.early_waker = Some(cx.waker().clone());
}

self.get_mut().poll_early_data(cx);
Poll::Pending
}
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream =
let stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
let prev = buf.remaining();

match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
if prev == buf.remaining() || stream.eof {
match stream.poll_fill_buf(cx) {
Poll::Ready(Ok(buf)) => {
if buf.is_empty() {
this.state.shutdown_read();
}

Poll::Ready(Ok(()))
Poll::Ready(Ok(buf))
}
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
Expand All @@ -134,9 +153,13 @@ where
output => output,
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
}
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.session.reader().consume(amt);
}
}

impl<IO> AsyncWrite for TlsStream<IO>
Expand Down
80 changes: 51 additions & 29 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::io::{self, IoSlice, Read, Write};
use std::io::{self, BufRead as _, IoSlice, Read, Write};
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};

use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

mod handshake;
pub(crate) use handshake::{IoSession, MidHandshake};
Expand Down Expand Up @@ -180,18 +180,11 @@ where
};
}
}
}

impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'_, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>>
where
SD: 'a,
{
let mut io_pending = false;

// read a packet
Expand All @@ -209,22 +202,13 @@ where
}
}

match self.session.reader().read(buf.initialize_unfilled()) {
// If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
// connection with a `CloseNotify` message and no more data will be forthcoming.
//
// Rustls yielded more data: advance the buffer, then see if more data is coming.
//
// We don't need to modify `self.eof` here, because it is only a temporary mark.
// rustls will only return 0 if is has received `CloseNotify`,
// in which case no additional processing is required.
Ok(n) => {
buf.advance(n);
Poll::Ready(Ok(()))
match self.session.reader().into_first_chunk() {
Ok(buf) => {
// Note that this could be empty (i.e. EOF) if a `CloseNotify` has been
// received and there is no more buffered data.
Poll::Ready(Ok(buf))
}

// Rustls doesn't have more data to yield, but it believes the connection is open.
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
if !io_pending {
// If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
// but if it does, we can try again.
Expand All @@ -236,9 +220,47 @@ where

Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}

Err(err) => Poll::Ready(Err(err)),
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData + 'a,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let data = ready!(self.as_mut().poll_fill_buf(cx))?;
let amount = buf.remaining().min(data.len());
buf.put_slice(&data[..amount]);
self.session.reader().consume(amount);
Poll::Ready(Ok(()))
}
}

impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C>
where
C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
SD: SideData + 'a,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.get_mut();
Stream {
// reborrow
io: this.io,
session: this.session,
..*this
}
.poll_fill_buf(cx)
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.session.reader().consume(amt);
}
}

Expand Down
17 changes: 13 additions & 4 deletions src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,20 @@ impl AsyncWrite for Eof {

#[tokio::test]
async fn stream_good() -> io::Result<()> {
stream_good_impl(false).await
stream_good_impl(false, false).await
}

#[tokio::test]
async fn stream_good_vectored() -> io::Result<()> {
stream_good_impl(true).await
stream_good_impl(true, false).await
}

async fn stream_good_impl(vectored: bool) -> io::Result<()> {
#[tokio::test]
async fn stream_good_bufread() -> io::Result<()> {
stream_good_impl(false, true).await
}

async fn stream_good_impl(vectored: bool, bufread: bool) -> io::Result<()> {
const FILE: &[u8] = include_bytes!("../../README.md");

let (server, mut client) = make_pair();
Expand All @@ -177,7 +182,11 @@ async fn stream_good_impl(vectored: bool) -> io::Result<()> {
let mut stream = Stream::new(&mut good, &mut client);

let mut buf = Vec::new();
dbg!(stream.read_to_end(&mut buf).await)?;
if bufread {
dbg!(tokio::io::copy_buf(&mut stream, &mut buf).await)?;
} else {
dbg!(stream.read_to_end(&mut buf).await)?;
}
assert_eq!(buf, FILE);

dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?;
Expand Down
23 changes: 22 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub use rustls;
use rustls::pki_types::ServerName;
use rustls::server::AcceptedAlert;
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

macro_rules! ready {
( $e:expr ) => {
Expand Down Expand Up @@ -545,6 +545,27 @@ where
}
}

impl<T> AsyncBufRead for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
match self.get_mut() {
TlsStream::Client(x) => Pin::new(x).poll_fill_buf(cx),
TlsStream::Server(x) => Pin::new(x).poll_fill_buf(cx),
}
}

#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
match self.get_mut() {
TlsStream::Client(x) => Pin::new(x).consume(amt),
TlsStream::Server(x) => Pin::new(x).consume(amt),
}
}
}

impl<T> AsyncWrite for TlsStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
Expand Down
Loading

0 comments on commit 710cf25

Please sign in to comment.