diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index f955892f..f7a7e4c3 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -22,6 +22,96 @@ from .state import AgentState +def _log_live_agent_created( + agent_id: str, + agent_name: str, + task: str, + parent_id: str | None, + agent_type: str | None, +) -> None: + """Log agent creation to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_agent_created( + agent_id=agent_id, + agent_name=agent_name, + task=task, + parent_id=parent_id, + agent_type=agent_type, + ) + except Exception: # noqa: BLE001, S110 + pass + + +def _log_live_agent_completed( + agent_id: str, + status: str, + result: dict[str, Any] | None = None, + error_message: str | None = None, +) -> None: + """Log agent completion to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_agent_completed( + agent_id=agent_id, + status=status, + result=result, + error_message=error_message, + ) + except Exception: # noqa: BLE001, S110 + pass + + +def _log_live_state_change( + agent_id: str, + field: str, + old_value: Any, + new_value: Any, +) -> None: + """Log agent state change to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_agent_state_change( + agent_id=agent_id, + field=field, + old_value=old_value, + new_value=new_value, + ) + except Exception: # noqa: BLE001, S110 + pass + + +def _log_live_message( + agent_id: str | None, + role: str, + content: str, + metadata: dict[str, Any] | None = None, +) -> None: + """Log a chat message to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_message( + agent_id=agent_id, + role=role, + content=content if isinstance(content, str) else str(content), + metadata=metadata, + ) + except Exception: # noqa: BLE001, S110 + pass + + logger = logging.getLogger(__name__) @@ -112,6 +202,15 @@ def __init__(self, config: dict[str, Any]): ) tracer.update_tool_execution(execution_id=exec_id, status="completed", result={}) + # Log to live tracer + _log_live_agent_created( + agent_id=self.state.agent_id, + agent_name=self.state.agent_name, + task=self.state.task, + parent_id=self.state.parent_id, + agent_type=self.__class__.__name__, + ) + self._add_to_agents_graph() def _add_to_agents_graph(self) -> None: @@ -284,17 +383,38 @@ async def _enter_waiting_state( error_occurred: bool = False, was_cancelled: bool = False, ) -> None: + old_waiting = self.state.waiting_for_input self.state.enter_waiting_state() + # Determine status for logging + if task_completed: + status = "completed" + elif error_occurred: + status = "error" + elif was_cancelled: + status = "stopped" + else: + status = "stopped" + if tracer: - if task_completed: - tracer.update_agent_status(self.state.agent_id, "completed") - elif error_occurred: - tracer.update_agent_status(self.state.agent_id, "error") - elif was_cancelled: - tracer.update_agent_status(self.state.agent_id, "stopped") - else: - tracer.update_agent_status(self.state.agent_id, "stopped") + tracer.update_agent_status(self.state.agent_id, status) + + # Log state change to live tracer + _log_live_state_change( + agent_id=self.state.agent_id, + field="waiting_for_input", + old_value=old_waiting, + new_value=True, + ) + + # Log completion to live tracer if task completed or error + if task_completed or error_occurred: + _log_live_agent_completed( + agent_id=self.state.agent_id, + status=status, + result=self.state.final_result, + error_message=self.state.errors[-1] if self.state.errors else None, + ) if task_completed: self.state.add_message( @@ -344,6 +464,13 @@ async def _initialize_sandbox_and_state(self, task: str) -> None: self.state.add_message("user", task) + # Log user message to live tracer + _log_live_message( + agent_id=self.state.agent_id, + role="user", + content=task, + ) + async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: final_response = None @@ -381,6 +508,14 @@ async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: agent_id=self.state.agent_id, ) + # Log message to live tracer + _log_live_message( + agent_id=self.state.agent_id, + role="assistant", + content=final_response.content or "", + metadata={"has_thinking": bool(thinking_blocks)}, + ) + actions = ( final_response.tool_invocations if hasattr(final_response, "tool_invocations") and final_response.tool_invocations @@ -418,6 +553,14 @@ async def _execute_actions(self, actions: list[Any], tracer: Optional["Tracer"]) self.state.set_completed({"success": True}) if tracer: tracer.update_agent_status(self.state.agent_id, "completed") + + # Log agent completion to live tracer + _log_live_agent_completed( + agent_id=self.state.agent_id, + status="completed", + result={"success": True}, + ) + if self.non_interactive and self.state.parent_id is None: return True return True diff --git a/strix/config/config.py b/strix/config/config.py index 387834be..2f1c0455 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -45,6 +45,11 @@ class Config: # Telemetry strix_telemetry = "1" + # Live Tracing + strix_trace = None # Enable live tracing (set to "1" or "true") + strix_trace_output = None # Custom trace output path (defaults to strix_runs//trace.jsonl) + strix_redact_secrets = None # Redact secrets in trace output (set to "1" or "true") + # Config file override (set via --config CLI arg) _config_file_override: Path | None = None diff --git a/strix/interface/cli.py b/strix/interface/cli.py index 4b5d109f..54d2c30f 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -107,12 +107,29 @@ def display_vulnerability(report: dict[str, Any]) -> None: def cleanup_on_exit() -> None: from strix.runtime import cleanup_runtime + from strix.telemetry.live_tracer import get_live_tracer, set_live_tracer tracer.cleanup() + + # Clean up live tracer + live_tracer = get_live_tracer() + if live_tracer: + live_tracer.close() + set_live_tracer(None) + cleanup_runtime() def signal_handler(_signum: int, _frame: Any) -> None: + from strix.telemetry.live_tracer import get_live_tracer, set_live_tracer + tracer.cleanup() + + # Clean up live tracer + live_tracer = get_live_tracer() + if live_tracer: + live_tracer.close() + set_live_tracer(None) + sys.exit(1) atexit.register(cleanup_on_exit) diff --git a/strix/interface/main.py b/strix/interface/main.py index edd7dd5f..e4e4ef4c 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -40,6 +40,7 @@ ) from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402 from strix.telemetry import posthog # noqa: E402 +from strix.telemetry.live_tracer import LiveTracer, set_live_tracer # noqa: E402 from strix.telemetry.tracer import get_global_tracer # noqa: E402 @@ -363,6 +364,45 @@ def parse_arguments() -> argparse.Namespace: help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json", ) + parser.add_argument( + "--trace", + action="store_true", + help=( + "Enable live tracing mode. Creates a complete JSONL audit trail of the run " + "including LLM requests/responses, tool calls, and agent events. " + "Output defaults to strix_runs//trace.jsonl. " + "Can also be enabled via STRIX_TRACE=1 environment variable." + ), + ) + + parser.add_argument( + "--trace-output", + type=str, + help=( + "Custom path for trace output file (requires --trace). " + "Defaults to strix_runs//trace.jsonl." + ), + ) + + parser.add_argument( + "--redact-secrets", + action="store_true", + help=( + "Redact sensitive information (API keys, tokens, passwords) in trace output. " + "Recommended when sharing traces externally. " + "Can also be enabled via STRIX_REDACT_SECRETS=1 environment variable." + ), + ) + + parser.add_argument( + "--trace-verbose", + action="store_true", + help=( + "Output human-readable trace to console (requires --non-interactive). " + "Shows tool calls, agent actions, and LLM activity in real-time." + ), + ) + args = parser.parse_args() if args.instruction and args.instruction_file: @@ -370,6 +410,25 @@ def parse_arguments() -> argparse.Namespace: "Cannot specify both --instruction and --instruction-file. Use one or the other." ) + if args.trace_output and not args.trace: + parser.error("--trace-output requires --trace to be enabled.") + + if args.trace_verbose and not args.non_interactive: + parser.error("--trace-verbose requires --non-interactive mode.") + + # Check environment variables for tracing options + trace_env = Config.get("strix_trace") + if trace_env and trace_env.lower() in ("1", "true", "yes"): + args.trace = True + + redact_env = Config.get("strix_redact_secrets") + if redact_env and redact_env.lower() in ("1", "true", "yes"): + args.redact_secrets = True + + # Check environment variable for trace output path + if not args.trace_output: + args.trace_output = Config.get("strix_trace_output") + if args.instruction_file: instruction_path = Path(args.instruction_file) try: @@ -555,6 +614,25 @@ def main() -> None: has_instructions=bool(args.instruction), ) + # Initialize live tracer if enabled + live_tracer: LiveTracer | None = None + trace_verbose = getattr(args, "trace_verbose", False) + if getattr(args, "trace", False) or trace_verbose: + live_tracer = LiveTracer( + output_path=getattr(args, "trace_output", None), + run_name=args.run_name, + redact_secrets=getattr(args, "redact_secrets", False), + verbose=trace_verbose, + ) + set_live_tracer(live_tracer) + + console = Console() + if not trace_verbose: + console.print(f"[dim]Live trace enabled:[/] {live_tracer.output_path}") + if getattr(args, "redact_secrets", False): + console.print("[dim]Secret redaction:[/] enabled") + console.print() + exit_reason = "user_exit" try: if args.non_interactive: @@ -572,6 +650,11 @@ def main() -> None: if tracer: posthog.end(tracer, exit_reason=exit_reason) + # Close live tracer + if live_tracer: + live_tracer.close() + set_live_tracer(None) + results_path = Path("strix_runs") / args.run_name display_completion_message(args, results_path) diff --git a/strix/interface/tui.py b/strix/interface/tui.py index 102693b8..a2fb643b 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -743,12 +743,29 @@ def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]: def _setup_cleanup_handlers(self) -> None: def cleanup_on_exit() -> None: from strix.runtime import cleanup_runtime + from strix.telemetry.live_tracer import get_live_tracer, set_live_tracer self.tracer.cleanup() + + # Clean up live tracer + live_tracer = get_live_tracer() + if live_tracer: + live_tracer.close() + set_live_tracer(None) + cleanup_runtime() def signal_handler(_signum: int, _frame: Any) -> None: + from strix.telemetry.live_tracer import get_live_tracer, set_live_tracer + self.tracer.cleanup() + + # Clean up live tracer + live_tracer = get_live_tracer() + if live_tracer: + live_tracer.close() + set_live_tracer(None) + sys.exit(0) atexit.register(cleanup_on_exit) diff --git a/strix/llm/llm.py b/strix/llm/llm.py index 311de35e..c99f8017 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,4 +1,5 @@ import asyncio +import time from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any @@ -129,6 +130,10 @@ async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResp accumulated = "" chunks: list[Any] = [] done_streaming = 0 + start_time = time.perf_counter() + + # Log LLM request to live tracer + self._log_llm_request(messages) self._total_stats.requests += 1 response = await acompletion(**self._build_completion_args(messages), stream=True) @@ -155,11 +160,18 @@ async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResp if chunks: self._update_usage_stats(stream_chunk_builder(chunks)) + duration_ms = (time.perf_counter() - start_time) * 1000 accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated)) + tool_invocations = parse_tool_invocations(accumulated) + thinking_blocks = self._extract_thinking(chunks) + + # Log LLM response to live tracer + self._log_llm_response(accumulated, tool_invocations, thinking_blocks, duration_ms) + yield LLMResponse( content=accumulated, - tool_invocations=parse_tool_invocations(accumulated), - thinking_blocks=self._extract_thinking(chunks), + tool_invocations=tool_invocations, + thinking_blocks=thinking_blocks, ) def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -271,8 +283,80 @@ def _raise_error(self, e: Exception) -> None: from strix.telemetry import posthog posthog.error("llm_error", type(e).__name__) + + # Log error to live tracer + self._log_llm_error(e) + raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e + def _log_llm_request(self, messages: list[dict[str, Any]]) -> None: + """Log LLM request to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_llm_request( + agent_id=self.agent_id, + model=self.config.model_name, + messages=messages, + metadata={ + "agent_name": self.agent_name, + "reasoning_effort": self._reasoning_effort, + }, + ) + except Exception: # noqa: BLE001, S110 + pass + + def _log_llm_response( + self, + content: str, + tool_invocations: list[dict[str, Any]] | None, + thinking_blocks: list[dict[str, Any]] | None, + duration_ms: float, + ) -> None: + """Log LLM response to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + usage = None + if self._total_stats: + usage = { + "input_tokens": self._total_stats.input_tokens, + "output_tokens": self._total_stats.output_tokens, + "cached_tokens": self._total_stats.cached_tokens, + } + + tracer.log_llm_response( + agent_id=self.agent_id, + content=content, + usage=usage, + tool_invocations=tool_invocations, + thinking_blocks=thinking_blocks, + duration_ms=duration_ms, + ) + except Exception: # noqa: BLE001, S110 + pass + + def _log_llm_error(self, error: Exception) -> None: + """Log LLM error to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + retryable = self._should_retry(error) + tracer.log_llm_error( + agent_id=self.agent_id, + error_type=type(error).__name__, + error_message=str(error), + retryable=retryable, + ) + except Exception: # noqa: BLE001, S110 + pass + def _is_anthropic(self) -> bool: if not self.config.model_name: return False diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index b783dccc..f71c8ab2 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -78,6 +78,26 @@ def _recover_container_state(self, container: Container) -> None: if port_bindings.get(port_key): self._tool_server_port = int(port_bindings[port_key][0]["HostPort"]) + def _start_tool_server(self, container: Container) -> None: + """Start the tool server inside the container.""" + try: + token = self._tool_server_token + port = CONTAINER_TOOL_SERVER_PORT + cmd = ( + f"cd /app && " + f"PYTHONPATH=/app STRIX_SANDBOX_MODE=true " + f"/app/venv/bin/python -m strix.runtime.tool_server " + f"--token={token} --host=0.0.0.0 --port={port}" + ) + container.exec_run( + ["/bin/bash", "-c", cmd], + detach=True, + user="pentester", + environment={"PYTHONUNBUFFERED": "1"}, + ) + except (DockerException, RequestsConnectionError, RequestsTimeout): + pass + def _wait_for_tool_server(self, max_retries: int = 30, timeout: int = 5) -> None: host = self._resolve_docker_host() health_url = f"http://{host}:{self._tool_server_port}/health" @@ -135,6 +155,7 @@ def _create_container(self, scan_id: str, max_retries: int = 2) -> Container: labels={"strix-scan-id": scan_id}, environment={ "PYTHONUNBUFFERED": "1", + "CAIDO_PORT": "48080", "TOOL_SERVER_PORT": str(CONTAINER_TOOL_SERVER_PORT), "TOOL_SERVER_TOKEN": self._tool_server_token, "STRIX_SANDBOX_EXECUTION_TIMEOUT": str(execution_timeout), @@ -145,6 +166,7 @@ def _create_container(self, scan_id: str, max_retries: int = 2) -> Container: ) self._scan_container = container + self._start_tool_server(container) self._wait_for_tool_server() except (DockerException, RequestsConnectionError, RequestsTimeout) as e: @@ -184,10 +206,11 @@ def _get_or_create_container(self, scan_id: str) -> Container: self._scan_container = container self._recover_container_state(container) + self._start_tool_server(container) + self._wait_for_tool_server() + return container except NotFound: pass - else: - return container try: containers = self.client.containers.list( @@ -201,6 +224,8 @@ def _get_or_create_container(self, scan_id: str) -> Container: self._scan_container = container self._recover_container_state(container) + self._start_tool_server(container) + self._wait_for_tool_server() return container except DockerException: pass diff --git a/strix/telemetry/__init__.py b/strix/telemetry/__init__.py index 0537f61f..ba931b60 100644 --- a/strix/telemetry/__init__.py +++ b/strix/telemetry/__init__.py @@ -1,10 +1,18 @@ from . import posthog +from .console_tracer import ConsoleTracer +from .live_tracer import LiveTracer, get_live_tracer, set_live_tracer +from .redactor import SecretRedactor from .tracer import Tracer, get_global_tracer, set_global_tracer __all__ = [ + "ConsoleTracer", + "LiveTracer", + "SecretRedactor", "Tracer", "get_global_tracer", + "get_live_tracer", "posthog", "set_global_tracer", + "set_live_tracer", ] diff --git a/strix/telemetry/console_tracer.py b/strix/telemetry/console_tracer.py new file mode 100644 index 00000000..8f191bec --- /dev/null +++ b/strix/telemetry/console_tracer.py @@ -0,0 +1,340 @@ +"""Console tracer for human-readable verbose output.""" + +from datetime import datetime +from typing import Any + +from rich.console import Console +from rich.text import Text + + +class ConsoleTracer: + """Outputs human-readable trace events to console in real-time.""" + + def __init__(self, console: Console | None = None): + self.console = console or Console() + + def _timestamp(self) -> str: + """Get current timestamp in HH:MM:SS format.""" + return datetime.now().strftime("%H:%M:%S") + + def _print_event(self, text: Text) -> None: + """Print an event with timestamp prefix.""" + output = Text() + output.append(f"[{self._timestamp()}] ", style="dim") + output.append_text(text) + self.console.print(output) + + def log_trace_start(self, run_id: str, target: str | None = None) -> None: + """Log trace start event.""" + text = Text() + text.append("▶ ", style="bold #22c55e") + text.append("Trace started", style="bold #22c55e") + if run_id: + text.append(f" ({run_id})", style="dim") + self._print_event(text) + + def log_trace_end(self, run_id: str) -> None: + """Log trace end event.""" + text = Text() + text.append("■ ", style="bold #6b7280") + text.append("Trace ended", style="#6b7280") + self._print_event(text) + + def log_llm_request( + self, + model: str, + message_count: int, + agent_id: str | None = None, + ) -> None: + """Log LLM request event.""" + text = Text() + text.append("🤖 ", style="") + text.append("LLM Request", style="bold #60a5fa") + text.append(" → ", style="dim") + text.append(model, style="#60a5fa") + text.append(f" ({message_count} messages)", style="dim") + if agent_id: + text.append(f" [{agent_id[:12]}]", style="dim italic") + self._print_event(text) + + def log_llm_response( + self, + model: str, + tokens: dict[str, int] | None = None, + agent_id: str | None = None, + ) -> None: + """Log LLM response event.""" + text = Text() + text.append("✓ ", style="#22c55e") + text.append("LLM Response", style="#22c55e") + if tokens: + input_tokens = tokens.get("input", 0) + output_tokens = tokens.get("output", 0) + text.append(f" ({input_tokens}→{output_tokens} tokens)", style="dim") + self._print_event(text) + + def log_llm_error(self, error: str, model: str | None = None) -> None: + """Log LLM error event.""" + text = Text() + text.append("✗ ", style="#ef4444") + text.append("LLM Error", style="bold #ef4444") + text.append(f": {error[:100]}", style="#ef4444") + self._print_event(text) + + def log_tool_call( + self, + tool_name: str, + args: dict[str, Any] | None = None, + agent_id: str | None = None, + ) -> None: + """Log tool call event.""" + text = Text() + + # Use different icons/styles based on tool type + if tool_name == "terminal_execute": + text.append(">_ ", style="dim") + command = (args or {}).get("command", "") + is_input = (args or {}).get("is_input", False) + if is_input: + text.append(">>> ", style="#3b82f6") + else: + text.append("$ ", style="#22c55e") + # Truncate long commands + if len(command) > 100: + command = command[:97] + "..." + text.append(command, style="bold") + + elif tool_name == "think": + text.append("🧠 ", style="") + text.append("Thinking", style="bold #a855f7") + thought = (args or {}).get("thought", "") + if thought: + # Truncate long thoughts + if len(thought) > 80: + thought = thought[:77] + "..." + text.append(f": {thought}", style="italic dim") + + elif tool_name == "create_agent": + text.append("◈ ", style="#a78bfa") + text.append("spawning ", style="dim") + name = (args or {}).get("name", "Agent") + text.append(name, style="bold #a78bfa") + task = (args or {}).get("task", "") + if task: + text.append("\n ") + if len(task) > 80: + task = task[:77] + "..." + text.append(task, style="dim") + + elif tool_name == "send_message_to_agent": + text.append("→ ", style="#60a5fa") + agent_target = (args or {}).get("agent_id", "") + if agent_target: + text.append(f"to {agent_target[:12]}", style="dim") + message = (args or {}).get("message", "") + if message: + text.append("\n ") + if len(message) > 80: + message = message[:77] + "..." + text.append(message, style="dim") + + elif tool_name == "wait_for_message": + text.append("○ ", style="#6b7280") + text.append("waiting", style="dim") + reason = (args or {}).get("reason", "") + if reason: + text.append(f": {reason[:60]}", style="dim italic") + + elif tool_name == "agent_finish": + success = (args or {}).get("success", True) + if success: + text.append("◆ ", style="#22c55e") + text.append("Agent completed", style="bold #22c55e") + else: + text.append("◆ ", style="#ef4444") + text.append("Agent failed", style="bold #ef4444") + summary = (args or {}).get("result_summary", "") + if summary: + text.append("\n ") + if len(summary) > 80: + summary = summary[:77] + "..." + text.append(summary, style="bold") + + elif tool_name == "browser_action": + text.append("🌐 ", style="") + action = (args or {}).get("action", "browse") + text.append(f"Browser: {action}", style="bold #f59e0b") + url = (args or {}).get("url", "") + if url: + if len(url) > 60: + url = url[:57] + "..." + text.append(f" → {url}", style="dim") + + elif tool_name == "send_request": + text.append("📡 ", style="") + method = (args or {}).get("method", "GET") + url = (args or {}).get("url", "") + text.append(f"{method}", style="bold #f59e0b") + if url: + if len(url) > 60: + url = url[:57] + "..." + text.append(f" {url}", style="dim") + + elif tool_name == "file_edit": + text.append("📝 ", style="") + text.append("Edit file", style="bold #22d3ee") + path = (args or {}).get("path", "") + if path: + text.append(f": {path}", style="dim") + + elif tool_name == "web_search": + text.append("🔍 ", style="") + text.append("Web search", style="bold #a855f7") + query = (args or {}).get("query", "") + if query: + if len(query) > 60: + query = query[:57] + "..." + text.append(f": {query}", style="dim italic") + + else: + # Generic tool format + text.append("⚡ ", style="dim") + text.append(tool_name, style="bold") + if args: + # Show first arg value as preview + for key, value in args.items(): + if isinstance(value, str) and value: + preview = value[:40] + "..." if len(value) > 40 else value + text.append(f" ({key}={preview})", style="dim") + break + + self._print_event(text) + + def log_tool_result( + self, + tool_name: str, + result: Any, + error: str | None = None, + agent_id: str | None = None, + ) -> None: + """Log tool result event.""" + # Only log errors or significant results + if error: + text = Text() + text.append(" ✗ ", style="#ef4444") + text.append("Error: ", style="#ef4444") + if len(error) > 80: + error = error[:77] + "..." + text.append(error, style="dim #ef4444") + self._print_event(text) + elif tool_name == "terminal_execute" and result: + # Show terminal output preview + content = "" + if isinstance(result, dict): + content = result.get("content", "") + elif isinstance(result, str): + content = result + if content: + lines = content.strip().split("\n") + if lines: + text = Text() + # Show first few lines of output + for i, line in enumerate(lines[:3]): + if len(line) > 80: + line = line[:77] + "..." + text.append(f" {line}\n", style="dim") + if len(lines) > 3: + text.append(f" ... ({len(lines) - 3} more lines)", style="dim italic") + self.console.print(text) + + def log_agent_created( + self, + agent_id: str, + agent_name: str, + task: str | None = None, + ) -> None: + """Log agent created event.""" + text = Text() + text.append("● ", style="#a78bfa") + text.append("Agent created: ", style="dim") + text.append(agent_name, style="bold #a78bfa") + text.append(f" [{agent_id[:12]}]", style="dim italic") + self._print_event(text) + + def log_agent_completed( + self, + agent_id: str, + agent_name: str, + success: bool = True, + ) -> None: + """Log agent completed event.""" + text = Text() + if success: + text.append("● ", style="#22c55e") + text.append("Agent finished: ", style="dim") + text.append(agent_name, style="#22c55e") + else: + text.append("● ", style="#ef4444") + text.append("Agent failed: ", style="dim") + text.append(agent_name, style="#ef4444") + self._print_event(text) + + def log_state_change( + self, + agent_id: str, + state: str, + details: str | None = None, + ) -> None: + """Log agent state change event.""" + # Only log significant state changes + if state in ("waiting", "error", "completed"): + text = Text() + text.append("◇ ", style="dim") + text.append(f"State: {state}", style="dim") + if details: + text.append(f" - {details[:50]}", style="dim italic") + self._print_event(text) + + def log_message( + self, + from_agent: str, + to_agent: str | None = None, + content: str | None = None, + ) -> None: + """Log inter-agent message event.""" + text = Text() + text.append("💬 ", style="") + text.append(f"{from_agent[:12]}", style="bold") + if to_agent: + text.append(" → ", style="dim") + text.append(f"{to_agent[:12]}", style="bold") + if content: + text.append("\n ") + if len(content) > 80: + content = content[:77] + "..." + text.append(content, style="dim") + self._print_event(text) + + def log_vulnerability_found( + self, + vuln_id: str, + title: str, + severity: str | None = None, + ) -> None: + """Log vulnerability found event.""" + text = Text() + text.append("🚨 ", style="") + text.append("VULNERABILITY", style="bold #ef4444") + text.append(f" [{vuln_id}]", style="#ef4444") + if severity: + severity_colors = { + "critical": "#dc2626", + "high": "#ef4444", + "medium": "#f59e0b", + "low": "#22c55e", + "info": "#6b7280", + } + color = severity_colors.get(severity.lower(), "#6b7280") + text.append(f" ({severity})", style=f"bold {color}") + text.append(f"\n {title}", style="bold") + self._print_event(text) diff --git a/strix/telemetry/live_tracer.py b/strix/telemetry/live_tracer.py new file mode 100644 index 00000000..84321b4b --- /dev/null +++ b/strix/telemetry/live_tracer.py @@ -0,0 +1,563 @@ +""" +Live Tracer Module - Real-time JSONL audit trail for pentesting runs. + +Streams structured events to disk as they happen, capturing: +- LLM requests and responses +- Tool calls and results +- Agent lifecycle events +- State changes + +Enable via --trace flag or STRIX_TRACE=1 environment variable. +Use --trace-verbose for human-readable console output. +""" + +import json +import logging +import threading +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from strix.telemetry.console_tracer import ConsoleTracer +from strix.telemetry.redactor import SecretRedactor + + +if TYPE_CHECKING: + from io import TextIOWrapper + + +logger = logging.getLogger(__name__) + +_global_live_tracer: "LiveTracer | None" = None + + +def get_live_tracer() -> "LiveTracer | None": + """Get the global live tracer instance.""" + return _global_live_tracer + + +def set_live_tracer(tracer: "LiveTracer | None") -> None: + """Set the global live tracer instance.""" + global _global_live_tracer # noqa: PLW0603 + _global_live_tracer = tracer + + +class LiveTracer: + """ + Real-time JSONL tracer that streams events to disk. + + Each event is a JSON object written as a single line with: + - timestamp: ISO 8601 timestamp + - trace_id: Unique ID for the entire run + - event_id: Unique ID for each event + - sequence: Monotonically increasing sequence number + - event_type: Type of event + - agent_id: Which agent this event belongs to (if applicable) + - data: Event-specific payload + """ + + def __init__( + self, + output_path: Path | str | None = None, + run_name: str | None = None, + redact_secrets: bool = False, + verbose: bool = False, + ): + self.trace_id = f"trace-{uuid4().hex[:12]}" + self.run_name = run_name or self.trace_id + self.start_time = datetime.now(UTC).isoformat() + self.redact_secrets = redact_secrets + self.verbose = verbose + + self._sequence = 0 + self._sequence_lock = threading.Lock() + self._file_lock = threading.Lock() + self._file: "TextIOWrapper | None" = None + self._closed = False + + # Initialize redactor + self._redactor = SecretRedactor() if redact_secrets else None + + # Initialize console tracer for verbose output + self._console_tracer = ConsoleTracer() if verbose else None + + # Determine output path + if output_path: + self._output_path = Path(output_path) + else: + runs_dir = Path.cwd() / "strix_runs" + runs_dir.mkdir(exist_ok=True) + run_dir = runs_dir / self.run_name + run_dir.mkdir(exist_ok=True) + self._output_path = run_dir / "trace.jsonl" + + # Open file for append + self._open_file() + + # Write trace start event + self._emit_event( + event_type="trace_start", + data={ + "run_name": self.run_name, + "redact_secrets": self.redact_secrets, + }, + ) + + # Console output + if self._console_tracer: + self._console_tracer.log_trace_start(self.run_name) + + def _open_file(self) -> None: + """Open the trace file for writing.""" + try: + self._output_path.parent.mkdir(parents=True, exist_ok=True) + self._file = self._output_path.open("a", encoding="utf-8") + logger.info(f"Live trace output: {self._output_path}") + except OSError as e: + logger.error(f"Failed to open trace file: {e}") + raise + + def _get_next_sequence(self) -> int: + """Get the next sequence number (thread-safe).""" + with self._sequence_lock: + self._sequence += 1 + return self._sequence + + def _emit_event( + self, + event_type: str, + data: dict[str, Any] | None = None, + agent_id: str | None = None, + ) -> str: + """ + Emit an event to the trace file. + + Returns the event_id. + """ + if self._closed or self._file is None: + return "" + + event_id = f"evt-{uuid4().hex[:8]}" + sequence = self._get_next_sequence() + + event = { + "timestamp": datetime.now(UTC).isoformat(), + "trace_id": self.trace_id, + "event_id": event_id, + "sequence": sequence, + "event_type": event_type, + } + + if agent_id: + event["agent_id"] = agent_id + + if data: + # Apply redaction if enabled + if self._redactor: + data = self._redactor.redact(data) + event["data"] = data + + try: + with self._file_lock: + if self._file and not self._closed: + self._file.write(json.dumps(event, default=str) + "\n") + self._file.flush() + except OSError as e: + logger.error(f"Failed to write trace event: {e}") + + return event_id + + # ------------------------------------------------------------------------- + # LLM Events + # ------------------------------------------------------------------------- + + def log_llm_request( + self, + agent_id: str | None, + model: str, + messages: list[dict[str, Any]], + metadata: dict[str, Any] | None = None, + ) -> str: + """Log an LLM request being sent.""" + data: dict[str, Any] = { + "model": model, + "message_count": len(messages), + "messages": self._summarize_messages(messages), + } + if metadata: + data["metadata"] = metadata + + # Console output + if self._console_tracer: + self._console_tracer.log_llm_request(model, len(messages), agent_id) + + return self._emit_event( + event_type="llm_request", + agent_id=agent_id, + data=data, + ) + + def log_llm_response( + self, + agent_id: str | None, + content: str, + usage: dict[str, Any] | None = None, + tool_invocations: list[dict[str, Any]] | None = None, + thinking_blocks: list[dict[str, Any]] | None = None, + duration_ms: float | None = None, + model: str | None = None, + ) -> str: + """Log an LLM response received.""" + data: dict[str, Any] = { + "content_length": len(content) if content else 0, + "content_preview": (content[:500] + "...") if content and len(content) > 500 else content, + } + + if usage: + data["usage"] = usage + if tool_invocations: + data["tool_invocations"] = tool_invocations + if thinking_blocks: + data["has_thinking"] = True + data["thinking_count"] = len(thinking_blocks) + if duration_ms is not None: + data["duration_ms"] = round(duration_ms, 2) + + # Console output + if self._console_tracer: + tokens = None + if usage: + tokens = { + "input": usage.get("prompt_tokens", 0), + "output": usage.get("completion_tokens", 0), + } + self._console_tracer.log_llm_response(model or "unknown", tokens, agent_id) + + return self._emit_event( + event_type="llm_response", + agent_id=agent_id, + data=data, + ) + + def log_llm_error( + self, + agent_id: str | None, + error_type: str, + error_message: str, + retryable: bool = False, + model: str | None = None, + ) -> str: + """Log an LLM error.""" + # Console output + if self._console_tracer: + self._console_tracer.log_llm_error(error_message, model) + + return self._emit_event( + event_type="llm_error", + agent_id=agent_id, + data={ + "error_type": error_type, + "error_message": error_message, + "retryable": retryable, + }, + ) + + # ------------------------------------------------------------------------- + # Tool Events + # ------------------------------------------------------------------------- + + def log_tool_call( + self, + agent_id: str, + tool_name: str, + args: dict[str, Any], + execution_id: int | None = None, + ) -> str: + """Log a tool being called.""" + # Console output + if self._console_tracer: + self._console_tracer.log_tool_call(tool_name, args, agent_id) + + return self._emit_event( + event_type="tool_call", + agent_id=agent_id, + data={ + "tool_name": tool_name, + "args": args, + "execution_id": execution_id, + }, + ) + + def log_tool_result( + self, + agent_id: str, + tool_name: str, + status: str, + result: Any, + execution_id: int | None = None, + duration_ms: float | None = None, + error: str | None = None, + ) -> str: + """Log a tool execution result.""" + # Summarize large results + result_summary = self._summarize_result(result) + + data: dict[str, Any] = { + "tool_name": tool_name, + "status": status, + "result": result_summary, + "execution_id": execution_id, + } + if duration_ms is not None: + data["duration_ms"] = round(duration_ms, 2) + + # Console output + if self._console_tracer: + self._console_tracer.log_tool_result(tool_name, result, error, agent_id) + + return self._emit_event( + event_type="tool_result", + agent_id=agent_id, + data=data, + ) + + # ------------------------------------------------------------------------- + # Agent Events + # ------------------------------------------------------------------------- + + def log_agent_created( + self, + agent_id: str, + agent_name: str, + task: str, + parent_id: str | None = None, + agent_type: str | None = None, + ) -> str: + """Log an agent being created.""" + # Console output + if self._console_tracer: + self._console_tracer.log_agent_created(agent_id, agent_name, task) + + return self._emit_event( + event_type="agent_created", + agent_id=agent_id, + data={ + "agent_name": agent_name, + "task": task, + "parent_id": parent_id, + "agent_type": agent_type, + }, + ) + + def log_agent_completed( + self, + agent_id: str, + status: str, + result: dict[str, Any] | None = None, + error_message: str | None = None, + agent_name: str | None = None, + ) -> str: + """Log an agent completing its task.""" + data: dict[str, Any] = {"status": status} + if result: + data["result"] = self._summarize_result(result) + if error_message: + data["error_message"] = error_message + + # Console output + if self._console_tracer: + success = status in ("completed", "success") + self._console_tracer.log_agent_completed( + agent_id, agent_name or "Agent", success + ) + + return self._emit_event( + event_type="agent_completed", + agent_id=agent_id, + data=data, + ) + + def log_agent_state_change( + self, + agent_id: str, + field: str, + old_value: Any, + new_value: Any, + ) -> str: + """Log a significant state change in an agent.""" + # Console output + if self._console_tracer: + self._console_tracer.log_state_change( + agent_id, str(new_value), str(field) + ) + + return self._emit_event( + event_type="state_change", + agent_id=agent_id, + data={ + "field": field, + "old_value": str(old_value) if old_value is not None else None, + "new_value": str(new_value) if new_value is not None else None, + }, + ) + + # ------------------------------------------------------------------------- + # Message Events + # ------------------------------------------------------------------------- + + def log_message( + self, + agent_id: str | None, + role: str, + content: str, + metadata: dict[str, Any] | None = None, + ) -> str: + """Log a chat message.""" + data: dict[str, Any] = { + "role": role, + "content_length": len(content) if content else 0, + "content_preview": (content[:1000] + "...") if content and len(content) > 1000 else content, + } + if metadata: + data["metadata"] = metadata + + return self._emit_event( + event_type="message", + agent_id=agent_id, + data=data, + ) + + # ------------------------------------------------------------------------- + # Vulnerability Events + # ------------------------------------------------------------------------- + + def log_vulnerability_found( + self, + agent_id: str | None, + vuln_id: str, + title: str, + severity: str, + target: str | None = None, + ) -> str: + """Log a vulnerability being discovered.""" + # Console output + if self._console_tracer: + self._console_tracer.log_vulnerability_found(vuln_id, title, severity) + + return self._emit_event( + event_type="vulnerability_found", + agent_id=agent_id, + data={ + "vuln_id": vuln_id, + "title": title, + "severity": severity, + "target": target, + }, + ) + + # ------------------------------------------------------------------------- + # Helper Methods + # ------------------------------------------------------------------------- + + def _summarize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Create a summary of messages for logging.""" + summaries = [] + for msg in messages: + role = msg.get("role", "unknown") + content = msg.get("content", "") + + if isinstance(content, str): + content_len = len(content) + preview = (content[:200] + "...") if len(content) > 200 else content + elif isinstance(content, list): + # Handle multi-part content (text + images) + content_len = sum( + len(p.get("text", "")) if isinstance(p, dict) else 0 + for p in content + ) + preview = f"[{len(content)} parts]" + else: + content_len = 0 + preview = str(content)[:200] + + summaries.append({ + "role": role, + "content_length": content_len, + "preview": preview, + }) + + return summaries + + def _summarize_result(self, result: Any) -> Any: + """Summarize a result for logging (truncate large values).""" + if result is None: + return None + + if isinstance(result, str): + if len(result) > 2000: + return result[:1000] + f"\n... [truncated {len(result) - 2000} chars] ...\n" + result[-1000:] + return result + + if isinstance(result, dict): + # Remove screenshot data if present + result = dict(result) + if "screenshot" in result: + result["screenshot"] = "[screenshot data removed]" + + # Truncate large string values + for key, value in result.items(): + if isinstance(value, str) and len(value) > 500: + result[key] = value[:250] + f"... [truncated {len(value) - 500} chars] ..." + value[-250:] + + return result + + if isinstance(result, list) and len(result) > 50: + return result[:25] + [f"... [{len(result) - 50} items truncated] ..."] + result[-25:] + + return result + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + def close(self) -> None: + """Close the trace file.""" + if self._closed: + return + + # Console output + if self._console_tracer: + self._console_tracer.log_trace_end(self.run_name) + + # Write trace end event + self._emit_event( + event_type="trace_end", + data={ + "total_events": self._sequence, + "end_time": datetime.now(UTC).isoformat(), + }, + ) + + self._closed = True + + with self._file_lock: + if self._file: + try: + self._file.close() + except OSError: + pass + self._file = None + + logger.info(f"Live trace completed: {self._output_path} ({self._sequence} events)") + + @property + def output_path(self) -> Path: + """Get the trace output path.""" + return self._output_path + + def __enter__(self) -> "LiveTracer": + return self + + def __exit__(self, *_: Any) -> None: + self.close() diff --git a/strix/telemetry/redactor.py b/strix/telemetry/redactor.py new file mode 100644 index 00000000..8151e020 --- /dev/null +++ b/strix/telemetry/redactor.py @@ -0,0 +1,209 @@ +""" +Secret Redaction Utility for Live Tracing. + +Identifies and redacts sensitive information like: +- API keys and tokens +- Passwords and credentials +- Authorization headers +- Environment variable values containing secrets +""" + +import re +from typing import Any + + +class SecretRedactor: + """ + Redacts sensitive information from trace data. + + Identifies secrets through: + - Known key patterns (api_key, password, token, etc.) + - Value patterns (Bearer tokens, base64-looking strings, etc.) + - Environment variable patterns + """ + + # Patterns for keys that typically contain secrets + SECRET_KEY_PATTERNS = [ + re.compile(r"(?i)(api[_-]?key|apikey)"), + re.compile(r"(?i)(secret[_-]?key|secretkey)"), + re.compile(r"(?i)(access[_-]?token|accesstoken)"), + re.compile(r"(?i)(auth[_-]?token|authtoken)"), + re.compile(r"(?i)(bearer[_-]?token)"), + re.compile(r"(?i)(refresh[_-]?token)"), + re.compile(r"(?i)^password$"), + re.compile(r"(?i)^passwd$"), + re.compile(r"(?i)(private[_-]?key|privatekey)"), + re.compile(r"(?i)(client[_-]?secret|clientsecret)"), + re.compile(r"(?i)(aws[_-]?secret)"), + re.compile(r"(?i)(database[_-]?password|db[_-]?password)"), + re.compile(r"(?i)(encryption[_-]?key)"), + re.compile(r"(?i)(signing[_-]?key)"), + re.compile(r"(?i)(jwt[_-]?secret)"), + re.compile(r"(?i)(session[_-]?secret)"), + re.compile(r"(?i)(cookie[_-]?secret)"), + re.compile(r"(?i)^credentials?$"), + re.compile(r"(?i)(sandbox[_-]?token)"), + re.compile(r"(?i)(perplexity[_-]?api[_-]?key)"), + re.compile(r"(?i)(openai[_-]?api[_-]?key)"), + re.compile(r"(?i)(anthropic[_-]?api[_-]?key)"), + ] + + # Patterns for values that look like secrets + SECRET_VALUE_PATTERNS = [ + # Bearer tokens + re.compile(r"Bearer\s+[A-Za-z0-9\-_=]+\.?[A-Za-z0-9\-_=]*\.?[A-Za-z0-9\-_=]*"), + # API keys (common formats) + re.compile(r"sk-[A-Za-z0-9]{20,}"), # OpenAI style + re.compile(r"sk-ant-[A-Za-z0-9\-]{20,}"), # Anthropic style + re.compile(r"pplx-[A-Za-z0-9]{20,}"), # Perplexity style + re.compile(r"gsk_[A-Za-z0-9]{20,}"), # Groq style + # AWS keys + re.compile(r"AKIA[0-9A-Z]{16}"), + # Generic long alphanumeric strings (likely tokens) + re.compile(r"[A-Za-z0-9]{40,}"), + ] + + # Environment variables that contain secrets + SECRET_ENV_VARS = { + "LLM_API_KEY", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "PERPLEXITY_API_KEY", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "GITHUB_TOKEN", + "DATABASE_PASSWORD", + "DB_PASSWORD", + "SECRET_KEY", + "JWT_SECRET", + "SESSION_SECRET", + } + + REDACTED = "[REDACTED]" + + def __init__(self, additional_patterns: list[re.Pattern[str]] | None = None): + """ + Initialize the redactor. + + Args: + additional_patterns: Additional regex patterns for secret keys + """ + self._key_patterns = list(self.SECRET_KEY_PATTERNS) + if additional_patterns: + self._key_patterns.extend(additional_patterns) + + def redact(self, data: Any) -> Any: + """ + Recursively redact secrets from data. + + Args: + data: Data to redact (dict, list, str, or primitive) + + Returns: + Data with secrets redacted + """ + if data is None: + return None + + if isinstance(data, dict): + return self._redact_dict(data) + + if isinstance(data, list): + return [self.redact(item) for item in data] + + if isinstance(data, str): + return self._redact_string(data) + + # Primitives (int, float, bool) pass through unchanged + return data + + def _redact_dict(self, data: dict[str, Any]) -> dict[str, Any]: + """Redact secrets from a dictionary.""" + result = {} + + for key, value in data.items(): + # Check if the key indicates a secret + if self._is_secret_key(key): + result[key] = self.REDACTED + else: + result[key] = self.redact(value) + + return result + + def _redact_string(self, value: str) -> str: + """Redact secrets from a string value.""" + if not value: + return value + + result = value + + # Check for Bearer tokens and API keys in the string + for pattern in self.SECRET_VALUE_PATTERNS: + result = pattern.sub(self.REDACTED, result) + + # Check for Authorization headers + result = re.sub( + r'(Authorization["\']?\s*[:=]\s*["\']?)(Bearer\s+)?[A-Za-z0-9\-_.=]+', + rf"\1{self.REDACTED}", + result, + flags=re.IGNORECASE, + ) + + # Redact env var values that look like secrets + for env_var in self.SECRET_ENV_VARS: + # Match patterns like: ENV_VAR=value or "ENV_VAR": "value" + result = re.sub( + rf'({env_var}["\']?\s*[:=]\s*["\']?)([^"\'\s,}}]+)', + rf"\1{self.REDACTED}", + result, + ) + + return result + + def _is_secret_key(self, key: str) -> bool: + """Check if a key name indicates it contains a secret.""" + if not key: + return False + + key_lower = key.lower() + + # Check against known secret env vars + if key.upper() in self.SECRET_ENV_VARS: + return True + + # Check against patterns + for pattern in self._key_patterns: + if pattern.search(key_lower): + return True + + return False + + def redact_headers(self, headers: dict[str, str]) -> dict[str, str]: + """ + Specifically redact HTTP headers. + + Args: + headers: HTTP headers dict + + Returns: + Headers with sensitive values redacted + """ + sensitive_headers = { + "authorization", + "x-api-key", + "api-key", + "x-auth-token", + "cookie", + "set-cookie", + "x-csrf-token", + "x-access-token", + } + + result = {} + for key, value in headers.items(): + if key.lower() in sensitive_headers: + result[key] = self.REDACTED + else: + result[key] = value + + return result diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 25af62c8..e2a279f0 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -143,9 +143,35 @@ def add_vulnerability_report( # noqa: PLR0912 if self.vulnerability_found_callback: self.vulnerability_found_callback(report) + # Log to live tracer + self._log_vulnerability_to_live_tracer(report_id, title, severity, target) + self.save_run_data() return report_id + def _log_vulnerability_to_live_tracer( + self, + vuln_id: str, + title: str, + severity: str, + target: str | None, + ) -> None: + """Log vulnerability to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_vulnerability_found( + agent_id=None, # Vulnerability reports don't always have agent context + vuln_id=vuln_id, + title=title, + severity=severity, + target=target, + ) + except Exception: # noqa: BLE001, S110 + pass + def get_existing_vulnerabilities(self) -> list[dict[str, Any]]: return list(self.vulnerability_reports) diff --git a/strix/tools/executor.py b/strix/tools/executor.py index 1c240877..7429c8e7 100644 --- a/strix/tools/executor.py +++ b/strix/tools/executor.py @@ -1,5 +1,6 @@ import inspect import os +import time from typing import Any import httpx @@ -266,10 +267,14 @@ async def _execute_single_tool( args = tool_inv.get("args", {}) execution_id = None should_agent_finish = False + start_time = time.perf_counter() if tracer: execution_id = tracer.log_tool_execution_start(agent_id, tool_name, args) + # Log tool call to live tracer + _log_live_tool_call(agent_id, tool_name, args, execution_id) + try: result = await execute_tool_invocation(tool_inv, agent_state) @@ -287,10 +292,20 @@ async def _execute_single_tool( _update_tracer_with_result(tracer, execution_id, is_error, result, error_payload) + # Log tool result to live tracer + duration_ms = (time.perf_counter() - start_time) * 1000 + status = "error" if is_error else "completed" + _log_live_tool_result(agent_id, tool_name, status, result, execution_id, duration_ms) + except (ConnectionError, RuntimeError, ValueError, TypeError, OSError) as e: error_msg = str(e) if tracer and execution_id: tracer.update_tool_execution(execution_id, "error", error_msg) + + # Log tool error to live tracer + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_live_tool_result(agent_id, tool_name, "error", {"error": error_msg}, execution_id, duration_ms) + raise observation_xml, images = _format_tool_result(tool_name, result) @@ -362,3 +377,51 @@ def remove_screenshot_from_result(result: Any) -> Any: result_copy["screenshot"] = "[Image data extracted - see attached image]" return result_copy + + +def _log_live_tool_call( + agent_id: str, + tool_name: str, + args: dict[str, Any], + execution_id: int | None, +) -> None: + """Log tool call to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_tool_call( + agent_id=agent_id, + tool_name=tool_name, + args=args, + execution_id=execution_id, + ) + except Exception: # noqa: BLE001, S110 + pass + + +def _log_live_tool_result( + agent_id: str, + tool_name: str, + status: str, + result: Any, + execution_id: int | None, + duration_ms: float, +) -> None: + """Log tool result to live tracer if enabled.""" + try: + from strix.telemetry.live_tracer import get_live_tracer + + tracer = get_live_tracer() + if tracer: + tracer.log_tool_result( + agent_id=agent_id, + tool_name=tool_name, + status=status, + result=result, + execution_id=execution_id, + duration_ms=duration_ms, + ) + except Exception: # noqa: BLE001, S110 + pass diff --git a/tests/telemetry/__init__.py b/tests/telemetry/__init__.py index 8f6aa4f8..21040cf4 100644 --- a/tests/telemetry/__init__.py +++ b/tests/telemetry/__init__.py @@ -1 +1 @@ -"""Tests for strix.telemetry module.""" +# Telemetry tests diff --git a/tests/telemetry/test_live_tracer.py b/tests/telemetry/test_live_tracer.py new file mode 100644 index 00000000..ff3b5c7f --- /dev/null +++ b/tests/telemetry/test_live_tracer.py @@ -0,0 +1,444 @@ +"""Tests for the LiveTracer module.""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from strix.telemetry.live_tracer import LiveTracer, get_live_tracer, set_live_tracer + + +class TestLiveTracer: + """Tests for LiveTracer class.""" + + def test_initialization_creates_file(self, tmp_path: Path) -> None: + """Test that LiveTracer creates a trace file on initialization.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path, run_name="test-run") + + assert trace_path.exists() + tracer.close() + + def test_trace_start_event(self, tmp_path: Path) -> None: + """Test that trace_start event is written on initialization.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path, run_name="test-run") + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + assert len(lines) >= 2 # At least trace_start and trace_end + + # Check trace_start event + start_event = json.loads(lines[0]) + assert start_event["event_type"] == "trace_start" + assert start_event["data"]["run_name"] == "test-run" + assert "trace_id" in start_event + assert "timestamp" in start_event + assert "event_id" in start_event + assert start_event["sequence"] == 1 + + def test_trace_end_event(self, tmp_path: Path) -> None: + """Test that trace_end event is written on close.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path, run_name="test-run") + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + # Check trace_end event (last line) + end_event = json.loads(lines[-1]) + assert end_event["event_type"] == "trace_end" + assert "total_events" in end_event["data"] + + def test_log_llm_request(self, tmp_path: Path) -> None: + """Test logging an LLM request.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ] + + event_id = tracer.log_llm_request( + agent_id="agent_123", + model="test-model", + messages=messages, + metadata={"test": True}, + ) + tracer.close() + + assert event_id.startswith("evt-") + + with trace_path.open() as f: + lines = f.readlines() + + # Find the llm_request event + llm_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "llm_request": + llm_event = event + break + + assert llm_event is not None + assert llm_event["agent_id"] == "agent_123" + assert llm_event["data"]["model"] == "test-model" + assert llm_event["data"]["message_count"] == 2 + + def test_log_llm_response(self, tmp_path: Path) -> None: + """Test logging an LLM response.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + event_id = tracer.log_llm_response( + agent_id="agent_123", + content="Hello! How can I help you?", + usage={"input_tokens": 100, "output_tokens": 50}, + duration_ms=1234.56, + ) + tracer.close() + + assert event_id.startswith("evt-") + + with trace_path.open() as f: + lines = f.readlines() + + # Find the llm_response event + response_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "llm_response": + response_event = event + break + + assert response_event is not None + assert response_event["agent_id"] == "agent_123" + assert response_event["data"]["content_length"] == 26 + assert response_event["data"]["duration_ms"] == 1234.56 + + def test_log_tool_call(self, tmp_path: Path) -> None: + """Test logging a tool call.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + tracer.log_tool_call( + agent_id="agent_123", + tool_name="terminal", + args={"command": "ls -la"}, + execution_id=42, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + tool_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "tool_call": + tool_event = event + break + + assert tool_event is not None + assert tool_event["data"]["tool_name"] == "terminal" + assert tool_event["data"]["args"]["command"] == "ls -la" + assert tool_event["data"]["execution_id"] == 42 + + def test_log_tool_result(self, tmp_path: Path) -> None: + """Test logging a tool result.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + tracer.log_tool_result( + agent_id="agent_123", + tool_name="terminal", + status="completed", + result={"output": "file1.txt\nfile2.txt"}, + execution_id=42, + duration_ms=100.5, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + result_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "tool_result": + result_event = event + break + + assert result_event is not None + assert result_event["data"]["tool_name"] == "terminal" + assert result_event["data"]["status"] == "completed" + assert result_event["data"]["duration_ms"] == 100.5 + + def test_log_agent_created(self, tmp_path: Path) -> None: + """Test logging agent creation.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + tracer.log_agent_created( + agent_id="agent_123", + agent_name="Root Agent", + task="Perform security scan", + parent_id=None, + agent_type="StrixAgent", + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + agent_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "agent_created": + agent_event = event + break + + assert agent_event is not None + assert agent_event["agent_id"] == "agent_123" + assert agent_event["data"]["agent_name"] == "Root Agent" + assert agent_event["data"]["task"] == "Perform security scan" + assert agent_event["data"]["agent_type"] == "StrixAgent" + + def test_log_agent_completed(self, tmp_path: Path) -> None: + """Test logging agent completion.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + tracer.log_agent_completed( + agent_id="agent_123", + status="completed", + result={"success": True, "findings": 5}, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + completed_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "agent_completed": + completed_event = event + break + + assert completed_event is not None + assert completed_event["data"]["status"] == "completed" + assert completed_event["data"]["result"]["success"] is True + + def test_log_state_change(self, tmp_path: Path) -> None: + """Test logging state changes.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + tracer.log_agent_state_change( + agent_id="agent_123", + field="iteration", + old_value=5, + new_value=6, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + state_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "state_change": + state_event = event + break + + assert state_event is not None + assert state_event["data"]["field"] == "iteration" + assert state_event["data"]["old_value"] == "5" + assert state_event["data"]["new_value"] == "6" + + def test_log_vulnerability_found(self, tmp_path: Path) -> None: + """Test logging vulnerability discovery.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + tracer.log_vulnerability_found( + agent_id="agent_123", + vuln_id="vuln-0001", + title="SQL Injection", + severity="high", + target="https://example.com/api", + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + vuln_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "vulnerability_found": + vuln_event = event + break + + assert vuln_event is not None + assert vuln_event["data"]["vuln_id"] == "vuln-0001" + assert vuln_event["data"]["title"] == "SQL Injection" + assert vuln_event["data"]["severity"] == "high" + + def test_sequence_numbers_increment(self, tmp_path: Path) -> None: + """Test that sequence numbers increment monotonically.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + # Generate multiple events + for i in range(5): + tracer.log_message( + agent_id="agent_123", + role="user", + content=f"Message {i}", + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + sequences = [json.loads(line)["sequence"] for line in lines] + assert sequences == sorted(sequences) # Must be monotonically increasing + assert len(set(sequences)) == len(sequences) # All unique + + def test_context_manager(self, tmp_path: Path) -> None: + """Test LiveTracer as context manager.""" + trace_path = tmp_path / "trace.jsonl" + + with LiveTracer(output_path=trace_path) as tracer: + tracer.log_message( + agent_id="agent_123", + role="user", + content="Test message", + ) + + # File should be properly closed after context exits + with trace_path.open() as f: + lines = f.readlines() + + assert len(lines) >= 3 # trace_start, message, trace_end + + def test_output_path_property(self, tmp_path: Path) -> None: + """Test that output_path property returns correct path.""" + trace_path = tmp_path / "custom" / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + assert tracer.output_path == trace_path + tracer.close() + + def test_global_tracer_functions(self, tmp_path: Path) -> None: + """Test get_live_tracer and set_live_tracer functions.""" + # Initially should be None + original = get_live_tracer() + + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + set_live_tracer(tracer) + assert get_live_tracer() is tracer + + # Clean up + set_live_tracer(original) + tracer.close() + + def test_large_content_truncation(self, tmp_path: Path) -> None: + """Test that large content is properly truncated.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path) + + large_content = "x" * 10000 # 10KB of content + + tracer.log_llm_response( + agent_id="agent_123", + content=large_content, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + response_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "llm_response": + response_event = event + break + + assert response_event is not None + # Content preview should be truncated + assert len(response_event["data"]["content_preview"]) < len(large_content) + assert "..." in response_event["data"]["content_preview"] + # But content_length should report actual size + assert response_event["data"]["content_length"] == 10000 + + +class TestLiveTracerWithRedaction: + """Tests for LiveTracer with secret redaction enabled.""" + + def test_redaction_enabled(self, tmp_path: Path) -> None: + """Test that redaction is applied when enabled.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path, redact_secrets=True) + + tracer.log_tool_call( + agent_id="agent_123", + tool_name="http_request", + args={ + "url": "https://api.example.com", + "headers": {"Authorization": "Bearer sk-secret-key-12345"}, + "api_key": "sk-openai-secret-key-abcdef", + }, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + tool_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "tool_call": + tool_event = event + break + + assert tool_event is not None + # API key should be redacted + assert tool_event["data"]["args"]["api_key"] == "[REDACTED]" + + def test_redaction_disabled(self, tmp_path: Path) -> None: + """Test that redaction is not applied when disabled.""" + trace_path = tmp_path / "trace.jsonl" + tracer = LiveTracer(output_path=trace_path, redact_secrets=False) + + test_value = "not-really-secret" + tracer.log_tool_call( + agent_id="agent_123", + tool_name="test_tool", + args={"url": test_value}, + ) + tracer.close() + + with trace_path.open() as f: + lines = f.readlines() + + tool_event = None + for line in lines: + event = json.loads(line) + if event["event_type"] == "tool_call": + tool_event = event + break + + assert tool_event is not None + assert tool_event["data"]["args"]["url"] == test_value diff --git a/tests/telemetry/test_redactor.py b/tests/telemetry/test_redactor.py new file mode 100644 index 00000000..46d97587 --- /dev/null +++ b/tests/telemetry/test_redactor.py @@ -0,0 +1,257 @@ +"""Tests for the SecretRedactor module.""" + +import pytest + +from strix.telemetry.redactor import SecretRedactor + + +class TestSecretRedactor: + """Tests for SecretRedactor class.""" + + def test_redact_api_key(self) -> None: + """Test redaction of API key fields.""" + redactor = SecretRedactor() + + data = {"api_key": "sk-secret-12345", "name": "test"} + result = redactor.redact(data) + + assert result["api_key"] == "[REDACTED]" + assert result["name"] == "test" + + def test_redact_password(self) -> None: + """Test redaction of password fields.""" + redactor = SecretRedactor() + + data = {"password": "supersecret", "username": "admin"} + result = redactor.redact(data) + + assert result["password"] == "[REDACTED]" + assert result["username"] == "admin" + + def test_redact_access_token(self) -> None: + """Test redaction of access token fields.""" + redactor = SecretRedactor() + + data = { + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "refresh_token": "refresh-token-value", + "data": "normal", + } + result = redactor.redact(data) + + assert result["access_token"] == "[REDACTED]" + assert result["refresh_token"] == "[REDACTED]" + assert result["data"] == "normal" + + def test_redact_nested_dict(self) -> None: + """Test redaction in nested dictionaries.""" + redactor = SecretRedactor() + + data = { + "config": { + "database": {"password": "dbpass123", "host": "localhost"}, + "api": {"secret_key": "my-secret"}, + }, + "name": "app", + } + result = redactor.redact(data) + + assert result["config"]["database"]["password"] == "[REDACTED]" + assert result["config"]["database"]["host"] == "localhost" + assert result["config"]["api"]["secret_key"] == "[REDACTED]" + assert result["name"] == "app" + + def test_redact_list_of_dicts(self) -> None: + """Test redaction in lists of dictionaries.""" + redactor = SecretRedactor() + + data = [ + {"api_key": "key1", "name": "service1"}, + {"api_key": "key2", "name": "service2"}, + ] + result = redactor.redact(data) + + assert result[0]["api_key"] == "[REDACTED]" + assert result[0]["name"] == "service1" + assert result[1]["api_key"] == "[REDACTED]" + assert result[1]["name"] == "service2" + + def test_redact_bearer_token_in_string(self) -> None: + """Test redaction of Bearer tokens in string values.""" + redactor = SecretRedactor() + + data = { + "headers": "Authorization: Bearer sk-very-long-secret-token-12345678901234567890", + "method": "GET", + } + result = redactor.redact(data) + + assert "Bearer" not in result["headers"] or "[REDACTED]" in result["headers"] + assert result["method"] == "GET" + + def test_redact_openai_api_key_pattern(self) -> None: + """Test redaction of OpenAI API key pattern.""" + redactor = SecretRedactor() + + value = "Using API key: sk-abcdefghijklmnopqrstuvwxyz1234567890" + result = redactor.redact(value) + + assert "sk-" not in result or "[REDACTED]" in result + + def test_redact_anthropic_api_key_pattern(self) -> None: + """Test redaction of Anthropic API key pattern.""" + redactor = SecretRedactor() + + value = "Anthropic key is sk-ant-api03-abcdefghijklmnopqrstuvwxyz" + result = redactor.redact(value) + + assert "sk-ant-" not in result or "[REDACTED]" in result + + def test_redact_aws_access_key_pattern(self) -> None: + """Test redaction of AWS access key pattern.""" + redactor = SecretRedactor() + + value = "AWS key: AKIAIOSFODNN7EXAMPLE" + result = redactor.redact(value) + + assert "AKIAIOSFODNN7EXAMPLE" not in result + + def test_redact_env_vars_in_string(self) -> None: + """Test redaction of environment variable values in strings.""" + redactor = SecretRedactor() + + value = 'export LLM_API_KEY="my-secret-key-value"' + result = redactor.redact(value) + + assert "my-secret-key-value" not in result or "[REDACTED]" in result + + def test_redact_headers_method(self) -> None: + """Test the specialized redact_headers method.""" + redactor = SecretRedactor() + + headers = { + "Authorization": "Bearer secret-token", + "X-API-Key": "another-secret", + "Content-Type": "application/json", + "Accept": "application/json", + } + result = redactor.redact_headers(headers) + + assert result["Authorization"] == "[REDACTED]" + assert result["X-API-Key"] == "[REDACTED]" + assert result["Content-Type"] == "application/json" + assert result["Accept"] == "application/json" + + def test_redact_cookie_header(self) -> None: + """Test redaction of cookie headers.""" + redactor = SecretRedactor() + + headers = {"Cookie": "session=abc123; auth_token=secret123"} + result = redactor.redact_headers(headers) + + assert result["Cookie"] == "[REDACTED]" + + def test_redact_none_value(self) -> None: + """Test that None values pass through unchanged.""" + redactor = SecretRedactor() + + result = redactor.redact(None) + assert result is None + + def test_redact_primitive_values(self) -> None: + """Test that primitive values (int, float, bool) pass through unchanged.""" + redactor = SecretRedactor() + + assert redactor.redact(42) == 42 + assert redactor.redact(3.14) == 3.14 + assert redactor.redact(True) is True + assert redactor.redact(False) is False + + def test_redact_empty_string(self) -> None: + """Test that empty strings pass through unchanged.""" + redactor = SecretRedactor() + + assert redactor.redact("") == "" + + def test_redact_preserves_dict_structure(self) -> None: + """Test that dictionary structure is preserved during redaction.""" + redactor = SecretRedactor() + + data = { + "level1": { + "level2": { + "level3": {"api_key": "secret", "data": [1, 2, 3]}, + }, + }, + } + result = redactor.redact(data) + + assert isinstance(result["level1"]["level2"]["level3"]["data"], list) + assert result["level1"]["level2"]["level3"]["data"] == [1, 2, 3] + + def test_redact_case_insensitive_keys(self) -> None: + """Test that key matching is case-insensitive.""" + redactor = SecretRedactor() + + data = { + "API_KEY": "secret1", + "Api_Key": "secret2", + "api_key": "secret3", + "APIKEY": "secret4", + } + result = redactor.redact(data) + + for key in data: + assert result[key] == "[REDACTED]" + + def test_redact_various_secret_patterns(self) -> None: + """Test redaction of various secret key patterns.""" + redactor = SecretRedactor() + + data = { + "private_key": "-----BEGIN RSA PRIVATE KEY-----", + "client_secret": "oauth-client-secret", + "database_password": "dbpass", + "jwt_secret": "jwt-secret-key", + "session_secret": "session-secret", + "encryption_key": "enc-key", + "signing_key": "sig-key", + } + result = redactor.redact(data) + + for key in data: + assert result[key] == "[REDACTED]", f"Expected {key} to be redacted" + + def test_custom_additional_patterns(self) -> None: + """Test adding custom patterns to the redactor.""" + import re + + custom_pattern = re.compile(r"(?i)my_custom_secret") + redactor = SecretRedactor(additional_patterns=[custom_pattern]) + + data = {"my_custom_secret": "should-be-redacted", "other": "visible"} + result = redactor.redact(data) + + assert result["my_custom_secret"] == "[REDACTED]" + assert result["other"] == "visible" + + def test_redact_long_alphanumeric_in_string(self) -> None: + """Test redaction of long alphanumeric strings that look like tokens.""" + redactor = SecretRedactor() + + # 40+ character alphanumeric strings are likely tokens + value = "Token: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + result = redactor.redact(value) + + # The long string should be redacted + assert "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" not in result + + def test_redact_sandbox_token(self) -> None: + """Test redaction of sandbox token fields.""" + redactor = SecretRedactor() + + data = {"sandbox_token": "sandbox-auth-token-123", "sandbox_id": "sbx-123"} + result = redactor.redact(data) + + assert result["sandbox_token"] == "[REDACTED]" + assert result["sandbox_id"] == "sbx-123"