diff --git a/clawteam/store/file.py b/clawteam/store/file.py index 4dcb3a9..e706eb9 100644 --- a/clawteam/store/file.py +++ b/clawteam/store/file.py @@ -3,19 +3,11 @@ from __future__ import annotations import json -import os -import sys -import tempfile -from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path from typing import Any -if sys.platform == "win32": - import msvcrt -else: - import fcntl - +from clawteam.fileutil import atomic_write_text, file_locked from clawteam.paths import ensure_within_root, validate_identifier from clawteam.store.base import BaseTaskStore, TaskLockError from clawteam.team.models import TaskItem, TaskPriority, TaskStatus, get_data_dir @@ -34,9 +26,6 @@ def _task_path(team_name: str, task_id: str) -> Path: return _tasks_root(team_name) / f"task-{task_id}.json" -def _tasks_lock_path(team_name: str) -> Path: - return _tasks_root(team_name) / ".tasks.lock" - def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() @@ -51,28 +40,8 @@ class FileTaskStore(BaseTaskStore): Concurrent access is serialised with an OS-specific advisory lock. """ - @contextmanager def _write_lock(self): - lock_path = _tasks_lock_path(self.team_name) - lock_path.parent.mkdir(parents=True, exist_ok=True) - with lock_path.open("a+", encoding="utf-8") as lock_file: - if sys.platform == "win32": - pos = lock_file.tell() - lock_file.seek(0) - msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) - lock_file.seek(pos) - else: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - try: - yield - finally: - if sys.platform == "win32": - pos = lock_file.tell() - lock_file.seek(0) - msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) - lock_file.seek(pos) - else: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + return file_locked(_tasks_root(self.team_name) / ".tasks") def create( self, @@ -320,19 +289,7 @@ def _visit(node: str) -> bool: def _save_unlocked(self, task: TaskItem) -> None: path = _task_path(self.team_name, task.id) - path.parent.mkdir(parents=True, exist_ok=True) - fd, tmp_name = tempfile.mkstemp( - dir=path.parent, - prefix=f"{path.stem}-", - suffix=".tmp", - ) - try: - with os.fdopen(fd, "w", encoding="utf-8") as tmp_file: - tmp_file.write(task.model_dump_json(indent=2, by_alias=True)) - os.replace(tmp_name, str(path)) - except BaseException: - Path(tmp_name).unlink(missing_ok=True) - raise + atomic_write_text(path, task.model_dump_json(indent=2, by_alias=True)) def _resolve_dependents_unlocked(self, completed_task_id: str) -> None: root = _tasks_root(self.team_name) diff --git a/clawteam/team/router.py b/clawteam/team/router.py index f3e566b..083d733 100644 --- a/clawteam/team/router.py +++ b/clawteam/team/router.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json from datetime import datetime from clawteam.spawn.tmux_backend import TmuxBackend @@ -50,7 +49,7 @@ def normalize_message(self, message: TeamMessage) -> RuntimeEnvelope: evidence.append(f"requestId: {message.request_id}") summary = (message.content or "").strip() or f"{message.type.value} from {source}" - payload = json.loads(message.model_dump_json(by_alias=True, exclude_none=True)) + payload = message.model_dump(by_alias=True, exclude_none=True) return RuntimeEnvelope( source=source, diff --git a/clawteam/team/routing_policy.py b/clawteam/team/routing_policy.py index 161f714..b8035c9 100644 --- a/clawteam/team/routing_policy.py +++ b/clawteam/team/routing_policy.py @@ -3,14 +3,13 @@ from __future__ import annotations import json -import os -import tempfile from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any +from clawteam.fileutil import atomic_write_text, file_locked from clawteam.team.models import get_data_dir _RECENT_EVENT_LIMIT = 50 @@ -97,8 +96,15 @@ def __init__(self, team_name: str, throttle_seconds: int = 30): self.team_name = team_name self.throttle_seconds = throttle_seconds + def _state_lock(self): + return file_locked(_runtime_state_path(self.team_name)) + def decide(self, envelope: RuntimeEnvelope, now: datetime | str | None = None) -> RouteDecision: now_dt = _ensure_datetime(now) + with self._state_lock(): + return self._decide_locked(envelope, now_dt) + + def _decide_locked(self, envelope: RuntimeEnvelope, now_dt: datetime) -> RouteDecision: state = self.read_state() route_key = self._route_key(envelope.source, envelope.target) route = state["routes"].setdefault(route_key, self._empty_route(envelope)) @@ -162,6 +168,10 @@ def flush_due( now: datetime | str | None = None, ) -> list[RouteDecision]: now_dt = _ensure_datetime(now) + with self._state_lock(): + return self._flush_due_locked(target_agent, now_dt) + + def _flush_due_locked(self, target_agent: str | None, now_dt: datetime) -> list[RouteDecision]: state = self.read_state() decisions: list[RouteDecision] = [] @@ -211,6 +221,17 @@ def record_dispatch_result( error: str = "", ) -> None: now_dt = _ensure_datetime(now) + with self._state_lock(): + self._record_dispatch_result_locked(decision, success=success, now_dt=now_dt, error=error) + + def _record_dispatch_result_locked( + self, + decision: RouteDecision, + *, + success: bool, + now_dt: datetime, + error: str = "", + ) -> None: state = self.read_state() route = state["routes"].setdefault( decision.route_key, @@ -273,21 +294,10 @@ def read_state(self) -> dict[str, Any]: def _save_state(self, state: dict[str, Any]) -> None: path = _runtime_state_path(self.team_name) - path.parent.mkdir(parents=True, exist_ok=True) state["team"] = self.team_name state["throttleSeconds"] = self.throttle_seconds state["updatedAt"] = _utcnow().isoformat() - - fd, tmp_name = tempfile.mkstemp(dir=path.parent, suffix=".tmp") - try: - with os.fdopen(fd, "w", encoding="utf-8") as handle: - json.dump(state, handle, indent=2, ensure_ascii=False) - Path(tmp_name).replace(path) - finally: - try: - Path(tmp_name).unlink(missing_ok=True) - except OSError: - pass + atomic_write_text(path, json.dumps(state, indent=2, ensure_ascii=False)) @staticmethod def _route_key(source: str, target: str) -> str: