From 0e46df7f1e18f65d7440c98f4fb4fd42e2d85f76 Mon Sep 17 00:00:00 2001 From: Tim Vilgot Mikael Fredenberg Date: Sat, 1 Feb 2025 15:27:48 +0000 Subject: [PATCH] feat(http-ratelimiting)!: rewrite crate Support buckets and global limit. Work in progress... --- examples/http-proxy.rs | 2 +- twilight-http-ratelimiting/Cargo.toml | 4 +- .../src/in_memory/bucket.rs | 324 --------- .../src/in_memory/mod.rs | 161 ----- twilight-http-ratelimiting/src/lib.rs | 669 ++++++++++++++---- twilight-http-ratelimiting/src/request.rs | 120 ++++ twilight-http-ratelimiting/src/ticket.rs | 175 ----- twilight-http/src/client/builder.rs | 23 +- twilight-http/src/client/mod.rs | 16 +- twilight-http/src/response/future.rs | 29 +- twilight-http/src/routing.rs | 2 +- 11 files changed, 690 insertions(+), 835 deletions(-) delete mode 100644 twilight-http-ratelimiting/src/in_memory/bucket.rs delete mode 100644 twilight-http-ratelimiting/src/in_memory/mod.rs delete mode 100644 twilight-http-ratelimiting/src/ticket.rs diff --git a/examples/http-proxy.rs b/examples/http-proxy.rs index 1fccea21a64..4bde6c3b344 100644 --- a/examples/http-proxy.rs +++ b/examples/http-proxy.rs @@ -10,7 +10,7 @@ async fn main() -> anyhow::Result<()> { let client = Client::builder() .proxy("localhost:3000".to_owned(), true) - .ratelimiter(None) + .ratelimiter(false) .build(); let channel_id = Id::new(620_980_184_606_048_278); diff --git a/twilight-http-ratelimiting/Cargo.toml b/twilight-http-ratelimiting/Cargo.toml index 0bab3e93d71..2f2c602b975 100644 --- a/twilight-http-ratelimiting/Cargo.toml +++ b/twilight-http-ratelimiting/Cargo.toml @@ -14,7 +14,9 @@ rust-version.workspace = true version = "0.16.0" [dependencies] -tokio = { version = "1", default-features = false, features = ["rt", "sync", "time"] } +hashbrown = { default-features = false, version = "0.15"} +tokio = { default-features = false, features = ["macros", "rt", "sync", "time"], version = "1" } +tokio-util = { default-features = false, features = ["time"], version = "0.7.11" } tracing = { default-features = false, features = ["std", "attributes"], version = "0.1.23" } [dev-dependencies] diff --git a/twilight-http-ratelimiting/src/in_memory/bucket.rs b/twilight-http-ratelimiting/src/in_memory/bucket.rs deleted file mode 100644 index f02041d7a0e..00000000000 --- a/twilight-http-ratelimiting/src/in_memory/bucket.rs +++ /dev/null @@ -1,324 +0,0 @@ -//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally. -//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is -//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket -//! and respects the global ratelimit. - -use super::GlobalLockPair; -use crate::{headers::RatelimitHeaders, request::Path, ticket::TicketNotifier}; -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, Mutex, - }, - time::{Duration, Instant}, -}; -use tokio::{ - sync::{ - mpsc::{self, UnboundedReceiver, UnboundedSender}, - Mutex as AsyncMutex, - }, - time::{sleep, timeout}, -}; - -/// Time remaining until a bucket will reset. -#[derive(Clone, Debug)] -pub enum TimeRemaining { - /// Bucket has already reset. - Finished, - /// Bucket's ratelimit refresh countdown has not started yet. - NotStarted, - /// Amount of time until the bucket resets. - Some(Duration), -} - -/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`]. -/// -/// A generic version not specific to this ratelimiter is [`crate::Bucket`]. -#[derive(Debug)] -pub struct Bucket { - /// Total number of tickets allotted in a cycle. - pub limit: AtomicU64, - /// Path this ratelimit applies to. - // This is dead code, but it is useful for debugging. - #[allow(dead_code)] - pub path: Path, - /// Queue associated with this bucket. - pub queue: BucketQueue, - /// Number of tickets remaining. - pub remaining: AtomicU64, - /// Duration after the [`Self::started_at`] time the bucket will refresh. - pub reset_after: AtomicU64, - /// When the bucket's ratelimit refresh countdown started. - pub started_at: Mutex>, -} - -impl Bucket { - /// Create a new bucket for the specified [`Path`]. - pub fn new(path: Path) -> Self { - Self { - limit: AtomicU64::new(u64::MAX), - path, - queue: BucketQueue::default(), - remaining: AtomicU64::new(u64::MAX), - reset_after: AtomicU64::new(u64::MAX), - started_at: Mutex::new(None), - } - } - - /// Total number of tickets allotted in a cycle. - pub fn limit(&self) -> u64 { - self.limit.load(Ordering::Relaxed) - } - - /// Number of tickets remaining. - pub fn remaining(&self) -> u64 { - self.remaining.load(Ordering::Relaxed) - } - - /// Duration after the [`started_at`] time the bucket will refresh. - /// - /// [`started_at`]: Self::started_at - pub fn reset_after(&self) -> u64 { - self.reset_after.load(Ordering::Relaxed) - } - - /// Time remaining until this bucket will reset. - pub fn time_remaining(&self) -> TimeRemaining { - let reset_after = self.reset_after(); - let maybe_started_at = *self.started_at.lock().expect("bucket poisoned"); - - let Some(started_at) = maybe_started_at else { - return TimeRemaining::NotStarted; - }; - - let elapsed = started_at.elapsed(); - - if elapsed > Duration::from_millis(reset_after) { - return TimeRemaining::Finished; - } - - TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed) - } - - /// Try to reset this bucket's [`started_at`] value if it has finished. - /// - /// Returns whether resetting was possible. - /// - /// [`started_at`]: Self::started_at - pub fn try_reset(&self) -> bool { - if self.started_at.lock().expect("bucket poisoned").is_none() { - return false; - } - - if let TimeRemaining::Finished = self.time_remaining() { - self.remaining.store(self.limit(), Ordering::Relaxed); - *self.started_at.lock().expect("bucket poisoned") = None; - - true - } else { - false - } - } - - /// Update this bucket's ratelimit data after a request has been made. - pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) { - let bucket_limit = self.limit(); - - { - let mut started_at = self.started_at.lock().expect("bucket poisoned"); - - if started_at.is_none() { - started_at.replace(Instant::now()); - } - } - - if let Some((limit, remaining, reset_after)) = ratelimits { - if bucket_limit != limit && bucket_limit == u64::MAX { - self.reset_after.store(reset_after, Ordering::SeqCst); - self.limit.store(limit, Ordering::SeqCst); - } - - self.remaining.store(remaining, Ordering::Relaxed); - } else { - self.remaining.fetch_sub(1, Ordering::Relaxed); - } - } -} - -/// Queue of ratelimit requests for a bucket. -#[derive(Debug)] -pub struct BucketQueue { - /// Receiver for the ratelimit requests. - rx: AsyncMutex>, - /// Sender for the ratelimit requests. - tx: UnboundedSender, -} - -impl BucketQueue { - /// Add a new ratelimit request to the queue. - pub fn push(&self, tx: TicketNotifier) { - let _sent = self.tx.send(tx); - } - - /// Receive the first incoming ratelimit request. - pub async fn pop(&self, timeout_duration: Duration) -> Option { - let mut rx = self.rx.lock().await; - - timeout(timeout_duration, rx.recv()).await.ok().flatten() - } -} - -impl Default for BucketQueue { - fn default() -> Self { - let (tx, rx) = mpsc::unbounded_channel(); - - Self { - rx: AsyncMutex::new(rx), - tx, - } - } -} - -/// A background task that handles ratelimit requests to a [`Bucket`] -/// and processes them in order, keeping track of both the global and -/// the [`Path`]-specific ratelimits. -pub(super) struct BucketQueueTask { - /// The [`Bucket`] managed by this task. - bucket: Arc, - /// All buckets managed by the associated [`super::InMemoryRatelimiter`]. - buckets: Arc>>>, - /// Global ratelimit data. - global: Arc, - /// The [`Path`] this [`Bucket`] belongs to. - path: Path, -} - -impl BucketQueueTask { - /// Timeout to wait for response headers after initiating a request. - const WAIT: Duration = Duration::from_secs(10); - - /// Create a new task to manage the ratelimit for a [`Bucket`]. - pub const fn new( - bucket: Arc, - buckets: Arc>>>, - global: Arc, - path: Path, - ) -> Self { - Self { - bucket, - buckets, - global, - path, - } - } - - /// Process incoming ratelimit requests to this bucket and update the state - /// based on received [`RatelimitHeaders`]. - #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))] - pub async fn run(self) { - while let Some(queue_tx) = self.next().await { - if self.global.is_locked() { - drop(self.global.0.lock().await); - } - - let Some(ticket_headers) = queue_tx.available() else { - continue; - }; - - tracing::debug!("starting to wait for response headers"); - - match timeout(Self::WAIT, ticket_headers).await { - Ok(Ok(Some(headers))) => self.handle_headers(&headers).await, - Ok(Ok(None)) => { - tracing::debug!("request aborted"); - } - Ok(Err(_)) => { - tracing::debug!("ticket channel closed"); - } - Err(_) => { - tracing::debug!("receiver timed out"); - } - } - } - - tracing::debug!("bucket appears finished, removing"); - - self.buckets - .lock() - .expect("ratelimit buckets poisoned") - .remove(&self.path); - } - - /// Update the bucket's ratelimit state. - async fn handle_headers(&self, headers: &RatelimitHeaders) { - let ratelimits = match headers { - RatelimitHeaders::Global(global) => { - self.lock_global(Duration::from_secs(global.retry_after())) - .await; - - None - } - RatelimitHeaders::None => return, - RatelimitHeaders::Present(present) => { - Some((present.limit(), present.remaining(), present.reset_after())) - } - }; - - tracing::debug!(path=?self.path, "updating bucket"); - self.bucket.update(ratelimits); - } - - /// Lock the global ratelimit for a specified duration. - async fn lock_global(&self, wait: Duration) { - tracing::debug!(path=?self.path, "request got global ratelimited"); - self.global.lock(); - let lock = self.global.0.lock().await; - sleep(wait).await; - self.global.unlock(); - - drop(lock); - } - - /// Get the next [`TicketNotifier`] in the queue. - async fn next(&self) -> Option { - tracing::debug!(path=?self.path, "starting to get next in queue"); - - self.wait_if_needed().await; - - self.bucket.queue.pop(Self::WAIT).await - } - - /// Wait for this bucket to refresh if it isn't ready yet. - #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))] - async fn wait_if_needed(&self) { - let wait = { - if self.bucket.remaining() > 0 { - return; - } - - tracing::debug!("0 tickets remaining, may have to wait"); - - match self.bucket.time_remaining() { - TimeRemaining::Finished => { - self.bucket.try_reset(); - - return; - } - TimeRemaining::NotStarted => return, - TimeRemaining::Some(dur) => dur, - } - }; - - tracing::debug!( - milliseconds=%wait.as_millis(), - "waiting for ratelimit to pass", - ); - - sleep(wait).await; - - tracing::debug!("done waiting for ratelimit to pass"); - - self.bucket.try_reset(); - } -} diff --git a/twilight-http-ratelimiting/src/in_memory/mod.rs b/twilight-http-ratelimiting/src/in_memory/mod.rs deleted file mode 100644 index e4d64a5ea7a..00000000000 --- a/twilight-http-ratelimiting/src/in_memory/mod.rs +++ /dev/null @@ -1,161 +0,0 @@ -//! In-memory based default [`Ratelimiter`] implementation used in `twilight-http`. - -mod bucket; - -use self::bucket::{Bucket, BucketQueueTask}; -use super::{ - ticket::{self, TicketNotifier}, - Bucket as InfoBucket, Ratelimiter, -}; -use crate::{ - request::Path, GetBucketFuture, GetTicketFuture, HasBucketFuture, IsGloballyLockedFuture, -}; -use std::{ - collections::hash_map::{Entry, HashMap}, - future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, - }, - time::Duration, -}; -use tokio::sync::Mutex as AsyncMutex; - -/// Global lock. We use a pair to avoid actually locking the mutex every check. -/// This allows futures to only wait on the global lock when a global ratelimit -/// is in place by, in turn, waiting for a guard, and then each immediately -/// dropping it. -#[derive(Debug, Default)] -struct GlobalLockPair(AsyncMutex<()>, AtomicBool); - -impl GlobalLockPair { - /// Set the global ratelimit as exhausted. - pub fn lock(&self) { - self.1.store(true, Ordering::Release); - } - - /// Set the global ratelimit as no longer exhausted. - pub fn unlock(&self) { - self.1.store(false, Ordering::Release); - } - - /// Whether the global ratelimit is exhausted. - pub fn is_locked(&self) -> bool { - self.1.load(Ordering::Relaxed) - } -} - -/// Default ratelimiter implementation used in twilight that -/// stores ratelimit information in an in-memory mapping. -/// -/// This will meet most users' needs for simple ratelimiting, -/// but for multi-processed bots, consider either implementing -/// your own [`Ratelimiter`] that uses a shared storage backend -/// or use the [HTTP proxy]. -/// -/// [HTTP proxy]: https://twilight.rs/chapter_2_multi-serviced_approach.html#http-proxy-ratelimiting -#[derive(Clone, Debug, Default)] -pub struct InMemoryRatelimiter { - /// Mapping of [`Path`]s to their associated [`Bucket`]s. - buckets: Arc>>>, - /// Global ratelimit data. - global: Arc, -} - -impl InMemoryRatelimiter { - /// Create a new in-memory ratelimiter. - /// - /// This is used by HTTP client to queue requests in order to avoid - /// hitting the API's ratelimits. - #[must_use] - pub fn new() -> Self { - Self::default() - } - - /// Enqueue the [`TicketNotifier`] to the [`Path`]'s [`Bucket`]. - /// - /// Returns the new [`Bucket`] if none existed. - fn entry(&self, path: Path, tx: TicketNotifier) -> Option> { - let mut buckets = self.buckets.lock().expect("buckets poisoned"); - - match buckets.entry(path.clone()) { - Entry::Occupied(bucket) => { - tracing::debug!("got existing bucket: {path:?}"); - - bucket.get().queue.push(tx); - - tracing::debug!("added request into bucket queue: {path:?}"); - - None - } - Entry::Vacant(entry) => { - tracing::debug!("making new bucket for path: {path:?}"); - - let bucket = Bucket::new(path); - bucket.queue.push(tx); - - let bucket = Arc::new(bucket); - entry.insert(Arc::clone(&bucket)); - - Some(bucket) - } - } - } -} - -impl Ratelimiter for InMemoryRatelimiter { - fn bucket(&self, path: &Path) -> GetBucketFuture { - self.buckets - .lock() - .expect("buckets poisoned") - .get(path) - .map_or_else( - || Box::pin(future::ready(Ok(None))), - |bucket| { - let started_at = bucket.started_at.lock().expect("bucket poisoned"); - let reset_after = Duration::from_millis(bucket.reset_after()); - - Box::pin(future::ready(Ok(Some(InfoBucket::new( - bucket.limit(), - bucket.remaining(), - reset_after, - *started_at, - ))))) - }, - ) - } - - fn is_globally_locked(&self) -> IsGloballyLockedFuture { - Box::pin(future::ready(Ok(self.global.is_locked()))) - } - - fn has(&self, path: &Path) -> HasBucketFuture { - let has = self - .buckets - .lock() - .expect("buckets poisoned") - .contains_key(path); - - Box::pin(future::ready(Ok(has))) - } - - fn ticket(&self, path: Path) -> GetTicketFuture { - tracing::debug!("getting bucket for path: {path:?}"); - - let (tx, rx) = ticket::channel(); - - if let Some(bucket) = self.entry(path.clone(), tx) { - tokio::spawn( - BucketQueueTask::new( - bucket, - Arc::clone(&self.buckets), - Arc::clone(&self.global), - path, - ) - .run(), - ); - } - - Box::pin(future::ready(Ok(rx))) - } -} diff --git a/twilight-http-ratelimiting/src/lib.rs b/twilight-http-ratelimiting/src/lib.rs index bf28174ab7c..c9f13684176 100644 --- a/twilight-http-ratelimiting/src/lib.rs +++ b/twilight-http-ratelimiting/src/lib.rs @@ -6,171 +6,586 @@ missing_docs, unsafe_code )] -#![allow( - clippy::module_name_repetitions, - clippy::must_use_candidate, - clippy::unnecessary_wraps -)] +#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)] pub mod headers; -pub mod in_memory; pub mod request; -pub mod ticket; pub use self::{ headers::RatelimitHeaders, - in_memory::InMemoryRatelimiter, request::{Method, Path}, }; -use self::ticket::{TicketReceiver, TicketSender}; +use hashbrown::hash_table; use std::{ - error::Error, - fmt::Debug, - future::Future, - pin::Pin, - time::{Duration, Instant}, + collections::{hash_map::Entry, HashMap, VecDeque}, + future::{poll_fn, Future}, + hash::{BuildHasher as _, Hash, Hasher as _, RandomState}, + mem, + num::ParseIntError, + pin::{self, Pin}, + str::{self, FromStr}, + task::{Context, Poll}, +}; +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinSet, + time::{self, Duration, Instant}, }; +use tokio_util::time::delay_queue::{DelayQueue, Key}; -/// A bucket containing ratelimiting information for a [`Path`]. -pub struct Bucket { - /// Total number of tickets allotted in a cycle. - limit: u64, - /// Number of tickets remaining. - remaining: u64, - /// Duration after [`Self::started_at`] time the bucket will refresh. - reset_after: Duration, - /// When the bucket's ratelimit refresh countdown started. - started_at: Option, +use crate::headers::{HeaderName, HeaderParsingError, HeaderParsingErrorType, HeaderType}; + +/// +#[derive(Debug)] +pub struct Headers { + bucket: Box, + limit: u16, + remaining: u16, + reset_after: u32, } -impl Bucket { - /// Create a representation of a ratelimiter bucket. +impl Headers { /// - /// Buckets are returned by ratelimiters via [`Ratelimiter::bucket`] method. - /// Its primary use is for informational purposes, including information - /// such as the [number of remaining tickets][`Self::limit`] or determining - /// how much time remains - /// [until the bucket interval resets][`Self::time_remaining`]. - #[must_use] - pub const fn new( - limit: u64, - remaining: u64, - reset_after: Duration, - started_at: Option, - ) -> Self { - Self { - limit, - remaining, - reset_after, - started_at, + pub fn from_pairs<'a>( + headers: impl Iterator, + ) -> Result, HeaderParsingError> { + /// Parse a value expected to be a float. + fn header_float(name: HeaderName, value: &[u8]) -> Result { + let text = header_str(name, value)?; + + let end = text.parse().map_err(|source| HeaderParsingError { + kind: HeaderParsingErrorType::Parsing { + kind: HeaderType::Float, + name, + value: text.to_owned(), + }, + source: Some(Box::new(source)), + })?; + + Ok(end) } - } - /// Total number of tickets allotted in a cycle. - #[must_use] - pub const fn limit(&self) -> u64 { - self.limit - } + /// Parse a value expected to be an integer. + fn header_int>( + name: HeaderName, + value: &[u8], + ) -> Result { + let text = header_str(name, value)?; - /// Number of tickets remaining. - #[must_use] - pub const fn remaining(&self) -> u64 { - self.remaining - } + let end = text.parse().map_err(|source| HeaderParsingError { + kind: HeaderParsingErrorType::Parsing { + kind: HeaderType::Integer, + name, + value: text.to_owned(), + }, + source: Some(Box::new(source)), + })?; - /// Duration after the [`Self::started_at`] time the bucket will - /// refresh. - #[must_use] - pub const fn reset_after(&self) -> Duration { - self.reset_after - } + Ok(end) + } + + /// Parse a value expected to be a UTF-8 valid string. + fn header_str(name: HeaderName, value: &[u8]) -> Result<&str, HeaderParsingError> { + let text = str::from_utf8(value) + .map_err(|source| HeaderParsingError::not_utf8(name, value.to_owned(), source))?; - /// When the bucket's ratelimit refresh countdown started. - #[must_use] - pub const fn started_at(&self) -> Option { - self.started_at + Ok(text) + } + + let mut bucket = None; + let mut limit = None; + let mut remaining = None; + let mut reset_after = None; + + for (name, value) in headers { + match name { + HeaderName::BUCKET => { + bucket.replace(header_str(HeaderName::Bucket, value)?); + } + HeaderName::LIMIT => { + limit.replace(header_int(HeaderName::Limit, value)?); + } + HeaderName::REMAINING => { + remaining.replace(header_int(HeaderName::Remaining, value)?); + } + HeaderName::RESET_AFTER => { + let reset_after_value = header_float(HeaderName::ResetAfter, value)?; + + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + reset_after.replace((reset_after_value * 1000.).ceil() as u32); + } + _ => {} + } + } + + if let Some(bucket) = bucket { + if let Some(limit) = limit { + if let Some(remaining) = remaining { + if let Some(reset_after) = reset_after { + return Ok(Some(Headers { + bucket: bucket.to_owned().into_boxed_str(), + limit, + remaining, + reset_after, + })); + } + } + } + } + + Ok(None) } +} - /// How long until the bucket will refresh. - /// - /// May return `None` if the refresh timer has not been started yet or - /// the bucket has already refreshed. - #[must_use] - pub fn time_remaining(&self) -> Option { - let reset_at = self.started_at? + self.reset_after; +/// Permit to send a Discord HTTP API request. +#[derive(Debug)] +#[must_use = "dropping the permit immediately cancels itself"] +pub struct Permit(oneshot::Sender>); - reset_at.checked_duration_since(Instant::now()) +impl Permit { + /// Update the rate limiter based on the response headers. + /// + /// Non-completed permits are regarded as cancelled, so only call this + /// on receiving a response. + pub fn complete(self, headers: Option) { + _ = self.0.send(headers); } } -/// A generic error type that implements [`Error`]. -pub type GenericError = Box; +/// Future that completes when a permit is ready. +#[derive(Debug)] +pub struct PermitFuture(oneshot::Receiver>>); -/// Future returned by [`Ratelimiter::bucket`]. -pub type GetBucketFuture = - Pin, GenericError>> + Send + 'static>>; +impl Future for PermitFuture { + type Output = Permit; -/// Future returned by [`Ratelimiter::is_globally_locked`]. -pub type IsGloballyLockedFuture = - Pin> + Send + 'static>>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0) + .poll(cx) + .map(|r| Permit(r.expect("actor is alive"))) + } +} -/// Future returned by [`Ratelimiter::has`]. -pub type HasBucketFuture = - Pin> + Send + 'static>>; +/// Future that completes when a permit is ready if it passed the predicate. +#[derive(Debug)] +pub struct MaybePermitFuture(oneshot::Receiver>>); -/// Future returned by [`Ratelimiter::ticket`]. -pub type GetTicketFuture = - Pin> + Send + 'static>>; +impl Future for MaybePermitFuture { + type Output = Option; -/// Future returned by [`Ratelimiter::wait_for_ticket`]. -pub type WaitForTicketFuture = - Pin> + Send + 'static>>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit)) + } +} -/// An implementation of a ratelimiter for the Discord REST API. +/// Pending permit state. +#[derive(Debug)] +struct Request { + /// Path the permit is for, mapping to a [`Queue`]. + path: Path, + /// Completion handle for the associated [`PermitFuture`]. + notifier: oneshot::Sender>>, +} + +/// Grouped pending permits holder snapshot. /// -/// A default implementation can be found in [`InMemoryRatelimiter`]. +/// Grouping may be done by path or bucket, based on previous permits' response +/// headers. +#[non_exhaustive] +#[derive(Debug)] +pub struct QueueSnapshot { + /// Number of already queued permits. + pub len: usize, + /// Time at which the bucket resets. + pub reset_at: Instant, + /// Total number of permits until the queue becomes exhausted. + pub limit: u16, + /// Number of remaining permits until the queue becomes exhausted. + pub remaining: u16, +} + +/// Grouped pending permits holder. /// -/// All operations are asynchronous to allow for custom implementations to -/// use different storage backends, for example databases. +/// Grouping may be done by path or bucket, based on previous permits' response +/// headers. /// -/// Ratelimiters should keep track of two kids of ratelimits: -/// * The global ratelimit status -/// * [`Path`]-specific ratelimits +/// Queue may not be rate limited, in which case the values of [`limit`][Self::limit], +/// [`reset`][Self::reset], and [`remaining`][Self::remaining] are unused. +#[derive(Debug)] +struct Queue { + /// Whether the queue is handling outstanding permits. + /// + /// Note that this is `true` when globally exhausted and `false` when + /// the queue is exhausted. + idle: bool, + /// List of pending permit requests. + inner: VecDeque, + /// Total number of permits until the queue becomes exhausted. + limit: u16, + /// Key mapping to an [`Instant`] when the queue resets, if rate limited. + reset: Option, + /// Number of remaining permits until the queue becomes exhausted. + remaining: u16, +} + +impl Queue { + /// Create a new non rate limited queue. + const fn new() -> Self { + Self { + idle: true, + inner: VecDeque::new(), + limit: 0, + reset: None, + remaining: 0, + } + } + + /// Completes and returns the first queued permit, unless the queue is + /// globally exhausted. + fn pop( + &mut self, + globally_exhausted: bool, + ) -> Option<(Path, oneshot::Receiver>)> { + let (mut tx, rx) = oneshot::channel(); + while self + .inner + .front() + .is_some_and(|req| req.path.is_interaction() || !globally_exhausted) + { + let req = self.inner.pop_front().unwrap(); + match req.notifier.send(tx) { + Ok(()) => return Some((req.path, rx)), + Err(recover) => tx = recover, + } + } + self.idle = true; + + None + } +} + +/// Discord HTTP client API rate limiter. /// -/// To do this, clients utilizing a ratelimiter will send back response -/// ratelimit headers via a [`TicketSender`]. +/// The rate limiter runs an associated actor task to concurrently handle permit +/// requests and responses. /// -/// The ratelimiter itself will hand a [`TicketReceiver`] to the caller -/// when a ticket is being requested. -pub trait Ratelimiter: Debug + Send + Sync { - /// Retrieve the basic information of the bucket for a given path. - fn bucket(&self, path: &Path) -> GetBucketFuture; - - /// Whether the ratelimiter is currently globally locked. - fn is_globally_locked(&self) -> IsGloballyLockedFuture; - - /// Determine if the ratelimiter has a bucket for the given path. - fn has(&self, path: &Path) -> HasBucketFuture; - - /// Retrieve a ticket to know when to send a request. - /// The provided future will be ready when a ticket in the bucket is - /// available. Tickets are ready in order of retrieval. - fn ticket(&self, path: Path) -> GetTicketFuture; - - /// Retrieve a ticket to send a request. - /// Other than [`Self::ticket`], this method will return - /// a [`TicketSender`]. +/// Cloning a rate limiter increments just the amount of senders for the actor. +/// The actor completes when there are no senders and pending permits left. +#[derive(Clone, Debug)] +pub struct RateLimiter { + /// Actor message sender. + tx: mpsc::UnboundedSender<( + Request, + Option) -> bool + Send>>, + )>, +} + +impl RateLimiter { + /// Create a new rate limiter with custom settings. + pub fn new(global_limit: u16) -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + tokio::spawn(runner(global_limit, rx)); + + Self { tx } + } + + /// Await a single permit for this path. /// - /// This is identical to calling [`Self::ticket`] and then - /// awaiting the [`TicketReceiver`]. - fn wait_for_ticket(&self, path: Path) -> WaitForTicketFuture { - let get_ticket = self.ticket(path); - Box::pin(async move { - match get_ticket.await { - Ok(rx) => rx.await.map_err(From::from), - Err(e) => Err(e), - } + /// Permits are queued per path in the order they were requested. + #[allow(clippy::missing_panics_doc)] + pub fn acquire(&self, path: Path) -> PermitFuture { + let (tx, rx) = oneshot::channel(); + self.tx + .send((Request { path, notifier: tx }, None)) + .expect("actor is alive"); + + PermitFuture(rx) + } + + /// Await a single permit for this path, but only if the predicate evaluates + /// to `true`. + /// + /// Permits are queued per path in the order they were requested. + /// + /// Note that the predicate is asynchronously called in the actor task. + #[allow(clippy::missing_panics_doc)] + pub fn acquire_if

(&self, path: Path, predicate: P) -> MaybePermitFuture + where + P: FnOnce(Option) -> bool + Send + 'static, + { + let (tx, rx) = oneshot::channel(); + self.tx + .send((Request { path, notifier: tx }, Some(Box::new(predicate)))) + .expect("actor is alive"); + + MaybePermitFuture(rx) + } + + /// Retrieve the [`QueueSnapshot`] for this path. + /// + /// The snapshot is internally retrieved via [`acquire_if`][Self::acquire_if]. + pub async fn snapshot(&self, path: Path) -> Option { + let (tx, rx) = oneshot::channel(); + self.acquire_if(path, |snapshot| { + _ = tx.send(snapshot); + false }) + .await; + + rx.await.unwrap() + } +} + +impl Default for RateLimiter { + /// Create a new rate limiter with Discord's default global limit. + /// + /// Currently this is `50`. + fn default() -> Self { + Self::new(50) + } +} + +/// Duration from the first globally limited request until the remaining count +/// resets to the global limit count. +const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1); + +/// Rate limiter actor runner. +#[allow(clippy::too_many_lines)] +async fn runner( + global_limit: u16, + mut rx: mpsc::UnboundedReceiver<( + Request, + Option) -> bool + Send>>, + )>, +) { + let mut global_remaining = global_limit; + let mut global_timer = pin::pin!(time::sleep(Duration::ZERO)); + + let mut buckets = HashMap::>::new(); + let mut in_flight = + JoinSet::<(Path, Result, oneshot::error::RecvError>)>::new(); + + let mut reset = DelayQueue::::new(); + let mut queues = hashbrown::HashTable::<(u64, Queue)>::new(); + let hasher = RandomState::new(); + + macro_rules! on_permit { + () => { + // Global must be decremented before sending the message as, unlike the bucket, + // it is not blocked until this request receives response headers. + global_remaining -= 1; + if global_remaining == global_limit - 1 { + global_timer + .as_mut() + .reset(Instant::now() + GLOBAL_LIMIT_PERIOD); + } else if global_remaining == 0 { + let now = Instant::now(); + let reset_after = now.saturating_duration_since(global_timer.deadline()); + if reset_after.is_zero() { + global_remaining = global_limit - 1; + global_timer.as_mut().reset(now + GLOBAL_LIMIT_PERIOD); + } else { + tracing::info!(?reset_after, "globally exhausted"); + } + } + }; + } + + #[allow(clippy::ignored_unit_patterns)] + loop { + tokio::select! { + biased; + _ = &mut global_timer, if global_remaining == 0 => { + global_remaining = global_limit; + for (_, queue) in queues.iter_mut().filter(|(_, queue)| queue.idle) { + if let Some((path, rx)) = queue.pop(global_remaining == 0) { + queue.idle = false; + tracing::debug!(?path, "permitted"); + on_permit!(); + in_flight.spawn(async move { (path, rx.await) }); + } + } + } + Some(hash) = poll_fn(|cx| reset.poll_expired(cx)) => { + let hash = hash.into_inner(); + let (_, queue) = queues.find_mut(hash, |val| val.0 == hash).expect("hash is unchanged"); + queue.reset = None; + let maybe_in_flight = queue.remaining != 0; + if maybe_in_flight { continue; } + + if let Some((path, rx)) = queue.pop(global_remaining == 0) { + tracing::debug!(?path, "permitted"); + if !path.is_interaction() { + on_permit!(); + } + in_flight.spawn(async move { (path, rx.await) }); + } + } + Some(response) = in_flight.join_next() => { + let (path, headers) = response.expect("task should not fail"); + + let mut builder = hasher.build_hasher(); + path.hash_components(&mut builder); + + let queue = match headers { + Ok(Some(headers)) => { + let _span = tracing::info_span!("headers", ?path).entered(); + tracing::trace!(?headers); + let bucket = headers.bucket; + + bucket.hash(&mut builder); + let hash = builder.finish(); + let queue = match buckets.entry(path) { + Entry::Occupied(mut entry) if *entry.get() != bucket => { + let mut old_builder = hasher.build_hasher(); + entry.key().hash_components(&mut old_builder); + entry.get().hash(&mut old_builder); + let old_hash = old_builder.finish(); + + tracing::debug!(new = hash, previous = old_hash, "bucket changed"); + + *entry.get_mut() = bucket; + let path = entry.key(); + + let mut entry = queues.find_entry(old_hash, |a| a.0 == old_hash).expect("hash is unchanged"); + let shared = entry.get().1.inner.iter().any(|req| req.path != *path); + let queue = if shared { + let mut inner = VecDeque::new(); + for req in mem::take(&mut entry.get_mut().1.inner) { + if req.path == *path { + inner.push_back(req); + } else { + entry.get_mut().1.inner.push_back(req); + } + } + + let old_queue = &mut entry.get_mut().1; + if let Some((path, rx)) = old_queue.pop(global_remaining == 0) { + tracing::debug!(?path, "permitted"); + if !path.is_interaction() { + on_permit!(); + } + in_flight.spawn(async move { (path, rx.await) }); + } + + Queue { + idle: false, + inner, + limit: 0, + reset: None, + remaining: 0, + } + } else { + entry.remove().0.1 + }; + + match queues.entry(hash, |a| a.0 == hash, |a| a.0) { + hash_table::Entry::Occupied(mut entry) => { + entry.get_mut().1.inner.extend(queue.inner); + &mut entry.into_mut().1 + } + hash_table::Entry::Vacant(entry) => &mut entry.insert((hash, queue)).into_mut().1, + } + } + Entry::Occupied(_) => &mut queues.find_mut(hash, |a| a.0 == hash).unwrap().1, + Entry::Vacant(entry) => { + let mut old_builder = hasher.build_hasher(); + entry.key().hash_components(&mut old_builder); + let old_hash = old_builder.finish(); + + tracing::debug!(hash, "bucket assigned"); + entry.insert(bucket); + + let ((_, queue), _) = queues.find_entry(old_hash, |a| a.0 == old_hash).expect("hash is unchanged").remove(); + &mut queues.insert_unique(hash, (hash, queue), |a| a.0).into_mut().1 + }, + }; + + queue.limit = headers.limit; + queue.remaining = headers.remaining; + let reset_after = Duration::from_millis(headers.reset_after.into()); + if let Some(key) = &queue.reset { + reset.reset(key, reset_after); + } else { + queue.reset = Some(reset.insert(hash, reset_after)); + } + if queue.remaining == 0 { + tracing::info!(?reset_after, "exhausted"); + queue.idle = true; + continue; + } + + queue + } + Ok(None) => { + if let Some(bucket) = buckets.get(&path) { + bucket.hash(&mut builder); + } + let hash = builder.finish(); + + &mut queues.find_mut(hash, |a| a.0 == hash).expect("hash is unchanged").1 + } + Err(_) => { + tracing::debug!(?path, "cancelled"); + if global_remaining != global_limit { + global_remaining += 1; + } + + if let Some(bucket) = buckets.get(&path) { + bucket.hash(&mut builder); + } + let hash = builder.finish(); + + &mut queues.find_mut(hash, |a| a.0 == hash).expect("hash is unchanged").1 + } + }; + + if let Some((path, rx)) = queue.pop(global_remaining == 0) { + tracing::debug!(?path, "permitted"); + if !path.is_interaction() { + on_permit!(); + } + in_flight.spawn(async move { (path, rx.await) }); + } + } + Some((msg, predicate)) = rx.recv() => { + let mut builder = hasher.build_hasher(); + msg.path.hash_components(&mut builder); + + let (_, queue) = if let Some(bucket) = buckets.get(&msg.path) { + bucket.hash(&mut builder); + let hash = builder.finish(); + queues.find_mut(hash, |a| a.0 == hash).unwrap() + } else { + let hash = builder.finish(); + queues.entry(hash, |a| a.0 == hash, |a| a.0).or_insert_with(|| (hash, Queue::new())).into_mut() + }; + + let snapshot = queue.reset.map(|key| QueueSnapshot { + len: queue.inner.len(), + reset_at: reset.deadline(&key), + limit: queue.limit, + remaining: queue.remaining, + }); + + if predicate.is_some_and(|p| !p(snapshot)) { + drop(msg); + } else if !queue.idle || (!msg.path.is_interaction() && global_remaining == 0) { + queue.inner.push_back(msg); + } else { + let (tx, rx) = oneshot::channel(); + if msg.notifier.send(tx).is_ok() { + queue.idle = false; + tracing::debug!(path = ?msg.path, "permitted"); + if !msg.path.is_interaction() { + on_permit!(); + } + in_flight.spawn(async move { (msg.path, rx.await) }); + } + } + } + else => break, + } } } diff --git a/twilight-http-ratelimiting/src/request.rs b/twilight-http-ratelimiting/src/request.rs index b7f70f0c7c7..80c56c9d9c2 100644 --- a/twilight-http-ratelimiting/src/request.rs +++ b/twilight-http-ratelimiting/src/request.rs @@ -11,6 +11,7 @@ use std::{ error::Error, fmt::{Display, Formatter, Result as FmtResult}, + hash::{Hash, Hasher}, str::FromStr, }; @@ -306,6 +307,125 @@ pub enum Path { WebhooksIdTokenMessagesId(u64, String), } +impl Path { + /// Whether the path is an interaction path. + pub(crate) const fn is_interaction(&self) -> bool { + matches!( + self, + Self::InteractionCallback(_) + | Self::WebhooksId(_) + | Self::WebhooksIdToken(_, _) + | Self::WebhooksIdTokenMessagesId(_, _) + ) + } + + /// Feeds the top level components of this path into the given [`Hasher`]. + #[allow(clippy::too_many_lines)] + pub(crate) fn hash_components(&self, state: &mut impl Hasher) { + match self { + Path::ApplicationsMe + | Path::Gateway + | Path::GatewayBot + | Path::Guilds + | Path::InvitesCode + | Path::OauthApplicationsMe + | Path::OauthMe + | Path::StageInstances + | Path::StickerPacks + | Path::Stickers + | Path::UsersId + | Path::UsersIdChannels + | Path::UsersIdConnections + | Path::UsersIdGuilds + | Path::UsersIdGuildsId + | Path::UsersIdGuildsIdMember + | Path::VoiceRegions => {} + Path::ApplicationCommand(id) + | Path::ApplicationCommandId(id) + | Path::ApplicationEmojis(id) + | Path::ApplicationEmoji(id) + | Path::ApplicationGuildCommand(id) + | Path::ChannelsId(id) + | Path::ChannelsIdFollowers(id) + | Path::ChannelsIdInvites(id) + | Path::ChannelsIdMessages(id) + | Path::ChannelsIdMessagesBulkDelete(id) + | Path::ChannelsIdMessagesIdCrosspost(id) + | Path::ChannelsIdMessagesIdReactions(id) + | Path::ChannelsIdMessagesIdReactionsUserIdType(id) + | Path::ChannelsIdMessagesIdThreads(id) + | Path::ChannelsIdPermissionsOverwriteId(id) + | Path::ChannelsIdPins(id) + | Path::ChannelsIdPinsMessageId(id) + | Path::ChannelsIdPolls(id) + | Path::ChannelsIdRecipients(id) + | Path::ChannelsIdThreadMembers(id) + | Path::ChannelsIdThreadMembersId(id) + | Path::ChannelsIdThreads(id) + | Path::ChannelsIdTyping(id) + | Path::ChannelsIdWebhooks(id) + | Path::ApplicationIdEntitlements(id) + | Path::ApplicationIdSKUs(id) + | Path::GuildsId(id) + | Path::GuildsIdAuditLogs(id) + | Path::GuildsIdAutoModerationRules(id) + | Path::GuildsIdAutoModerationRulesId(id) + | Path::GuildsIdBans(id) + | Path::GuildsIdBansId(id) + | Path::GuildsIdBansUserId(id) + | Path::GuildsIdChannels(id) + | Path::GuildsIdEmojis(id) + | Path::GuildsIdEmojisId(id) + | Path::GuildsIdIntegrations(id) + | Path::GuildsIdIntegrationsId(id) + | Path::GuildsIdIntegrationsIdSync(id) + | Path::GuildsIdInvites(id) + | Path::GuildsIdMembers(id) + | Path::GuildsIdMembersId(id) + | Path::GuildsIdMembersIdRolesId(id) + | Path::GuildsIdMembersMeNick(id) + | Path::GuildsIdMembersSearch(id) + | Path::GuildsIdMfa(id) + | Path::GuildsIdOnboarding(id) + | Path::GuildsIdPreview(id) + | Path::GuildsIdPrune(id) + | Path::GuildsIdRegions(id) + | Path::GuildsIdRoles(id) + | Path::GuildsIdRolesId(id) + | Path::GuildsIdScheduledEvents(id) + | Path::GuildsIdScheduledEventsId(id) + | Path::GuildsIdScheduledEventsIdUsers(id) + | Path::GuildsIdStickers(id) + | Path::GuildsIdTemplates(id) + | Path::GuildsIdThreads(id) + | Path::GuildsIdVanityUrl(id) + | Path::GuildsIdVoiceStates(id) + | Path::GuildsIdWebhooks(id) + | Path::GuildsIdWelcomeScreen(id) + | Path::GuildsIdWidget(id) + | Path::GuildsIdWidgetJson(id) + | Path::InteractionCallback(id) + | Path::WebhooksId(id) + | Path::ApplicationGuildCommandId(id) => id.hash(state), + Path::ChannelsIdMessagesId(method, id) => { + method.hash(state); + id.hash(state); + } + Path::GuildsIdTemplatesCode(id, code) => { + id.hash(state); + code.hash(state); + } + Path::GuildsTemplatesCode(code) => { + code.hash(state); + } + Path::WebhooksIdToken(id, token) | Path::WebhooksIdTokenMessagesId(id, token) => { + id.hash(state); + token.hash(state); + } + } + } +} + impl FromStr for Path { type Err = PathParseError; diff --git a/twilight-http-ratelimiting/src/ticket.rs b/twilight-http-ratelimiting/src/ticket.rs deleted file mode 100644 index 1e68fbe6908..00000000000 --- a/twilight-http-ratelimiting/src/ticket.rs +++ /dev/null @@ -1,175 +0,0 @@ -//! Flow for managing ratelimit tickets. -//! -//! Tickets are the [`Ratelimiter`]'s method of managing approval for a consumer -//! to be able to send a request. -//! -//! # Ratelimit Consumer -//! -//! ## 1. Requesting a ticket -//! -//! Consumers of a ratelimiter will call [`Ratelimiter::ticket`]. -//! -//! ## 2. Waiting for approval -//! -//! In return consumers will receive a [`TicketReceiver`]. This must be polled -//! in order to know when the ratelimiter has approved a ticket. -//! -//! ## 3. Receiving approval -//! -//! When a ticket is approved and the future resolves, a [`TicketSender`] is -//! provided. This must be used to provide the ratelimiter with the response's -//! ratelimit headers. -//! -//! ## 4. Performing the request -//! -//! Consumers may now execute the HTTP request associated with the ticket. Once -//! a response (or lack of one) is received, the headers [must be parsed] and -//! sent to the ratelimiter via [`TicketSender::headers`]. This completes the -//! cycle. -//! -//! # Ratelimiter -//! -//! ## 1. Initializing a ticket's channels -//! -//! Ratelimiters will accept a request for a ticket when [`Ratelimiter::ticket`] -//! is called. You must call [`channel`] to create a channel between the -//! ratelimiter and the consumer. -//! -//! ## 2. Keeping the consumer waiting -//! -//! [`channel`] will return two halves: [`TicketNotifier`] and -//! [`TicketReceiver`]. Ratelimiters must keep the notifier and give the user -//! the receiver in return. -//! -//! ## 3. Notifying the consumer of ticket approval -//! -//! When any ratelimits have passed and a user is free to perform their request, -//! call [`TicketNotifier::available`]. If the user hasn't canceled their -//! request for a ticket, you will receive a [`TicketHeaders`]. -//! -//! ## 4. Receiving the response's headers -//! -//! The consumer will perform their HTTP request and parse the response's -//! headers. Once the headers (or lack of headers) are available the user will -//! send them along the channel. Poll the provided [`TicketHeaders`] for those -//! headers to complete the cycle. -//! -//! [`Ratelimiter::ticket`]: super::Ratelimiter::ticket -//! [`Ratelimiter`]: super::Ratelimiter -//! [must be parsed]: super::headers - -use crate::headers::RatelimitHeaders; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::sync::oneshot::{self, error::RecvError, Receiver, Sender}; - -/// Receiver to wait for the headers sent by the API consumer. -/// -/// You must poll the future in order to process the headers. If the future -/// results to an error, then the API consumer dropped the sernding half of the -/// channel. You should treat this as if the request happened. -#[derive(Debug)] -pub struct TicketHeaders(Receiver>); - -impl Future for TicketHeaders { - type Output = Result, RecvError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) - } -} - -/// Indicate to the ratelimit consumer that their ticket has been granted and -/// they may now send a request. -#[derive(Debug)] -pub struct TicketNotifier(Sender>>); - -impl TicketNotifier { - /// Signal to the ratelimiter consumer (an HTTP client) that a request may - /// now be performed. - /// - /// A receiver is returned. This must be stored and awaited so that - /// ratelimiting backends can handle the headers that the API consumer will - /// send back, thus completing the cycle. - /// - /// Returns a `None` if the consumer has dropped their - /// [`TicketReceiver`] half. The ticket is considered canceled. - #[must_use] - pub fn available(self) -> Option { - let (tx, rx) = oneshot::channel(); - - self.0.send(tx).ok()?; - - Some(TicketHeaders(rx)) - } -} - -/// Channel receiver to wait for availability of a ratelimit ticket. -/// -/// This is used by the ratelimiter consumer (such as an API client) to wait for -/// an available ratelimit ticket. -/// -/// Once one is available, a [`TicketSender`] will be produced which can be used to -/// send the associated HTTP response's ratelimit headers. -#[derive(Debug)] -pub struct TicketReceiver(Receiver>>); - -impl Future for TicketReceiver { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx).map_ok(TicketSender) - } -} - -/// Channel sender to send response ratelimit information to the ratelimiter. -/// -/// This is used by the ratelimiter consumer (such as an API client) once a -/// request has been granted via [`TicketReceiver`]. -/// -/// If a response results in available ratelimit headers, send them via -/// [`headers`] to the ratelimiter backend. If a response results in an -/// error - such as a server error or request cancellation - send `None`. -/// -/// [`headers`]: Self::headers -#[derive(Debug)] -pub struct TicketSender(Sender>); - -impl TicketSender { - /// Send the response's ratelimit headers to the ratelimiter. - /// - /// This will allow the ratelimiter to complete the cycle and acknowledge - /// that the request has been completed. This must be done so that the - /// ratelimiter can process information such as whether there's a global - /// ratelimit. - /// - /// # Errors - /// - /// Returns the input headers if the ratelimiter has dropped the receiver - /// half. This may happen if the ratelimiter is dropped or if a timeout has - /// occurred. - pub fn headers( - self, - headers: Option, - ) -> Result<(), Option> { - self.0.send(headers) - } -} - -/// Produce a new channel consisting of a sender and receiver. -/// -/// The notifier is to be used by the ratelimiter while the receiver is to be -/// provided to the consumer. -/// -/// Refer to the [module-level] documentation for more information. -/// -/// [module-level]: self -#[must_use] -pub fn channel() -> (TicketNotifier, TicketReceiver) { - let (tx, rx) = oneshot::channel(); - - (TicketNotifier(tx), TicketReceiver(rx)) -} diff --git a/twilight-http/src/client/builder.rs b/twilight-http/src/client/builder.rs index db0f937128a..759476c8aae 100644 --- a/twilight-http/src/client/builder.rs +++ b/twilight-http/src/client/builder.rs @@ -6,7 +6,7 @@ use std::{ sync::{atomic::AtomicBool, Arc}, time::Duration, }; -use twilight_http_ratelimiting::{InMemoryRatelimiter, Ratelimiter}; +use twilight_http_ratelimiting::RateLimiter; use twilight_model::channel::message::AllowedMentions; /// A builder for [`Client`]. @@ -15,7 +15,7 @@ use twilight_model::channel::message::AllowedMentions; pub struct ClientBuilder { pub(crate) default_allowed_mentions: Option, pub(crate) proxy: Option>, - pub(crate) ratelimiter: Option>, + ratelimit: bool, remember_invalid_token: bool, pub(crate) default_headers: Option, pub(crate) timeout: Duration, @@ -46,7 +46,7 @@ impl ClientBuilder { http, default_headers: self.default_headers, proxy: self.proxy, - ratelimiter: self.ratelimiter, + ratelimiter: (self.ratelimit && !cfg!(test)).then(RateLimiter::default), timeout: self.timeout, token_invalidated, token: self.token, @@ -72,7 +72,7 @@ impl ClientBuilder { /// /// Set the proxy to `twilight_http_proxy.internal`: /// - /// ``` + /// ```no_run /// use twilight_http::Client; /// /// # fn main() -> Result<(), Box> { @@ -90,16 +90,12 @@ impl ClientBuilder { self } - /// Set a ratelimiter to use. - /// - /// If the argument is `None` then the client's ratelimiter will be skipped - /// before making a request. + /// Whether to rate limit requests. /// - /// If this method is not called at all then a default [`InMemoryRatelimiter`] will be - /// created by [`ClientBuilder::build`]. + /// Defaults to true. #[allow(clippy::missing_const_for_fn)] - pub fn ratelimiter(mut self, ratelimiter: Option>) -> Self { - self.ratelimiter = ratelimiter; + pub fn ratelimiter(mut self, ratelimit: bool) -> Self { + self.ratelimit = ratelimit; self } @@ -152,12 +148,11 @@ impl ClientBuilder { impl Default for ClientBuilder { fn default() -> Self { - #[allow(clippy::box_default)] Self { default_allowed_mentions: None, default_headers: None, proxy: None, - ratelimiter: Some(Box::new(InMemoryRatelimiter::default())), + ratelimit: true, remember_invalid_token: true, timeout: Duration::from_secs(10), token: None, diff --git a/twilight-http/src/client/mod.rs b/twilight-http/src/client/mod.rs index 0379c914381..1ea37a10963 100644 --- a/twilight-http/src/client/mod.rs +++ b/twilight-http/src/client/mod.rs @@ -114,7 +114,7 @@ use std::{ time::Duration, }; use tokio::time; -use twilight_http_ratelimiting::Ratelimiter; +use twilight_http_ratelimiting::RateLimiter; use twilight_model::{ channel::{message::AllowedMentions, ChannelType}, guild::{ @@ -246,7 +246,7 @@ pub struct Client { default_headers: Option, http: HyperClient>, proxy: Option>, - ratelimiter: Option>, + ratelimiter: Option, timeout: Duration, /// Whether the token has been invalidated. /// @@ -329,14 +329,6 @@ impl Client { self.default_allowed_mentions.as_ref() } - /// Get the Ratelimiter used by the client internally. - /// - /// This will return `None` only if ratelimit handling - /// has been explicitly disabled in the [`ClientBuilder`]. - pub fn ratelimiter(&self) -> Option<&dyn Ratelimiter> { - self.ratelimiter.as_ref().map(AsRef::as_ref) - } - /// Get an auto moderation rule in a guild. /// /// Requires the [`MANAGE_GUILD`] permission. @@ -2999,9 +2991,9 @@ impl Client { .flatten(); Ok(if let Some(ratelimiter) = &self.ratelimiter { - let tx_future = ratelimiter.wait_for_ticket(ratelimit_path); + let rx = ratelimiter.acquire(ratelimit_path); - ResponseFuture::ratelimit(invalid_token, inner, self.timeout, tx_future) + ResponseFuture::ratelimit(invalid_token, inner, self.timeout, rx) } else { ResponseFuture::new(Box::pin(time::timeout(self.timeout, inner)), invalid_token) }) diff --git a/twilight-http/src/response/future.rs b/twilight-http/src/response/future.rs index 74f63d3c11e..51987524c46 100644 --- a/twilight-http/src/response/future.rs +++ b/twilight-http/src/response/future.rs @@ -18,7 +18,7 @@ use std::{ time::Duration, }; use tokio::time::{self, Timeout}; -use twilight_http_ratelimiting::{ticket::TicketSender, RatelimitHeaders, WaitForTicketFuture}; +use twilight_http_ratelimiting::{Headers, Permit, PermitFuture}; type Output = Result, Error>; @@ -75,7 +75,7 @@ impl Failed { struct InFlight { future: Pin>>, invalid_token: Option>, - tx: Option, + tx: Option, } impl InFlight { @@ -112,14 +112,12 @@ impl InFlight { .iter() .map(|(key, value)| (key.as_str(), value.as_bytes())); - match RatelimitHeaders::from_pairs(headers) { - Ok(v) => { - let _res = tx.headers(Some(v)); - } + match Headers::from_pairs(headers) { + Ok(v) => tx.complete(v), Err(source) => { tracing::warn!("header parsing failed: {source:?}; {resp:?}"); - let _res = tx.headers(None); + tx.complete(None); } } } @@ -171,20 +169,13 @@ struct RatelimitQueue { response_future: HyperResponseFuture, timeout: Duration, pre_flight_check: Option bool + Send + 'static>>, - wait_for_sender: WaitForTicketFuture, + rx: PermitFuture, } impl RatelimitQueue { fn poll(mut self, cx: &mut Context<'_>) -> InnerPoll { - let tx = match Pin::new(&mut self.wait_for_sender).poll(cx) { - Poll::Ready(Ok(tx)) => tx, - Poll::Ready(Err(source)) => { - return InnerPoll::Ready(Err(Error { - kind: ErrorType::RatelimiterTicket, - source: Some(source), - })) - } - Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::RatelimitQueue(self)), + let Poll::Ready(tx) = Pin::new(&mut self.rx).poll(cx) else { + return InnerPoll::Pending(ResponseFutureStage::RatelimitQueue(self)); }; if let Some(pre_flight_check) = self.pre_flight_check { @@ -351,7 +342,7 @@ impl ResponseFuture { invalid_token: Option>, response_future: HyperResponseFuture, timeout: Duration, - wait_for_sender: WaitForTicketFuture, + wait_for_sender: PermitFuture, ) -> Self { Self { phantom: PhantomData, @@ -360,7 +351,7 @@ impl ResponseFuture { response_future, timeout, pre_flight_check: None, - wait_for_sender, + rx: wait_for_sender, }), } } diff --git a/twilight-http/src/routing.rs b/twilight-http/src/routing.rs index cffa735f22b..1f1719ca3f8 100644 --- a/twilight-http/src/routing.rs +++ b/twilight-http/src/routing.rs @@ -1421,7 +1421,7 @@ impl Route<'_> { /// /// Use a route's path to retrieve a ratelimiter ticket: /// - /// ``` + /// ```ignore /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// use twilight_http::routing::Route;