diff --git a/tests/test_async.py b/tests/test_async.py index 2c2ddfd..552b872 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -119,8 +119,8 @@ async def slow_fn(x): slow_fn(3), ) assert results == [2, 4, 6, 2, 4, 6] - # At least 3 unique calls, possibly more due to concurrent misses - assert call_count >= 3 + # Exactly 3 unique keys — single-flight coalescing prevents redundant calls + assert call_count == 3 # ── Eviction ────────────────────────────────────────────────────────────── @@ -193,3 +193,112 @@ def add(a, b): assert add(1, 2) == 3 info = add.cache_info() assert info.hits == 1 + + +# ── Single-flight (dogpile prevention) ─────────────────────────────────── + + +@pytest.mark.asyncio +async def test_async_single_flight(): + """Multiple concurrent coroutines for the same key: only one computes.""" + call_count = 0 + + @cache(max_size=128) + async def slow_fn(x): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + return x * 10 + + results = await asyncio.gather(*(slow_fn(1) for _ in range(10))) + assert results == [10] * 10 + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_async_single_flight_different_keys(): + """Concurrent calls with different keys compute independently.""" + call_count = 0 + + @cache(max_size=128) + async def slow_fn(x): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + return x * 10 + + results = await asyncio.gather( + slow_fn(1), slow_fn(2), slow_fn(3), + slow_fn(1), slow_fn(2), slow_fn(3), + ) + assert results == [10, 20, 30, 10, 20, 30] + assert call_count == 3 + + +@pytest.mark.asyncio +async def test_async_single_flight_error_recovery(): + """If the leader fails, a waiter becomes the new leader and retries.""" + call_count = 0 + + @cache(max_size=128) + async def flaky_fn(x): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.02) + if call_count == 1: + raise ValueError("transient error") + return x * 10 + + # First batch: leader fails, waiters should retry + tasks = [asyncio.create_task(flaky_fn(1)) for _ in range(5)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # The leader raised; all waiters recovered via a new leader + successes = [r for r in results if r == 10] + errors = [r for r in results if isinstance(r, ValueError)] + assert len(errors) == 1 # only the original leader failed + assert len(successes) == 4 # all waiters got the result + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_async_single_flight_cancellation(): + """If the leader is cancelled, waiters recover and compute.""" + call_count = 0 + + @cache(max_size=128) + async def slow_fn(x): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.1) + return x * 10 + + leader = asyncio.create_task(slow_fn(1)) + await asyncio.sleep(0.01) # let the leader start + waiters = [asyncio.create_task(slow_fn(1)) for _ in range(3)] + await asyncio.sleep(0.01) # let waiters register + + leader.cancel() + results = await asyncio.gather(*waiters) + assert results == [10, 10, 10] + # Leader was cancelled (count 1), then one waiter recomputed (count 2) + assert call_count == 2 + + +@pytest.mark.skipif(sys.platform == "win32", reason="shared memory is Unix-only") +@pytest.mark.asyncio +async def test_async_single_flight_shared(): + """Single-flight works with the shared backend too.""" + call_count = 0 + + @cache(max_size=128, backend="shared") + async def slow_fn(x): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + return x * 10 + + slow_fn.cache_clear() + results = await asyncio.gather(*(slow_fn(1) for _ in range(10))) + assert results == [10] * 10 + assert call_count == 1 diff --git a/warp_cache/_decorator.py b/warp_cache/_decorator.py index 9702a55..694c88c 100644 --- a/warp_cache/_decorator.py +++ b/warp_cache/_decorator.py @@ -21,6 +21,10 @@ class AsyncCachedFunction: Uses the Rust get/set methods for cache lookup/store so that the async function is only awaited on cache miss. + + Implements single-flight coalescing: when multiple coroutines miss the + cache for the same key concurrently, only one computes the result and + the rest wait for it. """ def __init__( @@ -28,19 +32,48 @@ def __init__( ) -> None: self._fn = fn self._inner = inner + self._inflight: dict[Any, asyncio.Event] = {} self.__wrapped__ = fn self.__name__ = getattr(fn, "__name__", repr(fn)) self.__qualname__ = getattr(fn, "__qualname__", self.__name__) self.__module__ = getattr(fn, "__module__", __name__) self.__doc__ = getattr(fn, "__doc__", None) + @staticmethod + def _make_inflight_key( + args: tuple[Any, ...], kwargs: dict[str, Any] | None + ) -> Any: + if kwargs: + return (args, tuple(sorted(kwargs.items()))) + return args + async def __call__(self, *args: Any, **kwargs: Any) -> Any: cached = self._inner.get(*args, **kwargs) if cached is not None: return cached - result = await self._fn(*args, **kwargs) - self._inner.set(result, *args, **kwargs) - return result + + key = self._make_inflight_key(args, kwargs or None) + + while True: + event = self._inflight.get(key) + if event is not None: + await event.wait() + cached = self._inner.get(*args, **kwargs) + if cached is not None: + return cached + # Leader failed — loop back to check for a new leader + continue + + # We're the first: register our intent + event = asyncio.Event() + self._inflight[key] = event + try: + result = await self._fn(*args, **kwargs) + self._inner.set(result, *args, **kwargs) + return result + finally: + event.set() + self._inflight.pop(key, None) def cache_info(self) -> CacheInfo | SharedCacheInfo: return self._inner.cache_info()