diff --git a/docs/git_go_agent.md b/docs/git_go_agent.md new file mode 100644 index 00000000..6ef1da94 --- /dev/null +++ b/docs/git_go_agent.md @@ -0,0 +1,153 @@ +# GitGoAgent Documentation + +`GitGoAgent` is a specialized execution agent for git-managed Go repositories. It combines git control, Go development tools, code linting, and file operations to enable autonomous development workflows on Go projects. + +## Basic Usage + +```python +from pathlib import Path + +from langchain.chat_models import init_chat_model +from ursa.agents import GitGoAgent + +# Initialize the agent +agent = GitGoAgent( + llm=init_chat_model("openai:gpt-5-mini"), + workspace=Path("/path/to/your/repo"), +) + +# Run a request +result = agent.invoke( + "Run tests, lint the code, and commit any fixes." +) +print(result["messages"][-1].text) +``` + +## Tools Available + +### Git Operations +- `git_status`: Show repository status. +- `git_diff`: Show diffs (staged or unstaged). +- `git_log`: Show recent commits. +- `git_ls_files`: List tracked files. +- `git_add`: Stage files for commit. +- `git_commit`: Create a commit. +- `git_switch`: Switch branches (optionally creating new ones). +- `git_create_branch`: Create a branch without switching. + +### Go Build and Test Tools +- `go_build`: Build the module using `go build ./...` +- `go_test`: Run tests with `go test ./...` (supports verbose mode) +- `go_vet`: Run Go vet for code pattern analysis +- `go_mod_tidy`: Validate and clean module dependencies + +### Code Quality Tools +- `golangci_lint`: Run golangci-lint on the repository + - Automatically detects and uses `.golangci.yml` if present + - Falls back to default linter configuration if config file missing + - Provides helpful error messages if golangci-lint is not installed +- `gofmt_files`: Format .go files in-place using gofmt + +### File Operations +- `read_file`: Read file contents +- `write_code`: Write new files (with optional path validation) +- `write_code_with_repo`: Write new files constrained to a repository path +- `edit_code`: Edit existing files (with optional path validation) + +## Configuration and Behavior + +### Timeouts +Operations use differentiated, operation-specific timeouts (not a unified timeout): + +| Operation | Timeout | Rationale | +|-----------|---------|-----------| +| Git commands | 30 seconds | Should be near-instant; timeout indicates hanging (waiting for input or wrong directory) | +| Code formatting (`gofmt`) | 30 seconds | Usually fast operation | +| Code analysis (`go vet`, `go mod tidy`) | 60 seconds | Analysis is typically quick | +| Go build | 5 minutes (300s) | Builds on large codebases can be slow | +| Go test | 10 minutes (600s) | Test suites can legitimately take time | +| Linting (`golangci-lint`) | 3 minutes (180s) | Comprehensive linting takes moderate time | + +**Design note:** If git commands timeout, it typically indicates: +- The agent is running commands in the wrong directory (should use `repo_path` parameter) +- Git is waiting for interactive input (e.g., passphrase, editor) +- Network issues when accessing remote repos + +If other operations timeout, try running them in smaller chunks or profiling the specific operation. + +### Path Safety +`write_code`, `write_code_with_repo`, and `edit_code` validate file paths to prevent: +- **Path traversal attacks** (e.g., `../../../etc/passwd` attempts are rejected) +- **Writes outside the workspace** (all files must be within the workspace directory) +- **Writes outside the repository** (`write_code_with_repo`, or `edit_code` when `repo_path` is used) + +Path validation is enabled by default. For trusted sandbox/container usage, you can opt in to unsafe writes by setting: + +```bash +export URSA_ALLOW_UNSAFE_WRITES=1 +``` + +When enabled, workspace and repository boundary checks are bypassed for `write_code`, `write_code_with_repo`, and `edit_code`. + +Example: Specifying a repo boundary ensures all file modifications stay within that repository. + +### Golangci-lint Integration + +The agent automatically integrates with `golangci-lint` for code quality checks: + +```python +# Agent detects .golangci.yml and uses it automatically +agent.invoke("Run linting and report all issues") + +# Linter configuration is respected while agent iterates on fixes +``` + +**Install golangci-lint** if not already present: +```bash +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +``` + +The linter supports: +- Custom linter configurations via `.golangci.yml` +- Extensibility to additional linters in future versions +- Clear error reporting when linter is misconfigured + +## Common Workflows + +### 1. Build and Test Validation +```python +agent.invoke( + "Build the module, run all tests, and report any failures." +) +``` + +### 2. Code Quality Check and Fix +```python +agent.invoke( + "Run golangci-lint, identify issues, and attempt to fix them automatically." +) +``` + +### 3. Feature Implementation with Git Integration +```python +agent.invoke( + "Create a new feature branch, implement the requested functionality, " + "run tests and linting, and commit the changes." +) +``` + +### 4. Dependency Management +```python +agent.invoke( + "Run go mod tidy to clean up dependencies, then commit the changes." +) +``` + +## Notes + +- Operates only inside the configured workspace +- All file writes are validated against workspace and optionally repository boundaries +- Avoids destructive git operations by design (no force pushes, rebases, etc.) +- Supports subdirectory repositories via `repo_path` parameter on tools +- Explicit timeout handling prevents the agent from hanging on slow operations +- All tool output (stdout/stderr) is captured and returned to the agent for analysis diff --git a/examples/single_agent_examples/git_go_agent/git_go_agent_example.py b/examples/single_agent_examples/git_go_agent/git_go_agent_example.py new file mode 100644 index 00000000..54198d1d --- /dev/null +++ b/examples/single_agent_examples/git_go_agent/git_go_agent_example.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from langchain.chat_models import init_chat_model + +from ursa.agents import GitGoAgent + + +def main(): + repo_root = Path.cwd() + agent = GitGoAgent( + llm=init_chat_model("openai:gpt-5-mini"), + workspace=repo_root, + ) + + prompt = "Show git status, list tracked .go files, and summarize any main.go you find." + result = agent.invoke(prompt) + print(result["messages"][-1].text) + + +if __name__ == "__main__": + main() diff --git a/examples/two_agent_examples/plan_execute/city_10_vowels.py b/examples/two_agent_examples/plan_execute/city_10_vowels.py index cdf949f0..8aea15ad 100644 --- a/examples/two_agent_examples/plan_execute/city_10_vowels.py +++ b/examples/two_agent_examples/plan_execute/city_10_vowels.py @@ -58,7 +58,7 @@ def main(): return final_results except Exception as e: - print(f"Error in example: {str(e)}") + print(f"Error in example: {e!s}") import traceback traceback.print_exc() diff --git a/examples/two_agent_examples/plan_execute/example_from_yaml.yaml b/examples/two_agent_examples/plan_execute/example_from_yaml.yaml new file mode 100644 index 00000000..9e938f69 --- /dev/null +++ b/examples/two_agent_examples/plan_execute/example_from_yaml.yaml @@ -0,0 +1,22 @@ +# Simple configuration for plan_execute_from_yaml.py +# Run with: python plan_execute_from_yaml.py --config example_from_yaml.yaml + +problem: | + Write a simple Python script that reads a CSV file, filters rows where a "score" + column is above 90, and writes the results to a new CSV file. + +project: csv_filter_project + +planning_mode: single + +models: + default: openai:gpt-4o-mini + choices: + - openai:gpt-4o-mini + - openai:gpt-3.5-turbo + - openai:gpt-5.2 + +logo: + enabled: false + +symlink: {} diff --git a/examples/two_agent_examples/plan_execute/example_multi_repo.yaml b/examples/two_agent_examples/plan_execute/example_multi_repo.yaml new file mode 100644 index 00000000..9369001d --- /dev/null +++ b/examples/two_agent_examples/plan_execute/example_multi_repo.yaml @@ -0,0 +1,58 @@ +# Simple configuration for plan_execute_multi_repo.py +# Run with: python plan_execute_multi_repo.py --config example_multi_repo.yaml + +problem: | + Create a simple three-part system: + 1. A shared library (utils_lib) with utility functions for CSV processing + 2. A CLI tool (csv_cli) that uses the shared library + 3. Documentation (docs) describing how to use the CLI + +repos: + - name: utils_lib + path: /tmp/ursa_test_repos/utils_lib + url: https://github.com/example/utils_lib.git + branch: main + checkout: false + description: Shared utility functions for CSV processing + language: python + + - name: csv_cli + path: /tmp/ursa_test_repos/csv_cli + url: https://github.com/example/csv_cli.git + branch: main + checkout: false + description: Command-line tool for CSV filtering + language: python + + - name: docs + path: /tmp/ursa_test_repos/docs + url: https://github.com/example/docs.git + branch: main + checkout: false + description: Documentation for the CSV processing system + language: markdown + +project: multi_repo_csv_system + +models: + default: openai:gpt-4o-mini + choices: + - openai:gpt-4o-mini + - openai:gpt-3.5-turbo + +planner: + reflection_steps: 1 + research: + github: + enabled: false + queries: [] + +execution: + max_parallel: 2 + recursion_limit: 2000 + resume: false + status_interval_sec: 5 + max_check_retries: 2 + step_timeout_sec: 0 + timeout_mode: pause + skip_failed_repos: false diff --git a/examples/two_agent_examples/plan_execute/openchami_boot_docs_example.yaml b/examples/two_agent_examples/plan_execute/openchami_boot_docs_example.yaml new file mode 100644 index 00000000..bc997244 --- /dev/null +++ b/examples/two_agent_examples/plan_execute/openchami_boot_docs_example.yaml @@ -0,0 +1,120 @@ +# Multi-repo example: Generate boot-service documentation for openchami.org website +# +# This configuration orchestrates changes across two repos: +# 1. boot-service: Extract usage examples, API endpoints, and configuration details +# 2. openchami.org: Generate and publish documentation pages based on boot-service + +project: openchami_boot_docs + +# Problem statement: What are we trying to accomplish? +problem: | + Generate comprehensive documentation for the OpenCHAMI boot-service and publish it + on the openchami.org website. The documentation should include: + + - Overview of what the boot-service does and its role in the OpenCHAMI ecosystem + - API endpoint reference with request/response examples + - Configuration guide (environment variables, config files, deployment options) + - Usage examples for common workflows (iPXE boot, cloud-init, etc.) + - Integration guide for administrators deploying the service + - One or more blog posts to describe the service to the OpenCHAMI community + + The boot-service repo should be analyzed to extract accurate, up-to-date information. + The openchami.org repo should have new documentation pages created in the appropriate + location with proper navigation links. + +# Repository definitions +repos: + - name: boot-service + url: https://github.com/openchami/boot-service + branch: main + checkout: true + language: go + description: OpenCHAMI boot orchestration service + checks: + # Verify the repo is in good state (don't break existing functionality) + - go mod verify + - go vet ./... + + - name: openchami.org + url: https://github.com/openchami/openchami.org + branch: main + checkout: true + language: markdown # Documentation site (likely Hugo, Jekyll, or similar) + description: OpenCHAMI project website and documentation + checks: + # Verify documentation builds successfully + - npm run build || echo "No build command found, skipping" + +# Model configuration +models: + # Available model providers + providers: + openai: + model_provider: openai + api_key_env: OPENAI_API_KEY + + # Default model choice + choices: + - openai:gpt-5.2 + - openai:gpt-5 + - openai:o3 + - openai:o3-mini + - my_endpoint:openai/gpt-oss-120b # <— example external endpoint + model + default: openai:gpt-5 + + # Override models for specific roles + planner: openai:gpt-5.2 + executor: openai:gpt-5 + + # Named parameter profiles for different capabilities + profiles: + standard: + temperature: 0.2 + max_completion_tokens: 8000 + + reasoning: + temperature: 0.1 + max_completion_tokens: 16000 + reasoning: + effort: medium + + # Global defaults + defaults: + profile: standard + params: + temperature: 0.2 + max_retries: 2 + + # Per-agent overrides + agents: + planner: + profile: reasoning + params: + max_completion_tokens: 16000 + + executor: + profile: standard + params: + max_completion_tokens: 10000 + +# Planning configuration +planner: + reflection_steps: 2 # Number of plan review iterations + + research: + # GitHub research: automatically fetch issues/PRs from repo URLs + github: + enabled: true + max_issues: 20 + max_prs: 10 + +# Execution configuration +execution: + max_parallel: 2 # Process both repos in parallel when possible + recursion_limit: 2000 # LangGraph recursion limit + resume: true # Support resuming from checkpoints + status_interval_sec: 10 # Progress table update frequency + max_check_retries: 2 # Retry failed verification checks + step_timeout_sec: 0 # timeout per step (0 = no limit) + timeout_mode: pause # On timeout: pause, skip, or fail + skip_failed_repos: false # Stop all repos if one fails diff --git a/examples/two_agent_examples/plan_execute/plan_execute_from_yaml.py b/examples/two_agent_examples/plan_execute/plan_execute_from_yaml.py index feee792c..3948a8e6 100644 --- a/examples/two_agent_examples/plan_execute/plan_execute_from_yaml.py +++ b/examples/two_agent_examples/plan_execute/plan_execute_from_yaml.py @@ -1,19 +1,12 @@ import argparse # needed for checkpoint / restart -import hashlib -import importlib import json -import os import sqlite3 import sys from pathlib import Path -from types import SimpleNamespace as NS from typing import Any -import randomname -import yaml -from langchain.chat_models import init_chat_model from langchain_core.messages import HumanMessage from langgraph.checkpoint.sqlite import SqliteSaver @@ -26,6 +19,17 @@ from ursa.agents import ExecutionAgent, PlanningAgent from ursa.observability.timing import render_session_summary from ursa.util.logo_generator import kickoff_logo +from ursa.util.plan_execute_utils import ( + generate_workspace_name, + hash_plan, + load_json_file, + load_yaml_config, + save_json_file, + setup_llm, + setup_workspace, + snapshot_sqlite_db, + timed_input_with_countdown, +) from ursa.util.plan_renderer import render_plan_steps_rich console = get_console() # always returns the same instance @@ -142,21 +146,9 @@ def _progress_file(workspace: str) -> Path: return Path(workspace) / "executor_progress.json" -def _hash_plan(plan_steps) -> str: - # hash the structure so we can detect if the plan changed between runs - return hashlib.sha256( - json.dumps(plan_steps, sort_keys=True, default=str).encode("utf-8") - ).hexdigest() - - def load_exec_progress(workspace: str) -> dict: p = _progress_file(workspace) - if p.exists(): - try: - return json.loads(p.read_text()) - except Exception: - return {} - return {} + return load_json_file(p, {}) # we have to save the last step in here too @@ -170,36 +162,7 @@ def save_exec_progress( payload = {"next_index": int(next_index), "plan_hash": plan_hash} if last_summary is not None: payload["last_summary"] = last_summary - p.write_text(json.dumps(payload, indent=2)) - - -# --- snapshot a consistent copy of a SQLite db (works even in WAL mode) --- -def snapshot_sqlite_db(src_path: Path, dst_path: Path) -> None: - """ - Make a consistent copy of the SQLite database at src_path into dst_path, - using the sqlite3 backup API. Safe with WAL; no need to copy -wal/-shm. - """ - import sqlite3 - - dst_path.parent.mkdir(parents=True, exist_ok=True) - src_uri = f"file:{Path(src_path).resolve().as_posix()}?mode=ro" - src = dst = None - try: - src = sqlite3.connect(src_uri, uri=True) - dst = sqlite3.connect(str(dst_path)) - with dst: - src.backup(dst) - finally: - try: - if dst: - dst.close() - except Exception: - pass - try: - if src: - src.close() - except Exception: - pass + save_json_file(p, payload) def step_to_text(step) -> str: @@ -278,64 +241,6 @@ def _ckpt_sort_key(p: Path): return (2, float("inf"), float("inf"), name) -# --- timed input with countdown (POSIX-friendly; auto-fallback if non-interactive) --- -def timed_input_with_countdown(prompt: str, timeout: int) -> str | None: - """ - Read a line with a per-second countdown. Returns: - - the user's input (str) if provided, - - None if timeout expires, - - None if non-interactive or timeout<=0. - No bracketed prefixes are printed (clean output for all prompts). - """ - import sys - import time - - # Non-interactive or disabled timeout → default immediately (no noisy prefix) - try: - is_tty = sys.stdin.isatty() - except Exception: - is_tty = False - - if not is_tty: - print("(non-interactive) selecting default . . .") - return None - if timeout <= 0: - print("(timeout disabled) selecting default . . .") - return None - - # Show prompt and run a 1s polling loop - deadline = time.time() + timeout - print(prompt, end="", flush=True) - - try: - import select - - while True: - remaining = int(max(0, deadline - time.time())) - if remaining in {30, 10, 5, 4, 3, 2, 1}: - # print a short tick line, then reprint the prompt - print( - f"\n{remaining} seconds left . . . (Ctrl-C to abort)", - flush=True, - ) - print(prompt, end="", flush=True) - if remaining <= 0: - print() # newline after prompt - return None - - rlist, _, _ = select.select([sys.stdin], [], [], 1.0) - if rlist: - line = sys.stdin.readline() - return None if line is None else line.strip() - - except Exception: - # Fallback if select is unavailable - try: - return input() - except KeyboardInterrupt: - raise - - def list_executor_checkpoints(workspace: str) -> list[Path]: ws = Path(workspace) ckdir = _ckpt_dir(workspace) @@ -460,25 +365,16 @@ def _hier_progress_file(workspace: str) -> Path: return Path(workspace) / "hier_progress.json" -def _read_json(path: Path, default): - if path.exists(): - try: - return json.loads(path.read_text()) - except Exception: - return default - return default - - def load_hier_progress(workspace: str) -> dict: # shape: {"main": {"next_index": int, "plan_hash": str}, "subs": {"": {"next_index": int, "plan_hash": str, "last_summary": str}}} - return _read_json( + return load_json_file( _hier_progress_file(workspace), {"main": {"next_index": 0, "plan_hash": None}, "subs": {}}, ) def save_hier_progress(workspace: str, data: dict) -> None: - _hier_progress_file(workspace).write_text(json.dumps(data, indent=2)) + save_json_file(_hier_progress_file(workspace), data) def save_hier_main_progress( @@ -546,10 +442,9 @@ def load_run_meta(workspace: str) -> dict: def save_run_meta(workspace: str, **fields) -> dict: p = _run_meta_file(workspace) - p.parent.mkdir(parents=True, exist_ok=True) # <-- ensure dir exists meta = load_run_meta(workspace) meta.update({k: v for k, v in fields.items() if v is not None}) - p.write_text(json.dumps(meta, indent=2)) + save_json_file(p, meta) return meta @@ -607,46 +502,6 @@ def _print_next_step(prefix: str, next_zero: int, total: int, workspace: str): ######################################################################### # END: Assorted other helpers ######################################################################### -_SECRET_KEY_SUBSTRS = ( - "api_key", - "apikey", - "access_token", - "refresh_token", - "secret", - "password", - "bearer", -) - - -def _looks_like_secret_key(name: str) -> bool: - n = name.lower() - return any(s in n for s in _SECRET_KEY_SUBSTRS) - - -def _mask_secret(value: str, keep_start: int = 6, keep_end: int = 4) -> str: - """ - Mask a secret-like string, keeping only the beginning and end. - Example: sk-proj-abc123456789xyz -> sk-proj-…9xyz - """ - if not isinstance(value, str): - return value - if len(value) <= keep_start + keep_end + 3: - return "…" # too short to safely show anything - return f"{value[:keep_start]}...{value[-keep_end:]}" - - -def _sanitize_for_logging(obj): - if isinstance(obj, dict): - out = {} - for k, v in obj.items(): - if _looks_like_secret_key(str(k)): - out[k] = _mask_secret(v) if isinstance(v, str) else "..." - else: - out[k] = _sanitize_for_logging(v) - return out - if isinstance(obj, list): - return [_sanitize_for_logging(v) for v in obj] - return obj def setup_agents( @@ -706,273 +561,6 @@ def setup_agents( ) -def _deep_merge_dicts(base: dict, override: dict) -> dict: - """ - Recursively merge override into base and return a new dict. - - dict + dict => deep merge - - otherwise => override wins - """ - base = dict(base or {}) - override = dict(override or {}) - out = dict(base) - for k, v in override.items(): - if k in out and isinstance(out[k], dict) and isinstance(v, dict): - out[k] = _deep_merge_dicts(out[k], v) - else: - out[k] = v - return out - - -def _resolve_llm_kwargs_for_agent( - models_cfg: dict | None, agent_name: str | None -) -> dict: - """ - Given the YAML `models:` dict, compute merged kwargs for init_chat_model(...) - for a specific agent ('planner' or 'executor'). - - Merge order (later wins): - 1) {} (empty) - 2) models.defaults.params (optional) - 3) models.profiles[defaults.profile] (optional) - 4) models.agents[agent_name].profile (optional; merges that profile on top) - 5) models.agents[agent_name].params (optional) - """ - models_cfg = models_cfg or {} - profiles = models_cfg.get("profiles") or {} - defaults = models_cfg.get("defaults") or {} - agents = models_cfg.get("agents") or {} - - # Start with global defaults - merged = {} - merged = _deep_merge_dicts(merged, defaults.get("params") or {}) - - # Apply default profile - default_profile_name = defaults.get("profile") - if default_profile_name and default_profile_name in profiles: - merged = _deep_merge_dicts(merged, profiles[default_profile_name] or {}) - - # Apply agent-specific profile + params - if agent_name and isinstance(agents, dict) and agent_name in agents: - a = agents.get(agent_name) or {} - agent_profile_name = a.get("profile") - if agent_profile_name and agent_profile_name in profiles: - merged = _deep_merge_dicts( - merged, profiles[agent_profile_name] or {} - ) - merged = _deep_merge_dicts(merged, a.get("params") or {}) - - return merged - - -def _print_llm_init_banner( - agent_name: str | None, - provider: str, - model_name: str, - provider_extra: dict, - llm_kwargs: dict, - model_obj=None, -) -> None: - who = agent_name or "llm" - - safe_provider_extra = _sanitize_for_logging(provider_extra or {}) - safe_llm_kwargs = _sanitize_for_logging(llm_kwargs or {}) - - console.print( - Panel.fit( - Text.from_markup( - f"[bold cyan]LLM init ({who})[/]\n" - f"[bold]provider[/]: {provider}\n" - f"[bold]model[/]: {model_name}\n\n" - f"[bold]provider kwargs[/]: {json.dumps(safe_provider_extra, indent=2)}\n\n" - f"[bold]llm kwargs (merged)[/]: {json.dumps(safe_llm_kwargs, indent=2)}" - ), - border_style="cyan", - ) - ) - - # Best-effort readback from the LangChain model object - if model_obj is None: - return - - readback = {} - for attr in ( - "model_name", - "model", - "reasoning", - "temperature", - "max_completion_tokens", - "max_tokens", - ): - if hasattr(model_obj, attr): - try: - readback[attr] = getattr(model_obj, attr) - except Exception: - pass - - for attr in ("model_kwargs", "kwargs"): - if hasattr(model_obj, attr): - try: - readback[attr] = getattr(model_obj, attr) - except Exception: - pass - - if readback: - console.print( - Panel.fit( - Text.from_markup( - "[bold green]LLM readback (best-effort from LangChain object)[/]\n" - + json.dumps(_sanitize_for_logging(readback), indent=2) - ), - border_style="green", - ) - ) - - effort = None - try: - effort = (llm_kwargs or {}).get("reasoning", {}).get("effort") - except Exception: - effort = None - - if effort: - console.print( - Panel.fit( - Text.from_markup( - f"[bold yellow]Reasoning effort requested[/]: {effort}\n" - "Note: This confirms what we sent to init_chat_model; actual enforcement is provider-side." - ), - border_style="yellow", - ) - ) - - -def _resolve_model_choice(model_choice: str, models_cfg: dict): - """ - Accepts strings like 'openai:gpt-5.2' or 'my_endpoint:openai/gpt-oss-120b'. - Looks up per-provider settings from cfg.models.providers. - - Returns: (model_provider, pure_model, provider_extra_kwargs_for_init) - """ - if ":" in model_choice: - alias, pure_model = model_choice.split(":", 1) - else: - alias, pure_model = "openai", model_choice # back-compat default - - providers = (models_cfg or {}).get("providers", {}) - prov = providers.get(alias, {}) - - # Which LangChain integration to use (e.g. "openai", "mistral", etc.) - model_provider = prov.get("model_provider", alias) - - # auth: prefer env var; optionally load via function if configured - api_key = None - if prov.get("api_key_env"): - api_key = os.getenv(prov["api_key_env"]) - if not api_key and prov.get("token_loader"): - mod, fn = prov["token_loader"].rsplit(".", 1) - api_key = getattr(importlib.import_module(mod), fn)() - - provider_extra = {} - if prov.get("base_url"): - provider_extra["base_url"] = prov["base_url"] - if api_key: - provider_extra["api_key"] = api_key - - return model_provider, pure_model, provider_extra - - -def setup_llm( - model_choice: str, - models_cfg: dict | None = None, - agent_name: str | None = None, -): - """ - Build a LangChain chat model via init_chat_model(...), optionally applying - YAML-driven params: - models.profiles - models.defaults - models.agents. - - Back-compat: if those blocks don't exist, you get your previous behavior. - """ - models_cfg = models_cfg or {} - - provider, pure_model, provider_extra = _resolve_model_choice( - model_choice, models_cfg - ) - - # Your existing hardcoded defaults (keep these so older YAML behaves the same) - base_llm_kwargs = { - "max_completion_tokens": 10000, - "max_retries": 2, - } - - # YAML-driven kwargs (safe if absent) - yaml_llm_kwargs = _resolve_llm_kwargs_for_agent(models_cfg, agent_name) - - # Merge: base defaults < YAML overrides - llm_kwargs = _deep_merge_dicts(base_llm_kwargs, yaml_llm_kwargs) - - # Initialize - model = init_chat_model( - model=pure_model, - model_provider=provider, - **llm_kwargs, - **(provider_extra or {}), - ) - - # Print confirmation early - _print_llm_init_banner( - agent_name=agent_name, - provider=provider, - model_name=pure_model, - provider_extra=provider_extra, - llm_kwargs=llm_kwargs, - model_obj=model, - ) - - return model - - -def setup_workspace( - user_specified_workspace: str | None, - project: str = "run", - model_name: str = "openai:gpt-5-mini", -) -> str: - if user_specified_workspace is None: - print("No workspace specified, creating one for this project!") - print( - "Make sure to pass this string to restart using --workspace " - ) - # https://pypi.org/project/randomname/ - workspace = f"{project}_{randomname.get_name(adj=('colors', 'emotions', 'character', 'speed', 'size', 'weather', 'appearance', 'sound', 'age', 'taste'), noun=('cats', 'dogs', 'apex_predators', 'birds', 'fish', 'fruit'))}" - else: - workspace = user_specified_workspace - print(f"User specified workspace: {workspace}") - - Path(workspace).mkdir(parents=True, exist_ok=True) - - # Choose a fun emoji based on the model family (swap / extend as you add more) - if model_name.startswith("openai"): - model_emoji = "🤖" # OpenAI - elif "llama" in model_name.lower(): - model_emoji = "🦙" # Llama - else: - model_emoji = "🧠" # Fallback / generic LLM - - # Print the panel with model info - console.print( - Panel.fit( - f":rocket: [bold bright_blue]{workspace}[/bold bright_blue] :rocket:\n" - f"{model_emoji} [bold cyan]{model_name}[/bold cyan]", - title="[bold green]ACTIVE WORKSPACE[/bold green]", - border_style="bright_magenta", - padding=(1, 4), - ) - ) - - return workspace - - def main_plan_load_or_perform( planner, planner_checkpointer, @@ -1042,14 +630,14 @@ def main_plan_load_or_perform( "\nRe-run this program with the SAME --workspace to resume the plan.\n" ) print("Planning done, exiting") - exit() + sys.exit() # NOTE: # This is where we figure out where we are in the execution of the plan, what step # we are on # unify the plan dict for both fresh and resumed paths plan_steps = plan_dict.get("plan_steps") or [] - plan_sig = _hash_plan(plan_steps) + plan_sig = hash_plan(plan_steps) save_run_meta( workspace, plan_sig=plan_sig, plan_steps_count=len(plan_steps) ) @@ -1070,7 +658,7 @@ def get_or_create_subplan( ): if not hierarchical: # Single mode: 1-item synthetic sub-plan - return {"plan_steps": [main_step]}, _hash_plan([main_step]), None, None + return {"plan_steps": [main_step]}, hash_plan([main_step]), None, None sub_tid = f"{thread_id}::detail::{m_idx}" sub_values, _, dbg = load_latest_planner_state_from_sqlite( @@ -1080,7 +668,7 @@ def get_or_create_subplan( if sub_values: sub_steps = sub_values.get("plan_steps") or [] - return sub_values, _hash_plan(sub_steps), sub_tid, None + return sub_values, hash_plan(sub_steps), sub_tid, None # Need to plan sub-steps detail_planner_prompt = "Flesh out this main step into concrete sub-steps to fully accomplish it." @@ -1115,7 +703,7 @@ def get_or_create_subplan( ] sub_steps = sub_output.get("plan_steps") or [] - sub_sig = _hash_plan(sub_steps) + sub_sig = hash_plan(sub_steps) # persist initial sub-progress (index=0) save_hier_sub_progress( @@ -1133,7 +721,7 @@ def get_or_create_subplan( print( "Re-run with the SAME --workspace to execute the first sub-step.\n" ) - exit() + sys.exit() return {"plan_steps": sub_steps}, sub_sig, sub_tid, sub_output @@ -1217,7 +805,7 @@ def run_substeps( total=total_sub, workspace=workspace, ) - exit() + sys.exit() prev_sub_summary = last_sub_summary sub_start_idx = next_sub_zero @@ -1229,7 +817,7 @@ def main( model_name: str, config: Any, planning_mode: str = "single", - user_specified_workspace: str = None, + user_specified_workspace: str | None = None, stepwise_exit: bool = False, resume_from: str | None = None, interactive_timeout: int = 60, @@ -1240,9 +828,10 @@ def main( symlinkdict = getattr(config, "symlink", {}) or None # sets up the workspace, run config json, etc. - workspace = setup_workspace( - user_specified_workspace, project, model_name + resolved_workspace = ( + user_specified_workspace or generate_workspace_name(project) ) + workspace = setup_workspace(resolved_workspace, project, model_name) print(workspace) print(user_specified_workspace) @@ -1264,7 +853,7 @@ def main( sys.exit(1) # lock planning_mode per workspace - planning_mode, mode_locked = lock_or_warn_planning_mode( + planning_mode, _mode_locked = lock_or_warn_planning_mode( workspace, planning_mode ) console.print( @@ -1373,7 +962,7 @@ def main( save_run_meta(workspace, thread_id=thread_id, model_name=model_name) # do the main planning step, or load it from checkpoint - plan_dict, plan_steps, plan_sig = main_plan_load_or_perform( + _plan_dict, plan_steps, plan_sig = main_plan_load_or_perform( planner, planner_checkpointer, pdb_path, @@ -1677,7 +1266,7 @@ def save_progress_single(m, _next_idx, last_summary): return answer, workspace except Exception as e: - print(f"Error: {str(e)}") + print(f"Error: {e!s}") import traceback traceback.print_exc() @@ -1717,19 +1306,7 @@ def parse_args_and_user_inputs(): ) args = parser.parse_args() - # --- load YAML -> dict -> shallow namespace (top-level keys only) --- - try: - with open(args.config, "r", encoding="utf-8") as f: - raw_cfg = yaml.safe_load(f) or {} - if not isinstance(raw_cfg, dict): - raise ValueError("Top-level YAML must be a mapping/object.") - cfg = NS(**raw_cfg) # top-level attrs; nested remain dicts - except FileNotFoundError: - print(f"Config file not found: {args.config}", file=sys.stderr) - sys.exit(2) - except Exception as e: - print(f"Error loading YAML: {e}", file=sys.stderr) - sys.exit(2) + cfg = load_yaml_config(args.config) # ── config-driven model choices ──────────── models_cfg = getattr(cfg, "models", {}) or {} diff --git a/examples/two_agent_examples/plan_execute/plan_execute_multi_repo.py b/examples/two_agent_examples/plan_execute/plan_execute_multi_repo.py new file mode 100644 index 00000000..d48bb4bc --- /dev/null +++ b/examples/two_agent_examples/plan_execute/plan_execute_multi_repo.py @@ -0,0 +1,2151 @@ +import argparse +import asyncio +import json +import shlex +import subprocess +import sys +import time +from pathlib import Path + +import aiosqlite +from langchain_core.messages import HumanMessage, SystemMessage +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver +from pydantic import BaseModel, Field +from rich import box, get_console +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from ursa.agents import WebSearchAgent, make_git_agent +from ursa.prompt_library.planning_prompts import reflection_prompt +from ursa.util.github_research import gather_github_context +from ursa.util.plan_execute_utils import ( + fmt_elapsed, + generate_workspace_name, + hash_plan, + load_json_file, + load_yaml_config, + save_json_file, + setup_llm, + timed_input_with_countdown, +) + +console = get_console() + + +class RepoStep(BaseModel): + repo: str = Field(description="Target repo name from the provided list") + name: str = Field(description="Short, specific step title") + description: str = Field(description="Detailed description of the step") + requires_code: bool = Field( + description="True if this step needs code to be written/run" + ) + expected_outputs: list[str] = Field( + description="Concrete artifacts or results produced by this step" + ) + success_criteria: list[str] = Field( + description="Measurable checks that indicate the step succeeded" + ) + depends_on_repos: list[str] = Field( + default_factory=list, + description=( + "Repo names that must complete ALL their steps before this step " + "can start. Use this when this step depends on changes made in " + "another repo (e.g. a library repo must finish before consumer " + "repos can integrate its changes)." + ), + ) + + +class RepoPlan(BaseModel): + steps: list[RepoStep] = Field( + description="Ordered list of steps to solve the problem" + ) + + +def _resolve_workspace(user_workspace: str | None, project: str) -> Path: + if user_workspace: + workspace = Path(user_workspace) + else: + workspace = Path(generate_workspace_name(project)) + + workspace.mkdir(parents=True, exist_ok=True) + (workspace / "repos").mkdir(exist_ok=True) + return workspace + + +def _validate_model(llm, model_name: str, role: str) -> None: + """Send a minimal chat completion to verify the model is reachable and + supports the chat completions endpoint. Raises ``RuntimeError`` with a + clear message on failure so the user can fix their config before burning + tokens on a full run. + """ + from langchain_core.messages import HumanMessage as _HM + + try: + llm.invoke([_HM(content="ping")]) + except Exception as exc: + msg = str(exc) + # Surface the most common issues clearly + if "not a chat model" in msg.lower() or "404" in msg: + raise RuntimeError( + f"Model '{model_name}' (role: {role}) is not available as a " + f"chat model. Check that the model ID is correct and that it " + f"supports the v1/chat/completions endpoint.\n" + f" Original error: {msg}" + ) from exc + if "401" in msg or "auth" in msg.lower(): + raise RuntimeError( + f"Authentication failed for model '{model_name}' (role: {role}). " + f"Check your API key.\n Original error: {msg}" + ) from exc + # Re-raise anything else as-is + raise RuntimeError( + f"Failed to reach model '{model_name}' (role: {role}).\n" + f" Original error: {msg}" + ) from exc + + +def _resolve_repos( + raw_repos: list[dict], config_dir: Path, workspace: Path +) -> list[dict]: + repos = [] + for raw in raw_repos: + if not isinstance(raw, dict): + raise ValueError("Each repo entry must be a mapping/object.") + + name = raw.get("name") + if not name: + raise ValueError("Each repo requires a 'name'.") + + path_value = raw.get("path") + if path_value: + path = Path(path_value) + if not path.is_absolute(): + path = (config_dir / path).resolve() + else: + # Default: clone into /repos/ + path = (workspace / "repos" / name).resolve() + + repos.append({ + "name": name, + "path": path, + "url": raw.get("url"), + "branch": raw.get("branch"), + "checkout": bool(raw.get("checkout", False)), + "checks": raw.get("checks") or [], + "description": raw.get("description") or "", + "language": raw.get("language", "generic"), + }) + return repos + + +def _run_command( + args: list[str], cwd: Path | None = None, timeout: int = 600 +) -> tuple[int, str, str]: + try: + result = subprocess.run( + args, + text=True, + capture_output=True, + timeout=timeout, + cwd=cwd, + check=False, + ) + except Exception as exc: + return 1, "", f"Error: {exc}" + return result.returncode, result.stdout, result.stderr + + +def _ensure_checkout(repo: dict) -> None: + path = repo["path"] + url = repo.get("url") + branch = repo.get("branch") + if not repo.get("checkout"): + return + + if not path.exists(): + if not url: + raise RuntimeError( + f"Repo {repo['name']} missing locally and no url provided." + ) + args = ["git", "clone"] + if branch: + args.extend(["--branch", branch]) + args.extend([url, str(path)]) + code, stdout, stderr = _run_command(args) + if code != 0: + raise RuntimeError( + f"git clone failed for {repo['name']}\n{stdout}\n{stderr}" + ) + return + + if not branch: + return + + # Check current branch -- skip checkout if already on the right one + code, current, _ = _run_command([ + "git", + "-C", + str(path), + "rev-parse", + "--abbrev-ref", + "HEAD", + ]) + current = current.strip() + if code == 0 and current == branch: + console.print( + f" [dim]{repo['name']}:[/dim] already on [cyan]{branch}[/cyan]" + ) + return + + # Attempt checkout + code, stdout, stderr = _run_command([ + "git", + "-C", + str(path), + "checkout", + branch, + ]) + if code == 0: + return + + # Checkout failed -- likely dirty working tree. Warn and continue on + # the current branch rather than crashing the entire run. + console.print( + Panel( + f"[bold yellow]{repo['name']}:[/bold yellow] " + f"Could not checkout [cyan]{branch}[/cyan] " + f"(staying on [cyan]{current}[/cyan]).\n\n" + f"[dim]{stderr.strip()}[/dim]\n\n" + "Tip: commit or stash local changes, or set " + "[bold]checkout: false[/bold] in the config.", + border_style="yellow", + expand=False, + ) + ) + + +def _ensure_repo_symlink(workspace: Path, repo: dict) -> Path: + repos_dir = workspace / "repos" + repos_dir.mkdir(exist_ok=True) + target = repos_dir / repo["name"] + source = repo["path"] + + # If the repo was cloned directly into workspace/repos/, + # it's already in the right place -- no symlink needed. + if target.resolve() == source.resolve(): + return target + + if target.exists() or target.is_symlink(): + if target.is_symlink() and target.resolve() == source.resolve(): + return target + raise RuntimeError(f"Repo link target already exists: {target}") + + target.symlink_to(source, target_is_directory=True) + return target + + +def _format_repo_list(repos: list[dict]) -> str: + lines = [] + for repo in repos: + desc = repo.get("description") + extra = f" - {desc}" if desc else "" + branch = repo.get("branch") + branch_note = f" (branch: {branch})" if branch else "" + lines.append( + f"- {repo['name']}: repos/{repo['name']}{branch_note}{extra}" + ) + return "\n".join(lines) + + +def _planner_prompt( + problem: str, repos: list[dict], research: str | None +) -> str: + repo_block = _format_repo_list(repos) + research_block = f"\n\nResearch notes:\n{research}\n" if research else "" + repo_names = ", ".join([repo["name"] for repo in repos]) + return ( + "You are planning changes across multiple git repositories.\n" + "Create a step-by-step plan that can be executed independently per repo.\n\n" + f"Available repos (use repo field from this list only): {repo_names}\n" + f"Repo details:\n{repo_block}\n\n" + f"Problem:\n{problem}\n" + f"{research_block}\n" + "Rules:\n" + "- Each step MUST include a 'repo' field matching one of the repo names.\n" + "- If a task affects multiple repos, split it into separate steps per repo.\n" + "- Prefer small, reviewable steps that can run in parallel across repos.\n" + "- Include expected outputs and success criteria for each step.\n" + "- Use 'depends_on_repos' to declare cross-repo dependencies.\n" + " When a step in repo B depends on changes made in repo A (e.g. repo B\n" + " consumes a library from repo A), set depends_on_repos: [A].\n" + " This means ALL steps in repo A must complete before this step starts.\n" + " Steps with no dependencies run in parallel. Only add dependencies\n" + " when there is a real build/import dependency, not just logical ordering.\n" + " CRITICAL RULES for depends_on_repos:\n" + " - NEVER create circular dependencies (A depends on B AND B depends on A).\n" + " Dependencies must form a DAG (directed acyclic graph).\n" + " - NEVER list a repo as depending on itself — steps within a repo are\n" + " already sequential and self-deps will be stripped.\n" + " - Dependencies flow one direction: libraries -> consumers, never back.\n" + ) + + +async def _gather_research( + llm, + workspace: Path, + research_cfg: dict | None, + problem: str, + repos: list[dict] | None = None, +) -> str | None: + sections: list[str] = [] + + # -- GitHub context: auto-fetch issues/PRs from repo URLs -- + if repos: + gh_cfg = (research_cfg or {}).get("github", {}) or {} + if gh_cfg.get("enabled", True): + max_issues = int(gh_cfg.get("max_issues", 10)) + max_prs = int(gh_cfg.get("max_prs", 10)) + gh_context = gather_github_context( + repos, max_issues=max_issues, max_prs=max_prs + ) + if gh_context: + sections.append("# GitHub Repository Context\n\n" + gh_context) + console.print( + "[green]Fetched GitHub issues/PRs for repos with GitHub URLs.[/green]" + ) + else: + console.print( + "[dim]No GitHub context available " + "(gh CLI missing or no GitHub URLs).[/dim]" + ) + + # -- Explicit web search queries (optional) -- + queries = (research_cfg or {}).get("queries") or [] + if queries: + try: + agent = WebSearchAgent( + llm=llm, + workspace=workspace, + summarize=True, + max_results=int((research_cfg or {}).get("max_results", 3)), + ) + except Exception as exc: + console.print( + f"[bold yellow]WebSearchAgent unavailable:[/bold yellow] {exc}" + ) + agent = None + + if agent: + for query in queries: + context = f"{problem}\n\nResearch focus: {query}" + result = await agent.ainvoke({ + "query": query, + "context": context, + }) + summary = result.get("final_summary") or "" + sections.append(f"Query: {query}\n{summary}") + + return "\n\n".join(sections) if sections else None + + +async def _plan( + llm, + problem: str, + repos: list[dict], + research: str | None, + reflection_steps: int, +) -> RepoPlan: + prompt = _planner_prompt(problem, repos, research) + messages = [SystemMessage(content=prompt)] + structured_llm = llm.with_structured_output(RepoPlan) + plan = structured_llm.invoke(messages) + + for _ in range(max(0, reflection_steps)): + review = llm.invoke([ + SystemMessage(content=reflection_prompt), + HumanMessage(content=plan.model_dump_json()), + ]) + review_text = (review.text or "").strip() + if "[APPROVED]" in review_text: + break + messages = [ + SystemMessage(content=prompt), + HumanMessage( + content=f"Reviewer notes:\n{review_text}\n\nRevise the plan." + ), + ] + plan = structured_llm.invoke(messages) + + return plan + + +def _write_plan(workspace: Path, plan: RepoPlan) -> None: + plan_path = workspace / "plan.json" + plan_path.write_text(plan.model_dump_json(indent=2)) + + +def _render_repo_plan(plan: RepoPlan) -> None: + """Display the multi-repo plan as a Rich table.""" + table = Table( + title="[bold]Multi-Repo Plan[/bold]", + box=box.ROUNDED, + show_lines=True, + header_style="bold magenta", + expand=True, + ) + table.add_column("#", style="bold cyan", no_wrap=True, width=3) + table.add_column("Repo", style="bold yellow", no_wrap=True) + table.add_column("Step", overflow="fold") + table.add_column("Description", overflow="fold") + table.add_column("Code?", justify="center", no_wrap=True) + + for i, step in enumerate(plan.steps, 1): + code_badge = Text.from_markup( + "[bold green]Yes[/]" if step.requires_code else "[dim]No[/dim]" + ) + table.add_row( + str(i), + step.repo, + step.name, + step.description, + code_badge, + ) + + console.print(table) + + # Summary: steps per repo + repo_counts: dict[str, int] = {} + for step in plan.steps: + repo_counts[step.repo] = repo_counts.get(step.repo, 0) + 1 + summary_parts = [ + f"[bold]{name}[/bold]: {count}" + for name, count in sorted(repo_counts.items()) + ] + console.print( + Panel( + " ".join(summary_parts), + title="[bold]Steps per repo[/bold]", + border_style="cyan", + expand=False, + ) + ) + + +def _group_steps_by_repo(plan: RepoPlan) -> dict[str, list[RepoStep]]: + grouped: dict[str, list[RepoStep]] = {} + for step in plan.steps: + grouped.setdefault(step.repo, []).append(step) + return grouped + + +def _validate_plan_repos(plan: RepoPlan, repos: list[dict]) -> None: + repo_names = {repo["name"] for repo in repos} + invalid = sorted({step.repo for step in plan.steps} - repo_names) + if invalid: + raise RuntimeError( + "Plan referenced unknown repos: " + ", ".join(invalid) + ) + # Validate and sanitize depends_on_repos + plan_repos = {step.repo for step in plan.steps} + for step in plan.steps: + bad_deps = sorted(set(step.depends_on_repos) - plan_repos) + if bad_deps: + console.print( + f"[bold yellow]Warning:[/bold yellow] step '{step.name}' in " + f"repo '{step.repo}' depends on repos not in the plan: " + + ", ".join(bad_deps) + ) + # Strip self-dependencies — steps within a repo are already sequential + if step.repo in step.depends_on_repos: + step.depends_on_repos = [ + d for d in step.depends_on_repos if d != step.repo + ] + console.print( + f"[dim]Stripped self-dependency from step '{step.name}' " + f"in repo '{step.repo}'[/dim]" + ) + + # Build repo-level dependency graph and break any cycles. + # A repo A depends on repo B if ANY step in A lists B in depends_on_repos. + dep_graph: dict[str, set[str]] = {name: set() for name in plan_repos} + for step in plan.steps: + for dep in step.depends_on_repos: + if dep in plan_repos and dep != step.repo: + dep_graph[step.repo].add(dep) + + # Detect and break cycles via DFS + UNVISITED, IN_PROGRESS, DONE = 0, 1, 2 + state: dict[str, int] = {name: UNVISITED for name in dep_graph} + broken_edges: list[tuple[str, str]] = [] + + def _visit(node: str, stack: set[str]) -> None: + state[node] = IN_PROGRESS + stack.add(node) + for dep in list(dep_graph.get(node, set())): + if dep in stack: + # Back-edge found — break the cycle by removing this edge + dep_graph[node].discard(dep) + broken_edges.append((node, dep)) + elif state[dep] == UNVISITED: + _visit(dep, stack) + stack.discard(node) + state[node] = DONE + + for repo in list(dep_graph): + if state[repo] == UNVISITED: + _visit(repo, set()) + + # Remove broken edges from the actual step data + if broken_edges: + broken_set = set(broken_edges) + for step in plan.steps: + removed = [ + d for d in step.depends_on_repos if (step.repo, d) in broken_set + ] + if removed: + step.depends_on_repos = [ + d + for d in step.depends_on_repos + if (step.repo, d) not in broken_set + ] + edge_strs = [f"{a} -> {b}" for a, b in broken_edges] + console.print( + Panel( + "Circular repo dependencies detected and broken:\n" + + "\n".join(f" {e}" for e in edge_strs) + + "\n\nThese dependency edges were removed to prevent deadlock. " + "Steps that lost dependencies will run without waiting.", + title="[bold yellow]Cycle detected[/bold yellow]", + border_style="yellow", + expand=False, + ) + ) + + +def _progress_path( + workspace: Path, + repo_name: str, + resume_dir: Path | None, + resume_files: dict[str, Path], +) -> Path: + if repo_name in resume_files: + return resume_files[repo_name] + + progress_dir = resume_dir or (workspace / "progress") + progress_dir.mkdir(exist_ok=True, parents=True) + return progress_dir / f"{repo_name}.json" + + +async def _repo_checkpointer(workspace: Path, repo_name: str): + ckpt_dir = workspace / "checkpoints" / repo_name + ckpt_dir.mkdir(parents=True, exist_ok=True) + db_path = ckpt_dir / "executor.db" + conn = await aiosqlite.connect(str(db_path)) + checkpointer = AsyncSqliteSaver(conn) + return checkpointer, conn, db_path + + +def _namespace_to_dict(value): + if isinstance(value, dict): + return {k: _namespace_to_dict(v) for k, v in value.items()} + if isinstance(value, list): + return [_namespace_to_dict(v) for v in value] + if isinstance(value, tuple): + return [_namespace_to_dict(v) for v in value] + if hasattr(value, "__dict__"): + return { + k: _namespace_to_dict(v) + for k, v in vars(value).items() + if not k.startswith("_") + } + return value + + +def _repo_token_snapshot(info: dict) -> dict[str, int]: + return { + "input_tokens": int(info.get("input_tokens", 0)), + "output_tokens": int(info.get("output_tokens", 0)), + "total_tokens": int(info.get("total_tokens", 0)), + } + + +def _step_token_record( + *, step_index: int, step_name: str, status: str, tokens: dict[str, int] +) -> dict: + return { + "step_index": step_index, + "step_name": step_name, + "status": status, + "input_tokens": int(tokens.get("input_tokens", 0)), + "output_tokens": int(tokens.get("output_tokens", 0)), + "total_tokens": int(tokens.get("total_tokens", 0)), + } + + +def _format_step_token_usage(tokens: dict[str, int]) -> str: + return ( + f"Step token usage: {_fmt_tokens(tokens.get('input_tokens', 0))} in / " + f"{_fmt_tokens(tokens.get('output_tokens', 0))} out " + f"({_fmt_tokens(tokens.get('total_tokens', 0))} total)" + ) + + +def _parse_resume_overrides( + paths: list[str] | None, config_dir: Path +) -> tuple[Path | None, dict[str, Path]]: + resume_dir: Path | None = None + resume_files: dict[str, Path] = {} + + for raw in paths or []: + path = Path(raw) + if not path.is_absolute(): + path = (config_dir / path).resolve() + + if path.is_dir(): + if resume_dir and resume_dir != path: + raise ValueError("Only one resume directory may be provided.") + resume_dir = path + continue + + if not path.exists(): + raise ValueError(f"Resume checkpoint not found: {path}") + resume_files[path.stem] = path + + return resume_dir, resume_files + + +def _list_progress_files(progress_dir: Path) -> list[Path]: + if not progress_dir.exists(): + return [] + return sorted(p for p in progress_dir.glob("*.json") if p.is_file()) + + +def _choose_resume_dir(workspace: Path, timeout: int) -> Path | None: + progress_dir = workspace / "progress" + progress_files = _list_progress_files(progress_dir) + if not progress_files: + return None + + files_preview = "\n".join(f"- {p.name}" for p in progress_files[:12]) + if len(progress_files) > 12: + files_preview += f"\n... ({len(progress_files) - 12} more)" + + console.print( + Panel( + f"Found saved progress files in {progress_dir}:\n{files_preview}", + title="[bold yellow]Resume available[/bold yellow]", + border_style="yellow", + expand=False, + ) + ) + prompt = f"Resume from {progress_dir.name}? [Y/n] (auto in {timeout}s) > " + answer = timed_input_with_countdown(prompt, timeout) + if answer and answer.strip().lower() in {"n", "no"}: + return None + return progress_dir + + +def _init_progress(repo_steps: dict[str, list[RepoStep]]) -> dict[str, dict]: + now = time.time() + return { + name: { + "state": "queued", + "step": 0, + "total": len(steps), + "current": None, + "error": None, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "_seen_input_tokens": 0, + "_seen_output_tokens": 0, + "_seen_total_tokens": 0, + "last_step_tokens": None, + "step_token_deltas": [], + "started": now, + "step_started": None, + "updated": now, + } + for name, steps in repo_steps.items() + } + + +def _extract_agent_tokens(agent) -> dict[str, int]: + """Read token usage from agent telemetry samples accumulated since last begin_run.""" + totals = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + try: + for sample in agent.telemetry.llm.samples: + rollup = (sample.get("metrics") or {}).get("usage_rollup") or {} + totals["input_tokens"] += int(rollup.get("input_tokens", 0)) + totals["output_tokens"] += int(rollup.get("output_tokens", 0)) + totals["total_tokens"] += int(rollup.get("total_tokens", 0)) + except Exception: + pass + return totals + + +def _token_delta(current: int, seen: int) -> int: + if current >= seen: + return current - seen + return current + + +def _fmt_tokens(n: int) -> str: + """Format token count with K/M suffix.""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + if n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(n) + + +_STATE_STYLES = { + "queued": ("dim", "..."), + "blocked": ("bold yellow", "||"), + "paused": ("yellow", "!!"), + "running": ("bold cyan", ">>"), + "done": ("bold green", "ok"), + "failed": ("bold red", "!!"), +} + + +def _build_progress_table( + snapshot: dict[str, dict], max_parallel: int +) -> Table: + counts: dict[str, int] = { + "queued": 0, + "blocked": 0, + "paused": 0, + "running": 0, + "done": 0, + "failed": 0, + } + grand_total_tokens = 0 + for info in snapshot.values(): + state = info.get("state", "queued") + counts[state] = counts.get(state, 0) + 1 + grand_total_tokens += info.get("total_tokens", 0) + info.get( + "_live_total", 0 + ) + + title = ( + f"[bold]active [cyan]{counts['running']}/{max_parallel}[/cyan] " + + ( + f"blocked [yellow]{counts['blocked']}[/yellow] " + if counts["blocked"] + else "" + ) + + ( + f"paused [yellow]{counts['paused']}[/yellow] " + if counts["paused"] + else "" + ) + + f"queued [dim]{counts['queued']}[/dim] " + f"done [green]{counts['done']}[/green] " + f"failed [red]{counts['failed']}[/red] " + f"tokens [yellow]{_fmt_tokens(grand_total_tokens)}[/yellow][/bold]" + ) + + table = Table( + title=title, + box=box.SIMPLE_HEAVY, + show_lines=False, + header_style="bold magenta", + expand=False, + padding=(0, 1), + ) + table.add_column("", no_wrap=True, width=2) + table.add_column("Repo", style="bold", no_wrap=True) + table.add_column("Status", no_wrap=True) + table.add_column("Progress", no_wrap=True) + table.add_column("Elapsed", no_wrap=True, justify="right") + table.add_column("Tokens", no_wrap=True, justify="right") + table.add_column("Step", overflow="fold") + + now = time.time() + for name in sorted(snapshot): + info = snapshot[name] + state = info.get("state", "queued") + step = info.get("step", 0) + total = info.get("total", 0) + current = info.get("current") or "" + error = info.get("error") + in_tok = info.get("input_tokens", 0) + info.get("_live_input", 0) + out_tok = info.get("output_tokens", 0) + info.get("_live_output", 0) + + style, icon = _STATE_STYLES.get(state, ("", "?")) + + progress_bar = "" + if total: + filled = int((step / total) * 10) + progress_bar = f"{'█' * filled}{'░' * (10 - filled)} {step}/{total}" + + # Elapsed time: total for repo + current step duration + elapsed_text = "" + if state not in ("queued",): + started = info.get("started") or now + total_elapsed = now - started + elapsed_text = f"[dim]{fmt_elapsed(total_elapsed)}[/dim]" + step_started = info.get("step_started") + if step_started and state in ("running", "blocked"): + step_elapsed = now - step_started + elapsed_text += f" [bold]({fmt_elapsed(step_elapsed)})[/bold]" + + token_text = "" + if in_tok or out_tok: + token_text = f"[dim]{_fmt_tokens(in_tok)}[/dim]/[bold]{_fmt_tokens(out_tok)}[/bold]" + + step_text = current + if error: + step_text = error[:60] + + elapsed_render = ( + Text.from_markup(elapsed_text) if elapsed_text else Text("") + ) + token_render = Text.from_markup(token_text) if token_text else Text("") + step_render = ( + Text(str(step_text), style="red") if error else Text(str(step_text)) + ) + + table.add_row( + Text(icon, style=style), + Text(str(name)), + Text(state, style=style), + Text(progress_bar), + elapsed_render, + token_render, + step_render, + ) + + return table + + +async def _snapshot_progress( + progress: dict[str, dict], lock: asyncio.Lock +) -> dict[str, dict]: + async with lock: + return {name: dict(info) for name, info in progress.items()} + + +async def _emit_progress( + progress: dict[str, dict], lock: asyncio.Lock, max_parallel: int +) -> None: + snapshot = await _snapshot_progress(progress, lock) + console.print(_build_progress_table(snapshot, max_parallel)) + + +def _executor_prompt( + problem: str, + repo: dict, + step: RepoStep, + step_index: int, + total_steps: int, + previous_summary: str | None, +) -> str: + prev = previous_summary or "None" + checks = repo.get("checks") or [] + checks_block = "" + if checks: + checks_list = "\n".join(f" - {c}" for c in checks) + checks_block = ( + f"\nVerification commands (will be run automatically after this step):\n" + f"{checks_list}\n" + f"You MUST ensure these commands pass before considering the step complete.\n" + f"Run the verification commands yourself using language-specific tools " + f"and fix any failures before finishing.\n" + ) + return ( + f"Working repo: {repo['name']} (path: repos/{repo['name']}).\n" + f"Overall goal:\n{problem}\n\n" + f"Step {step_index + 1} of {total_steps}: {step.name}\n" + f"Description: {step.description}\n\n" + f"Expected outputs:\n- " + "\n- ".join(step.expected_outputs) + "\n\n" + "Success criteria:\n- " + "\n- ".join(step.success_criteria) + "\n\n" + f"Previous step summary:\n{prev}\n" + f"{checks_block}\n" + "Use git tools with repo_path='repos/{repo_name}'.\n" + "Use language-specific tools to validate your changes.\n" + "Report the changes you made and the git status/diff summary." + ).replace("{repo_name}", repo["name"]) + + +def _fix_prompt( + repo: dict, + step: RepoStep, + check_output: str, + attempt: int, + max_attempts: int, +) -> str: + return ( + f"The verification checks FAILED after completing step '{step.name}' " + f"in repo {repo['name']} (attempt {attempt}/{max_attempts}).\n\n" + f"Failure output:\n{check_output}\n\n" + f"Fix the failing tests. Do NOT move on to other work -- focus entirely " + f"on making the checks pass. Run the tests again after your fix to confirm." + ) + + +async def _run_checks(repo: dict, workspace: Path) -> list[dict]: + results = [] + checks = repo.get("checks") or [] + if not checks: + return results + + log_dir = workspace / "checks" + log_dir.mkdir(exist_ok=True) + log_path = log_dir / f"{repo['name']}.log" + + with open(log_path, "w", encoding="utf-8") as log: + for command in checks: + args = shlex.split(command) + code, stdout, stderr = _run_command(args, cwd=repo["path"]) + log.write(f"$ {command}\n") + log.write(stdout) + if stderr: + log.write("\nSTDERR:\n") + log.write(stderr) + log.write("\n\n") + results.append({ + "command": command, + "exit_code": code, + "stdout": stdout, + "stderr": stderr, + "log": str(log_path), + }) + return results + + +def _checks_passed(check_results: list[dict]) -> bool: + return all(c.get("exit_code", 1) == 0 for c in check_results) + + +def _format_check_failures( + check_results: list[dict], max_output: int = 2000 +) -> str: + """Format failed check output for inclusion in a retry prompt.""" + lines = [] + for cr in check_results: + if cr.get("exit_code", 1) != 0: + lines.append( + f"FAILED: {cr['command']} (exit code {cr['exit_code']})" + ) + out = (cr.get("stdout") or "").strip() + err = (cr.get("stderr") or "").strip() + combined = f"{out}\n{err}".strip() + if len(combined) > max_output: + combined = combined[:max_output] + "\n... (truncated)" + lines.append(combined) + return "\n\n".join(lines) + + +async def _accumulate_tokens( + agent, progress_state: dict, repo_name: str, lock: asyncio.Lock +) -> dict[str, int]: + """Extract tokens from agent telemetry and add to progress state. + + Clears any live-preview counters set by the heartbeat so they aren't + double-counted in the progress table. + """ + cumulative = _extract_agent_tokens(agent) + async with lock: + info = progress_state[repo_name] + seen_in = int(info.get("_seen_input_tokens", 0)) + seen_out = int(info.get("_seen_output_tokens", 0)) + seen_total = int(info.get("_seen_total_tokens", 0)) + + step_tokens = { + "input_tokens": _token_delta(cumulative["input_tokens"], seen_in), + "output_tokens": _token_delta( + cumulative["output_tokens"], seen_out + ), + "total_tokens": _token_delta( + cumulative["total_tokens"], seen_total + ), + } + + info["input_tokens"] = ( + info.get("input_tokens", 0) + step_tokens["input_tokens"] + ) + info["output_tokens"] = ( + info.get("output_tokens", 0) + step_tokens["output_tokens"] + ) + info["total_tokens"] = ( + info.get("total_tokens", 0) + step_tokens["total_tokens"] + ) + info["_seen_input_tokens"] = cumulative["input_tokens"] + info["_seen_output_tokens"] = cumulative["output_tokens"] + info["_seen_total_tokens"] = cumulative["total_tokens"] + # Clear live counters — the real totals are now in the main fields + info.pop("_live_input", None) + info.pop("_live_output", None) + info.pop("_live_total", None) + return step_tokens + + +async def _ainvoke_with_heartbeat( + agent, + prompt: str, + recursion_limit: int, + repo_name: str, + step_name: str, + timeout_sec: int = 0, + heartbeat_sec: int = 60, + progress_state: dict[str, dict] | None = None, + progress_lock: asyncio.Lock | None = None, +) -> dict: + """Run agent.ainvoke with periodic heartbeat logs and optional timeout. + + Heartbeat logs print every *heartbeat_sec* seconds so the user can see + the agent is still alive even when no progress-table fields change. + If *timeout_sec* > 0, the call is cancelled after that many seconds. + Token counts in *progress_state* are updated live during heartbeats. + """ + + async def _heartbeat(stop: asyncio.Event): + elapsed = 0 + while not stop.is_set(): + await asyncio.sleep(heartbeat_sec) + if stop.is_set(): + break + elapsed += heartbeat_sec + # Read live token snapshot and update progress for display + if progress_state is not None and progress_lock is not None: + cumulative = _extract_agent_tokens(agent) + async with progress_lock: + info = progress_state.get(repo_name, {}) + seen_in = int(info.get("_seen_input_tokens", 0)) + seen_out = int(info.get("_seen_output_tokens", 0)) + seen_total = int(info.get("_seen_total_tokens", 0)) + # Store live step tokens separately so _accumulate_tokens + # can do the final authoritative add without double-counting + info["_live_input"] = _token_delta( + cumulative["input_tokens"], seen_in + ) + info["_live_output"] = _token_delta( + cumulative["output_tokens"], seen_out + ) + info["_live_total"] = _token_delta( + cumulative["total_tokens"], seen_total + ) + tok_str = "" + if progress_state and repo_name in progress_state: + info = progress_state[repo_name] + total = info.get("total_tokens", 0) + info.get("_live_total", 0) + if total: + tok_str = f" [{_fmt_tokens(total)} tokens]" + console.log( + f"[dim]{repo_name}[/dim] step [bold]{step_name}[/bold] " + f"still running ({fmt_elapsed(elapsed)}){tok_str}" + ) + + stop = asyncio.Event() + hb_task = asyncio.create_task(_heartbeat(stop)) + + try: + coro = agent.ainvoke( + prompt, config={"recursion_limit": recursion_limit} + ) + if timeout_sec > 0: + result = await asyncio.wait_for(coro, timeout=timeout_sec) + else: + result = await coro + return result + except asyncio.TimeoutError: + console.log( + f"[bold red]{repo_name}[/bold red] step [bold]{step_name}[/bold] " + f"timed out after {fmt_elapsed(timeout_sec)}" + ) + raise + finally: + stop.set() + hb_task.cancel() + try: + await hb_task + except asyncio.CancelledError: + pass + + +async def _wait_for_repos( + deps: list[str], + repo_done_events: dict[str, asyncio.Event], + repo_name: str, + progress_state: dict[str, dict], + progress_lock: asyncio.Lock, + max_parallel: int, +) -> None: + """Block until dependency repos reach a terminal state. + + If any dependency is terminal but not "done" (e.g. paused/failed), + raise so callers fail gracefully instead of waiting forever. + """ + # Filter out self-dependencies — a repo's own steps are already sequential + pending = [ + d + for d in deps + if d != repo_name + and d in repo_done_events + and not repo_done_events[d].is_set() + ] + if not pending: + return + + async with progress_lock: + info = progress_state[repo_name] + info["state"] = "blocked" + info["current"] = f"waiting on {', '.join(pending)}" + info["updated"] = time.time() + await _emit_progress(progress_state, progress_lock, max_parallel) + + await asyncio.gather(*(repo_done_events[d].wait() for d in pending)) + + async with progress_lock: + blocked = { + dep: progress_state.get(dep, {}).get("state", "unknown") + for dep in deps + if dep != repo_name + and dep in repo_done_events + and progress_state.get(dep, {}).get("state") != "done" + } + + if blocked: + blockers = ", ".join( + f"{name} ({state})" for name, state in sorted(blocked.items()) + ) + raise RuntimeError( + f"Dependency repo(s) not completed successfully: {blockers}" + ) + + async with progress_lock: + info = progress_state[repo_name] + info["state"] = "running" + info["updated"] = time.time() + await _emit_progress(progress_state, progress_lock, max_parallel) + + +async def _run_repo_steps( + repo: dict, + steps: list[RepoStep], + problem: str, + workspace: Path, + llm, + recursion_limit: int, + resume: bool, + progress_state: dict[str, dict], + progress_lock: asyncio.Lock, + max_parallel: int, + resume_dir: Path | None, + resume_files: dict[str, Path], + max_check_retries: int = 2, + repo_done_events: dict[str, asyncio.Event] | None = None, + step_timeout_sec: int = 0, + timeout_mode: str = "pause", + checkpointer=None, + thread_id: str | None = None, +) -> dict: + plan_hash = hash_plan(steps) + thread_id = thread_id or f"{repo['name']}-{plan_hash[:8]}" + agent = make_git_agent( + llm=llm, + language=repo.get("language", "generic"), + workspace=workspace, + checkpointer=checkpointer, + thread_id=thread_id, + ) + progress_path = _progress_path( + workspace, repo["name"], resume_dir, resume_files + ) + resume_progress = load_json_file(progress_path, {}) if resume else {} + start_index = int(resume_progress.get("next_index", 0)) if resume else 0 + has_checks = bool(repo.get("checks")) + step_token_deltas = list(resume_progress.get("step_token_deltas") or []) + last_step_tokens = resume_progress.get("last_step_tokens") + + if resume and resume_progress.get("plan_hash") != plan_hash: + start_index = 0 + step_token_deltas = [] + last_step_tokens = None + + last_summary = resume_progress.get("last_summary") if resume else None + step_outputs_dir = workspace / "step_outputs" / repo["name"] + step_outputs_dir.mkdir(parents=True, exist_ok=True) + + now = time.time() + async with progress_lock: + info = progress_state[repo["name"]] + info.update({ + "state": "running", + "step": start_index, + "total": len(steps), + "current": None, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + "started": now, + "step_started": None, + "updated": now, + }) + await _emit_progress(progress_state, progress_lock, max_parallel) + + all_check_results: list[dict] = [] + + for idx in range(start_index, len(steps)): + step = steps[idx] + + # -- Wait for cross-repo dependencies -- + if step.depends_on_repos and repo_done_events: + await _wait_for_repos( + deps=step.depends_on_repos, + repo_done_events=repo_done_events, + repo_name=repo["name"], + progress_state=progress_state, + progress_lock=progress_lock, + max_parallel=max_parallel, + ) + + step_start = time.time() + async with progress_lock: + info = progress_state[repo["name"]] + info.update({ + "state": "running", + "step": idx + 1, + "current": step.name, + "step_started": step_start, + "updated": step_start, + }) + await _emit_progress(progress_state, progress_lock, max_parallel) + prompt = _executor_prompt( + problem=problem, + repo=repo, + step=step, + step_index=idx, + total_steps=len(steps), + previous_summary=last_summary, + ) + try: + result = await _ainvoke_with_heartbeat( + agent=agent, + prompt=prompt, + recursion_limit=recursion_limit, + repo_name=repo["name"], + step_name=step.name, + timeout_sec=step_timeout_sec, + progress_state=progress_state, + progress_lock=progress_lock, + ) + except asyncio.TimeoutError: + timeout_tokens = await _accumulate_tokens( + agent, progress_state, repo["name"], progress_lock + ) + timeout_msg = ( + f"Timed out after {fmt_elapsed(step_timeout_sec)}: {step.name}" + ) + step_record = _step_token_record( + step_index=idx + 1, + step_name=step.name, + status="timed_out", + tokens=timeout_tokens, + ) + step_token_deltas.append(step_record) + last_step_tokens = step_record + (step_outputs_dir / f"step_{idx + 1}.md").write_text( + f"{timeout_msg}\n\n{_format_step_token_usage(timeout_tokens)}", + encoding="utf-8", + ) + if timeout_mode == "skip": + console.log( + f"[bold yellow]{repo['name']}[/bold yellow] step {idx + 1} " + f"timed out; skipping and continuing" + ) + async with progress_lock: + info = progress_state[repo["name"]] + info["last_step_tokens"] = last_step_tokens + info["step_token_deltas"] = step_token_deltas + repo_tokens = _repo_token_snapshot(info) + save_json_file( + progress_path, + { + "next_index": idx + 1, + "plan_hash": plan_hash, + "last_summary": timeout_msg, + "state": "running", + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + }, + ) + continue + if timeout_mode == "fail": + console.log( + f"[bold red]{repo['name']}[/bold red] step {idx + 1} " + "timed out; failing repo" + ) + async with progress_lock: + info = progress_state[repo["name"]] + info["last_step_tokens"] = last_step_tokens + info["step_token_deltas"] = step_token_deltas + repo_tokens = _repo_token_snapshot(info) + save_json_file( + progress_path, + { + "next_index": idx, + "plan_hash": plan_hash, + "last_summary": timeout_msg, + "state": "failed", + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + }, + ) + raise + + async with progress_lock: + info = progress_state[repo["name"]] + info.update({ + "state": "paused", + "current": step.name, + "error": timeout_msg, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + "updated": time.time(), + }) + repo_tokens = _repo_token_snapshot(info) + await _emit_progress(progress_state, progress_lock, max_parallel) + save_json_file( + progress_path, + { + "next_index": idx, + "plan_hash": plan_hash, + "last_summary": timeout_msg, + "state": "paused", + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + }, + ) + if repo_done_events and repo["name"] in repo_done_events: + repo_done_events[repo["name"]].set() + return { + "repo": repo["name"], + "steps": idx, + "checks": all_check_results, + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + "state": "paused", + } + step_tokens = await _accumulate_tokens( + agent, progress_state, repo["name"], progress_lock + ) + step_total_tokens = dict(step_tokens) + + summary = result["messages"][-1].text + last_summary = summary + + # -- Run checks after each step -- + if has_checks: + check_results = await _run_checks(repo, workspace) + passed = _checks_passed(check_results) + + for cr in check_results: + status = ( + "[green]pass[/green]" + if cr["exit_code"] == 0 + else "[red]FAIL[/red]" + ) + console.log( + f"[bold]{repo['name']}[/bold] step {idx + 1} check {status}: {cr['command']}" + ) + + # Retry loop: give the agent a chance to fix failures + attempt = 0 + while not passed and attempt < max_check_retries: + attempt += 1 + failure_output = _format_check_failures(check_results) + console.log( + f"[bold yellow]{repo['name']}[/bold yellow] step {idx + 1} " + f"checks failed, retry {attempt}/{max_check_retries}" + ) + async with progress_lock: + info = progress_state[repo["name"]] + info["current"] = ( + f"{step.name} (fix {attempt}/{max_check_retries})" + ) + info["updated"] = time.time() + await _emit_progress( + progress_state, progress_lock, max_parallel + ) + + fix = _fix_prompt( + repo=repo, + step=step, + check_output=failure_output, + attempt=attempt, + max_attempts=max_check_retries, + ) + result = await _ainvoke_with_heartbeat( + agent=agent, + prompt=fix, + recursion_limit=recursion_limit, + repo_name=repo["name"], + step_name=f"{step.name} (fix {attempt})", + timeout_sec=step_timeout_sec, + progress_state=progress_state, + progress_lock=progress_lock, + ) + fix_tokens = await _accumulate_tokens( + agent, progress_state, repo["name"], progress_lock + ) + step_total_tokens["input_tokens"] += fix_tokens["input_tokens"] + step_total_tokens["output_tokens"] += fix_tokens[ + "output_tokens" + ] + step_total_tokens["total_tokens"] += fix_tokens["total_tokens"] + summary = result["messages"][-1].text + last_summary = summary + + check_results = await _run_checks(repo, workspace) + passed = _checks_passed(check_results) + for cr in check_results: + status = ( + "[green]pass[/green]" + if cr["exit_code"] == 0 + else "[red]FAIL[/red]" + ) + console.log( + f"[bold]{repo['name']}[/bold] step {idx + 1} retry {attempt} " + f"check {status}: {cr['command']}" + ) + + all_check_results = check_results # keep latest results + + if not passed: + console.log( + f"[bold red]{repo['name']}[/bold red] step {idx + 1} " + f"checks still failing after {max_check_retries} retries, " + f"continuing to next step" + ) + + step_record = _step_token_record( + step_index=idx + 1, + step_name=step.name, + status="completed", + tokens=step_total_tokens, + ) + step_token_deltas.append(step_record) + last_step_tokens = step_record + + (step_outputs_dir / f"step_{idx + 1}.md").write_text( + summary + "\n\n" + _format_step_token_usage(step_total_tokens), + encoding="utf-8", + ) + async with progress_lock: + info = progress_state[repo["name"]] + info["last_step_tokens"] = last_step_tokens + info["step_token_deltas"] = step_token_deltas + repo_tokens = _repo_token_snapshot(info) + save_json_file( + progress_path, + { + "next_index": idx + 1, + "plan_hash": plan_hash, + "last_summary": summary, + "state": "running", + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + }, + ) + step_tok_str = ( + _fmt_tokens(step_total_tokens["total_tokens"]) + if step_total_tokens["total_tokens"] + else "" + ) + console.log( + f"[bold]{repo['name']}[/bold] step {idx + 1}/{len(steps)} " + f"[green]complete[/green]: {step.name}" + + (f" [dim]({step_tok_str} tokens)[/dim]" if step_tok_str else "") + ) + + async with progress_lock: + info = progress_state[repo["name"]] + info.update({ + "state": "done", + "step": len(steps), + "current": None, + "updated": time.time(), + }) + # Signal that this repo is done so dependent repos can proceed + if repo_done_events and repo["name"] in repo_done_events: + repo_done_events[repo["name"]].set() + await _emit_progress(progress_state, progress_lock, max_parallel) + + # Final check run (catches anything the last step may have missed) + if has_checks: + all_check_results = await _run_checks(repo, workspace) + for cr in all_check_results: + status = ( + "[green]pass[/green]" + if cr["exit_code"] == 0 + else "[red]FAIL[/red]" + ) + console.log( + f"[bold]{repo['name']}[/bold] final check {status}: {cr['command']}" + ) + else: + console.log( + f"[bold]{repo['name']}[/bold] [dim]no checks configured[/dim]" + ) + + async with progress_lock: + info = progress_state[repo["name"]] + repo_tokens = _repo_token_snapshot(info) + + save_json_file( + progress_path, + { + "next_index": len(steps), + "plan_hash": plan_hash, + "last_summary": last_summary, + "state": "done", + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + }, + ) + + return { + "repo": repo["name"], + "steps": len(steps), + "checks": all_check_results, + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + "state": "done", + } + + +async def _run_parallel( + repo_steps: dict[str, list[RepoStep]], + repos: list[dict], + problem: str, + workspace: Path, + models_cfg: dict, + executor_model: str, + recursion_limit: int, + max_parallel: int, + resume: bool, + status_interval_sec: int, + resume_dir: Path | None, + resume_files: dict[str, Path], + max_check_retries: int = 2, + step_timeout_sec: int = 0, + timeout_mode: str = "pause", + skip_failed_repos: bool = False, +) -> list[dict]: + sem = asyncio.Semaphore(max(1, max_parallel)) + repo_lookup = {repo["name"]: repo for repo in repos} + progress_state = _init_progress(repo_steps) + progress_lock = asyncio.Lock() + + # Create an event per repo so dependents can wait for completion + repo_done_events: dict[str, asyncio.Event] = { + name: asyncio.Event() for name in repo_steps + } + + # Log dependency info + dep_pairs: list[str] = [] + for name, steps in repo_steps.items(): + for step in steps: + dep_pairs.extend( + f"{name} -> {dep}" for dep in step.depends_on_repos + ) + if dep_pairs: + console.print( + Panel( + "\n".join(sorted(set(dep_pairs))), + title="[bold]Repo dependencies[/bold]", + border_style="cyan", + expand=False, + ) + ) + + async def status_loop(stop_event: asyncio.Event): + if status_interval_sec <= 0: + return + while not stop_event.is_set(): + try: + await asyncio.wait_for( + stop_event.wait(), timeout=status_interval_sec + ) + break + except asyncio.TimeoutError: + pass + try: + await _emit_progress( + progress_state, progress_lock, max_parallel + ) + except Exception as exc: + console.log( + "[bold yellow]status reporter encountered render error; " + f"continuing shutdown safely:[/bold yellow] {exc}" + ) + return + + async def run_one(repo_name: str, steps: list[RepoStep]) -> dict: + async with sem: + repo = repo_lookup[repo_name] + llm = setup_llm( + model_choice=executor_model, + models_cfg=models_cfg, + agent_name="executor", + ) + plan_hash = hash_plan(steps) + checkpointer, ckpt_conn, ckpt_path = await _repo_checkpointer( + workspace, repo_name + ) + thread_id = f"{repo_name}-{plan_hash[:8]}" + console.log(f"[dim]{repo_name}[/dim] checkpoint db: {ckpt_path}") + try: + return await _run_repo_steps( + repo=repo, + steps=steps, + problem=problem, + workspace=workspace, + llm=llm, + recursion_limit=recursion_limit, + resume=resume, + progress_state=progress_state, + progress_lock=progress_lock, + max_parallel=max_parallel, + resume_dir=resume_dir, + resume_files=resume_files, + max_check_retries=max_check_retries, + repo_done_events=repo_done_events, + step_timeout_sec=step_timeout_sec, + timeout_mode=timeout_mode, + checkpointer=checkpointer, + thread_id=thread_id, + ) + except asyncio.CancelledError: + async with progress_lock: + info = progress_state[repo_name] + if info.get("state") in {"running", "blocked"}: + info.update({ + "state": "failed", + "error": "Cancelled due to run interruption", + "updated": time.time(), + }) + if repo_name in repo_done_events: + repo_done_events[repo_name].set() + await _emit_progress( + progress_state, progress_lock, max_parallel + ) + raise + except Exception as exc: + async with progress_lock: + info = progress_state[repo_name] + info.update({ + "state": "failed", + "error": str(exc), + "updated": time.time(), + }) + if repo_name in repo_done_events: + repo_done_events[repo_name].set() + await _emit_progress( + progress_state, progress_lock, max_parallel + ) + if skip_failed_repos: + repo_tokens = _repo_token_snapshot(info) + last_step_tokens = info.get("last_step_tokens") + step_token_deltas = list( + info.get("step_token_deltas") or [] + ) + console.log( + f"[bold yellow]{repo_name}[/bold yellow] failed; " + f"continuing due to skip_failed_repos (checkpoint: {ckpt_path})" + ) + return { + "repo": repo_name, + "steps": 0, + "checks": [], + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + "state": "failed", + "error": str(exc), + } + raise + finally: + try: + await ckpt_conn.close() + console.log(f"[dim]{repo_name}[/dim] checkpoint db closed") + except Exception as close_exc: + console.log( + f"[bold yellow]{repo_name}[/bold yellow] could not close " + f"checkpoint db: {close_exc}" + ) + + await _emit_progress(progress_state, progress_lock, max_parallel) + stop_event = asyncio.Event() + reporter = asyncio.create_task(status_loop(stop_event)) + try: + repo_names = list(repo_steps) + tasks = [ + asyncio.create_task(run_one(name, repo_steps[name])) + for name in repo_names + ] + raw_results = await asyncio.gather(*tasks, return_exceptions=True) + console.log( + f"[dim]runner gather complete: {len(raw_results)} repo result(s)[/dim]" + ) + + results: list[dict] = [] + for repo_name, raw in zip(repo_names, raw_results): + if isinstance(raw, dict): + results.append(raw) + continue + + if isinstance(raw, asyncio.CancelledError): + error = "Cancelled due to run interruption" + elif isinstance(raw, BaseException): + error = str(raw) or raw.__class__.__name__ + else: + error = "Unknown failure" + + async with progress_lock: + info = progress_state[repo_name] + if info.get("state") not in {"done", "paused", "failed"}: + info.update({ + "state": "failed", + "error": error, + "updated": time.time(), + }) + repo_tokens = _repo_token_snapshot(info) + last_step_tokens = info.get("last_step_tokens") + step_token_deltas = list(info.get("step_token_deltas") or []) + + if repo_name in repo_done_events: + repo_done_events[repo_name].set() + + results.append({ + "repo": repo_name, + "steps": 0, + "checks": [], + "tokens": repo_tokens, + "last_step_tokens": last_step_tokens, + "step_token_deltas": step_token_deltas, + "state": "failed", + "error": error, + }) + + await _emit_progress(progress_state, progress_lock, max_parallel) + console.log("[dim]runner returning structured results[/dim]") + return results + finally: + console.log("[dim]runner stopping status reporter[/dim]") + stop_event.set() + reporter.cancel() + try: + await asyncio.wait_for(reporter, timeout=3) + except asyncio.TimeoutError: + console.log( + "[bold yellow]status reporter did not stop within 3s; " + "continuing shutdown[/bold yellow]" + ) + except asyncio.CancelledError: + pass + except Exception as exc: + console.log( + "[bold yellow]status reporter exited with error during " + f"shutdown:[/bold yellow] {exc}" + ) + console.log("[dim]runner status reporter stopped[/dim]") + + +def main(): + parser = argparse.ArgumentParser( + description="Multi-repo plan/execute runner" + ) + parser.add_argument( + "--config", + required=True, + help="YAML config for planning and multi-repo execution", + ) + parser.add_argument( + "--workspace", + required=False, + help="Workspace directory for artifacts and logs", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing progress files in the workspace", + ) + parser.add_argument( + "--resume-from", + action="append", + dest="resume_from", + help=( + "Path to a repo progress file (repeatable) or a directory containing" + " progress files" + ), + ) + parser.add_argument( + "--interactive-timeout", + type=int, + default=60, + help="Seconds to wait for interactive resume prompts (0 disables)", + ) + parser.add_argument( + "--timeout-mode", + choices=["pause", "skip", "fail"], + default=None, + help="On step timeout: pause (default), skip step, or fail repo", + ) + parser.add_argument( + "--skip-failed-repos", + action="store_true", + help="Continue other repos if one fails", + ) + args = parser.parse_args() + + cfg = load_yaml_config(args.config) + initial_yaml_config = _namespace_to_dict(cfg) + config_dir = Path(args.config).parent.resolve() + + project = getattr(cfg, "project", "multi_repo_run") + problem = getattr(cfg, "problem", "").strip() + if not problem: + console.print( + "[bold red]Config must include a non-empty 'problem' field.[/bold red]" + ) + sys.exit(2) + + raw_repos = getattr(cfg, "repos", None) + if not raw_repos: + console.print( + "[bold red]Config must include a 'repos' list.[/bold red]" + ) + sys.exit(2) + + workspace = _resolve_workspace(args.workspace, project) + repos = _resolve_repos(raw_repos, config_dir, workspace) + + # -- Workspace banner -- + repo_lines = "\n".join( + f" [bold]{r['name']}[/bold] ({r.get('language', 'generic')})" + + (f" - {r['description']}" if r.get("description") else "") + for r in repos + ) + console.print( + Panel( + f"[bold bright_blue]{workspace}[/bold bright_blue]\n\n" + f"[bold]Repos:[/bold]\n{repo_lines}", + title="[bold green]MULTI-REPO WORKSPACE[/bold green]", + border_style="bright_magenta", + padding=(1, 2), + ) + ) + + with console.status("[bold green]Checking out repos...", spinner="point"): + for repo in repos: + _ensure_checkout(repo) + _ensure_repo_symlink(workspace, repo) + + models_cfg = getattr(cfg, "models", {}) or {} + default_model = (models_cfg.get("default") or None) or ( + models_cfg.get("choices") or ["openai:gpt-5-mini"] + )[0] + planner_model = models_cfg.get("planner") or default_model + executor_model = models_cfg.get("executor") or default_model + console.print( + Panel( + f"[bold]Planner model:[/bold] [cyan]{planner_model}[/cyan] " + f"[bold]Executor model:[/bold] [cyan]{executor_model}[/cyan]", + border_style="cyan", + expand=False, + ) + ) + + # -- Validate models before starting real work -- + with console.status("[bold green]Validating models...", spinner="point"): + planner_llm = setup_llm( + model_choice=planner_model, + models_cfg=models_cfg, + agent_name="planner", + ) + _validate_model(planner_llm, planner_model, "planner") + console.log( + f"[green]✓[/green] planner model [cyan]{planner_model}[/cyan]" + ) + + if executor_model != planner_model: + executor_test_llm = setup_llm( + model_choice=executor_model, + models_cfg=models_cfg, + agent_name="executor", + ) + _validate_model(executor_test_llm, executor_model, "executor") + console.log( + f"[green]✓[/green] executor model [cyan]{executor_model}[/cyan]" + ) + else: + console.log("[green]✓[/green] executor model (same as planner)") + + planner_cfg = getattr(cfg, "planner", {}) or {} + reflection_steps = int(planner_cfg.get("reflection_steps", 0)) + research_cfg = planner_cfg.get("research") or {} + + # -- Problem statement -- + console.print( + Panel( + Text.from_markup(f"[bold cyan]Problem:[/bold cyan]\n{problem}"), + border_style="cyan", + ) + ) + + # -- Research phase -- + with console.status("[bold green]Gathering research...", spinner="point"): + research = asyncio.run( + _gather_research( + llm=planner_llm, + workspace=workspace, + research_cfg=research_cfg, + problem=problem, + repos=repos, + ) + ) + if research: + console.print("[green]Research complete.[/green]") + else: + console.print("[dim]No research context gathered.[/dim]") + + # -- Planning phase -- + with console.status( + f"[bold green]Planning across {len(repos)} repos " + f"(reflection steps: {reflection_steps})...", + spinner="point", + ): + plan = asyncio.run( + _plan( + llm=planner_llm, + problem=problem, + repos=repos, + research=research, + reflection_steps=reflection_steps, + ) + ) + + _validate_plan_repos(plan, repos) + _write_plan(workspace, plan) + _render_repo_plan(plan) + repo_steps = _group_steps_by_repo(plan) + + missing = sorted({repo["name"] for repo in repos} - set(repo_steps)) + if missing: + console.print( + Panel( + "[bold yellow]Plan includes no steps for:[/bold yellow] " + + ", ".join(missing), + border_style="yellow", + expand=False, + ) + ) + + exec_cfg = getattr(cfg, "execution", {}) or {} + max_parallel = int(exec_cfg.get("max_parallel", len(repo_steps))) + recursion_limit = int(exec_cfg.get("recursion_limit", 2000)) + resume = bool(exec_cfg.get("resume", False)) + status_interval_sec = int(exec_cfg.get("status_interval_sec", 5)) + max_check_retries = int(exec_cfg.get("max_check_retries", 2)) + step_timeout_sec = int(exec_cfg.get("step_timeout_sec", 0)) # 0 = no limit + timeout_mode = exec_cfg.get("timeout_mode", "pause") + skip_failed_repos = bool(exec_cfg.get("skip_failed_repos", False)) + resume_dir = None + resume_files: dict[str, Path] = {} + + if args.resume_from: + resume = True + resume_dir, resume_files = _parse_resume_overrides( + args.resume_from, config_dir + ) + + if args.resume: + resume = True + + if args.timeout_mode: + timeout_mode = args.timeout_mode + if args.skip_failed_repos: + skip_failed_repos = True + + if resume and not args.resume_from: + resume_dir = _choose_resume_dir( + workspace, timeout=args.interactive_timeout + ) + if resume_dir is None: + resume = False + + unknown_resume = sorted( + set(resume_files) - {repo["name"] for repo in repos} + ) + if unknown_resume: + raise RuntimeError( + "Resume checkpoints do not match repos: " + + ", ".join(unknown_resume) + ) + + # -- Execution banner -- + console.rule("[bold cyan]Execution") + if resume: + console.print( + Panel( + "[bold yellow]Resuming[/bold yellow] from saved progress" + + (f" (dir: {resume_dir})" if resume_dir else ""), + border_style="yellow", + expand=False, + ) + ) + console.print( + Panel( + f"[bold]Checkpoints:[/bold] {workspace / 'checkpoints'}", + border_style="cyan", + expand=False, + ) + ) + timeout_str = f"{step_timeout_sec}s" if step_timeout_sec else "none" + + run_context = { + "config_path": str(Path(args.config).resolve()), + "initial_yaml_config": initial_yaml_config, + "effective_runtime": { + "workspace": str(workspace), + "planner_model": planner_model, + "executor_model": executor_model, + "max_parallel": max_parallel, + "recursion_limit": recursion_limit, + "resume": resume, + "status_interval_sec": status_interval_sec, + "max_check_retries": max_check_retries, + "step_timeout_sec": step_timeout_sec, + "timeout_mode": timeout_mode, + "skip_failed_repos": skip_failed_repos, + }, + "cli_args": { + "resume": args.resume, + "resume_from": args.resume_from, + "interactive_timeout": args.interactive_timeout, + "timeout_mode": args.timeout_mode, + "skip_failed_repos": args.skip_failed_repos, + }, + } + run_context_path = workspace / "run_context.json" + run_context_path.write_text(json.dumps(run_context, indent=2)) + + console.print( + f"[bold]Parallel workers:[/bold] {max_parallel} " + f"[bold]Status interval:[/bold] {status_interval_sec}s " + f"[bold]Check retries:[/bold] {max_check_retries} " + f"[bold]Step timeout:[/bold] {timeout_str} " + f"[bold]Timeout mode:[/bold] {timeout_mode} " + f"[bold]Skip failed repos:[/bold] {skip_failed_repos} " + f"[bold]Repos:[/bold] {len(repo_steps)}" + ) + + console.log("[dim]main starting parallel execution[/dim]") + results = asyncio.run( + _run_parallel( + repo_steps=repo_steps, + repos=repos, + problem=problem, + workspace=workspace, + models_cfg=models_cfg, + executor_model=executor_model, + recursion_limit=recursion_limit, + max_parallel=max_parallel, + resume=resume, + status_interval_sec=status_interval_sec, + resume_dir=resume_dir, + resume_files=resume_files, + max_check_retries=max_check_retries, + step_timeout_sec=step_timeout_sec, + timeout_mode=timeout_mode, + skip_failed_repos=skip_failed_repos, + ) + ) + console.log("[dim]main parallel execution returned[/dim]") + + summary_path = workspace / "run_summary.json" + console.log(f"[dim]main writing summary to {summary_path}[/dim]") + summary_path.write_text(json.dumps(results, indent=2)) + console.log("[dim]main summary write complete[/dim]") + + # -- Final summary -- + console.rule("[bold cyan]Run complete") + result_lines = [] + grand_in = 0 + grand_out = 0 + grand_total = 0 + for r in results: + name = r.get("repo", "?") + steps = r.get("steps", 0) + checks = r.get("checks") or [] + tokens = r.get("tokens") or {} + state = r.get("state", "done") + error = r.get("error") + r_in = tokens.get("input_tokens", 0) + r_out = tokens.get("output_tokens", 0) + r_total = tokens.get("total_tokens", 0) + grand_in += r_in + grand_out += r_out + grand_total += r_total + check_ok = all(c.get("exit_code", 1) == 0 for c in checks) + check_text = ( + "[green]passed[/green]" + if check_ok and checks + else "[red]failures[/red]" + if checks + else "[dim]none[/dim]" + ) + tok_text = ( + f"tokens: {_fmt_tokens(r_in)} in / {_fmt_tokens(r_out)} out" + if r_total + else "tokens: [dim]0[/dim]" + ) + state_text = ( + f"state: [yellow]{state}[/yellow]" + if state != "done" + else "state: [green]done[/green]" + ) + error_text = f" ([red]{error[:60]}[/red])" if error else "" + result_lines.append( + f" [bold]{name}[/bold]: {steps} steps, checks: {check_text}, {tok_text}, {state_text}{error_text}" + ) + result_lines.append("") + result_lines.append( + f" [bold]Total tokens:[/bold] [yellow]{_fmt_tokens(grand_total)}[/yellow] " + f"({_fmt_tokens(grand_in)} in / {_fmt_tokens(grand_out)} out)" + ) + console.print( + Panel( + "\n".join(result_lines) + + ( + f"\n\n[dim]Summary written to {summary_path}[/dim]" + f"\n[dim]Run context written to {run_context_path}[/dim]" + ), + title="[bold green]RESULTS[/bold green]", + border_style="bright_magenta", + padding=(1, 2), + ) + ) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + console.print( + "[bold yellow]Interrupted by user (Ctrl+C). " + "Exiting gracefully.[/bold yellow]" + ) + raise SystemExit(130) diff --git a/examples/two_agent_examples/plan_execute/quantum_Rabi_QuTiP.py b/examples/two_agent_examples/plan_execute/quantum_Rabi_QuTiP.py index 008313f8..cf5cc7ad 100644 --- a/examples/two_agent_examples/plan_execute/quantum_Rabi_QuTiP.py +++ b/examples/two_agent_examples/plan_execute/quantum_Rabi_QuTiP.py @@ -51,7 +51,7 @@ def main(): return final_results except Exception as e: - print(f"Error in example: {str(e)}") + print(f"Error in example: {e!s}") import traceback traceback.print_exc() diff --git a/examples/two_agent_examples/plan_execute/scrabble.py b/examples/two_agent_examples/plan_execute/scrabble.py index ecb8abaf..8a29fa4b 100644 --- a/examples/two_agent_examples/plan_execute/scrabble.py +++ b/examples/two_agent_examples/plan_execute/scrabble.py @@ -201,7 +201,7 @@ def main(mode: str): return answer except Exception as e: - print(f"Error: {str(e)}") + print(f"Error: {e!s}") import traceback traceback.print_exc() diff --git a/mkdocs.yml b/mkdocs.yml index 8091dc91..9a2b1018 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ nav: - arXiv Agent: arxiv_agent.md - DSIAgent: dsi_agent.md - Execution Agent: execution_agent.md + - Git Go Agent: git_go_agent.md - Hypothesizer Agent: hypothesizer_agent.md - Planning Agent: planning_agent.md - Web Search Agent: web_search_agent.md diff --git a/src/ursa/agents/__init__.py b/src/ursa/agents/__init__.py index c5e35fef..3fcbc7b7 100644 --- a/src/ursa/agents/__init__.py +++ b/src/ursa/agents/__init__.py @@ -13,8 +13,11 @@ "CodeReviewAgent": (".code_review_agent", "CodeReviewAgent"), "DSIAgent": (".dsi_agent", "DSIAgent"), "ExecutionAgent": (".execution_agent", "ExecutionAgent"), + "GitAgent": (".git_agent", "GitAgent"), + "GitGoAgent": (".git_go_agent", "GitGoAgent"), "HypothesizerAgent": (".hypothesizer_agent", "HypothesizerAgent"), "LammpsAgent": (".lammps_agent", "LammpsAgent"), + "make_git_agent": (".git_agent", "make_git_agent"), "MaterialsProjectAgent": (".mp_agent", "MaterialsProjectAgent"), "PlanningAgent": (".planning_agent", "PlanningAgent"), "RAGAgent": (".rag_agent", "RAGAgent"), diff --git a/src/ursa/agents/git_agent.py b/src/ursa/agents/git_agent.py new file mode 100644 index 00000000..8da143ed --- /dev/null +++ b/src/ursa/agents/git_agent.py @@ -0,0 +1,133 @@ +"""Git-aware coding agent with pluggable language support.""" + +from __future__ import annotations + +import logging + +from langchain.chat_models import BaseChatModel +from langchain_core.tools import BaseTool + +from ursa.agents.execution_agent import ExecutionAgent +from ursa.prompt_library.git_prompts import compose_git_prompt + +# Lazy import to avoid circular deps at module level +from ursa.prompt_library.go_prompts import go_language_prompt +from ursa.tools.git_tools import GIT_TOOLS +from ursa.tools.go_tools import GO_TOOLS +from ursa.tools.write_code_tool import write_code_with_repo + +LANGUAGE_REGISTRY: dict[str, dict] = { + "generic": { + "tools": None, + "prompt": None, + "safe_codes": [], + }, + "go": { + "tools": GO_TOOLS, + "prompt": go_language_prompt, + "safe_codes": ["go"], + }, + "markdown": { + "tools": None, + "prompt": None, + "safe_codes": ["markdown"], + }, +} + +LOGGER = logging.getLogger(__name__) + + +class GitAgent(ExecutionAgent): + """Execution agent with git tools and optional language-specific extensions. + + Use directly for language-agnostic git work, or pass ``language_tools``, + ``language_prompt``, and ``safe_codes`` for a language-specific variant. + """ + + def __init__( + self, + llm: BaseChatModel, + language_tools: list[BaseTool] | None = None, + language_prompt: str | None = None, + safe_codes: list[str] | None = None, + **kwargs, + ): + extra_tools: list[BaseTool] = [*GIT_TOOLS, write_code_with_repo] + if language_tools: + extra_tools.extend(language_tools) + + super().__init__( + llm=llm, + extra_tools=extra_tools, + safe_codes=safe_codes or [], + **kwargs, + ) + + self.executor_prompt = compose_git_prompt(language_prompt or "") + + self.remove_tool([ + "run_command", + "run_web_search", + "run_osti_search", + "run_arxiv_search", + ]) + + +def make_git_agent( + llm: BaseChatModel, + language: str | None = None, + language_tools: list[BaseTool] | None = None, + language_prompt: str | None = None, + safe_codes: list[str] | None = None, + **kwargs, +) -> GitAgent: + """Create a GitAgent, optionally with language-specific tools and prompts. + + Args: + llm: The language model to use. + language: Optional language name for registry lookup. If provided and + found in LANGUAGE_REGISTRY, its tools/prompt/safe_codes are used + as defaults (overridable by explicit parameters). Unknown languages + are logged and ignored, defaulting to git-only agent. + language_tools: Explicit language tools to add. Overrides registry. + language_prompt: Explicit language prompt. Overrides registry. + safe_codes: Explicit safe code list. Overrides registry. + **kwargs: Passed to GitAgent constructor. + + Returns: + A GitAgent configured with git tools and optionally language-specific + extensions. Works with any file type without requiring explicit + language registration. + """ + # Start with explicit parameters (highest priority) + tools = language_tools + prompt = language_prompt + codes = safe_codes + + # Fill in from registry if language is provided and found + if ( + language + and language not in (tools or []) + and language in LANGUAGE_REGISTRY + ): + config = LANGUAGE_REGISTRY[language] + if tools is None: + tools = config.get("tools") + if prompt is None: + prompt = config.get("prompt") + if codes is None: + codes = config.get("safe_codes") + elif language and language not in LANGUAGE_REGISTRY: + LOGGER.debug( + "Language %r not in registry; using git-only agent. Available: %s", + language, + sorted(LANGUAGE_REGISTRY), + ) + + return GitAgent( + llm=llm, + language_tools=tools, + language_prompt=prompt, + safe_codes=codes or [], + **kwargs, + ) diff --git a/src/ursa/agents/git_go_agent.py b/src/ursa/agents/git_go_agent.py new file mode 100644 index 00000000..2bcf0c98 --- /dev/null +++ b/src/ursa/agents/git_go_agent.py @@ -0,0 +1,29 @@ +"""Git-aware Go coding agent -- backward-compatible wrapper around GitAgent.""" + +from langchain.chat_models import BaseChatModel + +from ursa.agents.git_agent import GitAgent +from ursa.prompt_library.go_prompts import go_language_prompt +from ursa.tools.go_tools import GO_TOOLS + + +class GitGoAgent(GitAgent): + """Execution agent specialized for git-managed Go repositories. + + Tools: + - Git: status, diff, log, ls-files, add, commit, switch, create_branch + - Go: build, test, vet, mod tidy, linting (golangci-lint with .golangci.yml support) + - Code formatting: gofmt + + This is a convenience subclass of :class:`GitAgent` with the Go language + tools and prompt pre-configured. + """ + + def __init__(self, llm: BaseChatModel, **kwargs): + super().__init__( + llm=llm, + language_tools=GO_TOOLS, + language_prompt=go_language_prompt, + safe_codes=["go"], + **kwargs, + ) diff --git a/src/ursa/prompt_library/git_go_prompts.py b/src/ursa/prompt_library/git_go_prompts.py new file mode 100644 index 00000000..3ad551cc --- /dev/null +++ b/src/ursa/prompt_library/git_go_prompts.py @@ -0,0 +1,4 @@ +from ursa.prompt_library.git_prompts import compose_git_prompt +from ursa.prompt_library.go_prompts import go_language_prompt + +git_go_executor_prompt = compose_git_prompt(go_language_prompt) diff --git a/src/ursa/prompt_library/git_prompts.py b/src/ursa/prompt_library/git_prompts.py new file mode 100644 index 00000000..0591a74b --- /dev/null +++ b/src/ursa/prompt_library/git_prompts.py @@ -0,0 +1,29 @@ +git_base_prompt = """ +You are a coding agent working with git-managed repositories. + +Your responsibilities are as follows: + +1. Inspect existing files before changing them. +2. Use the git tools for repository operations (status, diff, log, add, commit, branch). +3. Use the file tools to read and update source files, keeping changes minimal and consistent. +4. Clearly document actions taken, including files changed and git operations performed. + +Constraints: +- Only operate inside the workspace and its subdirectories. +- Avoid destructive git commands (reset --hard, clean -fd, force push). +- Prefer small, reviewable diffs. +""" + + +def compose_git_prompt(*language_sections: str) -> str: + """Combine the git base prompt with language-specific sections. + + Each language_section is appended as a paragraph after the base prompt. + """ + parts = [git_base_prompt.strip()] + parts.extend( + section.strip() + for section in language_sections + if section and section.strip() + ) + return "\n\n".join(parts) diff --git a/src/ursa/prompt_library/go_prompts.py b/src/ursa/prompt_library/go_prompts.py new file mode 100644 index 00000000..e75a7c08 --- /dev/null +++ b/src/ursa/prompt_library/go_prompts.py @@ -0,0 +1,7 @@ +go_language_prompt = """ +Language-specific instructions (Go): +- Run gofmt on modified .go files when appropriate. +- Use go build, go test, go vet for validation. +- Use go mod tidy to clean up dependencies when needed. +- Run golangci-lint if a .golangci.yml config is present. +""" diff --git a/src/ursa/tools/__init__.py b/src/ursa/tools/__init__.py index 17e72e7d..ab85d2d9 100644 --- a/src/ursa/tools/__init__.py +++ b/src/ursa/tools/__init__.py @@ -4,3 +4,4 @@ from .run_command_tool import run_command as run_command from .write_code_tool import edit_code as edit_code from .write_code_tool import write_code as write_code +from .write_code_tool import write_code_with_repo as write_code_with_repo diff --git a/src/ursa/tools/git_tools.py b/src/ursa/tools/git_tools.py new file mode 100644 index 00000000..471d0387 --- /dev/null +++ b/src/ursa/tools/git_tools.py @@ -0,0 +1,221 @@ +import subprocess +from collections.abc import Iterable +from pathlib import Path + +from langchain.tools import ToolRuntime +from langchain_core.tools import tool + +from ursa.agents.base import AgentContext +from ursa.util.types import AsciiStr + +# Git commands are typically instant; timeout indicates hanging (waiting for input or wrong directory) +GIT_TIMEOUT = 30 # seconds - git ops should be near-instant + + +def _format_result(stdout: str | None, stderr: str | None) -> str: + return f"STDOUT:\n{stdout or ''}\nSTDERR:\n{stderr or ''}" + + +def _repo_path( + repo_path: str | None, runtime: ToolRuntime[AgentContext] +) -> Path: + base = Path(runtime.context.workspace).absolute() + if not repo_path: + candidate = base + else: + candidate = Path(repo_path) + if not candidate.is_absolute(): + candidate = base / candidate + candidate = candidate.absolute() + + try: + candidate.relative_to(base) + except ValueError as exc: + raise ValueError("repo_path must resolve inside the workspace") from exc + + return candidate + + +def _run_git(repo: Path, args: Iterable[str]) -> str: + try: + result = subprocess.run( + ["git", "-C", str(repo), *list(args)], + text=True, + capture_output=True, + timeout=GIT_TIMEOUT, + check=False, + ) + except Exception as exc: # noqa: BLE001 + return _format_result("", f"Error running git: {exc}") + + return _format_result(result.stdout, result.stderr) + + +def _check_ref_format(repo: Path, branch: str) -> str | None: + result = subprocess.run( + ["git", "-C", str(repo), "check-ref-format", "--branch", branch], + text=True, + capture_output=True, + timeout=GIT_TIMEOUT, + check=False, + ) + if result.returncode != 0: + return result.stderr or result.stdout or "Invalid branch name for git" + return None + + +@tool +def git_status( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, +) -> str: + """Return git status for a repository inside the workspace. + + Args: + repo_path: Path to repository relative to workspace. If None, uses workspace root. + Recommended to always specify a repo_path to avoid large untracked file lists. + """ + repo = _repo_path(repo_path, runtime) + + # Warn if using workspace root without explicit repo_path + workspace = Path(runtime.context.workspace).absolute() + if repo_path is None and repo == workspace: + return ( + "WARNING: git_status called on workspace root without specifying repo_path. " + "This may show many untracked files. " + "Please specify a specific repository path (e.g., repo_path='my-project'). " + "Use list_directory to see available repositories in the workspace first." + ) + + result = _run_git(repo, ["status", "-sb"]) + + # Limit output size for very large untracked file lists + if len(result) > 10000: + lines = result.split("\n") + if len(lines) > 100: + return ( + f"Git status output too large ({len(lines)} lines). " + f"Showing first 50 and last 50 lines:\n" + f"{''.join(lines[:50])}\n" + f"... ({len(lines) - 100} lines omitted) ...\n" + f"{''.join(lines[-50:])}\n" + f"Consider using git_status on a specific subdirectory." + ) + + return result + + +@tool +def git_diff( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, + staged: bool = False, + pathspecs: list[AsciiStr] | None = None, +) -> str: + """Return git diff for a repository inside the workspace.""" + repo = _repo_path(repo_path, runtime) + args = ["diff"] + if staged: + args.append("--staged") + if pathspecs: + args.append("--") + args.extend(list(pathspecs)) + return _run_git(repo, args) + + +@tool +def git_log( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, + limit: int = 20, +) -> str: + """Return recent git log entries for a repository.""" + repo = _repo_path(repo_path, runtime) + limit = max(1, int(limit)) + return _run_git(repo, ["log", f"-n{limit}", "--oneline", "--decorate"]) + + +@tool +def git_ls_files( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, + pathspecs: list[AsciiStr] | None = None, +) -> str: + """List tracked files, optionally filtered by pathspecs.""" + repo = _repo_path(repo_path, runtime) + args = ["ls-files"] + if pathspecs: + args.append("--") + args.extend(list(pathspecs)) + return _run_git(repo, args) + + +@tool +def git_add( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, + pathspecs: list[AsciiStr] | None = None, +) -> str: + """Stage files for commit using git add.""" + repo = _repo_path(repo_path, runtime) + if not pathspecs: + return _format_result("", "No pathspecs provided to git_add") + return _run_git(repo, ["add", "--", *list(pathspecs)]) + + +@tool +def git_commit( + runtime: ToolRuntime[AgentContext], + message: AsciiStr, + repo_path: AsciiStr | None = None, +) -> str: + """Create a git commit with the provided message.""" + repo = _repo_path(repo_path, runtime) + if not message.strip(): + return _format_result("", "Commit message must not be empty") + return _run_git(repo, ["commit", "--message", message]) + + +@tool +def git_switch( + runtime: ToolRuntime[AgentContext], + branch: AsciiStr, + repo_path: AsciiStr | None = None, + create: bool = False, +) -> str: + """Switch branches using git switch (optionally create).""" + repo = _repo_path(repo_path, runtime) + err = _check_ref_format(repo, branch) + if err: + return _format_result("", err) + args = ["switch"] + if create: + args.append("-c") + args.append(branch) + return _run_git(repo, args) + + +@tool +def git_create_branch( + runtime: ToolRuntime[AgentContext], + branch: AsciiStr, + repo_path: AsciiStr | None = None, +) -> str: + """Create a branch without switching to it.""" + repo = _repo_path(repo_path, runtime) + err = _check_ref_format(repo, branch) + if err: + return _format_result("", err) + return _run_git(repo, ["branch", branch]) + + +GIT_TOOLS = [ + git_status, + git_diff, + git_log, + git_ls_files, + git_add, + git_commit, + git_switch, + git_create_branch, +] diff --git a/src/ursa/tools/go_tools.py b/src/ursa/tools/go_tools.py new file mode 100644 index 00000000..81f85ac4 --- /dev/null +++ b/src/ursa/tools/go_tools.py @@ -0,0 +1,219 @@ +import subprocess + +from langchain.tools import ToolRuntime +from langchain_core.tools import tool + +from ursa.agents.base import AgentContext +from ursa.tools.git_tools import _format_result, _repo_path +from ursa.util.types import AsciiStr + +# Differentiated timeouts by operation type +GO_FORMAT_TIMEOUT = 30 # seconds - gofmt is usually fast +GO_ANALYSIS_TIMEOUT = 60 # seconds - go vet, go mod tidy +GO_BUILD_TIMEOUT = ( + 300 # seconds (5 min) - builds can take time for large projects +) +GO_TEST_TIMEOUT = 600 # seconds (10 min) - test suites can be slow +LINT_TIMEOUT = 180 # seconds (3 min) - linting is moderate speed + + +@tool +def gofmt_files( + runtime: ToolRuntime[AgentContext], + paths: list[AsciiStr], + repo_path: AsciiStr | None = None, +) -> str: + """Format Go files in-place using gofmt.""" + if not paths: + return _format_result("", "No paths provided to gofmt_files") + if any(not str(p).endswith(".go") for p in paths): + return _format_result("", "gofmt_files only accepts .go files") + repo = _repo_path(repo_path, runtime) + try: + result = subprocess.run( + ["gofmt", "-w", *list(paths)], + text=True, + capture_output=True, + timeout=GO_FORMAT_TIMEOUT, + cwd=repo, + check=False, + ) + except Exception as exc: + return _format_result("", f"Error running gofmt: {exc}") + return _format_result(result.stdout, result.stderr) + + +@tool +def go_build( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, +) -> str: + """Build a Go module using go build ./...""" + repo = _repo_path(repo_path, runtime) + try: + result = subprocess.run( + ["go", "build", "./..."], + text=True, + capture_output=True, + timeout=GO_BUILD_TIMEOUT, + cwd=repo, + check=False, + ) + except subprocess.TimeoutExpired: + return _format_result( + "", + f"go build timed out after {GO_BUILD_TIMEOUT}s (5 minutes). " + "Large builds may need to be run in smaller chunks.", + ) + except Exception as exc: + return _format_result("", f"Error running go build: {exc}") + return _format_result(result.stdout, result.stderr) + + +@tool +def go_test( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, + verbose: bool = True, +) -> str: + """Run Go tests using go test ./...""" + repo = _repo_path(repo_path, runtime) + args = ["go", "test"] + if verbose: + args.append("-v") + args.append("./...") + try: + result = subprocess.run( + args, + text=True, + capture_output=True, + timeout=GO_TEST_TIMEOUT, + cwd=repo, + check=False, + ) + except subprocess.TimeoutExpired: + return _format_result( + "", + f"go test timed out after {GO_TEST_TIMEOUT}s (10 minutes). " + "Large test suites may need to run selectively.", + ) + except Exception as exc: + return _format_result("", f"Error running go test: {exc}") + return _format_result(result.stdout, result.stderr) + + +@tool +def go_vet( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, +) -> str: + """Run Go vet for code pattern analysis using go vet ./...""" + repo = _repo_path(repo_path, runtime) + try: + result = subprocess.run( + ["go", "vet", "./..."], + text=True, + capture_output=True, + timeout=GO_ANALYSIS_TIMEOUT, + cwd=repo, + check=False, + ) + except subprocess.TimeoutExpired: + return _format_result( + "", f"go vet timed out after {GO_ANALYSIS_TIMEOUT}s." + ) + except Exception as exc: + return _format_result("", f"Error running go vet: {exc}") + return _format_result(result.stdout, result.stderr) + + +@tool +def go_mod_tidy( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, +) -> str: + """Clean up and validate Go module dependencies using go mod tidy.""" + repo = _repo_path(repo_path, runtime) + try: + result = subprocess.run( + ["go", "mod", "tidy"], + text=True, + capture_output=True, + timeout=GO_ANALYSIS_TIMEOUT, + cwd=repo, + check=False, + ) + except subprocess.TimeoutExpired: + return _format_result( + "", f"go mod tidy timed out after {GO_ANALYSIS_TIMEOUT}s." + ) + except Exception as exc: + return _format_result("", f"Error running go mod tidy: {exc}") + return _format_result(result.stdout, result.stderr) + + +@tool +def golangci_lint( + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, +) -> str: + """Run golangci-lint on the repository. + + Automatically detects and uses .golangci.yml if present, + otherwise uses sensible defaults. + """ + from ursa.tools.git_tools import GIT_TIMEOUT + + repo = _repo_path(repo_path, runtime) + + # Check if golangci-lint is installed + try: + subprocess.run( + ["golangci-lint", "--version"], + text=True, + capture_output=True, + timeout=GIT_TIMEOUT, + check=False, + ) + except FileNotFoundError: + return _format_result( + "", + "Error: golangci-lint is not installed. " + "Install it with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest", + ) + except Exception as exc: + return _format_result("", f"Error checking golangci-lint: {exc}") + + # Run golangci-lint + config_file = repo / ".golangci.yml" + args = ["golangci-lint", "run"] + if config_file.exists(): + args.extend(["--config", str(config_file)]) + + try: + result = subprocess.run( + args, + text=True, + capture_output=True, + timeout=LINT_TIMEOUT, + cwd=repo, + check=False, + ) + except subprocess.TimeoutExpired: + return _format_result( + "", f"golangci-lint timed out after {LINT_TIMEOUT}s (3 minutes)." + ) + except Exception as exc: + return _format_result("", f"Error running golangci-lint: {exc}") + + return _format_result(result.stdout, result.stderr) + + +GO_TOOLS = [ + gofmt_files, + go_build, + go_test, + go_vet, + go_mod_tidy, + golangci_lint, +] diff --git a/src/ursa/tools/write_code_tool.py b/src/ursa/tools/write_code_tool.py index ea1d8fcf..765afa76 100644 --- a/src/ursa/tools/write_code_tool.py +++ b/src/ursa/tools/write_code_tool.py @@ -1,3 +1,4 @@ +import os import time from pathlib import Path @@ -15,28 +16,49 @@ console = get_console() -@tool(description="Write source code to a file") -def write_code( - code: str, +def _resolve_repo_dir( + repo_path: AsciiStr, + workspace_dir: Path, + action: str, filename: AsciiStr, - runtime: ToolRuntime[AgentContext], -) -> str: - """Write source code to a file +) -> tuple[Path | None, str | None]: + repo = Path(repo_path) + if not repo.is_absolute(): + repo = workspace_dir / repo + repo = repo.resolve() + if not repo.exists(): + return ( + None, + f"Failed to {action} {filename}: Repository path not found.", + ) + if not repo.is_dir(): + return None, ( + f"Failed to {action} {filename}: Repository path is not a directory." + ) - Records successful file edits to the graph's store + return repo, None - Args: - code: The source code content to be written to disk. - filename: Name of the target file (including its extension). - """ - # Determine the full path to the target file +def _write_code_file( + code: str, + filename: AsciiStr, + runtime: ToolRuntime[AgentContext], + repo: Path | None = None, +) -> str: workspace_dir = runtime.context.workspace console.print("[cyan]Writing file:[/]", filename) - # Show syntax-highlighted preview before writing to file + code_file, error = _validate_file_path(filename, workspace_dir, repo) + if error: + console.print( + f"[bold bright_white on red] :heavy_multiplication_x: [/] [red]{error}[/]" + ) + return f"Failed to write {filename}: {error}" + if code_file is None: + return f"Failed to write {filename}: Invalid file path." + try: - lexer_name = Syntax.guess_lexer(filename, code) + lexer_name = Syntax.guess_lexer(str(code_file), code) except Exception: lexer_name = "text" @@ -48,9 +70,8 @@ def write_code( ) ) - # Write cleaned code to disk - code_file = workspace_dir.joinpath(filename) try: + code_file.parent.mkdir(parents=True, exist_ok=True) with open(code_file, "w", encoding="utf-8") as f: f.write(code) except Exception as exc: @@ -59,14 +80,13 @@ def write_code( "[red]Failed to write file:[/]", exc, ) - return f"Failed to write {filename}." + return f"Failed to write {filename}: {exc}" console.print( f"[bold bright_white on green] :heavy_check_mark: [/] " f"[green]File written:[/] {code_file}" ) - # Record the edit operation if (store := runtime.store) is not None: store.put( ("workspace", "file_edit"), @@ -82,12 +102,120 @@ def write_code( return f"File {filename} written successfully." +def _validate_file_path( + filename: str, + workspace_dir: Path, + repo_path: Path | None = None, + allow_unsafe_writes: bool | None = None, +) -> tuple[Path | None, str | None]: + """Validate that a filename is within workspace and optionally within a repo. + + Args: + filename: The requested filename to write to + workspace_dir: The workspace directory (all files must be under this) + repo_path: Optional repo directory (if provided, file must be under this) + allow_unsafe_writes: Permit writes outside workspace directory and repo directory. This is unsafe; users should use a sandbox or container when enabling this option. + + Returns: + Tuple of (resolved_path, error_message). If error_message is not None, + the resolution failed. + """ + # Resolve the file path + if Path(filename).is_absolute(): + file_path = Path(filename) + else: + file_path = workspace_dir / filename + + file_path = file_path.resolve() + + if allow_unsafe_writes is None: + allow_unsafe_writes = _allow_unsafe_writes_enabled() + # Validate it's within the workspace + if not allow_unsafe_writes and not file_path.is_relative_to( + workspace_dir.resolve() + ): + return None, ( + f"File path '{filename}' resolves outside workspace directory. " + "Files must be written within the workspace." + ) + + # If repo_path is specified, validate it's within the repo + if not allow_unsafe_writes and repo_path is not None: + # Resolve repo_path relative to workspace if it's not absolute + repo_resolved = ( + repo_path if repo_path.is_absolute() else workspace_dir / repo_path + ) + repo_resolved = repo_resolved.resolve() + + if not file_path.is_relative_to(repo_resolved): + return None, ( + f"File path '{filename}' resolves outside repository directory. " + "Files must be written within the repository." + ) + + return file_path, None + + +def _allow_unsafe_writes_enabled() -> bool: + """Return whether unsafe writes are explicitly enabled via environment. + + Set URSA_ALLOW_UNSAFE_WRITES to one of: 1, true, yes, on. + """ + return os.getenv("URSA_ALLOW_UNSAFE_WRITES", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +@tool(description="Write source code to a file") +def write_code( + code: str, + filename: AsciiStr, + runtime: ToolRuntime[AgentContext], +) -> str: + """Write source code to a file + + Records successful file edits to the graph's store + + Args: + code: The source code content to be written to disk. + filename: Name of the target file (including its extension). + + """ + return _write_code_file(code, filename, runtime) + + +@tool(description="Write source code to a file within a repository boundary") +def write_code_with_repo( + code: str, + filename: AsciiStr, + runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr, +) -> str: + """Write source code to a file constrained to a repository path. + + Args: + code: The source code content to be written to disk. + filename: Name of the target file (including its extension). + repo_path: Repo path - file must resolve within this directory. + """ + workspace_dir = runtime.context.workspace + repo, error = _resolve_repo_dir(repo_path, workspace_dir, "write", filename) + if error: + return error + + return _write_code_file(code, filename, runtime, repo) + + @tool def edit_code( old_code: str, new_code: str, filename: AsciiStr, runtime: ToolRuntime[AgentContext], + repo_path: AsciiStr | None = None, ) -> str: """Replace the **first** occurrence of *old_code* with *new_code* in *filename*. @@ -95,6 +223,7 @@ def edit_code( old_code: Code fragment to search for. new_code: Replacement fragment. filename: Target file inside the workspace. + repo_path: Optional repo path - if provided, file must be within this repo. Returns: Success / failure message. @@ -102,7 +231,31 @@ def edit_code( workspace_dir = runtime.context.workspace console.print("[cyan]Editing file:[/cyan]", filename) - code_file = Path(workspace_dir, filename) + # Validate file path + repo = None + if repo_path: + repo, error = _resolve_repo_dir( + repo_path, + workspace_dir, + "edit", + filename, + ) + if error: + return error + + code_file, error = _validate_file_path( + filename, + workspace_dir, + repo, + ) + if error: + console.print( + f"[bold bright_white on red] :heavy_multiplication_x: [/] [red]{error}[/]" + ) + return f"Failed to edit {filename}: {error}" + if code_file is None: + return f"Failed to edit {filename}: Invalid file path." + try: content = read_text_file(code_file) except FileNotFoundError: @@ -111,6 +264,10 @@ def edit_code( "[red]File not found:[/]", ) return f"Failed: {filename} not found." + except ValueError as exc: + return f"Failed to edit {filename}: {exc}" + except OSError as exc: + return f"Failed to edit {filename}: Could not read file: {exc}" # Clean up markdown fences old_code_clean = old_code @@ -141,7 +298,7 @@ def edit_code( "[red]Failed to write file:[/]", exc, ) - return f"Failed to edit {filename}." + return f"Failed to edit {filename}: {exc}" console.print( f"[bold bright_white on green] :heavy_check_mark: [/] " diff --git a/src/ursa/util/diff_renderer.py b/src/ursa/util/diff_renderer.py index 19bef30b..8ebbd820 100644 --- a/src/ursa/util/diff_renderer.py +++ b/src/ursa/util/diff_renderer.py @@ -85,7 +85,7 @@ def __rich_console__( code = raw[1:] else: style = _STYLE["ctx"] - code = raw[1:] if raw.startswith(" ") else raw + code = raw.lstrip() # compute line numbers if raw.startswith("+"): diff --git a/src/ursa/util/github_research.py b/src/ursa/util/github_research.py new file mode 100644 index 00000000..0673a4a3 --- /dev/null +++ b/src/ursa/util/github_research.py @@ -0,0 +1,176 @@ +"""Fetch recent issues & PRs from GitHub repos for planning context. + +Uses the ``gh`` CLI (https://cli.github.com/) which handles authentication +transparently. Falls back gracefully when ``gh`` is not installed or when a +repo URL does not point at GitHub. +""" + +from __future__ import annotations + +import json +import re +import shutil +import subprocess +from typing import Any + +_GH_URL_RE = re.compile( + r"github\.com[:/](?P[^/]+)/(?P[^/.]+?)(?:\.git)?$" +) + + +def parse_github_owner_repo(url: str) -> tuple[str, str] | None: + """Extract ``(owner, repo)`` from a GitHub clone URL. + + Supports both HTTPS and SSH URLs. Returns ``None`` for non-GitHub URLs. + """ + m = _GH_URL_RE.search(url or "") + if m: + return m.group("owner"), m.group("repo") + return None + + +def _gh_available() -> bool: + return shutil.which("gh") is not None + + +def _gh_api(endpoint: str, timeout: int = 30) -> Any: + """Call ``gh api`` and return parsed JSON.""" + result = subprocess.run( + ["gh", "api", endpoint, "--paginate"], + capture_output=True, + text=True, + timeout=timeout, + check=False, + ) + if result.returncode != 0: + raise RuntimeError(result.stderr.strip()) + return json.loads(result.stdout) + + +def _format_issue(item: dict) -> str: + number = item.get("number", "?") + title = item.get("title", "") + state = item.get("state", "") + labels = ", ".join( + label.get("name", "") for label in (item.get("labels") or []) + ) + created = (item.get("created_at") or "")[:10] + body = (item.get("body") or "")[:300] + parts = [f" #{number} [{state}] {title}"] + if labels: + parts.append(f" Labels: {labels}") + if created: + parts.append(f" Created: {created}") + if body: + parts.append(f" {body}") + return "\n".join(parts) + + +def _format_pr(item: dict) -> str: + number = item.get("number", "?") + title = item.get("title", "") + state = item.get("state", "") + created = (item.get("created_at") or "")[:10] + body = (item.get("body") or "")[:300] + parts = [f" #{number} [{state}] {title}"] + if created: + parts.append(f" Created: {created}") + if body: + parts.append(f" {body}") + return "\n".join(parts) + + +def fetch_repo_context( + owner: str, + repo: str, + *, + max_issues: int = 10, + max_prs: int = 10, + issue_state: str = "all", + pr_state: str = "all", +) -> str: + """Fetch recent issues and PRs for a single GitHub repo. + + Returns a formatted text block suitable for inclusion in a planner prompt. + """ + sections: list[str] = [] + sections.append(f"## {owner}/{repo}") + + # Recent issues + try: + issues = _gh_api( + f"/repos/{owner}/{repo}/issues?state={issue_state}" + f"&per_page={max_issues}&sort=updated&direction=desc" + ) + # gh api may return PRs mixed with issues; filter them out + pure_issues = [i for i in issues if "pull_request" not in i][ + :max_issues + ] + if pure_issues: + sections.append(f"### Recent issues ({len(pure_issues)})") + sections.extend(_format_issue(issue) for issue in pure_issues) + else: + sections.append("### Recent issues: none") + except Exception as exc: # noqa: BLE001 + sections.append(f"### Issues: could not fetch ({exc})") + + # Recent PRs + try: + prs = _gh_api( + f"/repos/{owner}/{repo}/pulls?state={pr_state}" + f"&per_page={max_prs}&sort=updated&direction=desc" + ) + if prs: + sections.append(f"### Recent pull requests ({len(prs[:max_prs])})") + sections.extend(_format_pr(pr) for pr in prs[:max_prs]) + else: + sections.append("### Recent pull requests: none") + except Exception as exc: # noqa: BLE001 + sections.append(f"### PRs: could not fetch ({exc})") + + return "\n".join(sections) + + +def gather_github_context( + repos: list[dict], + *, + max_issues: int = 10, + max_prs: int = 10, +) -> str | None: + """Gather GitHub context for all repos that have GitHub URLs. + + Parameters + ---------- + repos: + List of repo config dicts (each must have at least ``url`` and ``name``). + max_issues: + Maximum recent issues to fetch per repo. + max_prs: + Maximum recent PRs to fetch per repo. + + Returns + ------- + Formatted text block with issues/PRs across repos, or ``None`` if nothing + was fetched (e.g. no GitHub URLs, ``gh`` not installed). + """ + if not _gh_available(): + return None + + blocks: list[str] = [] + for repo in repos: + parsed = parse_github_owner_repo(repo.get("url", "")) + if not parsed: + continue + owner, name = parsed + try: + block = fetch_repo_context( + owner, name, max_issues=max_issues, max_prs=max_prs + ) + blocks.append(block) + except Exception: # noqa: BLE001, S112 + # Network issue, auth issue, etc. -- skip silently + continue + + if not blocks: + return None + return "\n\n".join(blocks) diff --git a/src/ursa/util/parse.py b/src/ursa/util/parse.py index 1648b89a..3b6754a0 100644 --- a/src/ursa/util/parse.py +++ b/src/ursa/util/parse.py @@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET import zipfile from pathlib import Path -from typing import Any, Optional +from typing import Any from urllib.parse import urljoin, urlparse import justext @@ -23,14 +23,14 @@ from docx import Document docx_installed = True -except Exception: +except Exception: # noqa: BLE001, S110 pass try: from pptx import Presentation pptx_installed = True -except Exception: +except Exception: # noqa: BLE001, S110 pass @@ -39,6 +39,7 @@ # plain text & docs ".txt", ".md", + ".markdown", ".rst", ".rtf", ".tex", @@ -50,6 +51,8 @@ ".xml", ".html", ".htm", + ".adoc", + ".asciidoc", # source code (common) ".py", ".pyi", @@ -59,6 +62,7 @@ ".cpp", ".hpp", ".cc", + ".cxx", ".java", ".kt", ".scala", @@ -74,6 +78,49 @@ ".bash", ".zsh", ".ps1", + ".r", + ".R", + ".jl", + ".lua", + ".pl", + ".swift", + ".m", + ".mm", + # config files + ".yaml", + ".yml", + ".toml", + ".ini", + ".cfg", + ".conf", + ".config", + ".properties", + ".env", + ".editorconfig", + # build & project files + ".gradle", + ".cmake", + ".bazel", + ".bzl", + # systemd & podman quadlet files + ".service", + ".socket", + ".timer", + ".target", + ".mount", + ".automount", + ".path", + ".slice", + ".container", + ".volume", + ".network", + ".kube", + ".spec", + # other markup & data + ".proto", + ".graphql", + ".gql", + ".sql", } SPECIAL_TEXT_FILENAMES = { @@ -116,7 +163,7 @@ def extract_json(text: str) -> list[dict]: generic_block = re.search(r"```(.*?)```", text, re.DOTALL) if generic_block: json_str = generic_block.group(1).strip() - if json_str.startswith("{") or json_str.startswith("["): + if json_str.startswith(("{", "[")): try: return json.loads(json_str) except json.JSONDecodeError: @@ -233,14 +280,14 @@ def _download_stream_to(path: str, resp: requests.Response) -> str: def _get_soup( - url: str, timeout: int = 20, headers: Optional[dict[str, str]] = None + url: str, timeout: int = 20, headers: dict[str, str] | None = None ) -> BeautifulSoup: r = requests.get(url, timeout=timeout, headers=headers or {}) r.raise_for_status() return BeautifulSoup(r.text, "html.parser") -def _find_pdf_on_landing(soup: BeautifulSoup, base_url: str) -> Optional[str]: +def _find_pdf_on_landing(soup: BeautifulSoup, base_url: str) -> str | None: # 1) meta citation_pdf_url meta = soup.find("meta", attrs={"name": "citation_pdf_url"}) if meta and meta.get("content"): @@ -270,7 +317,7 @@ def _pdf_page_count(path: Path) -> int: loader = PyPDFLoader(path) pages = loader.load() return len(pages) - except Exception as e: + except Exception as e: # noqa: BLE001 print("[Error]: ", e) return 0 @@ -323,10 +370,10 @@ def _ocr_to_searchable_pdf( def resolve_pdf_from_osti_record( rec: dict[str, Any], *, - headers: Optional[dict[str, str]] = None, - unpaywall_email: Optional[str] = None, + headers: dict[str, str] | None = None, + unpaywall_email: str | None = None, timeout: int = 25, -) -> tuple[Optional[str], Optional[str], str]: +) -> tuple[str | None, str | None, str]: """ Returns (pdf_url, landing_used, note) - pdf_url: direct downloadable PDF URL if found (or a strong candidate) @@ -373,7 +420,7 @@ def resolve_pdf_from_osti_record( "found PDF via meta/anchor on fulltext landing" ) return (candidate, fulltext, " | ".join(note_parts)) - except Exception as e: + except Exception as e: # noqa: BLE001 note_parts.append(f"fulltext failed: {e}") # 2) Try DOE PAGES landing (citation_doe_pages) @@ -403,14 +450,14 @@ def resolve_pdf_from_osti_record( note_parts.append("citation_doe_pages → direct PDF") return (r2.url, doe_pages, " | ".join(note_parts)) r2.close() - except Exception: + except Exception: # noqa: BLE001, S110 pass # If not clearly PDF, still return as a candidate (agent will fetch & parse) note_parts.append( "citation_doe_pages → PDF-like candidate (not confirmed by headers)" ) return (candidate, doe_pages, " | ".join(note_parts)) - except Exception as e: + except Exception as e: # noqa: BLE001 note_parts.append(f"citation_doe_pages failed: {e}") # # 3) Optional: DOI → Unpaywall OA @@ -433,8 +480,7 @@ def _normalize_ws(text: str) -> str: text = re.sub(r"[ \t\r\f\v]+", " ", text) text = re.sub(r"\s*\n\s*", "\n", text) text = re.sub(r"\n{3,}", "\n\n", text) - text = text.strip() - return text + return text.strip() def _dedupe_lines(text: str, min_len: int = 40) -> str: @@ -477,7 +523,7 @@ def extract_main_text_only(html: str, *, max_chars: int = 250_000) -> str: txt = _normalize_ws(txt) txt = _dedupe_lines(txt) return txt[:max_chars] - except Exception: + except Exception: # noqa: BLE001, S110 pass # 2) jusText @@ -488,7 +534,7 @@ def extract_main_text_only(html: str, *, max_chars: int = 250_000) -> str: txt = _normalize_ws("\n\n".join(body_paras)) txt = _dedupe_lines(txt) return txt[:max_chars] - except Exception: + except Exception: # noqa: BLE001, S110 pass # 4) last-resort: BS4 paragraphs/headings only @@ -600,18 +646,18 @@ def read_pdf(path: str | Path) -> str: except (FileNotFoundError, subprocess.CalledProcessError) as e: # Missing ocrmypdf or OCR failed: keep original extraction print(f"[OCR Error]: {e}") - except Exception as e: + except Exception as e: # noqa: BLE001 # Any other OCR-related failure: keep original extraction print(f"[OCR Error]: {e}") - return text + return text # noqa: TRY300 except subprocess.CalledProcessError as e: # OCR failed; return whatever we got from normal extraction err = (e.stderr or "")[:500] print(f"[OCR Error]: {err}") return text if text else f"[Error]: OCR failed: {err}" - except Exception as e: + except Exception as e: # noqa: BLE001 print(f"[Error]: {e}") return f"[Error]: {e}" @@ -623,9 +669,12 @@ def read_text_file(path: str | Path) -> str: Args: path: string filename, with path, to read in """ - with open(path, "r", encoding="utf-8") as file: - file_contents = file.read() - return file_contents + try: + with open(path, "r", encoding="utf-8") as file: + return file.read() + except UnicodeDecodeError: + # If UTF-8 fails, it's likely binary + raise ValueError(f"File appears to be binary: {path}") # helper to extract text from OpenDocument formats (.odt/.odp) @@ -662,7 +711,7 @@ def read_docx(path: Path) -> str: return "\n".join(parts) else: return ( - f"No DOCX reader so skipping {str(path)}.\n", + f"No DOCX reader so skipping {path!s}.\n", "Consider installing via `pip install 'ursa-ai[office_readers]'`.", ) @@ -681,7 +730,7 @@ def read_pptx(path: Path) -> str: return "\n".join(parts) else: return ( - f"No PPTX reader so skipping {str(path)}.\n", + f"No PPTX reader so skipping {path!s}.\n", "Consider installing via `pip install 'ursa-ai[office_readers]'`.", ) @@ -715,7 +764,11 @@ def read_text_from_file(path): ): full_text = read_text_file(path) else: - full_text = f"Unsupported file type: {path.name}" - except Exception as e: + # Gracefully attempt to read unknown extensions as text + try: + full_text = read_text_file(path) + except (UnicodeDecodeError, ValueError): + full_text = f"Unsupported file type (binary or non-UTF-8): {path.name}" + except Exception as e: # noqa: BLE001 full_text = f"Error loading {path.name}: {e}" return full_text diff --git a/src/ursa/util/plan_execute_utils.py b/src/ursa/util/plan_execute_utils.py new file mode 100644 index 00000000..c398981b --- /dev/null +++ b/src/ursa/util/plan_execute_utils.py @@ -0,0 +1,539 @@ +""" +Shared utilities for plan_execute workflows. + +This module contains common functionality used by both single-repo and multi-repo +plan/execute workflows to reduce duplication and improve maintainability. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import select +import sqlite3 +import sys +import time +from pathlib import Path +from types import SimpleNamespace as NS +from typing import Any + +import randomname +import yaml +from langchain.chat_models import init_chat_model +from rich import get_console +from rich.panel import Panel +from rich.text import Text + +console = get_console() + +_RANDOMNAME_ADJ = ( + "colors", + "emotions", + "character", + "speed", + "size", + "weather", + "appearance", + "sound", + "age", + "taste", + "physics", +) + +_RANDOMNAME_NOUN = ( + "cats", + "dogs", + "apex_predators", + "birds", + "fish", + "fruit", + "seasonings", +) + + +# ============================================================================ +# YAML Configuration Loading +# ============================================================================ + + +def generate_workspace_name(project: str = "run") -> str: + """Generate a workspace name using randomname, with timestamp fallback.""" + try: + suffix = randomname.get_name(adj=_RANDOMNAME_ADJ, noun=_RANDOMNAME_NOUN) + except Exception: + suffix = time.strftime("%Y%m%d-%H%M%S") + return f"{project}_{suffix}" + + +def load_yaml_config(path: str) -> NS: + """Load a YAML config file and return as a SimpleNamespace.""" + try: + with open(path, encoding="utf-8") as f: + raw_cfg = yaml.safe_load(f) or {} + if not isinstance(raw_cfg, dict): + raise ValueError("Top-level YAML must be a mapping/object.") + return NS(**raw_cfg) + except FileNotFoundError: + print(f"Config file not found: {path}", file=sys.stderr) + sys.exit(2) + except Exception as exc: + print(f"Failed to load config {path}: {exc}", file=sys.stderr) + sys.exit(2) + + +def load_json_file(path: str | Path, default: Any): + """Load JSON from a file path, returning default on missing/invalid JSON.""" + p = Path(path) + if not p.exists(): + return default + try: + return json.loads(p.read_text()) + except Exception: + return default + + +def save_json_file( + path: str | Path, + payload: Any, + *, + indent: int = 2, + ensure_parent: bool = True, +) -> None: + """Write JSON payload to disk with optional parent directory creation.""" + p = Path(path) + if ensure_parent: + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(json.dumps(payload, indent=indent)) + + +# ============================================================================ +# Dictionary Merging +# ============================================================================ + + +def deep_merge_dicts(base: dict, override: dict) -> dict: + """ + Recursively merge override into base and return a new dict. + - dict + dict => deep merge + - otherwise => override wins + """ + base = dict(base or {}) + override = dict(override or {}) + out = dict(base) + for k, v in override.items(): + if isinstance(v, dict) and isinstance(out.get(k), dict): + out[k] = deep_merge_dicts(out[k], v) + else: + out[k] = v + return out + + +# ============================================================================ +# Plan Hashing +# ============================================================================ + + +def hash_plan(plan_steps: list | tuple) -> str: + """Generate a stable hash of plan steps for change detection.""" + serial = json.dumps( + [ + step.model_dump() if hasattr(step, "model_dump") else step + for step in plan_steps + ], + sort_keys=True, + default=str, + ) + return hashlib.sha256(serial.encode("utf-8")).hexdigest() + + +# ============================================================================ +# Secret Masking for Logging +# ============================================================================ + +_SECRET_KEY_SUBSTRS = ( + "api_key", + "apikey", + "access_token", + "refresh_token", + "secret", + "password", + "bearer", +) + + +def looks_like_secret_key(name: str) -> bool: + """Check if a parameter name looks like it contains sensitive data.""" + n = name.lower() + return any(s in n for s in _SECRET_KEY_SUBSTRS) + + +def mask_secret(value: str, keep_start: int = 6, keep_end: int = 4) -> str: + """ + Mask a secret-like string, keeping only the beginning and end. + Example: sk-proj-abc123456789xyz -> sk-proj-...9xyz + """ + if not isinstance(value, str): + return value + if len(value) <= keep_start + keep_end + 3: + return "..." + return f"{value[:keep_start]}...{value[-keep_end:]}" + + +def sanitize_for_logging(obj: Any) -> Any: + """Recursively sanitize secrets from config objects for safe logging.""" + if isinstance(obj, dict): + out = {} + for k, v in obj.items(): + if looks_like_secret_key(str(k)): + out[k] = mask_secret(v) if isinstance(v, str) else "..." + else: + out[k] = sanitize_for_logging(v) + return out + if isinstance(obj, list): + return [sanitize_for_logging(v) for v in obj] + return obj + + +# ============================================================================ +# Model Resolution & LLM Setup +# ============================================================================ + + +def resolve_llm_kwargs_for_agent( + models_cfg: dict | None, agent_name: str | None +) -> dict: + """ + Given the YAML `models:` dict, compute merged kwargs for init_chat_model(...) + for a specific agent ('planner' or 'executor'). + + Merge order (later wins): + 1) {} (empty) + 2) models.defaults.params (optional) + 3) models.profiles[defaults.profile] (optional) + 4) models.agents[agent_name].profile (optional; merges that profile on top) + 5) models.agents[agent_name].params (optional) + """ + models_cfg = models_cfg or {} + profiles = models_cfg.get("profiles") or {} + defaults = models_cfg.get("defaults") or {} + agents = models_cfg.get("agents") or {} + + # Start with global defaults + merged = {} + merged = deep_merge_dicts(merged, defaults.get("params") or {}) + + # Apply default profile + default_profile_name = defaults.get("profile") + if default_profile_name and default_profile_name in profiles: + merged = deep_merge_dicts(merged, profiles[default_profile_name]) + + # Apply agent-specific profile + params + if agent_name and isinstance(agents, dict) and agent_name in agents: + agent_cfg = agents[agent_name] + agent_profile = agent_cfg.get("profile") + if agent_profile and agent_profile in profiles: + merged = deep_merge_dicts(merged, profiles[agent_profile]) + merged = deep_merge_dicts(merged, agent_cfg.get("params") or {}) + + return merged + + +def resolve_model_choice(model_choice: str, models_cfg: dict): + """ + Accepts strings like 'openai:gpt-5.2' or 'my_endpoint:openai/gpt-oss-120b'. + Looks up per-provider settings from cfg.models.providers. + + Returns: (model_provider, pure_model, provider_extra_kwargs_for_init) + """ + if ":" in model_choice: + alias, pure_model = model_choice.split(":", 1) + else: + alias, pure_model = model_choice, model_choice + + providers = (models_cfg or {}).get("providers", {}) + prov = providers.get(alias, {}) + + # Which LangChain integration to use (e.g. "openai", "mistral", etc.) + model_provider = prov.get("model_provider", alias) + + # auth: prefer env var; optionally load via function if configured + api_key = None + if prov.get("api_key_env"): + api_key = os.getenv(prov["api_key_env"]) + if not api_key and prov.get("token_loader"): + # Dynamic token loading (omitted for brevity; can import if needed) + pass + + provider_extra = {} + if prov.get("base_url"): + provider_extra["base_url"] = prov["base_url"] + if api_key: + provider_extra["api_key"] = api_key + + return model_provider, pure_model, provider_extra + + +def print_llm_init_banner( + agent_name: str | None, + provider: str, + model_name: str, + provider_extra: dict, + llm_kwargs: dict, + model_obj=None, +) -> None: + """Print a Rich panel showing LLM initialization details.""" + who = agent_name or "llm" + + safe_provider_extra = sanitize_for_logging(provider_extra or {}) + safe_llm_kwargs = sanitize_for_logging(llm_kwargs or {}) + + console.print( + Panel.fit( + Text.from_markup( + f"[bold cyan]LLM init ({who})[/]\n" + f"[bold]provider[/]: {provider}\n" + f"[bold]model[/]: {model_name}\n\n" + f"[bold]provider kwargs[/]: {json.dumps(safe_provider_extra, indent=2)}\n\n" + f"[bold]llm kwargs (merged)[/]: {json.dumps(safe_llm_kwargs, indent=2)}" + ), + border_style="cyan", + ) + ) + + # Best-effort readback from the LangChain model object + if model_obj is None: + return + + readback = {} + for attr in ( + "model_name", + "model", + "reasoning", + "temperature", + "max_completion_tokens", + "max_tokens", + ): + if hasattr(model_obj, attr): + val = getattr(model_obj, attr, None) + if val is not None: + readback[attr] = val + + for attr in ("model_kwargs", "kwargs"): + if hasattr(model_obj, attr): + val = getattr(model_obj, attr, {}) + if isinstance(val, dict) and val: + readback[attr] = val + + if readback: + safe_readback = sanitize_for_logging(readback) + console.print( + Panel.fit( + Text.from_markup( + f"[dim]Model object readback:[/]\n{json.dumps(safe_readback, indent=2)}" + ), + border_style="dim", + ) + ) + + # Attempt a minimal test call + effort = None + try: + from langchain_core.messages import HumanMessage as _HM + + effort = model_obj.invoke([_HM(content="test")]) + except Exception: + pass + + if effort: + console.print("[dim]✓ Test invocation succeeded[/dim]") + + +def setup_llm( + model_choice: str, + models_cfg: dict | None = None, + agent_name: str | None = None, +): + """ + Build a LangChain chat model via init_chat_model(...), optionally applying + YAML-driven params from models.profiles, models.defaults, models.agents. + """ + models_cfg = models_cfg or {} + + provider, pure_model, provider_extra = resolve_model_choice( + model_choice, models_cfg + ) + + # Hardcoded defaults for backward compatibility + base_llm_kwargs = { + "max_completion_tokens": 10000, + "max_retries": 2, + } + + # YAML-driven kwargs (safe if absent) + yaml_llm_kwargs = resolve_llm_kwargs_for_agent(models_cfg, agent_name) + + # Merge: base defaults < YAML overrides + llm_kwargs = deep_merge_dicts(base_llm_kwargs, yaml_llm_kwargs) + + # Initialize + model = init_chat_model( + model=pure_model, + model_provider=provider, + **llm_kwargs, + **(provider_extra or {}), + ) + + # Print confirmation + print_llm_init_banner( + agent_name=agent_name, + provider=provider, + model_name=pure_model, + provider_extra=provider_extra, + llm_kwargs=llm_kwargs, + model_obj=model, + ) + + return model + + +# ============================================================================ +# Workspace Setup +# ============================================================================ + + +def setup_workspace( + user_specified_workspace: str | None, + project: str = "run", + model_name: str = "openai:gpt-5-mini", +) -> str: + """ + Set up a workspace directory for a plan/execute run. + Returns the workspace path as a string. + """ + if user_specified_workspace is None: + workspace = generate_workspace_name(project) + else: + workspace = user_specified_workspace + + Path(workspace).mkdir(parents=True, exist_ok=True) + + # Choose a fun emoji based on the model family + if model_name.startswith("openai"): + model_emoji = "🤖" + elif "llama" in model_name.lower(): + model_emoji = "🦙" + else: + model_emoji = "🧠" + + # Print the panel with model info + console.print( + Panel.fit( + f":rocket: [bold bright_blue]{workspace}[/bold bright_blue] :rocket:\n" + f"{model_emoji} [bold cyan]{model_name}[/bold cyan]", + title="[bold green]ACTIVE WORKSPACE[/bold green]", + border_style="bright_magenta", + padding=(1, 4), + ) + ) + + return workspace + + +# ============================================================================ +# Interactive Input with Timeout +# ============================================================================ + + +def timed_input_with_countdown(prompt: str, timeout: int) -> str | None: + """ + Read a line with a per-second countdown. Returns: + - the user's input (str) if provided, + - None if timeout expires, + - None if non-interactive or timeout<=0. + """ + try: + is_tty = sys.stdin.isatty() + except Exception: + is_tty = False + + if not is_tty: + # Non-interactive: default immediately + return None + if timeout <= 0: + # Timeout disabled: default immediately + return None + + deadline = time.time() + timeout + print(prompt, end="", flush=True) + + try: + while True: + remaining = int(deadline - time.time()) + if remaining <= 0: + print() + return None + + # Poll stdin with a 1-second timeout + ready, _, _ = select.select([sys.stdin], [], [], 1.0) + if ready: + line = sys.stdin.readline() + return line.rstrip("\n") if line else None + + # Update countdown display (clear to EOL to avoid ghost text) + print(f"\r{prompt}({remaining}s) \x1b[K", end="", flush=True) + + except Exception: + print() + return None + + +# ============================================================================ +# Checkpoint Snapshotting (SQLite) +# ============================================================================ + + +def snapshot_sqlite_db(src_path: Path, dst_path: Path) -> None: + """ + Make a consistent copy of the SQLite database at src_path into dst_path, + using the sqlite3 backup API. Safe with WAL; no need to copy -wal/-shm. + """ + if not src_path.exists(): + raise FileNotFoundError(f"Source database not found: {src_path}") + + dst_path.parent.mkdir(parents=True, exist_ok=True) + src_uri = f"file:{Path(src_path).resolve().as_posix()}?mode=ro" + src = dst = None + try: + src = sqlite3.connect(src_uri, uri=True) + dst = sqlite3.connect(str(dst_path)) + with dst: + src.backup(dst) + finally: + try: + if src: + src.close() + except Exception: + pass + try: + if dst: + dst.close() + except Exception: + pass + + +# ============================================================================ +# Formatted Elapsed Time +# ============================================================================ + + +def fmt_elapsed(seconds: float) -> str: + """Format elapsed seconds as compact h:mm:ss or m:ss.""" + s = int(seconds) + if s < 60: + return f"{s}s" + m, s = divmod(s, 60) + if m < 60: + return f"{m}m{s:02d}s" + h, m = divmod(m, 60) + return f"{h}h{m:02d}m" diff --git a/tests/agents/test_git_go_agent/test_git_go_agent.py b/tests/agents/test_git_go_agent/test_git_go_agent.py new file mode 100644 index 00000000..c1735f66 --- /dev/null +++ b/tests/agents/test_git_go_agent/test_git_go_agent.py @@ -0,0 +1,153 @@ +import shutil +from collections.abc import Iterator +from pathlib import Path + +import pytest +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage + +if shutil.which("git") is None: + pytest.skip( + "Skipping git agent tests: `git` executable is not available on PATH. " + "Install git and ensure it is available in the active shell.", + allow_module_level=True, + ) + +try: + from ursa.agents import GitAgent, GitGoAgent, make_git_agent +except (ImportError, ModuleNotFoundError) as exc: + pytest.skip( + "Skipping git agent tests: git-related Python tooling could not be imported. " + "Install the project test dependencies and verify git tool integrations are available. " + f"Import error: {exc}", + allow_module_level=True, + ) + + +class ToolReadyFakeChatModel(GenericFakeChatModel): + def bind_tools(self, tools, **kwargs): + return self + + +def _message_stream(content: str) -> Iterator[AIMessage]: + while True: + yield AIMessage(content=content) + + +def test_git_go_agent_tools(tmpdir): + chat_model = ToolReadyFakeChatModel(messages=_message_stream("ok")) + workspace = Path(str(tmpdir)) + agent = GitGoAgent(llm=chat_model, workspace=workspace) + + tool_names = set(agent.tools.keys()) + # Git tools + assert "git_status" in tool_names + assert "git_diff" in tool_names + assert "git_commit" in tool_names + assert "git_add" in tool_names + assert "git_switch" in tool_names + assert "git_create_branch" in tool_names + assert "git_log" in tool_names + assert "git_ls_files" in tool_names + # Go tools + assert "go_build" in tool_names + assert "go_test" in tool_names + assert "go_vet" in tool_names + assert "go_mod_tidy" in tool_names + assert "golangci_lint" in tool_names + # Code formatting + assert "gofmt_files" in tool_names + # Removed tools + assert "run_command" not in tool_names + assert "run_web_search" not in tool_names + assert "run_osti_search" not in tool_names + assert "run_arxiv_search" not in tool_names + # Configuration + assert "go" in agent.safe_codes + + +def test_git_agent_generic_has_only_git_tools(tmpdir): + """GitAgent with no language tools should only have git tools.""" + chat_model = ToolReadyFakeChatModel(messages=_message_stream("ok")) + workspace = Path(str(tmpdir)) + agent = GitAgent(llm=chat_model, workspace=workspace) + + tool_names = set(agent.tools.keys()) + # Git tools present + assert "git_status" in tool_names + assert "git_diff" in tool_names + assert "git_commit" in tool_names + assert "git_add" in tool_names + assert "git_switch" in tool_names + assert "git_create_branch" in tool_names + assert "git_log" in tool_names + assert "git_ls_files" in tool_names + # No Go tools + assert "go_build" not in tool_names + assert "go_test" not in tool_names + assert "gofmt_files" not in tool_names + assert "golangci_lint" not in tool_names + # Removed tools + assert "run_command" not in tool_names + assert "run_web_search" not in tool_names + + +def test_make_git_agent_go_matches_git_go_agent(tmpdir): + """make_git_agent(language='go') should produce the same tool set as GitGoAgent.""" + chat_model = ToolReadyFakeChatModel(messages=_message_stream("ok")) + workspace = Path(str(tmpdir)) + + go_agent = GitGoAgent(llm=chat_model, workspace=workspace) + factory_agent = make_git_agent( + llm=chat_model, language="go", workspace=workspace + ) + + assert set(go_agent.tools.keys()) == set(factory_agent.tools.keys()) + assert go_agent.safe_codes == factory_agent.safe_codes + + +def test_make_git_agent_generic(tmpdir): + """make_git_agent(language='generic') should produce a git-only agent.""" + chat_model = ToolReadyFakeChatModel(messages=_message_stream("ok")) + workspace = Path(str(tmpdir)) + + agent = make_git_agent( + llm=chat_model, language="generic", workspace=workspace + ) + + tool_names = set(agent.tools.keys()) + assert "git_status" in tool_names + assert "go_build" not in tool_names + + +def test_make_git_agent_unknown_language_defaults_to_git_only(tmpdir): + """make_git_agent with an unknown language should fall back to git-only.""" + chat_model = ToolReadyFakeChatModel(messages=_message_stream("ok")) + workspace = Path(str(tmpdir)) + + agent = make_git_agent(llm=chat_model, language="rust", workspace=workspace) + + tool_names = set(agent.tools.keys()) + # Should have git tools + assert "git_status" in tool_names + # Should not have language-specific tools + assert "go_build" not in tool_names + + +def test_make_git_agent_explicit_tools_bypass_registry(tmpdir): + """make_git_agent accepts explicit tools/prompt/safe_codes, bypassing registry.""" + chat_model = ToolReadyFakeChatModel(messages=_message_stream("ok")) + workspace = Path(str(tmpdir)) + + # Create agent with explicit safe codes for a hypothetical "adoc" language + agent = make_git_agent( + llm=chat_model, + language="adoc", # Not in registry + safe_codes=["asciidoc", "podman"], # Explicit safe codes + workspace=workspace, + ) + + # Should be git-only (adoc not in registry), but with custom safe codes + assert agent.safe_codes == {"asciidoc", "podman"} + tool_names = set(agent.tools.keys()) + assert "git_status" in tool_names diff --git a/tests/agents/test_git_go_agent/test_go_tools.py b/tests/agents/test_git_go_agent/test_go_tools.py new file mode 100644 index 00000000..ef32d053 --- /dev/null +++ b/tests/agents/test_git_go_agent/test_go_tools.py @@ -0,0 +1,291 @@ +"""Tests for Go tooling functions in git_tools module.""" + +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +try: + from ursa.tools.go_tools import ( + go_build, + go_mod_tidy, + go_test, + go_vet, + golangci_lint, + ) +except (ImportError, ModuleNotFoundError) as exc: + pytest.skip( + "Skipping legacy Go tooling tests: ursa.tools.go_tools is unavailable. " + f"Import error: {exc}", + allow_module_level=True, + ) + +from ursa.agents.base import AgentContext + + +@pytest.fixture +def mock_runtime(): + """Create a mock ToolRuntime with AgentContext.""" + runtime = MagicMock() + context = MagicMock(spec=AgentContext) + context.workspace = Path("/tmp/workspace") + runtime.context = context + runtime.store = None + return runtime + + +@pytest.fixture +def go_repo(tmpdir): + """Create a minimal Go repo for testing.""" + repo_dir = Path(tmpdir) / "test_repo" + repo_dir.mkdir(parents=True) + + # Create minimal go.mod + mod_file = repo_dir / "go.mod" + mod_file.write_text("module github.com/test/example\n\ngo 1.21\n") + + # Create a simple Go file + main_file = repo_dir / "main.go" + main_file.write_text( + 'package main\n\nimport "fmt"\n\nfunc main() {\n fmt.Println("Hello")\n}\n' + ) + + # Create a test file + test_file = repo_dir / "main_test.go" + test_file.write_text( + 'package main\n\nimport "testing"\n\nfunc TestHello(t *testing.T) {\n t.Log("Test")\n}\n' + ) + + return repo_dir + + +class TestGoTools: + """Test suite for Go tooling functions.""" + + def test_go_build_success(self, mock_runtime, go_repo): + """Test successful go build execution.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, stdout="", stderr="" + ) + + result = go_build.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "STDOUT:" in result + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "go" in args + assert "build" in args + assert "./..." in args + + def test_go_test_with_verbose(self, mock_runtime, go_repo): + """Test go test with verbose flag.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, stdout="ok\n", stderr="" + ) + + result = go_test.func( + mock_runtime, repo_path=str(go_repo.name), verbose=True + ) + + assert "STDOUT:" in result + assert "ok\n" in result + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "-v" in args + + def test_go_test_without_verbose(self, mock_runtime, go_repo): + """Test go test without verbose flag.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, stdout="", stderr="" + ) + + result = go_test.func( + mock_runtime, repo_path=str(go_repo.name), verbose=False + ) + + assert "STDOUT:" in result + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "-v" not in args + + def test_go_vet_success(self, mock_runtime, go_repo): + """Test successful go vet execution.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, stdout="", stderr="" + ) + + result = go_vet.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "STDOUT:" in result + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "go" in args + assert "vet" in args + + def test_go_mod_tidy_success(self, mock_runtime, go_repo): + """Test successful go mod tidy execution.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, stdout="", stderr="" + ) + + result = go_mod_tidy.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "STDOUT:" in result + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "go" in args + assert "mod" in args + assert "tidy" in args + + def test_go_build_timeout(self, mock_runtime, go_repo): + """Test go build timeout handling.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.side_effect = subprocess.TimeoutExpired( + ["go", "build", "./..."], 300 + ) + + result = go_build.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "timed out" in result.lower() + assert "300" in result + assert "5 minutes" in result + + def test_go_test_timeout(self, mock_runtime, go_repo): + """Test go test timeout handling.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.side_effect = subprocess.TimeoutExpired( + ["go", "test", "./..."], 600 + ) + + result = go_test.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "timed out" in result.lower() + assert "600" in result + + def test_golangci_lint_not_installed(self, mock_runtime, go_repo): + """Test golangci_lint when linter is not installed.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.side_effect = FileNotFoundError() + + result = golangci_lint.func( + mock_runtime, repo_path=str(go_repo.name) + ) + + assert "not installed" in result.lower() + assert "go install" in result.lower() + + def test_golangci_lint_with_config(self, mock_runtime, go_repo): + """Test golangci_lint detects and uses .golangci.yml.""" + # Create .golangci.yml in the repo + config_file = go_repo / ".golangci.yml" + config_file.write_text("linters:\n enable:\n - gofmt\n") + + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + # First call for version check, second for actual linting + mock_run.side_effect = [ + MagicMock( + returncode=0, + stdout="golangci-lint version 1.0.0\n", + stderr="", + ), + MagicMock(returncode=0, stdout="", stderr=""), + ] + + result = golangci_lint.func( + mock_runtime, repo_path=str(go_repo.name) + ) + + assert "STDOUT:" in result + # Verify the second call used the config file + second_call_args = mock_run.call_args_list[1][0][0] + assert "--config" in second_call_args + assert str( + config_file + ) in second_call_args or ".golangci.yml" in str(second_call_args) + + def test_golangci_lint_without_config(self, mock_runtime, go_repo): + """Test golangci_lint runs without config file.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + # First call for version check, second for actual linting + mock_run.side_effect = [ + MagicMock( + returncode=0, + stdout="golangci-lint version 1.0.0\n", + stderr="", + ), + MagicMock(returncode=0, stdout="", stderr=""), + ] + + result = golangci_lint.func( + mock_runtime, repo_path=str(go_repo.name) + ) + + assert "STDOUT:" in result + # Verify the second call did not use a config file + second_call_args = mock_run.call_args_list[1][0][0] + assert "--config" not in second_call_args + + +class TestGoToolsErrorHandling: + """Test error handling in Go tools.""" + + def test_go_build_general_error(self, mock_runtime, go_repo): + """Test go build general exception handling.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.side_effect = Exception("Some error") + + result = go_build.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "Error running go build" in result + + def test_go_test_general_error(self, mock_runtime, go_repo): + """Test go test general exception handling.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.side_effect = Exception("Some error") + + result = go_test.func(mock_runtime, repo_path=str(go_repo.name)) + + assert "Error running go test" in result + + def test_golangci_lint_version_check_error(self, mock_runtime, go_repo): + """Test golangci_lint version check failure.""" + mock_runtime.context.workspace = go_repo.parent + + with patch("ursa.tools.go_tools.subprocess.run") as mock_run: + mock_run.side_effect = Exception("Version check failed") + + result = golangci_lint.func( + mock_runtime, repo_path=str(go_repo.name) + ) + + assert "Error checking golangci-lint" in result diff --git a/tests/agents/test_planning_agent/test_planning_agent.py b/tests/agents/test_planning_agent/test_planning_agent.py index 8c6b0187..c2ae3e16 100644 --- a/tests/agents/test_planning_agent/test_planning_agent.py +++ b/tests/agents/test_planning_agent/test_planning_agent.py @@ -3,9 +3,31 @@ from ursa.agents.planning_agent import Plan, PlanningAgent -async def test_planning_agent_creates_structured_plan(chat_model, tmpdir): +class FakePlanningChatModel: + def model_copy(self, update=None): + return self + + def with_structured_output(self, schema, **kwargs): + class _Runner: + def invoke(self, messages): + return schema( + steps=[ + { + "name": "Add numbers", + "description": "Add 1 and 2.", + "requires_code": False, + "expected_outputs": ["sum"], + "success_criteria": ["sum equals 3"], + } + ] + ) + + return _Runner() + + +async def test_planning_agent_creates_structured_plan(tmpdir): planning_agent = PlanningAgent( - llm=chat_model.model_copy(update={"max_tokens": 4000}), + llm=FakePlanningChatModel(), workspace=tmpdir, max_reflection_steps=0, ) diff --git a/tests/conftest.py b/tests/conftest.py index 083a551c..d514ebd4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,14 @@ +import os + import pytest from dotenv import load_dotenv from langchain.chat_models import init_chat_model from langchain.embeddings import init_embeddings +from langchain_core.embeddings import FakeEmbeddings +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from pydantic import BaseModel @pytest.fixture(scope="session", autouse=True) @@ -16,8 +23,132 @@ def bind_kwargs(func, **kwargs): return model +def _message_stream(content: str): + while True: + yield AIMessage(content=content) + + +def _fake_structured(schema): + if isinstance(schema, type) and issubclass(schema, BaseModel): + if "steps" in schema.model_fields: + return schema( + steps=[ + { + "name": "Stub step", + "description": "Stubbed plan step for tests.", + "requires_code": False, + "expected_outputs": ["stub"], + "success_criteria": ["stub"], + } + ] + ) + return schema() + + annotations = getattr(schema, "__annotations__", {}) or {} + if "is_safe" in annotations and "reason" in annotations: + return { + "is_safe": True, + "reason": "Stubbed safety check for tests", + } + return {key: "stub" for key in annotations} + + +class FakeChatModel(GenericFakeChatModel): + @property + def model_name(self) -> str: + return "fake-chat" + + @property + def model(self) -> str: + return "fake-chat" + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + text = " ".join( + msg.content + for msg in messages + if hasattr(msg, "content") and isinstance(msg.content, str) + ).lower() + if "latex" in text or "\\documentclass" in text: + content = "\\documentclass{article}\n\\begin{document}\nStub\n\\end{document}" + else: + content = "ok" + + message = AIMessage( + content=content, + usage_metadata={ + "input_tokens": 1, + "output_tokens": 1, + "total_tokens": 2, + }, + ) + return ChatResult(generations=[ChatGeneration(message=message)]) + + def bind_tools(self, tools, **kwargs): + return self + + def with_structured_output(self, schema, **kwargs): + output = _fake_structured(schema) + + class _Runner: + def invoke(self, messages): + return output + + async def ainvoke(self, messages): + return output + + return _Runner() + + +class FakeEmbeddingModel(FakeEmbeddings): + @property + def model_name(self) -> str: + return "fake-embeddings" + + @property + def model(self) -> str: + return "fake-embeddings" + + +@pytest.fixture(autouse=True) +def _stub_model_init(monkeypatch): + if os.getenv("OPENAI_API_KEY"): + return + + def fake_init_chat_model(*args, **kwargs): + return FakeChatModel(messages=_message_stream("ok")) + + def fake_init_embeddings(*args, **kwargs): + return FakeEmbeddingModel(size=12) + + monkeypatch.setattr( + "langchain.chat_models.init_chat_model", fake_init_chat_model + ) + monkeypatch.setattr( + "langchain.embeddings.init_embeddings", fake_init_embeddings + ) + monkeypatch.setattr( + "ursa.cli.hitl.init_chat_model", fake_init_chat_model, raising=False + ) + monkeypatch.setattr( + "ursa.cli.hitl.init_embeddings", fake_init_embeddings, raising=False + ) + monkeypatch.setattr( + "ursa.agents.rag_agent.init_embeddings", + fake_init_embeddings, + raising=False, + ) + + @pytest.fixture(scope="function") def chat_model(): + if not os.getenv("OPENAI_API_KEY"): + model = FakeChatModel(messages=_message_stream("ok")) + model._testing_only_kwargs = { + "model": "fake:chat", + "max_tokens": 3000, + "temperature": 0.0, + } + return model return bind_kwargs( init_chat_model, model="openai:gpt-5-nano", @@ -28,6 +159,12 @@ def chat_model(): @pytest.fixture(scope="function") def embedding_model(): + if not os.getenv("OPENAI_API_KEY"): + model = FakeEmbeddingModel(size=12) + model._testing_only_kwargs = { + "model": "fake:embeddings", + } + return model return bind_kwargs( init_embeddings, model="openai:text-embedding-3-small", diff --git a/tests/tools/test_write_code_tool.py b/tests/tools/test_write_code_tool.py index b45687c0..0aaf531b 100644 --- a/tests/tools/test_write_code_tool.py +++ b/tests/tools/test_write_code_tool.py @@ -5,7 +5,11 @@ from langgraph.store.memory import InMemoryStore from tests.tools.utils import make_runtime -from ursa.tools.write_code_tool import edit_code, write_code +from ursa.tools.write_code_tool import ( + edit_code, + write_code, + write_code_with_repo, +) def test_write_code_records_store_entry( @@ -85,6 +89,34 @@ def test_edit_code_noop_when_old_code_missing( assert store.get(("workspace", "file_edit"), "script.py") is None +def test_write_code_with_repo_records_store_entry( + tmp_path: Path, chat_model: BaseChatModel +): + repo = tmp_path / "repo" + repo.mkdir() + store = InMemoryStore() + runtime = make_runtime( + tmp_path, + llm=chat_model, + store=store, + tool_call_id="tc-2", + thread_id="thread-2", + ) + + result = write_code_with_repo.func( + code="print(7)", + filename="repo/sample.py", + runtime=runtime, + repo_path=str(repo), + ) + + assert "written successfully" in result + item = store.get(("workspace", "file_edit"), "repo/sample.py") + assert item is not None + assert item.value["tool_call_id"] == "tc-2" + assert item.value["thread_id"] == "thread-2" + + def test_edit_code_missing_file(tmp_path: Path, chat_model: BaseChatModel): store = InMemoryStore() runtime = make_runtime( diff --git a/tests/tools/test_write_code_tool_validation.py b/tests/tools/test_write_code_tool_validation.py new file mode 100644 index 00000000..f991a2db --- /dev/null +++ b/tests/tools/test_write_code_tool_validation.py @@ -0,0 +1,399 @@ +"""Tests for write_code_tool path validation and file operations.""" + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from ursa.agents.base import AgentContext +from ursa.tools.write_code_tool import ( + _allow_unsafe_writes_enabled, + _validate_file_path, + edit_code, + write_code, + write_code_with_repo, +) + + +@pytest.fixture +def mock_runtime(): + """Create a mock ToolRuntime with AgentContext.""" + runtime = MagicMock() + context = MagicMock(spec=AgentContext) + runtime.context = context + runtime.store = None + runtime.tool_call_id = "test_tool_call" + runtime.config = {"metadata": {"thread_id": "test_thread"}} + return runtime + + +class TestPathValidation: + """Test path validation in write_code and edit_code.""" + + def test_valid_file_within_workspace(self, tmpdir): + """Test that a file within workspace is accepted.""" + workspace = Path(tmpdir) + filename = "test.py" + + result_path, error = _validate_file_path(filename, workspace) + + assert error is None + assert result_path is not None + assert ( + workspace in result_path.parents or result_path.parent == workspace + ) + + def test_valid_nested_file_within_workspace(self, tmpdir): + """Test that nested files within workspace are accepted.""" + workspace = Path(tmpdir) + filename = "src/main/test.py" + + result_path, error = _validate_file_path(filename, workspace) + + assert error is None + assert result_path is not None + + def test_path_traversal_attempt_rejected(self, tmpdir): + """Test that path traversal attempts are rejected.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + filename = "../../../etc/passwd" + + result_path, error = _validate_file_path(filename, workspace) + + assert error is not None + assert "outside workspace" in error.lower() + assert result_path is None + + def test_absolute_path_outside_workspace_rejected(self, tmpdir): + """Test that absolute paths outside workspace are rejected.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + filename = "/etc/passwd" + + _result_path, error = _validate_file_path(filename, workspace) + + assert error is not None + assert "outside workspace" in error.lower() + + def test_file_within_repo_path(self, tmpdir): + """Test that files within specified repo are accepted.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + filename = "myrepo/test.py" + + result_path, error = _validate_file_path(filename, workspace, repo) + + assert error is None + assert result_path is not None + + def test_file_outside_repo_path_rejected(self, tmpdir): + """Test that files outside specified repo are rejected.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + other = workspace / "other" + other.mkdir(parents=True) + filename = "other/test.py" + + _result_path, error = _validate_file_path(filename, workspace, repo) + + assert error is not None + assert "outside repository" in error.lower() + + def test_relative_repo_path(self, tmpdir): + """Test that relative repo paths are resolved correctly.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + filename = "myrepo/test.py" + + # Test with relative repo path + result_path, error = _validate_file_path( + filename, workspace, Path("myrepo") + ) + + assert error is None + assert result_path is not None + + def test_allow_unsafe_writes_allows_outside_workspace(self, tmpdir): + """Test unsafe writes can bypass workspace validation when enabled.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + filename = "../../../etc/passwd" + + result_path, error = _validate_file_path( + filename, + workspace, + allow_unsafe_writes=True, + ) + + assert error is None + assert result_path is not None + + +class TestUnsafeWriteEnvToggle: + """Test env var parsing for unsafe write mode.""" + + @pytest.mark.parametrize("value", ["1", "true", "TRUE", "yes", "on"]) + def test_truthy_values_enable_unsafe_writes(self, monkeypatch, value): + monkeypatch.setenv("URSA_ALLOW_UNSAFE_WRITES", value) + assert _allow_unsafe_writes_enabled() is True + + @pytest.mark.parametrize("value", ["0", "false", "no", "off", ""]) + def test_falsey_values_disable_unsafe_writes(self, monkeypatch, value): + monkeypatch.setenv("URSA_ALLOW_UNSAFE_WRITES", value) + assert _allow_unsafe_writes_enabled() is False + + def test_missing_env_defaults_to_safe_mode(self, monkeypatch): + monkeypatch.delenv("URSA_ALLOW_UNSAFE_WRITES", raising=False) + assert _allow_unsafe_writes_enabled() is False + + +class TestWriteCodePathValidation: + """Test write_code function with path validation.""" + + def test_write_code_within_workspace(self, mock_runtime, tmpdir): + """Test write_code works for files within workspace.""" + workspace = Path(tmpdir) + mock_runtime.context.workspace = workspace + + code = "print('hello')" + filename = "test.py" + + result = write_code.func(code, filename, mock_runtime) + + assert "successfully" in result.lower() + assert (workspace / filename).exists() + + def test_write_code_path_traversal_rejected(self, mock_runtime, tmpdir): + """Test write_code rejects path traversal.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + code = "print('hello')" + filename = "../../../etc/passwd" + + result = write_code.func(code, filename, mock_runtime) + + assert "failed" in result.lower() + assert "outside workspace" in result.lower() + + def test_write_code_nested_directory_creation(self, mock_runtime, tmpdir): + """Test write_code creates nested directories.""" + workspace = Path(tmpdir) + mock_runtime.context.workspace = workspace + + code = "print('hello')" + filename = "src/main/test.py" + + result = write_code.func(code, filename, mock_runtime) + + assert "successfully" in result.lower() + assert (workspace / filename).exists() + assert (workspace / "src" / "main").is_dir() + + def test_write_code_with_repo_boundary(self, mock_runtime, tmpdir): + """Test write_code_with_repo respects repo boundary when specified.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + code = "print('hello')" + filename = "myrepo/test.py" + + result = write_code_with_repo.func( + code, filename, mock_runtime, repo_path=str(repo) + ) + + assert "successfully" in result.lower() + + def test_write_code_outside_repo_boundary_rejected( + self, mock_runtime, tmpdir + ): + """Test write_code_with_repo rejects files outside repo boundary.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + other = workspace / "other" + other.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + code = "print('hello')" + filename = "other/test.py" + + result = write_code_with_repo.func( + code, filename, mock_runtime, repo_path=str(repo) + ) + + assert "failed" in result.lower() + assert "outside repository" in result.lower() + + def test_write_code_outside_workspace_allowed_when_env_enabled( + self, mock_runtime, tmpdir, monkeypatch + ): + """Test write_code allows unsafe writes when env toggle is enabled.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + monkeypatch.setenv("URSA_ALLOW_UNSAFE_WRITES", "1") + target = Path(tmpdir) / "outside.py" + + result = write_code.func("print('hello')", str(target), mock_runtime) + + assert "successfully" in result.lower() + assert target.exists() + + def test_write_code_repo_path_not_found_rejected( + self, mock_runtime, tmpdir + ): + """Test write_code_with_repo fails clearly when repo path does not exist.""" + workspace = Path(tmpdir) + mock_runtime.context.workspace = workspace + + missing_repo = workspace / "does-not-exist" + result = write_code_with_repo.func( + "print('hello')", + "test.py", + mock_runtime, + repo_path=str(missing_repo), + ) + + assert "failed" in result.lower() + assert "repository path not found" in result.lower() + + +class TestEditCodePathValidation: + """Test edit_code function with path validation.""" + + def test_edit_code_within_workspace(self, mock_runtime, tmpdir): + """Test edit_code works for files within workspace.""" + workspace = Path(tmpdir) + mock_runtime.context.workspace = workspace + + # Create initial file + filename = "test.py" + test_file = workspace / filename + test_file.write_text("x = 1\n") + + old_code = "x = 1" + new_code = "x = 2" + + result = edit_code.func(old_code, new_code, filename, mock_runtime) + + assert "successfully" in result.lower() + assert test_file.read_text() == "x = 2\n" + + def test_edit_code_path_traversal_rejected(self, mock_runtime, tmpdir): + """Test edit_code rejects path traversal.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + old_code = "x = 1" + new_code = "x = 2" + filename = "../../../etc/passwd" + + result = edit_code.func(old_code, new_code, filename, mock_runtime) + + assert "failed" in result.lower() + assert "outside workspace" in result.lower() + + def test_edit_code_with_repo_boundary(self, mock_runtime, tmpdir): + """Test edit_code respects repo boundary when specified.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + # Create initial file + filename = "myrepo/test.py" + test_file = workspace / filename + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("x = 1\n") + + old_code = "x = 1" + new_code = "x = 2" + + result = edit_code.func( + old_code, new_code, filename, mock_runtime, repo_path=str(repo) + ) + + assert "successfully" in result.lower() + + def test_edit_code_outside_repo_boundary_rejected( + self, mock_runtime, tmpdir + ): + """Test edit_code rejects files outside repo boundary.""" + workspace = Path(tmpdir) + repo = workspace / "myrepo" + repo.mkdir(parents=True) + other = workspace / "other" + other.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + old_code = "x = 1" + new_code = "x = 2" + filename = "other/test.py" + + result = edit_code.func( + old_code, new_code, filename, mock_runtime, repo_path=str(repo) + ) + + assert "failed" in result.lower() + assert "outside repository" in result.lower() + + def test_edit_code_outside_workspace_allowed_when_env_enabled( + self, mock_runtime, tmpdir, monkeypatch + ): + """Test edit_code allows unsafe edits when env toggle is enabled.""" + workspace = Path(tmpdir) / "workspace" + workspace.mkdir(parents=True) + mock_runtime.context.workspace = workspace + + monkeypatch.setenv("URSA_ALLOW_UNSAFE_WRITES", "1") + target = Path(tmpdir) / "outside.py" + target.write_text("x = 1\n") + + result = edit_code.func("x = 1", "x = 2", str(target), mock_runtime) + + assert "successfully" in result.lower() + assert target.read_text() == "x = 2\n" + + def test_edit_code_binary_file_returns_failure(self, mock_runtime, tmpdir): + """Test edit_code handles binary files without raising.""" + workspace = Path(tmpdir) + mock_runtime.context.workspace = workspace + + filename = "binary.bin" + test_file = workspace / filename + test_file.write_bytes(b"\xff\xfe\x00\x01") + + result = edit_code.func("x", "y", filename, mock_runtime) + + assert "failed" in result.lower() + assert "binary" in result.lower() + + def test_edit_code_repo_path_not_found_rejected(self, mock_runtime, tmpdir): + """Test edit_code fails clearly when repo path does not exist.""" + workspace = Path(tmpdir) + mock_runtime.context.workspace = workspace + + filename = "test.py" + (workspace / filename).write_text("x = 1\n") + missing_repo = workspace / "does-not-exist" + + result = edit_code.func( + "x = 1", + "x = 2", + filename, + mock_runtime, + repo_path=str(missing_repo), + ) + + assert "failed" in result.lower() + assert "repository path not found" in result.lower()