From 140b45bf069deeed569d63d034da3727f2b1f1b0 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 13 Dec 2025 03:24:42 -0500 Subject: [PATCH 1/5] feat(mlx): add thread-safe LRU prompt cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port mlx-lm's LRUPromptCache to fix race condition where concurrent requests corrupt shared KV cache state. The previous implementation used a single prompt_cache instance shared across all requests. Changes: - Add backend/python/common/mlx_cache.py with ThreadSafeLRUPromptCache - Modify backend.py to use per-request cache isolation via fetch/insert - Add prefix matching for cache reuse across similar prompts - Add LRU eviction (default 10 entries, configurable) - Add concurrency and cache unit tests The cache uses a trie-based structure for efficient prefix matching, allowing prompts that share common prefixes to reuse cached KV states. Thread safety is provided via threading.Lock. New configuration options: - max_cache_entries: Maximum LRU cache entries (default: 10) - max_kv_size: Maximum KV cache size per entry (default: None) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: Blightbow --- backend/python/common/mlx_cache.py | 262 +++++++++++++++++++++++++++++ backend/python/mlx/backend.py | 135 +++++++++++---- backend/python/mlx/test.py | 195 ++++++++++++++++++++- 3 files changed, 557 insertions(+), 35 deletions(-) create mode 100644 backend/python/common/mlx_cache.py diff --git a/backend/python/common/mlx_cache.py b/backend/python/common/mlx_cache.py new file mode 100644 index 000000000000..245639ff0cd0 --- /dev/null +++ b/backend/python/common/mlx_cache.py @@ -0,0 +1,262 @@ +""" +Thread-safe LRU prompt cache for MLX-based backends. + +Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.) +with thread-safety additions for LocalAI's gRPC backend. + +Usage: + from mlx_cache import ThreadSafeLRUPromptCache + + # In LoadModel: + self.lru_cache = ThreadSafeLRUPromptCache(max_size=10) + + # In Predict/PredictStream: + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens) + # ... generate ... + self.lru_cache.insert_cache(model_key, tokens, prompt_cache) +""" +import copy +import threading +from collections import deque +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + + +@dataclass +class CacheEntry: + """A cache entry with reference counting.""" + prompt_cache: List[Any] + count: int + + +@dataclass +class SearchResult: + """Result of searching the cache trie.""" + model: Any + exact: Optional[List[int]] + shorter: Optional[List[int]] + longer: Optional[List[int]] + common_prefix: int + + +class ThreadSafeLRUPromptCache: + """ + Thread-safe LRU cache with prefix matching for prompt KV caches. + + This cache stores KV caches keyed by token sequences and supports: + - Exact match: Return the cache for the exact token sequence + - Shorter prefix match: Return a cache for a prefix of the tokens + - Longer prefix match: If a longer sequence is cached and can be trimmed + - LRU eviction: When max_size is exceeded, evict least recently used + + Thread safety is provided via a threading.Lock that protects all + cache operations. + + Args: + max_size: Maximum number of cache entries (default: 10) + can_trim_fn: Optional function to check if a cache can be trimmed + trim_fn: Optional function to trim a cache + """ + + def __init__( + self, + max_size: int = 10, + can_trim_fn: Optional[Any] = None, + trim_fn: Optional[Any] = None, + ): + self.max_size = max_size + self._cache = {} + self._lru = deque() + self._lock = threading.Lock() + + # Optional trim functions (for longer prefix reuse) + self._can_trim_fn = can_trim_fn + self._trim_fn = trim_fn + + def _search(self, model, tokens: List[int]) -> SearchResult: + """ + Search the cache for a prompt cache. Return exact or close match. + + The cache is organized as a trie where each node is keyed by a token. + This allows efficient prefix matching. + """ + if model not in self._cache: + return SearchResult(model, None, None, None, 0) + + current = self._cache[model] + last_cache_index = -1 + index = 0 + + # Traverse the trie following the token sequence + while index < len(tokens) and tokens[index] in current: + current = current[tokens[index]] + if "cache" in current: + last_cache_index = index + index += 1 + + # Exact match - no need to search for longer or shorter caches + if last_cache_index == len(tokens) - 1: + return SearchResult(model, tuple(tokens), None, None, 0) + + # Find the shorter cache (a prefix that has a cache) + shorter = None + if last_cache_index > 0: + shorter = tuple(tokens[: last_cache_index + 1]) + + # Check for caches that are longer than our token sequence + longer = None + common_prefix = index + if index > 0 and last_cache_index <= 0: + best = None + stack = [(current, [])] + while stack: + current, extra = stack.pop() + if "cache" in current: + if best is None or len(extra) < len(best): + best = extra + else: + for tok in current: + stack.append((current[tok], extra + [tok])) + if best is not None: + longer = tuple(tokens[:index] + best) + + return SearchResult(model, None, shorter, longer, common_prefix) + + def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """Get a cache entry by traversing the trie.""" + current = self._cache[model] + for tok in tokens: + current = current[tok] + return current["cache"] + + def _delete(self, model, tokens: Tuple[int, ...]) -> None: + """Delete a cache entry and clean up empty trie nodes.""" + path = [self._cache[model]] + for tok in tokens: + path.append(path[-1][tok]) + del path[-1]["cache"] + + # Clean up empty nodes bottom-up + for i in reversed(range(len(tokens))): + d_prev, d, t = path[i], path[i + 1], tokens[i] + if len(d) > 0: + break + del d_prev[t] + + def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry: + """ + Extract a cache entry for exclusive use. + + If the entry has count > 1, deep copy and decrement. + If count == 1, remove from cache entirely. + """ + cache_entry = self._get(model, tokens) + if cache_entry.count == 1: + self._delete(model, tokens) + self._lru.remove((model, tokens)) + return cache_entry + + cache_entry.count -= 1 + return CacheEntry( + copy.deepcopy(cache_entry.prompt_cache), + 1, + ) + + def fetch_nearest_cache( + self, model, tokens: List[int] + ) -> Tuple[Optional[List[Any]], List[int]]: + """ + Fetch the nearest cache for the given token sequence. + + Thread-safe. Returns (cache, remaining_tokens) where: + - cache: The KV cache to use (or None if no cache found) + - remaining_tokens: Tokens that still need to be processed + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence for the prompt + + Returns: + Tuple of (prompt_cache, remaining_tokens) + """ + with self._lock: + tokens_tuple = tuple(tokens) + result = self._search(model, tokens) + + # Exact match - extract and return + if result.exact is not None: + cache_entry = self._extract(result.model, result.exact) + return cache_entry.prompt_cache, [] + + # Shorter prefix match - extract and return remaining + if result.shorter is not None: + cache_entry = self._extract(result.model, result.shorter) + prefix_len = len(result.shorter) + return cache_entry.prompt_cache, list(tokens[prefix_len:]) + + # Longer prefix match - try to trim if possible + if result.longer is not None and self._can_trim_fn is not None: + cache_entry = self._get(result.model, result.longer) + if self._can_trim_fn(cache_entry.prompt_cache): + # Deep copy and trim + trimmed_cache = copy.deepcopy(cache_entry.prompt_cache) + prefix = min(len(tokens) - 1, result.common_prefix) + num_to_trim = len(result.longer) - prefix + if self._trim_fn is not None: + self._trim_fn(trimmed_cache, num_to_trim) + return trimmed_cache, list(tokens[prefix:]) + + # No match found + return None, list(tokens) + + def insert_cache( + self, model, tokens: List[int], prompt_cache: List[Any] + ) -> None: + """ + Insert a cache entry after generation completes. + + Thread-safe. Handles LRU eviction if max_size is exceeded. + + Args: + model: Model identifier (used to namespace caches) + tokens: The full token sequence (prompt + generated) + prompt_cache: The KV cache to store + """ + with self._lock: + tokens_tuple = tuple(tokens) + + if model not in self._cache: + self._cache[model] = {} + current = self._cache[model] + + # Build trie path + for tok in tokens_tuple: + if tok not in current: + current[tok] = {} + current = current[tok] + + # Update or create entry + if "cache" in current: + current["cache"].count += 1 + self._lru.remove((model, tokens_tuple)) + else: + current["cache"] = CacheEntry(prompt_cache, 1) + + # Update LRU order + self._lru.append((model, tokens_tuple)) + + # Evict if over capacity + if len(self._lru) > self.max_size: + evict_model, evict_tokens = self._lru.popleft() + self._delete(evict_model, evict_tokens) + + def clear(self) -> None: + """Clear all cache entries. Thread-safe.""" + with self._lock: + self._cache.clear() + self._lru.clear() + + def __len__(self) -> int: + """Return the number of cache entries. Thread-safe.""" + with self._lock: + return len(self._lru) diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index 072f8a0b0bba..2459ad1b5d6c 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -14,11 +14,15 @@ import grpc from mlx_lm import load, generate, stream_generate from mlx_lm.sample_utils import make_sampler -from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache import mlx.core as mx import base64 import io +# Add common module to path for shared cache +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) +from mlx_cache import ThreadSafeLRUPromptCache + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 @@ -118,10 +122,16 @@ async def LoadModel(self, request, context): self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) else: self.model, self.tokenizer = load(request.Model) - - # Initialize prompt cache for efficient generation - max_kv_size = self.options.get("max_kv_size", None) - self.prompt_cache = make_prompt_cache(self.model, max_kv_size) + + # Initialize thread-safe LRU prompt cache for efficient generation + max_cache_entries = self.options.get("max_cache_entries", 10) + self.max_kv_size = self.options.get("max_kv_size", None) + self.model_key = request.Model + self.lru_cache = ThreadSafeLRUPromptCache( + max_size=max_cache_entries, + can_trim_fn=can_trim_prompt_cache, + trim_fn=trim_prompt_cache, + ) except Exception as err: print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr) @@ -134,6 +144,8 @@ async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters using MLX. + Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. + Args: request: The predict request. context: The gRPC context. @@ -141,31 +153,48 @@ async def Predict(self, request, context): Returns: backend_pb2.Reply: The predict result. """ + prompt_cache = None + cache_key = None + try: - # Prepare the prompt - prompt = self._prepare_prompt(request) - + # Prepare the prompt and tokenize for cache key + prompt_text = self._prepare_prompt(request) + cache_key = self._get_tokens_from_prompt(prompt_text) + + # Fetch nearest cache (exact, shorter prefix, or create new) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + # Build generation parameters using request attributes and options max_tokens, sampler_params = self._build_generation_params(request) - - print(f"Generating text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr) - + + print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) + # Create sampler with parameters sampler = make_sampler(**sampler_params) - - # Generate text using MLX with proper parameters - response = generate( + + # Use stream_generate to track generated tokens for cache key + generated_text = [] + for response in stream_generate( self.model, self.tokenizer, - prompt=prompt, + prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, - prompt_cache=self.prompt_cache, - verbose=False - ) - - return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) - + prompt_cache=prompt_cache, + ): + generated_text.append(response.text) + cache_key.append(response.token) + + # Insert completed cache + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + + return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8')) + except Exception as e: print(f"Error in MLX Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) @@ -194,6 +223,8 @@ async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results using MLX. + Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. + Args: request: The predict stream request. context: The gRPC context. @@ -201,35 +232,56 @@ async def PredictStream(self, request, context): Yields: backend_pb2.Reply: Streaming predict results. """ + prompt_cache = None + cache_key = None + try: - # Prepare the prompt - prompt = self._prepare_prompt(request) - + # Prepare the prompt and tokenize for cache key + prompt_text = self._prepare_prompt(request) + cache_key = self._get_tokens_from_prompt(prompt_text) + + # Fetch nearest cache (exact, shorter prefix, or create new) + prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( + self.model_key, cache_key + ) + if prompt_cache is None: + prompt_cache = make_prompt_cache(self.model, self.max_kv_size) + remaining_tokens = cache_key + # Build generation parameters using request attributes and options max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) - - print(f"Streaming text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr) - + + print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) + # Create sampler with parameters sampler = make_sampler(**sampler_params) - + # Stream text generation using MLX with proper parameters for response in stream_generate( self.model, self.tokenizer, - prompt=prompt, + prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, - prompt_cache=self.prompt_cache, + prompt_cache=prompt_cache, ): + cache_key.append(response.token) yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) - + except Exception as e: print(f"Error in MLX PredictStream: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Streaming generation failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + finally: + # Always insert cache, even on interruption + if prompt_cache is not None and cache_key is not None: + try: + self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) + except Exception as e: + print(f"Error inserting cache: {e}", file=sys.stderr) + def _prepare_prompt(self, request): """ Prepare the prompt for MLX generation, handling chat templates if needed. @@ -246,16 +298,31 @@ def _prepare_prompt(self, request): messages = [] for msg in request.Messages: messages.append({"role": msg.role, "content": msg.content}) - + prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, + messages, + tokenize=False, add_generation_prompt=True ) return prompt else: return request.Prompt + def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: + """ + Tokenize prompt text for cache key generation. + + Args: + prompt_text: The prompt string to tokenize. + + Returns: + List[int]: List of token IDs. + """ + tokens = self.tokenizer.encode(prompt_text) + if hasattr(tokens, 'tolist'): + return tokens.tolist() + return list(tokens) + diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py index 827aa71a3e33..f6047cd1fbd2 100644 --- a/backend/python/mlx/test.py +++ b/backend/python/mlx/test.py @@ -143,4 +143,197 @@ def test_embedding(self): print(err) self.fail("Embedding service failed") finally: - self.tearDown() \ No newline at end of file + self.tearDown() + + def test_concurrent_requests(self): + """ + This method tests that concurrent requests don't corrupt each other's cache state. + This is a regression test for the race condition in the original implementation. + """ + import concurrent.futures + + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + + def make_request(prompt): + req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20) + return stub.Predict(req) + + # Run 5 concurrent requests with different prompts + prompts = [ + "The capital of France is", + "The capital of Germany is", + "The capital of Italy is", + "The capital of Spain is", + "The capital of Portugal is", + ] + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(make_request, p) for p in prompts] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + # All results should be non-empty + messages = [r.message for r in results] + self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses") + print(f"Concurrent test passed: {len(messages)} responses received") + + except Exception as err: + print(err) + self.fail("Concurrent requests test failed") + finally: + self.tearDown() + + def test_cache_reuse(self): + """ + This method tests that repeated prompts reuse cached KV states. + The second request should benefit from the cached prompt processing. + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + + prompt = "The quick brown fox jumps over the lazy dog. " + + # First request - populates cache + req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10) + resp1 = stub.Predict(req1) + self.assertIsNotNone(resp1.message) + + # Second request with same prompt - should reuse cache + req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10) + resp2 = stub.Predict(req2) + self.assertIsNotNone(resp2.message) + + print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes") + + except Exception as err: + print(err) + self.fail("Cache reuse test failed") + finally: + self.tearDown() + + def test_prefix_cache_reuse(self): + """ + This method tests that prompts sharing a common prefix benefit from cached KV states. + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + + # First request with base prompt + prompt_base = "Once upon a time in a land far away, " + req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10) + resp1 = stub.Predict(req1) + self.assertIsNotNone(resp1.message) + + # Second request with extended prompt (same prefix) + prompt_extended = prompt_base + "there lived a brave knight who " + req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10) + resp2 = stub.Predict(req2) + self.assertIsNotNone(resp2.message) + + print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes") + + except Exception as err: + print(err) + self.fail("Prefix cache reuse test failed") + finally: + self.tearDown() + + +class TestThreadSafeLRUPromptCache(unittest.TestCase): + """ + Unit tests for the ThreadSafeLRUPromptCache class. + These tests don't require the gRPC server. + """ + + def setUp(self): + import sys + import os + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) + from mlx_cache import ThreadSafeLRUPromptCache + self.cache = ThreadSafeLRUPromptCache(max_size=3) + + def test_insert_and_fetch_exact(self): + """Test inserting and fetching an exact match.""" + tokens = [1, 2, 3, 4, 5] + mock_cache = ["mock_kv_cache"] + + self.cache.insert_cache("model1", tokens, mock_cache) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + + self.assertEqual(result_cache, mock_cache) + self.assertEqual(remaining, []) + + def test_fetch_shorter_prefix(self): + """Test fetching a shorter prefix match.""" + # Insert a short sequence + short_tokens = [1, 2, 3] + mock_cache = ["mock_kv_cache"] + self.cache.insert_cache("model1", short_tokens, mock_cache) + + # Fetch with a longer sequence + long_tokens = [1, 2, 3, 4, 5] + result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens) + + self.assertEqual(result_cache, mock_cache) + self.assertEqual(remaining, [4, 5]) + + def test_lru_eviction(self): + """Test that LRU eviction works when max_size is exceeded.""" + # Insert 3 entries (max_size) + self.cache.insert_cache("model1", [1], ["cache1"]) + self.cache.insert_cache("model1", [2], ["cache2"]) + self.cache.insert_cache("model1", [3], ["cache3"]) + + self.assertEqual(len(self.cache), 3) + + # Insert a 4th entry - should evict the oldest (tokens=[1]) + self.cache.insert_cache("model1", [4], ["cache4"]) + + self.assertEqual(len(self.cache), 3) + + # The first entry should be evicted + result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1]) + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1]) + + def test_thread_safety(self): + """Test that concurrent access doesn't cause errors.""" + import concurrent.futures + import random + + def random_operation(op_id): + tokens = [random.randint(1, 100) for _ in range(random.randint(1, 10))] + if random.random() < 0.5: + self.cache.insert_cache(f"model{op_id % 3}", tokens, [f"cache_{op_id}"]) + else: + self.cache.fetch_nearest_cache(f"model{op_id % 3}", tokens) + return op_id + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(random_operation, i) for i in range(100)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + self.assertEqual(len(results), 100) + + def test_clear(self): + """Test that clear() removes all entries.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache1"]) + self.cache.insert_cache("model2", [4, 5, 6], ["cache2"]) + + self.assertEqual(len(self.cache), 2) + + self.cache.clear() + + self.assertEqual(len(self.cache), 0) \ No newline at end of file From 022eededc570bf9bbb33135d738ab229cb381afa Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 13 Dec 2025 04:21:26 -0500 Subject: [PATCH 2/5] feat(mlx): add min_p and top_k sampler support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add MinP field to proto (field 52) following the precedent set by other non-OpenAI sampling parameters like TopK, TailFreeSamplingZ, TypicalP, and Mirostat. Changes: - backend.proto: Add float MinP field for min-p sampling - backend.py: Extract and pass min_p and top_k to mlx_lm sampler (top_k was in proto but not being passed) - test.py: Fix test_sampling_params to use valid proto fields and switch to MLX-compatible model (mlx-community/Llama-3.2-1B-Instruct) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: Blightbow --- backend/backend.proto | 1 + backend/python/mlx/backend.py | 14 ++++++++++++-- backend/python/mlx/test.py | 25 ++++++------------------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index 187294236862..cf213387209e 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -158,6 +158,7 @@ message PredictOptions { string ToolChoice = 49; // JSON string or object specifying tool choice behavior int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter) int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter) + float MinP = 52; // Min-p sampling: minimum probability threshold scaled by top token probability } // The response message containing the result diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index 2459ad1b5d6c..54089ffbac8c 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -351,11 +351,19 @@ def _build_generation_params(self, request, default_max_tokens=200): top_p = getattr(request, 'TopP', 0.0) if top_p == 0.0: top_p = 1.0 # Default top_p - + + min_p = getattr(request, 'MinP', 0.0) + # min_p default of 0.0 means disabled (no filtering) + + top_k = getattr(request, 'TopK', 0) + # top_k default of 0 means disabled (no filtering) + # Initialize sampler parameters sampler_params = { 'temp': temp, 'top_p': top_p, + 'min_p': min_p, + 'top_k': top_k, 'xtc_threshold': 0.0, 'xtc_probability': 0.0, } @@ -375,7 +383,9 @@ def _build_generation_params(self, request, default_max_tokens=200): sampler_option_mapping = { 'temp': 'temp', 'temperature': 'temp', # alias - 'top_p': 'top_p', + 'top_p': 'top_p', + 'min_p': 'min_p', + 'top_k': 'top_k', 'xtc_threshold': 'xtc_threshold', 'xtc_probability': 'xtc_probability', } diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py index f6047cd1fbd2..d8be83fdf8ee 100644 --- a/backend/python/mlx/test.py +++ b/backend/python/mlx/test.py @@ -47,7 +47,7 @@ def test_load_model(self): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: @@ -64,7 +64,7 @@ def test_text(self): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) @@ -84,7 +84,7 @@ def test_sampling_params(self): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions( @@ -95,26 +95,13 @@ def test_sampling_params(self): TopK=40, PresencePenalty=0.1, FrequencyPenalty=0.2, - RepetitionPenalty=1.1, MinP=0.05, Seed=42, StopPrompts=["\n"], - StopTokenIds=[50256], - BadWords=["badword"], - IncludeStopStrInOutput=True, IgnoreEOS=True, - MinTokens=5, - Logprobs=5, - PromptLogprobs=5, - SkipSpecialTokens=True, - SpacesBetweenSpecialTokens=True, - TruncatePromptTokens=10, - GuidedDecoding=True, - N=2, ) resp = stub.Predict(req) self.assertIsNotNone(resp.message) - self.assertIsNotNone(resp.logprobs) except Exception as err: print(err) self.fail("sampling params service failed") @@ -156,7 +143,7 @@ def test_concurrent_requests(self): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) def make_request(prompt): @@ -196,7 +183,7 @@ def test_cache_reuse(self): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) prompt = "The quick brown fox jumps over the lazy dog. " @@ -227,7 +214,7 @@ def test_prefix_cache_reuse(self): self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) # First request with base prompt From 06f012b366ef75e10b9d0a8244e0abab518df81f Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 13 Dec 2025 06:49:30 -0500 Subject: [PATCH 3/5] refactor(mlx): move mlx_cache.py from common to mlx backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ThreadSafeLRUPromptCache is only used by the mlx backend. After evaluating mlx-vlm, it was determined that the cache cannot be shared because mlx-vlm's generate/stream_generate functions don't support the prompt_cache parameter that mlx_lm provides. - Move mlx_cache.py from backend/python/common/ to backend/python/mlx/ - Remove sys.path manipulation from backend.py and test.py - Fix test assertion to expect "MLX model loaded successfully" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: Blightbow --- backend/python/mlx/backend.py | 2 -- backend/python/{common => mlx}/mlx_cache.py | 0 backend/python/mlx/test.py | 5 +---- 3 files changed, 1 insertion(+), 6 deletions(-) rename backend/python/{common => mlx}/mlx_cache.py (100%) diff --git a/backend/python/mlx/backend.py b/backend/python/mlx/backend.py index 54089ffbac8c..aaa0d6f347f8 100644 --- a/backend/python/mlx/backend.py +++ b/backend/python/mlx/backend.py @@ -19,8 +19,6 @@ import base64 import io -# Add common module to path for shared cache -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) from mlx_cache import ThreadSafeLRUPromptCache _ONE_DAY_IN_SECONDS = 60 * 60 * 24 diff --git a/backend/python/common/mlx_cache.py b/backend/python/mlx/mlx_cache.py similarity index 100% rename from backend/python/common/mlx_cache.py rename to backend/python/mlx/mlx_cache.py diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py index d8be83fdf8ee..2d4f48700400 100644 --- a/backend/python/mlx/test.py +++ b/backend/python/mlx/test.py @@ -49,7 +49,7 @@ def test_load_model(self): stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) - self.assertEqual(response.message, "Model loaded successfully") + self.assertEqual(response.message, "MLX model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") @@ -245,9 +245,6 @@ class TestThreadSafeLRUPromptCache(unittest.TestCase): """ def setUp(self): - import sys - import os - sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) from mlx_cache import ThreadSafeLRUPromptCache self.cache = ThreadSafeLRUPromptCache(max_size=3) From a796742e74313afa643b58742cf55f5ff3d617e2 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Sat, 13 Dec 2025 06:57:55 -0500 Subject: [PATCH 4/5] test(mlx): add comprehensive cache tests and document upstream behavior MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added comprehensive unit tests (test_mlx_cache.py) covering all cache operation modes: - Exact match - Shorter prefix match - Longer prefix match with trimming - No match scenarios - LRU eviction and access order - Reference counting and deep copy behavior - Multi-model namespacing - Thread safety with data integrity verification Documents upstream mlx_lm/server.py behavior: single-token prefixes are deliberately not matched (uses > 0, not >= 0) to allow longer cached sequences to be preferred for trimming. This is acceptable because real prompts with chat templates are always many tokens. Removed weak unit tests from test.py that only verified "no exception thrown" rather than correctness. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: Blightbow --- backend/python/mlx/mlx_cache.py | 4 + backend/python/mlx/test.py | 93 +----- backend/python/mlx/test_mlx_cache.py | 480 +++++++++++++++++++++++++++ 3 files changed, 486 insertions(+), 91 deletions(-) create mode 100644 backend/python/mlx/test_mlx_cache.py diff --git a/backend/python/mlx/mlx_cache.py b/backend/python/mlx/mlx_cache.py index 245639ff0cd0..6ec2bb9baabb 100644 --- a/backend/python/mlx/mlx_cache.py +++ b/backend/python/mlx/mlx_cache.py @@ -99,6 +99,10 @@ def _search(self, model, tokens: List[int]) -> SearchResult: return SearchResult(model, tuple(tokens), None, None, 0) # Find the shorter cache (a prefix that has a cache) + # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. + # Single-token prefixes are not matched, which allows longer cached + # sequences to be preferred for trimming. This is acceptable because + # real prompts with chat templates are always many tokens. shorter = None if last_cache_index > 0: shorter = tuple(tokens[: last_cache_index + 1]) diff --git a/backend/python/mlx/test.py b/backend/python/mlx/test.py index 2d4f48700400..53d7bc7ec1b4 100644 --- a/backend/python/mlx/test.py +++ b/backend/python/mlx/test.py @@ -1,17 +1,10 @@ import unittest import subprocess import time -import backend_pb2 -import backend_pb2_grpc import grpc - -import unittest -import subprocess -import time -import grpc -import backend_pb2_grpc import backend_pb2 +import backend_pb2_grpc class TestBackendServicer(unittest.TestCase): """ @@ -238,86 +231,4 @@ def test_prefix_cache_reuse(self): self.tearDown() -class TestThreadSafeLRUPromptCache(unittest.TestCase): - """ - Unit tests for the ThreadSafeLRUPromptCache class. - These tests don't require the gRPC server. - """ - - def setUp(self): - from mlx_cache import ThreadSafeLRUPromptCache - self.cache = ThreadSafeLRUPromptCache(max_size=3) - - def test_insert_and_fetch_exact(self): - """Test inserting and fetching an exact match.""" - tokens = [1, 2, 3, 4, 5] - mock_cache = ["mock_kv_cache"] - - self.cache.insert_cache("model1", tokens, mock_cache) - result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) - - self.assertEqual(result_cache, mock_cache) - self.assertEqual(remaining, []) - - def test_fetch_shorter_prefix(self): - """Test fetching a shorter prefix match.""" - # Insert a short sequence - short_tokens = [1, 2, 3] - mock_cache = ["mock_kv_cache"] - self.cache.insert_cache("model1", short_tokens, mock_cache) - - # Fetch with a longer sequence - long_tokens = [1, 2, 3, 4, 5] - result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens) - - self.assertEqual(result_cache, mock_cache) - self.assertEqual(remaining, [4, 5]) - - def test_lru_eviction(self): - """Test that LRU eviction works when max_size is exceeded.""" - # Insert 3 entries (max_size) - self.cache.insert_cache("model1", [1], ["cache1"]) - self.cache.insert_cache("model1", [2], ["cache2"]) - self.cache.insert_cache("model1", [3], ["cache3"]) - - self.assertEqual(len(self.cache), 3) - - # Insert a 4th entry - should evict the oldest (tokens=[1]) - self.cache.insert_cache("model1", [4], ["cache4"]) - - self.assertEqual(len(self.cache), 3) - - # The first entry should be evicted - result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1]) - self.assertIsNone(result_cache) - self.assertEqual(remaining, [1]) - - def test_thread_safety(self): - """Test that concurrent access doesn't cause errors.""" - import concurrent.futures - import random - - def random_operation(op_id): - tokens = [random.randint(1, 100) for _ in range(random.randint(1, 10))] - if random.random() < 0.5: - self.cache.insert_cache(f"model{op_id % 3}", tokens, [f"cache_{op_id}"]) - else: - self.cache.fetch_nearest_cache(f"model{op_id % 3}", tokens) - return op_id - - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - futures = [executor.submit(random_operation, i) for i in range(100)] - results = [f.result() for f in concurrent.futures.as_completed(futures)] - - self.assertEqual(len(results), 100) - - def test_clear(self): - """Test that clear() removes all entries.""" - self.cache.insert_cache("model1", [1, 2, 3], ["cache1"]) - self.cache.insert_cache("model2", [4, 5, 6], ["cache2"]) - - self.assertEqual(len(self.cache), 2) - - self.cache.clear() - - self.assertEqual(len(self.cache), 0) \ No newline at end of file +# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py \ No newline at end of file diff --git a/backend/python/mlx/test_mlx_cache.py b/backend/python/mlx/test_mlx_cache.py new file mode 100644 index 000000000000..c888782e9ddf --- /dev/null +++ b/backend/python/mlx/test_mlx_cache.py @@ -0,0 +1,480 @@ +""" +Comprehensive unit tests for ThreadSafeLRUPromptCache. + +Tests all cache operation modes: +- Exact match +- Shorter prefix match +- Longer prefix match (with trimming) +- No match +- LRU eviction +- Reference counting +- Multi-model namespacing +- Thread safety with data integrity verification +""" +import unittest +import concurrent.futures +import threading +import copy +from mlx_cache import ThreadSafeLRUPromptCache + + +class TestCacheExactMatch(unittest.TestCase): + """Tests for exact match cache behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_exact_match_returns_cache_and_empty_remaining(self): + """Exact match should return the cache with no remaining tokens.""" + tokens = [1, 2, 3, 4, 5] + mock_cache = ["kv_cache_data"] + + self.cache.insert_cache("model1", tokens, mock_cache) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + + self.assertEqual(result_cache, mock_cache) + self.assertEqual(remaining, []) + + def test_exact_match_extracts_and_removes_from_cache(self): + """Fetching exact match with count=1 should remove entry from cache.""" + tokens = [1, 2, 3] + self.cache.insert_cache("model1", tokens, ["cache"]) + + self.assertEqual(len(self.cache), 1) + + # First fetch extracts the entry + self.cache.fetch_nearest_cache("model1", tokens) + + # Cache should now be empty + self.assertEqual(len(self.cache), 0) + + # Second fetch should return None (no match) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + self.assertIsNone(result_cache) + self.assertEqual(remaining, tokens) + + +class TestCacheShorterPrefix(unittest.TestCase): + """Tests for shorter prefix match behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_shorter_prefix_returns_cache_with_remaining_tokens(self): + """When cached prefix is shorter, return cache and remaining suffix.""" + short_tokens = [1, 2, 3] + long_tokens = [1, 2, 3, 4, 5, 6] + mock_cache = ["prefix_cache"] + + self.cache.insert_cache("model1", short_tokens, mock_cache) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens) + + self.assertEqual(result_cache, mock_cache) + self.assertEqual(remaining, [4, 5, 6]) + + def test_shorter_prefix_correct_remaining_calculation(self): + """Verify remaining tokens are calculated correctly for various prefix lengths.""" + # Note: Single-token prefixes ([1] -> [1,2,3]) are deliberately not matched + # to allow longer cached sequences to be preferred for trimming. + # This matches upstream mlx_lm/server.py behavior. + test_cases = [ + # (cached_tokens, requested_tokens, expected_remaining) + ([1, 2], [1, 2, 3, 4, 5], [3, 4, 5]), + ([10, 20, 30, 40], [10, 20, 30, 40, 50], [50]), + ] + + for cached, requested, expected_remaining in test_cases: + with self.subTest(cached=cached, requested=requested): + cache = ThreadSafeLRUPromptCache(max_size=10) + cache.insert_cache("model", cached, ["cache"]) + result_cache, remaining = cache.fetch_nearest_cache("model", requested) + + self.assertIsNotNone(result_cache) + self.assertEqual(remaining, expected_remaining) + + def test_single_token_prefix_not_matched(self): + """Single-token prefixes are not matched (by design, matches upstream). + + This allows longer cached sequences to be preferred for trimming, + which provides better KV cache reuse. Single-token caches are rare + in practice since real prompts with chat templates are many tokens. + """ + cache = ThreadSafeLRUPromptCache(max_size=10) + cache.insert_cache("model", [1], ["cache"]) + + result_cache, remaining = cache.fetch_nearest_cache("model", [1, 2, 3]) + + # Single-token prefix is NOT matched + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 3]) + + +class TestCacheLongerPrefix(unittest.TestCase): + """Tests for longer prefix match behavior (trimming).""" + + def setUp(self): + # Track trim calls for verification + self.trim_calls = [] + + def mock_can_trim(cache): + return True + + def mock_trim(cache, num_to_trim): + self.trim_calls.append(num_to_trim) + # Simulate trimming by modifying the cache + cache.append(f"trimmed_{num_to_trim}") + + self.cache = ThreadSafeLRUPromptCache( + max_size=10, + can_trim_fn=mock_can_trim, + trim_fn=mock_trim, + ) + + def test_longer_prefix_triggers_trim(self): + """When cached sequence is longer, should trim to match requested prefix.""" + long_tokens = [1, 2, 3, 4, 5] + short_tokens = [1, 2, 3] + + self.cache.insert_cache("model1", long_tokens, ["original_cache"]) + result_cache, remaining = self.cache.fetch_nearest_cache("model1", short_tokens) + + # Should have called trim + self.assertTrue(len(self.trim_calls) > 0, "trim_fn should have been called") + # Result should be a trimmed copy, not the original + self.assertIn("trimmed_", str(result_cache)) + + def test_longer_prefix_without_trim_fn_returns_no_match(self): + """Without trim functions, longer prefix should not match.""" + cache_no_trim = ThreadSafeLRUPromptCache(max_size=10) + + long_tokens = [1, 2, 3, 4, 5] + short_tokens = [1, 2, 3] + + cache_no_trim.insert_cache("model1", long_tokens, ["cache"]) + result_cache, remaining = cache_no_trim.fetch_nearest_cache("model1", short_tokens) + + # Without trim_fn, should return no match + self.assertIsNone(result_cache) + self.assertEqual(remaining, short_tokens) + + def test_longer_prefix_can_trim_false_returns_no_match(self): + """When can_trim_fn returns False, should not attempt trim.""" + cache = ThreadSafeLRUPromptCache( + max_size=10, + can_trim_fn=lambda c: False, + trim_fn=lambda c, n: None, + ) + + cache.insert_cache("model1", [1, 2, 3, 4, 5], ["cache"]) + result_cache, remaining = cache.fetch_nearest_cache("model1", [1, 2, 3]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 3]) + + +class TestCacheNoMatch(unittest.TestCase): + """Tests for no match behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_empty_cache_returns_none(self): + """Empty cache should return None and all tokens as remaining.""" + tokens = [1, 2, 3] + result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, tokens) + + def test_different_prefix_returns_none(self): + """Tokens with different prefix should not match.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) + + # Completely different tokens + result_cache, remaining = self.cache.fetch_nearest_cache("model1", [4, 5, 6]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [4, 5, 6]) + + def test_partial_prefix_mismatch_returns_none(self): + """Tokens that diverge mid-sequence should not match.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) + + # Same start but diverges + result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1, 2, 99]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 99]) + + def test_wrong_model_returns_none(self): + """Different model key should not match.""" + self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) + + result_cache, remaining = self.cache.fetch_nearest_cache("model2", [1, 2, 3]) + + self.assertIsNone(result_cache) + self.assertEqual(remaining, [1, 2, 3]) + + +class TestCacheLRUEviction(unittest.TestCase): + """Tests for LRU eviction behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=3) + + def test_evicts_oldest_when_full(self): + """Should evict least recently used entry when capacity exceeded.""" + self.cache.insert_cache("model", [1], ["cache1"]) + self.cache.insert_cache("model", [2], ["cache2"]) + self.cache.insert_cache("model", [3], ["cache3"]) + + self.assertEqual(len(self.cache), 3) + + # Insert 4th entry - should evict [1] + self.cache.insert_cache("model", [4], ["cache4"]) + + self.assertEqual(len(self.cache), 3) + + # [1] should be evicted + result, _ = self.cache.fetch_nearest_cache("model", [1]) + self.assertIsNone(result) + + # [2], [3], [4] should still exist + for tokens in [[2], [3], [4]]: + # Re-insert since fetch extracts + self.cache.insert_cache("model", tokens, [f"cache{tokens[0]}"]) + + result2, _ = self.cache.fetch_nearest_cache("model", [2]) + self.assertIsNotNone(result2) + + def test_access_updates_lru_order(self): + """Accessing an entry should move it to most recently used.""" + self.cache.insert_cache("model", [1], ["cache1"]) + self.cache.insert_cache("model", [2], ["cache2"]) + self.cache.insert_cache("model", [3], ["cache3"]) + + # Access [1] to make it most recently used + cache1, _ = self.cache.fetch_nearest_cache("model", [1]) + # Re-insert it (simulating normal usage pattern) + self.cache.insert_cache("model", [1], cache1) + + # Now insert two more entries - should evict [2] then [3], not [1] + self.cache.insert_cache("model", [4], ["cache4"]) + self.cache.insert_cache("model", [5], ["cache5"]) + + # [1] should still exist (was accessed, so not evicted) + result1, _ = self.cache.fetch_nearest_cache("model", [1]) + self.assertIsNotNone(result1) + + # [2] should be evicted (was oldest after [1] was accessed) + result2, _ = self.cache.fetch_nearest_cache("model", [2]) + self.assertIsNone(result2) + + +class TestCacheReferenceCount(unittest.TestCase): + """Tests for reference counting behavior.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_multiple_inserts_increment_count(self): + """Inserting same tokens multiple times should increment count.""" + tokens = [1, 2, 3] + + self.cache.insert_cache("model", tokens, ["cache"]) + self.cache.insert_cache("model", tokens, ["cache"]) + self.cache.insert_cache("model", tokens, ["cache"]) + + # Should still be one entry (with count=3 internally) + self.assertEqual(len(self.cache), 1) + + # First two fetches should return copies (count decremented) + result1, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNotNone(result1) + + result2, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNotNone(result2) + + # Third fetch extracts the last reference + result3, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNotNone(result3) + + # Fourth fetch should return None (entry fully extracted) + result4, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertIsNone(result4) + + def test_extract_with_high_count_returns_deep_copy(self): + """When count > 1, extract should return a deep copy.""" + tokens = [1, 2, 3] + original_cache = [{"nested": "data"}] + + self.cache.insert_cache("model", tokens, original_cache) + self.cache.insert_cache("model", tokens, original_cache) # count=2 + + result1, _ = self.cache.fetch_nearest_cache("model", tokens) + + # Modify the returned cache + result1[0]["nested"] = "modified" + + # Second fetch should get unmodified copy + result2, _ = self.cache.fetch_nearest_cache("model", tokens) + self.assertEqual(result2[0]["nested"], "data") + + +class TestCacheMultiModel(unittest.TestCase): + """Tests for multi-model namespacing.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_same_tokens_different_models_are_separate(self): + """Same token sequence under different models should be independent.""" + tokens = [1, 2, 3] + + self.cache.insert_cache("model_a", tokens, ["cache_a"]) + self.cache.insert_cache("model_b", tokens, ["cache_b"]) + + self.assertEqual(len(self.cache), 2) + + result_a, _ = self.cache.fetch_nearest_cache("model_a", tokens) + result_b, _ = self.cache.fetch_nearest_cache("model_b", tokens) + + self.assertEqual(result_a, ["cache_a"]) + self.assertEqual(result_b, ["cache_b"]) + + def test_eviction_across_models(self): + """LRU eviction should work across different models.""" + cache = ThreadSafeLRUPromptCache(max_size=3) + + cache.insert_cache("model_a", [1], ["a1"]) + cache.insert_cache("model_b", [1], ["b1"]) + cache.insert_cache("model_a", [2], ["a2"]) + + self.assertEqual(len(cache), 3) + + # Insert 4th - should evict model_a:[1] (oldest) + cache.insert_cache("model_b", [2], ["b2"]) + + result, _ = cache.fetch_nearest_cache("model_a", [1]) + self.assertIsNone(result) + + +class TestCacheThreadSafety(unittest.TestCase): + """Tests for thread safety with data integrity verification.""" + + def test_concurrent_inserts_no_data_loss(self): + """Concurrent inserts should not lose data.""" + cache = ThreadSafeLRUPromptCache(max_size=100) + num_threads = 10 + inserts_per_thread = 20 + + def insert_entries(thread_id): + for i in range(inserts_per_thread): + tokens = [thread_id, i] + cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(insert_entries, tid) for tid in range(num_threads)] + concurrent.futures.wait(futures) + + # Verify expected number of entries (may be less due to LRU eviction with max_size=100) + # But should be exactly 100 since we inserted exactly 200 and max_size is 100 + self.assertEqual(len(cache), 100) + + def test_concurrent_fetch_and_insert_no_corruption(self): + """Concurrent fetches and inserts should not corrupt data.""" + cache = ThreadSafeLRUPromptCache(max_size=50) + errors = [] + lock = threading.Lock() + + # Pre-populate with known data + for i in range(20): + cache.insert_cache("model", [i], [f"original_{i}"]) + + def fetch_and_verify(thread_id): + try: + for _ in range(50): + token_id = thread_id % 20 + result, remaining = cache.fetch_nearest_cache("model", [token_id]) + + if result is not None: + # Verify data integrity + expected_prefix = f"original_{token_id}" + if not str(result[0]).startswith("original_"): + with lock: + errors.append(f"Corrupted data: {result}") + + # Re-insert to keep cache populated + cache.insert_cache("model", [token_id], result) + + except Exception as e: + with lock: + errors.append(str(e)) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(fetch_and_verify, tid) for tid in range(10)] + concurrent.futures.wait(futures) + + self.assertEqual(errors, [], f"Thread safety errors: {errors}") + + def test_concurrent_operations_maintain_cache_bounds(self): + """Cache size should never exceed max_size under concurrent operations.""" + max_size = 10 + cache = ThreadSafeLRUPromptCache(max_size=max_size) + size_violations = [] + lock = threading.Lock() + + def random_operations(thread_id): + import random + for i in range(100): + tokens = [random.randint(0, 50)] + if random.random() < 0.7: + cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) + else: + cache.fetch_nearest_cache("model", tokens) + + current_size = len(cache) + if current_size > max_size: + with lock: + size_violations.append(current_size) + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(random_operations, tid) for tid in range(10)] + concurrent.futures.wait(futures) + + self.assertEqual(size_violations, [], f"Size exceeded max: {size_violations}") + self.assertLessEqual(len(cache), max_size) + + +class TestCacheClear(unittest.TestCase): + """Tests for cache clear operation.""" + + def setUp(self): + self.cache = ThreadSafeLRUPromptCache(max_size=10) + + def test_clear_removes_all_entries(self): + """Clear should remove all entries.""" + self.cache.insert_cache("model1", [1, 2], ["cache1"]) + self.cache.insert_cache("model2", [3, 4], ["cache2"]) + self.cache.insert_cache("model1", [5, 6], ["cache3"]) + + self.assertEqual(len(self.cache), 3) + + self.cache.clear() + + self.assertEqual(len(self.cache), 0) + + def test_clear_allows_new_inserts(self): + """After clear, new inserts should work normally.""" + self.cache.insert_cache("model", [1], ["cache1"]) + self.cache.clear() + self.cache.insert_cache("model", [2], ["cache2"]) + + self.assertEqual(len(self.cache), 1) + + result, _ = self.cache.fetch_nearest_cache("model", [2]) + self.assertEqual(result, ["cache2"]) + + +if __name__ == "__main__": + unittest.main() From 60c3b3539b457d2da40d49af6602fb6188a0ea11 Mon Sep 17 00:00:00 2001 From: Blightbow Date: Mon, 15 Dec 2025 18:35:03 -0500 Subject: [PATCH 5/5] chore(mlx): remove unused MinP proto field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MinP field was added to PredictOptions but is not populated by the Go frontend/API. The MLX backend uses getattr with a default value, so it works without the proto field. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: Blightbow --- backend/backend.proto | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/backend.proto b/backend/backend.proto index cf213387209e..187294236862 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -158,7 +158,6 @@ message PredictOptions { string ToolChoice = 49; // JSON string or object specifying tool choice behavior int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter) int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter) - float MinP = 52; // Min-p sampling: minimum probability threshold scaled by top token probability } // The response message containing the result