diff --git a/Cargo.toml b/Cargo.toml index 51f04f0..1c63844 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,13 @@ categories = ["command-line-utilities", "data-structures"] [lib] crate-type = ["rlib"] +[features] +default = [] +# Enable vector retrieval + cross-encoder rerank + adaptive KG traversal. +# Adds fastembed (embedding + reranker inference via ONNX) and usearch (ANN index). +# Off by default to keep the default binary slim. +embeddings = ["dep:fastembed", "dep:usearch"] + [dependencies] clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } @@ -56,6 +63,13 @@ tree-sitter = "0.26" tree-sitter-perl = "1.1.2" sysinfo = "0.32" +# Embedding retrieval stack (optional, behind `embeddings` feature). +# fastembed: in-process embedding inference AND cross-encoder reranking +# via ONNX Runtime. Covers BGE-small-en-v1.5 (embed) + bge-reranker-v2-m3 (rerank). +# usearch: HNSW ANN index, SIMD cosine search, file persistence. +fastembed = { version = "4", optional = true } +usearch = { version = "2", optional = true } + [target.'cfg(unix)'.dependencies] libc = "0.2" diff --git a/docs/mcp-setup.md b/docs/mcp-setup.md index b503d1a..2d94bc9 100644 --- a/docs/mcp-setup.md +++ b/docs/mcp-setup.md @@ -131,3 +131,51 @@ When the MCP server starts with an existing LeanKG project, it checks if the ind ## Fallback If the MCP server reports "LeanKG not initialized", manually run `leankg init` in your project directory, then restart the AI tool. + +## Embedding Retrieval (optional, `embeddings` feature) + +The `kg_semantic_context` tool — vector retrieve + cross-encoder rerank + adaptive KG traversal — only ships when LeanKG is built with `--features embeddings`. Default builds skip it to keep the binary lean. + +### Building with the feature + +```bash +# From a LeanKG checkout: +cargo build --release --features embeddings + +# Or directly install the binary with the feature on: +cargo install --path . --features embeddings +``` + +This pulls in `fastembed` (ONNX-backed embedding + reranker inference) and `usearch` (HNSW ANN index). The first build downloads ONNX Runtime binaries via fastembed's deps. + +### One-time setup per machine + +```bash +# 1. Pre-download embedding (BGE-small-en-v1.5, ~130MB) and reranker +# (bge-reranker-v2-m3, ~600MB) into ~/.cache/leankg/models: +leankg embed --init + +# 2. Index your project (if not already indexed): +leankg index ./src + +# 3. Build the embedding index (~seconds for incremental, minutes for a +# fresh 10k-node repo on CPU): +leankg embed +``` + +### Index lifecycle + +`leankg embed` (default) is **incremental**: it reads the `embedding_state` CozoDB table that tracks per-node freshness and only re-embeds nodes that are stale (touched by a recent `index` run), missing (newly added), or whose text blob hash changed. Orphans (state rows whose `qualified_name` is no longer in `code_elements`) are reaped. + +`leankg embed --full` ignores state and re-embeds every node. Use after a model swap or suspected index corruption. + +The `index` command marks touched elements stale but does **not** trigger `embed` automatically — embedding is a separate explicit step. The MCP tool surfaces a stale-embeddings warning in `diagnostics.embeddings_stale` so callers know when to re-run `embed`. + +### Worktree exclusion + +By default, `kg_semantic_context` filters out paths under `.worktrees/`, `.claude/worktrees/`, and `.opencode/worktrees/` to avoid duplicate-noise from agent scratch copies. Pass `include_worktrees: true` to include them. + +### Reranker fallback + +If the reranker fails to load or score, the tool falls back to ANN-order top-N (no cross-encoder). `diagnostics.reranker` will be `"fallback_ann"` instead of `"bge-reranker-v2-m3"`. The most common cause is a partial model download — re-running `leankg embed --init` fixes it. + diff --git a/docs/mcp-tools.md b/docs/mcp-tools.md index eb8f6f6..8fd0aa2 100644 --- a/docs/mcp-tools.md +++ b/docs/mcp-tools.md @@ -66,6 +66,40 @@ LeanKG exposes a comprehensive set of MCP tools for AI tools to query the knowle When the MCP server starts without an existing LeanKG project, it automatically initializes and indexes the current directory. This provides a "plug and play" experience for AI tools. +## Semantic Retrieval (optional, `embeddings` feature) + +These tools ship only when LeanKG is built with `--features embeddings`. They add vector retrieval + cross-encoder rerank + adaptive graph traversal on top of the existing keyword/graph search. + +| Tool | Description | +|------|-------------| +| `kg_semantic_context` | Vector retrieve → rerank → traverse. Best for natural-language questions where keyword search misses (e.g., 'where do we validate access rights'). Returns ranked seed nodes plus 1-2 hop graph context. | + +Setup (one-time): + +```bash +cargo run --release --features embeddings -- embed --init # pre-download models (~700MB) +cargo run --release --features embeddings -- embed # build the embedding index +``` + +Then call from any MCP client: + +```json +{ "tool": "kg_semantic_context", "arguments": { "query": "where is refund failure handled" } } +``` + +Optional arguments: `env` (default `local`), `top_k` (default 50), `rerank_top_n` (default 10), `traverse` (default true), `include_worktrees` (default false), `debug` (default false). + +Response shape (debug=false): `{ query, env, seeds[], traversed[] }`. With `debug=true`: adds `diagnostics` with reranker status, candidate counts, per-stage latency, and the edges traversed. + +Behavior notes: +- If the reranker fails to load, the tool silently falls back to ANN-order top-N (Q4 option A). `diagnostics.reranker = "fallback_ann"` surfaces this. +- If the embedding index is older than the last `index` run, `diagnostics.embeddings_stale = true` (still serves, just warns). +- Worktree scratch copies (`.worktrees/`, `.claude/worktrees/`, `.opencode/worktrees/`) are filtered out by default to avoid duplicate-noise results. + +## Auto-Indexing + +When the MCP server starts with an existing LeanKG project, it checks if the index is stale (by comparing git HEAD commit time vs database file modification time). If stale, it automatically runs incremental indexing to ensure AI tools have up-to-date context. + ## Auto-Indexing When the MCP server starts with an existing LeanKG project, it checks if the index is stale (by comparing git HEAD commit time vs database file modification time). If stale, it automatically runs incremental indexing to ensure AI tools have up-to-date context. diff --git a/docs/plans/2026-06-30-embedding-retrieve-rerank-traverse.md b/docs/plans/2026-06-30-embedding-retrieve-rerank-traverse.md new file mode 100644 index 0000000..b7f1b53 --- /dev/null +++ b/docs/plans/2026-06-30-embedding-retrieve-rerank-traverse.md @@ -0,0 +1,332 @@ +# Embedding Retrieve → Rerank → KG Traverse — Unified Plan + +Date: 2026-06-30 +Status: Draft +Scope: Add vector retrieval + cross-encoder rerank + adaptive graph traversal on top of the existing ontology layer and current keyword/graph retrieval. + +## Origin and Reconciliation + +This plan unifies and supersedes the embedding-deferred sections of two prior docs: + +- `docs/design/hybrid-retrieval-reranking.md` — Phase 4 (Optional Embeddings). The hybrid retriever architecture (parallel retrievers → RRF fusion → graph-aware rerank) stays as the future direction for *multi-channel* retrieval. This plan delivers the **embedding channel** that the hybrid design flagged as optional. +- `docs/planning/2026-05-17-ontology-semantic-search-mvp.md` — Future Enhancement "embeddings for semantic alias matching". The ontology layer described there is now implemented (`src/ontology/`); this plan adds the embedding-backed semantic match that the MVP explicitly deferred. + +What this plan is **not** doing: +- Not replacing existing `semantic_search` / `kg_context` keyword and regex matching. +- Not building the full parallel-retriever + RRF fusion pipeline from `hybrid-retrieval-reranking.md`. That remains a separate follow-on. +- Not introducing Graphiti, bi-temporal history, or hosted vector databases. + +## Locked Decisions (from 2026-06-30 design review) + +| Decision | Choice | +| --- | --- | +| Runtime stack | All-Rust in-process: `fastembed` (embeddings + cross-encoder rerank, ONNX-backed) + `usearch` (ANN). No external services. `ort` was originally listed as a separate reranker dep but `fastembed::TextRerank` covers `bge-reranker-v2-m3` natively, so the dedicated `ort` dep was dropped during Phase 0. | +| What gets embedded | Text blob (qualified_name + name + doc/signature) **and** ontology description + aliases. No code body. No offline GNN embeddings. | +| Traversal policy | Adaptive: hops depend on seed `element_type`. Workflow/procedural seeds → 2 hops; function/file/concept seeds → 1 hop. Fanout cap per seed. | +| Plan handling | Extend the two existing docs into this unified plan (done). | +| Model placement (Q1) | Lazy-download to `~/.cache/leankg/models/`, SHA-256 verified. `embed --init` pre-downloads both models explicitly. | +| Worktree filter (Q2) | Default-on: exclude `**/.worktrees/**`, `**/.claude/worktrees/**`, `**/.opencode/worktrees/**` at the ANN stage. Opt-in flag `include_worktrees: bool = false`. | +| Index freshness (Q3) | `index` and `embed` are separate. `index` marks touched CodeElements as stale in a new `embedding_state` CozoDB table. `embed` (default) does **incremental** work: only nodes whose `content_hash` changed or whose state is `stale`/`missing`. No auto-embed inside `index`. | +| Reranker failure (Q4) | Option A — ANN-only fallback with `diagnostics.reranker = "fallback_ann"`. Model missing / load failure / OOM all fall back, never refuse. | + +## Architecture + +``` +[query text] + │ + ▼ Stage 1 — Embed query (fastembed, BGE-small or jina-code) +[query_vec] + │ + ▼ Stage 2 — ANN retrieve (usearch, cosine, top-K=50) +[top-50 candidate node IDs] + │ + ▼ Stage 3 — Cross-encoder rerank (fastembed::TextRerank, bge-reranker-v2-m3) +[top-N=10 seed nodes] + │ + ▼ Stage 4 — Adaptive KG traversal (CozoDB Datalog) +[seeds + 1-2 hop neighbors + edges] + │ + ▼ Stage 5 — Optional final rerank on union; compress; return MCP payload +[enriched context] +``` + +Stages 1–3 are new. Stage 4 reuses the existing CozoDB graph. Stage 5 reuses the compression logic already in `kg_context`. + +## Data Model + +### What gets a vector + +Every `CodeElement` row that is one of: + +- `element_type` in `{file, function, class, module}` (code nodes) — embed the **code text blob** +- `element_type` in `{domain_entity, service, api_endpoint, data_store, workflow, workflow_step, decision_point, failure_mode, playbook, playbook_step, known_issue, team_knowledge}` (ontology nodes) — embed the **ontology text blob** + +Docs files (when `documents` rows exist with prose) — embed the **doc text blob** (title + heading path + first paragraph). + +### Text blob construction + +Code blob: + +``` +qualified_name + "\n" + name + "\n" + doc_comment (if any) + "\n" + signature +``` + +No function body. Bounded length: truncate at 512 tokens (embedding model max). + +Ontology blob: + +``` +name + "\n" + aliases.join(", ") + "\n" + metadata.description + "\n" + element_type +``` + +Doc blob: + +``` +heading_path.join(" / ") + "\n" + title + "\n" + first_paragraph +``` + +### Vector storage + +Sidecar ANN index file, not a CozoDB table: + +``` +.leankg/ + embeddings.usearch # vectors keyed by CodeElement.id (i32) + embeddings.meta.json # model_id, dim, metric, element_type counts, build timestamp +``` + +Rationale: CozoDB has no native HNSW; `usearch` gives sub-ms cosine search in pure Rust. The `CodeElement.id` → vector mapping is the only bridge. Rebuilding the index is cheap (fastembed is CPU-friendly). + +### Incremental embedding & staleness + +Embedding a 10k-node repo takes minutes; rebuilding from scratch on every `index` run is unacceptable. The design is incremental: + +**CozoDB side table `embedding_state`** (new, in the same CozoDB file): + +``` +embedding_state { + code_element_id: i64, # FK to code_elements.id + content_hash: String, # SHA-256 of the text blob at last embed + state: String, # "fresh" | "stale" | "missing" + embedded_at: String, # ISO 8601 timestamp +} +``` + +**Marking stale.** The `index` command, after upserting `code_elements` rows, runs a single Datalog statement that flips `embedding_state.state` to `"stale"` for every `code_element_id` it touched (inserts, updates, deletes). Deleted CodeElements get `"stale"` too — `embed` will reap them from usearch and drop the state row. + +**Incremental `embed`.** Default `embed` behavior: + +1. Read all `code_elements` rows. +2. For each, compute the text blob and its `content_hash`. +3. Compare against `embedding_state.content_hash`: + - No state row → `missing` → embed, insert state row. + - Hash differs → `stale` → embed, update state row. + - Hash matches and `state = "fresh"` → skip. + - Hash matches and `state = "stale"` → re-embed (handles "touched but content unchanged" cases cheaply). +4. For state rows whose `code_element_id` no longer exists in `code_elements` → remove the vector from usearch, delete the state row. +5. Mark all touched rows `state = "fresh"`, write `embedded_at`. +6. Persist `embeddings.usearch` + `embeddings.meta.json`. + +For a typical re-index that touches 50–200 nodes, `embed` runs in seconds, not minutes. + +**Full rebuild.** `cargo run --release -- embed --full` ignores state and re-embeds every CodeElement. Use after model swap, usearch corruption, or version upgrade. + +**Model pre-download.** `cargo run --release -- embed --init` downloads both the embedding model (`bge-small-en-v1.5`) and the reranker (`bge-reranker-v2-m3` ONNX) to `~/.cache/leankg/models/`, verifies SHA-256, and exits without touching the index. Recommended setup step. If skipped, lazy-download fires on first use of each model. + +### Embedding model + +Default: `BAAI/bge-small-en-v1.5` (384-dim, fast, general text). +Optional config swap: `jinaai/jina-embeddings-v2-base-code` (code-aware, 768-dim) for code-heavy deployments. + +Model choice is stored in `embeddings.meta.json` so retrieval knows which model to use for the query vector. + +## Adaptive Traversal Rules + +Stage 4 runs a bounded CozoDB Datalog traversal per seed node. Hop count and edge filter depend on the seed's `element_type`: + +| Seed type | Hops | Allowed edge types | Fanout cap (per hop) | +| --- | --- | --- | --- | +| `workflow` | 2 | `has_step`, `next_step`, `branches_to`, `implemented_by`, `entry_point_of`, `step_in_process`, `has_failure_mode` | 20 | +| `workflow_step`, `decision_point`, `failure_mode` | 2 | `next_step`, `branches_to`, `implemented_by`, `handled_by_playbook`, `has_failure_mode`, `resolved_by_playbook` | 15 | +| `domain_entity`, `service`, `api_endpoint`, `data_store` | 1 | `owns_concept`, `implements_concept`, `exposes_endpoint`, `reads_from`, `writes_to`, `documents_concept`, `has_known_issue` | 15 | +| `known_issue`, `playbook`, `team_knowledge` | 1 | `has_known_issue`, `resolved_by_playbook`, `documents_concept` | 10 | +| `function`, `class` | 1 | `calls`, `imports`, `references`, `tested_by`, `documented_by`, `implements_concept` | 10 | +| `file`, `module` | 1 | `imports`, `references`, `tested_by`, `documented_by`, `contains`, `defines` | 10 | +| `doc` / other | 1 | `documented_by` (reverse), `documents_concept` | 5 | + +Global caps: +- Total traversed neighbors across all seeds: 60 (dedup by `qualified_name`). +- Traversal skips nodes already in the seed set. +- Edges to nodes outside the active `env` are filtered. + +## MCP Tool Contract + +New tool `kg_semantic_context`. Kept separate from `kg_context` so the existing keyword-based flow stays the default; agents can opt in. + +``` +kg_semantic_context( + query: string, + env?: string = "local", // metadata pre-filter + top_k?: number = 50, // ANN retrieve depth + rerank_top_n?: number = 10, // cross-encoder keep depth + traverse?: bool = true, // toggle Stage 4 + final_rerank?: bool = true, // toggle Stage 5 on union + debug?: bool = false // include diagnostics +) +``` + +Response shape (Stage 5 union): + +```json +{ + "query": "where is refund failure handled", + "env": "local", + "intent_hint": "explain_flow", + "seeds": [ + { + "qualified_name": "local:checkout-service:workflow:checkout:v1", + "element_type": "workflow", + "final_score": 0.88, + "ann_rank": 3, + "rerank_score": 0.88, + "matched_blob_excerpt": "Checkout flow ... authorize payment ..." + } + ], + "traversed": [ + { + "qualified_name": "local:checkout-service:workflow_step:authorize_payment:v1", + "element_type": "workflow_step", + "via_edge": "has_step", + "from_seed": "local:checkout-service:workflow:checkout:v1", + "hop": 1 + } + ], + "edges": [ + { "source": "...", "target": "...", "rel_type": "has_step" } + ], + "diagnostics": { + "ann_candidate_count": 50, + "reranker": "bge-reranker-v2-m3", + "embedder": "bge-small-en-v1.5", + "traversal": { "hops_used": 2, "neighbors_traversed": 23, "capped": false }, + "latency_ms": { "embed": 4, "ann": 1, "rerank": 22, "traverse": 6, "total": 33 } + } +} +``` + +When `debug=false`, drop `diagnostics`, `matched_blob_excerpt`, and edge list. Compress to fit MCP token budget using the same logic as `kg_context`. + +## Implementation Phases + +### Phase 0 — Dependencies and feature gate + +1. Add Cargo deps under a new feature `embeddings`: `fastembed` (covers embed + rerank), `usearch`. `ort` was originally listed but dropped — `fastembed::TextRerank` covers the reranker natively. +2. Gate all new modules behind `#[cfg(feature = "embeddings")]` so default builds stay slim. +3. Document ONNX runtime requirements in `docs/mcp-setup.md` (fastembed bundles ONNX runtime via its own deps). + +### Phase 1 — `src/embeddings/` module + +1. `src/embeddings/mod.rs` — `Embedder` trait, factory. +2. `src/embeddings/text_blob.rs` — code/ontology/doc blob builders (table above). +3. `src/embeddings/index.rs` — `usearch::Index` wrapper: `build_from_code_elements`, `load`, `save`, `search`, `remove`, supports incremental add/remove. +4. `src/embeddings/build.rs` — orchestrate incremental build: read `code_elements`, compute text blob hashes, diff against `embedding_state`, embed only changed/missing/stale, reap deleted, persist `.leankg/embeddings.usearch` + `.meta.json`. +5. `src/embeddings/state.rs` — `embedding_state` CozoDB table DDL + helpers (`mark_stale_for_ids`, `upsert_fresh`, `list_stale`, `list_orphans`). +6. `src/embeddings/models.rs` — lazy-download + SHA-256 verify to `~/.cache/leankg/models/`; `init_models()` for `embed --init`. +7. **Indexer hook.** Modify `src/indexer/` to call `embedding_state::mark_stale_for_ids` after upserting/deleting CodeElements during `index`. Behind `#[cfg(feature = "embeddings")]`. + +CLI: +- `cargo run --release -- embed --init` — download models, no build. +- `cargo run --release -- embed` — incremental (default). +- `cargo run --release -- embed --full` — full rebuild. + +### Phase 2 — `src/retrieval/` module + +1. `src/retrieval/ann.rs` — embed query → `usearch` top-K → return `(CodeElement.id, score)[]`. Apply worktree path filter here (Q2 default-on). +2. `src/retrieval/rerank.rs` — `fastembed::TextRerank` with `RerankerModel::BGERerankerV2M3`, batch-score `(query, blob)` pairs, return reranked top-N. **On any failure** (model missing after lazy-download attempt, init error, inference OOM/panic) → return ANN-order top-N unchanged and set a `RerankerStatus::Fallback` flag on the result (Q4 option A). +3. `src/retrieval/pipeline.rs` — `SemanticRetrievalPipeline` struct with `retrieve(query, env, top_k, rerank_top_n) -> RetrievalResult { seeds, reranker_status, embeddings_stale }`. + +No MCP wiring yet. Unit-testable end to end. + +### Phase 3 — Adaptive traversal + +1. Extend `src/ontology/query.rs` (or new `src/graph/traverse.rs`) with `traverse_seeds(seeds, env, rules) -> TraverseResult`. +2. Rules table encoded as Rust `match` on `element_type` (per table above). +3. CozoDB Datalog queries parameterized by `(hops, edge_types, fanout)`; reuse existing arity-correct patterns from `src/ontology/query.rs`. +4. Dedup, env filter, global cap. + +### Phase 4 — MCP wiring + +1. Add `kg_semantic_context` to `src/mcp/tools.rs` schema. +2. Handler in `src/mcp/handler.rs` calls pipeline → traverse → compress → return. +3. Add `debug` field passthrough. +4. Register in `kg_self_test` smoke flow. + +### Phase 5 — CLI parity + +```bash +cargo run --release -- embed --init # pre-download both models (setup, no build) +cargo run --release -- embed # incremental rebuild (default, stale-only) +cargo run --release -- embed --full # full rebuild (recovery / model swap) +cargo run --release -- semantic-context "query" # one-shot CLI for testing +``` + +### Phase 6 — Tests + +- Unit: text blob construction (per element_type), adaptive rule selection, dedup. +- Integration: small fixture repo, build index, run known queries, assert seed + traversed membership. +- Regression: ensure existing `kg_context` and `semantic_search` outputs are unchanged. +- Latency: budget assertions in `kg_self_test` (embed < 10ms, rerank < 50ms, traverse < 30ms on the fixture). + +## File Touchpoints + +| Area | Change | +| --- | --- | +| `Cargo.toml` | New `embeddings` feature, deps | +| `src/lib.rs` | Export `embeddings`, `retrieval` modules | +| `src/embeddings/*` | New | +| `src/retrieval/*` | New | +| `src/graph/traverse.rs` or `src/ontology/query.rs` | Add `traverse_seeds` | +| `src/mcp/tools.rs` | New `kg_semantic_context` schema | +| `src/mcp/handler.rs` | New handler, pipeline orchestration | +| `src/cli.rs` | New `embed` and `semantic-context` subcommands | +| `src/mcp/tools.rs::kg_self_test` | Add semantic smoke check | +| `docs/mcp-setup.md` | Document embedding deps and `embed` command | +| `docs/mcp-tools.md` | Document `kg_semantic_context` | + +## Acceptance Criteria + +- `cargo run --release -- embed` builds `.leankg/embeddings.usearch` from an indexed repo; idempotent on re-run. +- `kg_semantic_context("checkout refund failure", env="local")` returns at least one workflow/concept seed and at least one traversed neighbor (file/function/step). +- Adaptive hop rule respected: workflow seeds produce 2-hop traversed sets; function seeds produce 1-hop sets (verified via `debug=true` diagnostics). +- p95 total latency on a 5k-node repo < 150ms with embeddings enabled. +- Existing `kg_context`, `semantic_search`, `find_function`, `get_impact_radius` outputs unchanged. +- Default `cargo build --release` (without `embeddings` feature) still succeeds and produces no binary bloat. +- `kg_self_test` includes a semantic retrieval assertion. + +## Resolved Questions (2026-06-30) + +All four open questions settled before branch creation: + +1. **Reranker model placement → lazy-download + `--init`.** Models live in `~/.cache/leankg/models/`, SHA-256 verified. `embed --init` is the explicit setup; lazy-download is the fallback for users who skip it. +2. **Worktree exclusion → default-on.** `**/.worktrees/**`, `**/.claude/worktrees/**`, `**/.opencode/worktrees/**` filtered at ANN stage. `include_worktrees: bool = false` opt-in. +3. **Index freshness → incremental embed via `embedding_state`.** `index` marks touched nodes stale; `embed` does incremental batch on the stale/missing/changed-hash set; `embed --full` ignores state for recovery. Query-time `diagnostics.embeddings_stale` flags a stale index but still serves. +4. **Reranker fallback → option A (ANN-only).** Any reranker failure (missing model, load failure, OOM, panic) drops Stage 3 and returns ANN-order top-N. `diagnostics.reranker = "fallback_ann"` flag makes degradation visible to agents. + +## Future Enhancements (explicitly deferred) + +- Parallel multi-channel retrieval + RRF fusion (the full `hybrid-retrieval-reranking.md` design). +- Offline structural GNN embeddings (node2vec / GraphSAGE via PyG) as a second vector channel. +- Code body embeddings for "find by implementation detail" queries. +- LLM-assisted query intent classification (replaces the lightweight deterministic intent hints). +- Cross-repo embedding shards for multi-repo deployments. + +## Risks + +| Risk | Mitigation | +| --- | --- | +| Model download size blocks first-run UX | Lazy download with progress, cache reuse, document offline path | +| Reranker latency dominates | Cap rerank input at top-K=50; batch in one `ort` call | +| Traversal returns noise | Edge-type filter per seed type, global cap, dedup | +| Embeddings miss exact identifiers | Keep `semantic_search` and `find_function` unchanged; agents choose tool | +| Index drift after re-indexing | Timestamp check + explicit `embed` step + warning in MCP response | +| `ort` / ONNX runtime portability | fastembed bundles ONNX runtime; document supported targets; fall back to ANN-only on load failure | diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 6f3c692..073f03d 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -113,6 +113,58 @@ pub enum CLICommand { #[arg(long)] lang: Option, }, + /// Build or refresh the embedding index (requires --features embeddings). + /// Default mode is incremental: only re-embed nodes touched since the + /// last `embed` run, plus newly-added nodes. Orphans (state rows whose + /// qualified_name no longer exists) are reaped from usearch + state. + #[cfg(feature = "embeddings")] + Embed { + /// Download the embedding + reranker models to the cache and exit. + /// No index is built. Recommended first step on a fresh install. + #[arg(long)] + init: bool, + /// Ignore embedding_state freshness and re-embed every node from + /// scratch. Use after a model swap or index corruption. + #[arg(long)] + full: bool, + /// Override the embedding batch size (default 256). Lower this on + /// memory-constrained hosts. + #[arg(long, default_value = "256")] + batch_size: usize, + /// Project root (defaults to current working directory). + #[arg(long, default_value = ".")] + project: String, + }, + /// One-shot embedding retrieval for CLI testing (requires + /// --features embeddings). Useful for validating the retrieve→rerank→ + /// traverse pipeline without standing up the MCP server. + #[cfg(feature = "embeddings")] + SemanticContext { + /// Natural language query. + query: String, + /// Environment filter. + #[arg(long, default_value = "local")] + env: String, + /// ANN retrieve depth. + #[arg(long, default_value = "50")] + top_k: usize, + /// Final seed count after rerank. + #[arg(long, default_value = "10")] + rerank_top_n: usize, + /// Disable Stage 4 graph enrichment. + #[arg(long)] + no_traverse: bool, + /// Include paths under .worktrees/ / .claude/worktrees/ / + /// .opencode/worktrees/ (filtered by default). + #[arg(long)] + include_worktrees: bool, + /// Print diagnostics: candidate counts, latency, reranker status. + #[arg(long)] + debug: bool, + /// Project root (defaults to current working directory). + #[arg(long, default_value = ".")] + project: String, + }, /// Export knowledge graph Export { /// Output file path diff --git a/src/db/schema.rs b/src/db/schema.rs index 370c6c8..83bdbbe 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -302,6 +302,14 @@ fn init_schema(db: &CozoDb) -> Result<(), Box> { } } + // Embedding-state table (only when the `embeddings` feature is compiled in). + // Without the feature, the table is never created and `embeddings::*` + // calls are absent from the binary — keeps default builds lean. + #[cfg(feature = "embeddings")] + { + crate::embeddings::state::ensure_embedding_state_table(db)?; + } + Ok(()) } diff --git a/src/embeddings/build.rs b/src/embeddings/build.rs new file mode 100644 index 0000000..adc109c --- /dev/null +++ b/src/embeddings/build.rs @@ -0,0 +1,261 @@ +//! Embedding build orchestration: incremental vs full rebuild, plus orphan +//! reaping. Implements `cargo run --release -- embed [--full]`. +//! +//! Incremental flow (default): +//! 1. Load (or create) the usearch index from `embeddings.usearch`. +//! 2. Walk all `code_elements` and compute the current text blob + hash for +//! each embeddable node. +//! 3. Diff against `embedding_state`: embed any qualified_name where +//! (a) no state row exists, OR +//! (b) `state != "fresh"`, OR +//! (c) stored `content_hash` differs from the current blob hash. +//! 4. Reap orphans: state rows whose qualified_name is no longer in +//! `code_elements` get their vector removed from usearch and their row +//! deleted. +//! 5. Persist `embeddings.usearch` + `embeddings.meta.json`. +//! +//! Full rebuild (`--full`): step 3 becomes "embed every embeddable node". + +use crate::embeddings::{ + index::AnnIndex, + models::{EMBEDDING_DIM, Embedder}, + state::{self, FreshRow}, + text_blob, +}; +use crate::graph::query::GraphEngine; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BuildMode { + /// Skip up-to-date rows; embed only stale/missing/changed. + Incremental, + /// Re-embed every embeddable CodeElement, regardless of state. + Full, +} + +#[derive(Debug, Clone)] +pub struct BuildOptions { + pub mode: BuildMode, + /// Vectors per embed call. fastembed handles batching internally; we + /// chunk to keep peak memory bounded on very large repos. + pub batch_size: usize, + /// Optional capacity hint for the usearch index. If None, reserve to + /// the current element count + 10% headroom. + pub reserve_capacity: Option, +} + +impl Default for BuildOptions { + fn default() -> Self { + Self { + mode: BuildMode::Incremental, + batch_size: 256, + reserve_capacity: None, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct BuildReport { + pub considered_count: usize, + pub embedded_count: usize, + pub skipped_fresh_count: usize, + pub orphaned_count: usize, + pub index_size: usize, + pub index_path: PathBuf, +} + +pub fn run( + graph: &GraphEngine, + index_path: &Path, + opts: &BuildOptions, +) -> Result> { + let embedder = Embedder::new()?; + let dim = embedder.dim(); + + let index = if index_path.exists() { + match AnnIndex::load(index_path) { + Ok(loaded) if loaded.dim() == dim => loaded, + Ok(loaded) => { + tracing::warn!( + "existing index dim {} != model dim {}; rebuilding from scratch", + loaded.dim(), + dim + ); + AnnIndex::new(dim)? + } + Err(e) => { + tracing::warn!("failed to load existing index ({}); rebuilding", e); + AnnIndex::new(dim)? + } + } + } else { + let new = AnnIndex::new(dim)?; + if let Some(cap) = opts.reserve_capacity { + new.reserve(cap)?; + } + new + }; + + // 1. Walk code_elements and build the work list. + let elements = graph.all_elements()?; + let work: Vec = elements + .iter() + .filter_map(|el| { + let blob = text_blob::build_blob(el)?; + let hash = text_blob::content_hash_for(&blob); + let key = text_blob::usearch_key_for(&el.qualified_name); + Some(WorkItem { + qualified_name: el.qualified_name.clone(), + blob, + current_hash: hash, + key, + }) + }) + .collect(); + + // 2. Build the "needs embed" set. + let existing_state: std::collections::HashMap = state::list_all(graph.db())? + .into_iter() + .map(|r| (r.qualified_name.clone(), r)) + .collect(); + + let to_embed: Vec<&WorkItem> = work + .iter() + .filter(|w| match opts.mode { + BuildMode::Full => true, + BuildMode::Incremental => match existing_state.get(&w.qualified_name) { + None => true, + Some(row) => { + row.state != "fresh" + || row.content_hash.is_empty() + || row.content_hash != w.current_hash + } + }, + }) + .collect(); + + let considered = work.len(); + let skipped_fresh = considered - to_embed.len(); + + // 3. Reserve usearch capacity ahead of any insertions. usearch panics + // ("Reserve capacity ahead of insertions!") if you add before reserving. + // Use the existing index size + the new embed count as a lower bound, + // with 10% headroom for future incremental runs. + let needed_capacity = match opts.reserve_capacity { + Some(cap) => cap, + None => index.size() + to_embed.len() + (to_embed.len() / 10).max(16), + }; + if needed_capacity > index.size() { + index.reserve(needed_capacity)?; + } + + // 4. Batch embed and add to usearch. + let mut embedded = 0usize; + let mut fresh_rows: Vec = Vec::with_capacity(to_embed.len()); + for chunk in to_embed.chunks(opts.batch_size) { + let texts: Vec = chunk.iter().map(|w| w.blob.clone()).collect(); + let vectors = embedder.embed(&texts)?; + for (item, vector) in chunk.iter().zip(vectors.iter()) { + // Remove the old vector if it exists (usearch `add` does NOT + // overwrite by default — it can leave duplicate keys). + let _ = index.remove(item.key); + index.add(item.key, vector)?; + fresh_rows.push(FreshRow { + qualified_name: item.qualified_name.clone(), + usearch_key: item.key, + content_hash: item.current_hash.clone(), + }); + embedded += 1; + } + } + + // 4. Persist fresh state. + state::upsert_fresh(graph.db(), &fresh_rows)?; + + // 5. Reap orphans: state rows whose qualified_name is no longer present. + let work_qns: std::collections::HashSet<&str> = + work.iter().map(|w| w.qualified_name.as_str()).collect(); + let orphans: Vec = existing_state + .keys() + .filter(|qn| !work_qns.contains(qn.as_str())) + .cloned() + .collect(); + for qn in &orphans { + if let Ok(Some(key)) = state::lookup_usearch_key(graph.db(), qn) { + let _ = index.remove(key); + } + } + if !orphans.is_empty() { + state::delete_state_rows(graph.db(), &orphans)?; + } + + // 6. Persist index + meta. + index.save(index_path)?; + write_meta(index_path, dim, embedded, index.size())?; + + Ok(BuildReport { + considered_count: considered, + embedded_count: embedded, + skipped_fresh_count: skipped_fresh, + orphaned_count: orphans.len(), + index_size: index.size(), + index_path: index_path.to_path_buf(), + }) +} + +struct WorkItem { + qualified_name: String, + blob: String, + current_hash: String, + key: u64, +} + +#[derive(Debug, Clone, serde::Serialize)] +struct IndexMeta { + model_id: &'static str, + dim: usize, + metric: &'static str, + size: usize, + built_at: u64, +} + +fn write_meta(index_path: &Path, dim: usize, _embedded: usize, size: usize) -> Result<(), Box> { + let meta = IndexMeta { + model_id: "BAAI/bge-small-en-v1.5", + dim, + metric: "cosine", + size, + built_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0), + }; + let meta_path = meta_path_for(index_path); + let bytes = serde_json::to_vec_pretty(&meta)?; + std::fs::write(&meta_path, bytes)?; + Ok(()) +} + +pub fn meta_path_for(index_path: &Path) -> PathBuf { + let mut p = index_path.to_path_buf(); + p.set_extension("meta.json"); + p +} + +pub const EMBEDDING_DIM_CONST: usize = EMBEDDING_DIM; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn meta_path_swaps_extension() { + let p = PathBuf::from("/tmp/.leankg/embeddings.usearch"); + let meta = meta_path_for(&p); + assert_eq!(meta.file_name().unwrap(), "embeddings.meta.json"); + } + + // End-to-end build tests live in /tests/embeddings_build_e2e.rs (Phase 6). + // They require a live CozoDB + fastembed model cache, so they aren't run + // as part of `cargo test` on machines without the `embeddings` feature. +} diff --git a/src/embeddings/index.rs b/src/embeddings/index.rs new file mode 100644 index 0000000..83ec7c8 --- /dev/null +++ b/src/embeddings/index.rs @@ -0,0 +1,171 @@ +//! usearch HNSW ANN wrapper with file persistence. +//! +//! Cosine similarity, f32 quantization, auto connectivity/expansion. Keys +//! are u64 — we use the deterministic SHA-256-derived key from +//! `text_blob::usearch_key_for` so the same `qualified_name` always maps to +//! the same usearch key across rebuilds. + +use std::path::Path; +use usearch::{Index, IndexOptions, MetricKind, ScalarKind, new_index}; + +pub struct AnnIndex { + inner: Index, + dim: usize, +} + +#[derive(Debug, Clone)] +pub struct AnnSearchResult { + pub key: u64, + /// Cosine distance in [-1, 1]. Lower is more similar for L2; for Cos, + /// usearch returns similarity directly (higher is better) — semantics + /// depend on the underlying library version. We expose the raw value + /// and let callers decide. + pub distance: f32, +} + +impl AnnIndex { + /// Create an empty index. Memory-only until `save` is called. + pub fn new(dim: usize) -> Result> { + let opts = IndexOptions { + dimensions: dim, + metric: MetricKind::Cos, + quantization: ScalarKind::F32, + connectivity: 0, // auto + expansion_add: 0, + expansion_search: 0, + multi: false, + }; + let inner = new_index(&opts)?; + Ok(Self { inner, dim }) + } + + /// Load an existing index from disk. The dimension is read from the + /// file's metadata via `inner.dimensions()` after load. + pub fn load(path: &Path) -> Result> { + let path_str = path.to_string_lossy().to_string(); + // We need dimensions to construct an index, but the index file + // already encodes them. Trick: create a 1-d placeholder, then load + // (which overrides), then read back the real dim. + let placeholder = Self::new(1)?; + placeholder.inner.load(&path_str)?; + let dim = placeholder.inner.dimensions() as usize; + Ok(Self { + inner: placeholder.inner, + dim, + }) + } + + pub fn save(&self, path: &Path) -> Result<(), Box> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let path_str = path.to_string_lossy().to_string(); + self.inner.save(&path_str)?; + Ok(()) + } + + pub fn add(&self, key: u64, vector: &[f32]) -> Result<(), Box> { + if vector.len() != self.dim { + return Err(format!( + "vector dim mismatch: expected {}, got {}", + self.dim, + vector.len() + ) + .into()); + } + self.inner.add(key, vector)?; + Ok(()) + } + + /// Remove a vector by key. Best-effort: silently succeeds if the key + /// isn't present. Used by the embed step to reap orphans. + pub fn remove(&self, key: u64) -> Result<(), Box> { + self.inner.remove(key)?; + Ok(()) + } + + pub fn search( + &self, + query: &[f32], + k: usize, + ) -> Result, Box> { + if query.len() != self.dim { + return Err(format!( + "query dim mismatch: expected {}, got {}", + self.dim, + query.len() + ) + .into()); + } + let matches = self.inner.search(query, k)?; + Ok(matches + .keys + .iter() + .zip(matches.distances.iter()) + .map(|(&k, &d)| AnnSearchResult { + key: k, + distance: d, + }) + .collect()) + } + + /// Hint capacity to avoid reallocations during bulk insert. Optional. + pub fn reserve(&self, capacity: usize) -> Result<(), Box> { + self.inner.reserve(capacity)?; + Ok(()) + } + + pub fn size(&self) -> usize { + self.inner.size() as usize + } + + pub fn dim(&self) -> usize { + self.dim + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_and_search_returns_nearest_first() { + let index = AnnIndex::new(3).unwrap(); + index.add(1, &[0.1, 0.2, 0.3]).unwrap(); + index.add(2, &[0.9, 0.8, 0.7]).unwrap(); + index.add(3, &[0.15, 0.25, 0.35]).unwrap(); + + let results = index.search(&[0.1, 0.2, 0.3], 2).unwrap(); + assert_eq!(results.len(), 2); + // Closest to [0.1, 0.2, 0.3] is key 1 (exact match), then key 3. + assert_eq!(results[0].key, 1); + assert_eq!(results[1].key, 3); + } + + #[test] + fn save_load_roundtrip_preserves_size() { + let tmp = tempfile::tempdir().unwrap(); + let path = tmp.path().join("test.usearch"); + + { + let index = AnnIndex::new(2).unwrap(); + index.add(10, &[1.0, 0.0]).unwrap(); + index.add(20, &[0.0, 1.0]).unwrap(); + index.save(&path).unwrap(); + } + + let loaded = AnnIndex::load(&path).unwrap(); + assert_eq!(loaded.dim(), 2); + assert_eq!(loaded.size(), 2); + + let results = loaded.search(&[1.0, 0.0], 1).unwrap(); + assert_eq!(results[0].key, 10); + } + + #[test] + fn dim_mismatch_is_an_error() { + let index = AnnIndex::new(3).unwrap(); + let err = index.add(1, &[0.0, 0.0]).unwrap_err(); + assert!(err.to_string().contains("dim mismatch")); + } +} diff --git a/src/embeddings/mod.rs b/src/embeddings/mod.rs new file mode 100644 index 0000000..256b4db --- /dev/null +++ b/src/embeddings/mod.rs @@ -0,0 +1,38 @@ +//! Embedding-based retrieval for LeanKG. +//! +//! Behind the `embeddings` cargo feature. Provides: +//! - Text-blob construction for code, ontology, and doc nodes +//! - fastembed-backed embedding inference (BGE-small-en-v1.5) and reranking +//! (bge-reranker-v2-m3) +//! - usearch HNSW ANN index with file persistence +//! - Incremental build via the `embedding_state` CozoDB table +//! - Lazy model download + `embed --init` pre-download +//! +//! See `docs/plans/2026-06-30-embedding-retrieve-rerank-traverse.md` for +//! the design rationale and decision history. + +#![cfg(feature = "embeddings")] + +pub mod build; +pub mod index; +pub mod models; +pub mod state; +pub mod text_blob; + +#[allow(unused_imports)] +pub use build::{run as build_index, BuildMode, BuildOptions, BuildReport}; +#[allow(unused_imports)] +pub use index::{AnnIndex, AnnSearchResult}; +#[allow(unused_imports)] +pub use models::{ + cache_dir, init_models, Embedder, InitReport, Reranker, RerankerStatus, RerankScore, + DEFAULT_EMBEDDING_MODEL, DEFAULT_RERANKER_MODEL, EMBEDDING_DIM, +}; +#[allow(unused_imports)] +pub use state::{ + count_by_state, delete_state_rows, ensure_embedding_state_table, list_all, list_orphans, + list_stale, lookup_usearch_key, mark_stale_for_qualified_names, upsert_fresh, + EmbeddingStateRow, FreshRow, StateCounts, +}; +#[allow(unused_imports)] +pub use text_blob::{build_blob, classify, usearch_key_for, BlobKind}; diff --git a/src/embeddings/models.rs b/src/embeddings/models.rs new file mode 100644 index 0000000..9aba21d --- /dev/null +++ b/src/embeddings/models.rs @@ -0,0 +1,169 @@ +//! fastembed wrappers: embedding inference + cross-encoder reranking, plus +//! model pre-download (`embed --init`) and lazy-download cache configuration. +//! +//! Both the embedder (BGE-small-en-v1.5, 384-dim) and the reranker +//! (bge-reranker-v2-m3) are loaded via fastembed, which handles ONNX +//! Runtime initialization and model caching internally. We set the cache +//! directory to a LeanKG-specific location so models don't collide with +//! other fastembed users. + +use fastembed::{ + EmbeddingModel, InitOptions, RerankInitOptions, RerankerModel, TextEmbedding, TextRerank, +}; +use std::path::PathBuf; + +/// Where fastembed will store downloaded ONNX weights. Linux: +/// `~/.cache/leankg/models`; macOS: `~/Library/Caches/leankg/models`; +/// Windows: `%LOCALAPPDATA%\leankg\models`. Falls back to +/// `./.leankg-cache/models` if no home directory is resolvable. +pub fn cache_dir() -> PathBuf { + dirs::cache_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".leankg-cache")) + .join("leankg") + .join("models") +} + +/// Default embedding model. 384-dim, ~130MB ONNX, fast on CPU. +pub const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallENV15; + +/// Default reranker model. Multilingual, ~600MB ONNX. +pub const DEFAULT_RERANKER_MODEL: RerankerModel = RerankerModel::BGERerankerV2M3; + +/// Embedding dimension for the default embedding model. Used to size the +/// usearch index without having to load the model first. +pub const EMBEDDING_DIM: usize = 384; + +/// Wraps a fastembed `TextEmbedding`. Cheap to clone post-construction; +/// construction is expensive (model load, ~1s after first cache). +pub struct Embedder { + inner: TextEmbedding, +} + +impl Embedder { + /// Load the default embedding model. Triggers lazy-download on first + /// call per machine. Subsequent calls hit the on-disk cache. + pub fn new() -> Result> { + Self::with_model(DEFAULT_EMBEDDING_MODEL) + } + + pub fn with_model(model: EmbeddingModel) -> Result> { + let opts = InitOptions::new(model) + .with_cache_dir(cache_dir()) + .with_show_download_progress(true) + // Pin to a single intra-op thread. ONNX Runtime pre-allocates + // memory pools per thread; on small hosts (e.g. 1-vCPU ARM + // instances) the default of "all cores" explodes RSS and OOMs + // the container. Single-threaded inference is also faster on + // 1-CPU hosts because it avoids cross-thread contention. + .with_intra_threads(Some(1)); + let inner = TextEmbedding::try_new(opts)?; + Ok(Self { inner }) + } + + /// Embed a batch of texts. Returns one 384-dim vector per input text, + /// in the same order. Batch size is fastembed's default (256). + pub fn embed(&self, texts: &[String]) -> Result>, Box> { + let borrowed: Vec<&str> = texts.iter().map(|s| s.as_str()).collect(); + let vectors = self.inner.embed(borrowed, None)?; + Ok(vectors) + } + + pub fn dim(&self) -> usize { + EMBEDDING_DIM + } +} + +/// Wraps a fastembed `TextRerank` cross-encoder. +pub struct Reranker { + inner: TextRerank, +} + +impl Reranker { + pub fn new() -> Result> { + Self::with_model(DEFAULT_RERANKER_MODEL) + } + + pub fn with_model(model: RerankerModel) -> Result> { + let opts = RerankInitOptions::new(model) + .with_cache_dir(cache_dir()) + .with_show_download_progress(true) + .with_intra_threads(Some(1)); + let inner = TextRerank::try_new(opts)?; + Ok(Self { inner }) + } + + /// Score `(query, document)` pairs and return indices sorted by + /// descending score. `documents` is consumed; the returned indices + /// reference the original input positions. + pub fn rerank( + &self, + query: &str, + documents: Vec, + ) -> Result, Box> { + let borrowed: Vec<&str> = documents.iter().map(|s| s.as_str()).collect(); + let results = self.inner.rerank(query, borrowed, false, None)?; + Ok(results + .into_iter() + .map(|r| RerankScore { + document_idx: r.index, + score: r.score, + }) + .collect()) + } +} + +#[derive(Debug, Clone)] +pub struct RerankScore { + pub document_idx: usize, + pub score: f32, +} + +/// Operational status of the reranker. Used by the retrieval pipeline to +/// decide whether to skip Stage 3 (Q4 option A: ANN-only fallback). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RerankerStatus { + /// Cross-encoder is loaded and being applied. + Active, + /// Reranker failed to initialize; pipeline is returning ANN-order top-N. + Fallback, +} + +/// Pre-download both models into the cache so subsequent `embed` and +/// `kg_semantic_context` calls don't pay the download cost. Implements +/// `cargo run --release -- embed --init`. +pub fn init_models() -> Result> { + tracing::info!("initializing embedding + reranker models at {}", cache_dir().display()); + let _embedder = Embedder::new()?; + let _reranker = Reranker::new()?; + Ok(InitReport { + cache_dir: cache_dir(), + }) +} + +#[derive(Debug, Clone)] +pub struct InitReport { + pub cache_dir: PathBuf, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cache_dir_ends_with_leankg_models() { + let dir = cache_dir(); + let components: Vec<_> = dir.components().collect(); + let last_two: Vec = components + .into_iter() + .rev() + .take(2) + .map(|c| c.as_os_str().to_string_lossy().to_string()) + .collect(); + assert_eq!(last_two, vec!["models".to_string(), "leankg".to_string()]); + } + + #[test] + fn embedding_dim_matches_bge_small() { + assert_eq!(EMBEDDING_DIM, 384); + } +} diff --git a/src/embeddings/state.rs b/src/embeddings/state.rs new file mode 100644 index 0000000..2fe2eae --- /dev/null +++ b/src/embeddings/state.rs @@ -0,0 +1,289 @@ +//! `embedding_state` CozoDB table and helpers. +//! +//! Tracks per-CodeElement embedding freshness so `embed` runs incrementally. +//! See plan §"Incremental embedding & staleness". +//! +//! Lifecycle: +//! 1. `index` upserts CodeElements, then calls `mark_stale_for_qualified_names` +//! on every touched qualified_name. Existing rows flip to `state="stale"`; +//! new rows get a placeholder (`content_hash=""`, `state="stale"`). +//! 2. `embed` queries for rows where `state != "fresh"` OR `content_hash` no +//! longer matches the current blob, embeds them, and calls `upsert_fresh`. +//! 3. `embed` also reaps orphans: state rows whose qualified_name is no longer +//! in `code_elements`. Their usearch vectors are removed and the state row +//! is deleted. + +use crate::db::schema::CozoDb; + +const CREATE_EMBEDDING_STATE: &str = + r#":create embedding_state {qualified_name: String, usearch_key: Int, content_hash: String, state: String, embedded_at: String}"#; + +const CREATE_QN_INDEX: &str = + r#":create embedding_state::qn_index {ref: (qualified_name), compressed: true, unique: true}"#; + +const CREATE_KEY_INDEX: &str = + r#":create embedding_state::usearch_key_index {ref: (usearch_key), compressed: true, unique: true}"#; + +const CREATE_STATE_INDEX: &str = + r#":create embedding_state::state_index {ref: (state), compressed: true}"#; + +#[derive(Debug, Clone)] +pub struct EmbeddingStateRow { + pub qualified_name: String, + /// Stored in CozoDB as i64; cast to u64 when feeding usearch. Bit pattern + /// is preserved across the cast. + pub usearch_key: i64, + pub content_hash: String, + pub state: String, + pub embedded_at: String, +} + +/// Idempotently create the `embedding_state` table. Called from `init_schema` +/// on every DB open, so it must be cheap when the table already exists. +pub fn ensure_embedding_state_table(db: &CozoDb) -> Result<(), Box> { + let existing: std::collections::HashSet = db + .run_script("::relations", Default::default()) + .map(|r| { + r.rows + .iter() + .filter_map(|row| row.first().and_then(|v| v.as_str().map(String::from))) + .collect() + }) + .unwrap_or_default(); + + if !existing.contains("embedding_state") { + db.run_script(CREATE_EMBEDDING_STATE, Default::default())?; + for idx in &[CREATE_QN_INDEX, CREATE_KEY_INDEX, CREATE_STATE_INDEX] { + if let Err(e) = db.run_script(idx, Default::default()) { + tracing::debug!("embedding_state index note: {:?}", e); + } + } + tracing::info!("created embedding_state table"); + } + Ok(()) +} + +/// Mark a batch of qualified_names as stale. Idempotent: rows that already +/// exist flip to `state="stale"`; rows that don't exist are inserted with a +/// placeholder (`content_hash=""`) so the next `embed` run picks them up. +/// +/// `usearch_key` is computed deterministically from each qualified_name and +/// stored even on first insert, so the embed step can lookup the key without +/// recomputing. +pub fn mark_stale_for_qualified_names( + db: &CozoDb, + qualified_names: &[String], +) -> Result<(), Box> { + if qualified_names.is_empty() { + return Ok(()); + } + let now = now_iso(); + let rows: Vec = qualified_names + .iter() + .map(|qn| { + let key_i64 = crate::embeddings::text_blob::usearch_key_for(qn) as i64; + format!( + "[{}, {}, {}, {}, {}]", + serde_json::Value::String(qn.clone()), + serde_json::Value::Number(key_i64.into()), + serde_json::Value::String("".to_string()), + serde_json::Value::String("stale".to_string()), + serde_json::Value::String(now.clone()), + ) + }) + .collect(); + let values_clause = rows.join(", "); + + let query = format!( + r#"?[qualified_name, usearch_key, content_hash, state, embedded_at] <- [{values_clause}] + :put embedding_state {{qualified_name, usearch_key, content_hash, state, embedded_at}}"# + ); + db.run_script(&query, Default::default())?; + Ok(()) +} + +/// Return every row whose `state != "fresh"`. Includes newly-inserted +/// placeholders (state="stale", content_hash="") and existing rows that were +/// re-touched by the indexer. +pub fn list_stale(db: &CozoDb) -> Result, Box> { + let query = + r#"?[qualified_name, usearch_key, content_hash, state, embedded_at] := *embedding_state[qualified_name, usearch_key, content_hash, state, embedded_at], state != "fresh""#; + let result = db.run_script(query, Default::default())?; + Ok(result + .rows + .iter() + .filter_map(row_to_state_row) + .collect()) +} + +/// Return every state row whose qualified_name no longer exists in +/// `code_elements`. The embed step reaps these (removes the vector from +/// usearch and deletes the state row). +pub fn list_orphans(db: &CozoDb) -> Result, Box> { + let query = r#" + ?[qualified_name, usearch_key, content_hash, state, embedded_at] := + *embedding_state[qualified_name, usearch_key, content_hash, state, embedded_at], + not *code_elements[qualified_name, _, _, _, _, _, _, _, _, _, _, _, _] + "#; + let result = db.run_script(query, Default::default())?; + Ok(result + .rows + .iter() + .filter_map(row_to_state_row) + .collect()) +} + +/// Return all state rows. Used by `embed --full` to re-embed every existing +/// vector. +pub fn list_all(db: &CozoDb) -> Result, Box> { + let query = + r#"?[qualified_name, usearch_key, content_hash, state, embedded_at] := *embedding_state[qualified_name, usearch_key, content_hash, state, embedded_at]"#; + let result = db.run_script(query, Default::default())?; + Ok(result + .rows + .iter() + .filter_map(row_to_state_row) + .collect()) +} + +/// Lookup the usearch key for a single qualified_name. Returns None if the +/// row is missing (e.g., the element was never indexed). +pub fn lookup_usearch_key( + db: &CozoDb, + qualified_name: &str, +) -> Result, Box> { + let query = + r#"?[usearch_key] := *embedding_state[qualified_name, usearch_key, _, _, _], qualified_name = $qn"#; + let mut params = std::collections::BTreeMap::new(); + params.insert( + "qn".to_string(), + serde_json::Value::String(qualified_name.to_string()), + ); + let result = db.run_script(query, params)?; + Ok(result + .rows + .first() + .and_then(|row| row.first()) + .and_then(|v| v.as_i64()) + .map(|i| i as u64)) +} + +/// Batch upsert: mark rows fresh and stamp their content_hash + embedded_at. +/// Called by the embed step after vectors land in usearch. +pub fn upsert_fresh( + db: &CozoDb, + updates: &[FreshRow], +) -> Result<(), Box> { + if updates.is_empty() { + return Ok(()); + } + let now = now_iso(); + let rows: Vec = updates + .iter() + .map(|u| { + let key_i64 = u.usearch_key as i64; + format!( + "[{}, {}, {}, {}, {}]", + serde_json::Value::String(u.qualified_name.clone()), + serde_json::Value::Number(key_i64.into()), + serde_json::Value::String(u.content_hash.clone()), + serde_json::Value::String("fresh".to_string()), + serde_json::Value::String(now.clone()), + ) + }) + .collect(); + let values_clause = rows.join(", "); + let query = format!( + r#"?[qualified_name, usearch_key, content_hash, state, embedded_at] <- [{values_clause}] + :put embedding_state {{qualified_name, usearch_key, content_hash, state, embedded_at}}"# + ); + db.run_script(&query, Default::default())?; + Ok(()) +} + +/// Delete state rows for a set of qualified_names. Called after the embed +/// step removes orphan vectors from usearch. +pub fn delete_state_rows( + db: &CozoDb, + qualified_names: &[String], +) -> Result<(), Box> { + if qualified_names.is_empty() { + return Ok(()); + } + let rows: Vec = qualified_names + .iter() + .map(|qn| format!("[{}]", serde_json::Value::String(qn.clone()))) + .collect(); + let rows_clause = rows.join(", "); + let query = format!( + r#"?[qualified_name] <- [{rows_clause}] :delete embedding_state {{qualified_name}}"# + ); + db.run_script(&query, Default::default())?; + Ok(()) +} + +/// Count of fresh vs stale rows, for diagnostics. +pub fn count_by_state(db: &CozoDb) -> Result> { + let query = r#" + ?[state, n] := + *embedding_state[_, _, _, state, _], + n = count(state) + "#; + let result = db.run_script(query, Default::default())?; + let mut counts = StateCounts::default(); + for row in &result.rows { + let state = row.first().and_then(|v| v.as_str()).unwrap_or(""); + let n = row.get(1).and_then(|v| v.as_i64()).unwrap_or(0) as usize; + match state { + "fresh" => counts.fresh = n, + "stale" => counts.stale = n, + _ => counts.other += n, + } + } + Ok(counts) +} + +#[derive(Debug, Clone, Default)] +pub struct StateCounts { + pub fresh: usize, + pub stale: usize, + pub other: usize, +} + +#[derive(Debug, Clone)] +pub struct FreshRow { + pub qualified_name: String, + pub usearch_key: u64, + pub content_hash: String, +} + +fn row_to_state_row(row: &Vec) -> Option { + let qualified_name = row.first()?.as_str()?.to_string(); + let usearch_key = row.get(1)?.as_i64()?; + let content_hash = row.get(2)?.as_str()?.to_string(); + let state = row.get(3)?.as_str()?.to_string(); + let embedded_at = row.get(4)?.as_str()?.to_string(); + Some(EmbeddingStateRow { + qualified_name, + usearch_key, + content_hash, + state, + embedded_at, + }) +} + +fn now_iso() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + let secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + format!("{}", secs) +} + +#[cfg(test)] +mod tests { + // Integration tests live in /tests; these are unit-level guards for the + // SQL builders. The state helpers themselves require a live CozoDB and + // are exercised by tests/embeddings_state_e2e.rs (added in Phase 6). +} diff --git a/src/embeddings/text_blob.rs b/src/embeddings/text_blob.rs new file mode 100644 index 0000000..e4c50b0 --- /dev/null +++ b/src/embeddings/text_blob.rs @@ -0,0 +1,257 @@ +//! Text-blob construction for embedding. +//! +//! Each CodeElement is converted to a short text blob suitable for embedding +//! with a sentence transformer (BGE-small-en-v1.5, 384-dim, 512-token max). +//! Blobs are deliberately compact: name + qualified_name + doc/signature for +//! code nodes; name + aliases + description for ontology nodes. Source bodies +//! are intentionally excluded — see plan §"What gets embedded". + +use crate::db::models::CodeElement; +use sha2::{Digest, Sha256}; + +/// Maximum text-blob length in characters before truncation. The embedding +/// model's hard limit is 512 BPE tokens; ~1500 ASCII characters is a safe +/// approximation that leaves headroom for tokenization expansion. +pub const MAX_BLOB_CHARS: usize = 1500; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BlobKind { + Code, + Ontology, + Doc, + Skip, +} + +/// Classify a CodeElement into a blob-construction strategy. Returns `Skip` +/// for element types that should not be embedded (e.g. clusters, processes +/// that duplicate the code they group). +pub fn classify(element_type: &str) -> BlobKind { + match element_type { + "file" | "function" | "class" | "module" | "method" | "trait" | "interface" => BlobKind::Code, + "domain_entity" + | "service" + | "api_endpoint" + | "data_store" + | "environment" + | "known_issue" + | "playbook" + | "playbook_step" + | "team_knowledge" + | "workflow" + | "workflow_step" + | "decision_point" + | "failure_mode" => BlobKind::Ontology, + // Skip clusters/processes/etc.: they're grouping abstractions whose + // members already get embedded individually. + _ => BlobKind::Skip, + } +} + +/// Build the text blob for a CodeElement. Returns `None` if the element type +/// is in the Skip category or if the resulting blob is empty. +pub fn build_blob(element: &CodeElement) -> Option { + let kind = classify(&element.element_type); + let raw = match kind { + BlobKind::Code => build_code_blob(element), + BlobKind::Ontology => build_ontology_blob(element), + BlobKind::Doc => build_doc_blob(element), + BlobKind::Skip => return None, + }; + let trimmed = raw.trim(); + if trimmed.is_empty() { + None + } else { + Some(truncate_to_chars(trimmed, MAX_BLOB_CHARS).to_string()) + } +} + +fn build_code_blob(element: &CodeElement) -> String { + let mut parts: Vec = Vec::with_capacity(4); + parts.push(element.qualified_name.clone()); + if !element.name.is_empty() && element.name != element.qualified_name { + parts.push(element.name.clone()); + } + if let Some(doc) = extract_doc_signature(&element.metadata) { + parts.push(doc); + } else { + // Fallback: file path + language as a weak signature stand-in. + if !element.file_path.is_empty() { + parts.push(element.file_path.clone()); + } + } + parts.join("\n") +} + +fn build_ontology_blob(element: &CodeElement) -> String { + let mut parts: Vec = Vec::with_capacity(4); + parts.push(element.name.clone()); + if let Some(aliases) = element.metadata.get("aliases").and_then(|v| v.as_array()) { + let alias_str: Vec = aliases + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + if !alias_str.is_empty() { + parts.push(alias_str.join(", ")); + } + } + if let Some(desc) = element.metadata.get("description").and_then(|v| v.as_str()) { + if !desc.is_empty() { + parts.push(desc.to_string()); + } + } + parts.push(element.element_type.clone()); + parts.join("\n") +} + +fn build_doc_blob(element: &CodeElement) -> String { + let mut parts: Vec = Vec::with_capacity(3); + if let Some(title) = element.metadata.get("title").and_then(|v| v.as_str()) { + parts.push(title.to_string()); + } + if let Some(heading) = element + .metadata + .get("heading_path") + .and_then(|v| v.as_array()) + { + let heading_str: Vec = heading + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + if !heading_str.is_empty() { + parts.push(heading_str.join(" / ")); + } + } + if let Some(body) = element.metadata.get("first_paragraph").and_then(|v| v.as_str()) { + parts.push(body.to_string()); + } + parts.join("\n") +} + +/// Pull a doc comment / signature out of the CodeElement metadata, if the +/// indexer stored one. Different extractor paths use different keys; we +/// accept any of the known ones. +fn extract_doc_signature(metadata: &serde_json::Value) -> Option { + for key in &["doc_comment", "doc", "signature", "signature_text"] { + if let Some(s) = metadata.get(key).and_then(|v| v.as_str()) { + if !s.trim().is_empty() { + return Some(s.to_string()); + } + } + } + None +} + +fn truncate_to_chars(s: &str, max_chars: usize) -> &str { + if s.len() <= max_chars { + return s; + } + let mut end = max_chars; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + &s[..end] +} + +/// Deterministic u64 key derived from `qualified_name`. Used as the usearch +/// HNSW key — stable across re-indexes as long as qualified_name is unchanged. +/// We store the same value in `embedding_state.usearch_key` so reverse lookup +/// (search result → qualified_name) is a single equality predicate. +/// +/// We use the first 8 bytes of SHA-256 (little-endian) rather than a faster +/// non-crypto hash: collisions across a 10k-node repo are ~10^-11, and we get +/// portability across architectures for free. +pub fn usearch_key_for(qualified_name: &str) -> u64 { + let mut hasher = Sha256::new(); + hasher.update(qualified_name.as_bytes()); + let digest = hasher.finalize(); + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&digest[..8]); + u64::from_le_bytes(bytes) +} + +/// SHA-256 hex digest of the text blob. Stored in `embedding_state.content_hash` +/// to detect content changes between embed runs. +pub fn content_hash_for(blob: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(blob.as_bytes()); + format!("{:x}", hasher.finalize()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_element(element_type: &str, name: &str, qualified_name: &str) -> CodeElement { + CodeElement { + element_type: element_type.to_string(), + name: name.to_string(), + qualified_name: qualified_name.to_string(), + ..Default::default() + } + } + + #[test] + fn classify_known_types() { + assert_eq!(classify("function"), BlobKind::Code); + assert_eq!(classify("class"), BlobKind::Code); + assert_eq!(classify("workflow"), BlobKind::Ontology); + assert_eq!(classify("domain_entity"), BlobKind::Ontology); + assert_eq!(classify("cluster"), BlobKind::Skip); + } + + #[test] + fn code_blob_uses_qualified_name_and_doc() { + let mut el = make_element("function", "do_thing", "src/main.rs::do_thing"); + el.metadata = serde_json::json!({"doc_comment": "Does the thing."}); + let blob = build_blob(&el).unwrap(); + assert!(blob.contains("src/main.rs::do_thing")); + assert!(blob.contains("do_thing")); + assert!(blob.contains("Does the thing.")); + } + + #[test] + fn ontology_blob_includes_aliases_and_description() { + let mut el = make_element( + "domain_entity", + "Refund", + "ontology://local:checkout:domain_entity:refund:v1", + ); + el.metadata = serde_json::json!({ + "aliases": ["reversal", "chargeback"], + "description": "Money returned to a customer after payment capture" + }); + let blob = build_blob(&el).unwrap(); + assert!(blob.contains("Refund")); + assert!(blob.contains("reversal")); + assert!(blob.contains("chargeback")); + assert!(blob.contains("Money returned")); + assert!(blob.contains("domain_entity")); + } + + #[test] + fn skip_element_types_return_none() { + let el = make_element("cluster", "cluster1", "cluster://x"); + assert!(build_blob(&el).is_none()); + } + + #[test] + fn usearch_key_is_deterministic() { + let k1 = usearch_key_for("src/main.rs::main"); + let k2 = usearch_key_for("src/main.rs::main"); + assert_eq!(k1, k2); + } + + #[test] + fn usearch_key_differs_for_different_names() { + let k1 = usearch_key_for("src/main.rs::main"); + let k2 = usearch_key_for("src/main.rs::helper"); + assert_ne!(k1, k2); + } + + #[test] + fn truncation_respects_char_boundaries() { + let s = "a".repeat(2000); + let truncated = truncate_to_chars(&s, MAX_BLOB_CHARS); + assert_eq!(truncated.len(), MAX_BLOB_CHARS); + } +} diff --git a/src/graph/traversal.rs b/src/graph/traversal.rs index 9df31d5..6fec1d6 100644 --- a/src/graph/traversal.rs +++ b/src/graph/traversal.rs @@ -115,3 +115,282 @@ pub struct ImpactResult { pub affected_elements: Vec, pub affected_with_confidence: Vec, } + +// =========================================================================== +// Semantic-retrieval traversal (Stage 4) +// +// Adaptive N-hop BFS from each seed node. Hops + allowed edge types + fanout +// cap depend on the seed's element_type — see `traverse_rule_for`. The +// function is feature-independent: callers pass plain `(qualified_name, +// element_type)` tuples so this module doesn't depend on the gated retrieval +// types. The retrieval pipeline + MCP handler adapt their seed list to this +// shape. +// +// See docs/plans/2026-06-30-embedding-retrieve-rerank-traverse.md §"Adaptive +// Traversal Rules". +// ==================================================================================== + +/// Hard ceiling on total traversed neighbors across all seeds, regardless of +/// per-seed fanout. Keeps MCP response size bounded for agents. +const GLOBAL_NEIGHBOR_CAP: usize = 60; + +pub struct TraverseRule { + pub hops: u32, + pub edge_types: &'static [&'static str], + pub fanout_cap: usize, +} + +const WORKFLOW_EDGES: &[&str] = &[ + "has_step", + "next_step", + "branches_to", + "implemented_by", + "entry_point_of", + "step_in_process", + "has_failure_mode", +]; + +const STEP_EDGES: &[&str] = &[ + "next_step", + "branches_to", + "implemented_by", + "handled_by_playbook", + "has_failure_mode", + "resolved_by_playbook", +]; + +const CONCEPT_EDGES: &[&str] = &[ + "owns_concept", + "implements_concept", + "exposes_endpoint", + "reads_from", + "writes_to", + "documents_concept", + "has_known_issue", +]; + +const ISSUE_EDGES: &[&str] = &[ + "has_known_issue", + "resolved_by_playbook", + "documents_concept", +]; + +const CODE_EDGES: &[&str] = &[ + "calls", + "imports", + "references", + "tested_by", + "documented_by", + "implements_concept", +]; + +const FILE_EDGES: &[&str] = &[ + "imports", + "references", + "tested_by", + "documented_by", + "contains", + "defines", +]; + +const DOC_EDGES: &[&str] = &["documented_by", "documents_concept"]; + +pub fn traverse_rule_for(element_type: &str) -> TraverseRule { + match element_type { + "workflow" => TraverseRule { + hops: 2, + edge_types: WORKFLOW_EDGES, + fanout_cap: 20, + }, + "workflow_step" | "decision_point" | "failure_mode" => TraverseRule { + hops: 2, + edge_types: STEP_EDGES, + fanout_cap: 15, + }, + "domain_entity" | "service" | "api_endpoint" | "data_store" => TraverseRule { + hops: 1, + edge_types: CONCEPT_EDGES, + fanout_cap: 15, + }, + "known_issue" | "playbook" | "team_knowledge" => TraverseRule { + hops: 1, + edge_types: ISSUE_EDGES, + fanout_cap: 10, + }, + "function" | "class" => TraverseRule { + hops: 1, + edge_types: CODE_EDGES, + fanout_cap: 10, + }, + "file" | "module" => TraverseRule { + hops: 1, + edge_types: FILE_EDGES, + fanout_cap: 10, + }, + _ => TraverseRule { + hops: 1, + edge_types: DOC_EDGES, + fanout_cap: 5, + }, + } +} + +#[derive(Debug, Clone)] +pub struct TraversedNode { + pub qualified_name: String, + pub element_type: String, + pub from_seed: String, + pub via_edge: String, + pub hop: u32, +} + +#[derive(Debug, Clone)] +pub struct TraverseResult { + pub nodes: Vec, + pub edges: Vec, + pub total_neighbors: usize, + pub capped: bool, +} + +#[derive(Debug, Clone)] +pub struct TraversedEdge { + pub source: String, + pub target: String, + pub rel_type: String, +} + +/// Adaptive multi-hop BFS from a set of seed nodes. Returns traversed +/// neighbors (excluding the seeds themselves) plus the edges that connect +/// them. Honors per-seed-type fanout caps and the global +/// `GLOBAL_NEIGHBOR_CAP`. +pub fn traverse_seeds( + graph: &GraphEngine, + seeds: I, + env: Option<&str>, +) -> Result> +where + I: IntoIterator, +{ + use std::collections::{HashSet, VecDeque}; + + let mut visited: HashSet = HashSet::new(); + let mut nodes: Vec = Vec::new(); + let mut edges: Vec = Vec::new(); + let mut total = 0usize; + + for (seed_qn, seed_type) in seeds { + // Seed itself is always "visited" — we don't return it in the + // traversed set even if a cycle would bring us back to it. + visited.insert(seed_qn.clone()); + + let rule = traverse_rule_for(&seed_type); + let mut frontier: VecDeque<(String, u32, String, String)> = VecDeque::new(); + // (current_qn, current_hop, from_seed_qn, via_edge_into_current) + frontier.push_back((seed_qn.clone(), 0, seed_qn.clone(), "seed".to_string())); + + let mut seed_count = 0usize; + while let Some((current, hop, from, _via)) = frontier.pop_front() { + if hop >= rule.hops { + continue; + } + if seed_count >= rule.fanout_cap || total >= GLOBAL_NEIGHBOR_CAP { + break; + } + + let outgoing = graph.get_relationships(¤t).unwrap_or_default(); + let incoming = graph.get_relationships_for_target(¤t).unwrap_or_default(); + + for rel in outgoing.iter().chain(incoming.iter()) { + if !rule.edge_types.contains(&rel.rel_type.as_str()) { + continue; + } + if let Some(wanted) = env { + if rel.env != wanted { + continue; + } + } + + let neighbor = if rel.source_qualified == current { + rel.target_qualified.clone() + } else { + rel.source_qualified.clone() + }; + + if visited.contains(&neighbor) { + continue; + } + visited.insert(neighbor.clone()); + + edges.push(TraversedEdge { + source: rel.source_qualified.clone(), + target: rel.target_qualified.clone(), + rel_type: rel.rel_type.clone(), + }); + + let element_type = graph + .find_element(&neighbor) + .ok() + .flatten() + .map(|e| e.element_type) + .unwrap_or_else(|| "unknown".to_string()); + + nodes.push(TraversedNode { + qualified_name: neighbor.clone(), + element_type: element_type.clone(), + from_seed: from.clone(), + via_edge: rel.rel_type.clone(), + hop: hop + 1, + }); + + frontier.push_back(( + neighbor, + hop + 1, + from.clone(), + rel.rel_type.clone(), + )); + + seed_count += 1; + total += 1; + if total >= GLOBAL_NEIGHBOR_CAP || seed_count >= rule.fanout_cap { + break; + } + } + } + } + + Ok(TraverseResult { + nodes, + edges, + total_neighbors: total, + capped: total >= GLOBAL_NEIGHBOR_CAP, + }) +} + +#[cfg(test)] +mod traverse_tests { + use super::*; + + #[test] + fn rule_for_workflow_is_two_hops() { + let r = traverse_rule_for("workflow"); + assert_eq!(r.hops, 2); + assert_eq!(r.fanout_cap, 20); + assert!(r.edge_types.contains(&"has_step")); + } + + #[test] + fn rule_for_function_is_one_hop() { + let r = traverse_rule_for("function"); + assert_eq!(r.hops, 1); + assert_eq!(r.fanout_cap, 10); + assert!(r.edge_types.contains(&"calls")); + } + + #[test] + fn rule_for_unknown_type_falls_back_to_docs() { + let r = traverse_rule_for("some-random-type"); + assert_eq!(r.hops, 1); + assert_eq!(r.fanout_cap, 5); + assert!(r.edge_types.contains(&"documented_by")); + } +} diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index ae3cf00..de8d5a2 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -791,6 +791,22 @@ pub fn index_files_parallel( total_elements, total_elements ); } + + // Mark every touched element as embedding-stale so the next + // `embed` run picks them up incrementally. Only fires when the + // `embeddings` feature is compiled in; otherwise no-op. + #[cfg(feature = "embeddings")] + { + let touched: Vec = all_elements + .iter() + .map(|e| e.qualified_name.clone()) + .collect(); + if let Err(e) = + crate::embeddings::state::mark_stale_for_qualified_names(graph.db(), &touched) + { + tracing::warn!("embedding_state stale-mark failed: {}", e); + } + } } if !all_relationships.is_empty() { diff --git a/src/lib.rs b/src/lib.rs index e1f8abf..8f1ca8a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,8 @@ pub mod config; pub mod db; pub mod doc; pub mod doc_indexer; +#[cfg(feature = "embeddings")] +pub mod embeddings; pub mod graph; pub mod hooks; pub mod indexer; @@ -16,5 +18,7 @@ pub mod obsidian; pub mod ontology; pub mod orchestrator; pub mod registry; +#[cfg(feature = "embeddings")] +pub mod retrieval; pub mod runtime; pub mod watcher; diff --git a/src/main.rs b/src/main.rs index f77f14d..f7a5a6d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,8 @@ mod db; mod doc; mod doc_indexer; mod embed; +#[cfg(feature = "embeddings")] +mod embeddings; mod graph; mod indexer; mod mcp; @@ -15,6 +17,8 @@ mod obsidian; mod ontology; mod orchestrator; mod registry; +#[cfg(feature = "embeddings")] +mod retrieval; mod runtime; mod watcher; mod web; @@ -277,6 +281,37 @@ async fn main() -> Result<(), Box> { let db_path = project_path.join(".leankg"); find_oversized_functions(min_lines, lang.as_deref(), &db_path)?; } + #[cfg(feature = "embeddings")] + cli::CLICommand::Embed { + init, + full, + batch_size, + project, + } => { + run_embed(init, full, batch_size, &project)?; + } + #[cfg(feature = "embeddings")] + cli::CLICommand::SemanticContext { + query, + env, + top_k, + rerank_top_n, + no_traverse, + include_worktrees, + debug, + project, + } => { + run_semantic_context( + &query, + &env, + top_k, + rerank_top_n, + !no_traverse, + include_worktrees, + debug, + &project, + )?; + } cli::CLICommand::Export { output, format, @@ -3783,3 +3818,174 @@ fn handle_ontology_command( Ok(()) } + +#[cfg(feature = "embeddings")] +fn run_embed( + init: bool, + full: bool, + batch_size: usize, + project: &str, +) -> Result<(), Box> { + if init { + let report = embeddings::init_models()?; + println!("Models cached at: {}", report.cache_dir.display()); + println!(); + println!("Next steps:"); + println!(" cargo run --release -- index {project}"); + println!(" cargo run --release -- embed --project {project}"); + return Ok(()); + } + + let project_path = std::path::PathBuf::from(project); + let leankg_dir = project_path.join(".leankg"); + let db_path = leankg_dir.join("leankg.db"); + let index_path = leankg_dir.join("embeddings.usearch"); + + if !db_path.exists() { + return Err(format!( + "LeanKG database not found at {}. Run `cargo run --release -- index {}` first.", + db_path.display(), + project + ) + .into()); + } + + let db = db::schema::init_db(&db_path)?; + let graph = graph::GraphEngine::new(db); + + let mode = if full { + embeddings::BuildMode::Full + } else { + embeddings::BuildMode::Incremental + }; + let opts = embeddings::BuildOptions { + mode, + batch_size, + reserve_capacity: None, + }; + + let started = std::time::Instant::now(); + let report = embeddings::build_index(&graph, &index_path, &opts)?; + let elapsed = started.elapsed(); + + println!("Embed build complete ({:?}) in {:.2}s", mode, elapsed.as_secs_f64()); + println!(" Considered: {}", report.considered_count); + println!(" Embedded: {}", report.embedded_count); + println!(" Skipped fresh: {}", report.skipped_fresh_count); + println!(" Orphans reaped: {}", report.orphaned_count); + println!(" Index size: {} vectors", report.index_size); + println!(" Index path: {}", report.index_path.display()); + Ok(()) +} + +#[cfg(feature = "embeddings")] +#[allow(clippy::too_many_arguments)] +fn run_semantic_context( + query: &str, + env: &str, + top_k: usize, + rerank_top_n: usize, + traverse: bool, + include_worktrees: bool, + debug: bool, + project: &str, +) -> Result<(), Box> { + use embeddings::RerankerStatus; + + let project_path = std::path::PathBuf::from(project); + let leankg_dir = project_path.join(".leankg"); + let db_path = leankg_dir.join("leankg.db"); + let index_path = leankg_dir.join("embeddings.usearch"); + + if !index_path.exists() { + return Err(format!( + "Embedding index not found at {}. Run `cargo run --release -- embed --init` \ + (to download models), then `cargo run --release -- embed` (to build the index).", + index_path.display() + ) + .into()); + } + + let db = db::schema::init_db(&db_path)?; + let graph = graph::GraphEngine::new(db.clone()); + + let pipeline = retrieval::SemanticRetrievalPipeline::new(db, &index_path)?; + let opts = retrieval::RetrieveOptions { + env: Some(env.to_string()), + ann_top_k: top_k, + rerank_top_n, + include_worktrees, + embeddings_stale: false, + }; + + let started = std::time::Instant::now(); + let retrieval = pipeline.retrieve(query, &opts)?; + let retrieve_ms = started.elapsed().as_millis(); + + println!("Query: {}", query); + println!( + "Reranker: {}", + match retrieval.reranker_status { + RerankerStatus::Active => "active (bge-reranker-v2-m3)", + RerankerStatus::Fallback => "FALLBACK (ANN-only)", + } + ); + println!(); + + println!("Seeds ({}):", retrieval.seeds.len()); + for (i, s) in retrieval.seeds.iter().enumerate() { + let score = s + .rerank_score + .map(|x| format!("rerank={:.4}", x)) + .unwrap_or_else(|| format!("ann={:.4}", s.ann_distance)); + println!( + " {:>2}. [{:<15}] {} ({})", + i + 1, + s.element_type, + s.qualified_name, + score + ); + if debug { + println!(" blob: {}", s.blob_excerpt); + } + } + + if traverse && !retrieval.seeds.is_empty() { + let t = std::time::Instant::now(); + let seeds_iter = retrieval + .seeds + .iter() + .map(|s| (s.qualified_name.clone(), s.element_type.clone())); + let result = graph::traversal::traverse_seeds(&graph, seeds_iter, Some(env))?; + let trav_ms = t.elapsed().as_millis(); + + println!(); + println!( + "Traversed ({} neighbors, {} edges{}) in {}ms:", + result.nodes.len(), + result.edges.len(), + if result.capped { ", CAPPED" } else { "" }, + trav_ms + ); + for n in &result.nodes { + println!( + " hop {} via {:<20} [{:<15}] {} (from {})", + n.hop, n.via_edge, n.element_type, n.qualified_name, n.from_seed + ); + } + } + + if debug { + println!(); + println!("Diagnostics:"); + println!(" ANN candidates: {}", retrieval.ann_candidate_count); + println!( + "Worktree-filtered: {}", + retrieval.worktree_filtered_count + ); + println!("Env-filtered: {}", retrieval.env_filtered_count); + println!("Retrieve latency: {}ms", retrieve_ms); + } + + Ok(()) +} diff --git a/src/mcp/handler.rs b/src/mcp/handler.rs index 9c31c8b..9355100 100644 --- a/src/mcp/handler.rs +++ b/src/mcp/handler.rs @@ -270,6 +270,8 @@ impl ToolHandler { "kg_trace_workflow" => self.kg_trace_workflow(arguments), "kg_ontology_status" => self.kg_ontology_status(arguments), "kg_self_test" => self.kg_self_test(arguments), + #[cfg(feature = "embeddings")] + "kg_semantic_context" => self.kg_semantic_context(arguments), _ => Err(format!("Unknown tool: {}", tool_name)), }; @@ -2809,6 +2811,169 @@ impl ToolHandler { })) } + /// Embedding-backed semantic retrieval + adaptive KG traversal. + /// Compiles out entirely unless the binary was built with + /// `--features embeddings`. + #[cfg(feature = "embeddings")] + fn kg_semantic_context(&self, args: &Value) -> Result { + use crate::embeddings as emb; + use crate::graph::traversal::traverse_seeds; + use crate::retrieval::{RetrieveOptions, SemanticRetrievalPipeline}; + + let query = args["query"] + .as_str() + .ok_or("Missing 'query' parameter")? + .trim(); + if query.is_empty() { + return Err("'query' must not be empty".to_string()); + } + + let env = args["env"].as_str().unwrap_or("local").to_string(); + let top_k = args["top_k"].as_u64().unwrap_or(50) as usize; + let rerank_top_n = args["rerank_top_n"].as_u64().unwrap_or(10) as usize; + let do_traverse = args["traverse"].as_bool().unwrap_or(true); + let include_worktrees = args["include_worktrees"].as_bool().unwrap_or(false); + let debug = args["debug"].as_bool().unwrap_or(false); + let project = args["project"].as_str().unwrap_or("."); + + let index_path = std::path::Path::new(project) + .join(".leankg") + .join("embeddings.usearch"); + if !index_path.exists() { + return Err(format!( + "Embedding index not found at {}. Run `cargo run --release -- embed --init` \ + to download models, then `cargo run --release -- embed` to build the index.", + index_path.display() + )); + } + + let t0 = std::time::Instant::now(); + let pipeline = + SemanticRetrievalPipeline::new(self.graph_engine.db().clone(), &index_path) + .map_err(|e| format!("Failed to init retrieval pipeline: {}", e))?; + let t_pipeline_ms = t0.elapsed().as_millis() as u64; + + let meta_path = index_path.with_extension("meta.json"); + let embeddings_stale = embeddings_are_stale(&meta_path); + + let opts = RetrieveOptions { + env: Some(env.clone()), + ann_top_k: top_k, + rerank_top_n, + include_worktrees, + embeddings_stale, + }; + + let t1 = std::time::Instant::now(); + let retrieval = pipeline + .retrieve(query, &opts) + .map_err(|e| format!("Retrieval failed: {}", e))?; + let t_retrieve_ms = t1.elapsed().as_millis() as u64; + + let mut traversed_json: Vec = Vec::new(); + let mut edges_json: Vec = Vec::new(); + let mut traverse_capped = false; + let mut total_neighbors = 0usize; + let mut t_traverse_ms = 0u64; + + if do_traverse && !retrieval.seeds.is_empty() { + let t2 = std::time::Instant::now(); + let seeds_iter = retrieval + .seeds + .iter() + .map(|s| (s.qualified_name.clone(), s.element_type.clone())); + let traverse_result = traverse_seeds(&self.graph_engine, seeds_iter, Some(env.as_str())) + .map_err(|e| format!("Traversal failed: {}", e))?; + t_traverse_ms = t2.elapsed().as_millis() as u64; + traverse_capped = traverse_result.capped; + total_neighbors = traverse_result.total_neighbors; + + traversed_json = traverse_result + .nodes + .iter() + .map(|n| { + json!({ + "qualified_name": n.qualified_name, + "element_type": n.element_type, + "from_seed": n.from_seed, + "via_edge": n.via_edge, + "hop": n.hop, + }) + }) + .collect(); + edges_json = traverse_result + .edges + .iter() + .map(|e| { + json!({ + "source": e.source, + "target": e.target, + "rel_type": e.rel_type, + }) + }) + .collect(); + } + + let total_ms = t0.elapsed().as_millis() as u64; + + let seeds_json: Vec = retrieval + .seeds + .iter() + .map(|s| { + let mut obj = json!({ + "qualified_name": s.qualified_name, + "element_type": s.element_type, + "file_path": s.file_path, + "ann_distance": s.ann_distance, + }); + if let Some(score) = s.rerank_score { + obj["rerank_score"] = json!(score); + } + if debug { + obj["blob_excerpt"] = json!(s.blob_excerpt); + } + obj + }) + .collect(); + + let mut response = json!({ + "query": query, + "env": env, + "seeds": seeds_json, + "traversed": traversed_json, + }); + + if debug { + let reranker_label = match retrieval.reranker_status { + emb::RerankerStatus::Active => "bge-reranker-v2-m3", + emb::RerankerStatus::Fallback => "fallback_ann", + }; + response["diagnostics"] = json!({ + "ann_candidate_count": retrieval.ann_candidate_count, + "worktree_filtered_count": retrieval.worktree_filtered_count, + "env_filtered_count": retrieval.env_filtered_count, + "reranker": reranker_label, + "embeddings_stale": retrieval.embeddings_stale, + "traversal": { + "enabled": do_traverse, + "capped": traverse_capped, + "neighbor_count": total_neighbors, + }, + "latency_ms": { + "pipeline_init": t_pipeline_ms, + "retrieve": t_retrieve_ms, + "traverse": t_traverse_ms, + "total": total_ms, + }, + }); + if !edges_json.is_empty() { + response["diagnostics"]["edges"] = json!(edges_json); + } + } + + Ok(response) + } + fn wake_up(&self, args: &Value) -> Result { let project_path = args["project"].as_str().unwrap_or("."); let cache_path = std::path::Path::new(project_path) @@ -2858,6 +3023,36 @@ fn truncate_str(s: &str, max_len: usize) -> String { } } +/// Q3 stale-embeddings detection: compare `embeddings.meta.json.built_at` +/// against the cozodb file's mtime. If the database was modified after the +/// embedding index was built, we're stale. Conservative: on any read or +/// parse error, returns false (assume fresh) so queries still serve. +#[cfg(feature = "embeddings")] +fn embeddings_are_stale(meta_path: &std::path::Path) -> bool { + let meta_bytes = match std::fs::read(meta_path) { + Ok(b) => b, + Err(_) => return false, + }; + let meta: serde_json::Value = match serde_json::from_slice(&meta_bytes) { + Ok(v) => v, + Err(_) => return false, + }; + let built_at = meta.get("built_at").and_then(|v| v.as_u64()).unwrap_or(0); + + let cozo_path = match meta_path.parent() { + Some(p) => p.join("leankg.db"), + None => return false, + }; + let cozo_mtime = std::fs::metadata(&cozo_path) + .and_then(|m| m.modified()) + .ok() + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| d.as_secs()) + .unwrap_or(0); + + cozo_mtime > built_at +} + fn generate_review_prompt(elements: &[CodeElement], _relationships: &[Relationship]) -> String { if elements.is_empty() { return "No elements found for review.".to_string(); diff --git a/src/mcp/tools.rs b/src/mcp/tools.rs index 211f056..49d9767 100644 --- a/src/mcp/tools.rs +++ b/src/mcp/tools.rs @@ -806,6 +806,25 @@ impl ToolRegistry { "required": [] }), }, + #[cfg(feature = "embeddings")] + ToolDefinition { + name: "kg_semantic_context".to_string(), + description: "Vector retrieval + cross-encoder rerank + adaptive KG traversal. Use for natural-language queries where keyword search misses: 'where do we validate access rights', 'how does the refund flow work'. Returns ranked seed nodes plus 1-2 hop graph context (related code, tests, docs, workflows). Requires the `embeddings` cargo feature and an embedding index built via `cargo run --release -- embed`.".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "query": {"type": "string", "description": "Natural language query"}, + "env": {"type": "string", "enum": ["local", "staging", "production"], "default": "local", "description": "Environment to search"}, + "top_k": {"type": "integer", "default": 50, "description": "ANN retrieve depth (candidates before rerank)"}, + "rerank_top_n": {"type": "integer", "default": 10, "description": "Final seed count after cross-encoder rerank"}, + "traverse": {"type": "boolean", "default": true, "description": "Toggle Stage 4 graph enrichment (1-2 hop neighbors via ontology + code edges)"}, + "include_worktrees": {"type": "boolean", "default": false, "description": "Include paths under .worktrees/ / .claude/worktrees/ / .opencode/worktrees/ (filtered by default to dedupe agent scratch copies)"}, + "debug": {"type": "boolean", "default": false, "description": "Include diagnostics: candidate counts, latency per stage, reranker status"}, + "project": {"type": "string", "description": "Optional: project path (defaults to current working directory)"} + }, + "required": ["query"] + }), + }, ] } } diff --git a/src/retrieval/ann.rs b/src/retrieval/ann.rs new file mode 100644 index 0000000..79e3d02 --- /dev/null +++ b/src/retrieval/ann.rs @@ -0,0 +1,34 @@ +//! Stage 2: embed query, run usearch top-K, return raw key+distance pairs. +//! +//! This module deliberately does NOT touch CozoDB — it just wraps the embedder +//! and the usearch index. The pipeline (`pipeline.rs`) is responsible for +//! mapping u64 keys back to qualified_names and applying worktree filters. + +use crate::embeddings::{index::AnnIndex, models::Embedder}; + +pub struct AnnRetrieve<'a> { + embedder: &'a Embedder, + index: &'a AnnIndex, +} + +impl<'a> AnnRetrieve<'a> { + pub fn new(embedder: &'a Embedder, index: &'a AnnIndex) -> Self { + Self { embedder, index } + } + + /// Embed the query and run top-K search. Returns keys + raw distances, + /// sorted by usearch's internal ordering (best-first for cosine). + pub fn retrieve( + &self, + query: &str, + top_k: usize, + ) -> Result, Box> { + let qv = self + .embedder + .embed(&[query.to_string()])? + .into_iter() + .next() + .ok_or("fastembed returned no vectors for query")?; + Ok(self.index.search(&qv, top_k)?) + } +} diff --git a/src/retrieval/mod.rs b/src/retrieval/mod.rs new file mode 100644 index 0000000..682ae29 --- /dev/null +++ b/src/retrieval/mod.rs @@ -0,0 +1,14 @@ +//! Embedding-backed retrieval pipeline. Stages 2 (ANN) and 3 (cross-encoder +//! rerank) live here; Stage 4 (KG traversal) stays in `crate::graph` and is +//! invoked by the MCP handler after this pipeline returns its seeds. +//! +//! Behind the `embeddings` feature like `crate::embeddings`. + +#![cfg(feature = "embeddings")] + +pub mod ann; +pub mod pipeline; +pub mod rerank; + +#[allow(unused_imports)] +pub use pipeline::{RetrieveOptions, RetrievalResult, Seed, SemanticRetrievalPipeline}; diff --git a/src/retrieval/pipeline.rs b/src/retrieval/pipeline.rs new file mode 100644 index 0000000..3726b20 --- /dev/null +++ b/src/retrieval/pipeline.rs @@ -0,0 +1,258 @@ +//! Retrieval pipeline orchestration: query → embed → ANN → worktree/env +//! filter → cross-encoder rerank. Returns a `RetrievalResult` ready for the +//! MCP handler to hand off to the traversal stage. + +use crate::db::models::CodeElement; +use crate::db::schema::CozoDb; +use crate::embeddings::{ + index::AnnIndex, + models::{Embedder, RerankerStatus}, +}; +use crate::retrieval::{ann::AnnRetrieve, rerank::RerankStage}; +use std::collections::{HashMap, HashSet}; +use std::path::Path; + +pub struct SemanticRetrievalPipeline { + embedder: Embedder, + index: AnnIndex, + rerank_stage: RerankStage, + db: CozoDb, +} + +#[derive(Debug, Clone)] +pub struct Seed { + pub qualified_name: String, + pub usearch_key: u64, + /// Raw usearch cosine distance/similarity (semantics depend on usearch + /// version; we surface the value as-is for diagnostics). + pub ann_distance: f32, + /// Set by the cross-encoder. None when the pipeline ran in ANN-only + /// fallback mode (Q4 option A). + pub rerank_score: Option, + pub element_type: String, + pub file_path: String, + pub env: String, + /// Short text-blob excerpt used for rerank; included in diagnostics so + /// agents can see *why* a seed matched. + pub blob_excerpt: String, +} + +#[derive(Debug, Clone)] +pub struct RetrievalResult { + pub seeds: Vec, + pub reranker_status: RerankerStatus, + pub ann_candidate_count: usize, + pub worktree_filtered_count: usize, + pub env_filtered_count: usize, + pub embeddings_stale: bool, +} + +#[derive(Debug, Clone)] +pub struct RetrieveOptions { + /// Restrict results to a single env ("local" / "staging" / "production"). + /// None disables env filtering. + pub env: Option, + /// ANN depth. The reranker then narrows to `rerank_top_n`. Default 50. + pub ann_top_k: usize, + /// Final seed count after rerank. Default 10. + pub rerank_top_n: usize, + /// Q2 default-on worktree filter. Set true to include worktree copies. + pub include_worktrees: bool, + /// Surface a stale-embeddings warning in diagnostics. Set by the caller + /// based on comparing embeddings.meta.json.built_at vs last index run. + pub embeddings_stale: bool, +} + +impl Default for RetrieveOptions { + fn default() -> Self { + Self { + env: Some("local".to_string()), + ann_top_k: 50, + rerank_top_n: 10, + include_worktrees: false, + embeddings_stale: false, + } + } +} + +impl SemanticRetrievalPipeline { + pub fn new(db: CozoDb, index_path: &Path) -> Result> { + let embedder = Embedder::new()?; + let index = AnnIndex::load(index_path)?; + let rerank_stage = RerankStage::try_new(); + Ok(Self { + embedder, + index, + rerank_stage, + db, + }) + } + + pub fn reranker_active(&self) -> bool { + self.rerank_stage.is_active() + } + + pub fn retrieve( + &self, + query: &str, + opts: &RetrieveOptions, + ) -> Result> { + // Stage 2: ANN retrieve. + let ann = AnnRetrieve::new(&self.embedder, &self.index); + let raw = ann.retrieve(query, opts.ann_top_k)?; + let ann_candidate_count = raw.len(); + + // Map keys → qualified_names (single batched query). + let qn_map = self.build_key_to_qn_map()?; + + // Resolve desired qualified_names for the batch CodeElements fetch. + let desired_qns: Vec = raw + .iter() + .filter_map(|r| qn_map.get(&r.key).cloned()) + .collect(); + + // Fetch CodeElements for those qualified_names. + let element_map = self.fetch_elements_batch(&desired_qns)?; + + // Build seeds, applying worktree + env filters. + let mut seeds: Vec = Vec::with_capacity(raw.len()); + let mut worktree_filtered = 0usize; + let mut env_filtered = 0usize; + for r in &raw { + let Some(qn) = qn_map.get(&r.key) else { + continue; + }; + let Some(el) = element_map.get(qn) else { + continue; + }; + + if !opts.include_worktrees && is_worktree_path(&el.file_path) { + worktree_filtered += 1; + continue; + } + if let Some(wanted_env) = &opts.env { + if &el.env != wanted_env { + env_filtered += 1; + continue; + } + } + + let blob = crate::embeddings::build_blob(el).unwrap_or_default(); + seeds.push(Seed { + qualified_name: qn.clone(), + usearch_key: r.key, + ann_distance: r.distance, + rerank_score: None, + element_type: el.element_type.clone(), + file_path: el.file_path.clone(), + env: el.env.clone(), + blob_excerpt: truncate(&blob, 200), + }); + } + + // Stage 3: cross-encoder rerank. + let docs: Vec = seeds.iter().map(|s| s.blob_excerpt.clone()).collect(); + let (scores, status) = self.rerank_stage.rerank(query, docs); + let mut ranked_seeds: Vec = Vec::with_capacity(scores.len()); + for s in &scores { + if let Some(mut seed) = seeds.get(s.document_idx).cloned() { + seed.rerank_score = Some(s.score); + ranked_seeds.push(seed); + } + } + ranked_seeds.truncate(opts.rerank_top_n); + + Ok(RetrievalResult { + seeds: ranked_seeds, + reranker_status: status, + ann_candidate_count, + worktree_filtered_count: worktree_filtered, + env_filtered_count: env_filtered, + embeddings_stale: opts.embeddings_stale, + }) + } + + fn build_key_to_qn_map(&self) -> Result, Box> { + let rows = crate::embeddings::state::list_all(&self.db)?; + Ok(rows + .into_iter() + .map(|r| (r.usearch_key as u64, r.qualified_name)) + .collect()) + } + + fn fetch_elements_batch( + &self, + qns: &[String], + ) -> Result, Box> { + if qns.is_empty() { + return Ok(HashMap::new()); + } + // Phase 2 simplicity: pull all elements and filter in Rust. This is + // O(n) per query which is fine for repos up to ~50k elements; larger + // deployments should swap in a real batched Datalog lookup. + let engine = crate::graph::query::GraphEngine::new(self.db.clone()); + let all = engine.all_elements()?; + let qn_set: HashSet<&str> = qns.iter().map(|s| s.as_str()).collect(); + Ok(all + .into_iter() + .filter(|e| qn_set.contains(e.qualified_name.as_str())) + .map(|e| (e.qualified_name.clone(), e)) + .collect()) + } +} + +/// Match the patterns from Q2: `.worktrees/`, `.claude/worktrees/`, +/// `.opencode/worktrees/`. Path-separator aware so `.worktrees-x/` doesn't +/// false-positive. +fn is_worktree_path(path: &str) -> bool { + const PATTERNS: &[&str] = &[ + "/.worktrees/", + "/.claude/worktrees/", + "/.opencode/worktrees/", + ]; + if path.starts_with(".worktrees/") + || path.starts_with(".claude/worktrees/") + || path.starts_with(".opencode/worktrees/") + { + return true; + } + PATTERNS.iter().any(|p| path.contains(p)) +} + +fn truncate(s: &str, max_chars: usize) -> String { + if s.len() <= max_chars { + return s.to_string(); + } + let mut end = max_chars; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + s[..end].to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn worktree_filter_matches_q2_patterns() { + assert!(is_worktree_path("src/.worktrees/foo/bar.rs")); + assert!(is_worktree_path(".worktrees/foo.rs")); + assert!(is_worktree_path("repo/.claude/worktrees/abc/main.rs")); + assert!(is_worktree_path("repo/.opencode/worktrees/x/y.rs")); + } + + #[test] + fn worktree_filter_does_not_match_unrelated_dirs() { + assert!(!is_worktree_path("src/main.rs")); + assert!(!is_worktree_path(".worktrees-extra/foo.rs")); + assert!(!is_worktree_path("src/.worktrees_other/x.rs")); + } + + #[test] + fn truncate_respects_char_boundaries() { + let s = "hello".repeat(100); + let t = truncate(&s, 200); + assert!(t.len() <= 200); + } +} diff --git a/src/retrieval/rerank.rs b/src/retrieval/rerank.rs new file mode 100644 index 0000000..95420b0 --- /dev/null +++ b/src/retrieval/rerank.rs @@ -0,0 +1,69 @@ +//! Stage 3: cross-encoder rerank with Q4 option-A fallback. +//! +//! On any failure to load OR score, the reranker degrades to ANN-order +//! pass-through. The pipeline reads `RerankerStatus` from the result to +//! populate diagnostics so callers know when they're in fallback mode. + +use crate::embeddings::{models::Reranker, RerankScore, RerankerStatus}; + +/// Wraps an optional `Reranker`. Constructed once at pipeline startup; if +/// construction fails, `inner` stays None and every `rerank` call returns +/// `RerankerStatus::Fallback` with the input order unchanged. +pub struct RerankStage { + inner: Option, +} + +impl RerankStage { + /// Try to load the reranker. Failure is non-fatal — the pipeline still + /// works, just without Stage 3. + pub fn try_new() -> Self { + match Reranker::new() { + Ok(r) => Self { inner: Some(r) }, + Err(e) => { + tracing::warn!( + "reranker load failed; pipeline will run in ANN-only fallback mode: {}", + e + ); + Self { inner: None } + } + } + } + + pub fn is_active(&self) -> bool { + self.inner.is_some() + } + + /// Score `(query, doc)` pairs and return indices into `documents` sorted + /// by descending score. If the reranker is unavailable or the call + /// fails, returns `(0..n, RerankerStatus::Fallback)` — i.e., ANN order + /// is preserved. + pub fn rerank( + &self, + query: &str, + documents: Vec, + ) -> (Vec, RerankerStatus) { + let n = documents.len(); + let Some(reranker) = &self.inner else { + return (ann_order(n), RerankerStatus::Fallback); + }; + match reranker.rerank(query, documents) { + Ok(scores) => (scores, RerankerStatus::Active), + Err(e) => { + tracing::warn!( + "rerank inference failed; falling back to ANN order: {}", + e + ); + (ann_order(n), RerankerStatus::Fallback) + } + } + } +} + +fn ann_order(n: usize) -> Vec { + (0..n) + .map(|i| RerankScore { + document_idx: i, + score: 0.0, + }) + .collect() +} diff --git a/tests/embeddings_state_e2e.rs b/tests/embeddings_state_e2e.rs new file mode 100644 index 0000000..ddf88ef --- /dev/null +++ b/tests/embeddings_state_e2e.rs @@ -0,0 +1,151 @@ +//! Integration tests for the embedding_state CozoDB table. +//! +//! Feature-gated: only compiled when the `embeddings` feature is on. Run with: +//! +//! ```bash +//! cargo test --release --features embeddings --test embeddings_state_e2e +//! ``` +//! +//! These tests don't touch fastembed/usearch — they only exercise the state +//! table helpers in `leankg::embeddings::state`. Model downloads are not +//! required. + +#![cfg(feature = "embeddings")] + +use leankg::db::schema::init_db; +use leankg::embeddings::state::{ + count_by_state, delete_state_rows, ensure_embedding_state_table, list_all, list_orphans, + list_stale, mark_stale_for_qualified_names, upsert_fresh, FreshRow, +}; + +fn fresh_db() -> leankg::db::schema::CozoDb { + let tmp = tempfile::tempdir().expect("tempdir"); + let db_path = tmp.path().join("test.db"); + // init_db runs init_schema, which creates embedding_state when the + // feature is compiled in. We hold on to tmp for the life of the test by + // leaking it — these DBs are tiny and tests are short-lived. + std::mem::forget(tmp); + init_db(&db_path).expect("init_db") +} + +#[test] +fn ensure_embedding_state_table_is_idempotent() { + let db = fresh_db(); + ensure_embedding_state_table(&db).expect("first call"); + ensure_embedding_state_table(&db).expect("second call"); +} + +#[test] +fn mark_stale_inserts_rows_with_placeholder_hash() { + let db = fresh_db(); + let qns: Vec = (0..5) + .map(|i| format!("src/file{i}.rs::fn{i}")) + .collect(); + mark_stale_for_qualified_names(&db, &qns).expect("mark_stale"); + + let stale = list_stale(&db).expect("list_stale"); + assert_eq!(stale.len(), 5); + for row in &stale { + assert_eq!(row.state, "stale"); + assert!(row.content_hash.is_empty()); + } +} + +#[test] +fn mark_stale_is_idempotent() { + let db = fresh_db(); + let qns = vec!["src/a.rs::f".to_string()]; + mark_stale_for_qualified_names(&db, &qns).expect("first"); + mark_stale_for_qualified_names(&db, &qns).expect("second"); + + let all = list_all(&db).expect("list_all"); + assert_eq!(all.len(), 1, "no duplicates after double mark_stale"); +} + +#[test] +fn upsert_fresh_transitions_state_and_stores_hash() { + let db = fresh_db(); + let qns: Vec = (0..3).map(|i| format!("q{i}")).collect(); + mark_stale_for_qualified_names(&db, &qns).expect("mark"); + + let fresh_rows: Vec = qns + .iter() + .map(|qn| FreshRow { + qualified_name: qn.clone(), + usearch_key: leankg::embeddings::usearch_key_for(qn), + content_hash: format!("hash-{qn}"), + }) + .collect(); + upsert_fresh(&db, &fresh_rows).expect("upsert_fresh"); + + let stale = list_stale(&db).expect("list_stale"); + assert!(stale.is_empty(), "no rows should still be stale"); + + let all = list_all(&db).expect("list_all"); + for row in &all { + assert_eq!(row.state, "fresh"); + assert!(row.content_hash.starts_with("hash-")); + } +} + +#[test] +fn list_orphans_detects_rows_without_code_elements() { + let db = fresh_db(); + let qns = vec!["ghost1".to_string(), "ghost2".to_string()]; + mark_stale_for_qualified_names(&db, &qns).expect("mark"); + // No code_elements rows created → both are orphans. + let orphans = list_orphans(&db).expect("list_orphans"); + assert_eq!(orphans.len(), 2); +} + +#[test] +fn delete_state_rows_removes_named_rows() { + let db = fresh_db(); + let qns: Vec = (0..4).map(|i| format!("q{i}")).collect(); + mark_stale_for_qualified_names(&db, &qns).expect("mark"); + + delete_state_rows(&db, &qns[0..2].to_vec()).expect("delete"); + + let remaining = list_all(&db).expect("list_all"); + assert_eq!(remaining.len(), 2); + let remaining_qns: std::collections::HashSet = + remaining.iter().map(|r| r.qualified_name.clone()).collect(); + assert!(remaining_qns.contains("q2")); + assert!(remaining_qns.contains("q3")); +} + +#[test] +fn count_by_state_partitions_correctly() { + let db = fresh_db(); + let qns: Vec = (0..5).map(|i| format!("q{i}")).collect(); + mark_stale_for_qualified_names(&db, &qns).expect("mark"); + + let counts = count_by_state(&db).expect("count_by_state"); + assert_eq!(counts.stale, 5); + assert_eq!(counts.fresh, 0); + + let fresh_rows: Vec = qns[0..2] + .iter() + .map(|qn| FreshRow { + qualified_name: qn.clone(), + usearch_key: leankg::embeddings::usearch_key_for(qn), + content_hash: "x".to_string(), + }) + .collect(); + upsert_fresh(&db, &fresh_rows).expect("upsert"); + + let counts = count_by_state(&db).expect("count_by_state again"); + assert_eq!(counts.fresh, 2); + assert_eq!(counts.stale, 3); +} + +#[test] +fn lookup_usearch_key_returns_computed_value() { + let db = fresh_db(); + let qn = "src/main.rs::main".to_string(); + mark_stale_for_qualified_names(&db, &[qn.clone()]).expect("mark"); + + let key = leankg::embeddings::state::lookup_usearch_key(&db, &qn).expect("lookup"); + let expected = leankg::embeddings::usearch_key_for(&qn); + assert_eq!(key, Some(expected)); +}