diff --git a/examples/single_agent_examples/summarizing_agent/README.md b/examples/single_agent_examples/summarizing_agent/README.md new file mode 100644 index 00000000..bf47016b --- /dev/null +++ b/examples/single_agent_examples/summarizing_agent/README.md @@ -0,0 +1,93 @@ +# Summarizing Agent Example — Summarize a Directory + +This example demonstrates `SummarizingAgent`, which summarizes **all eligible documents in a directory** into a **single unified output**. + +Key features: +- Summarizes a directory of docs (N inputs → 1 summary) +- Paragraph-aware chunking with configurable overlap to fit model context +- Map/Reduce summarization: chunk notes → iterative merges → final structured summary +- Optional rewrite pass to scrub meta-language/segmentation if the model violates constraints + +## How it works (high level) + +1) **Select files** in `--input-dir` (optionally recursive), filtered by extension +2) **Read** file contents via `read_file` +3) **Chunk** text into deterministic, paragraph-aware chunks (with overlap) +4) **Map:** each chunk → compact bullet notes +5) **Reduce:** merge notes in batches until one summary remains +6) **Rewrite (optional):** if output contains forbidden references/headings, rewrite to clean it up + +### Chunking + overlap (as implemented) + +Chunks are built from paragraphs when possible. If overlap is enabled, each chunk after the first is prefixed with the last `overlap_chars` from the previous chunk: + +``` +chunk1: [-------------- A ---------------] +chunk2: [--- tail(A) ---][------ B ------] +chunk3: [--- tail(B) ---][------ C ------] +``` + +Oversized paragraphs are split into fixed-size segments (also with overlap). + +## Running the example +Some example below show how to use this example. + +From this directory: + +```bash +python summarizing_agent_example.py --help +``` + +Some common options: + +``` +python summarizing_agent_example.py --recurse +python summarizing_agent_example.py --input-dir ./inputs --output-path ./out/summary.txt +python summarizing_agent_example.py --mode synthesis +python summarizing_agent_example.py --max-files 50 --chunk-size-chars 10000 --chunk-overlap-chars 800 --reduce-batch-size 8 +python summarizing_agent_example.py --show-tool-output +``` + +### Using a Gateway / Custom Endpoint +By default, the script expects an OpenAI-compatible endpoint and reads the API key from `OPENAI_API_KEY`. +To use a different base URL and API key env var: + +``` +python summarizing_agent_example.py \ + --base-url https://your.gateway.example/v1 \ + --api-key-env YOUR_GATEWAY_API_KEY +``` + +### Get sample input docs (public, safe downloads) + +If you don’t have documents handy, you can grab a few short, public-domain texts from Project Gutenberg +and summarize them as a directory. + +These three are small, varied, and interesting to synthesize (satire + horror + short fiction), and each is +listed as **Public domain in the USA** on Project Gutenberg. + +```bash +# From the examples directory +mkdir -p summarizing_agent_example_inputs/gutenberg +cd summarizing_agent_example_inputs/gutenberg + +# A Modest Proposal (Jonathan Swift) — eBook #1080 +wget -O a_modest_proposal_1080.txt https://www.gutenberg.org/ebooks/1080.txt.utf-8 + +# The Yellow Wallpaper (Charlotte Perkins Gilman) — eBook #1952 +wget -O the_yellow_wallpaper_1952.txt https://www.gutenberg.org/ebooks/1952.txt.utf-8 + +# The Gift of the Magi (O. Henry) — eBook #7256 +wget -O the_gift_of_the_magi_7256.txt https://www.gutenberg.org/ebooks/7256.txt.utf-8 +``` + +Then, relocate to the examples directory and run this: +```bash +python summarizing_agent_example.py --input-dir summarizing_agent_example_inputs/gutenberg +``` + +It's worth noting that the results of **summarizing** 3 fiction novels in the way that +the default prompts are in the SummarizingAgent (all configurable) may be odd +and produce weird results (e.g. Executive Summary, etc.) + +This example is intended only as a starting point to demonstrate the functionality. diff --git a/examples/single_agent_examples/summarizing_agent/summarize_a_directory.py b/examples/single_agent_examples/summarizing_agent/summarize_a_directory.py new file mode 100644 index 00000000..9eb15948 --- /dev/null +++ b/examples/single_agent_examples/summarizing_agent/summarize_a_directory.py @@ -0,0 +1,560 @@ +""" +SummarizingAgent example runner (CLI-first, human-friendly). + +- Self-contained: no external example harness helpers +- CLI args with --help (no env var dependency for config) +- Uses ursa.util.llm_factory.setup_llm for provider wiring +- Rich console/progress when available; plain prints otherwise +- Reads inputs from ./summarizing_agent_example_inputs/ by default +- Writes output to ./summarizing_agent_example_output/summary.txt by default + +Example usage: + python summarizing_agent_example.py + python summarizing_agent_example.py --mode synthesis + python summarizing_agent_example.py --input-dir ./docs --recurse --max-files 50 + python summarizing_agent_example.py --output-path ./out/summary.txt + +Using a non-default gateway/base URL: + python summarizing_agent_example.py --base-url https://your.gateway.example/v1 --api-key-env YOUR_GATEWAY_API_KEY +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Final, Optional, Sequence + +from ursa.agents.summarizing_agent import SummarizingAgent +from ursa.util.llm_factory import setup_llm + +# ---------------------------- +# Task presets +# ---------------------------- + +_TASK_SUMMARY: Final[str] = ( + "Produce: (1) Executive Summary (<=200 words), (2) Required Actions (5–10 bullets), " + "then a main synthesis (600–900 words). " + "Do not segment by input." +) + +_TASK_SYNTHESIS: Final[str] = ( + "Write one integrated synthesis that merges all material into a single coherent narrative. " + "Organize by themes, tensions, and shared ideas, and highlight meaningful contrasts where they exist—" + "but do NOT structure the output as separate summaries of each input. " + "Keep one consistent voice throughout." +) + +# Default provider wiring: OpenAI-compatible settings. +_DEFAULT_MODEL_CHOICE: Final[str] = "openai:gpt-4o-mini" +_DEFAULT_BASE_URL: Final[str] = "https://api.openai.com/v1" +_DEFAULT_API_KEY_ENV: Final[str] = "OPENAI_API_KEY" + + +# ---------------------------- +# Rich helpers (optional) +# ---------------------------- + + +def _get_console(): + """Return a rich console if installed; otherwise None.""" + try: + from rich import get_console # type: ignore + + return get_console() + except Exception: + return None + + +def _panel(console, title: str, lines: Sequence[str]) -> None: + """Render a small info panel; fallback to plain prints if Rich is unavailable.""" + body = "\n".join(lines) + if console is None: + print(f"\n== {title} ==\n{body}\n") + return + try: + from rich.panel import Panel # type: ignore + + console.print(Panel.fit(body, title=title)) + except Exception: + print(f"\n== {title} ==\n{body}\n") + + +class _ProgressTracker: + """ + Receives progress events emitted by SummarizingAgent via the `on_event` callback. + Uses Rich progress bars if available; otherwise prints a small set of milestones. + """ + + _PRINT_EVENTS: Final[set[str]] = { + "start", + "discover_done", + "read_start", + "chunking_done", + "reduce_round", + "rewrite_start", + "rewrite_done", + "done", + } + + def __init__(self, console): + self.console = console + self.progress = None + self._tasks: dict[str, Any] = {} + + if console is None: + return + + try: + from rich.progress import ( # type: ignore + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + ) + + self.progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("{task.completed}/{task.total}"), + TimeElapsedColumn(), + console=console, + transient=False, + ) + except Exception: + self.progress = None + + def __enter__(self): + if self.progress is not None: + self.progress.__enter__() + return self + + def __exit__(self, exc_type, exc, tb): + if self.progress is not None: + return self.progress.__exit__(exc_type, exc, tb) + return False + + def on_event(self, event: str, data: dict[str, Any]) -> None: + """Dispatch progress events from the agent.""" + if self.progress is None: + if event in self._PRINT_EVENTS: + print(f"[{event}] {data}") + return + + if event == "discover_done": + total = int(data.get("count", 0)) + self._tasks["read"] = self.progress.add_task( + "Reading files", total=total + ) + return + + if event == "read_start": + f = data.get("file") + if f: + self.progress.console.print(f"[READING]: {f}") + return + + if event in {"read_ok", "read_skip"}: + task_id = self._tasks.get("read") + if task_id is not None: + self.progress.advance(task_id, 1) + return + + if event == "chunking_done": + total = int(data.get("chunks", 0)) + self._tasks["map"] = self.progress.add_task( + "Summarizing chunks (map)", total=total + ) + return + + if event == "map_done": + task_id = self._tasks.get("map") + if task_id is not None: + self.progress.advance(task_id, 1) + return + + if event == "reduce_start": + self._tasks["reduce"] = self.progress.add_task( + "Reducing summaries", total=1 + ) + return + + if event == "reduce_round": + batches = int(data.get("batches", 1)) + task_id = self._tasks.get("reduce") + if task_id is not None: + self.progress.reset(task_id, total=batches) + + self.progress.console.print( + f"[dim]Reduce round {data.get('round')}: {data.get('items')} items → {batches} merges[/dim]" + ) + return + + if event == "reduce_batch_done": + task_id = self._tasks.get("reduce") + if task_id is not None: + self.progress.advance(task_id, 1) + return + + if event == "rewrite_start": + self.progress.console.print("[dim]Rewrite pass...[/dim]") + return + + if event == "rewrite_done": + self.progress.console.print( + f"[dim]Rewrite done (changed={data.get('changed')}).[/dim]" + ) + return + + if event == "done": + self.progress.console.print("[bold green]Done.[/bold green]") + return + + +# ---------------------------- +# CLI parsing +# ---------------------------- + + +def _comma_list(value: str) -> list[str]: + """Parse comma-separated list values.""" + return [v.strip() for v in value.split(",") if v.strip()] + + +def _normalize_exts(exts: Sequence[str]) -> tuple[str, ...]: + """Normalize extensions to begin with '.'.""" + out: list[str] = [] + for e in exts: + e = e.strip() + if not e: + continue + out.append(e if e.startswith(".") else f".{e}") + return tuple(out) + + +def _build_parser( + default_input_dir: Path, default_output_path: Path +) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="summarizing_agent_example", + description="Run SummarizingAgent over a directory of input documents.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # I/O + parser.add_argument( + "--input-dir", + type=Path, + default=default_input_dir, + help="Directory containing input documents to summarize.", + ) + parser.add_argument( + "--recurse", + action="store_true", + help="Recurse into subdirectories under --input-dir.", + ) + parser.add_argument( + "--allowed-exts", + type=_comma_list, + default=[".txt", ".md", ".rst", ".pdf"], + help="Comma-separated list of allowed extensions.", + ) + parser.add_argument( + "--max-files", + type=int, + default=None, + help="Optional cap on number of files (deterministic after sorting).", + ) + parser.add_argument( + "--output-path", + type=Path, + default=default_output_path, + help="File path for the generated summary.", + ) + + # Task / behavior + parser.add_argument( + "--mode", + choices=["summary", "synthesis"], + default="summary", + help="Select built-in task behavior; can be overridden by --task.", + ) + parser.add_argument( + "--task", + type=str, + default="", + help="Override the task text entirely (if set, ignores --mode presets).", + ) + parser.add_argument( + "--show-tool-output", + action="store_true", + help="Do not silence tool stdout/stderr.", + ) + parser.add_argument( + "--strict", + action="store_true", + help="Treat tool errors/empty reads as hard failures.", + ) + + # Provider wiring (OpenAI-compatible by default) + parser.add_argument( + "--model-choice", + type=str, + default=_DEFAULT_MODEL_CHOICE, + help=( + "Model choice string used by setup_llm (e.g., provider_alias:model_id). " + "Default targets an OpenAI-compatible provider." + ), + ) + parser.add_argument( + "--base-url", + type=str, + default=_DEFAULT_BASE_URL, + help="Base URL for an OpenAI-compatible endpoint (can be a gateway).", + ) + parser.add_argument( + "--api-key-env", + type=str, + default=_DEFAULT_API_KEY_ENV, + help="Name of env var that contains the API key (read by setup_llm).", + ) + + # LLM params + parser.add_argument( + "--temperature", type=float, default=0.2, help="Sampling temperature." + ) + parser.add_argument( + "--max-retries", type=int, default=2, help="Retry attempts." + ) + parser.add_argument( + "--max-tokens", + type=int, + default=2500, + help="Max tokens for completion (ChatOpenAI style).", + ) + + # Summarizer knobs + parser.add_argument( + "--chunk-size-chars", + type=int, + default=10000, + help="Chunk size in characters.", + ) + parser.add_argument( + "--chunk-overlap-chars", + type=int, + default=800, + help="Chunk overlap in characters.", + ) + parser.add_argument( + "--max-chunks-per-file", + type=int, + default=200, + help="Cap chunks produced per file.", + ) + parser.add_argument( + "--reduce-batch-size", + type=int, + default=8, + help="Batch size per reduce round.", + ) + + # Agent framework knobs + parser.add_argument( + "--thread-id", + type=str, + default="summarizer_example", + help="Thread/run identifier passed to the agent.", + ) + parser.add_argument( + "--workspace", + type=Path, + default=None, + help="Workspace directory used by tools; default is ./workspace_summarizer next to this script.", + ) + parser.add_argument( + "--disable-metrics", + action="store_true", + help="Disable agent metrics collection.", + ) + + return parser + + +def _resolve_task(mode: str, override: str) -> str: + """Choose the task prompt string from mode/override.""" + if override and override.strip(): + return override.strip() + return _TASK_SYNTHESIS if mode == "synthesis" else _TASK_SUMMARY + + +def _build_models_cfg( + base_url: str, + api_key_env: str, + temperature: float, + max_retries: int, + max_tokens: int, +) -> dict[str, Any]: + """Build the inline models_cfg consumed by setup_llm.""" + return { + "providers": { + "openai": { + "model_provider": "openai", + "base_url": base_url, + "api_key_env": api_key_env, + } + }, + "defaults": { + "params": { + "temperature": float(temperature), + "max_retries": int(max_retries), + "max_tokens": int(max_tokens), + "use_responses_api": False, + } + }, + "agents": {"summarizer": {"params": {}}}, + } + + +# ---------------------------- +# Main +# ---------------------------- + + +def main(argv: Optional[Sequence[str]] = None) -> int: + console = _get_console() + + here = Path(__file__).resolve().parent + default_input_dir = here / "summarizing_agent_example_inputs" + default_output_path = ( + here / "summarizing_agent_example_output" / "summary.txt" + ) + + parser = _build_parser( + default_input_dir=default_input_dir, + default_output_path=default_output_path, + ) + args = parser.parse_args(argv) + + input_dir: Path = args.input_dir + output_path: Path = args.output_path + output_path.parent.mkdir(parents=True, exist_ok=True) + + workspace = ( + args.workspace + if args.workspace is not None + else (here / "workspace_summarizer") + ) + allowed_exts = _normalize_exts(args.allowed_exts) + task = _resolve_task(args.mode, args.task) + + models_cfg = _build_models_cfg( + base_url=args.base_url, + api_key_env=args.api_key_env, + temperature=args.temperature, + max_retries=args.max_retries, + max_tokens=args.max_tokens, + ) + + _panel( + console, + "SummarizingAgent Example", + [ + f"Input: {input_dir}", + f"Output: {output_path}", + f"Mode: {args.mode}", + f"Recurse: {bool(args.recurse)}", + f"Allowed exts: {', '.join(allowed_exts)}", + f"Max files: {args.max_files}", + f"Tool output: {'ON' if args.show_tool_output else 'OFF'}", + f"Strict: {bool(args.strict)}", + f"Model choice: {args.model_choice}", + f"Base URL: {args.base_url}", + f"API key env: {args.api_key_env}", + ], + ) + + if not input_dir.exists() or not input_dir.is_dir(): + msg = f"input-dir is not a directory: {input_dir}" + if console is not None: + console.print(f"[bold red]✖ Error:[/bold red] {msg}") + else: + print(f"Error: {msg}") + return 2 + + try: + llm = setup_llm( + model_choice=args.model_choice, + models_cfg=models_cfg, + agent_name="summarizer", + base_llm_kwargs={ + "max_tokens": int(args.max_tokens), + # Prevent older defaults from winning if both exist. + "max_completion_tokens": None, + }, + console=console, + ) + + agent = SummarizingAgent( + llm=llm, + thread_id=args.thread_id, + workspace=Path(workspace), + enable_metrics=(not args.disable_metrics), + ) + + with _ProgressTracker(console) as tracker: + result = agent.invoke({ + "input_docs_dir": str(input_dir), + "recurse": bool(args.recurse), + "allowed_extensions": allowed_exts, + "max_files": args.max_files, + "task": task, + "silent_tools": (not args.show_tool_output), + "strict": bool(args.strict), + "chunk_size_chars": int(args.chunk_size_chars), + "chunk_overlap_chars": int(args.chunk_overlap_chars), + "max_chunks_per_file": int(args.max_chunks_per_file), + "reduce_batch_size": int(args.reduce_batch_size), + "on_event": tracker.on_event, + }) + + summary = (result.get("summary") or "").strip() + output_path.write_text(summary + "\n", encoding="utf-8") + + if console is not None: + console.print(f"[bold green]✔ Wrote[/bold green] {output_path}") + else: + print(f"OK: wrote {output_path}") + return 0 + + except Exception as e: + if console is not None: + console.print( + f"[bold red]✖ Error:[/bold red] {type(e).__name__}: {e}" + ) + try: + from rich.traceback import Traceback # type: ignore + + console.print( + Traceback.from_exception( + type(e), e, e.__traceback__, show_locals=False + ) + ) + except Exception: + pass + else: + print(f"Error: {type(e).__name__}: {e}") + + try: + output_path.write_text( + f"[Error] {type(e).__name__}: {e}\n", encoding="utf-8" + ) + except Exception: + pass + + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/ursa/agents/__init__.py b/src/ursa/agents/__init__.py index 3b44c8b9..bf72a940 100644 --- a/src/ursa/agents/__init__.py +++ b/src/ursa/agents/__init__.py @@ -19,6 +19,7 @@ "RAGAgent": (".rag_agent", "RAGAgent"), "RecallAgent": (".recall_agent", "RecallAgent"), "WebSearchAgentLegacy": (".websearch_agent", "WebSearchAgentLegacy"), + "SummarizingAgent": (".summarizing_agent", "SummarizingAgent"), } __all__ = list(_lazy_attrs.keys()) diff --git a/src/ursa/agents/summarizing_agent.py b/src/ursa/agents/summarizing_agent.py new file mode 100644 index 00000000..c76afb11 --- /dev/null +++ b/src/ursa/agents/summarizing_agent.py @@ -0,0 +1,577 @@ +from __future__ import annotations + +import contextlib +import io +import math +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated, Any, Callable, Optional, Sequence, TypedDict + +from langchain.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage, SystemMessage +from langgraph.graph import END +from langgraph.graph.message import add_messages + +from ursa.agents.base import BaseAgent +from ursa.prompt_library.summarizing_prompts import ( + FINAL_COVERAGE_INSTRUCTION, + MAP_USER_INSTRUCTIONS, + SYSTEM_MAP_PROMPT, + SYSTEM_NOTES_REDUCE_PROMPT, + SYSTEM_REDUCE_PROMPT, + SYSTEM_REWRITE_PROMPT, +) +from ursa.tools import read_file + +_DENY_RE = re.compile( + r"\b(doc|docs|document|documents|file|files|filename|filenames|source|sources|excerpt|excerpts|chunk|chunks)\b" + r"|\b(the text above|the passage above|this text|this document|this file|above text)\b" + r"|\b(url|urls|link|links)\b" + r"|https?://\S+" + r"|\bwww\.\S+", + re.IGNORECASE, +) +_HEADING_RE = re.compile(r"(?m)^\s*#{1,6}\s+") +_RULE_RE = re.compile(r"(?m)^\s*(-{3,}|={3,})\s*$") + + +def _needs_rewrite(text: str) -> bool: + """Return True if the generated text likely violates formatting/meta constraints.""" + if not text.strip(): + return False + return bool( + _DENY_RE.search(text) + or _HEADING_RE.search(text) + or _RULE_RE.search(text) + ) + + +def _normalize_whitespace(s: str) -> str: + """Normalize line endings and collapse excessive blank lines for stable prompting/output.""" + s = s.replace("\r\n", "\n").replace("\r", "\n") + s = re.sub(r"[ \t]+\n", "\n", s) + s = re.sub(r"\n{3,}", "\n\n", s) + return s.strip() + + +def _chunk_text( + text: str, + *, + chunk_size_chars: int, + overlap_chars: int, + max_chunks: int, +) -> list[str]: + """ + Deterministic, paragraph-aware chunking with optional overlap. + + - Keeps paragraphs intact where possible. + - Splits oversized paragraphs into fixed-size segments. + - Adds overlap between consecutive chunks to preserve continuity. + """ + text = _normalize_whitespace(text) + if len(text) <= chunk_size_chars: + return [text] + + paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] + chunks: list[str] = [] + buf: list[str] = [] + n = 0 + + def flush() -> None: + """Flush current buffered paragraphs to a new chunk.""" + nonlocal buf, n + if buf: + chunks.append("\n\n".join(buf).strip()) + buf = [] + n = 0 + + for p in paras: + if len(p) > chunk_size_chars: + flush() + start = 0 + while start < len(p): + end = min(len(p), start + chunk_size_chars) + seg = p[start:end].strip() + if seg: + chunks.append(seg) + if end >= len(p): + break + start = max(0, end - overlap_chars) + continue + + add = len(p) + (2 if buf else 0) + if n + add <= chunk_size_chars: + buf.append(p) + n += add + else: + flush() + buf.append(p) + n = len(p) + + flush() + + if overlap_chars > 0 and len(chunks) > 1: + out = [chunks[0]] + for prev, nxt in zip(chunks, chunks[1:]): + out.append((prev[-overlap_chars:] + "\n\n" + nxt).strip()) + chunks = out + + return chunks[:max_chunks] + + +@dataclass(frozen=True) +class _Doc: + """In-memory representation of an input document.""" + + name: str + text: str + + +@contextlib.contextmanager +def _silence_stdio(enabled: bool): + """ + Optionally suppress stdout/stderr. + + Some tools can be chatty; silencing keeps logs and notebooks clean. + """ + if not enabled: + yield + return + sink = io.StringIO() + with contextlib.redirect_stdout(sink), contextlib.redirect_stderr(sink): + yield + + +class SummarizingState(TypedDict, total=False): + """ + Agent state schema. + + Inputs control selection, reading, chunking, and summarization behavior. + Outputs include the final summary plus file selection diagnostics. + """ + + messages: Annotated[list, add_messages] + + input_docs_dir: str + recurse: bool + allowed_extensions: Sequence[str] + max_files: Optional[int] + task: Optional[str] + silent_tools: bool + strict: bool + chunk_size_chars: int + chunk_overlap_chars: int + max_chunks_per_file: int + reduce_batch_size: int + on_event: Optional[Callable[[str, dict[str, Any]], None]] + + summary: str + selected_files: list[str] + skipped_files: list[str] + + +class SummarizingAgent(BaseAgent[SummarizingState]): + """ + Summarize a directory of documents using a balanced map-reduce strategy. + + Strategy: + - Per-file MAP: summarize each chunk into compact notes. + - Per-file REDUCE: merge chunk notes into a single “doc notes” artifact per file. + - Global REDUCE: merge doc notes into a single structured summary. + - Optional REWRITE: scrub meta-language if the model violates constraints. + + This two-level approach prevents long files from dominating the final output purely by chunk count. + """ + + state_type = SummarizingState + + def __init__(self, llm: BaseChatModel, **kwargs): + super().__init__(llm=llm, **kwargs) + + def format_result(self, state: SummarizingState) -> str: + """Return the final summary text for the framework's result handling.""" + return state.get("summary", "") + + def _build_graph(self): + """Single-node graph: summarize and finish.""" + self.add_node(self._summarize_node, "summarize") + self.graph.set_entry_point("summarize") + self.graph.add_edge("summarize", END) + self.graph.set_finish_point("summarize") + + def _summarize_node(self, state: SummarizingState) -> SummarizingState: + """Read, chunk, summarize per file, then globally synthesize into one output.""" + + def emit(event: str, **data: Any) -> None: + cb = state.get("on_event") + if cb is not None: + cb(event, data) + + input_dir = Path(str(state.get("input_docs_dir", ""))) + if not input_dir.exists() or not input_dir.is_dir(): + raise ValueError(f"input_docs_dir is not a directory: {input_dir}") + + recurse = bool(state.get("recurse", False)) + allowed_exts = tuple( + str(e).lower() + for e in ( + state.get("allowed_extensions") + or (".txt", ".md", ".rst", ".pdf") + ) + ) + max_files = state.get("max_files") + silent_tools = bool(state.get("silent_tools", True)) + strict = bool(state.get("strict", False)) + + chunk_size = int(state.get("chunk_size_chars", 10_000)) + chunk_overlap = int(state.get("chunk_overlap_chars", 800)) + max_chunks_per_file = int(state.get("max_chunks_per_file", 200)) + reduce_batch = int(state.get("reduce_batch_size", 8)) + + task = state.get("task") or ( + "Write a single unified summary in one consistent voice. " + "Do not segment the output by input, and do not produce separate per-item summaries." + ) + + emit("start", input=str(input_dir), recurse=recurse) + + selected, skipped = self._select_files( + root=input_dir, + recurse=recurse, + allowed_exts=allowed_exts, + max_files=max_files, + ) + emit("discover_done", count=len(selected)) + + docs: list[_Doc] = [] + tool_state = {"workspace": str(input_dir)} + + for rel in selected: + emit("read_start", file=rel) + try: + with _silence_stdio(silent_tools): + txt = ( + read_file.func(filename=rel, state=tool_state) # type: ignore[attr-defined] + if hasattr(read_file, "func") + else read_file(filename=rel, state=tool_state) # type: ignore[misc] + ) + + s = _normalize_whitespace("" if txt is None else str(txt)) + + if s.startswith("[Error]") or s.startswith("[Error]:"): + if strict: + raise RuntimeError(f"read_file error for {rel}: {s}") + emit("read_skip", file=rel, reason="tool_error") + continue + + if not s: + if strict: + raise RuntimeError(f"Empty file content for {rel}") + emit("read_skip", file=rel, reason="empty") + continue + + docs.append(_Doc(name=rel, text=s)) + emit("read_ok", file=rel) + + except Exception as e: + if strict: + raise + emit( + "read_skip", + file=rel, + reason=f"exception:{type(e).__name__}", + ) + continue + + if not docs: + return { + "summary": "", + "selected_files": selected, + "skipped_files": skipped, + } + + # --- Per-file processing --- + # Each file produces exactly one doc-notes artifact, preventing long files from dominating. + doc_notes: list[str] = [] + total_chunks = 0 + + for d in docs: + chunks = _chunk_text( + d.text, + chunk_size_chars=chunk_size, + overlap_chars=chunk_overlap, + max_chunks=max_chunks_per_file, + ) + chunks = [c for c in chunks if c.strip()] + total_chunks += len(chunks) + + emit("file_chunking_done", file=d.name, chunks=len(chunks)) + + if not chunks: + continue + + # MAP: each chunk -> compact bullet notes. + partials: list[str] = [] + for i, ch in enumerate(chunks, start=1): + prompt_user = ( + f"Task:\n{task}\n\n" + f"{MAP_USER_INSTRUCTIONS}\n" + "Passage:\n" + "```\n" + f"{ch}\n" + "```" + ) + + msg = self.llm.invoke( + [ + SystemMessage(content=SYSTEM_MAP_PROMPT), + HumanMessage(content=prompt_user), + ], + self.build_config(tags=["summarizer", "map"]), + ) + partials.append(msg.text.strip()) + emit("map_done", file=d.name, i=i, total=len(chunks)) + + # Per-file REDUCE: merge chunk-notes into one bullet list (doc-notes). + note_text = self._reduce_notes_list( + partials, + task=task, + reduce_batch=reduce_batch, + emit=emit, + scope=f"file:{d.name}", + ) + + if note_text.strip(): + doc_notes.append(note_text.strip()) + emit("file_notes_done", file=d.name) + + if not doc_notes: + return { + "summary": "", + "selected_files": selected, + "skipped_files": skipped, + } + + emit("chunking_done", chunks=total_chunks, files=len(docs)) + + # --- Global REDUCE: doc-notes -> final structured output --- + summary = self._reduce_final( + doc_notes, + task=task, + reduce_batch=reduce_batch, + emit=emit, + ) + + summary = _normalize_whitespace(summary) + + if _needs_rewrite(summary): + emit("rewrite_start") + prompt_user = ( + "Rewrite the text below to remove forbidden references and meta-language.\n\n" + "Text:\n" + "```\n" + f"{summary}\n" + "```" + ) + msg = self.llm.invoke( + [ + SystemMessage(content=SYSTEM_REWRITE_PROMPT), + HumanMessage(content=prompt_user), + ], + self.build_config(tags=["summarizer", "rewrite"]), + ) + summary2 = _normalize_whitespace(msg.text.strip()) + if summary2: + summary = summary2 + emit("rewrite_done", changed=True) + else: + emit("rewrite_done", changed=False) + + emit("done", chars=len(summary)) + + return { + "summary": summary, + "selected_files": selected, + "skipped_files": skipped, + } + + def _reduce_notes_list( + self, + items: list[str], + *, + task: str, + reduce_batch: int, + emit: Callable[..., None], + scope: str, + ) -> str: + """ + Reduce a list of note blocks down to a single bullet-list note artifact. + + Used for per-file reduction so each file contributes one “doc-notes” unit. + """ + current = [x for x in items if (x or "").strip()] + if not current: + return "" + + round_i = 0 + emit( + "notes_reduce_start", + scope=scope, + items=len(current), + batch_size=reduce_batch, + ) + + while len(current) > 1: + round_i += 1 + num_batches = math.ceil(len(current) / reduce_batch) + emit( + "notes_reduce_round", + scope=scope, + round=round_i, + items=len(current), + batches=num_batches, + ) + + nxt: list[str] = [] + batch_idx = 0 + + for i in range(0, len(current), reduce_batch): + batch_idx += 1 + batch = current[i : i + reduce_batch] + material = "\n\n".join(batch).strip() + + prompt_user = ( + f"Task:\n{task}\n\n" + "Merge the notes below into a single bullet list that preserves distinct, concrete details.\n\n" + f"Notes:\n{material}" + ) + + msg = self.llm.invoke( + [ + SystemMessage(content=SYSTEM_NOTES_REDUCE_PROMPT), + HumanMessage(content=prompt_user), + ], + self.build_config(tags=["summarizer", "notes_reduce"]), + ) + nxt.append(msg.text.strip()) + emit( + "notes_reduce_batch_done", + scope=scope, + round=round_i, + batch=batch_idx, + batches=num_batches, + ) + + current = nxt + + emit("notes_reduce_done", scope=scope, rounds=round_i) + return current[0].strip() + + def _reduce_final( + self, + doc_notes: list[str], + *, + task: str, + reduce_batch: int, + emit: Callable[..., None], + ) -> str: + """ + Reduce per-file notes into a single final structured summary. + """ + current = [x for x in doc_notes if (x or "").strip()] + round_i = 0 + + emit("reduce_start", items=len(current), batch_size=reduce_batch) + + while len(current) > 1: + round_i += 1 + num_batches = math.ceil(len(current) / reduce_batch) + emit( + "reduce_round", + round=round_i, + items=len(current), + batches=num_batches, + ) + + nxt: list[str] = [] + batch_idx = 0 + + for i in range(0, len(current), reduce_batch): + batch_idx += 1 + batch = current[i : i + reduce_batch] + + emit( + "reduce_batch_start", + round=round_i, + batch=batch_idx, + batches=num_batches, + size=len(batch), + ) + + material = "\n\n".join(batch).strip() + prompt_user = ( + f"Task:\n{task}\n\n" + f"{FINAL_COVERAGE_INSTRUCTION}\n" + f"Material:\n{material}" + ) + + msg = self.llm.invoke( + [ + SystemMessage(content=SYSTEM_REDUCE_PROMPT), + HumanMessage(content=prompt_user), + ], + self.build_config(tags=["summarizer", "reduce"]), + ) + nxt.append(msg.text.strip()) + emit( + "reduce_batch_done", + round=round_i, + batch=batch_idx, + batches=num_batches, + ) + + current = nxt + + emit("reduce_done", rounds=round_i) + return current[0].strip() + + def _select_files( + self, + *, + root: Path, + recurse: bool, + allowed_exts: tuple[str, ...], + max_files: Optional[int], + ) -> tuple[list[str], list[str]]: + """ + Select eligible files under `root`. + + - Skips hidden paths (any segment starting with '.'). + - Filters by allowed file extensions. + - Applies `max_files` deterministically after sorting. + """ + + def eligible(p: Path) -> bool: + if not p.is_file(): + return False + if any(part.startswith(".") for part in p.relative_to(root).parts): + return False + return p.suffix.lower() in allowed_exts + + paths = [ + p + for p in (root.rglob("*") if recurse else root.iterdir()) + if eligible(p) + ] + rels = sorted( + [p.relative_to(root).as_posix() for p in paths], + key=lambda s: s.lower(), + ) + + skipped: list[str] = [] + if max_files is not None and len(rels) > int(max_files): + skipped.extend(rels[int(max_files) :]) + rels = rels[: int(max_files)] + + return rels, sorted(set(skipped)) diff --git a/src/ursa/prompt_library/summarizing_prompts.py b/src/ursa/prompt_library/summarizing_prompts.py new file mode 100644 index 00000000..0ae6dacd --- /dev/null +++ b/src/ursa/prompt_library/summarizing_prompts.py @@ -0,0 +1,71 @@ +""" +Prompt strings used by the summarizing agent. + +Kept centralized so: +- prompt diffs are prompt-only, +- agent logic stays readable, +- other agents can reuse constraints. +""" + +FORBIDDEN_REFERENCES = ( + "filenames, file paths, URLs/links, sources, documents, chunks, excerpts, " + "or any statement about inputs (e.g., 'the text above', 'this document', 'the first story')." +) + +# MAP phase: produce compact bullet notes from each chunk. +SYSTEM_MAP_PROMPT = ( + "You are a summarization engine.\n" + f"Hard constraints: do NOT mention {FORBIDDEN_REFERENCES}\n" + "No citations or disclaimers. No headings.\n" +) + +# Reduce notes *within a single file* (keeps per-file coverage balanced). +SYSTEM_NOTES_REDUCE_PROMPT = ( + "You are a summarization engine.\n" + f"Hard constraints: do NOT mention {FORBIDDEN_REFERENCES}\n" + "Output ONLY a bullet list.\n" + "Constraints:\n" + "- 10–18 bullets\n" + "- Each bullet <= 24 words\n" + "- Preserve specific, decision-relevant details (names, numbers, constraints, tradeoffs)\n" + "- No headings, no preamble, no conclusion\n" +) + +# Final REDUCE phase: merge per-file notes into one unified summary with strict structure. +SYSTEM_REDUCE_PROMPT = ( + "You are a summarization engine.\n" + f"Hard constraints: do NOT mention {FORBIDDEN_REFERENCES}\n" + "Output must have EXACTLY these parts, in this order:\n" + "1) Executive Summary: a single paragraph (<=200 words).\n" + "2) Required Actions: 5–10 bullet points.\n" + "3) Main Synthesis: cohesive narrative prose paragraphs (no bullets, no lists), 600–900 words.\n" + "No per-input segmentation. No extra sections.\n" +) + +# Optional rewrite pass: scrub meta-language/segmentation if the model violates constraints. +SYSTEM_REWRITE_PROMPT = ( + "You are a rewriting engine.\n" + f"Hard constraints: remove any mentions of {FORBIDDEN_REFERENCES}\n" + "Preserve the required structure exactly:\n" + "Executive Summary (paragraph), Required Actions (bullets), Main Synthesis (paragraphs).\n" + "Do NOT remove the bullets under Required Actions.\n" + "Output only the rewritten text.\n" +) + +# User instruction block for MAP. +MAP_USER_INSTRUCTIONS = ( + "Write ultra-compact notes for later synthesis.\n" + "- 6–10 bullets max\n" + "- Each bullet <= 20 words\n" + "- Focus only on decision-relevant facts, conflicts, constraints, and outcomes\n" + "- No fluff, no scene-setting, no quotes\n" + "Do not mention inputs.\n" +) + +# Used during final reduction to prevent “one-source collapse” when inputs differ in size. +FINAL_COVERAGE_INSTRUCTION = ( + "Coverage constraint:\n" + "- The output must reflect material from multiple distinct inputs.\n" + "- Include at least 1 concrete, non-generic detail from each input if possible.\n" + "- If details are incompatible, explicitly describe the tension without attributing to inputs.\n" +) diff --git a/src/ursa/util/llm_factory.py b/src/ursa/util/llm_factory.py new file mode 100644 index 00000000..7f70b5d2 --- /dev/null +++ b/src/ursa/util/llm_factory.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +import importlib +import json +import os + +# needed for SSL / PKI verifications, if the user needs that +import ssl +from typing import Any + +import httpx +from langchain.chat_models import init_chat_model + +""" +llm_factory.py + +Utilities for constructing LangChain chat models for URSA runners using a YAML config. + +This centralizes: +- provider alias resolution (e.g., "openai:gpt-5" or "my_endpoint:openai/gpt-oss-120b") +- auth and base_url wiring from cfg.models.providers +- merging of per-run defaults + optional YAML profiles + per-agent overrides +- safe(ish) logging banners that avoid printing secrets verbatim + +Goal: any URSA program (plan/execute, hypothesizer, etc.) can share the same model +configuration behavior and get consistent logging and overrides. +""" + + +# --------------------------------------------------------------------- +# Secret masking / sanitization for logs +# --------------------------------------------------------------------- +_SECRET_KEY_SUBSTRS = ( + "api_key", + "apikey", + "access_token", + "refresh_token", + "secret", + "password", + "bearer", +) + + +def _looks_like_secret_key(name: str) -> bool: + n = name.lower() + return any(s in n for s in _SECRET_KEY_SUBSTRS) + + +def _mask_secret(value: str, keep_start: int = 6, keep_end: int = 4) -> str: + """ + Mask a secret-like string, keeping only the beginning and end. + Example: sk-proj-abc123456789xyz -> sk-proj-...9xyz + """ + if not isinstance(value, str): + return value + if len(value) <= keep_start + keep_end + 3: + return "…" + return f"{value[:keep_start]}...{value[-keep_end:]}" + + +def _json_safe(obj: Any) -> Any: + """Best-effort conversion to something json.dumps can handle.""" + # Primitives are fine + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + # Common containers + if isinstance(obj, dict): + return {str(k): _json_safe(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_json_safe(v) for v in obj] + if isinstance(obj, tuple): + return [_json_safe(v) for v in obj] + # Fallback: readable repr (keeps type info) + return f"<{obj.__class__.__name__}>" + + +def _sanitize_for_logging(obj: Any) -> Any: + if isinstance(obj, dict): + out = {} + for k, v in obj.items(): + if _looks_like_secret_key(str(k)): + out[k] = _mask_secret(v) if isinstance(v, str) else "..." + else: + out[k] = _sanitize_for_logging(v) + return out + if isinstance(obj, list): + return [_sanitize_for_logging(v) for v in obj] + return _json_safe(obj) + + +# --------------------------------------------------------------------- +# Dict merge + YAML param resolution +# --------------------------------------------------------------------- +def _deep_merge_dicts(base: dict, override: dict) -> dict: + """ + Recursively merge override into base and return a new dict. + - dict + dict => deep merge + - otherwise => override wins + """ + base = dict(base or {}) + override = dict(override or {}) + out = dict(base) + for k, v in override.items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = _deep_merge_dicts(out[k], v) + else: + out[k] = v + return out + + +def _resolve_llm_kwargs_for_agent( + models_cfg: dict | None, agent_name: str | None +) -> dict: + """ + Given the YAML `models:` dict, compute merged kwargs for init_chat_model(...) + for a specific agent ('planner', 'executor', etc.). + + Merge order (later wins): + 1) {} (empty) + 2) models.defaults.params (optional) + 3) models.profiles[defaults.profile] (optional) + 4) models.agents[agent_name].profile (optional; merges that profile on top) + 5) models.agents[agent_name].params (optional) + """ + models_cfg = models_cfg or {} + profiles = models_cfg.get("profiles") or {} + defaults = models_cfg.get("defaults") or {} + agents = models_cfg.get("agents") or {} + + merged: dict = {} + merged = _deep_merge_dicts(merged, defaults.get("params") or {}) + + default_profile_name = defaults.get("profile") + if default_profile_name and default_profile_name in profiles: + merged = _deep_merge_dicts(merged, profiles[default_profile_name] or {}) + + if agent_name and isinstance(agents, dict) and agent_name in agents: + a = agents.get(agent_name) or {} + agent_profile_name = a.get("profile") + if agent_profile_name and agent_profile_name in profiles: + merged = _deep_merge_dicts( + merged, profiles[agent_profile_name] or {} + ) + merged = _deep_merge_dicts(merged, a.get("params") or {}) + + return merged + + +# --------------------------------------------------------------------- +# Provider / model string resolution +# --------------------------------------------------------------------- +def _resolve_model_choice( + model_choice: str, models_cfg: dict +) -> tuple[str, str, dict]: + """ + Accepts strings like: + - 'openai:gpt-5.2' + - 'my_endpoint:openai/gpt-oss-120b' + + Looks up per-provider settings from cfg.models.providers. + + Returns: + (model_provider, pure_model, provider_extra_kwargs_for_init) + + where: + - model_provider is a LangChain provider string (e.g., "openai") + - pure_model is the model name passed as `model=...` + - provider_extra_kwargs may include base_url/api_key + """ + if ":" in model_choice: + alias, pure_model = model_choice.split(":", 1) + else: + alias, pure_model = "openai", model_choice # back-compat default + + providers = (models_cfg or {}).get("providers", {}) + prov = providers.get(alias, {}) + + model_provider = prov.get("model_provider", alias) + + api_key = None + if prov.get("api_key_env"): + api_key = os.getenv(prov["api_key_env"]) + if not api_key and prov.get("token_loader"): + mod, fn = prov["token_loader"].rsplit(".", 1) + api_key = getattr(importlib.import_module(mod), fn)() + + provider_extra = {} + if prov.get("base_url"): + provider_extra["base_url"] = prov["base_url"] + if api_key: + provider_extra["api_key"] = api_key + + return model_provider, pure_model, provider_extra + + +# --------------------------------------------------------------------- +# Logging banner +# --------------------------------------------------------------------- +def _print_llm_init_banner( + *, + agent_name: str | None, + provider: str, + model_name: str, + provider_extra: dict, + llm_kwargs: dict, + model_obj: Any = None, + console: Any | None = None, +) -> None: + """ + Print a friendly summary of the init_chat_model(...) configuration. + + If `console` is a rich Console, we render Panels. Otherwise we print plain text. + """ + who = agent_name or "llm" + safe_provider_extra = _sanitize_for_logging(provider_extra or {}) + safe_llm_kwargs = _sanitize_for_logging(llm_kwargs or {}) + + text = ( + f"LLM init ({who})\n" + f"provider: {provider}\n" + f"model: {model_name}\n\n" + f"provider kwargs: {json.dumps(safe_provider_extra, indent=2)}\n\n" + f"llm kwargs (merged): {json.dumps(safe_llm_kwargs, indent=2)}" + ) + + if console is not None: + try: + from rich.panel import Panel + from rich.text import Text + + console.print( + Panel.fit( + Text.from_markup(text.replace("\n", "\n")), + border_style="cyan", + ) + ) + except Exception: + print(text) + else: + print(text) + + # Best-effort readback from the LangChain model object + if model_obj is not None: + readback = {} + for attr in ( + "model_name", + "model", + "reasoning", + "temperature", + "max_completion_tokens", + "max_tokens", + ): + if hasattr(model_obj, attr): + try: + readback[attr] = getattr(model_obj, attr) + except Exception: + pass + + for attr in ("model_kwargs", "kwargs"): + if hasattr(model_obj, attr): + try: + readback[attr] = getattr(model_obj, attr) + except Exception: + pass + + if readback: + rb_text = ( + "LLM readback (best-effort from LangChain object)\n" + + json.dumps(_sanitize_for_logging(readback), indent=2) + ) + if console is not None: + try: + from rich.panel import Panel + from rich.text import Text + + console.print( + Panel.fit( + Text.from_markup(rb_text), border_style="green" + ) + ) + except Exception: + print(rb_text) + else: + print(rb_text) + + # If reasoning effort was requested, highlight it + effort = None + try: + effort = (llm_kwargs or {}).get("reasoning", {}).get("effort") + except Exception: + effort = None + + if effort: + msg = ( + f"Reasoning effort requested: {effort}\n" + "Note: This confirms what we sent to init_chat_model; actual enforcement is provider-side." + ) + if console is not None: + try: + from rich.panel import Panel + + console.print(Panel.fit(msg, border_style="yellow")) + except Exception: + print(msg) + else: + print(msg) + + +def _maybe_add_system_trust_httpx_clients( + provider: str, provider_extra: dict +) -> dict: + """ + If we're using OpenAI-compatible provider + running on macOS corporate PKI, + use system trust store via truststore to avoid certifi/conda OpenSSL issues. + """ + if provider != "openai": + return provider_extra + + # Don't override if caller already provided custom clients + if "http_client" in provider_extra or "http_async_client" in provider_extra: + return provider_extra + + try: + import truststore # pip install truststore + except Exception: + return provider_extra + + ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # Provide both sync + async so invoke() and ainvoke() both work. + provider_extra = dict(provider_extra or {}) + provider_extra["http_client"] = httpx.Client(verify=ctx, trust_env=False) + provider_extra["http_async_client"] = httpx.AsyncClient( + verify=ctx, trust_env=False + ) + return provider_extra + + +# --------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------- +def resolve_model_choice( + model_choice: str, models_cfg: dict +) -> tuple[str, str, dict]: + """ + Public wrapper around the internal model/provider resolver. + + Accepts strings like 'openai:gpt-5.2' or 'my_endpoint:openai/gpt-oss-120b' + and returns: + (model_provider, pure_model, provider_extra_kwargs) + + provider_extra_kwargs may include base_url/api_key. + """ + return _resolve_model_choice(model_choice, models_cfg) + + +def setup_llm( + *, + model_choice: str, + models_cfg: dict | None = None, + agent_name: str | None = None, + base_llm_kwargs: dict | None = None, + console: Any | None = None, +): + """ + Build a LangChain chat model via init_chat_model(...), applying YAML-driven params. + + - `model_choice`: e.g. "openai:gpt-5" or "my_endpoint:openai/gpt-oss-120b" + - `models_cfg`: cfg.models dict + - `agent_name`: "planner" / "executor" / "hypothesizer" etc. (applies per-agent overrides) + - `base_llm_kwargs`: default kwargs that apply before YAML overrides + - `console`: optional rich console for pretty banners + + Behavior matches your runner: + base defaults < YAML overrides + """ + models_cfg = models_cfg or {} + + provider, pure_model, provider_extra = _resolve_model_choice( + model_choice, models_cfg + ) + provider_extra = _maybe_add_system_trust_httpx_clients( + provider, provider_extra + ) + + # Preserve your existing hardcoded defaults by default + default_base = { + "max_completion_tokens": 10000, + "max_retries": 2, + } + base = _deep_merge_dicts(default_base, base_llm_kwargs or {}) + + yaml_llm_kwargs = _resolve_llm_kwargs_for_agent(models_cfg, agent_name) + llm_kwargs = _deep_merge_dicts(base, yaml_llm_kwargs) + + model = init_chat_model( + model=pure_model, + model_provider=provider, + **llm_kwargs, + **(provider_extra or {}), + ) + + _print_llm_init_banner( + agent_name=agent_name, + provider=provider, + model_name=pure_model, + provider_extra=provider_extra, + llm_kwargs=llm_kwargs, + model_obj=model, + console=console, + ) + + return model diff --git a/tests/agents/test_summarizing_agent/test_summarizing_agent.py b/tests/agents/test_summarizing_agent/test_summarizing_agent.py new file mode 100644 index 00000000..fb42b8e5 --- /dev/null +++ b/tests/agents/test_summarizing_agent/test_summarizing_agent.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Optional + +import pytest + +from ursa.agents.summarizing_agent import SummarizingAgent +from ursa.prompt_library.summarizing_prompts import ( + SYSTEM_MAP_PROMPT, + SYSTEM_REDUCE_PROMPT, + SYSTEM_REWRITE_PROMPT, +) + +# ---------------------------- +# Test doubles +# ---------------------------- + + +@dataclass +class _LLMCall: + system: str + user: str + tags: tuple[str, ...] + + +class FakeLLM: + """ + Deterministic LLM stub that returns stage-appropriate outputs. + + The SummarizingAgent calls llm.invoke([SystemMessage, HumanMessage], config=...). + We detect stage from the system prompt. + """ + + def __init__( + self, *, reduce_output: str, rewrite_output: Optional[str] = None + ): + self.calls: list[_LLMCall] = [] + self._reduce_output = reduce_output + self._rewrite_output = rewrite_output or "" + + def invoke( + self, + messages: list[Any], + config: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + system = getattr(messages[0], "content", "") + user = getattr(messages[1], "content", "") + tags = tuple((config or {}).get("tags", [])) + self.calls.append(_LLMCall(system=system, user=user, tags=tags)) + + if system == SYSTEM_MAP_PROMPT: + # Compact bullet notes + return SimpleNamespace(text="- note one\n- note two\n- note three") + + if system == SYSTEM_REDUCE_PROMPT: + # Final structured output (may intentionally contain forbidden words to trigger rewrite) + return SimpleNamespace(text=self._reduce_output) + + if system == SYSTEM_REWRITE_PROMPT: + # Rewrite output to scrub forbidden meta-language + return SimpleNamespace(text=self._rewrite_output) + + # If prompts change, make failures obvious. + raise AssertionError( + f"Unexpected system prompt passed to LLM: {system!r}" + ) + + +def _stub_read_file_from_workspace(filename: str, state: dict[str, Any]) -> str: + """ + Tool stub for ursa.tools.read_file. + + The agent passes tool_state = {"workspace": }. + The agent passes filename as a relative path under that workspace. + """ + workspace = Path(state["workspace"]) + p = workspace / filename + return p.read_text(encoding="utf-8") + + +# ---------------------------- +# Fixtures +# ---------------------------- + + +@pytest.fixture(autouse=True) +def stub_read_file_tool(monkeypatch): + """ + Replace the agent's imported read_file tool with a deterministic local reader. + + Note: we patch the symbol in the agent module, not ursa.tools itself, + mirroring your existing tests pattern. + """ + monkeypatch.setattr( + "ursa.agents.summarizing_agent.read_file", + _stub_read_file_from_workspace, + ) + + +# ---------------------------- +# Helpers +# ---------------------------- + + +def _write(tmp_path: Path, rel: str, text: str) -> None: + p = tmp_path / rel + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(text, encoding="utf-8") + + +def _minimal_good_summary() -> str: + # A reduce-stage output that already satisfies the structure. + return ( + "Executive Summary: This is a short executive summary.\n\n" + "Required Actions:\n" + "- Eng: Do the thing this week\n" + "- PM: Confirm scope in 30 days\n\n" + "Main Synthesis: " + ("This is a paragraph. " * 120).strip() + ) + + +def _forbidden_reduce_summary() -> str: + # Intentionally includes forbidden meta-language to trigger rewrite. + return ( + "Executive Summary: This document explains the text above.\n\n" + "Required Actions:\n" + "- Eng: Do the thing this week\n\n" + "Main Synthesis: The first document says X and the second file says Y." + ) + + +def _clean_rewrite_summary() -> str: + return ( + "Executive Summary: This is a unified summary without referencing inputs.\n\n" + "Required Actions:\n" + "- Eng: Validate the key assumptions this week\n" + "- PM: Align stakeholders in 30 days\n\n" + "Main Synthesis: " + ("Unified synthesis paragraph. " * 80).strip() + ) + + +# ---------------------------- +# Tests +# ---------------------------- + + +def test_summarizing_agent_empty_dir_returns_empty_summary(tmp_path: Path): + llm = FakeLLM(reduce_output=_minimal_good_summary()) + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + + result = agent.invoke({"input_docs_dir": str(tmp_path)}) + + assert result["summary"] == "" + assert result["selected_files"] == [] + assert result["skipped_files"] == [] + assert llm.calls == [] + + +def test_summarizing_agent_selects_files_filters_hidden_and_extensions( + tmp_path: Path, +): + # Eligible + _write(tmp_path, "a.txt", "alpha") + _write(tmp_path, "b.md", "bravo") + # Ineligible extensions + _write(tmp_path, "c.json", '{"x":1}') + # Hidden path should be skipped + _write(tmp_path, ".hidden/secret.txt", "nope") + + llm = FakeLLM(reduce_output=_minimal_good_summary()) + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + + result = agent.invoke({ + "input_docs_dir": str(tmp_path), + "allowed_extensions": (".txt", ".md"), + "recurse": True, + }) + + # We can't directly see selection order from agent output except via selected_files + assert sorted(result["selected_files"]) == ["a.txt", "b.md"] + assert all(".hidden" not in s for s in result["selected_files"]) + assert "c.json" not in result["selected_files"] + + # Should produce a summary (non-empty) given two small docs. + assert (result["summary"] or "").strip() + + +def test_summarizing_agent_max_files_enforces_deterministic_cap(tmp_path: Path): + _write(tmp_path, "a.txt", "a") + _write(tmp_path, "b.txt", "b") + _write(tmp_path, "c.txt", "c") + + llm = FakeLLM(reduce_output=_minimal_good_summary()) + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + + result = agent.invoke({ + "input_docs_dir": str(tmp_path), + "allowed_extensions": (".txt",), + "max_files": 2, + }) + + # Sorted case-insensitive: a, b, c -> first two selected, remaining skipped + assert result["selected_files"] == ["a.txt", "b.txt"] + assert result["skipped_files"] == ["c.txt"] + + +def test_summarizing_agent_non_strict_skips_tool_error( + tmp_path: Path, monkeypatch +): + _write(tmp_path, "a.txt", "alpha") + _write(tmp_path, "b.txt", "bravo") + + def bad_read_file(filename: str, state: dict[str, Any]) -> str: + if filename == "a.txt": + return "[Error] failed to read" + return _stub_read_file_from_workspace(filename, state) + + monkeypatch.setattr( + "ursa.agents.summarizing_agent.read_file", bad_read_file + ) + + llm = FakeLLM(reduce_output=_minimal_good_summary()) + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + + result = agent.invoke({ + "input_docs_dir": str(tmp_path), + "allowed_extensions": (".txt",), + "strict": False, + }) + + assert result["selected_files"] == ["a.txt", "b.txt"] + # a.txt is skipped internally; selected_files lists discovered, not successfully read + assert (result["summary"] or "").strip() + # Only b.txt content is summarized; LLM should still have been called. + assert llm.calls + + +def test_summarizing_agent_strict_raises_on_tool_error( + tmp_path: Path, monkeypatch +): + _write(tmp_path, "a.txt", "alpha") + + def bad_read_file(filename: str, state: dict[str, Any]) -> str: + return "[Error] failed to read" + + monkeypatch.setattr( + "ursa.agents.summarizing_agent.read_file", bad_read_file + ) + + llm = FakeLLM(reduce_output=_minimal_good_summary()) + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + + with pytest.raises(RuntimeError): + agent.invoke({ + "input_docs_dir": str(tmp_path), + "allowed_extensions": (".txt",), + "strict": True, + }) + + +def test_summarizing_agent_rewrite_pass_triggers_and_scrubs_forbidden_refs( + tmp_path: Path, +): + _write(tmp_path, "a.txt", "alpha " * 50) + _write(tmp_path, "b.txt", "bravo " * 50) + + llm = FakeLLM( + reduce_output=_forbidden_reduce_summary(), + rewrite_output=_clean_rewrite_summary(), + ) + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + + result = agent.invoke({ + "input_docs_dir": str(tmp_path), + "allowed_extensions": (".txt",), + # Keep chunks small so we always exercise map->reduce with multiple calls + "chunk_size_chars": 200, + "chunk_overlap_chars": 50, + "max_chunks_per_file": 20, + "reduce_batch_size": 2, + }) + + summary = result["summary"] + assert "Executive Summary" in summary + assert "Required Actions" in summary + assert "Main Synthesis" in summary + + # Rewrite should remove common forbidden tokens that appear in the bad reduce output. + lowered = summary.lower() + assert "document" not in lowered + assert "the text above" not in lowered + assert "first document" not in lowered + + # Ensure the rewrite stage was actually invoked. + assert any(call.system == SYSTEM_REWRITE_PROMPT for call in llm.calls) + + +def test_summarizing_agent_emits_events(tmp_path: Path): + _write(tmp_path, "a.txt", "alpha " * 20) + + llm = FakeLLM(reduce_output=_minimal_good_summary()) + + events: list[str] = [] + + def on_event(name: str, data: dict[str, Any]) -> None: + events.append(name) + + agent = SummarizingAgent(llm=llm, workspace=tmp_path) + result = agent.invoke({ + "input_docs_dir": str(tmp_path), + "allowed_extensions": (".txt",), + "on_event": on_event, + }) + + assert (result["summary"] or "").strip() + # Spot-check that major lifecycle events were emitted + assert "start" in events + assert "discover_done" in events + assert "chunking_done" in events + assert "reduce_done" in events + assert "done" in events