Skip to content

Commit

Permalink
throttle: add gcra_throttle() mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Sep 19, 2024
1 parent 430515f commit 1bab348
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 11 deletions.
2 changes: 1 addition & 1 deletion crates/kumo-server-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ pub fn register(lua: &Lua) -> anyhow::Result<()> {
let key: RedisConnKey = from_lua_value(lua, params)?;
let conn = key.open().map_err(any_err)?;
conn.ping().await.map_err(any_err)?;
throttle::use_redis(conn).map_err(any_err)
throttle::use_redis(conn).await.map_err(any_err)
})?,
)?;

Expand Down
33 changes: 29 additions & 4 deletions crates/throttle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
//! to support using a redis-cell equipped redis server to share the throttles
//! among multiple machines.
#[cfg(feature = "redis")]
use mod_redis::{RedisConnection, RedisError};
use mod_redis::{Cmd, RedisConnection, RedisError};

#[cfg(feature = "redis")]
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::ops::Deref;
use std::time::Duration;
use thiserror::Error;

Expand All @@ -17,8 +18,21 @@ pub mod limit;
#[cfg(feature = "redis")]
mod throttle;

#[derive(Clone)]
struct RedisContext {
connection: RedisConnection,
cl_throttle_command: bool,
}

impl Deref for RedisContext {
type Target = RedisConnection;
fn deref(&self) -> &Self::Target {
&self.connection
}
}

#[cfg(feature = "redis")]
static REDIS: OnceCell<RedisConnection> = OnceCell::new();
static REDIS: OnceCell<RedisContext> = OnceCell::new();

#[derive(Error, Debug)]
pub enum Error {
Expand Down Expand Up @@ -184,9 +198,20 @@ pub struct ThrottleResult {
}

#[cfg(feature = "redis")]
pub fn use_redis(conn: RedisConnection) -> Result<(), Error> {
pub async fn use_redis(conn: RedisConnection) -> Result<(), Error> {
let mut cmd = Cmd::new();
cmd.arg("COMMAND").arg("INFO").arg("CL.THROTTLE");

let cl_throttle_command = match conn.query(cmd).await {
Ok(_) => true,
Err(_) => false,
};

REDIS
.set(conn)
.set(RedisContext {
connection: conn,
cl_throttle_command,
})
.map_err(|_| Error::Generic("redis already configured for throttles".to_string()))?;
Ok(())
}
Expand Down
7 changes: 4 additions & 3 deletions crates/throttle/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ enum Backend {
impl LimitSpec {
pub async fn acquire_lease<S: AsRef<str>>(&self, key: S) -> Result<LimitLease, Error> {
if let Some(redis) = REDIS.get().cloned() {
self.acquire_lease_redis(redis, key.as_ref()).await
self.acquire_lease_redis(redis.connection, key.as_ref())
.await
} else {
self.acquire_lease_memory(key.as_ref()).await
}
Expand Down Expand Up @@ -140,7 +141,7 @@ impl LimitLease {
Backend::Memory => self.release_memory().await,
Backend::Redis => {
if let Some(redis) = REDIS.get().cloned() {
self.release_redis(redis).await;
self.release_redis(redis.connection).await;
} else {
eprintln!("LimitLease::release: backend is Redis but REDIS is not set");
}
Expand All @@ -153,7 +154,7 @@ impl LimitLease {
Backend::Memory => self.extend_memory(duration).await,
Backend::Redis => {
if let Some(redis) = REDIS.get().cloned() {
self.extend_redis(redis, duration).await
self.extend_redis(redis.connection, duration).await
} else {
Err(anyhow::anyhow!(
"LimitLease::extend: backend is Redis but REDIS is not set"
Expand Down
108 changes: 105 additions & 3 deletions crates/throttle/src/throttle.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,74 @@
use crate::{Error, ThrottleResult, REDIS};
use mod_redis::{Cmd, FromRedisValue, RedisConnection};
use once_cell::sync::OnceCell;
use anyhow::Context;
use mod_redis::{Cmd, FromRedisValue, RedisConnection, Script};
use once_cell::sync::{Lazy, OnceCell};
use redis_cell_impl::{time, MemoryStore, Rate, RateLimiter, RateQuota};
use std::sync::Mutex;
use std::time::Duration;

static MEMORY: OnceCell<Mutex<MemoryStore>> = OnceCell::new();

// Adapted from https://github.com/Losant/redis-gcra/blob/master/lib/gcra.lua
static GCRA_SCRIPT: Lazy<Script> = Lazy::new(|| {
Script::new(
r#"
local key = KEYS[1]
local limit = ARGV[2]
local period = ARGV[3]
local max_burst = ARGV[4]
local quantity = ARGV[5]
local interval = period / limit
local increment = interval * quantity
local burst_offset = interval * burst
local time = tonumber(redis.call("TIME")[1])
local tat = redis.call("GET", key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local throttled
local reset_after
local retry_after
local remaining = math.floor(diff / interval) -- poor man's round
if remaining < 0 then
throttled = 1
-- calculate how many tokens there actually are, since
-- remaining is how many there would have been if we had been able to limit
-- and we did not limit
remaining = math.floor((now - (tat - burst_offset)) / interval)
reset_after = math.ceil(tat - now)
retry_after = math.ceil(diff * -1)
elseif remaining == 0 and increment <= 0 then
-- request with cost of 0
-- cost of 0 with remaining 0 is still limited
throttled = 1
remaining = 0
reset_after = math.ceil(tat - now)
retry_in = 0 -- retry_in is meaningless when quantity is 0
else
throttled = 0
reset_after = math.ceil(new_tat - now)
retry_after = 0
redis.call("SET", key, new_tat, "PX", reset_after)
end
return {throttled, remaining, reset_after, retry_after, tostring(diff), tostring(interval)}
"#,
)
});

fn local_throttle(
key: &str,
limit: u64,
Expand Down Expand Up @@ -89,6 +151,41 @@ async fn redis_throttle(
})
}

async fn gcra_throttle(
conn: RedisConnection,
key: &str,
limit: u64,
period: Duration,
max_burst: u64,
quantity: Option<u64>,
) -> Result<ThrottleResult, Error> {
let mut script = GCRA_SCRIPT.prepare_invoke();
script
.key(key)
.arg(limit)
.arg(period.as_secs())
.arg(max_burst)
.arg(quantity.unwrap_or(1));

let result = conn
.invoke_script(script)
.await
.context("error invoking redis GCRA script")?;
let result =
<(u64, u64, u64, u64, String, String) as FromRedisValue>::from_redis_value(&result)?;

Ok(ThrottleResult {
throttled: result.0 == 1,
limit: max_burst + 1,
remaining: result.1,
reset_after: Duration::from_secs(result.2),
retry_after: match result.3 {
0 => None,
n => Some(Duration::from_secs(n as u64)),
},
})
}

/// It is very important for `key` to be used with the same `limit`,
/// `period` and `max_burst` values in order to produce meaningful
/// results.
Expand Down Expand Up @@ -120,7 +217,12 @@ pub async fn throttle(
if force_local {
local_throttle(key, limit, period, max_burst, quantity)
} else if let Some(redis) = REDIS.get().cloned() {
redis_throttle(redis, key, limit, period, max_burst, quantity).await
match redis.cl_throttle_command {
true => gcra_throttle(redis.connection, key, limit, period, max_burst, quantity).await,
false => {
redis_throttle(redis.connection, key, limit, period, max_burst, quantity).await
}
}
} else {
local_throttle(key, limit, period, max_burst, quantity)
}
Expand Down

0 comments on commit 1bab348

Please sign in to comment.