diff --git a/src/shared_store.rs b/src/shared_store.rs index 764e95f..fc1f513 100644 --- a/src/shared_store.rs +++ b/src/shared_store.rs @@ -155,6 +155,25 @@ impl SharedCachedFunction { } } + /// Cache lookup returning (hit, value) to distinguish cached None from miss. + #[pyo3(signature = (*args, **kwargs))] + fn _probe<'py>( + &self, + py: Python<'py>, + args: Bound<'py, PyTuple>, + kwargs: Option>, + ) -> PyResult<(bool, Py)> { + let (key_hash, key_bytes) = self.make_key(py, &args, &kwargs)?; + + match self.cache.get(key_hash, &key_bytes) { + ShmGetResult::Hit(vb) => { + let value = self.deserialize_value(py, &vb)?; + Ok((true, value)) + } + ShmGetResult::Miss => Ok((false, py.None())), + } + } + /// Store a value in the cache for the given arguments. #[pyo3(signature = (value, *args, **kwargs))] fn set<'py>( diff --git a/src/store.rs b/src/store.rs index ca0497f..2c1742d 100644 --- a/src/store.rs +++ b/src/store.rs @@ -467,6 +467,39 @@ impl CachedFunction { Ok(None) } + /// Cache lookup returning (hit, value) to distinguish cached None from miss. + #[pyo3(signature = (*args, **kwargs))] + fn _probe<'py>( + &self, + py: Python<'py>, + args: Bound<'py, PyTuple>, + kwargs: Option>, + ) -> PyResult<(bool, Py)> { + let (hash, key_ptr, _key_owner) = Self::hash_args(py, &args, &kwargs)?; + let borrowed = BorrowedArgs { hash, ptr: key_ptr }; + let shard_idx = hash as usize & self.shard_mask; + + let shard = self.shards[shard_idx].read(); + if let Some(entry) = shard.map.get(&borrowed) { + if let Some(ttl) = self.ttl { + if entry.created_at.elapsed() > ttl { + drop(shard); + self.misses.fetch_add(1, Ordering::Relaxed); + return Ok((false, py.None())); + } + } + entry.visited.store(true, Ordering::Relaxed); + let val = entry.value.clone_ref(py); + drop(shard); + self.hits.fetch_add(1, Ordering::Relaxed); + return Ok((true, val)); + } + + drop(shard); + self.misses.fetch_add(1, Ordering::Relaxed); + Ok((false, py.None())) + } + /// Store a value in the cache for the given arguments. #[pyo3(signature = (value, *args, **kwargs))] fn set<'py>( diff --git a/tests/test_async.py b/tests/test_async.py index 552b872..4737b15 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -195,6 +195,56 @@ def add(a, b): assert info.hits == 1 +# ── None return value ──────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_async_none_return_value(): + """Verify that async functions returning None are cached correctly.""" + call_count = 0 + + @cache(max_size=128) + async def returns_none(x): + nonlocal call_count + call_count += 1 + return None + + result = await returns_none(1) + assert result is None + assert call_count == 1 + + result = await returns_none(1) + assert result is None + assert call_count == 1 # cached, not recomputed + + info = returns_none.cache_info() + assert info.hits == 1 + assert info.misses == 1 + + +@pytest.mark.skipif(sys.platform == "win32", reason="shared memory is Unix-only") +@pytest.mark.asyncio +async def test_async_none_return_value_shared(): + """Verify that async functions returning None are cached with shared backend.""" + call_count = 0 + + @cache(max_size=128, backend="shared") + async def returns_none(x): + nonlocal call_count + call_count += 1 + return None + + returns_none.cache_clear() + + result = await returns_none(1) + assert result is None + assert call_count == 1 + + result = await returns_none(1) + assert result is None + assert call_count == 1 # cached, not recomputed + + # ── Single-flight (dogpile prevention) ─────────────────────────────────── diff --git a/warp_cache/_decorator.py b/warp_cache/_decorator.py index 694c88c..97f7e6e 100644 --- a/warp_cache/_decorator.py +++ b/warp_cache/_decorator.py @@ -48,8 +48,8 @@ def _make_inflight_key( return args async def __call__(self, *args: Any, **kwargs: Any) -> Any: - cached = self._inner.get(*args, **kwargs) - if cached is not None: + hit, cached = self._inner._probe(*args, **kwargs) # type: ignore[unresolved-attribute] + if hit: return cached key = self._make_inflight_key(args, kwargs or None) @@ -58,8 +58,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: event = self._inflight.get(key) if event is not None: await event.wait() - cached = self._inner.get(*args, **kwargs) - if cached is not None: + hit, cached = self._inner._probe(*args, **kwargs) # type: ignore[unresolved-attribute] + if hit: return cached # Leader failed — loop back to check for a new leader continue