diff --git a/tokio-util/src/cfg.rs b/tokio-util/src/cfg.rs index 4035255aff0..7957f482409 100644 --- a/tokio-util/src/cfg.rs +++ b/tokio-util/src/cfg.rs @@ -60,6 +60,15 @@ macro_rules! cfg_rt { } } +macro_rules! cfg_not_rt { + ($($item:item)*) => { + $( + #[cfg(not(feature = "rt"))] + $item + )* + } +} + macro_rules! cfg_time { ($($item:item)*) => { $( diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index f5a182b5ce8..15a3b4011ee 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -14,6 +14,7 @@ mod copy_to_bytes; mod inspect; mod read_buf; mod reader_stream; +pub mod simplex; mod sink_writer; mod stream_reader; diff --git a/tokio-util/src/io/simplex.rs b/tokio-util/src/io/simplex.rs new file mode 100644 index 00000000000..bdf803fbad2 --- /dev/null +++ b/tokio-util/src/io/simplex.rs @@ -0,0 +1,343 @@ +//! Unidirectional byte-oriented channel. + +use crate::util::poll_proceed; + +use bytes::Buf; +use bytes::BytesMut; +use futures_core::ready; +use std::io::Error as IoError; +use std::io::ErrorKind as IoErrorKind; +use std::io::IoSlice; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +type IoResult = Result; + +const CLOSED_ERROR_MSG: &str = "simplex has been closed"; + +#[derive(Debug)] +struct Inner { + /// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached + backpressure_boundary: usize, + + /// either [`Sender`] or [`Receiver`] is closed + is_closed: bool, + + /// Waker used to wake the [`Receiver`] + receiver_waker: Option, + + /// Waker used to wake the [`Sender`] + sender_waker: Option, + + /// Buffer used to read and write data + buf: BytesMut, +} + +impl Inner { + fn with_capacity(capacity: usize) -> Self { + Self { + backpressure_boundary: capacity, + is_closed: false, + receiver_waker: None, + sender_waker: None, + buf: BytesMut::with_capacity(capacity), + } + } + + fn register_receiver_waker(&mut self, waker: &Waker) -> Option { + match self.receiver_waker.as_mut() { + Some(old) if old.will_wake(waker) => None, + _ => self.receiver_waker.replace(waker.clone()), + } + } + + fn register_sender_waker(&mut self, waker: &Waker) -> Option { + match self.sender_waker.as_mut() { + Some(old) if old.will_wake(waker) => None, + _ => self.sender_waker.replace(waker.clone()), + } + } + + fn take_receiver_waker(&mut self) -> Option { + self.receiver_waker.take() + } + + fn take_sender_waker(&mut self) -> Option { + self.sender_waker.take() + } + + fn is_closed(&self) -> bool { + self.is_closed + } + + fn close_receiver(&mut self) -> Option { + self.is_closed = true; + self.take_sender_waker() + } + + fn close_sender(&mut self) -> Option { + self.is_closed = true; + self.take_receiver_waker() + } +} + +/// Receiver of the simplex channel. +/// +/// You can still read the remaining data from the buffer +/// even if the write half has been dropped. +/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details. +#[derive(Debug)] +pub struct Receiver { + inner: Arc>, +} + +impl Drop for Receiver { + /// This also wakes up the [`Sender`]. + fn drop(&mut self) { + let maybe_waker = { + let mut inner = self.inner.lock().unwrap(); + inner.close_receiver() + }; + + if let Some(waker) = maybe_waker { + waker.wake(); + } + } +} + +impl AsyncRead for Receiver { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let coop = ready!(poll_proceed(cx)); + + let mut inner = self.inner.lock().unwrap(); + + let to_read = buf.remaining().min(inner.buf.remaining()); + if to_read == 0 { + if inner.is_closed() || buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let old_waker = inner.register_receiver_waker(cx.waker()); + let maybe_waker = inner.take_sender_waker(); + + // unlock before waking up and dropping old waker + drop(inner); + drop(old_waker); + if let Some(waker) = maybe_waker { + waker.wake(); + } + return Poll::Pending; + } + + // this is to avoid starving other tasks + coop.made_progress(); + + buf.put_slice(&inner.buf[..to_read]); + inner.buf.advance(to_read); + + let waker = inner.take_sender_waker(); + drop(inner); // unlock before waking up + if let Some(waker) = waker { + waker.wake(); + } + + Poll::Ready(Ok(())) + } +} + +/// Sender of the simplex channel. +/// +/// ## Shutdown +/// +/// See [`Sender::poll_shutdown`]. +#[derive(Debug)] +pub struct Sender { + inner: Arc>, +} + +impl Drop for Sender { + /// This also wakes up the [`Receiver`]. + fn drop(&mut self) { + let maybe_waker = { + let mut inner = self.inner.lock().unwrap(); + inner.close_sender() + }; + + if let Some(waker) = maybe_waker { + waker.wake(); + } + } +} + +impl AsyncWrite for Sender { + /// # Errors + /// + /// This method will return [`IoErrorKind::BrokenPipe`] + /// if the channel has been closed. + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let coop = ready!(poll_proceed(cx)); + + let mut inner = self.inner.lock().unwrap(); + + if inner.is_closed() { + return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG))); + } + + let free = inner + .backpressure_boundary + .checked_sub(inner.buf.len()) + .expect("backpressure boundary overflow"); + let to_write = buf.len().min(free); + if to_write == 0 { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + let old_waker = inner.register_sender_waker(cx.waker()); + let waker = inner.take_receiver_waker(); + + // unlock before waking up and dropping old waker + drop(inner); + drop(old_waker); + if let Some(waker) = waker { + waker.wake(); + } + + return Poll::Pending; + } + + // this is to avoid starving other tasks + coop.made_progress(); + + inner.buf.extend_from_slice(&buf[..to_write]); + + let waker = inner.take_receiver_waker(); + drop(inner); // unlock before waking up + if let Some(waker) = waker { + waker.wake(); + } + + Poll::Ready(Ok(to_write)) + } + + /// # Errors + /// + /// This method will return [`IoErrorKind::BrokenPipe`] + /// if the channel has been closed. + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let inner = self.inner.lock().unwrap(); + if inner.is_closed() { + Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG))) + } else { + Poll::Ready(Ok(())) + } + } + + /// After returns [`Poll::Ready`], all the following call to + /// [`Sender::poll_write`] and [`Sender::poll_flush`] + /// will return error. + /// + /// The [`Receiver`] can still be used to read remaining data + /// until all bytes have been consumed. + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let maybe_waker = { + let mut inner = self.inner.lock().unwrap(); + inner.close_sender() + }; + + if let Some(waker) = maybe_waker { + waker.wake(); + } + + Poll::Ready(Ok(())) + } + + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let coop = ready!(poll_proceed(cx)); + + let mut inner = self.inner.lock().unwrap(); + if inner.is_closed() { + return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG))); + } + + let free = inner + .backpressure_boundary + .checked_sub(inner.buf.len()) + .expect("backpressure boundary overflow"); + if free == 0 { + let old_waker = inner.register_sender_waker(cx.waker()); + let maybe_waker = inner.take_receiver_waker(); + + // unlock before waking up and dropping old waker + drop(inner); + drop(old_waker); + if let Some(waker) = maybe_waker { + waker.wake(); + } + + return Poll::Pending; + } + + // this is to avoid starving other tasks + coop.made_progress(); + + let mut rem = free; + for buf in bufs { + if rem == 0 { + break; + } + + let to_write = buf.len().min(rem); + if to_write == 0 { + assert_ne!(rem, 0); + assert_eq!(buf.len(), 0); + continue; + } + + inner.buf.extend_from_slice(&buf[..to_write]); + rem -= to_write; + } + + let waker = inner.take_receiver_waker(); + drop(inner); // unlock before waking up + if let Some(waker) = waker { + waker.wake(); + } + + Poll::Ready(Ok(free - rem)) + } +} + +/// Create a simplex channel. +/// +/// The `capacity` parameter specifies the maximum number of bytes that can be +/// stored in the channel without making the [`Sender::poll_write`] +/// return [`Poll::Pending`]. +/// +/// # Panics +/// +/// This function will panic if `capacity` is zero. +pub fn new(capacity: usize) -> (Sender, Receiver) { + assert_ne!(capacity, 0, "capacity must be greater than zero"); + + let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity))); + let tx = Sender { + inner: Arc::clone(&inner), + }; + let rx = Receiver { inner }; + (tx, rx) +} diff --git a/tokio-util/src/util/mod.rs b/tokio-util/src/util/mod.rs index a17f25a6b91..aaba542c289 100644 --- a/tokio-util/src/util/mod.rs +++ b/tokio-util/src/util/mod.rs @@ -6,3 +6,26 @@ pub(crate) use maybe_dangling::MaybeDangling; #[cfg(any(feature = "io", feature = "codec"))] #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] pub use poll_buf::{poll_read_buf, poll_write_buf}; + +cfg_rt! { + #[cfg_attr(not(feature = "io"), allow(unused))] + pub(crate) use tokio::task::coop::poll_proceed; +} + +cfg_not_rt! { + #[cfg_attr(not(feature = "io"), allow(unused))] + use std::task::{Context, Poll}; + + #[cfg_attr(not(feature = "io"), allow(unused))] + pub(crate) struct RestoreOnPending; + + #[cfg_attr(not(feature = "io"), allow(unused))] + impl RestoreOnPending { + pub(crate) fn made_progress(&self) {} + } + + #[cfg_attr(not(feature = "io"), allow(unused))] + pub(crate) fn poll_proceed(_cx: &mut Context<'_>) -> Poll { + Poll::Ready(RestoreOnPending) + } +} diff --git a/tokio-util/tests/io_simplex.rs b/tokio-util/tests/io_simplex.rs new file mode 100644 index 00000000000..0b54b79864a --- /dev/null +++ b/tokio-util/tests/io_simplex.rs @@ -0,0 +1,356 @@ +use futures::pin_mut; +use futures_test::task::noop_context; +use std::io::IoSlice; +use std::task::Poll; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready}; +use tokio_util::io::simplex; + +/// Sanity check for single-threaded operation. +#[tokio::test] +async fn single_thread() { + const N: usize = 64; + const MSG: &[u8] = b"Hello, world!"; + const CAPS: &[usize] = &[1, MSG.len() / 2, MSG.len() - 1, MSG.len(), MSG.len() + 1]; + + // test different buffer capacities to cover edge cases + for &capacity in CAPS { + let (mut tx, mut rx) = simplex::new(capacity); + + for _ in 0..N { + let mut read = 0; + let mut write = 0; + let mut buf = [0; MSG.len()]; + + while read < MSG.len() || write < MSG.len() { + if write < MSG.len() { + let n = tx.write(&MSG[write..]).await.unwrap(); + write += n; + } + + if read < MSG.len() { + let n = rx.read(&mut buf[read..]).await.unwrap(); + read += n; + } + } + + assert_eq!(&buf[..], MSG); + } + } +} + +/// Sanity check for multi-threaded operation. +#[test] +#[cfg(not(target_os = "wasi"))] // No thread on wasi. +fn multi_thread() { + use futures::executor::block_on; + use std::thread; + + const N: usize = 64; + const MSG: &[u8] = b"Hello, world!"; + const CAPS: &[usize] = &[1, MSG.len() / 2, MSG.len() - 1, MSG.len(), MSG.len() + 1]; + + // test different buffer capacities to cover edge cases + for &capacity in CAPS { + let (mut tx, mut rx) = simplex::new(capacity); + + let jh0 = thread::spawn(move || { + block_on(async { + let mut buf = vec![0; MSG.len()]; + for _ in 0..N { + rx.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf[..], MSG); + buf.clear(); + buf.resize(MSG.len(), 0); + } + }); + }); + + let jh1 = thread::spawn(move || { + block_on(async { + for _ in 0..N { + tx.write_all(MSG).await.unwrap(); + } + }); + }); + + jh0.join().unwrap(); + jh1.join().unwrap(); + } +} + +#[test] +#[should_panic(expected = "capacity must be greater than zero")] +fn zero_capacity() { + let _ = simplex::new(0); +} + +/// The `Receiver::poll_read` should return `Poll::Ready(Ok(()))` +/// if the `ReadBuf` has zero remaining capacity. +#[tokio::test] +async fn read_buf_is_full() { + let (_tx, rx) = simplex::new(32); + let mut buf = ReadBuf::new(&mut []); + tokio::pin!(rx); + assert_ready!(rx.as_mut().poll_read(&mut noop_context(), &mut buf)).unwrap(); + assert_eq!(buf.filled().len(), 0); +} + +/// The `Sender::poll_write` should return `Poll::Ready(Ok(0))` +/// if the input buffer has zero length. +#[tokio::test] +async fn write_buf_is_empty() { + let (tx, _rx) = simplex::new(32); + tokio::pin!(tx); + let n = assert_ready!(tx.as_mut().poll_write(&mut noop_context(), &[])).unwrap(); + assert_eq!(n, 0); +} + +/// The `Sender` should returns error if the `Receiver` has been dropped. +#[tokio::test] +async fn drop_receiver_0() { + let (mut tx, rx) = simplex::new(32); + drop(rx); + + tx.write_u8(1).await.unwrap_err(); +} + +/// The `Sender` should be woken up if the `Receiver` has been dropped. +#[tokio::test] +async fn drop_receiver_1() { + let (mut tx, rx) = simplex::new(1); + let mut write_task = spawn(tx.write_u16(1)); + assert_pending!(write_task.poll()); + + assert!(!write_task.is_woken()); + drop(rx); + assert!(write_task.is_woken()); +} + +/// The `Receiver` should return error if: +/// +/// - The `Sender` has been dropped. +/// - AND there is no remaining data in the buffer. +#[tokio::test] +async fn drop_sender_0() { + const MSG: &[u8] = b"Hello, world!"; + + let (tx, mut rx) = simplex::new(32); + drop(tx); + + let mut buf = vec![0; MSG.len()]; + rx.read_exact(&mut buf).await.unwrap_err(); +} + +/// The `Receiver` should be woken up if: +/// +/// - The `Sender` has been dropped. +/// - AND there is still remaining data in the buffer. +#[tokio::test] +async fn drop_sender_1() { + let (mut tx, mut rx) = simplex::new(2); + let mut buf = vec![]; + let mut read_task = spawn(rx.read_to_end(&mut buf)); + assert_pending!(read_task.poll()); + + tx.write_u8(1).await.unwrap(); + assert_pending!(read_task.poll()); + + assert!(!read_task.is_woken()); + drop(tx); + assert!(read_task.is_woken()); + + read_task.await.unwrap(); + assert_eq!(buf, vec![1]); +} + +/// All following calls to `Sender::poll_write` and `Sender::poll_flush` +/// should return error after `shutdown` has been called. +#[tokio::test] +async fn shutdown_sender_0() { + const MSG: &[u8] = b"Hello, world!"; + + let (mut tx, _rx) = simplex::new(32); + tx.shutdown().await.unwrap(); + + tx.write_all(MSG).await.unwrap_err(); + tx.flush().await.unwrap_err(); +} + +/// The `Sender::poll_shutdown` should be called multiple times +/// without error. +#[tokio::test] +async fn shutdown_sender_1() { + let (mut tx, _rx) = simplex::new(32); + tx.shutdown().await.unwrap(); + tx.shutdown().await.unwrap(); +} + +/// The `Sender::poll_shutdown` should wake up the `Receiver` +#[tokio::test] +async fn shutdown_sender_2() { + let (mut tx, mut rx) = simplex::new(32); + + let mut buf = vec![]; + let mut read_task = spawn(rx.read_to_end(&mut buf)); + assert_pending!(read_task.poll()); + + tx.write_u8(1).await.unwrap(); + assert_pending!(read_task.poll()); + + assert!(!read_task.is_woken()); + tx.shutdown().await.unwrap(); + assert!(read_task.is_woken()); + + read_task.await.unwrap(); + assert_eq!(buf, vec![1]); +} + +/// Both `Sender` and `Receiver` should yield periodically +/// in a tight-loop. +#[tokio::test] +#[cfg(feature = "rt")] +async fn cooperative_scheduling() { + // this magic number is copied from + // https://github.com/tokio-rs/tokio/blob/925c614c89d0a26777a334612e2ed6ad0e7935c3/tokio/src/task/coop/mod.rs#L116 + const INITIAL_BUDGET: usize = 128; + + let (tx, _rx) = simplex::new(INITIAL_BUDGET * 2); + pin_mut!(tx); + let mut is_pending = false; + for _ in 0..INITIAL_BUDGET + 1 { + match tx.as_mut().poll_write(&mut noop_context(), &[0u8; 1]) { + Poll::Pending => { + is_pending = true; + break; + } + Poll::Ready(Ok(1)) => {} + Poll::Ready(Ok(n)) => panic!("wrote too many bytes: {n}"), + Poll::Ready(Err(e)) => panic!("{e}"), + } + } + assert!(is_pending); + + let (tx, _rx) = simplex::new(INITIAL_BUDGET * 2); + pin_mut!(tx); + let mut is_pending = false; + let io_slices = &[IoSlice::new(&[0u8; 1])]; + for _ in 0..INITIAL_BUDGET + 1 { + match tx + .as_mut() + .poll_write_vectored(&mut noop_context(), io_slices) + { + Poll::Pending => { + is_pending = true; + break; + } + Poll::Ready(Ok(1)) => {} + Poll::Ready(Ok(n)) => panic!("wrote too many bytes: {n}"), + Poll::Ready(Err(e)) => panic!("{e}"), + } + } + assert!(is_pending); + + let (mut tx, rx) = simplex::new(INITIAL_BUDGET * 2); + tx.write_all(&[0u8; INITIAL_BUDGET + 2]).await.unwrap(); + pin_mut!(rx); + let mut is_pending = false; + for _ in 0..INITIAL_BUDGET + 1 { + let mut buf = [0u8; 1]; + let mut buf = ReadBuf::new(&mut buf); + match rx.as_mut().poll_read(&mut noop_context(), &mut buf) { + Poll::Pending => { + is_pending = true; + break; + } + Poll::Ready(Ok(())) => assert_eq!(buf.filled().len(), 1), + Poll::Ready(Err(e)) => panic!("{e}"), + } + } + assert!(is_pending); +} + +/// The capacity is exactly same as the total length of the vectored buffers. +#[tokio::test] +async fn poll_write_vectored_0() { + const MSG1: &[u8] = b"1"; + const MSG2: &[u8] = b"22"; + const MSG3: &[u8] = b"333"; + const MSG_LEN: usize = MSG1.len() + MSG2.len() + MSG3.len(); + + let io_slices = &[IoSlice::new(MSG1), IoSlice::new(MSG2), IoSlice::new(MSG3)]; + + let (tx, mut rx) = simplex::new(MSG_LEN); + tokio::pin!(tx); + let res = tx.poll_write_vectored(&mut noop_context(), io_slices); + let n = assert_ready!(res).unwrap(); + assert_eq!(n, MSG_LEN); + let mut buf = [0; MSG_LEN]; + let n = rx.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, MSG_LEN); + assert_eq!(&buf, b"122333"); +} + +/// The capacity is smaller than the total length of the vectored buffers. +#[tokio::test] +async fn poll_write_vectored_1() { + const MSG1: &[u8] = b"1"; + const MSG2: &[u8] = b"22"; + const MSG3: &[u8] = b"333"; + const CAPACITY: usize = MSG1.len() + MSG2.len() + 1; + + let io_slices = &[IoSlice::new(MSG1), IoSlice::new(MSG2), IoSlice::new(MSG3)]; + + let (tx, mut rx) = simplex::new(CAPACITY); + tokio::pin!(tx); + + // ==== The poll_write_vectored should write MSG1 and MSG2 fully, and MSG3 partially. ==== + let res = tx.poll_write_vectored(&mut noop_context(), io_slices); + let n = assert_ready!(res).unwrap(); + assert_eq!(n, CAPACITY); + let mut buf = [0; CAPACITY]; + let n = rx.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, CAPACITY); + assert_eq!(&buf, b"1223"); +} + +/// There are two empty buffers in the vectored buffers. +#[tokio::test] +async fn poll_write_vectored_2() { + const MSG1: &[u8] = b"1"; + const MSG2: &[u8] = b""; + const MSG3: &[u8] = b"22"; + const MSG4: &[u8] = b""; + const MSG5: &[u8] = b"333"; + const MSG_LEN: usize = MSG1.len() + MSG2.len() + MSG3.len() + MSG4.len() + MSG5.len(); + + let io_slices = &[ + IoSlice::new(MSG1), + IoSlice::new(MSG2), + IoSlice::new(MSG3), + IoSlice::new(MSG4), + IoSlice::new(MSG5), + ]; + + let (tx, mut rx) = simplex::new(MSG_LEN); + tokio::pin!(tx); + let res = tx.poll_write_vectored(&mut noop_context(), io_slices); + let n = assert_ready!(res).unwrap(); + assert_eq!(n, MSG_LEN); + let mut buf = [0; MSG_LEN]; + let n = rx.read_exact(&mut buf).await.unwrap(); + assert_eq!(n, MSG_LEN); + assert_eq!(&buf, b"122333"); +} + +/// The `Sender::poll_write_vectored` should return `Poll::Ready(Ok(0))` +/// if all the input buffers have zero length. +#[tokio::test] +async fn poll_write_vectored_3() { + let io_slices = &[IoSlice::new(&[]), IoSlice::new(&[]), IoSlice::new(&[])]; + let (tx, _rx) = simplex::new(32); + tokio::pin!(tx); + let n = assert_ready!(tx.poll_write_vectored(&mut noop_context(), io_slices)).unwrap(); + assert_eq!(n, 0); +}