diff --git a/tokio/src/loom/std/atomic_ptr.rs b/tokio/src/loom/std/atomic_ptr.rs index eb8e47557a2..f7fd56cc69b 100644 --- a/tokio/src/loom/std/atomic_ptr.rs +++ b/tokio/src/loom/std/atomic_ptr.rs @@ -11,10 +11,6 @@ impl AtomicPtr { let inner = std::sync::atomic::AtomicPtr::new(ptr); AtomicPtr { inner } } - - pub(crate) fn with_mut(&mut self, f: impl FnOnce(&mut *mut T) -> R) -> R { - f(self.inner.get_mut()) - } } impl Deref for AtomicPtr { diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 698e908ec59..f5bcc1b9418 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -186,7 +186,7 @@ impl Semaphore { /// Release `rem` permits to the semaphore's wait list, starting from the /// end of the queue. - /// + /// /// If `rem` exceeds the number of permits needed by the wait list, the /// remainder are assigned back to the semaphore. fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 9873dcb7214..0c8716f7795 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -109,12 +109,15 @@ //! } use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize}; +use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use crate::util::linked_list::{self, LinkedList}; use std::fmt; -use std::ptr; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; use std::usize; @@ -192,8 +195,8 @@ pub struct Receiver { /// Next position to read from next: u64, - /// Waiter state - wait: Arc, + /// Used to support the deprecated `poll_recv` fn + waiter: Option>>>, } /// Error returned by [`Sender::send`][Sender::send]. @@ -251,12 +254,9 @@ struct Shared { /// Mask a position -> index mask: usize, - /// Tail of the queue + /// Tail of the queue. Includes the rx wait list. tail: Mutex, - /// Stack of pending waiters - wait_stack: AtomicPtr, - /// Number of outstanding Sender handles num_tx: AtomicUsize, } @@ -271,6 +271,9 @@ struct Tail { /// True if the channel is closed closed: bool, + + /// Receivers waiting for a value + waiters: LinkedList, } /// Slot in the buffer @@ -296,23 +299,59 @@ struct Slot { val: UnsafeCell>, } -/// Tracks a waiting receiver -#[derive(Debug)] -struct WaitNode { - /// `true` if queued - queued: AtomicBool, +/// An entry in the wait queue +struct Waiter { + /// True if queued + queued: bool, + + /// Task waiting on the broadcast channel. + waker: Option, - /// Task to wake when a permit is made available. - waker: AtomicWaker, + /// Intrusive linked-list pointers. + pointers: linked_list::Pointers, - /// Next pointer in the stack of waiting senders. - next: UnsafeCell<*const WaitNode>, + /// Should not be `Unpin`. + _p: PhantomPinned, } struct RecvGuard<'a, T> { slot: RwLockReadGuard<'a, Slot>, } +/// Receive a value future +struct Recv +where + R: AsMut>, +{ + /// Receiver being waited on + receiver: R, + + /// Entry in the waiter `LinkedList` + waiter: UnsafeCell, + + _p: std::marker::PhantomData, +} + +/// `AsMut` is not implemented for `T` (coherence). Explicitly implementing +/// `AsMut` for `Receiver` would be included in the public API of the receiver +/// type. Instead, `Borrow` is used internally to bridge the gap. +struct Borrow(T); + +impl AsMut> for Borrow> { + fn as_mut(&mut self) -> &mut Receiver { + &mut self.0 + } +} + +impl<'a, T> AsMut> for Borrow<&'a mut Receiver> { + fn as_mut(&mut self) -> &mut Receiver { + &mut *self.0 + } +} + +unsafe impl> + Send, T: Send> Send for Recv {} +unsafe impl> + Sync, T: Send> Sync for Recv {} + /// Max number of receivers. Reserve space to lock. const MAX_RECEIVERS: usize = usize::MAX >> 2; @@ -386,19 +425,15 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { pos: 0, rx_cnt: 1, closed: false, + waiters: LinkedList::new(), }), - wait_stack: AtomicPtr::new(ptr::null_mut()), num_tx: AtomicUsize::new(1), }); let rx = Receiver { shared: shared.clone(), next: 0, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, }; let tx = Sender { shared }; @@ -508,11 +543,7 @@ impl Sender { Receiver { shared, next, - wait: Arc::new(WaitNode { - queued: AtomicBool::new(false), - waker: AtomicWaker::new(), - next: UnsafeCell::new(ptr::null()), - }), + waiter: None, } } @@ -589,34 +620,31 @@ impl Sender { slot.val.with_mut(|ptr| unsafe { *ptr = value }); } - // Release the slot lock before the tail lock + // Release the slot lock before notifying the receivers. drop(slot); + tail.notify_rx(); + // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread // is in the critical section. drop(tail); - // Notify waiting receivers - self.notify_rx(); - Ok(rem) } +} - fn notify_rx(&self) { - let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode; - - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - - // Update `curr` before toggling `queued` and waking - curr = waiter.next.with(|ptr| unsafe { *ptr }); +impl Tail { + fn notify_rx(&mut self) { + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; - // Unset queued - waiter.queued.store(false, SeqCst); + assert!(waiter.queued); + waiter.queued = false; - // Wake - waiter.waker.wake(); + let waker = waiter.waker.take().unwrap(); + waker.wake(); } } } @@ -640,15 +668,21 @@ impl Drop for Sender { impl Receiver { /// Locks the next value if there is one. - fn recv_ref(&mut self) -> Result, TryRecvError> { + fn recv_ref( + &mut self, + waiter: Option<(&UnsafeCell, &Waker)>, + ) -> Result, TryRecvError> { let idx = (self.next & self.shared.mask as u64) as usize; // The slot holding the next value to read let mut slot = self.shared.buffer[idx].read().unwrap(); if slot.pos != self.next { - // The receiver has read all current values in the channel - if slot.pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + // The receiver has read all current values in the channel and there + // is no waiter to register + if waiter.is_none() && next_pos == self.next { return Err(TryRecvError::Empty); } @@ -661,35 +695,83 @@ impl Receiver { // the slot lock. drop(slot); - let tail = self.shared.tail.lock().unwrap(); + let mut tail = self.shared.tail.lock().unwrap(); // Acquire slot lock again slot = self.shared.buffer[idx].read().unwrap(); - // `tail.pos` points to the slot that the **next** send writes to. If - // the channel is closed, the previous slot is the oldest value. - let mut adjust = 0; - if tail.closed { - adjust = 1 - } - let next = tail - .pos - .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + // Make sure the position did not change. This could happen in the + // unlikely event that the buffer is wrapped between dropping the + // read lock and acquiring the tail lock. + if slot.pos != self.next { + let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); + + if next_pos == self.next { + // Store the waker + if let Some((waiter, waker)) = waiter { + // Safety: called while locked. + unsafe { + // Only queue if not already queued + waiter.with_mut(|ptr| { + // If there is no waker **or** if the currently + // stored waker references a **different** task, + // track the tasks' waker to be notified on + // receipt of a new value. + match (*ptr).waker { + Some(ref w) if w.will_wake(waker) => {} + _ => { + (*ptr).waker = Some(waker.clone()); + } + } + + if !(*ptr).queued { + (*ptr).queued = true; + tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); + } + }); + } + } + + return Err(TryRecvError::Empty); + } - let missed = next.wrapping_sub(self.next); + // At this point, the receiver has lagged behind the sender by + // more than the channel capacity. The receiver will attempt to + // catch up by skipping dropped messages and setting the + // internal cursor to the **oldest** message stored by the + // channel. + // + // However, finding the oldest position is a bit more + // complicated than `tail-position - buffer-size`. When + // the channel is closed, the tail position is incremented to + // signal a new `None` message, but `None` is not stored in the + // channel itself (see issue #2425 for why). + // + // To account for this, if the channel is closed, the tail + // position is decremented by `buffer-size + 1`. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); - drop(tail); + let missed = next.wrapping_sub(self.next); - // The receiver is slow but no values have been missed - if missed == 0 { - self.next = self.next.wrapping_add(1); + drop(tail); - return Ok(RecvGuard { slot }); - } + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); - self.next = next; + return Ok(RecvGuard { slot }); + } + + self.next = next; - return Err(TryRecvError::Lagged(missed)); + return Err(TryRecvError::Lagged(missed)); + } } self.next = self.next.wrapping_add(1); @@ -746,22 +828,59 @@ where /// } /// ``` pub fn try_recv(&mut self) -> Result { - let guard = self.recv_ref()?; + let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } - #[doc(hidden)] // TODO: document + #[doc(hidden)] + #[deprecated(since = "0.2.21", note = "use async fn recv()")] pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(value) = ok_empty(self.try_recv())? { - return Poll::Ready(Ok(value)); + use Poll::{Pending, Ready}; + + // The borrow checker prohibits calling `self.poll_ref` while passing in + // a mutable ref to a field (as it should). To work around this, + // `waiter` is first *removed* from `self` then `poll_recv` is called. + // + // However, for safety, we must ensure that `waiter` is **not** dropped. + // It could be contained in the intrusive linked list. The `Receiver` + // drop implementation handles cleanup. + // + // The guard pattern is used to ensure that, on return, even due to + // panic, the waiter node is replaced on `self`. + + struct Guard<'a, T> { + waiter: Option>>>, + receiver: &'a mut Receiver, } - self.register_waker(cx.waker()); + impl<'a, T> Drop for Guard<'a, T> { + fn drop(&mut self) { + self.receiver.waiter = self.waiter.take(); + } + } - if let Some(value) = ok_empty(self.try_recv())? { - Poll::Ready(Ok(value)) - } else { - Poll::Pending + let waiter = self.waiter.take().or_else(|| { + Some(Box::pin(UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }))) + }); + + let guard = Guard { + waiter, + receiver: self, + }; + let res = guard + .receiver + .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker()))); + + match res { + Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)), + Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)), + Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Empty) => Pending, } } @@ -830,44 +949,14 @@ where /// assert_eq!(30, rx.recv().await.unwrap()); /// } pub async fn recv(&mut self) -> Result { - use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await - } - - fn register_waker(&self, cx: &Waker) { - self.wait.waker.register_by_ref(cx); - - if !self.wait.queued.load(SeqCst) { - // Set `queued` before queuing. - self.wait.queued.store(true, SeqCst); - - let mut curr = self.shared.wait_stack.load(SeqCst); - - // The ref count is decremented in `notify_rx` when all nodes are - // removed from the waiter stack. - let node = Arc::into_raw(self.wait.clone()) as *mut _; - - loop { - // Safety: `queued == false` means the caller has exclusive - // access to `self.wait.next`. - self.wait.next.with_mut(|ptr| unsafe { *ptr = curr }); - - let res = self - .shared - .wait_stack - .compare_exchange(curr, node, SeqCst, SeqCst); - - match res { - Ok(_) => return, - Err(actual) => curr = actual, - } - } - } + let fut = Recv::<_, T>::new(Borrow(self)); + fut.await } } #[cfg(feature = "stream")] +#[doc(hidden)] +#[deprecated(since = "0.2.21", note = "use `into_stream()`")] impl crate::stream::Stream for Receiver where T: Clone, @@ -878,6 +967,7 @@ where mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { + #[allow(deprecated)] self.poll_recv(cx).map(|v| match v { Ok(v) => Some(Ok(v)), lag @ Err(RecvError::Lagged(_)) => Some(lag), @@ -890,13 +980,30 @@ impl Drop for Receiver { fn drop(&mut self) { let mut tail = self.shared.tail.lock().unwrap(); + if let Some(waiter) = &self.waiter { + // safety: tail lock is held + let queued = waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } + tail.rx_cnt -= 1; let until = tail.pos; drop(tail); while self.next != until { - match self.recv_ref() { + match self.recv_ref(None) { Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -909,18 +1016,170 @@ impl Drop for Receiver { } } -impl Drop for Shared { - fn drop(&mut self) { - // Clear the wait stack - let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode); +impl Recv +where + R: AsMut>, +{ + fn new(receiver: R) -> Recv { + Recv { + receiver, + waiter: UnsafeCell::new(Waiter { + queued: false, + waker: None, + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + }), + _p: std::marker::PhantomData, + } + } - while !curr.is_null() { - let waiter = unsafe { Arc::from_raw(curr) }; - curr = waiter.next.with(|ptr| unsafe { *ptr }); + /// A custom `project` implementation is used in place of `pin-project-lite` + /// as a custom drop implementation is needed. + fn project(self: Pin<&mut Self>) -> (&mut Receiver, &UnsafeCell) { + unsafe { + // Safety: Receiver is Unpin + is_unpin::<&mut Receiver>(); + + let me = self.get_unchecked_mut(); + (me.receiver.as_mut(), &me.waiter) } } } +impl Future for Recv +where + R: AsMut>, + T: Clone, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)), + }; + + Poll::Ready(guard.clone_value().ok_or(RecvError::Closed)) + } +} + +cfg_stream! { + use futures_core::Stream; + + impl Receiver { + /// Convert the receiver into a `Stream`. + /// + /// The conversion allows using `Receiver` with APIs that require stream + /// values. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::StreamExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel(128); + /// + /// tokio::spawn(async move { + /// for i in 0..10_i32 { + /// tx.send(i).unwrap(); + /// } + /// }); + /// + /// // Streams must be pinned to iterate. + /// tokio::pin! { + /// let stream = rx + /// .into_stream() + /// .filter(Result::is_ok) + /// .map(Result::unwrap) + /// .filter(|v| v % 2 == 0) + /// .map(|v| v + 1); + /// } + /// + /// while let Some(i) = stream.next().await { + /// println!("{}", i); + /// } + /// } + /// ``` + pub fn into_stream(self) -> impl Stream> { + Recv::new(Borrow(self)) + } + } + + impl Stream for Recv + where + R: AsMut>, + T: Clone, + { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let (receiver, waiter) = self.project(); + + let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { + Ok(value) => value, + Err(TryRecvError::Empty) => return Poll::Pending, + Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))), + Err(TryRecvError::Closed) => return Poll::Ready(None), + }; + + Poll::Ready(guard.clone_value().map(Ok)) + } + } +} + +impl Drop for Recv +where + R: AsMut>, +{ + fn drop(&mut self) { + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap(); + + // safety: tail lock is held + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull; + type Target = Waiter; + + fn as_raw(handle: &NonNull) -> NonNull { + *handle + } + + unsafe fn from_raw(ptr: NonNull) -> NonNull { + ptr + } + + unsafe fn pointers(mut target: NonNull) -> NonNull> { + NonNull::from(&mut target.as_mut().pointers) + } +} + impl fmt::Debug for Sender { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "broadcast::Sender") @@ -952,15 +1211,6 @@ impl<'a, T> Drop for RecvGuard<'a, T> { } } -fn ok_empty(res: Result) -> Result, RecvError> { - match res { - Ok(value) => Ok(Some(value)), - Err(TryRecvError::Empty) => Ok(None), - Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)), - Err(TryRecvError::Closed) => Err(RecvError::Closed), - } -} - impl fmt::Display for RecvError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -983,3 +1233,5 @@ impl fmt::Display for TryRecvError { } impl std::error::Error for TryRecvError {} + +fn is_unpin() {} diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 4d756f91975..e37695b37d9 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -90,10 +90,13 @@ fn send_two_recv() { } #[tokio::test] -async fn send_recv_stream() { +async fn send_recv_into_stream_ready() { use tokio::stream::StreamExt; - let (tx, mut rx) = broadcast::channel::(8); + let (tx, rx) = broadcast::channel::(8); + tokio::pin! { + let rx = rx.into_stream(); + } assert_ok!(tx.send(1)); assert_ok!(tx.send(2)); @@ -106,6 +109,26 @@ async fn send_recv_stream() { assert_eq!(None, rx.next().await); } +#[tokio::test] +async fn send_recv_into_stream_pending() { + use tokio::stream::StreamExt; + + let (tx, rx) = broadcast::channel::(8); + + tokio::pin! { + let rx = rx.into_stream(); + } + + let mut recv = task::spawn(rx.next()); + assert_pending!(recv.poll()); + + assert_ok!(tx.send(1)); + + assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert_eq!(val, Some(Ok(1))); +} + #[test] fn send_recv_bounded() { let (tx, mut rx) = broadcast::channel(16); @@ -160,6 +183,23 @@ fn send_two_recv_bounded() { assert_eq!(val2, "world"); } +#[test] +fn change_tasks() { + let (tx, mut rx) = broadcast::channel(1); + + let mut recv = Box::pin(rx.recv()); + + let mut task1 = task::spawn(&mut recv); + assert_pending!(task1.poll()); + + let mut task2 = task::spawn(&mut recv); + assert_pending!(task2.poll()); + + tx.send("hello").unwrap(); + + assert!(task2.is_woken()); +} + #[test] fn send_slow_rx() { let (tx, mut rx1) = broadcast::channel(16); @@ -451,6 +491,39 @@ fn lagging_receiver_recovers_after_wrap_open() { assert_empty!(rx); } +#[tokio::test] +async fn send_recv_stream_ready_deprecated() { + use tokio::stream::StreamExt; + + let (tx, mut rx) = broadcast::channel::(8); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + + assert_eq!(Some(Ok(1)), rx.next().await); + assert_eq!(Some(Ok(2)), rx.next().await); + + drop(tx); + + assert_eq!(None, rx.next().await); +} + +#[tokio::test] +async fn send_recv_stream_pending_deprecated() { + use tokio::stream::StreamExt; + + let (tx, mut rx) = broadcast::channel::(8); + + let mut recv = task::spawn(rx.next()); + assert_pending!(recv.poll()); + + assert_ok!(tx.send(1)); + + assert!(recv.is_woken()); + let val = assert_ready!(recv.poll()); + assert_eq!(val, Some(Ok(1))); +} + fn is_closed(err: broadcast::RecvError) -> bool { match err { broadcast::RecvError::Closed => true,