diff --git a/.gitignore b/.gitignore index 94adf289..2f8bdc7e 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,5 @@ Thumbs.db schema.graphql .opencode/ + +.claude/ diff --git a/README.md b/README.md index 12740cd2..94ce6fd8 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ [![Python](https://img.shields.io/pypi/pyversions/strix-agent?color=3776AB)](https://pypi.org/project/strix-agent/) [![PyPI](https://img.shields.io/pypi/v/strix-agent?color=10b981)](https://pypi.org/project/strix-agent/) -![PyPI Downloads](https://static.pepy.tech/personalized-badge/strix-agent?period=total&units=INTERNATIONAL_SYSTEM&left_color=GREY&right_color=RED&left_text=Downloads) [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) [![GitHub Stars](https://img.shields.io/github/stars/usestrix/strix)](https://github.com/usestrix/strix) @@ -167,6 +166,9 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and # Provide detailed instructions through file (e.g., rules of engagement, scope, exclusions) strix --target api.your-app.com --instruction-file ./instruction.md + +# Resume an interrupted scan +strix --target https://your-app.com --run-name my-scan --resume ``` ### 🤖 Headless Mode diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 67aeb383..53eded8e 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -205,6 +205,24 @@ async def agent_loop(self, task: str) -> dict[str, Any]: # noqa: PLR0912, PLR09 try: should_finish = await self._process_iteration(tracer) + + # Save checkpoint after successful iteration + try: + from strix.telemetry.checkpoint import save_checkpoint + from strix.telemetry.tracer import get_global_tracer + + tracer_instance = get_global_tracer() + if tracer_instance and hasattr(self, "state"): + run_dir = tracer_instance.get_run_dir() + scan_config = tracer_instance.scan_config or {} + save_checkpoint(run_dir, self.state, scan_config) + except Exception as exc: # noqa: BLE001 + logger.debug( + "Checkpoint save failed (non-fatal): %s", + exc, + exc_info=True, + ) + if should_finish: if self.non_interactive: self.state.set_completed({"success": True}) diff --git a/strix/interface/cli.py b/strix/interface/cli.py index 582f8116..81276700 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -17,7 +17,7 @@ from .utils import build_final_stats_text, build_live_stats_text, get_severity_color -async def run_cli(args: Any) -> None: # noqa: PLR0915 +async def run_cli(args: Any) -> None: # noqa: PLR0915, PLR0912 console = Console() start_text = Text() @@ -85,6 +85,64 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 if getattr(args, "local_sources", None): agent_config["local_sources"] = args.local_sources + # Check for resume from checkpoint + from pathlib import Path + + from pydantic import ValidationError + + from strix.agents.state import AgentState + from strix.telemetry.checkpoint import can_resume, load_checkpoint + + resume_from_checkpoint = False + restored_state = None + + if getattr(args, "resume", False): + run_dir = Path.cwd() / "strix_runs" / args.run_name + + if run_dir.exists() and can_resume(run_dir, scan_config): + checkpoint = load_checkpoint(run_dir) + + if checkpoint: + try: + agent_state_data = checkpoint["agent_state"] + restored_state = AgentState(**agent_state_data) + resume_from_checkpoint = True + + console.print() + resume_text = Text() + resume_text.append("✓ ", style="bold green") + resume_text.append("Resuming from checkpoint at iteration ", style="green") + resume_text.append( + f"{restored_state.iteration}/{restored_state.max_iterations}", + style="bold green", + ) + console.print(resume_text) + console.print() + + except ValidationError as e: + warn_text = Text() + warn_text.append("⚠ ", style="bold yellow") + warn_text.append( + f"Checkpoint validation failed: {e}. Starting fresh scan.", style="yellow" + ) + console.print() + console.print(warn_text) + console.print() + elif getattr(args, "resume", False): + warn_text = Text() + warn_text.append("⚠ ", style="bold yellow") + warn_text.append( + "--resume flag provided but no valid checkpoint found. Starting fresh scan.", + style="yellow", + ) + console.print() + console.print(warn_text) + console.print() + + # Add restored state to agent config if resuming + if resume_from_checkpoint and restored_state: + agent_config["state"] = restored_state + tracer = Tracer(args.run_name) tracer.set_scan_config(scan_config) diff --git a/strix/interface/main.py b/strix/interface/main.py index 1da6e54b..a6ad0c21 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -343,6 +343,16 @@ def parse_arguments() -> argparse.Namespace: ), ) + parser.add_argument( + "--resume", + action="store_true", + help=( + "Resume an interrupted scan from checkpoint. " + "Requires --run-name to match the interrupted scan. " + "If no valid checkpoint is found, starts a fresh scan." + ), + ) + args = parser.parse_args() if args.instruction and args.instruction_file: diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 69f0fc9d..b31c750e 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -81,12 +81,15 @@ class SplashScreen(Static): # type: ignore[misc] " ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝" ) - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, *args: Any, resume_info: dict[str, Any] | None = None, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self._animation_step = 0 self._animation_timer: Timer | None = None self._panel_static: Static | None = None self._version = "dev" + self._resume_info = resume_info def compose(self) -> ComposeResult: self._version = get_package_version() @@ -116,15 +119,22 @@ def _animate_start_line(self) -> None: self._panel_static.update(panel) def _build_panel(self, start_line: Text) -> Panel: - content = Group( + content_parts = [ Align.center(Text(self.BANNER.strip("\n"), style=self.PRIMARY_GREEN, justify="center")), Align.center(Text(" ")), Align.center(self._build_welcome_text()), Align.center(self._build_version_text()), Align.center(self._build_tagline_text()), - Align.center(Text(" ")), - Align.center(start_line.copy()), - ) + ] + + if self._resume_info: + content_parts.append(Align.center(Text(" "))) + content_parts.append(Align.center(self._build_resume_text())) + + content_parts.append(Align.center(Text(" "))) + content_parts.append(Align.center(start_line.copy())) + + content = Group(*content_parts) return Panel.fit(content, border_style=self.PRIMARY_GREEN, padding=(1, 6)) @@ -140,6 +150,17 @@ def _build_version_text(self) -> Text: def _build_tagline_text(self) -> Text: return Text("Open-source AI hackers for your apps", style=Style(color="white", dim=True)) + def _build_resume_text(self) -> Text: + if not self._resume_info: + return Text("") + + text = Text("✓ Resuming from iteration ", style=Style(color="#fbbf24", bold=True)) + text.append( + f"{self._resume_info['iteration']}/{self._resume_info['max_iterations']}", + style=Style(color=self.PRIMARY_GREEN, bold=True), + ) + return text + def _build_start_line_text(self, phase: int) -> Text: emphasize = phase % 2 == 1 base_style = Style(color="white", dim=not emphasize, bold=emphasize) @@ -280,6 +301,51 @@ def __init__(self, args: argparse.Namespace): self.scan_config = self._build_scan_config(args) self.agent_config = self._build_agent_config(args) + # Check for resume from checkpoint + from pathlib import Path + + from pydantic import ValidationError + + from strix.agents.state import AgentState + from strix.telemetry.checkpoint import can_resume, load_checkpoint + + self.resume_info: dict[str, Any] | None = None + + if getattr(args, "resume", False): + run_dir = Path.cwd() / "strix_runs" / args.run_name + + if run_dir.exists() and can_resume(run_dir, self.scan_config): + checkpoint = load_checkpoint(run_dir) + + if checkpoint: + try: + agent_state_data = checkpoint["agent_state"] + restored_state = AgentState(**agent_state_data) + self.agent_config["state"] = restored_state + + self.resume_info = { + "iteration": restored_state.iteration, + "max_iterations": restored_state.max_iterations, + } + + import logging + + logging.info( + f"Resuming from checkpoint at iteration " + f"{restored_state.iteration}/{restored_state.max_iterations}" + ) + + except ValidationError as e: + import logging + + logging.warning(f"Checkpoint validation failed: {e}. Starting fresh scan.") + elif getattr(args, "resume", False): + import logging + + logging.warning( + "--resume flag provided but no valid checkpoint found. Starting fresh scan." + ) + self.tracer = Tracer(self.scan_config["run_name"]) self.tracer.set_scan_config(self.scan_config) set_global_tracer(self.tracer) @@ -348,7 +414,7 @@ def signal_handler(_signum: int, _frame: Any) -> None: def compose(self) -> ComposeResult: if self.show_splash: - yield SplashScreen(id="splash_screen") + yield SplashScreen(id="splash_screen", resume_info=self.resume_info) def watch_show_splash(self, show_splash: bool) -> None: if not show_splash and self.is_mounted: diff --git a/strix/telemetry/checkpoint.py b/strix/telemetry/checkpoint.py new file mode 100644 index 00000000..a6b1e70c --- /dev/null +++ b/strix/telemetry/checkpoint.py @@ -0,0 +1,188 @@ +""" +Checkpoint management for scan resumption. + +This module provides functionality to save and restore agent state, +enabling scans to resume after interruption. +""" + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + +# Checkpoint format version - increment when schema changes +CHECKPOINT_VERSION = 1 + + +class CheckpointError(Exception): + """Raised when checkpoint operations fail.""" + + +def get_checkpoint_path(run_dir: Path) -> Path: + """ + Get path to checkpoint file for a run. + + Args: + run_dir: Run directory (e.g., strix_runs/my-scan) + + Returns: + Path to checkpoint.json + """ + return run_dir / "checkpoint.json" + + +def save_checkpoint( + run_dir: Path, + agent_state: Any, # AgentState from state.py + scan_config: dict[str, Any], + tracer_data: dict[str, Any] | None = None, +) -> None: + """ + Save checkpoint for resumption. + + Args: + run_dir: Run directory + agent_state: Current AgentState instance + scan_config: Scan configuration dict + tracer_data: Optional tracer metadata + + Note: + - Fails silently if save fails (logs error) + - Uses same error handling pattern as tracer.save_run_data() + """ + try: + checkpoint_path = get_checkpoint_path(run_dir) + + # Build checkpoint data + checkpoint = { + "version": CHECKPOINT_VERSION, + "created_at": datetime.now(UTC).isoformat(), + "scan_config": scan_config, + "agent_state": agent_state.model_dump( + mode="json" + ), # Pydantic serialization with JSON mode + "tracer_data": tracer_data or {}, + } + + # Write atomically: write to temp file, then rename + temp_path = checkpoint_path.with_suffix(".tmp") + with temp_path.open("w", encoding="utf-8") as f: + json.dump(checkpoint, f, indent=2) + + # Atomic rename (overwrites existing checkpoint) + temp_path.replace(checkpoint_path) + + logger.info(f"Saved checkpoint at iteration {agent_state.iteration}") + + except (OSError, RuntimeError, TypeError, ValueError): + # Match tracer.py error handling pattern + logger.exception("Failed to save checkpoint") + + +def load_checkpoint(run_dir: Path) -> dict[str, Any] | None: + """ + Load checkpoint for resumption. + + Args: + run_dir: Run directory + + Returns: + Checkpoint dict if valid, None if missing/invalid + + Note: + - Returns None on any error (graceful degradation) + - Validates checkpoint version and structure + """ + try: + checkpoint_path = get_checkpoint_path(run_dir) + + if not checkpoint_path.exists(): + logger.debug(f"No checkpoint found at {checkpoint_path}") + return None + + with checkpoint_path.open("r", encoding="utf-8") as f: + checkpoint = json.load(f) + + # Validate version + version = checkpoint.get("version") + if version != CHECKPOINT_VERSION: + logger.warning( + f"Checkpoint version mismatch: expected {CHECKPOINT_VERSION}, " + f"got {version}. Cannot resume." + ) + return None + + # Validate required fields + required = ["scan_config", "agent_state"] + missing = [field for field in required if field not in checkpoint] + if missing: + logger.warning(f"Checkpoint missing fields: {missing}. Cannot resume.") + return None + + logger.info("Loaded valid checkpoint") + + except (OSError, json.JSONDecodeError, KeyError, TypeError): + logger.exception("Failed to load checkpoint") + return None + else: + return checkpoint + + +def can_resume(run_dir: Path, current_scan_config: dict[str, Any]) -> bool: + """ + Check if scan can be resumed from checkpoint. + + Args: + run_dir: Run directory + current_scan_config: Configuration for current scan attempt + + Returns: + True if checkpoint exists and is compatible + """ + checkpoint = load_checkpoint(run_dir) + if not checkpoint: + return False + + # Check if scan already completed + agent_state_data = checkpoint.get("agent_state", {}) + if agent_state_data.get("completed", False): + logger.info("Checkpoint found but scan already completed") + return False + + # Validate target compatibility (targets should match) + saved_config = checkpoint.get("scan_config", {}) + saved_targets = saved_config.get("targets", []) + current_targets = current_scan_config.get("targets", []) + + # Simple comparison: same number of targets + if len(saved_targets) != len(current_targets): + logger.warning( + f"Target mismatch: checkpoint has {len(saved_targets)} targets, " + f"current scan has {len(current_targets)}. Cannot resume." + ) + return False + + return True + + +def delete_checkpoint(run_dir: Path) -> None: + """ + Delete checkpoint file. + + Args: + run_dir: Run directory + + Note: + - Silently succeeds if checkpoint doesn't exist + """ + try: + checkpoint_path = get_checkpoint_path(run_dir) + if checkpoint_path.exists(): + checkpoint_path.unlink() + logger.info("Deleted checkpoint") + except OSError as e: + logger.warning(f"Failed to delete checkpoint: {e}") diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 6da30d53..1ce8a607 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -208,6 +208,19 @@ def save_run_data(self, mark_complete: bool = False) -> None: if mark_complete: self.end_time = datetime.now(UTC).isoformat() + # Only delete checkpoint if scan actually completed successfully + # (not just cleanup on exit/interrupt) + if self.scan_results and self.scan_results.get("scan_completed"): + try: + from strix.telemetry.checkpoint import delete_checkpoint + + delete_checkpoint(run_dir) + logger.info("Deleted checkpoint after successful scan completion") + except Exception: # noqa: BLE001 + logger.debug( + "Checkpoint cleanup failed (non-fatal)" + ) # Checkpoint cleanup failure is not critical + if self.final_scan_result: penetration_test_report_file = run_dir / "penetration_test_report.md" with penetration_test_report_file.open("w", encoding="utf-8") as f: