Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ async def slow_fn(x):
slow_fn(3),
)
assert results == [2, 4, 6, 2, 4, 6]
# At least 3 unique calls, possibly more due to concurrent misses
assert call_count >= 3
# Exactly 3 unique keys — single-flight coalescing prevents redundant calls
assert call_count == 3


# ── Eviction ──────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -193,3 +193,112 @@ def add(a, b):
assert add(1, 2) == 3
info = add.cache_info()
assert info.hits == 1


# ── Single-flight (dogpile prevention) ───────────────────────────────────


@pytest.mark.asyncio
async def test_async_single_flight():
"""Multiple concurrent coroutines for the same key: only one computes."""
call_count = 0

@cache(max_size=128)
async def slow_fn(x):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05)
return x * 10

results = await asyncio.gather(*(slow_fn(1) for _ in range(10)))
assert results == [10] * 10
assert call_count == 1


@pytest.mark.asyncio
async def test_async_single_flight_different_keys():
"""Concurrent calls with different keys compute independently."""
call_count = 0

@cache(max_size=128)
async def slow_fn(x):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05)
return x * 10

results = await asyncio.gather(
slow_fn(1), slow_fn(2), slow_fn(3),
slow_fn(1), slow_fn(2), slow_fn(3),
)
assert results == [10, 20, 30, 10, 20, 30]
assert call_count == 3


@pytest.mark.asyncio
async def test_async_single_flight_error_recovery():
"""If the leader fails, a waiter becomes the new leader and retries."""
call_count = 0

@cache(max_size=128)
async def flaky_fn(x):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.02)
if call_count == 1:
raise ValueError("transient error")
return x * 10

# First batch: leader fails, waiters should retry
tasks = [asyncio.create_task(flaky_fn(1)) for _ in range(5)]
results = await asyncio.gather(*tasks, return_exceptions=True)

# The leader raised; all waiters recovered via a new leader
successes = [r for r in results if r == 10]
errors = [r for r in results if isinstance(r, ValueError)]
assert len(errors) == 1 # only the original leader failed
assert len(successes) == 4 # all waiters got the result
assert call_count == 2


@pytest.mark.asyncio
async def test_async_single_flight_cancellation():
"""If the leader is cancelled, waiters recover and compute."""
call_count = 0

@cache(max_size=128)
async def slow_fn(x):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x * 10

leader = asyncio.create_task(slow_fn(1))
await asyncio.sleep(0.01) # let the leader start
waiters = [asyncio.create_task(slow_fn(1)) for _ in range(3)]
await asyncio.sleep(0.01) # let waiters register

leader.cancel()
results = await asyncio.gather(*waiters)
assert results == [10, 10, 10]
# Leader was cancelled (count 1), then one waiter recomputed (count 2)
assert call_count == 2


@pytest.mark.skipif(sys.platform == "win32", reason="shared memory is Unix-only")
@pytest.mark.asyncio
async def test_async_single_flight_shared():
"""Single-flight works with the shared backend too."""
call_count = 0

@cache(max_size=128, backend="shared")
async def slow_fn(x):
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05)
return x * 10

slow_fn.cache_clear()
results = await asyncio.gather(*(slow_fn(1) for _ in range(10)))
assert results == [10] * 10
assert call_count == 1
39 changes: 36 additions & 3 deletions warp_cache/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,59 @@ class AsyncCachedFunction:

Uses the Rust get/set methods for cache lookup/store so that the
async function is only awaited on cache miss.

Implements single-flight coalescing: when multiple coroutines miss the
cache for the same key concurrently, only one computes the result and
the rest wait for it.
"""

def __init__(
self, fn: Callable[..., Any], inner: CachedFunction | SharedCachedFunction
) -> None:
self._fn = fn
self._inner = inner
self._inflight: dict[Any, asyncio.Event] = {}
self.__wrapped__ = fn
self.__name__ = getattr(fn, "__name__", repr(fn))
self.__qualname__ = getattr(fn, "__qualname__", self.__name__)
self.__module__ = getattr(fn, "__module__", __name__)
self.__doc__ = getattr(fn, "__doc__", None)

@staticmethod
def _make_inflight_key(
args: tuple[Any, ...], kwargs: dict[str, Any] | None
) -> Any:
if kwargs:
return (args, tuple(sorted(kwargs.items())))
return args

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
cached = self._inner.get(*args, **kwargs)
if cached is not None:
return cached
result = await self._fn(*args, **kwargs)
self._inner.set(result, *args, **kwargs)
return result

key = self._make_inflight_key(args, kwargs or None)

while True:
event = self._inflight.get(key)
if event is not None:
await event.wait()
cached = self._inner.get(*args, **kwargs)
if cached is not None:
return cached
# Leader failed — loop back to check for a new leader
continue

# We're the first: register our intent
event = asyncio.Event()
self._inflight[key] = event
try:
result = await self._fn(*args, **kwargs)
self._inner.set(result, *args, **kwargs)
return result
finally:
event.set()
self._inflight.pop(key, None)

def cache_info(self) -> CacheInfo | SharedCacheInfo:
return self._inner.cache_info()
Expand Down
Loading