diff --git a/scripts/lib/rewind/store.py b/scripts/lib/rewind/store.py index 1222763..afc8ba2 100644 --- a/scripts/lib/rewind/store.py +++ b/scripts/lib/rewind/store.py @@ -16,9 +16,11 @@ """ from __future__ import annotations import hashlib +import json import time from collections import OrderedDict from dataclasses import dataclass +from pathlib import Path from typing import Optional @@ -87,3 +89,74 @@ def size(self) -> int: def clear(self) -> None: self._cache.clear() + + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def save(self, path: str | Path) -> int: + """Persist the current cache to a JSON file. + + Entries that have exceeded their TTL are not saved. Returns the + number of entries written. + + Args: + path: Destination file path (will be created or overwritten). + + Returns: + Number of entries persisted. + """ + now = time.monotonic() + entries: dict[str, dict] = {} + for hash_id, entry in self._cache.items(): + age = now - entry.stored_at + if age <= self.ttl_seconds: + entries[hash_id] = { + "original": entry.original, + "compressed": entry.compressed, + "remaining_ttl": round(self.ttl_seconds - age, 1), + "original_tokens": entry.original_tokens, + "compressed_tokens": entry.compressed_tokens, + } + Path(path).write_text( + json.dumps(entries, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return len(entries) + + def load(self, path: str | Path) -> int: + """Restore cache entries from a previously saved JSON file. + + Loaded entries are merged into the current cache (existing entries + are preserved). Each loaded entry's TTL is set to its + ``remaining_ttl`` from the save file — entries that would already + be expired are skipped. + + Args: + path: Source file path. + + Returns: + Number of entries loaded. + """ + data = json.loads(Path(path).read_text(encoding="utf-8")) + loaded = 0 + now = time.monotonic() + for hash_id, blob in data.items(): + remaining = blob.get("remaining_ttl", 0) + if remaining <= 0: + continue + if hash_id in self._cache: + continue # don't overwrite live entries + entry = CacheEntry( + original=blob["original"], + compressed=blob.get("compressed", ""), + stored_at=now - (self.ttl_seconds - remaining), + original_tokens=blob.get("original_tokens", 0), + compressed_tokens=blob.get("compressed_tokens", 0), + ) + self._cache[hash_id] = entry + loaded += 1 + # Enforce max_entries after load + while len(self._cache) > self.max_entries: + self._cache.popitem(last=False) + return loaded