Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/shared_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,25 @@ impl SharedCachedFunction {
}
}

/// Cache lookup returning (hit, value) to distinguish cached None from miss.
#[pyo3(signature = (*args, **kwargs))]
fn _probe<'py>(
&self,
py: Python<'py>,
args: Bound<'py, PyTuple>,
kwargs: Option<Bound<'py, PyDict>>,
) -> PyResult<(bool, Py<PyAny>)> {
let (key_hash, key_bytes) = self.make_key(py, &args, &kwargs)?;

match self.cache.get(key_hash, &key_bytes) {
ShmGetResult::Hit(vb) => {
let value = self.deserialize_value(py, &vb)?;
Ok((true, value))
}
ShmGetResult::Miss => Ok((false, py.None())),
}
}

/// Store a value in the cache for the given arguments.
#[pyo3(signature = (value, *args, **kwargs))]
fn set<'py>(
Expand Down
33 changes: 33 additions & 0 deletions src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,39 @@ impl CachedFunction {
Ok(None)
}

/// Cache lookup returning (hit, value) to distinguish cached None from miss.
#[pyo3(signature = (*args, **kwargs))]
fn _probe<'py>(
&self,
py: Python<'py>,
args: Bound<'py, PyTuple>,
kwargs: Option<Bound<'py, PyDict>>,
) -> PyResult<(bool, Py<PyAny>)> {
let (hash, key_ptr, _key_owner) = Self::hash_args(py, &args, &kwargs)?;
let borrowed = BorrowedArgs { hash, ptr: key_ptr };
let shard_idx = hash as usize & self.shard_mask;

let shard = self.shards[shard_idx].read();
if let Some(entry) = shard.map.get(&borrowed) {
if let Some(ttl) = self.ttl {
if entry.created_at.elapsed() > ttl {
drop(shard);
self.misses.fetch_add(1, Ordering::Relaxed);
return Ok((false, py.None()));
}
}
entry.visited.store(true, Ordering::Relaxed);
let val = entry.value.clone_ref(py);
drop(shard);
self.hits.fetch_add(1, Ordering::Relaxed);
return Ok((true, val));
}

drop(shard);
self.misses.fetch_add(1, Ordering::Relaxed);
Ok((false, py.None()))
}

/// Store a value in the cache for the given arguments.
#[pyo3(signature = (value, *args, **kwargs))]
fn set<'py>(
Expand Down
50 changes: 50 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,56 @@ def add(a, b):
assert info.hits == 1


# ── None return value ────────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_async_none_return_value():
"""Verify that async functions returning None are cached correctly."""
call_count = 0

@cache(max_size=128)
async def returns_none(x):
nonlocal call_count
call_count += 1
return None

result = await returns_none(1)
assert result is None
assert call_count == 1

result = await returns_none(1)
assert result is None
assert call_count == 1 # cached, not recomputed

info = returns_none.cache_info()
assert info.hits == 1
assert info.misses == 1


@pytest.mark.skipif(sys.platform == "win32", reason="shared memory is Unix-only")
@pytest.mark.asyncio
async def test_async_none_return_value_shared():
"""Verify that async functions returning None are cached with shared backend."""
call_count = 0

@cache(max_size=128, backend="shared")
async def returns_none(x):
nonlocal call_count
call_count += 1
return None

returns_none.cache_clear()

result = await returns_none(1)
assert result is None
assert call_count == 1

result = await returns_none(1)
assert result is None
assert call_count == 1 # cached, not recomputed


# ── Single-flight (dogpile prevention) ───────────────────────────────────


Expand Down
8 changes: 4 additions & 4 deletions warp_cache/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def _make_inflight_key(
return args

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
cached = self._inner.get(*args, **kwargs)
if cached is not None:
hit, cached = self._inner._probe(*args, **kwargs) # type: ignore[unresolved-attribute]
if hit:
return cached

key = self._make_inflight_key(args, kwargs or None)
Expand All @@ -58,8 +58,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
event = self._inflight.get(key)
if event is not None:
await event.wait()
cached = self._inner.get(*args, **kwargs)
if cached is not None:
hit, cached = self._inner._probe(*args, **kwargs) # type: ignore[unresolved-attribute]
if hit:
return cached
# Leader failed — loop back to check for a new leader
continue
Expand Down
Loading