diff --git a/.gitignore b/.gitignore index d1bd249..31cf8db 100644 --- a/.gitignore +++ b/.gitignore @@ -26,9 +26,11 @@ Thumbs.db # Logs *.log -# Note: Cargo.lock is intentionally NOT ignored. -# Because this project contains an executable (server-native), +# Note: Cargo.lock is intentionally NOT ignored. +# Because this project contains an executable (server-native), # committing Cargo.lock is best practice to ensure reproducible builds! PLAN.md NOTES.md + +scripts/ diff --git a/CHANGELOG.md b/CHANGELOG.md index d5d5a43..2e404d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,86 @@ All notable changes to Recached are documented here. --- +## [0.1.7] — 2026-06-12 + +### Fixed + +**Security** +- Replication auth password was compared with `!=` (byte-by-byte), leaking timing information an attacker could use to brute-force the password one character at a time. Replaced with a constant-time XOR-fold comparison. (`server-native/src/main.rs`) +- The client `AUTH` command compared the supplied password with `==` (`String` equality, short-circuiting on the first mismatched byte) — the same timing side-channel the replication path was already hardened against. `process_auth` now uses the constant-time comparison. (`server-native/src/main.rs`) +- Replication frame length prefixes (snapshot and per-command) were read and allocated without an upper bound. Because the replication port may be unauthenticated and plaintext, a peer or MITM could send a 4 GB length prefix and force a matching allocation per frame (memory DoS). Frames are now capped at 512 MB before allocation. (`server-native/src/main.rs`) +- `auth()` in the WASM SDK now emits a `console.warn` when the active connection is an unencrypted `ws://` URL, alerting developers that the password is sent in plaintext. Production deployments should use `wss://`. (`wasm-edge/src/lib.rs`) + +**Correctness** +- WebSocket `connect()` followed by `auth()` raced the socket handshake: `createCache` sent `AUTH` (and any early `set`/`del`/`subscribe`/`publish`) while the socket was still `CONNECTING`, so the frames were silently dropped — server sync was completely broken whenever `RECACHED_PASSWORD` was set, and early writes were lost otherwise. Commands issued before the socket opens are now buffered and flushed in FIFO order by an `onopen` handler. (`wasm-edge/src/lib.rs`, `wasm-edge/sdk.ts`) +- `MULTI`/`EXEC` did not honour `WATCH`: a watched key changing before `EXEC` did not abort the transaction, so the standard Redis optimistic-locking (compare-and-swap) pattern silently lost updates. `EXEC` now returns a nil array when any watched key has changed since `WATCH`; `EXEC` and `DISCARD` clear all watches; and `WATCH`/`UNWATCH` inside `MULTI` are rejected. Works over both the TCP and WebSocket ports. (`server-native/src/main.rs`) +- AOF replay restored nothing on the live server. Writes are recorded via `on_write` in RESP3 Push (`>`) form, but `replay_aof` passed parsed frames straight to `Command::from_value`, which only accepts arrays — so every replayed frame was rejected and skipped (the existing test masked this by feeding `*`-array frames). Replay now normalises Push→Array, matching the replica stream path. (`server-native/src/main.rs`) +- `SPOP` and `SRANDMEMBER` returned members in `HashMap` iteration order rather than randomly, and positive `SRANDMEMBER count` was non-random while negative count was fully deterministic (`members[i % len]`). All now sample randomly, matching Redis. (`core-engine/src/store.rs`) +- `allkeys-lru` / `volatile-lru` eviction ranked entries by last *write* time and never updated it on reads, so a hot, frequently-read key could be evicted as if it were cold. Entries now carry an atomic last-access timestamp refreshed on the main read paths (`GET`, `MGET`, `HGET`/`HGETALL`, `LRANGE`, `SMEMBERS`, `SISMEMBER`, `ZSCORE`, and the sorted-set range reads), giving true access-based LRU. (`core-engine/src/store.rs`) +- `SCAN` ignored its `COUNT` argument and returned the entire matching keyspace in one reply at cursor `0`, defeating its purpose as the non-blocking alternative to `KEYS`. It now returns at most `COUNT` keys per call (default 10) with a real next-cursor for incremental iteration. (`core-engine/src/store.rs`) +- A read-only replica applied writes streamed from the primary but never re-broadcast them, so the replica's own WebSocket clients received no live updates and multi-tier (chained) replication was impossible. Replicas now relay each applied write to their local WebSocket clients and run a replication server so they can serve sub-replicas. (`server-native/src/main.rs`) +- Local writes through the SDK fired the mutation callback twice — once from the Rust layer and again in `sdk.ts` — causing redundant `useSyncExternalStore` re-renders. The duplicate notification was removed. (`wasm-edge/sdk.ts`) +- `SRANDMEMBER key -N` panicked with a divide-by-zero when the target key did not exist or the set was empty. An early-return guard now produces an empty array, matching Redis semantics. (`core-engine/src/store.rs`) +- `ZINCRBY` did not validate the resulting score for NaN or Infinity before writing it into the sorted set, corrupting subsequent range queries when called with `+inf`/`-inf` deltas. The result is now pre-computed and rejected with `ERR increment would produce NaN or Infinity` if invalid — consistent with `HINCRBYFLOAT`. (`core-engine/src/store.rs`) +- `DECRBY` used `extract_string` (no size limit) for key parsing while `INCRBY` used `extract_key` (≤ 512 KB). Keys larger than 512 KB sent via `DECRBY` now return an error, consistent with all other key-bearing commands. (`core-engine/src/cmd.rs`) +- `SET … EX ` with a TTL value large enough to overflow `u64` when multiplied by 1 000 silently saturated to `u64::MAX`, making the key effectively immortal. Such values now return `ERR TTL overflow`. (`core-engine/src/store.rs`) +- Vue `useKey` and `useKeyJSON` read the initial value before subscribing to mutations, leaving a narrow window where a write between `get()` and `onMutation()` was missed. The subscription is now registered first, then the initial value is read. (`recached-vue/src/useKey.ts`) +- React `usePubSub` captured the `handler` closure at subscribe time and held it for the lifetime of the effect. Inline handlers (redefined each render) would go stale and never receive updated closure state. The hook now stores the latest handler in a `useRef` and calls through it — no re-subscribe needed when the handler changes. (`recached-react/src/usePubSub.ts`) + +**Performance** +- Memory-limit eviction (`RECACHED_MAX_MEMORY`) was O(N²): `try_evict_for_memory` re-scanned the entire keyspace to recompute total memory after every single eviction, on the 1-second background sweep — stalling the server under exactly the memory pressure it was meant to relieve. It now measures total memory once and maintains it incrementally by subtracting each evicted entry's measured size, with a periodic re-sync to correct drift. (`core-engine/src/store.rs`) + +**DoS / resilience** +- The RESP array parser bounded each bulk string at 64 MB but applied no limit to the total number of elements, making it possible to stream 1 million small strings and force ~64 TB of cumulative allocation before rejection. A 64 MB cumulative-bytes check is now applied across the entire array parse loop. (`core-engine/src/resp.rs`) +- The RESP3 Push (`>`) parser lacked the cumulative-size guard the array parser has, so the replica and AOF parse paths would accept arbitrarily large push frames. The same 64 MB cumulative check is now applied. (`core-engine/src/resp.rs`) +- The glob matcher used by `KEYS` and `SCAN` was a recursive function with no memoization or depth limit. Patterns such as `*.*.*.*x` against a long non-matching string caused exponential backtracking (ReDoS). The implementation is replaced with an iterative two-row DP algorithm that is strictly O(m × n). (`core-engine/src/store.rs`) + +**Portability** +- On non-Unix platforms the server bound one listener socket per CPU core to the same port, relying on `SO_REUSEPORT` (Unix-only). Without it the second bind failed and the process exited at startup. Non-Unix builds now fall back to a single accept loop. (`server-native/src/main.rs`) + +**Resource management** +- Calling `connect()` a second time on a `RecachedCache` instance replaced the internal `WebSocket` field without closing the previous socket. The old connection remained open, receiving stale messages. `connect()` now calls `.close()` on the existing socket before creating the new one. (`wasm-edge/src/lib.rs`) + +### Added + +- **`RECACHED_BIND`** — new env var controlling the network interface every listener (TCP, WebSocket, replication, metrics) binds to. Defaults to `0.0.0.0` for backwards compatibility; set `127.0.0.1` (or a specific private interface) to keep the server off public interfaces. A startup warning is logged when bound to all interfaces. (`server-native/src/main.rs`) +- **`WATCH` / `UNWATCH` over TCP** — optimistic-lock (CAS) `WATCH` is now available on the RESP/TCP port (6379), not just WebSocket. TCP clients receive no keychange push (it would break the request/response protocol); they use `WATCH` purely for the `EXEC` abort guarantee. (`server-native/src/main.rs`) + +### Changed + +- `WATCH`/`UNWATCH` semantics: previously WebSocket-only "observable keys" with no transactional effect, they now provide Redis-compatible optimistic locking on both transports. Over WebSocket, `WATCH` additionally pushes live keychange notifications as before. The store no longer returns `ERR WATCH/UNWATCH only supported over WebSocket`. (`server-native/src/main.rs`, `docs/server/commands.md`) +- The replication server now runs on every node, including replicas, so a replica can in turn serve sub-replicas (multi-tier replication). (`server-native/src/main.rs`) +- The `Entry` struct's `written_at_ms` field was replaced with an atomic `last_access_ms`, refreshed on reads, to back access-based LRU eviction. (`core-engine/src/store.rs`) +- Documented that WebSocket uses text frames, so values must be valid UTF-8 (non-UTF-8 bytes are replaced lossily); raw binary values are fully round-trippable only over the TCP port. The SDK's string-typed `set` API is unaffected. (`server-native/src/main.rs`) + +--- + +## [0.1.6] — 2026-05-11 + +### Fixed + +**Correctness** +- `TTL` / `PTTL`: replaced `exp - now` with `exp.saturating_sub(now)` — a tight race between the expiry check and the subtraction could panic in debug builds or wrap to `u64::MAX` in release, returning a wildly incorrect TTL. (`core-engine/src/store.rs`) +- `DEL` / `UNLINK`: switched from `data.remove(k)` to `data.remove_if(k, |_, e| !e.is_expired(now))` — expired-but-not-yet-swept keys were counted as deleted, violating Redis semantics which returns 0 for missing/expired keys. (`core-engine/src/store.rs`) +- `ZADD GT` / `LT` flags were parsed and silently discarded. They are now fully enforced: `GT` updates an existing member only if the new score is greater; `LT` only if lower; new members are always inserted regardless of the flag. Incompatible combinations (`GT`+`LT`, `GT`/`LT`+`NX`) return errors matching Redis. (`core-engine/src/cmd.rs`, `core-engine/src/store.rs`) + +**Security** +- `Command::Auth` reached `store.execute()` and unconditionally returned `+OK`, bypassing authentication during AOF replay and any other path that calls the store directly. `store.execute()` now returns an error for `Auth` — authentication is handled exclusively by the connection-layer `process_auth` function. (`core-engine/src/store.rs`) + +**Performance / reliability** +- `PubSubHub::unsubscribe` left an empty `Vec` in `channel_subs` after the last subscriber left a channel. Over time, high-churn subscriber patterns leaked memory proportional to the total number of unique channels ever seen. Empty entries are now removed immediately in `unsubscribe`, `unsubscribe_all`, and `publish`. (`server-native/src/main.rs`) +- `SharedPubSub` and `WatchRegistry` used `std::sync::Mutex` (blocking) in async connection handlers. Holding a blocking lock across `.await` points starves the Tokio thread pool under high pub/sub publish rates. Both types now use `tokio::sync::Mutex`; `notify_watchers` is now `async`. (`server-native/src/main.rs`) + +### Added + +- Key-length validation in `Command::from_value`: keys larger than 512 KB or empty keys are rejected at parse time with a descriptive `ERR` before reaching the store. Validation is applied to all primary-key positions in `GET`, `SET`, `DEL`, `UNLINK`, `MGET`, `MSET`, `EXISTS`, `APPEND`, `STRLEN`, `GETSET`, `SETNX`, `SETEX`, `PSETEX`, `INCR`, `DECR`, `INCRBY`, and all commands that assign `let key = …`. (`core-engine/src/cmd.rs`) + +### Changed + +- `format_score` (f64 → Redis score string) is now `pub` and exported from `core-engine::store`. The identical private `format_zset_score` function in `wasm-edge` has been removed in favour of the shared implementation. (`core-engine/src/store.rs`, `wasm-edge/src/lib.rs`) + +--- + ## [0.1.5] — 2026-05-10 ### Added diff --git a/Cargo.lock b/Cargo.lock index d4e5ec8..9cca6f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,7 +107,7 @@ dependencies = [ [[package]] name = "core-engine" -version = "0.1.5" +version = "0.1.7" dependencies = [ "dashmap", "rand", @@ -367,6 +367,12 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "http" version = "1.4.0" @@ -464,7 +470,7 @@ dependencies = [ "hyper", "libc", "pin-project-lite", - "socket2", + "socket2 0.6.3", "tokio", "tower-service", "tracing", @@ -632,6 +638,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -975,15 +991,18 @@ dependencies = [ [[package]] name = "server-native" -version = "0.1.5" +version = "0.1.7" dependencies = [ "core-engine", "futures-util", "metrics", "metrics-exporter-prometheus", + "num_cpus", "rmp-serde", "rustls-pemfile", "serde", + "socket2 0.5.10", + "tikv-jemallocator", "tokio", "tokio-rustls", "tokio-tungstenite", @@ -1045,6 +1064,16 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socket2" version = "0.6.3" @@ -1121,6 +1150,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "tikv-jemalloc-sys" +version = "0.6.1+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8aa5b2ab86a2cefa406d889139c162cbb230092f7d1d7cbc1716405d852a3b" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tikv-jemallocator" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0359b4327f954e0567e69fb191cf1436617748813819c94b8cd4a431422d053a" +dependencies = [ + "libc", + "tikv-jemalloc-sys", +] + [[package]] name = "tokio" version = "1.52.1" @@ -1133,7 +1182,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.6.3", "tokio-macros", "windows-sys 0.61.2", ] @@ -1384,7 +1433,7 @@ dependencies = [ [[package]] name = "wasm-edge" -version = "0.1.5" +version = "0.1.7" dependencies = [ "core-engine", "getrandom 0.3.4", diff --git a/Cargo.toml b/Cargo.toml index 1283898..28ad792 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ resolver = "2" # ── Single source of truth for all crate versions ──────────────────────────── # Members inherit with: version.workspace = true / edition.workspace = true [workspace.package] -version = "0.1.5" +version = "0.1.7" edition = "2024" license = "MIT" authors = ["ThinkGrid Labs"] @@ -44,3 +44,15 @@ rmp-serde = "1" wasm-bindgen = "0.2.92" js-sys = "0.3.69" web-sys = "0.3.69" + +# perf +socket2 = { version = "0.5", features = ["all"] } +num_cpus = "1" +tikv-jemallocator = "0.6" + +# ── Release profile ─────────────────────────────────────────────────────────── +[profile.release] +opt-level = 3 +lto = "thin" +codegen-units = 1 +strip = "symbols" diff --git a/core-engine/src/cmd.rs b/core-engine/src/cmd.rs index ca68b1f..c2ff45a 100644 --- a/core-engine/src/cmd.rs +++ b/core-engine/src/cmd.rs @@ -35,6 +35,8 @@ pub enum ZAddCondition { #[derive(Debug, Clone, PartialEq, Default)] pub struct ZAddOptions { pub condition: Option, + pub gt: bool, + pub lt: bool, pub ch: bool, pub incr: bool, } @@ -198,7 +200,7 @@ impl Command { // ── Strings ─────────────────────────────────────────────── "SET" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let val = extract_string(&arr[2]).unwrap_or_default(); let mut opts = SetOptions::default(); let mut i = 3usize; @@ -271,43 +273,37 @@ impl Command { } "GET" => { need!(2); - Ok(Command::Get(extract_string(&arr[1]).unwrap_or_default())) + Ok(Command::Get(extract_key(&arr[1])?)) } "DEL" => { need!(2); - Ok(Command::Del( - arr[1..].iter().filter_map(extract_string).collect(), - )) + Ok(Command::Del(extract_keys(&arr[1..])?)) } "UNLINK" => { need!(2); - Ok(Command::Unlink( - arr[1..].iter().filter_map(extract_string).collect(), - )) + Ok(Command::Unlink(extract_keys(&arr[1..])?)) } "APPEND" => { need!(3); Ok(Command::Append( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, extract_string(&arr[2]).unwrap_or_default(), )) } "STRLEN" => { need!(2); - Ok(Command::Strlen(extract_string(&arr[1]).unwrap_or_default())) + Ok(Command::Strlen(extract_key(&arr[1])?)) } "GETSET" => { need!(3); Ok(Command::GetSet( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, extract_string(&arr[2]).unwrap_or_default(), )) } "MGET" => { need!(2); - Ok(Command::MGet( - arr[1..].iter().filter_map(extract_string).collect(), - )) + Ok(Command::MGet(extract_keys(&arr[1..])?)) } "MSET" => { if arr.len() < 3 || (arr.len() - 1) % 2 != 0 { @@ -318,18 +314,18 @@ impl Command { let pairs = arr[1..] .chunks(2) .map(|c| { - ( - extract_string(&c[0]).unwrap_or_default(), + Ok(( + extract_key(&c[0])?, extract_string(&c[1]).unwrap_or_default(), - ) + )) }) - .collect(); + .collect::, String>>()?; Ok(Command::MSet(pairs)) } "SETNX" => { need!(3); Ok(Command::SetNx( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, extract_string(&arr[2]).unwrap_or_default(), )) } @@ -340,7 +336,7 @@ impl Command { return Err("ERR invalid expire time in 'setex' command".to_string()); } Ok(Command::SetEx( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, secs as u64, extract_string(&arr[3]).unwrap_or_default(), )) @@ -352,30 +348,30 @@ impl Command { return Err("ERR invalid expire time in 'psetex' command".to_string()); } Ok(Command::PSetEx( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, ms as u64, extract_string(&arr[3]).unwrap_or_default(), )) } "INCR" => { need!(2); - Ok(Command::Incr(extract_string(&arr[1]).unwrap_or_default())) + Ok(Command::Incr(extract_key(&arr[1])?)) } "DECR" => { need!(2); - Ok(Command::Decr(extract_string(&arr[1]).unwrap_or_default())) + Ok(Command::Decr(extract_key(&arr[1])?)) } "INCRBY" => { need!(3); Ok(Command::IncrBy( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, extract_int(&arr[2])?, )) } "DECRBY" => { need!(3); Ok(Command::DecrBy( - extract_string(&arr[1]).unwrap_or_default(), + extract_key(&arr[1])?, extract_int(&arr[2])?, )) } @@ -445,9 +441,7 @@ impl Command { // ── Keys ─────────────────────────────────────────────────── "EXISTS" => { need!(2); - Ok(Command::Exists( - arr[1..].iter().filter_map(extract_string).collect(), - )) + Ok(Command::Exists(extract_keys(&arr[1..])?)) } "KEYS" => { need!(2); @@ -509,7 +503,7 @@ impl Command { cmd_name.to_lowercase() )); } - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let pairs = arr[2..] .chunks(2) .map(|c| { @@ -536,7 +530,7 @@ impl Command { } "HDEL" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let fields = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::HDel(key, fields)) } @@ -586,7 +580,7 @@ impl Command { } "HMGET" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let fields = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::HMGet(key, fields)) } @@ -594,31 +588,31 @@ impl Command { // ── List ─────────────────────────────────────────────────── "LPUSH" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let vals = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::LPush(key, vals)) } "RPUSH" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let vals = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::RPush(key, vals)) } "LPUSHX" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let vals = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::LPushX(key, vals)) } "RPUSHX" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let vals = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::RPushX(key, vals)) } "LPOP" => { need!(2); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let count = if arr.len() > 2 { Some(extract_int(&arr[2])? as u64) } else { @@ -628,7 +622,7 @@ impl Command { } "RPOP" => { need!(2); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let count = if arr.len() > 2 { Some(extract_int(&arr[2])? as u64) } else { @@ -683,7 +677,7 @@ impl Command { // ── Set ──────────────────────────────────────────────────── "SADD" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let members = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::SAdd(key, members)) } @@ -695,7 +689,7 @@ impl Command { } "SREM" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let members = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::SRem(key, members)) } @@ -712,7 +706,7 @@ impl Command { } "SMISMEMBER" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let members = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::SMIsMember(key, members)) } @@ -754,7 +748,7 @@ impl Command { } "SPOP" => { need!(2); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let count = if arr.len() > 2 { Some(extract_int(&arr[2])? as u64) } else { @@ -764,7 +758,7 @@ impl Command { } "SRANDMEMBER" => { need!(2); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let count = if arr.len() > 2 { Some(extract_int(&arr[2])?) } else { @@ -784,7 +778,7 @@ impl Command { // ── Sorted Set ───────────────────────────────────────────── "ZADD" => { need!(4); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let mut opts = ZAddOptions::default(); let mut i = 2usize; @@ -800,8 +794,13 @@ impl Command { opts.condition = Some(ZAddCondition::Xx); i += 1; } - "GT" | "LT" => { - i += 1; // recognised but not yet enforced + "GT" => { + opts.gt = true; + i += 1; + } + "LT" => { + opts.lt = true; + i += 1; } "CH" => { opts.ch = true; @@ -827,6 +826,19 @@ impl Command { i += 2; } + if opts.gt && opts.lt { + return Err( + "ERR GT and LT options at the same time are not compatible" + .to_string(), + ); + } + if (opts.gt || opts.lt) && opts.condition == Some(ZAddCondition::Nx) { + return Err( + "ERR GT, LT, and NX options at the same time are not compatible" + .to_string(), + ); + } + if opts.incr && pairs.len() != 1 { return Err("ERR INCR option supports a single increment-element pair" .to_string()); @@ -836,7 +848,7 @@ impl Command { } "ZRANGE" => { need!(4); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let start = extract_int(&arr[2])?; let stop = extract_int(&arr[3])?; let withscores = arr @@ -848,7 +860,7 @@ impl Command { } "ZREVRANGE" => { need!(4); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let start = extract_int(&arr[2])?; let stop = extract_int(&arr[3])?; let withscores = arr @@ -860,7 +872,7 @@ impl Command { } "ZRANGEBYSCORE" => { need!(4); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let min = extract_string(&arr[2]).unwrap_or_default(); let max = extract_string(&arr[3]).unwrap_or_default(); let (withscores, limit) = parse_zrange_opts(&arr[4..])?; @@ -868,7 +880,7 @@ impl Command { } "ZREVRANGEBYSCORE" => { need!(4); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let max = extract_string(&arr[2]).unwrap_or_default(); let min = extract_string(&arr[3]).unwrap_or_default(); let (withscores, limit) = parse_zrange_opts(&arr[4..])?; @@ -883,7 +895,7 @@ impl Command { } "ZMSCORE" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let members = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::ZMScore(key, members)) } @@ -903,7 +915,7 @@ impl Command { } "ZREM" => { need!(3); - let key = extract_string(&arr[1]).unwrap_or_default(); + let key = extract_key(&arr[1])?; let members = arr[2..].iter().filter_map(extract_string).collect(); Ok(Command::ZRem(key, members)) } @@ -1000,6 +1012,32 @@ impl Command { // ── helpers ─────────────────────────────────────────────────────────────────── +const MAX_KEY_BYTES: usize = 512 * 1024; // 512 KB + +fn validate_key(key: &str) -> Result<(), String> { + if key.is_empty() { + return Err("ERR key cannot be empty".to_string()); + } + if key.len() > MAX_KEY_BYTES { + return Err(format!( + "ERR key too large ({} > {} bytes)", + key.len(), + MAX_KEY_BYTES + )); + } + Ok(()) +} + +fn extract_key(val: &Value) -> Result { + let key = extract_string(val).unwrap_or_default(); + validate_key(&key)?; + Ok(key) +} + +fn extract_keys(vals: &[Value]) -> Result, String> { + vals.iter().map(extract_key).collect() +} + fn extract_string(val: &Value) -> Option { match val { Value::BulkString(Some(data)) => Some(String::from_utf8_lossy(data).into_owned()), @@ -1225,6 +1263,8 @@ mod tests { "z".into(), ZAddOptions { condition: Some(ZAddCondition::Nx), + gt: false, + lt: false, ch: true, incr: false }, diff --git a/core-engine/src/resp.rs b/core-engine/src/resp.rs index 8cdd2c3..4f1fe6f 100644 --- a/core-engine/src/resp.rs +++ b/core-engine/src/resp.rs @@ -1,6 +1,7 @@ const MAX_ARRAY_DEPTH: usize = 16; const MAX_ARRAY_ELEMENTS: usize = 1_000_000; const MAX_BULK_STRING_BYTES: usize = 64 * 1024 * 1024; // 64 MB +const MAX_TOTAL_MESSAGE_BYTES: usize = 64 * 1024 * 1024; // 64 MB total per message #[derive(Debug, Clone, PartialEq)] pub enum Value { @@ -176,6 +177,9 @@ impl Value { let (val, len) = Self::parse_inner(&buffer[offset..], depth + 1)?; arr.push(val); offset += len; + if offset > MAX_TOTAL_MESSAGE_BYTES { + return Err("ERR message too large".to_string()); + } } Ok((Value::Push(arr), offset)) } @@ -211,6 +215,9 @@ impl Value { let (val, len) = Self::parse_inner(&buffer[offset..], depth + 1)?; arr.push(val); offset += len; + if offset > MAX_TOTAL_MESSAGE_BYTES { + return Err("ERR message too large".to_string()); + } } Ok((Value::Array(Some(arr)), offset)) diff --git a/core-engine/src/store.rs b/core-engine/src/store.rs index 4307009..424e245 100644 --- a/core-engine/src/store.rs +++ b/core-engine/src/store.rs @@ -113,11 +113,23 @@ impl EntryValue { // ── Entry ───────────────────────────────────────────────────────────────────── -#[derive(Clone)] struct Entry { value: EntryValue, expires_at_ms: Option, - written_at_ms: u64, + /// Last time this entry was read or written, in ms. Drives LRU eviction. + /// Atomic so reads can refresh recency while holding only a shared + /// (DashMap read-lock) reference — no writer lock on the GET path. + last_access_ms: AtomicU64, +} + +impl Clone for Entry { + fn clone(&self) -> Self { + Self { + value: self.value.clone(), + expires_at_ms: self.expires_at_ms, + last_access_ms: AtomicU64::new(self.last_access_ms.load(Ordering::Relaxed)), + } + } } impl Entry { @@ -125,7 +137,7 @@ impl Entry { Self { value: EntryValue::Str(value), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), } } @@ -133,13 +145,18 @@ impl Entry { Self { value: EntryValue::Str(value), expires_at_ms: Some(expires_at_ms), - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), } } fn is_expired(&self, now: u64) -> bool { self.expires_at_ms.is_some_and(|exp| now >= exp) } + + /// Mark this entry as just-used so LRU eviction treats it as recent. + fn touch(&self, now: u64) { + self.last_access_ms.store(now, Ordering::Relaxed); + } } // ── resolve list range helpers ──────────────────────────────────────────────── @@ -213,7 +230,7 @@ fn encode_zrange(items: &[(&str, f64)], withscores: bool) -> Value { Value::Array(Some(out)) } -fn format_score(s: f64) -> String { +pub fn format_score(s: f64) -> String { if s == f64::INFINITY { "inf".to_string() } else if s == f64::NEG_INFINITY { @@ -349,35 +366,40 @@ impl KeyValueStore { pub fn approximate_memory_bytes(&self) -> usize { self.data .iter() - .map(|r| { - let val_size = match &r.value().value { - EntryValue::Str(s) => s.len(), - EntryValue::Hash(m) => m.iter().map(|(k, v)| k.len() + v.len()).sum(), - EntryValue::List(l) => l.iter().map(|s| s.len()).sum(), - EntryValue::Set(s) => s.iter().map(|m| m.len()).sum::(), - EntryValue::ZSet(z) => z.scores.keys().map(|m| m.len() + 8).sum(), - }; - r.key().len() + val_size + 64 - }) + .map(|r| entry_size(r.key(), r.value())) .sum() } /// Evict entries until memory usage is below `max_memory_bytes`, or the /// eviction policy cannot free any more. Returns true if under limit. + /// + /// The total is scanned once up front, then maintained incrementally by + /// subtracting each evicted entry's measured size. The previous version + /// re-scanned the whole keyspace after every single eviction (O(N) per + /// eviction → O(N²) overall) — catastrophic under memory pressure. We + /// periodically re-scan to correct for drift from concurrent writes. pub fn try_evict_for_memory(&self) -> bool { let limit = match self.max_memory_bytes { Some(l) => l, None => return true, }; let now = now_ms(); - loop { - if self.approximate_memory_bytes() <= limit { - return true; - } - if !self.evict_one(now) { - return false; + let mut current = self.approximate_memory_bytes(); + let mut since_resync = 0u32; + while current > limit { + match self.evict_one(now) { + None => return false, // policy can't free anything more + Some(freed) => { + current = current.saturating_sub(freed); + since_resync += 1; + if since_resync >= 64 { + current = self.approximate_memory_bytes(); + since_resync = 0; + } + } } } + true } /// Returns the current value of a key for watch push notifications. @@ -404,49 +426,42 @@ impl KeyValueStore { self.data.retain(|_, e| !e.is_expired(now)); } - fn evict_one(&self, now: u64) -> bool { + /// Evict a single entry per the configured policy. Returns the number of + /// bytes freed (`Some`), or `None` if nothing could be evicted. + fn evict_one(&self, now: u64) -> Option { const SAMPLE: usize = 10; let mut rng = rand::rng(); - match self.eviction_policy { - EvictionPolicy::NoEviction => false, + let chosen: Option = match self.eviction_policy { + EvictionPolicy::NoEviction => None, EvictionPolicy::AllKeysLru => { let sample: Vec<(String, u64)> = self .data .iter() - .map(|r| (r.key().clone(), r.value().written_at_ms)) + .map(|r| { + ( + r.key().clone(), + r.value().last_access_ms.load(Ordering::Relaxed), + ) + }) .choose_multiple(&mut rng, SAMPLE); - match sample.into_iter().min_by_key(|(_, w)| *w) { - Some((k, _)) => { - self.data.remove(&k); - true - } - None => false, - } + sample.into_iter().min_by_key(|(_, w)| *w).map(|(k, _)| k) } EvictionPolicy::AllKeysRandom => { - let key = self.data.iter().map(|r| r.key().clone()).choose(&mut rng); - match key { - Some(k) => { - self.data.remove(&k); - true - } - None => false, - } + self.data.iter().map(|r| r.key().clone()).choose(&mut rng) } EvictionPolicy::VolatileLru => { let sample: Vec<(String, u64)> = self .data .iter() .filter(|r| r.value().expires_at_ms.is_some() && !r.value().is_expired(now)) - .map(|r| (r.key().clone(), r.value().written_at_ms)) + .map(|r| { + ( + r.key().clone(), + r.value().last_access_ms.load(Ordering::Relaxed), + ) + }) .choose_multiple(&mut rng, SAMPLE); - match sample.into_iter().min_by_key(|(_, w)| *w) { - Some((k, _)) => { - self.data.remove(&k); - true - } - None => false, - } + sample.into_iter().min_by_key(|(_, w)| *w).map(|(k, _)| k) } EvictionPolicy::VolatileTtl => { let sample: Vec<(String, u64)> = self @@ -461,15 +476,20 @@ impl KeyValueStore { } }) .choose_multiple(&mut rng, SAMPLE); - match sample.into_iter().min_by_key(|(_, exp)| *exp) { - Some((k, _)) => { - self.data.remove(&k); - true - } - None => false, - } + sample + .into_iter() + .min_by_key(|(_, exp)| *exp) + .map(|(k, _)| k) } - } + }; + let key = chosen?; + // Treat a lost race (key already gone) as a successful eviction that + // freed nothing, so callers don't spin. + Some( + self.data + .remove(&key) + .map_or(0, |(k, e)| entry_size(&k, &e)), + ) } pub fn snapshot(&self) -> Vec { @@ -518,7 +538,7 @@ impl KeyValueStore { Entry { value, expires_at_ms: e.expires_at_ms, - written_at_ms: now, + last_access_ms: AtomicU64::new(now), }, ); } @@ -531,7 +551,9 @@ impl KeyValueStore { Some(m) => Value::BulkString(Some(m.into_bytes())), None => Value::SimpleString("PONG".to_string()), }, - Command::Auth(_) => Value::SimpleString("OK".to_string()), + Command::Auth(_) => Value::Error( + "ERR AUTH is handled by the connection layer, not the store".to_string(), + ), // ── Strings ─────────────────────────────────────────────────────── Command::Set(key, val, opts) => { @@ -571,14 +593,19 @@ impl KeyValueStore { if let Some(max) = self.max_keys && self.data.len() >= max && !self.data.contains_key(&key) - && !self.evict_one(now) + && self.evict_one(now).is_none() { return Value::Error("ERR max keys limit reached".to_string()); } let expires_at_ms = match &opts.expiry { None => None, - Some(SetExpiry::Ex(s)) => Some(now.saturating_add(s.saturating_mul(1000))), + Some(SetExpiry::Ex(s)) => { + if *s > u64::MAX / 1000 { + return Value::Error("ERR TTL overflow".to_string()); + } + Some(now.saturating_add(s * 1000)) + } Some(SetExpiry::Px(ms)) => Some(now.saturating_add(*ms)), Some(SetExpiry::Exat(ts)) => Some(ts.saturating_mul(1000)), Some(SetExpiry::Pxat(ts_ms)) => Some(*ts_ms), @@ -590,7 +617,7 @@ impl KeyValueStore { Entry { value: EntryValue::Str(val), expires_at_ms, - written_at_ms: now, + last_access_ms: AtomicU64::new(now), }, ); @@ -607,7 +634,10 @@ impl KeyValueStore { let now = now_ms(); match self.data.get(&key) { Some(e) if !e.is_expired(now) => match &e.value { - EntryValue::Str(s) => Value::BulkString(Some(s.clone().into_bytes())), + EntryValue::Str(s) => { + e.touch(now); + Value::BulkString(Some(s.clone().into_bytes())) + } _ => Value::Error(WRONGTYPE.to_string()), }, _ => Value::BulkString(None), @@ -615,9 +645,10 @@ impl KeyValueStore { } Command::Del(keys) | Command::Unlink(keys) => { + let now = now_ms(); let count = keys .into_iter() - .filter(|k| self.data.remove(k).is_some()) + .filter(|k| self.data.remove_if(k, |_, e| !e.is_expired(now)).is_some()) .count(); Value::Integer(count as i64) } @@ -672,7 +703,10 @@ impl KeyValueStore { .iter() .map(|k| match self.data.get(k) { Some(e) if !e.is_expired(now) => match &e.value { - EntryValue::Str(s) => Value::BulkString(Some(s.clone().into_bytes())), + EntryValue::Str(s) => { + e.touch(now); + Value::BulkString(Some(s.clone().into_bytes())) + } _ => Value::BulkString(None), }, _ => Value::BulkString(None), @@ -692,7 +726,7 @@ impl KeyValueStore { if new_count > available { let needed = new_count - available; for _ in 0..needed { - if !self.evict_one(now) { + if self.evict_one(now).is_none() { return Value::Error("ERR max keys limit reached".to_string()); } } @@ -713,7 +747,7 @@ impl KeyValueStore { if let Some(max) = self.max_keys && self.data.len() >= max && !self.data.contains_key(&key) - && !self.evict_one(now) + && self.evict_one(now).is_none() { return Value::Error("ERR max keys limit reached".to_string()); } @@ -727,7 +761,7 @@ impl KeyValueStore { if let Some(max) = self.max_keys && self.data.len() >= max && !self.data.contains_key(&key) - && !self.evict_one(now) + && self.evict_one(now).is_none() { return Value::Error("ERR max keys limit reached".to_string()); } @@ -741,7 +775,7 @@ impl KeyValueStore { if let Some(max) = self.max_keys && self.data.len() >= max && !self.data.contains_key(&key) - && !self.evict_one(now) + && self.evict_one(now).is_none() { return Value::Error("ERR max keys limit reached".to_string()); } @@ -771,7 +805,7 @@ impl KeyValueStore { Some(e) if e.is_expired(now) => Value::Integer(-2), Some(e) => match e.expires_at_ms { None => Value::Integer(-1), - Some(exp) => Value::Integer(((exp - now) / 1000) as i64), + Some(exp) => Value::Integer((exp.saturating_sub(now) / 1000) as i64), }, } } @@ -783,7 +817,7 @@ impl KeyValueStore { Some(e) if e.is_expired(now) => Value::Integer(-2), Some(e) => match e.expires_at_ms { None => Value::Integer(-1), - Some(exp) => Value::Integer((exp - now) as i64), + Some(exp) => Value::Integer(exp.saturating_sub(now) as i64), }, } } @@ -834,24 +868,38 @@ impl KeyValueStore { Value::Array(Some(keys)) } - Command::Scan(cursor, pattern, _count) => { - if cursor != 0 { - return Value::Array(Some(vec![ - Value::BulkString(Some(b"0".to_vec())), - Value::Array(Some(vec![])), - ])); - } + Command::Scan(cursor, pattern, count) => { let now = now_ms(); let pat = pattern.as_deref().unwrap_or("*"); - let keys: Vec = self + let batch = count.unwrap_or(10).max(1); + // Sort for a stable order so the numeric cursor is a meaningful + // offset across calls and COUNT actually paginates. This is + // O(N log N) per call (like KEYS), but each reply is bounded to + // `batch` keys instead of dumping the whole keyspace at once. + // As with Redis, concurrent inserts/deletes between calls may + // cause a key to be skipped or returned twice. + let mut all: Vec = self .data .iter() .filter(|r| !r.value().is_expired(now) && glob_match(pat, r.key())) - .map(|r| Value::BulkString(Some(r.key().as_bytes().to_vec()))) + .map(|r| r.key().clone()) + .collect(); + all.sort_unstable(); + let start = cursor as usize; + let end = start.saturating_add(batch).min(all.len()); + let page: &[String] = if start < all.len() { + &all[start..end] + } else { + &[] + }; + let next_cursor = if end >= all.len() { 0 } else { end as u64 }; + let out: Vec = page + .iter() + .map(|k| Value::BulkString(Some(k.as_bytes().to_vec()))) .collect(); Value::Array(Some(vec![ - Value::BulkString(Some(b"0".to_vec())), - Value::Array(Some(keys)), + Value::BulkString(Some(next_cursor.to_string().into_bytes())), + Value::Array(Some(out)), ])) } @@ -901,7 +949,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::Hash(HashMap::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::Hash(HashMap::new()); @@ -927,10 +975,12 @@ impl KeyValueStore { None => Value::BulkString(None), Some(e) if e.is_expired(now) => Value::BulkString(None), Some(e) => match &e.value { - EntryValue::Hash(h) => h - .get(&field) - .map(|v| Value::BulkString(Some(v.clone().into_bytes()))) - .unwrap_or(Value::BulkString(None)), + EntryValue::Hash(h) => { + e.touch(now); + h.get(&field) + .map(|v| Value::BulkString(Some(v.clone().into_bytes()))) + .unwrap_or(Value::BulkString(None)) + } _ => Value::Error(WRONGTYPE.to_string()), }, } @@ -943,6 +993,7 @@ impl KeyValueStore { Some(e) if e.is_expired(now) => Value::Array(Some(vec![])), Some(e) => match &e.value { EntryValue::Hash(h) => { + e.touch(now); let mut pairs: Vec<(&str, &str)> = h.iter().map(|(f, v)| (f.as_str(), v.as_str())).collect(); pairs.sort_unstable_by_key(|(f, _)| *f); @@ -1058,7 +1109,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::Hash(HashMap::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::Hash(HashMap::new()); @@ -1108,7 +1159,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::List(VecDeque::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::List(VecDeque::new()); @@ -1130,7 +1181,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::List(VecDeque::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::List(VecDeque::new()); @@ -1239,6 +1290,7 @@ impl KeyValueStore { Some(e) if e.is_expired(now) => Value::Array(Some(vec![])), Some(e) => match &e.value { EntryValue::List(list) => { + e.touch(now); let slice: Vec<&String> = list.iter().collect(); match resolve_range(start, stop, slice.len()) { None => Value::Array(Some(vec![])), @@ -1370,7 +1422,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::Set(HashSet::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::Set(HashSet::new()); @@ -1394,6 +1446,7 @@ impl KeyValueStore { Some(e) if e.is_expired(now) => Value::Array(Some(vec![])), Some(e) => match &e.value { EntryValue::Set(s) => { + e.touch(now); let mut members: Vec<&str> = s.iter().map(|m| m.as_str()).collect(); members.sort_unstable(); Value::Array(Some( @@ -1442,6 +1495,7 @@ impl KeyValueStore { Some(e) if e.is_expired(now) => Value::Integer(0), Some(e) => match &e.value { EntryValue::Set(s) => { + e.touch(now); Value::Integer(if s.contains(&member) { 1 } else { 0 }) } _ => Value::Error(WRONGTYPE.to_string()), @@ -1490,7 +1544,7 @@ impl KeyValueStore { Entry { value: EntryValue::Set(result), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }, ); Value::Integer(len as i64) @@ -1518,7 +1572,7 @@ impl KeyValueStore { Entry { value: EntryValue::Set(result), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }, ); Value::Integer(len as i64) @@ -1546,7 +1600,7 @@ impl KeyValueStore { Entry { value: EntryValue::Set(result), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }, ); Value::Integer(len as i64) @@ -1560,7 +1614,10 @@ impl KeyValueStore { Some(mut e) => match &mut e.value { EntryValue::Set(s) => { let n = count.unwrap_or(1) as usize; - let popped: Vec = s.iter().take(n).cloned().collect(); + let mut rng = rand::rng(); + // SPOP removes *random* members, not iteration-order ones. + let popped: Vec = + s.iter().cloned().choose_multiple(&mut rng, n); for m in &popped { s.remove(m); } @@ -1597,15 +1654,20 @@ impl KeyValueStore { }, Some(e) => match &e.value { EntryValue::Set(s) => match count { - None => s - .iter() - .next() - .map(|m| Value::BulkString(Some(m.as_bytes().to_vec()))) - .unwrap_or(Value::BulkString(None)), + None => { + let mut rng = rand::rng(); + s.iter() + .choose(&mut rng) + .map(|m| Value::BulkString(Some(m.as_bytes().to_vec()))) + .unwrap_or(Value::BulkString(None)) + } Some(n) if n >= 0 => { - let mut members: Vec<&str> = - s.iter().map(|m| m.as_str()).take(n as usize).collect(); - members.sort_unstable(); + // Positive count: up to n *distinct* random members. + let mut rng = rand::rng(); + let members: Vec<&str> = s + .iter() + .map(|m| m.as_str()) + .choose_multiple(&mut rng, n as usize); Value::Array(Some( members .into_iter() @@ -1614,13 +1676,18 @@ impl KeyValueStore { )) } Some(n) => { - // Negative: allow repetition, return |n| elements + // Negative: allow repetition, return |n| random elements. let members: Vec<&str> = s.iter().map(|m| m.as_str()).collect(); + if members.is_empty() { + return Value::Array(Some(vec![])); + } + let mut rng = rand::rng(); let abs = n.unsigned_abs() as usize; Value::Array(Some( (0..abs) - .map(|i| { - let m = members[i % members.len()]; + .map(|_| { + let m = + members.iter().copied().choose(&mut rng).unwrap(); Value::BulkString(Some(m.as_bytes().to_vec())) }) .collect(), @@ -1667,7 +1734,7 @@ impl KeyValueStore { let mut dst_entry = self.data.entry(dst).or_insert_with(|| Entry { value: EntryValue::Set(HashSet::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired_dst { dst_entry.value = EntryValue::Set(HashSet::new()); @@ -1686,7 +1753,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::ZSet(ZSetInner::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::ZSet(ZSetInner::new()); @@ -1756,11 +1823,13 @@ impl KeyValueStore { None => Value::BulkString(None), Some(e) if e.is_expired(now) => Value::BulkString(None), Some(e) => match &e.value { - EntryValue::ZSet(z) => z - .scores - .get(&member) - .map(|s| Value::BulkString(Some(format_score(*s).into_bytes()))) - .unwrap_or(Value::BulkString(None)), + EntryValue::ZSet(z) => { + e.touch(now); + z.scores + .get(&member) + .map(|s| Value::BulkString(Some(format_score(*s).into_bytes()))) + .unwrap_or(Value::BulkString(None)) + } _ => Value::Error(WRONGTYPE.to_string()), }, } @@ -1867,7 +1936,7 @@ impl KeyValueStore { let mut entry = self.data.entry(key).or_insert_with(|| Entry { value: EntryValue::ZSet(ZSetInner::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::ZSet(ZSetInner::new()); @@ -1877,12 +1946,12 @@ impl KeyValueStore { EntryValue::ZSet(z) => z, _ => unreachable!(), }; - let score = zset - .scores - .entry(member) - .and_modify(|s| *s += delta) - .or_insert(delta); - let new_score = *score; + let prev_score = zset.scores.get(&member).copied().unwrap_or(0.0); + let new_score = prev_score + delta; + if new_score.is_nan() || new_score.is_infinite() { + return Value::Error("ERR increment would produce NaN or Infinity".to_string()); + } + zset.scores.insert(member, new_score); Value::BulkString(Some(format_score(new_score).into_bytes())) } @@ -1928,6 +1997,20 @@ impl KeyValueStore { // ── Free helpers ────────────────────────────────────────────────────────────── +/// Approximate heap footprint of a single entry: key + value bytes plus a fixed +/// per-entry overhead. Shared by `approximate_memory_bytes` and the eviction +/// loop so both agree on what a key "costs". +fn entry_size(key: &str, e: &Entry) -> usize { + let val_size = match &e.value { + EntryValue::Str(s) => s.len(), + EntryValue::Hash(m) => m.iter().map(|(k, v)| k.len() + v.len()).sum(), + EntryValue::List(l) => l.iter().map(|s| s.len()).sum(), + EntryValue::Set(s) => s.iter().map(|m| m.len()).sum::(), + EntryValue::ZSet(z) => z.scores.keys().map(|m| m.len() + 8).sum(), + }; + key.len() + val_size + 64 +} + fn incr_by(data: &DashMap, key: String, delta: i64) -> Value { let now = now_ms(); let was_expired = match data.get(&key) { @@ -1973,21 +2056,32 @@ fn set_expiry(data: &DashMap, key: String, ts_ms: u64) -> Value { } fn glob_match(pattern: &str, s: &str) -> bool { - glob_helper(pattern.as_bytes(), s.as_bytes()) -} - -fn glob_helper(pat: &[u8], s: &[u8]) -> bool { - match (pat.first(), s.first()) { - (None, None) => true, - (None, Some(_)) => false, - (Some(b'*'), _) => { - glob_helper(&pat[1..], s) || (!s.is_empty() && glob_helper(pat, &s[1..])) + let pat = pattern.as_bytes(); + let text = s.as_bytes(); + let (m, n) = (pat.len(), text.len()); + + // Iterative DP: prev[j] = pat[..i] matches text[..j]. + // This replaces a recursive matcher that had exponential worst-case + // backtracking on patterns like "*.*.*x" against long non-matching strings. + let mut prev = vec![false; n + 1]; + let mut curr = vec![false; n + 1]; + prev[0] = true; + + for i in 1..=m { + curr[0] = pat[i - 1] == b'*' && prev[0]; + for j in 1..=n { + curr[j] = if pat[i - 1] == b'*' { + prev[j] || curr[j - 1] + } else if pat[i - 1] == b'?' || pat[i - 1] == text[j - 1] { + prev[j - 1] + } else { + false + }; } - (Some(b'?'), Some(_)) => glob_helper(&pat[1..], &s[1..]), - (Some(b'?'), None) => false, - (Some(p), Some(c)) if p == c => glob_helper(&pat[1..], &s[1..]), - _ => false, + std::mem::swap(&mut prev, &mut curr); } + + prev[n] } fn no_list_response(count: Option) -> Value { @@ -2116,7 +2210,7 @@ fn hash_incr_int(data: &DashMap, key: String, field: String, delt let mut entry = data.entry(key).or_insert_with(|| Entry { value: EntryValue::Hash(HashMap::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::Hash(HashMap::new()); @@ -2149,7 +2243,7 @@ fn hash_incr_float(data: &DashMap, key: String, field: String, de let mut entry = data.entry(key).or_insert_with(|| Entry { value: EntryValue::Hash(HashMap::new()), expires_at_ms: None, - written_at_ms: now_ms(), + last_access_ms: AtomicU64::new(now_ms()), }); if was_expired { entry.value = EntryValue::Hash(HashMap::new()); @@ -2179,7 +2273,10 @@ where None => f(&empty), Some(e) if e.is_expired(now) => f(&empty), Some(e) => match &e.value { - EntryValue::ZSet(z) => f(z), + EntryValue::ZSet(z) => { + e.touch(now); + f(z) + } _ => return Value::Error(WRONGTYPE.to_string()), }, }; @@ -2216,24 +2313,53 @@ fn zadd_exec(zset: &mut ZSetInner, opts: ZAddOptions, pairs: Vec<(f64, String)>) } } Some(ZAddCondition::Xx) => { - if let Some(s) = zset.scores.get_mut(&member) - && (*s - score).abs() > f64::EPSILON - { - *s = score; - changed += 1; + if let Some(old_score) = zset.scores.get_mut(&member) { + let should_update = if opts.gt { + score > *old_score + } else if opts.lt { + score < *old_score + } else { + (score - *old_score).abs() > f64::EPSILON + }; + if should_update { + *old_score = score; + changed += 1; + } } } None => { - let old = zset.scores.insert(member, score); - match old { - None => { - added += 1; - changed += 1; + if opts.gt || opts.lt { + match zset.scores.entry(member) { + std::collections::hash_map::Entry::Vacant(e) => { + e.insert(score); + added += 1; + changed += 1; + } + std::collections::hash_map::Entry::Occupied(mut e) => { + let old_score = *e.get(); + let should_update = if opts.gt { + score > old_score + } else { + score < old_score + }; + if should_update { + e.insert(score); + changed += 1; + } + } } - Some(old_score) if (old_score - score).abs() > f64::EPSILON => { - changed += 1; + } else { + let old = zset.scores.insert(member, score); + match old { + None => { + added += 1; + changed += 1; + } + Some(old_score) if (old_score - score).abs() > f64::EPSILON => { + changed += 1; + } + _ => {} } - _ => {} } } } @@ -3032,6 +3158,49 @@ mod tests { } } + #[test] + fn scan_paginates_with_count() { + let s = store(); + for i in 0..5 { + s.execute(Command::Set( + format!("k{i}"), + "v".into(), + SetOptions::default(), + )); + } + // Walk the cursor in pages of 2, collecting every key exactly once. + let mut seen: Vec = Vec::new(); + let mut cursor = 0u64; + let mut iterations = 0; + loop { + let res = s.execute(Command::Scan(cursor, None, Some(2))); + let Value::Array(Some(parts)) = res else { + panic!("expected array") + }; + let Value::BulkString(Some(c)) = &parts[0] else { + panic!("expected cursor bulk") + }; + let next: u64 = String::from_utf8_lossy(c).parse().unwrap(); + let Value::Array(Some(keys)) = &parts[1] else { + panic!("expected keys array") + }; + assert!(keys.len() <= 2, "page must honour COUNT"); + for k in keys { + if let Value::BulkString(Some(d)) = k { + seen.push(String::from_utf8_lossy(d).into_owned()); + } + } + cursor = next; + iterations += 1; + if cursor == 0 { + break; + } + assert!(iterations < 10, "cursor should terminate"); + } + seen.sort(); + assert_eq!(seen, vec!["k0", "k1", "k2", "k3", "k4"]); + } + // ── Expiry ──────────────────────────────────────────────────────────────── #[test] diff --git a/docs/roadmap.md b/docs/roadmap.md index 45118c5..e4999c0 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -21,19 +21,3 @@ The `core-engine` crate is already `wasm32`-compatible. The main work is adaptin Run `.wasm` stored procedures in place of Lua scripts. The scripting VM would be sandboxed (no network, no file I/O, bounded execution time), accept any WASM module that exports a specific entry function, and execute it against the cache store. This supports any language that compiles to WASM: Rust, Go (TinyGo), AssemblyScript, Python (via Pyodide). --- - -## Intentionally out of scope - -These features will not be added to Recached. If you need them, Redis is the right tool. - -**RDB / AOF persistence.** Recached is an in-memory cache. Durability is the responsibility of the database behind it. Repopulate on startup from your source of truth. - -**`REPLICAOF` / leader-follower replication.** The native→browser WebSocket sync is Recached's replication story. Multi-server Redis-style replication does not fit the architecture. - -**Full Redis command parity.** Recached implements ~80 commands — the ones most applications actually use. The remaining 170+ Redis commands include cluster management, server introspection (`INFO`, `SLOWLOG`, `DEBUG`), and commands that assume RDB persistence (`BGSAVE`, `BGREWRITEAOF`, `SAVE`). These are not planned. - -**RESP3.** RESP2 is sufficient for Recached's scope and keeps the parser simple. RESP3 adds type hints that matter more for complex Redis use cases than for Recached's. - -**Cluster mode.** Recached is a single-node cache server. Horizontal scaling is not a goal for the current architecture. - -**Lua scripting.** WASM scripting (see roadmap above) is the planned scripting story. Lua will not be added. diff --git a/docs/server/commands.md b/docs/server/commands.md index cb54f5c..71ce724 100644 --- a/docs/server/commands.md +++ b/docs/server/commands.md @@ -58,7 +58,7 @@ The most common data type. Values are always stored as byte strings; numeric ope | `TYPE key` | Returns the type of the value stored at key: `string`, `hash`, `list`, `set`, `zset`, or `none` if the key does not exist. | | `RENAME key newkey` | Renames a key. Returns an error if the source key does not exist. Overwrites `newkey` if it already exists. | | `KEYS pattern` | Returns all keys matching the glob pattern. `*` matches any sequence of characters, `?` matches a single character, `[abc]` matches a character class. Warning: `KEYS *` on a large store is slow — prefer `SCAN`. | -| `SCAN cursor [MATCH pattern] [COUNT count]` | Iterates keys incrementally. Returns the next cursor and a batch of keys. Start with cursor `0`; continue until the returned cursor is `0`. `MATCH` filters results. `COUNT` is a hint for batch size. | +| `SCAN cursor [MATCH pattern] [COUNT count]` | Iterates keys incrementally, returning at most `COUNT` keys per call (default 10) plus the next cursor. Start with cursor `0` and continue until the returned cursor is `0`. `MATCH` filters results by glob pattern. As in Redis, keys inserted or deleted mid-iteration may be missed or returned twice. | | `DBSIZE` | Returns the total number of keys in the store. | | `FLUSHDB [ASYNC]` | Removes all keys from the store. `ASYNC` is accepted but does not change behavior (the flush is always synchronous). | @@ -180,8 +180,8 @@ Transactions queue commands and execute them atomically. No other client can int | Command | Description | |---|---| | `MULTI` | Begins a transaction. Subsequent commands are queued, not executed. Returns `OK`. | -| `EXEC` | Executes all queued commands atomically. Returns an array of results, one per queued command. | -| `DISCARD` | Abandons the transaction queue. Returns `OK`. | +| `EXEC` | Executes all queued commands. Returns an array of results, one per queued command — or a nil array if a `WATCH`ed key changed since `WATCH` was issued (optimistic-lock abort). | +| `DISCARD` | Abandons the transaction queue. Returns `OK`. Also clears any `WATCH`ed keys. | ### Example @@ -196,7 +196,7 @@ EXEC # 3) 2 ``` -Note: Recached transactions do not support optimistic locking (`WATCH` in the Redis sense). `WATCH` in Recached is the key observation command — see [Observable Keys](#observable-keys-websocket-only) below. +Optimistic locking: `WATCH key [key ...]` before `MULTI` marks those keys. If any watched key is modified by **any** client before `EXEC`, the transaction is aborted and `EXEC` returns a nil array (Redis `WATCH`/`MULTI`/`EXEC` CAS semantics). This works over both the TCP (6379) and WebSocket (6380) ports. `EXEC` and `DISCARD` both clear all watches. Over WebSocket, `WATCH` *additionally* pushes live keychange notifications — see [Observable Keys](#observable-keys) below. --- @@ -258,16 +258,17 @@ SAVE # +OK --- -## Observable Keys (WebSocket-only) +## Observable Keys -`WATCH` and `UNWATCH` are Recached-specific commands available only over WebSocket connections (port 6380). They have different semantics from Redis's `WATCH` (which is used for optimistic locking with transactions). +`WATCH` and `UNWATCH` serve two roles in Recached: -In Recached, `WATCH` subscribes the connection to change notifications for a specific key. Whenever the key is mutated — by any client, from any connection — the server sends a push message to all watching connections. +1. **Optimistic locking (both transports).** Over TCP (6379) and WebSocket (6380), `WATCH` participates in `MULTI`/`EXEC` exactly like Redis: if a watched key changes before `EXEC`, the transaction aborts (nil array). See [Transactions](#transactions). +2. **Live change notifications (WebSocket only).** Over WebSocket, `WATCH` *additionally* subscribes the connection to keychange pushes: whenever a watched key is mutated by any client, the server sends a push frame to every watching WS connection. TCP connections receive no such push (it would violate the request/response protocol) — they use `WATCH` purely for the CAS guarantee above. | Command | Description | |---|---| -| `WATCH key [key ...]` | Registers the WebSocket connection to receive push notifications whenever the given key(s) change. | -| `UNWATCH [key ...]` | Stops watching the given keys. With no arguments, clears all watches for this connection. | +| `WATCH key [key ...]` | Marks the given key(s) for optimistic locking, and (over WebSocket) registers the connection for keychange push notifications. Not allowed once `MULTI` has started. | +| `UNWATCH [key ...]` | Stops watching the given keys. With no arguments, clears all watches for this connection. `EXEC` and `DISCARD` also clear all watches. | ### Push message format diff --git a/docs/server/configuration.md b/docs/server/configuration.md index b3b0a76..3365c15 100644 --- a/docs/server/configuration.md +++ b/docs/server/configuration.md @@ -6,7 +6,8 @@ Recached is configured entirely through environment variables. There is no confi | Variable | Default | Description | |---|---|---| -| `RECACHED_PASSWORD` | _(none)_ | Require clients to authenticate with `AUTH `. If unset, the server accepts connections without authentication. After 5 consecutive failed `AUTH` attempts, the connection is closed. | +| `RECACHED_BIND` | `0.0.0.0` | Network interface all listeners (TCP, WebSocket, replication, metrics) bind to. Defaults to `0.0.0.0` (all interfaces). Set to `127.0.0.1` to restrict the server to localhost — strongly recommended unless the server is deliberately public. | +| `RECACHED_PASSWORD` | _(none)_ | Require clients to authenticate with `AUTH `. If unset, the server accepts connections without authentication. After 5 consecutive failed `AUTH` attempts, the connection is closed. The password is compared in constant time. | | `RECACHED_ALLOW_IPS` | _(allow all)_ | Comma-separated list of IP addresses allowed to connect. Any connection from an IP not in the list is immediately closed. Invalid entries are logged and skipped. | | `RECACHED_MAX_KEYS` | _(unlimited)_ | Maximum number of keys in the store. When this limit is reached, behavior depends on `RECACHED_EVICTION`. If set to `noeviction` (the default), write commands that would exceed the cap return an error. | | `RECACHED_EVICTION` | `noeviction` | Eviction policy when `RECACHED_MAX_KEYS` is reached. See eviction policies below. | @@ -236,6 +237,7 @@ For most web applications, 1024 concurrent connections to the cache server is mo ## Notes on sensitive configuration +- By default every listener binds `0.0.0.0` (all interfaces). On a shared or internet-facing host, set `RECACHED_BIND=127.0.0.1` (or a specific private interface) **and** `RECACHED_PASSWORD`, or place the server behind a firewall. The Prometheus metrics port is unauthenticated, so it should never be exposed publicly. - Never commit `RECACHED_PASSWORD` to source control. Use an environment file (see [Installation — systemd service](/server/installation#systemd-service)) or a secrets manager (Vault, AWS Secrets Manager, Doppler). - The password is compared in constant time to prevent timing attacks, but the brute-force lockout (5 failed attempts → disconnect) is the primary protection. Use a long random password. - TLS is strongly recommended for any deployment where the cache server is reachable over a network that you do not fully control. Without TLS, `RECACHED_PASSWORD` is sent in plaintext on initial `AUTH`. diff --git a/recached-react/package.json b/recached-react/package.json index 2f04362..376e18f 100644 --- a/recached-react/package.json +++ b/recached-react/package.json @@ -1,6 +1,6 @@ { "name": "@recached/react", - "version": "0.1.5", + "version": "0.1.7", "description": "Official React hooks for Recached \u2014 zero-latency reactive cache", "type": "module", "main": "./dist/index.js", diff --git a/recached-react/src/usePubSub.ts b/recached-react/src/usePubSub.ts index 0a4421d..dbe8473 100644 --- a/recached-react/src/usePubSub.ts +++ b/recached-react/src/usePubSub.ts @@ -1,4 +1,4 @@ -import { useEffect } from 'react'; +import { useEffect, useRef } from 'react'; import { useRecached } from './context'; /** @@ -22,16 +22,15 @@ import { useRecached } from './context'; */ export function usePubSub(channel: string, handler: (msg: string) => void): void { const cache = useRecached(); + const handlerRef = useRef(handler); + handlerRef.current = handler; useEffect(() => { cache.subscribe(channel); - const unsub = cache.onMessage(channel, handler); + const unsub = cache.onMessage(channel, (msg) => handlerRef.current(msg)); return () => { unsub(); cache.unsubscribe(channel); }; - // channel is the stable dependency; handler intentionally excluded to - // avoid re-subscribing on every render when defined inline - // eslint-disable-next-line react-hooks/exhaustive-deps }, [channel]); } diff --git a/recached-vue/package.json b/recached-vue/package.json index af06503..2806092 100644 --- a/recached-vue/package.json +++ b/recached-vue/package.json @@ -1,6 +1,6 @@ { "name": "@recached/vue", - "version": "0.1.5", + "version": "0.1.7", "description": "Official Vue 3 composables for Recached \u2014 zero-latency reactive cache", "type": "module", "main": "./dist/index.js", diff --git a/recached-vue/src/useKey.ts b/recached-vue/src/useKey.ts index f28be84..e582b1c 100644 --- a/recached-vue/src/useKey.ts +++ b/recached-vue/src/useKey.ts @@ -29,10 +29,11 @@ import { useRecached } from './plugin'; */ export function useKey(key: string): Ref { const cache = useRecached(); - const value = ref(cache.get(key)); + const value = ref(null); const unsub = cache.onMutation(() => { value.value = cache.get(key); }); + value.value = cache.get(key); onUnmounted(unsub); return value; } @@ -59,10 +60,11 @@ export function useKey(key: string): Ref { */ export function useKeyJSON(key: string): Ref { const cache = useRecached(); - const value = ref(cache.getJSON(key)) as Ref; + const value = ref(null) as Ref; const unsub = cache.onMutation(() => { value.value = cache.getJSON(key); }); + value.value = cache.getJSON(key); onUnmounted(unsub); return value; } diff --git a/server-native/Cargo.toml b/server-native/Cargo.toml index 8bbd2d2..019a3ba 100644 --- a/server-native/Cargo.toml +++ b/server-native/Cargo.toml @@ -16,3 +16,6 @@ tokio-rustls.workspace = true rustls-pemfile.workspace = true serde.workspace = true rmp-serde.workspace = true +socket2.workspace = true +num_cpus.workspace = true +tikv-jemallocator.workspace = true diff --git a/server-native/src/main.rs b/server-native/src/main.rs index 5bb915c..126c1a1 100644 --- a/server-native/src/main.rs +++ b/server-native/src/main.rs @@ -1,3 +1,6 @@ +#[global_allocator] +static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + use core_engine::cmd::{Command, SetExpiry, ZAddCondition}; use core_engine::resp::Value; use core_engine::store::{EvictionPolicy, KeyValueStore, SnapshotEntry}; @@ -9,8 +12,8 @@ use std::io::ErrorKind; use std::net::IpAddr; use std::path::PathBuf; use std::str::FromStr; +use std::sync::Arc; use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; @@ -171,6 +174,42 @@ fn execute_and_record(store: &KeyValueStore, cmd: &Command) -> Value { response } +// ── TCP listeners ───────────────────────────────────────────────────────────── + +/// Binds `n` TCP sockets on `addr`, all with `SO_REUSEPORT`, so the OS can +/// distribute incoming connections across multiple accept loops — one per +/// Tokio worker thread. Falls back to a single plain `TcpListener::bind` on +/// platforms that don't support `SO_REUSEPORT`. +fn make_tcp_listeners(addr: &str, n: usize) -> std::io::Result> { + use socket2::{Domain, Socket, Type}; + let socket_addr: std::net::SocketAddr = addr + .parse() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?; + let domain = if socket_addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + // SO_REUSEPORT — which lets multiple sockets share one port — is Unix-only. + // Without it, binding a second socket to the same port fails, so fall back + // to a single accept loop on non-Unix platforms. + #[cfg(not(unix))] + let n = 1; + let mut out = Vec::with_capacity(n); + for _ in 0..n { + let sock = Socket::new(domain, Type::STREAM, None)?; + sock.set_reuse_address(true)?; + #[cfg(unix)] + sock.set_reuse_port(true)?; + sock.set_nonblocking(true)?; + sock.bind(&socket_addr.into())?; + sock.listen(4096)?; + let std_listener: std::net::TcpListener = sock.into(); + out.push(TcpListener::from_std(std_listener)?); + } + Ok(out) +} + // ── TLS ─────────────────────────────────────────────────────────────────────── fn load_certs(path: &str) -> std::io::Result>> { @@ -205,7 +244,7 @@ fn load_tls_acceptor() -> Option { // ── tunables ──────────────────────────────────────────────────────────────── -const TCP_READ_BUFFER_BYTES: usize = 4096; +const TCP_READ_BUFFER_BYTES: usize = 16 * 1024; // 16 KB — matches Redis default const MAX_TCP_READ_BUFFER_BYTES: usize = 64 * 1024 * 1024; // 64 MB per connection const MAX_MULTI_QUEUE_LEN: usize = 10_000; const MAX_WATCHES_PER_CONN: usize = 1_024; @@ -358,7 +397,13 @@ async fn replay_aof(store: &KeyValueStore, path: &std::path::Path) -> usize { match Value::parse(&bytes[offset..]) { Ok((value, consumed)) => { offset += consumed; - if let Ok(cmd) = Command::from_value(value) { + // Writes are recorded via `on_write` in RESP3 Push form (`>N`); + // normalise to Array so Command::from_value can parse them. + let normalised = match value { + Value::Push(inner) => Value::Array(Some(inner)), + other => other, + }; + if let Ok(cmd) = Command::from_value(normalised) { store.execute(cmd); replayed += 1; } @@ -387,6 +432,12 @@ type ReplRegistry = Arc>>; /// is never blocked. const DEFAULT_REPL_CHANNEL_CAPACITY: usize = 4096; +/// Upper bound on a single length-prefixed replication frame (snapshot or +/// command). The replication port may be unauthenticated and plaintext, so an +/// untrusted peer could otherwise send a 4 GB length prefix and force a matching +/// allocation. 512 MB comfortably covers a large snapshot while bounding abuse. +const MAX_REPL_FRAME_BYTES: usize = 512 * 1024 * 1024; + // ── Server state ────────────────────────────────────────────────────────────── struct ServerState { @@ -512,6 +563,7 @@ fn parse_save_conditions(s: &str) -> Vec { // ── Replication server (primary side) ──────────────────────────────────────── async fn run_repl_server( + bind_host: String, port: u16, store: Arc, snap_cfg: Arc, @@ -519,14 +571,14 @@ async fn run_repl_server( repl_password: Option>, repl_channel_capacity: usize, ) { - let listener = match TcpListener::bind(format!("0.0.0.0:{}", port)).await { + let listener = match TcpListener::bind(format!("{}:{}", bind_host, port)).await { Ok(l) => l, Err(e) => { warn!("Replication listener failed to bind :{}: {}", port, e); return; } }; - info!("Replication server listening on 0.0.0.0:{}", port); + info!("Replication server listening on {}:{}", bind_host, port); loop { match listener.accept().await { Ok((socket, addr)) => { @@ -569,7 +621,7 @@ async fn handle_replica( socket.read_exact(&mut auth_buf).await?; let received_pwd = &auth_buf[..pwd.len()]; let terminator = auth_buf[pwd.len()]; - if received_pwd != pwd.as_bytes() || terminator != b'\n' { + if !ct_eq_bytes(received_pwd, pwd.as_bytes()) || terminator != b'\n' { let _ = socket .write_all(b"-ERR invalid replication password\n") .await; @@ -612,6 +664,7 @@ async fn run_repl_client( state: Arc, repl_password: Option, failover_timeout_secs: Option, + tx: broadcast::Sender<(u64, String)>, ) { let mut backoff_secs = 2u64; let mut unreachable_since: Option = None; @@ -633,7 +686,8 @@ async fn run_repl_client( unreachable_since = None; backoff_secs = 2; if let Err(e) = - sync_from_primary(&mut socket, &store, repl_password.as_deref()).await + sync_from_primary(&mut socket, &store, repl_password.as_deref(), &tx, &state) + .await { warn!("Replica: sync ended: {}", e); // Sync dropped — primary may be gone; start tracking if not already. @@ -668,6 +722,8 @@ async fn sync_from_primary( socket: &mut TcpStream, store: &KeyValueStore, repl_password: Option<&str>, + tx: &broadcast::Sender<(u64, String)>, + state: &ServerState, ) -> std::io::Result<()> { // 0. Send auth password if configured if let Some(pwd) = repl_password { @@ -689,6 +745,12 @@ async fn sync_from_primary( let mut len_buf = [0u8; 4]; socket.read_exact(&mut len_buf).await?; let snap_len = u32::from_le_bytes(len_buf) as usize; + if snap_len > MAX_REPL_FRAME_BYTES { + return Err(std::io::Error::new( + ErrorKind::InvalidData, + format!("snapshot frame too large ({snap_len} > {MAX_REPL_FRAME_BYTES} bytes)"), + )); + } let mut snap_bytes = vec![0u8; snap_len]; socket.read_exact(&mut snap_bytes).await?; @@ -708,6 +770,12 @@ async fn sync_from_primary( let mut len_buf = [0u8; 4]; socket.read_exact(&mut len_buf).await?; let cmd_len = u32::from_le_bytes(len_buf) as usize; + if cmd_len > MAX_REPL_FRAME_BYTES { + return Err(std::io::Error::new( + ErrorKind::InvalidData, + format!("command frame too large ({cmd_len} > {MAX_REPL_FRAME_BYTES} bytes)"), + )); + } let mut cmd_bytes = vec![0u8; cmd_len]; socket.read_exact(&mut cmd_bytes).await?; @@ -721,6 +789,12 @@ async fn sync_from_primary( }; if let Ok(cmd) = Command::from_value(normalised) { store.execute(cmd); + // Relay the applied write so this replica's own WebSocket + // clients see it, and any sub-replicas / AOF get it too + // (enables multi-tier replication and replica WS push). + let frame = String::from_utf8_lossy(&cmd_bytes).into_owned(); + let _ = tx.send((0, frame.clone())); + state.on_write(&frame).await; } } Err(e) => warn!("Replica: bad command from primary: {}", e), @@ -728,6 +802,19 @@ async fn sync_from_primary( } } +// ── security helpers ───────────────────────────────────────────────────────── + +/// Constant-time byte slice equality to prevent timing-based password leaks. +fn ct_eq_bytes(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + a.iter() + .zip(b.iter()) + .fold(0u8, |acc, (x, y)| acc | (x ^ y)) + == 0 +} + // ── connection identity ────────────────────────────────────────────────────── // TCP mutation broadcasts use id=0; WS/TCP pubsub connections get ids ≥ 1. @@ -780,6 +867,9 @@ impl PubSubHub { fn unsubscribe(&mut self, conn_id: u64, channel: &str) { if let Some(v) = self.channel_subs.get_mut(channel) { v.retain(|(id, _)| *id != conn_id); + if v.is_empty() { + self.channel_subs.remove(channel); + } } } @@ -789,9 +879,10 @@ impl PubSubHub { } fn unsubscribe_all(&mut self, conn_id: u64) { - for v in self.channel_subs.values_mut() { + self.channel_subs.retain(|_, v| { v.retain(|(id, _)| *id != conn_id); - } + !v.is_empty() + }); self.pattern_subs.retain(|(_, id, _)| *id != conn_id); } @@ -799,7 +890,10 @@ impl PubSubHub { fn publish(&mut self, channel: &str, message: &str) -> i64 { let mut count = 0i64; - if let Some(subs) = self.channel_subs.get_mut(channel) { + if let std::collections::hash_map::Entry::Occupied(mut e) = + self.channel_subs.entry(channel.to_string()) + { + let subs = e.get_mut(); subs.retain(|(_, tx)| { let ok = tx .send(PubSubMsg::Message { @@ -812,6 +906,9 @@ impl PubSubHub { } ok }); + if subs.is_empty() { + e.remove(); + } } let pattern_txs: Vec<(String, PubSubSender)> = self @@ -837,12 +934,13 @@ impl PubSubHub { } } -type SharedPubSub = Arc>; +type SharedPubSub = Arc>; // ── observable keys ─────────────────────────────────────────────────────────── type WatchNotif = (String, Value); -type WatchRegistry = Arc)>>>>; +type WatchRegistry = + Arc)>>>>; /// Extract the key(s) that `cmd` writes to, without inspecting the response. /// Used together with `broadcast_for()` — only call this when `broadcast_for` @@ -905,7 +1003,7 @@ fn encode_keychange(key: &str, value: &Value) -> Vec { .serialize() } -fn notify_watchers( +async fn notify_watchers( registry: &WatchRegistry, cmd: &Command, response: &Value, @@ -924,7 +1022,7 @@ fn notify_watchers( .iter() .map(|k| (k.clone(), store.get_current(k))) .collect(); - let mut reg = registry.lock().unwrap(); + let mut reg = registry.lock().await; for (key, value) in &key_values { if let Some(subs) = reg.get_mut(key) { subs.retain(|(_, tx)| tx.send((key.clone(), value.clone())).is_ok()); @@ -935,6 +1033,28 @@ fn notify_watchers( } } +/// Drop all of `conn_id`'s WATCH registrations and clear `watched_keys`. +/// Called at every transaction boundary (EXEC, DISCARD) and on connection close, +/// matching Redis semantics that WATCH state is flushed by EXEC/DISCARD. +async fn unregister_all_watches( + registry: &WatchRegistry, + conn_id: u64, + watched_keys: &mut HashSet, +) { + if watched_keys.is_empty() { + return; + } + let mut reg = registry.lock().await; + for key in watched_keys.drain() { + if let Some(subs) = reg.get_mut(&key) { + subs.retain(|(id, _)| *id != conn_id); + if subs.is_empty() { + reg.remove(&key); + } + } + } +} + // ── helpers ────────────────────────────────────────────────────────────────── fn encode_pubsub_msg(msg: PubSubMsg) -> Vec { @@ -1351,7 +1471,9 @@ fn process_auth( failures: &mut u32, ) -> (bool, Vec) { match expected.as_ref() { - Some(pwd) if provided == pwd => { + // Constant-time compare so a network attacker can't recover the password + // byte-by-byte from response-timing differences. + Some(pwd) if ct_eq_bytes(provided.as_bytes(), pwd.as_bytes()) => { *is_authenticated = true; *failures = 0; (false, b"+OK\r\n".to_vec()) @@ -1382,12 +1504,27 @@ async fn main() -> Result<(), Box> { ) .init(); + // ── bind address ────────────────────────────────────────────────────── + // Host/interface all listeners bind to. Defaults to 0.0.0.0 (all + // interfaces) for backwards compatibility; set RECACHED_BIND=127.0.0.1 to + // restrict to localhost, which — together with RECACHED_PASSWORD — is + // strongly recommended unless the server is deliberately public. + let bind_host = std::env::var("RECACHED_BIND").unwrap_or_else(|_| "0.0.0.0".to_string()); + if bind_host == "0.0.0.0" { + warn!( + "Binding all interfaces (0.0.0.0). Set RECACHED_BIND=127.0.0.1 and RECACHED_PASSWORD before exposing this host." + ); + } else { + info!("Binding interface {}", bind_host); + } + // ── Prometheus metrics ──────────────────────────────────────────────── let metrics_port: u16 = std::env::var("RECACHED_METRICS_PORT") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(9091); - let metrics_addr: std::net::SocketAddr = format!("0.0.0.0:{}", metrics_port).parse().unwrap(); + let metrics_addr: std::net::SocketAddr = + format!("{}:{}", bind_host, metrics_port).parse().unwrap(); metrics_exporter_prometheus::PrometheusBuilder::new() .with_http_listener(metrics_addr) .install() @@ -1607,23 +1744,34 @@ async fn main() -> Result<(), Box> { ); } + // ── broadcast channel (mutation sync) ──────────────────────────────── + // Carries (sender_conn_id, resp_encoded_mutation). WS receivers skip their + // own messages. Created before replication so a replica can push the writes + // it receives from the primary to its own local WebSocket clients. + let (tx, _rx) = broadcast::channel::<(u64, String)>(BROADCAST_CHANNEL_CAPACITY); + // ── start replication ───────────────────────────────────────────────── - if !is_replica_start { + // The replication server runs on every node — including replicas — so a + // replica can in turn serve sub-replicas (multi-tier replication). + { let store_r = Arc::clone(&store); let snap_r = Arc::clone(&snap_cfg); let reg_r = Arc::clone(&replicas); let pwd_r = repl_password.clone().map(Arc::new); let cap_r = repl_channel_capacity; + let host_r = bind_host.clone(); tokio::spawn(async move { - run_repl_server(repl_port, store_r, snap_r, reg_r, pwd_r, cap_r).await; + run_repl_server(host_r, repl_port, store_r, snap_r, reg_r, pwd_r, cap_r).await; }); - } else if let Some(primary_addr) = replicaof { + } + if is_replica_start && let Some(primary_addr) = replicaof { let store_r = Arc::clone(&store); let state_r = Arc::clone(&state); let pwd_r = repl_password.clone(); let fo_r = failover_timeout_secs; + let tx_r = tx.clone(); tokio::spawn(async move { - run_repl_client(primary_addr, store_r, state_r, pwd_r, fo_r).await; + run_repl_client(primary_addr, store_r, state_r, pwd_r, fo_r, tx_r).await; }); if let Some(t) = failover_timeout_secs { info!( @@ -1651,15 +1799,11 @@ async fn main() -> Result<(), Box> { }); } - // ── broadcast channel (mutation sync) ──────────────────────────────── - // Carries (sender_conn_id, resp_encoded_mutation). WS receivers skip their own messages. - let (tx, _rx) = broadcast::channel::<(u64, String)>(BROADCAST_CHANNEL_CAPACITY); - // ── pub/sub hub ─────────────────────────────────────────────────────── - let pubsub: SharedPubSub = Arc::new(Mutex::new(PubSubHub::new())); + let pubsub: SharedPubSub = Arc::new(tokio::sync::Mutex::new(PubSubHub::new())); // ── watch registry ──────────────────────────────────────────────────── - let watch_registry: WatchRegistry = Arc::new(Mutex::new(HashMap::new())); + let watch_registry: WatchRegistry = Arc::new(tokio::sync::Mutex::new(HashMap::new())); // ── connection limiter ──────────────────────────────────────────────── let max_connections = std::env::var("RECACHED_MAX_CONNECTIONS") @@ -1683,62 +1827,75 @@ async fn main() -> Result<(), Box> { let tls_acceptor = Arc::new(tls_acceptor); // ── listeners ───────────────────────────────────────────────────────── - let tcp_listener = TcpListener::bind("0.0.0.0:6379").await?; - info!("TCP server listening on 0.0.0.0:6379"); - - let ws_listener = TcpListener::bind("0.0.0.0:6380").await?; - info!("WebSocket server listening on 0.0.0.0:6380"); - - let store_tcp = Arc::clone(&store); - let tx_tcp = tx.clone(); - let pass_tcp = Arc::clone(&global_password); - let allowed_tcp = allowed_ips.clone(); - let sem_tcp = Arc::clone(&semaphore); - let pubsub_tcp = Arc::clone(&pubsub); - let tls_tcp = Arc::clone(&tls_acceptor); - let watch_tcp = Arc::clone(&watch_registry); - let snap_tcp = Arc::clone(&state); + let n_accept = num_cpus::get(); + let tcp_listeners = make_tcp_listeners(&format!("{}:6379", bind_host), n_accept)?; + info!( + "TCP server listening on {}:6379 ({} accept loop(s))", + bind_host, n_accept + ); - tokio::spawn(async move { - loop { - match tcp_listener.accept().await { - Ok((socket, addr)) => { - if let Some(allowed) = &allowed_tcp - && !allowed.contains(&addr.ip()) - { - debug!("TCP: rejected IP {}", addr.ip()); - continue; - } - let permit = match Arc::clone(&sem_tcp).try_acquire_owned() { - Ok(p) => p, - Err(_) => { - warn!("TCP: connection limit reached, dropping {}", addr); + let ws_listener = TcpListener::bind(format!("{}:6380", bind_host)).await?; + info!("WebSocket server listening on {}:6380", bind_host); + + // Spawn one accept loop per CPU core, each with its own SO_REUSEPORT socket. + // The OS load-balances incoming connections across all loops. + for tcp_listener in tcp_listeners { + let store_tcp = Arc::clone(&store); + let tx_tcp = tx.clone(); + let pass_tcp = Arc::clone(&global_password); + let allowed_tcp = allowed_ips.clone(); + let sem_tcp = Arc::clone(&semaphore); + let pubsub_tcp = Arc::clone(&pubsub); + let tls_tcp = Arc::clone(&tls_acceptor); + let watch_tcp = Arc::clone(&watch_registry); + let snap_tcp = Arc::clone(&state); + + tokio::spawn(async move { + loop { + match tcp_listener.accept().await { + Ok((socket, addr)) => { + let _ = socket.set_nodelay(true); + if let Some(allowed) = &allowed_tcp + && !allowed.contains(&addr.ip()) + { + debug!("TCP: rejected IP {}", addr.ip()); continue; } - }; - let s = Arc::clone(&store_tcp); - let t = tx_tcp.clone(); - let p = Arc::clone(&pass_tcp); - let ps = Arc::clone(&pubsub_tcp); - let wr = Arc::clone(&watch_tcp); - let tls = Arc::clone(&tls_tcp); - let sc = Arc::clone(&snap_tcp); - tokio::spawn(async move { - let _permit = permit; - if let Some(acc) = tls.as_ref() { - match acc.accept(socket).await { - Ok(tls_stream) => handle_tcp(tls_stream, s, t, p, ps, wr, sc).await, - Err(e) => warn!("TCP TLS handshake failed from {}: {}", addr, e), + let permit = match Arc::clone(&sem_tcp).try_acquire_owned() { + Ok(p) => p, + Err(_) => { + warn!("TCP: connection limit reached, dropping {}", addr); + continue; } - } else { - handle_tcp(socket, s, t, p, ps, wr, sc).await; - } - }); + }; + let s = Arc::clone(&store_tcp); + let t = tx_tcp.clone(); + let p = Arc::clone(&pass_tcp); + let ps = Arc::clone(&pubsub_tcp); + let wr = Arc::clone(&watch_tcp); + let tls = Arc::clone(&tls_tcp); + let sc = Arc::clone(&snap_tcp); + tokio::spawn(async move { + let _permit = permit; + if let Some(acc) = tls.as_ref() { + match acc.accept(socket).await { + Ok(tls_stream) => { + handle_tcp(tls_stream, s, t, p, ps, wr, sc).await + } + Err(e) => { + warn!("TCP TLS handshake failed from {}: {}", addr, e) + } + } + } else { + handle_tcp(socket, s, t, p, ps, wr, sc).await; + } + }); + } + Err(e) => warn!("TCP accept error: {}", e), } - Err(e) => warn!("TCP accept error: {}", e), } - } - }); + }); + } // ── graceful shutdown via oneshot channel ──────────────────────────── let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); @@ -1767,6 +1924,7 @@ async fn main() -> Result<(), Box> { res = ws_listener.accept() => { match res { Ok((socket, addr)) => { + let _ = socket.set_nodelay(true); if let Some(allowed) = &allowed_ips && !allowed.contains(&addr.ip()) { @@ -1830,8 +1988,10 @@ async fn handle_tcp( S: AsyncRead + AsyncWrite + Unpin + Send, { let _guard = ConnectionGuard::tcp(); - let (mut reader, mut writer) = tokio::io::split(socket); + let (mut reader, raw_writer) = tokio::io::split(socket); + let mut writer = tokio::io::BufWriter::with_capacity(32 * 1024, raw_writer); let mut buf = Vec::::new(); + let mut read_pos: usize = 0; let mut read_buf = [0u8; TCP_READ_BUFFER_BYTES]; let mut is_authenticated = password.is_none(); let mut auth_failures: u32 = 0; @@ -1839,6 +1999,11 @@ async fn handle_tcp( let mut subscribed_channels: HashSet = HashSet::new(); let mut subscribed_patterns: HashSet = HashSet::new(); let (ps_tx, mut ps_rx) = mpsc::unbounded_channel::(); + // WATCH state for optimistic-lock transactions over TCP. Unlike the WS + // handler, TCP clients are not sent keychange pushes — WATCH is pure CAS. + let mut watched_keys: HashSet = HashSet::new(); + let mut watch_dirty = false; + let (watch_tx, mut watch_rx) = mpsc::unbounded_channel::(); let conn_id = next_conn_id(); 'outer: loop { @@ -1849,15 +2014,15 @@ async fn handle_tcp( match result { Ok(0) => break, Ok(n) => { - if buf.len() + n > MAX_TCP_READ_BUFFER_BYTES { + if (buf.len() - read_pos) + n > MAX_TCP_READ_BUFFER_BYTES { warn!("TCP connection exceeded max buffer size, closing"); break 'outer; } buf.extend_from_slice(&read_buf[..n]); 'parse: loop { - match Value::parse(&buf) { + match Value::parse(&buf[read_pos..]) { Ok((value, consumed)) => { - buf.drain(..consumed); + read_pos += consumed; let cmd = match Command::from_value(value) { Ok(c) => c, Err(e) => { @@ -1873,7 +2038,10 @@ async fn handle_tcp( pwd, &password, &mut is_authenticated, &mut auth_failures, ); if writer.write_all(&resp).await.is_err() { break 'outer; } - if disconnect { break 'outer; } + if disconnect { + let _ = writer.flush().await; + break 'outer; + } continue 'parse; } @@ -1898,6 +2066,10 @@ async fn handle_tcp( } Command::Discard => { let resp = if multi_queue.take().is_some() { + // DISCARD also flushes WATCH state. + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; + while watch_rx.try_recv().is_ok() {} + watch_dirty = false; b"+OK\r\n".to_vec() } else { b"-ERR DISCARD without MULTI\r\n".to_vec() @@ -1911,18 +2083,33 @@ async fn handle_tcp( if writer.write_all(b"-ERR EXEC without MULTI\r\n").await.is_err() { break 'outer; } } Some(queue) => { - let mut results = Vec::with_capacity(queue.len()); - for qcmd in queue { - let resp = execute_and_record(&store, &qcmd); - if let Some(msg) = broadcast_for(&qcmd, &resp) { - let _ = tx.send((0, msg.clone())); - state.on_write(&msg).await; + // Drain pending notifications so the CAS check isn't racy. + while watch_rx.try_recv().is_ok() { + watch_dirty = true; + } + if watch_dirty { + // A watched key changed since WATCH — abort with nil array. + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; + while watch_rx.try_recv().is_ok() {} + watch_dirty = false; + if writer.write_all(&Value::Array(None).serialize()).await.is_err() { break 'outer; } + } else { + let mut results = Vec::with_capacity(queue.len()); + for qcmd in queue { + let resp = execute_and_record(&store, &qcmd); + if let Some(msg) = broadcast_for(&qcmd, &resp) { + let _ = tx.send((0, msg.clone())); + state.on_write(&msg).await; + } + notify_watchers(&watch_registry, &qcmd, &resp, &store).await; + results.push(resp); } - notify_watchers(&watch_registry, &qcmd, &resp, &store); - results.push(resp); + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; + while watch_rx.try_recv().is_ok() {} + watch_dirty = false; + let out = Value::Array(Some(results)).serialize(); + if writer.write_all(&out).await.is_err() { break 'outer; } } - let out = Value::Array(Some(results)).serialize(); - if writer.write_all(&out).await.is_err() { break 'outer; } } } continue 'parse; @@ -1932,11 +2119,12 @@ async fn handle_tcp( // If inside MULTI, queue the command if let Some(ref mut queue) = multi_queue { - // Pub/sub commands cannot be queued + // Pub/sub and WATCH commands cannot be queued match &cmd { Command::Subscribe(_) | Command::Unsubscribe(_) | Command::PSubscribe(_) | Command::PUnsubscribe(_) - | Command::Publish(_, _) => { + | Command::Publish(_, _) + | Command::Watch(_) | Command::Unwatch(_) => { let err = b"-ERR Command not allowed inside a transaction\r\n"; if writer.write_all(err).await.is_err() { break 'outer; } } @@ -1958,7 +2146,7 @@ async fn handle_tcp( Command::Subscribe(channels) => { for ch in channels { subscribed_channels.insert(ch.clone()); - pubsub.lock().unwrap().subscribe(conn_id, &ch, ps_tx.clone()); + pubsub.lock().await.subscribe(conn_id, &ch, ps_tx.clone()); let count = subscribed_channels.len() + subscribed_patterns.len(); let ack = resp_subscribe_ack("subscribe", &ch, count); if writer.write_all(&ack).await.is_err() { break 'outer; } @@ -1971,7 +2159,7 @@ async fn handle_tcp( channels.into_iter().filter(|c| subscribed_channels.remove(c)).collect() }; for ch in &targets { - pubsub.lock().unwrap().unsubscribe(conn_id, ch); + pubsub.lock().await.unsubscribe(conn_id, ch); let count = subscribed_channels.len() + subscribed_patterns.len(); let ack = resp_subscribe_ack("unsubscribe", ch, count); if writer.write_all(&ack).await.is_err() { break 'outer; } @@ -1984,7 +2172,7 @@ async fn handle_tcp( Command::PSubscribe(patterns) => { for pat in patterns { subscribed_patterns.insert(pat.clone()); - pubsub.lock().unwrap().psubscribe(conn_id, &pat, ps_tx.clone()); + pubsub.lock().await.psubscribe(conn_id, &pat, ps_tx.clone()); let count = subscribed_channels.len() + subscribed_patterns.len(); let ack = resp_subscribe_ack("psubscribe", &pat, count); if writer.write_all(&ack).await.is_err() { break 'outer; } @@ -1997,7 +2185,7 @@ async fn handle_tcp( patterns.into_iter().filter(|p| subscribed_patterns.remove(p)).collect() }; for pat in &targets { - pubsub.lock().unwrap().punsubscribe(conn_id, pat); + pubsub.lock().await.punsubscribe(conn_id, pat); let count = subscribed_channels.len() + subscribed_patterns.len(); let ack = resp_subscribe_ack("punsubscribe", pat, count); if writer.write_all(&ack).await.is_err() { break 'outer; } @@ -2008,11 +2196,49 @@ async fn handle_tcp( } } Command::Publish(channel, message) => { - let count = pubsub.lock().unwrap().publish(&channel, &message); + let count = pubsub.lock().await.publish(&channel, &message); let resp = Value::Integer(count).serialize(); if writer.write_all(&resp).await.is_err() { break 'outer; } } + Command::Watch(keys) => { + let new_count = keys.iter().filter(|k| !watched_keys.contains(*k)).count(); + if watched_keys.len() + new_count > MAX_WATCHES_PER_CONN { + if writer.write_all(b"-ERR watch limit per connection reached\r\n").await.is_err() { break 'outer; } + } else { + { + let mut reg = watch_registry.lock().await; + for key in &keys { + if watched_keys.insert(key.clone()) { + reg.entry(key.clone()).or_default().push((conn_id, watch_tx.clone())); + } + } + } + if writer.write_all(b"+OK\r\n").await.is_err() { break 'outer; } + } + } + Command::Unwatch(keys) => { + let targets: Vec = if keys.is_empty() { + watched_keys.drain().collect() + } else { + keys.into_iter().filter(|k| watched_keys.remove(k)).collect() + }; + { + let mut reg = watch_registry.lock().await; + for key in &targets { + if let Some(subs) = reg.get_mut(key) { + subs.retain(|(id, _)| *id != conn_id); + if subs.is_empty() { reg.remove(key); } + } + } + } + if watched_keys.is_empty() { + while watch_rx.try_recv().is_ok() {} + watch_dirty = false; + } + if writer.write_all(b"+OK\r\n").await.is_err() { break 'outer; } + } + cmd => { // In subscribe mode only ping is allowed if is_subscribed && !matches!(cmd, Command::Ping(_)) { @@ -2057,22 +2283,32 @@ async fn handle_tcp( let _ = tx.send((0, msg.clone())); state.on_write(&msg).await; } - notify_watchers(&watch_registry, &cmd, &response, &store); + notify_watchers(&watch_registry, &cmd, &response, &store).await; if writer.write_all(&response.serialize()).await.is_err() { break 'outer; } } } } - Err(ref e) if e == "Incomplete" => break 'parse, + Err(ref e) if e == "Incomplete" => { + // Compact: drop already-parsed bytes, reset cursor. + buf.drain(..read_pos); + read_pos = 0; + break 'parse; + } Err(e) => { warn!("TCP protocol error: {}", e); let _ = writer.write_all(b"-ERR Protocol error\r\n").await; buf.clear(); + read_pos = 0; break 'parse; } } } + // Flush all responses for this read batch in one syscall. + if writer.flush().await.is_err() { + break 'outer; + } } Err(e) => { warn!("TCP read error: {}", e); @@ -2091,12 +2327,21 @@ async fn handle_tcp( None => break, } } + + // A watched key changed: mark the transaction dirty so a following + // EXEC aborts. TCP clients get no keychange push (WATCH is pure CAS). + notif = watch_rx.recv(), if !watched_keys.is_empty() => { + if notif.is_some() { + watch_dirty = true; + } + } } } if !subscribed_channels.is_empty() || !subscribed_patterns.is_empty() { - pubsub.lock().unwrap().unsubscribe_all(conn_id); + pubsub.lock().await.unsubscribe_all(conn_id); } + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; } // ── WebSocket handler ───────────────────────────────────────────────────────── @@ -2132,8 +2377,14 @@ async fn handle_ws( let mut subscribed_patterns: HashSet = HashSet::new(); let (ps_tx, mut ps_rx) = mpsc::unbounded_channel::(); let mut watched_keys: HashSet = HashSet::new(); + // Set when any watched key changes; EXEC aborts (returns nil) if true. + let mut watch_dirty = false; let (watch_tx, mut watch_rx) = mpsc::unbounded_channel::(); + // NOTE: the WebSocket transport uses *text* frames, so values must be valid + // UTF-8. Non-UTF-8 bytes are replaced (lossy) on the way out. This is safe + // for the SDK, whose `set(key, value)` API only accepts `&str` values; raw + // binary values are only fully round-trippable over the TCP (RESP) port. macro_rules! ws_send { ($bytes:expr) => {{ let text = String::from_utf8_lossy($bytes).into_owned(); @@ -2198,6 +2449,10 @@ async fn handle_ws( } Command::Discard => { let resp = if multi_queue.take().is_some() { + // DISCARD also flushes WATCH state. + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; + while watch_rx.try_recv().is_ok() {} // drop stale notifications + watch_dirty = false; b"+OK\r\n".to_vec() } else { b"-ERR DISCARD without MULTI\r\n".to_vec() @@ -2211,18 +2466,38 @@ async fn handle_ws( ws_send!(b"-ERR EXEC without MULTI\r\n"); } Some(queue) => { - let mut results = Vec::with_capacity(queue.len()); - for qcmd in queue { - let resp = execute_and_record(&store, &qcmd); - if let Some(msg) = broadcast_for(&qcmd, &resp) { - let _ = tx.send((conn_id, msg.clone())); - state.on_write(&msg).await; + // Catch watched-key changes that arrived but the select + // loop hasn't drained yet, so the CAS check isn't racy. + while watch_rx.try_recv().is_ok() { + watch_dirty = true; + } + if watch_dirty { + // A watched key changed since WATCH — abort: return + // a nil array and run nothing (Redis CAS semantics). + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; + while watch_rx.try_recv().is_ok() {} // drop stale notifications + watch_dirty = false; + ws_send!(&Value::Array(None).serialize()); + } else { + let mut results = Vec::with_capacity(queue.len()); + for qcmd in queue { + let resp = execute_and_record(&store, &qcmd); + if let Some(msg) = broadcast_for(&qcmd, &resp) { + let _ = tx.send((conn_id, msg.clone())); + state.on_write(&msg).await; + } + notify_watchers(&watch_registry, &qcmd, &resp, &store).await; + results.push(resp); } - notify_watchers(&watch_registry, &qcmd, &resp, &store); - results.push(resp); + // EXEC always flushes WATCH state. Drain any + // self-notifications the queued writes produced so + // they can't dirty a later transaction. + unregister_all_watches(&watch_registry, conn_id, &mut watched_keys).await; + while watch_rx.try_recv().is_ok() {} + watch_dirty = false; + let out = Value::Array(Some(results)).serialize(); + ws_send!(&out); } - let out = Value::Array(Some(results)).serialize(); - ws_send!(&out); } } continue; @@ -2235,7 +2510,8 @@ async fn handle_ws( match &cmd { Command::Subscribe(_) | Command::Unsubscribe(_) | Command::PSubscribe(_) | Command::PUnsubscribe(_) - | Command::Publish(_, _) => { + | Command::Publish(_, _) + | Command::Watch(_) | Command::Unwatch(_) => { ws_send!(b"-ERR Command not allowed inside a transaction\r\n"); } _ => { @@ -2255,7 +2531,7 @@ async fn handle_ws( Command::Subscribe(channels) => { for ch in channels { subscribed_channels.insert(ch.clone()); - pubsub.lock().unwrap().subscribe(conn_id, &ch, ps_tx.clone()); + pubsub.lock().await.subscribe(conn_id, &ch, ps_tx.clone()); let count = subscribed_channels.len() + subscribed_patterns.len(); ws_send!(&resp_subscribe_ack("subscribe", &ch, count)); } @@ -2267,7 +2543,7 @@ async fn handle_ws( channels.into_iter().filter(|c| subscribed_channels.remove(c)).collect() }; for ch in &targets { - pubsub.lock().unwrap().unsubscribe(conn_id, ch); + pubsub.lock().await.unsubscribe(conn_id, ch); let count = subscribed_channels.len() + subscribed_patterns.len(); ws_send!(&resp_subscribe_ack("unsubscribe", ch, count)); } @@ -2278,7 +2554,7 @@ async fn handle_ws( Command::PSubscribe(patterns) => { for pat in patterns { subscribed_patterns.insert(pat.clone()); - pubsub.lock().unwrap().psubscribe(conn_id, &pat, ps_tx.clone()); + pubsub.lock().await.psubscribe(conn_id, &pat, ps_tx.clone()); let count = subscribed_channels.len() + subscribed_patterns.len(); ws_send!(&resp_subscribe_ack("psubscribe", &pat, count)); } @@ -2290,7 +2566,7 @@ async fn handle_ws( patterns.into_iter().filter(|p| subscribed_patterns.remove(p)).collect() }; for pat in &targets { - pubsub.lock().unwrap().punsubscribe(conn_id, pat); + pubsub.lock().await.punsubscribe(conn_id, pat); let count = subscribed_channels.len() + subscribed_patterns.len(); ws_send!(&resp_subscribe_ack("punsubscribe", pat, count)); } @@ -2299,7 +2575,7 @@ async fn handle_ws( } } Command::Publish(channel, message) => { - let count = pubsub.lock().unwrap().publish(&channel, &message); + let count = pubsub.lock().await.publish(&channel, &message); ws_send!(&Value::Integer(count).serialize()); } @@ -2312,7 +2588,7 @@ async fn handle_ws( ws_send!(b"-ERR watch limit per connection reached\r\n"); } else { { - let mut reg = watch_registry.lock().unwrap(); + let mut reg = watch_registry.lock().await; for key in &keys { if watched_keys.insert(key.clone()) { reg.entry(key.clone()) @@ -2331,7 +2607,7 @@ async fn handle_ws( keys.into_iter().filter(|k| watched_keys.remove(k)).collect() }; { - let mut reg = watch_registry.lock().unwrap(); + let mut reg = watch_registry.lock().await; for key in &targets { if let Some(subs) = reg.get_mut(key) { subs.retain(|(id, _)| *id != conn_id); @@ -2341,6 +2617,12 @@ async fn handle_ws( } } } + // Once nothing is watched, clear the dirty flag and drop any + // queued notifications so a later WATCH/MULTI/EXEC starts clean. + if watched_keys.is_empty() { + while watch_rx.try_recv().is_ok() {} + watch_dirty = false; + } ws_send!(b"+OK\r\n"); } @@ -2387,7 +2669,7 @@ async fn handle_ws( } state.on_write(&b_msg).await; } - notify_watchers(&watch_registry, &cmd, &response, &store); + notify_watchers(&watch_registry, &cmd, &response, &store).await; ws_send!(&response.serialize()); } } @@ -2432,6 +2714,10 @@ async fn handle_ws( notif = watch_rx.recv(), if !watched_keys.is_empty() => { if let Some((key, value)) = notif { + // A watched key changed: mark the transaction dirty (so a + // following EXEC aborts) and still push the keychange to the + // client for the observable-keys feature. + watch_dirty = true; let bytes = encode_keychange(&key, &value); let text = String::from_utf8_lossy(&bytes).into_owned(); if ws_sender.send(Message::Text(text.into())).await.is_err() { @@ -2443,10 +2729,10 @@ async fn handle_ws( } if !subscribed_channels.is_empty() || !subscribed_patterns.is_empty() { - pubsub.lock().unwrap().unsubscribe_all(conn_id); + pubsub.lock().await.unsubscribe_all(conn_id); } if !watched_keys.is_empty() { - let mut reg = watch_registry.lock().unwrap(); + let mut reg = watch_registry.lock().await; for key in &watched_keys { if let Some(subs) = reg.get_mut(key) { subs.retain(|(id, _)| *id != conn_id); @@ -2498,8 +2784,8 @@ mod tests { ) -> TestServer { let store = Arc::new(KeyValueStore::new()); let (tx, _rx) = broadcast::channel::<(u64, String)>(256); - let pubsub: SharedPubSub = Arc::new(Mutex::new(PubSubHub::new())); - let watch_registry: WatchRegistry = Arc::new(Mutex::new(HashMap::new())); + let pubsub: SharedPubSub = Arc::new(tokio::sync::Mutex::new(PubSubHub::new())); + let watch_registry: WatchRegistry = Arc::new(tokio::sync::Mutex::new(HashMap::new())); let semaphore = Arc::new(Semaphore::new(64)); let snap_cfg = Arc::new(SnapshotConfig { path: snap_path.unwrap_or_else(|| tmp_path("test.rdb")), @@ -2677,6 +2963,23 @@ mod tests { let _ = tokio::fs::remove_file(&path).await; } + #[tokio::test] + async fn replay_aof_push_frames() { + // The live server records writes via `on_write`, which stores them in + // RESP3 Push (`>`) form. Replay must accept those, not just `*` arrays. + let store = KeyValueStore::new(); + let path = tmp_path("aof_push.aof"); + let resp = ">3\r\n$3\r\nSET\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"; + tokio::fs::write(&path, resp.as_bytes()).await.unwrap(); + let count = replay_aof(&store, &path).await; + assert_eq!(count, 1); + assert_eq!( + store.execute(Command::Get("foo".into())), + Value::BulkString(Some(b"bar".to_vec())) + ); + let _ = tokio::fs::remove_file(&path).await; + } + // ── Snapshot save / load ────────────────────────────────────────────────── #[tokio::test] @@ -3103,8 +3406,8 @@ mod tests { let primary2 = { let store = Arc::clone(&primary.store); let (tx, _rx) = broadcast::channel::<(u64, String)>(256); - let pubsub: SharedPubSub = Arc::new(Mutex::new(PubSubHub::new())); - let wr: WatchRegistry = Arc::new(Mutex::new(HashMap::new())); + let pubsub: SharedPubSub = Arc::new(tokio::sync::Mutex::new(PubSubHub::new())); + let wr: WatchRegistry = Arc::new(tokio::sync::Mutex::new(HashMap::new())); let sem = Arc::new(Semaphore::new(64)); let snap = Arc::clone(&primary.state.snap); let state = Arc::new(ServerState { @@ -3162,8 +3465,9 @@ mod tests { let rs = Arc::clone(&replica_store); let rst = Arc::clone(&replica_state); let repl_addr = format!("127.0.0.1:{repl_port}"); + let rtx = broadcast::channel::<(u64, String)>(16).0; tokio::spawn(async move { - run_repl_client(repl_addr, rs, rst, None, None).await; + run_repl_client(repl_addr, rs, rst, None, None, rtx).await; }); // Give replica time to connect and receive initial snapshot @@ -3217,8 +3521,8 @@ mod tests { // Small semaphore: only 3 concurrent connections let store = Arc::new(KeyValueStore::new()); let (tx, _rx) = broadcast::channel::<(u64, String)>(16); - let pubsub: SharedPubSub = Arc::new(Mutex::new(PubSubHub::new())); - let watch_registry: WatchRegistry = Arc::new(Mutex::new(HashMap::new())); + let pubsub: SharedPubSub = Arc::new(tokio::sync::Mutex::new(PubSubHub::new())); + let watch_registry: WatchRegistry = Arc::new(tokio::sync::Mutex::new(HashMap::new())); let semaphore = Arc::new(Semaphore::new(3)); let state = Arc::new(ServerState { snap: Arc::new(SnapshotConfig { @@ -3352,8 +3656,9 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let dead_addr = format!("127.0.0.1:{}", listener.local_addr().unwrap().port()); drop(listener); + let rtx = broadcast::channel::<(u64, String)>(16).0; tokio::spawn(async move { - run_repl_client(dead_addr, rs, rst, None, Some(1)).await; + run_repl_client(dead_addr, rs, rst, None, Some(1), rtx).await; }); // Wait for 2 backoff cycles (initial fail + 2s sleep + retry fail → promote) @@ -3364,4 +3669,198 @@ mod tests { "replica should have promoted after primary was unreachable for >1s" ); } + + // ── WebSocket WATCH/EXEC harness ────────────────────────────────────────── + + /// Spawn a WebSocket server sharing one store + watch registry across all + /// connections, so WATCH notifications fan out between clients. + async fn spawn_ws_server() -> TestServer { + let store = Arc::new(KeyValueStore::new()); + let (tx, _rx) = broadcast::channel::<(u64, String)>(256); + let pubsub: SharedPubSub = Arc::new(tokio::sync::Mutex::new(PubSubHub::new())); + let watch_registry: WatchRegistry = Arc::new(tokio::sync::Mutex::new(HashMap::new())); + let snap_cfg = Arc::new(SnapshotConfig { + path: tmp_path("ws_test.rdb"), + last_save: AtomicI64::new(now_unix_secs()), + }); + let state = Arc::new(ServerState { + snap: snap_cfg, + aof: None, + replicas: Arc::new(tokio::sync::Mutex::new(vec![])), + is_replica: AtomicBool::new(false), + }); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let store2 = Arc::clone(&store); + let state2 = Arc::clone(&state); + + let task = tokio::spawn(async move { + loop { + let Ok((socket, _)) = listener.accept().await else { + return; + }; + let (s, t, ps, wr, st) = ( + Arc::clone(&store2), + tx.clone(), + Arc::clone(&pubsub), + Arc::clone(&watch_registry), + Arc::clone(&state2), + ); + let id = next_conn_id(); + tokio::spawn(async move { + handle_ws(socket, s, t, Arc::new(None), id, ps, wr, st).await; + }); + } + }); + + TestServer { + tcp_addr: addr, + store, + state, + _task: task, + } + } + + struct WsClient { + ws: tokio_tungstenite::WebSocketStream>, + } + + impl WsClient { + async fn connect(addr: std::net::SocketAddr) -> Self { + let (ws, _) = tokio_tungstenite::connect_async(format!("ws://{addr}")) + .await + .unwrap(); + Self { ws } + } + + async fn cmd(&mut self, args: &[&str]) -> Value { + let mut req = format!("*{}\r\n", args.len()); + for a in args { + req.push_str(&format!("${}\r\n{}\r\n", a.len(), a)); + } + self.ws.send(Message::Text(req.into())).await.unwrap(); + self.next_reply().await + } + + /// Read the next *command reply*, skipping server-initiated frames + /// (RESP3 Push broadcasts and `keychange` observable-key pushes). + async fn next_reply(&mut self) -> Value { + loop { + match self.ws.next().await { + Some(Ok(Message::Text(t))) => { + let Ok((v, _)) = Value::parse(t.as_bytes()) else { + continue; + }; + if matches!(v, Value::Push(_)) { + continue; + } + if let Value::Array(Some(items)) = &v + && matches!(items.first(), Some(Value::BulkString(Some(k))) if k == b"keychange") + { + continue; + } + return v; + } + Some(Ok(_)) => continue, + _ => panic!("ws closed unexpectedly"), + } + } + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn integration_ws_watch_exec_aborts_on_change() { + let srv = spawn_ws_server().await; + let mut watcher = WsClient::connect(srv.tcp_addr).await; + let mut writer = WsClient::connect(srv.tcp_addr).await; + + assert_eq!(watcher.cmd(&["SET", "k", "v0"]).await, ok()); + assert_eq!(watcher.cmd(&["WATCH", "k"]).await, ok()); + + // Another client mutates the watched key. + assert_eq!(writer.cmd(&["SET", "k", "v1"]).await, ok()); + // Give the notification time to reach the watcher's registry channel. + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + assert_eq!( + watcher.cmd(&["MULTI"]).await, + Value::SimpleString("OK".into()) + ); + assert_eq!( + watcher.cmd(&["SET", "k", "v2"]).await, + Value::SimpleString("QUEUED".into()) + ); + // EXEC must abort with a nil array because k changed since WATCH. + assert_eq!(watcher.cmd(&["EXEC"]).await, Value::Array(None)); + // The transaction did not run. + assert_eq!(srv.store.execute(Command::Get("k".into())), bulk("v1")); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn integration_ws_watch_exec_runs_when_unchanged() { + let srv = spawn_ws_server().await; + let mut c = WsClient::connect(srv.tcp_addr).await; + + assert_eq!(c.cmd(&["WATCH", "k"]).await, ok()); + assert_eq!(c.cmd(&["MULTI"]).await, ok()); + assert_eq!( + c.cmd(&["SET", "k", "v1"]).await, + Value::SimpleString("QUEUED".into()) + ); + // No one touched k → EXEC runs and returns the queued results. + assert_eq!( + c.cmd(&["EXEC"]).await, + Value::Array(Some(vec![Value::SimpleString("OK".into())])) + ); + assert_eq!(srv.store.execute(Command::Get("k".into())), bulk("v1")); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn integration_tcp_watch_exec_aborts_on_change() { + let srv = spawn_server().await; + let mut watcher = RespClient::connect(srv.tcp_addr).await; + let mut writer = RespClient::connect(srv.tcp_addr).await; + + assert_eq!(watcher.cmd(&["SET", "k", "v0"]).await, ok()); + assert_eq!(watcher.cmd(&["WATCH", "k"]).await, ok()); + // Another client mutates the watched key (reply awaited → notification queued). + assert_eq!(writer.cmd(&["SET", "k", "v1"]).await, ok()); + + assert_eq!(watcher.cmd(&["MULTI"]).await, ok()); + assert_eq!( + watcher.cmd(&["SET", "k", "v2"]).await, + Value::SimpleString("QUEUED".into()) + ); + // k changed since WATCH → EXEC aborts with a nil array. + assert_eq!(watcher.cmd(&["EXEC"]).await, Value::Array(None)); + assert_eq!(watcher.cmd(&["GET", "k"]).await, bulk("v1")); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn integration_tcp_watch_exec_runs_when_unchanged() { + let srv = spawn_server().await; + let mut c = RespClient::connect(srv.tcp_addr).await; + + assert_eq!(c.cmd(&["WATCH", "k"]).await, ok()); + assert_eq!(c.cmd(&["MULTI"]).await, ok()); + assert_eq!( + c.cmd(&["SET", "k", "v1"]).await, + Value::SimpleString("QUEUED".into()) + ); + assert_eq!( + c.cmd(&["EXEC"]).await, + Value::Array(Some(vec![Value::SimpleString("OK".into())])) + ); + assert_eq!(c.cmd(&["GET", "k"]).await, bulk("v1")); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn integration_tcp_watch_inside_multi_rejected() { + let srv = spawn_server().await; + let mut c = RespClient::connect(srv.tcp_addr).await; + assert_eq!(c.cmd(&["MULTI"]).await, ok()); + // WATCH is not allowed once a transaction has started. + assert!(matches!(c.cmd(&["WATCH", "k"]).await, Value::Error(_))); + } } diff --git a/wasm-edge/package.json b/wasm-edge/package.json index 299115f..5ac3004 100644 --- a/wasm-edge/package.json +++ b/wasm-edge/package.json @@ -1,7 +1,7 @@ { "name": "recached-edge", "description": "Browser and edge WebAssembly client for Recached \u2014 zero-latency local cache with automatic server sync", - "version": "0.1.5", + "version": "0.1.7", "type": "module", "main": "sdk.js", "module": "sdk.js", diff --git a/wasm-edge/sdk.ts b/wasm-edge/sdk.ts index 7fc5610..f69abc7 100644 --- a/wasm-edge/sdk.ts +++ b/wasm-edge/sdk.ts @@ -215,8 +215,9 @@ export class Cache { * Store a string value. Syncs to the server and other tabs when connected. */ set(key: string, value: string): void { + // raw.set fires the mutation callback registered in the constructor, so + // listeners are already notified — no second _notifyMutation() here. this.raw.set(key, value); - this._notifyMutation(); } /** @@ -225,7 +226,6 @@ export class Cache { */ setEx(key: string, value: string, seconds: number): void { this.raw.set_ex(key, value, seconds); - this._notifyMutation(); } /** @@ -243,7 +243,6 @@ export class Cache { } else { this.raw.set(key, serialized); } - this._notifyMutation(); } /** @@ -254,7 +253,6 @@ export class Cache { */ del(key: string): boolean { const existed = this.raw.del(key) === 1; - this._notifyMutation(); return existed; } diff --git a/wasm-edge/src/lib.rs b/wasm-edge/src/lib.rs index 6590f51..28f7efd 100644 --- a/wasm-edge/src/lib.rs +++ b/wasm-edge/src/lib.rs @@ -1,13 +1,13 @@ use core_engine::cmd::{Command, SetOptions}; use core_engine::resp::Value; -use core_engine::store::{KeyValueStore, SnapshotEntry, SnapshotValue}; +use core_engine::store::{KeyValueStore, SnapshotEntry, SnapshotValue, format_score}; use js_sys::Promise; use std::cell::{Cell, RefCell}; use std::rc::Rc; use std::sync::Arc; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::{JsFuture, spawn_local}; -use web_sys::{BroadcastChannel, MessageEvent, WebSocket}; +use web_sys::{BroadcastChannel, Event, MessageEvent, WebSocket}; // ── IndexedDB JS glue ───────────────────────────────────────────────────────── // @@ -84,18 +84,6 @@ fn to_resp_owned(parts: &[String]) -> String { /// snapshot commands so the next replay is fast regardless of write history. const WAL_COMPACT_THRESHOLD: u32 = 1000; -fn format_zset_score(s: f64) -> String { - if s == f64::INFINITY { - "inf".into() - } else if s == f64::NEG_INFINITY { - "-inf".into() - } else if s.fract() == 0.0 && s.abs() < 1e15 { - format!("{}", s as i64) - } else { - format!("{}", s) - } -} - /// Convert snapshot entries into minimal RESP command strings suitable for /// storing in the WAL. Each entry produces one command; entries with a TTL on /// collection types produce an extra PEXPIREAT command. @@ -159,7 +147,7 @@ fn snapshot_to_resp_cmds(entries: &[SnapshotEntry]) -> Vec { } let mut parts = vec!["ZADD".to_string(), e.key.clone()]; for (member, score) in pairs { - parts.push(format_zset_score(*score)); + parts.push(format_score(*score)); parts.push(member.clone()); } parts @@ -167,14 +155,14 @@ fn snapshot_to_resp_cmds(entries: &[SnapshotEntry]) -> Vec { }; out.push(to_resp_owned(&data_parts)); // Non-string types with a TTL need a separate PEXPIREAT command. - if !matches!(&e.value, SnapshotValue::Str(_)) { - if let Some(exp) = e.expires_at_ms { - out.push(to_resp_owned(&[ - "PEXPIREAT".to_string(), - e.key.clone(), - exp.to_string(), - ])); - } + if !matches!(&e.value, SnapshotValue::Str(_)) + && let Some(exp) = e.expires_at_ms + { + out.push(to_resp_owned(&[ + "PEXPIREAT".to_string(), + e.key.clone(), + exp.to_string(), + ])); } } out @@ -197,6 +185,12 @@ pub struct RecachedCache { on_message: Rc>>, _onmessage: Option>, _onbc: Option>, + _onopen: Option>, + /// Commands enqueued while the socket is still CONNECTING. Flushed in order + /// by the `onopen` handler so AUTH and early writes are never dropped. + ws_pending: Rc>>, + /// True when connected via unencrypted ws:// (not wss://). + ws_is_plaintext: bool, } impl Default for RecachedCache { @@ -241,6 +235,23 @@ impl RecachedCache { on_message: Rc::new(RefCell::new(None)), _onmessage: None, _onbc: None, + _onopen: None, + ws_pending: Rc::new(RefCell::new(Vec::new())), + ws_is_plaintext: false, + } + } + + /// Send `encoded` to the server if the socket is open, otherwise buffer it + /// until `onopen` fires. Without this, anything sent during the CONNECTING + /// window (notably the AUTH that `createCache` issues right after connect) + /// would be silently dropped. + fn ws_enqueue(&self, encoded: &str) { + if let Some(ws) = &self.ws { + if ws.ready_state() == WebSocket::OPEN { + let _ = ws.send_with_str(encoded); + } else { + self.ws_pending.borrow_mut().push(encoded.to_string()); + } } } @@ -298,6 +309,9 @@ impl RecachedCache { // If the WAL grew large, compact: rewrite it as minimal snapshot // commands. This keeps startup replay fast regardless of how many // writes accumulated between refreshes. + // Note: there is a brief data-loss window between idb_clear_js and + // writing the new snapshot. If the tab is closed during compaction, + // the WAL will be empty on next load and in-memory state is lost. let next_seq = if entry_count > WAL_COMPACT_THRESHOLD { JsFuture::from(idb_clear_js(&db)).await?; let cmds = snapshot_to_resp_cmds(&store.snapshot()); @@ -403,6 +417,10 @@ impl RecachedCache { /// Connect to the native Recached backend via WebSockets. /// Calling this a second time cleanly replaces the previous connection. pub fn connect(&mut self, url: &str) -> Result<(), JsValue> { + if let Some(old_ws) = self.ws.take() { + let _ = old_ws.close(); + } + self.ws_is_plaintext = url.starts_with("ws://"); let ws = WebSocket::new(url)?; let store_clone = Arc::clone(&self.store); let on_mut = Rc::clone(&self.on_mutation); @@ -477,6 +495,18 @@ impl RecachedCache { ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + // Flush commands buffered while the socket was CONNECTING (AUTH first, + // then any early writes), preserving FIFO order. + let pending = Rc::clone(&self.ws_pending); + let ws_for_open = ws.clone(); + let onopen = Closure::wrap(Box::new(move |_e: Event| { + for msg in pending.borrow_mut().drain(..) { + let _ = ws_for_open.send_with_str(&msg); + } + }) as Box); + ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); + + self._onopen = Some(onopen); self._onmessage = Some(onmessage); self.ws = Some(ws); Ok(()) @@ -484,11 +514,12 @@ impl RecachedCache { /// Send an AUTH command to the server. The response arrives asynchronously via onmessage. pub fn auth(&self, password: &str) -> String { - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&to_resp(&["AUTH", password])); + if self.ws_is_plaintext { + let _ = web_sys::console::warn_1(&JsValue::from_str( + "recached: AUTH over unencrypted ws:// exposes the password in plaintext; use wss://", + )); } + self.ws_enqueue(&to_resp(&["AUTH", password])); "OK".to_string() } @@ -501,11 +532,7 @@ impl RecachedCache { )); let encoded = to_resp(&["SET", key, value]); - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&encoded); - } + self.ws_enqueue(&encoded); if let Some(bc) = &self.bc { let _ = bc.post_message(&JsValue::from_str(&encoded)); } @@ -530,11 +557,7 @@ impl RecachedCache { .execute(Command::Set(key.to_string(), value.to_string(), opts)); let encoded = to_resp(&["SET", key, value, "EX", &seconds.to_string()]); - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&encoded); - } + self.ws_enqueue(&encoded); if let Some(bc) = &self.bc { let _ = bc.post_message(&JsValue::from_str(&encoded)); } @@ -561,11 +584,7 @@ impl RecachedCache { let resp = self.store.execute(Command::Del(vec![key.to_string()])); let encoded = to_resp(&["DEL", key]); - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&encoded); - } + self.ws_enqueue(&encoded); if let Some(bc) = &self.bc { let _ = bc.post_message(&JsValue::from_str(&encoded)); } @@ -596,28 +615,16 @@ impl RecachedCache { /// Publish a message to a channel on the server. pub fn publish(&self, channel: &str, message: &str) { - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&to_resp(&["PUBLISH", channel, message])); - } + self.ws_enqueue(&to_resp(&["PUBLISH", channel, message])); } /// Subscribe to a channel on the server. Push messages arrive via the `onmessage` callback. pub fn subscribe(&self, channel: &str) { - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&to_resp(&["SUBSCRIBE", channel])); - } + self.ws_enqueue(&to_resp(&["SUBSCRIBE", channel])); } /// Unsubscribe from a channel on the server. pub fn unsubscribe(&self, channel: &str) { - if let Some(ws) = &self.ws - && ws.ready_state() == WebSocket::OPEN - { - let _ = ws.send_with_str(&to_resp(&["UNSUBSCRIBE", channel])); - } + self.ws_enqueue(&to_resp(&["UNSUBSCRIBE", channel])); } }