Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

throttle: add support for Redis installs without redis-cell module #282

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/kumo-server-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ pub fn register(lua: &Lua) -> anyhow::Result<()> {
let key: RedisConnKey = from_lua_value(lua, params)?;
let conn = key.open().map_err(any_err)?;
conn.ping().await.map_err(any_err)?;
throttle::use_redis(conn).map_err(any_err)
throttle::use_redis(conn).await.map_err(any_err)
})?,
)?;

Expand Down
2 changes: 1 addition & 1 deletion crates/mod-redis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Manager for ClientManager {
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct RedisConnection(Arc<RedisConnKey>);

impl RedisConnection {
Expand Down
5 changes: 2 additions & 3 deletions crates/throttle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["impl"]
impl = ["dep:redis-cell-impl", "dep:mod-redis"]
default = ["redis"]
redis = ["dep:redis-cell-impl", "dep:mod-redis"]

[dependencies]
anyhow = "1.0"
mod-redis = {path="../mod-redis", optional=true}
once_cell = "1.17"
redis-cell-impl = { git = "https://github.com/wez/redis-cell.git", rev="97d409c3a62f2a0f5518c31fc9b4b65afbce2053" , optional=true}
serde = {version="1.0", features=["derive"]}
thiserror = "1.0"
Expand Down
266 changes: 59 additions & 207 deletions crates/throttle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,77 @@
//! The implementation uses an in-memory store, but can be adjusted in the future
//! to support using a redis-cell equipped redis server to share the throttles
//! among multiple machines.
#[cfg(feature = "impl")]
use mod_redis::{Cmd, FromRedisValue, RedisConnection, RedisError};
#[cfg(feature = "impl")]
use once_cell::sync::OnceCell;
#[cfg(feature = "impl")]
use redis_cell_impl::{time, MemoryStore, Rate, RateLimiter, RateQuota};
#[cfg(feature = "redis")]
use mod_redis::RedisError;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
#[cfg(feature = "impl")]
use std::sync::Mutex;
use std::time::Duration;
use thiserror::Error;

#[cfg(feature = "impl")]
#[cfg(feature = "redis")]
pub mod limit;
#[cfg(feature = "redis")]
mod throttle;
djc marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(feature = "impl")]
static MEMORY: OnceCell<Mutex<MemoryStore>> = OnceCell::new();
#[cfg(feature = "impl")]
static REDIS: OnceCell<RedisConnection> = OnceCell::new();
#[cfg(feature = "redis")]
mod redis {
use super::*;
use mod_redis::{Cmd, RedisConnection, RedisValue};
use std::ops::Deref;
use std::sync::OnceLock;

#[derive(Debug)]
pub(crate) struct RedisContext {
pub(crate) connection: RedisConnection,
pub(crate) has_redis_cell: bool,
}

impl RedisContext {
pub async fn try_from(connection: RedisConnection) -> anyhow::Result<Self> {
let mut cmd = Cmd::new();
cmd.arg("COMMAND").arg("INFO").arg("CL.THROTTLE");

let rsp = connection.query(cmd).await?;
let has_redis_cell = rsp
.as_sequence()
.map_or(false, |arr| arr.iter().any(|v| v != &RedisValue::Nil));

Ok(Self {
has_redis_cell,
connection,
})
}
}

impl Deref for RedisContext {
type Target = RedisConnection;
fn deref(&self) -> &Self::Target {
&self.connection
}
}

pub(crate) static REDIS: OnceLock<RedisContext> = OnceLock::new();

pub async fn use_redis(conn: RedisConnection) -> Result<(), Error> {
REDIS
.set(RedisContext::try_from(conn).await?)
.map_err(|_| Error::Generic("redis already configured for throttles".to_string()))?;
Ok(())
}
}

#[cfg(feature = "redis")]
pub use redis::use_redis;
#[cfg(feature = "redis")]
pub(crate) use redis::REDIS;

#[derive(Error, Debug)]
pub enum Error {
#[error("{0}")]
Generic(String),
#[error("{0}")]
AnyHow(#[from] anyhow::Error),
#[cfg(feature = "impl")]
#[cfg(feature = "redis")]
#[error("{0}")]
Redis(#[from] RedisError),
#[error("TooManyLeases, try again in {0:?}")]
Expand All @@ -48,7 +91,7 @@ pub struct ThrottleSpec {
pub force_local: bool,
}

#[cfg(feature = "impl")]
#[cfg(feature = "redis")]
impl ThrottleSpec {
pub async fn throttle<S: AsRef<str>>(&self, key: S) -> Result<ThrottleResult, Error> {
self.throttle_quantity(key, 1).await
Expand All @@ -64,7 +107,7 @@ impl ThrottleSpec {
let period = self.period;
let max_burst = self.max_burst.unwrap_or(limit);
let key = format!("{key}:{limit}:{max_burst}:{period}");
throttle(
throttle::throttle(
&key,
limit,
Duration::from_secs(period),
Expand Down Expand Up @@ -186,201 +229,10 @@ pub struct ThrottleResult {
pub retry_after: Option<Duration>,
}

#[cfg(feature = "impl")]
fn local_throttle(
key: &str,
limit: u64,
period: Duration,
max_burst: u64,
quantity: Option<u64>,
) -> Result<ThrottleResult, Error> {
let mut store = MEMORY
.get_or_init(|| Mutex::new(MemoryStore::new()))
.lock()
.unwrap();
let max_rate = Rate::per_period(
limit as i64,
time::Duration::try_from(period).map_err(|err| Error::Generic(format!("{err:#}")))?,
);
let mut limiter = RateLimiter::new(
&mut *store,
&RateQuota {
max_burst: max_burst.min(limit - 1) as i64,
max_rate,
},
);
let quantity = quantity.unwrap_or(1) as i64;
let (throttled, rate_limit_result) = limiter
.rate_limit(key, quantity)
.map_err(|err| Error::Generic(format!("{err:#}")))?;

// If either time had a partial component, bump it up to the next full
// second because otherwise a fast-paced caller could try again too
// early.
let mut retry_after = rate_limit_result.retry_after.whole_seconds();
if rate_limit_result.retry_after.subsec_milliseconds() > 0 {
retry_after += 1
}
let mut reset_after = rate_limit_result.reset_after.whole_seconds();
if rate_limit_result.reset_after.subsec_milliseconds() > 0 {
reset_after += 1
}

Ok(ThrottleResult {
throttled,
limit: rate_limit_result.limit as u64,
remaining: rate_limit_result.remaining as u64,
reset_after: Duration::from_secs(reset_after.max(0) as u64),
retry_after: if retry_after == -1 {
None
} else {
Some(Duration::from_secs(retry_after.max(0) as u64))
},
})
}

#[cfg(feature = "impl")]
async fn redis_throttle(
conn: RedisConnection,
key: &str,
limit: u64,
period: Duration,
max_burst: u64,
quantity: Option<u64>,
) -> Result<ThrottleResult, Error> {
let mut cmd = Cmd::new();
cmd.arg("CL.THROTTLE")
.arg(key)
.arg(max_burst)
.arg(limit)
.arg(period.as_secs())
.arg(quantity.unwrap_or(1));
let result = conn.query(cmd).await?;
let result = <Vec<i64> as FromRedisValue>::from_redis_value(&result)?;

Ok(ThrottleResult {
throttled: result[0] != 0,
limit: result[1] as u64,
remaining: result[2] as u64,
retry_after: match result[3] {
n if n < 0 => None,
n => Some(Duration::from_secs(n as u64)),
},
reset_after: Duration::from_secs(result[4].max(0) as u64),
})
}

/// It is very important for `key` to be used with the same `limit`,
/// `period` and `max_burst` values in order to produce meaningful
/// results.
///
/// This interface cannot detect or report that kind of misuse.
/// It is recommended that those parameters be encoded into the
/// key to make it impossible to misuse.
///
/// * `limit` - specifies the maximum number of tokens allow
/// over the specified `period`
/// * `period` - the time period over which `limit` is allowed.
/// * `max_burst` - the maximum initial burst that will be permitted.
/// set this smaller than `limit` to prevent using
/// up the entire budget immediately and force it
/// to spread out across time.
/// * `quantity` - how many tokens to add to the throttle. If omitted,
/// 1 token is added.
/// * `force_local` - if true, always use the in-memory store on the local
/// machine even if the redis backend has been configured.
#[cfg(feature = "impl")]
pub async fn throttle(
key: &str,
limit: u64,
period: Duration,
max_burst: u64,
quantity: Option<u64>,
force_local: bool,
) -> Result<ThrottleResult, Error> {
if force_local {
local_throttle(key, limit, period, max_burst, quantity)
} else if let Some(redis) = REDIS.get().cloned() {
redis_throttle(redis, key, limit, period, max_burst, quantity).await
} else {
local_throttle(key, limit, period, max_burst, quantity)
}
}

#[cfg(feature = "impl")]
pub fn use_redis(conn: RedisConnection) -> Result<(), Error> {
REDIS
.set(conn)
.map_err(|_| Error::Generic("redis already configured for throttles".to_string()))?;
Ok(())
}

#[cfg(feature = "impl")]
#[cfg(test)]
mod test {
use super::*;

fn test_big_limits(limit: u64, max_burst: Option<u64>, permitted_tolerance: f64) {
let period = Duration::from_secs(60);
let max_burst = max_burst.unwrap_or(limit);
let key = format!("test_big_limits-{limit}-{max_burst}");

let mut throttled_iter = None;

for i in 0..limit * 2 {
let result = local_throttle(&key, limit, period, max_burst, None).unwrap();
if result.throttled {
println!("iter: {i} -> {result:?}");
throttled_iter.replace(i);
break;
}
}

let throttled_iter = throttled_iter.expect("to hit the throttle limit");
let diff = ((max_burst as f64) - (throttled_iter as f64)).abs();
let tolerance = (max_burst as f64) * permitted_tolerance;
println!(
"throttled after {throttled_iter} iterations for \
limit {limit}. diff={diff}. tolerance {tolerance}"
);
let max_rate = Rate::per_period(limit as i64, time::Duration::try_from(period).unwrap());
println!("max_rate: {max_rate:?}");

assert!(
diff < tolerance,
"throttled after {throttled_iter} iterations for \
limit {limit}. diff={diff} is not within tolerance {tolerance}"
);
}

#[test]
fn basic_throttle_100() {
test_big_limits(100, None, 0.01);
}

#[test]
fn basic_throttle_1_000() {
test_big_limits(1_000, None, 0.02);
}

#[test]
fn basic_throttle_6_000() {
test_big_limits(6_000, None, 0.02);
}

#[test]
fn basic_throttle_60_000() {
test_big_limits(60_000, None, 0.05);
}

#[test]
fn basic_throttle_60_000_burst_30k() {
// Note that the 5% tolerance here is the same as the basic_throttle_60_000
// test case because the variance is due to timing issues with very small
// time periods produced by the overally limit, rather than the burst.
test_big_limits(60_000, Some(30_000), 0.05);
}

#[test]
fn throttle_spec_parse() {
assert_eq!(
Expand Down
Loading