diff --git a/.gitignore b/.gitignore index 3a6c02f..f2106b5 100644 --- a/.gitignore +++ b/.gitignore @@ -443,3 +443,11 @@ pyrightconfig.json .ionide # End of https://www.toptal.com/developers/gitignore/api/python,direnv,visualstudiocode,pycharm,macos,jetbrains + +# KG-RAG Example Data (Large Files - Keep Locally) +docs/examples/kgrag/.env +docs/examples/kgrag/dataset/*.jsonl.bz2 +docs/examples/kgrag/dataset/movie/*.json +docs/examples/kgrag/dataset/movie/*.bz2 +docs/examples/kgrag/output/ +docs/examples/kgrag/data/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..465742b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,137 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +**mellea-contribs** is an incubation repository for contributions to the Mellea ecosystem. It provides a library for incubating generative programming tools and utilities that integrate with the Mellea LLM agent framework. + +- **Tech Stack**: Python 3.10+, PDM build system, Mellea framework with LiteLLM +- **Key Dependencies**: rapidfuzz (fuzzy matching), eyecite (legal citations), playwright (web scraping), markdown +- **License**: Apache License 2.0 + +## Common Development Commands + +### Setup and Installation + +```bash +# Install with development dependencies +pdm install --group dev + +# Or using uv (faster): +uv pip install -e . --group dev +``` + +### Code Quality + +```bash +# Format code (ruff is used for both formatting and linting) +ruff format . +ruff check . --fix + +# Type checking +mypy mellea_contribs/ + +# Run all linters individually +isort mellea_contribs/ test/ +pylint mellea_contribs/ test/ +ruff check mellea_contribs/ test/ + +# Run pre-commit hooks +pre-commit run --all-files +``` + +### Testing + +```bash +# Run all tests +pytest + +# Run a specific test file +pytest test/test_citation_exists.py + +# Run a specific test function +pytest test/test_citation_exists.py::test_function_name + +# Run tests with verbose output +pytest -v + +# Run only tests marked as qualitative (LLM-dependent) +pytest -m qualitative + +# Run tests excluding qualitative tests (useful in CI) +pytest -m "not qualitative" + +# Run tests with asyncio verbose output (asyncio_mode is auto-configured) +pytest -v --asyncio-mode=auto +``` + +### Documentation + +```bash +# Build Sphinx documentation +sphinx-build -b html docs/ docs/_build/ +``` + +## Architecture Overview + +### Module Organization + +**`mellea_contribs/tools/`** - Reusable LLM-based selection and ranking algorithms: +- `top_k.py`: Generic Top-K selection engine using LLM judgment. Selects top K from N items with rejection sampling and caching. +- `double_round_robin.py`: Pairwise comparison scoring engine. Performs all-pair comparisons and returns items ranked by accumulated scores. + +**`mellea_contribs/reqlib/`** - Domain-specific validators extending Mellea's validation framework: +- `citation_exists.py`: Validates legal case citations via case.law metadata + fuzzy matching +- `is_appellate_case.py`: Classifies cases as appellate by court abbreviation patterns +- `import_repair.py`: Fixes Python import errors in LLM-generated code via AST analysis +- `import_resolution.py`: Parses and resolves module not found / import errors with confidence-scored suggestions +- `grounding_context_formatter.py`: Structures multi-field context for LLM prompts (auto-skips empty fields) +- `common_aliases.py`: Module name mappings and relocations for import resolution +- `statute_data.py`: Legal statute data handling + +**`mellea_contribs/kg/`** - Knowledge Graph database abstraction: +- `base.py`: Core data structures (GraphNode, GraphEdge, GraphPath) +- `graph_dbs/base.py`: GraphBackend abstract interface +- `graph_dbs/neo4j.py`: Production Neo4j implementation (requires [kg] dependencies) +- `graph_dbs/mock.py`: In-memory mock backend for testing without infrastructure +- `components/`: Query, result, traversal components (minimal Layer 4 implementations) + +**Installation**: `pip install mellea-contribs[kg]` for Neo4j support + +### Design Patterns + +**Mellea Integration Pattern**: All validators are Requirement classes with a `validation_fn(output) → ValidationResult`. This enables iterative LLM refinement via the Instruct-Validate-Repair loop. + +**Caching Strategy**: Tools use decorator-based caching keyed on item hash + context + prompts to avoid redundant LLM calls. + +**Model Interaction**: Uses `mellea.instruct()` with grounding context, system prompts for output formatting (JSON arrays, single tokens), and rejection sampling (loop_budget=2) for reliability. + +**Data Validation Layers**: +- Legal citations: Direct lookup → fuzzy match → LLM judgment +- Python imports: Static AST analysis + dynamic error parsing +- Court classification: Pattern matching + database lookup + +### Test Infrastructure + +- Large test databases: `test/data/citation_exists_database.json` (~2.8MB) +- Qualitative test markers for LLM-dependent tests (auto-xfail in CI when `MELLEA_SKIP_QUALITATIVE` env var is set) +- Neo4j integration tests: Marked with `@pytest.mark.neo4j`, skipped unless NEO4J_URI is set +- Separate database for CI efficiency +- Shared fixtures via `test/conftest.py` +- Async-first testing: `asyncio_mode = "auto"` (no explicit marking needed) + +## Code Quality Standards + +- **Docstrings**: Google-style convention (enforced via ruff rule D) +- **Complexity**: Maximum cyclomatic complexity of 20 (ruff C901) +- **Type Hints**: Full type annotation with mypy checking enabled +- **Imports**: isort enforced for consistent ordering with `combine-as-imports` +- **Pre-commit hooks**: Automatic formatting and validation on commit + +## Release and Versioning + +- **Semantic Versioning**: Angular commit parser (feat → minor, fix → patch) +- **Automated Releases**: python-semantic-release on main branch +- **CI/CD**: GitHub Actions workflows (ci.yml, cd.yml, pypi.yml) +- **PyPI Publishing**: Automated on semantic version tags diff --git a/README.md b/README.md index 2b4d225..dcab960 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ - # Mellea Contribs The `mellea-contribs` repository is an incubation point for contributions to diff --git a/docs/examples/kgrag/.env_template b/docs/examples/kgrag/.env_template new file mode 100644 index 0000000..bcd93da --- /dev/null +++ b/docs/examples/kgrag/.env_template @@ -0,0 +1,52 @@ +# Graph Database Configuration (Neo4j default; adjust for other graph DBs) +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=password + +# Data Directory +KG_BASE_DIRECTORY=./dataset +DATA_PATH=./data + +# --------------------------------------------------------------------------- +# Primary LLM — any OpenAI-compatible endpoint +# --------------------------------------------------------------------------- +# Option A: OpenAI +# API_KEY=sk-... +# MODEL_NAME=gpt-4o-mini + +# Option B: Local Ollama +# API_BASE=http://localhost:11434/v1 +# API_KEY=ollama +# MODEL_NAME=llama3.2 + +# Option C: vLLM / self-hosted +# API_BASE=http://localhost:8000/v1 +# API_KEY=dummy +# MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct + +# Option D: Azure OpenAI +# API_BASE=https://.openai.azure.com/openai/deployments// +# API_KEY= +# MODEL_NAME=gpt-4o-mini + +# --------------------------------------------------------------------------- +# Optional: Separate evaluation model (defaults to primary LLM if unset) +# --------------------------------------------------------------------------- +# EVAL_API_BASE=... +# EVAL_API_KEY=... +# EVAL_MODEL_NAME=... + +# --------------------------------------------------------------------------- +# Optional: Embedding model for vector entity alignment +# --------------------------------------------------------------------------- +# EMB_API_BASE=http://localhost:11434/v1 +# EMB_API_KEY=ollama +# EMB_MODEL_NAME=nomic-embed-text +# VECTOR_DIMENSIONS=768 + +# Request Configuration +MAX_RETRIES=3 +TIME_OUT=1800 + +# OpenTelemetry — disable if you don't have a collector running +OTEL_SDK_DISABLED=true diff --git a/docs/examples/kgrag/GraphRag.drawio.png b/docs/examples/kgrag/GraphRag.drawio.png new file mode 100644 index 0000000..94840e5 Binary files /dev/null and b/docs/examples/kgrag/GraphRag.drawio.png differ diff --git a/docs/examples/kgrag/README.md b/docs/examples/kgrag/README.md new file mode 100644 index 0000000..d554d77 --- /dev/null +++ b/docs/examples/kgrag/README.md @@ -0,0 +1,429 @@ +# KG-RAG: Knowledge Graph-Enhanced Retrieval-Augmented Generation + +A complete example system demonstrating how to build intelligent retrieval-augmented generation (RAG) using knowledge graphs with the Mellea framework and RITS cloud LLM service. + +## Overview + +This example demonstrates a five-stage KG-RAG pipeline: + +1. **Preprocessing**: Load predefined structured data into a Neo4j knowledge graph +2. **Embedding**: Generate and store vector embeddings for entities and relations +3. **Updating**: Process documents to extract and merge new entities/relations into the KG +4. **QA**: Answer questions using multi-hop Think-on-Graph reasoning over the KG +5. **Evaluation**: Score predictions with an LLM judge and compute CRAG metrics + +**Tech Stack**: +- **Neo4j**: Graph database (localhost:7687) +- **RITS**: Cloud LLM service (llama-3-3-70b-instruct model) +- **Mellea**: LLM orchestration framework with `OpenAIBackend` + +**Domain Example:** Movie & Entertainment domain with 64K+ movies, 373K+ persons, and 1M+ relations. + +## Directory Structure + +``` +docs/examples/kgrag/ +├── README.md (this file) +├── .env_template # Copy to .env and fill in credentials +├── dataset/ +│ └── README.md # Dataset acquisition instructions +├── models/ +│ └── movie_domain_models.py # Movie entity classes (MovieEntity, PersonEntity, AwardEntity) +├── preprocessor/ +│ └── movie_preprocessor.py # Domain-specific preprocessing +├── rep/ +│ └── movie_rep.py # Movie-specific representations for LLM prompts +└── scripts/ + ├── run.sh # Pipeline orchestration (all 5 steps) + ├── create_tiny_dataset.py # Slice a small dataset for testing + ├── run_kg_preprocess.py # Step 1: Load predefined data into Neo4j + ├── run_kg_embed.py # Step 2: Generate embeddings for entities + ├── run_kg_update.py # Step 3: Update KG with new documents + ├── run_qa.py # Step 4: Run QA retrieval over questions + └── run_eval.py # Step 5: Evaluate QA results (LLM judge + CRAG metrics) +``` + +## Quick Start + +### Prerequisites + +1. **Start Neo4j Server** + ```bash + # Neo4j should be running on localhost:7687 + docker ps | grep neo4j # If using Docker + ``` + +2. **Configure credentials** + ```bash + cd docs/examples/kgrag + cp .env_template .env + # Edit .env and set: API_BASE, RITS_API_KEY, MODEL_NAME + ``` + +3. **Place dataset files** in `dataset/` (see [Dataset Files](#dataset-files) below) + +### Running the Pipeline + +Use `run.sh` to run the full pipeline or individual steps. + +#### Dataset mode + +| Flag | Description | +|------|-------------| +| `--tiny` (default) | Uses the tiny test dataset (~10 docs). Step 0 creates it if missing. | +| `--full` | Uses the full dataset (`crag_movie_dev.jsonl.bz2`). Skips step 0. | + +```bash +cd scripts + +# Run all steps on the tiny dataset (default) +bash run.sh + +# Run all steps on the full dataset +bash run.sh --full + +# Run specific steps only +bash run.sh --tiny 3 4 5 # update + QA + eval on tiny set +bash run.sh --full 4 5 # QA + eval on full set +bash run.sh 1 2 # reload Neo4j + recompute embeddings +``` + +#### Step reference + +| Step | Script | Description | +|------|--------|-------------| +| 0 | `create_tiny_dataset.py` | Create tiny dataset from full set (tiny mode only) | +| 1 | `run_kg_preprocess.py` | Load 64K movies + 373K persons into Neo4j | +| 2 | `run_kg_embed.py` | Compute and store entity/relation embeddings | +| 3 | `run_kg_update.py` | Extract entities from documents and merge into KG | +| 4 | `run_qa.py` | Answer questions via Think-on-Graph retrieval | +| 5 | `run_eval.py` | Score predictions with LLM judge; compute CRAG metrics | + +Output files (written to `output/`): + +| File | Produced by | +|------|-------------| +| `preprocess_stats.json` | Step 1 | +| `embedding_stats.json` | Step 2 | +| `update_stats.json` | Step 3 | +| `qa_results.jsonl` | Step 4 | +| `qa_progress.json` | Step 4 (resumption state) | +| `eval_results.json` | Step 5 (annotated per-item results) | +| `eval_metrics.json` | Step 5 (aggregate CRAG metrics) | + +## Individual Scripts + +### Step 1: Preprocessing + +Load predefined movie/person data into Neo4j: + +```bash +python run_kg_preprocess.py \ + --data-dir ../dataset/movie \ + --neo4j-uri bolt://localhost:7687 \ + --neo4j-user neo4j \ + --neo4j-password password \ + --batch-size 500 +``` + +Output (`preprocess_stats.json`): +```json +{ + "total_documents": 1, + "entities_loaded": 437891, + "entities_inserted": 437891, + "relations_inserted": 1045369 +} +``` + +### Step 2: Embedding + +Compute embeddings for entities and relations: + +```bash +python run_kg_embed.py \ + --neo4j-uri bolt://localhost:7687 \ + --neo4j-user neo4j \ + --neo4j-password password \ + --batch-size 100 +``` + +Set `EMB_API_BASE` / `EMB_MODEL_NAME` in `.env` to use a custom embedding endpoint. + +### Step 3: KG Update + +Extract entities/relations from documents and merge them into the KG: + +```bash +# Tiny dataset (testing) +python run_kg_update.py \ + --dataset ../dataset/crag_movie_tiny.jsonl.bz2 \ + --domain movie \ + --num-workers 10 + +# Full dataset +python run_kg_update.py \ + --dataset ../dataset/crag_movie_dev.jsonl.bz2 \ + --domain movie \ + --num-workers 64 +``` + +LLM configuration is read from `.env` (`API_BASE`, `MODEL_NAME`, `RITS_API_KEY`). + +### Step 4: QA + +Answer questions using Think-on-Graph multi-hop retrieval: + +```bash +python run_qa.py \ + --dataset ../dataset/crag_movie_tiny.jsonl.bz2 \ + --output ../output/qa_results.jsonl \ + --progress ../output/qa_progress.json \ + --domain movie \ + --routes 3 \ + --width 30 \ + --depth 3 +``` + +Key options: +- `--routes N` — number of independent solving routes (default: 3) +- `--width N` — max candidate relations per traversal step (default: 30) +- `--depth N` — max traversal depth (default: 3) +- `--reset-progress` — ignore previous progress and reprocess all questions +- `--workers N` — parallel async workers (default: 1) + +Output JSONL format: +```json +{ + "id": "q_0", + "query": "Who directed Inception?", + "predicted": "Christopher Nolan", + "answer": "Christopher Nolan", + "answer_aliases": ["Christopher Nolan", "Nolan"], + "elapsed_ms": 1234.5 +} +``` + +### Step 5: Evaluation + +Evaluate predictions using LLM judge + CRAG-style scoring: + +```bash +python run_eval.py \ + --input ../output/qa_results.jsonl \ + --output ../output/eval_results.json \ + --metrics ../output/eval_metrics.json + +# Skip LLM calls (fuzzy match only, for testing) +python run_eval.py \ + --input ../output/qa_results.jsonl \ + --metrics ../output/eval_metrics.json \ + --mock +``` + +The evaluator runs each prediction through: +1. Exact match against `answer_aliases` +2. Fuzzy match (rapidfuzz token_set_ratio ≥ 0.8) +3. LLM judge (for cases not resolved by string matching) + +Output (`eval_metrics.json`): +```json +{ + "total": 100, + "n_correct": 72, + "n_miss": 5, + "n_hallucination": 23, + "accuracy": 72.0, + "score": 49.0, + "hallucination": 23.0, + "missing": 5.0, + "eval_model": "meta-llama/llama-3-3-70b-instruct" +} +``` + +**CRAG score formula**: `((2 × correct + missing) / total − 1) × 100` +— penalises hallucination more than unanswered questions. + +Use a separate eval model by setting `EVAL_API_BASE` / `EVAL_MODEL_NAME` / `EVAL_RITS_API_KEY` +in `.env`; the script falls back to the main session if these are not set. + +## Configuration + +### .env file + +```bash +cp .env_template .env +``` + +```bash +# Neo4j +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=password + +# Primary LLM (RITS or any OpenAI-compatible endpoint) +API_BASE=https://your-rits-endpoint/v1 +MODEL_NAME=meta-llama/llama-3-3-70b-instruct +API_KEY=dummy +RITS_API_KEY=your_rits_api_key + +# Optional: separate eval model +EVAL_API_BASE=https://your-eval-endpoint/v1 +EVAL_MODEL_NAME=meta-llama/llama-3-3-70b-instruct +# EVAL_RITS_API_KEY= # falls back to RITS_API_KEY if unset + +# Optional: embedding model +EMB_API_BASE=https://your-embedding-endpoint/v1 +EMB_MODEL_NAME=text-embedding-3-small + +# Misc +OTEL_SDK_DISABLED=true +``` + +The Python scripts load this file automatically via `python-dotenv` (`override=False`), +so values already exported in the shell take precedence. + +### Session architecture + +All scripts create sessions using `create_session_from_env` from +`mellea_contribs.kg.utils`, which wires up `MelleaSession(backend=OpenAIBackend(...))` +directly — no LiteLLM proxy needed. The `RITS_API_KEY` is forwarded as a custom HTTP +header that RITS requires for authentication. + +## Domain-Specific Components + +### Movie Domain Models (`models/movie_domain_models.py`) + +Defines domain-specific entity classes extending the core `Entity`/`Relation` models: +- `MovieEntity`: genre, release_year, budget, box_office +- `PersonEntity`: birth_year, nationality +- `AwardEntity`: category, year, ceremony + +### Movie Domain Preprocessor (`preprocessor/movie_preprocessor.py`) + +Extends `KGPreprocessor` with movie-specific extraction hints and post-processing +(entity type standardisation, relation normalisation). + +### Movie Domain Representation (`rep/movie_rep.py`) + +Formatting utilities for LLM prompts: `movie_entity_to_text`, `person_entity_to_text`, +`format_movie_context`, `movie_relation_to_text`. + +## Creating a Custom Domain + +1. **Models** — create `models/[domain]_models.py` extending `Entity`/`Relation` +2. **Preprocessor** — create `preprocessor/[domain]_preprocessor.py` extending `KGPreprocessor`; + implement `get_hints()` and optionally `post_process_extraction()` +3. **Representation** — create `rep/[domain]_rep.py` with domain-specific text formatters +4. **Run** — pass `--domain [domain]` to `run_kg_update.py` and `run_qa.py` + +## Testing + +```bash +# Run all KG utility tests from the project root +pytest test/kg/ -v +pytest test/kg/utils/ -v + +# Test scripts without Neo4j or LLM (mock mode) +cd scripts +python run_kg_update.py --dataset ../dataset/crag_movie_tiny.jsonl.bz2 --mock +python run_qa.py --dataset ../dataset/crag_movie_tiny.jsonl.bz2 --mock +python run_eval.py --input ../output/qa_results.jsonl --mock +``` + +## Troubleshooting + +### Neo4j connection + +```bash +# Verify Neo4j is reachable +nc -zv localhost 7687 +``` + +### LLM / RITS authentication + +- Ensure `.env` exists and `RITS_API_KEY` is set (not just the template placeholder) +- Do **not** `export API_BASE` / `RITS_API_KEY` as empty strings in the shell before + running `run.sh` — empty exports prevent `load_dotenv` from filling them in +- The scripts log `create_session_from_env(prefix=...): api_base=set/MISSING, rits_api_key=set/MISSING` + at `INFO` level to help diagnose missing credentials + +### Resuming an interrupted QA run + +Step 4 writes a progress file (`qa_progress.json`). Re-running without +`--reset-progress` picks up where it left off. Use `--reset-progress` to start fresh. + +### Dataset not found + +```bash +ls -lh ../dataset/crag_movie_tiny.jsonl.bz2 +# If missing, run step 0 first: +bash run.sh 0 +``` + +## Architecture + +``` +Step 1: Preprocessing Step 2: Embedding Step 3: Updating +run_kg_preprocess.py run_kg_embed.py run_kg_update.py + │ │ │ + ├─ Load predefined data ├─ Fetch entities ├─ Load documents + ├─ Batch insert Neo4j ├─ Compute embeddings├─ Extract entities/relations + └─ Output stats └─ Store + index └─ Align & merge with KG + + Neo4j Knowledge Graph + (bolt://localhost:7687) + +Step 4: QA Step 5: Evaluation +run_qa.py run_eval.py + │ │ + ├─ Decompose question ├─ Exact / fuzzy match + ├─ Align entities (embed+fuzzy) ├─ LLM judge + ├─ Think-on-Graph traversal └─ CRAG metrics (accuracy, score, + ├─ Prune + synthesise answer hallucination, missing) + └─ Output JSONL results +``` + +## Dataset Files + +Large data files are **not tracked in git** to keep repository size manageable. + +| File | Size | Used by | +|------|------|---------| +| `dataset/crag_movie_dev.jsonl.bz2` | ~140 MB | Step 3 (full mode) | +| `dataset/movie/movie_db.json` | ~181 MB | Step 1 | +| `dataset/movie/person_db.json` | ~44 MB | Step 1 | + +### Generating a tiny test dataset + +```bash +cd scripts +python create_tiny_dataset.py --output ../dataset/crag_movie_tiny.jsonl +# Or truncate from the full set: +python create_truncated_dataset.py --max-examples 100 +``` + +### Acquiring the full dataset + +- Full CRAG dataset: contact project maintainers for access to `crag_movie_dev.jsonl.bz2` +- Movie/person databases: sourced from the TMDB dataset + +### Testing without any data files + +Use `--mock` to test all scripts without a database or dataset: + +```bash +cd scripts +python run_kg_update.py --dataset ../dataset/crag_movie_tiny.jsonl.bz2 --mock +python run_qa.py --dataset ../dataset/crag_movie_tiny.jsonl.bz2 --mock +python run_eval.py --input ../output/qa_results.jsonl --mock +``` + +## See Also + +- **Core Library**: [mellea_contribs/kg/README.md](../../mellea_contribs/kg/README.md) +- **Mellea Framework**: https://github.com/generative-computing/mellea + +## License + +Apache License 2.0 (same as mellea-contribs) diff --git a/docs/examples/kgrag/dataset/__init__.py b/docs/examples/kgrag/dataset/__init__.py new file mode 100644 index 0000000..e300ff1 --- /dev/null +++ b/docs/examples/kgrag/dataset/__init__.py @@ -0,0 +1,5 @@ +"""Dataset loaders for the KG-RAG example.""" + +from .movie_dataset_loader import MovieDatasetLoader + +__all__ = ["MovieDatasetLoader"] diff --git a/docs/examples/kgrag/dataset/movie_dataset_loader.py b/docs/examples/kgrag/dataset/movie_dataset_loader.py new file mode 100644 index 0000000..6df8470 --- /dev/null +++ b/docs/examples/kgrag/dataset/movie_dataset_loader.py @@ -0,0 +1,135 @@ +"""Movie dataset loader for CRAG-format QA data. + +Reads the CRAG movie benchmark JSONL (plain or bz2-compressed) and yields +normalised QA items suitable for ``orchestrate_qa_retrieval``. + +Expected input format per line:: + + { + "query": "Who directed Inception?", + "query_time": "2024-03-05 00:00:00", + "answer": "Christopher Nolan", + "answer_aliases": ["Christopher Nolan", "Nolan"], + "search_results": [...] # ignored + } + +Each yielded item contains: + +* ``id`` — ``"{prefix}{index}"`` string used for resumption. +* ``query`` — question text. +* ``query_time`` — original timestamp string. +* ``answer`` — canonical gold answer string. +* ``answer_aliases``— list of acceptable answer strings. +* ``_raw`` — the original dict from the file. +""" + +from typing import Any, Dict, Generator + +from mellea_contribs.kg.utils.data_utils import BaseDatasetLoader, load_jsonl + + +class MovieDatasetLoader(BaseDatasetLoader): + """Dataset loader for the CRAG movie QA benchmark. + + Iterates over a JSONL (or ``.jsonl.bz2``) file and emits normalised QA + items. Supports slicing with ``prefix`` / ``postfix`` to process a + sub-range of the dataset, which is useful for parallel batch jobs. + + Args: + dataset_path: Path to the JSONL or ``.jsonl.bz2`` dataset file. + num_workers: Number of parallel async workers (default: ``1``). + queue_size: Internal asyncio queue capacity (default: ``100``). + id_prefix: String prepended to the index to form the item ``id`` + (default: ``"q_"``). + prefix: 0-based index of the first item to include (default: ``0``). + postfix: Exclusive upper bound; ``None`` means read to end of file + (default: ``None``). + """ + + def __init__( + self, + dataset_path: str, + num_workers: int = 1, + queue_size: int = 100, + id_prefix: str = "q_", + prefix: int = 0, + postfix: int | None = None, + ) -> None: + super().__init__( + dataset_path=dataset_path, + num_workers=num_workers, + queue_size=queue_size, + ) + self._id_prefix = id_prefix + self._prefix = prefix + self._postfix = postfix + + # ------------------------------------------------------------------ + # BaseDatasetLoader interface + # ------------------------------------------------------------------ + + def iter_items(self) -> Generator[Dict[str, Any], None, None]: + """Yield normalised QA items from the dataset file. + + Items outside the [``prefix``, ``postfix``) index range are skipped + silently. + + Yields: + Dict with keys ``id``, ``query``, ``query_time``, ``answer``, + ``answer_aliases``, and ``_raw``. + """ + for global_idx, raw in enumerate(load_jsonl(self.dataset_path)): + if global_idx < self._prefix: + continue + if self._postfix is not None and global_idx >= self._postfix: + break + + item = self._normalise(raw, global_idx) + if item is not None: + yield item + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _normalise( + self, raw: Dict[str, Any], index: int + ) -> Dict[str, Any] | None: + """Convert a raw CRAG record into a normalised QA item. + + Args: + raw: Raw dict from the JSONL file. + index: Global 0-based position in the file. + + Returns: + Normalised dict, or *None* when the record has no ``query``. + """ + query = raw.get("query") or raw.get("question") or "" + if not query: + return None + + query_time = raw.get("query_time") or raw.get("query_date") or "" + + # Canonical answer + answer = raw.get("answer") or "" + if isinstance(answer, list): + answer = answer[0] if answer else "" + + # Acceptable answer aliases + aliases = raw.get("answer_aliases") or [] + if isinstance(aliases, str): + aliases = [aliases] + # Always include the canonical answer + if answer and answer not in aliases: + aliases = [answer] + list(aliases) + + item_id = f"{self._id_prefix}{index}" + + return { + "id": item_id, + "query": str(query), + "query_time": str(query_time), + "answer": str(answer), + "answer_aliases": [str(a) for a in aliases], + "_raw": raw, + } diff --git a/docs/examples/kgrag/kg_overview.md b/docs/examples/kgrag/kg_overview.md new file mode 100644 index 0000000..0641000 --- /dev/null +++ b/docs/examples/kgrag/kg_overview.md @@ -0,0 +1,1597 @@ +# Mellea Graph Query Library Design + +## Vision + +A graph query library that embodies Mellea's philosophy: **graph queries and results as Components that format for LLMs, with composable query building, validation/repair loops, and backend abstraction.** + +## Core Philosophy Alignment + +### 1. Components as Prompt Templates +Graph queries and results are **Components** with `format_for_llm()` that produce LLM-ready representations. + +### 2. Instruct/Validate/Repair Loops +Query construction supports validation and iterative repair using Sampling Strategies. + +### 3. Composability and Immutability +Query components compose naturally and follow functional patterns (immutable updates via `deepcopy`). + +### 4. GraphDB Abstraction +Abstract over different graph databases (Neo4j, Neptune, RDF stores) similar to how Mellea abstracts LLM backends. + +--- + +## Architecture Overview + +```mermaid +flowchart TB + subgraph Layer1["🔷 LAYER 1: Application Layer"] + direction TB + App["User code using graph queries
with LLM reasoning"] + end + + subgraph Layer2["🔷 LAYER 2: Graph Query Components"] + direction TB + L2_1["GraphQuery (base Component)"] + L2_2["CypherQuery, SparqlQuery
(query builders)"] + L2_3["GraphTraversal
(multi-hop patterns)"] + L2_4["GraphResult
(formatted results)"] + end + + subgraph Layer3["🔷 LAYER 3: LLM-Guided Query Construction"] + direction TB + L3_1["QueryInstruction
(natural language → query)"] + L3_2["QueryValidation
(check query correctness)"] + L3_3["QueryRepair
(fix malformed queries)"] + L3_4["@generative functions
for query generation"] + end + + subgraph Layer4["🔷 LAYER 4: Graph Backend Abstraction"] + direction TB + L4_1["GraphBackend
(abstract base)"] + L4_2["Neo4jBackend, NeptuneBackend,
RDFBackend"] + L4_3["Query execution and
result parsing"] + end + + Layer1 --> Layer2 + Layer2 --> Layer3 + Layer3 --> Layer4 + + style Layer1 fill:#e8f5e9,stroke:#2e7d32,stroke-width:3px + style Layer2 fill:#e3f2fd,stroke:#1565c0,stroke-width:3px + style Layer3 fill:#fff3e0,stroke:#ef6c00,stroke-width:3px + style Layer4 fill:#f3e5f5,stroke:#6a1b9a,stroke-width:3px +``` + +The following is an example of how the data flow through the system when user give a question. (Note, we use neo4j database as an example but the implementation is graph database agnostic and the difference only exist in layer 4). +![KGRag example flowchart](GraphRag.drawio.png) + +--- + +## Module Structure + +**Location**: `mellea-contribs` repository (github.com/generative-computing/mellea-contribs) + +**Dependencies**: All graph database and KG dependencies should be added to a `[kg]` dependency group in `pyproject.toml` + +``` +mellea-contribs/kg/ +├── __init__.py # Public API exports +├── base.py # Core data structures +│ ├── GraphNode # Node dataclass +│ ├── GraphEdge # Edge dataclass +│ └── GraphPath # Path representation +│ +├── graph_dbs/ # Graph database backend implementations +│ ├── __init__.py +│ ├── base.py # GraphBackend (ABC) +│ ├── neo4j.py # Neo4j backend +│ ├── neptune.py # AWS Neptune backend +│ ├── rdf.py # RDF/SPARQL backend +│ └── mock.py # Mock backend for testing +│ +├── components/ # Graph Components (REQUIRED NAME) +│ ├── __init__.py +│ ├── query.py # GraphQuery, CypherQuery, SparqlQuery +│ ├── result.py # GraphResult component +│ ├── traversal.py # GraphTraversal patterns +│ └── llm_guided.py # LLM-guided query construction +│ ├── QueryInstruction # NL → Query instruction +│ ├── @generative functions # Query generation functions +│ └── query validation helpers +│ +├── sampling/ # Graph-specific sampling strategies (REQUIRED NAME) +│ ├── __init__.py +│ ├── validation.py # QueryValidationStrategy +│ ├── traversal.py # TraversalStrategy with pruning +│ └── extraction.py # SubgraphExtractionStrategy +│ +├── requirements/ # Graph-specific requirements (REQUIRED NAME) +│ └── __init__.py # All requirement functions +│ ├── is_valid_cypher() # Cypher syntax validation +│ ├── is_valid_sparql() # SPARQL syntax validation +│ ├── returns_results() # Query returns non-empty results +│ └── respects_schema() # Query respects graph schema +│ +└── README.md # Documentation +``` + +--- + +## Core Design: Data Structures + +### GraphNode and GraphEdge (Not Components) + +These are pure data structures, not Components. They represent graph data. + +```python +from dataclasses import dataclass +from typing import Any + +@dataclass +class GraphNode: + """A node in a graph. + + This is a dataclass, not a Component. It's just data. + """ + id: str + label: str # Node type/label + properties: dict[str, Any] + + @classmethod + def from_neo4j_node(cls, node: Any) -> "GraphNode": + """Create from Neo4j node object.""" + return cls( + id=str(node.id), + label=list(node.labels)[0] if node.labels else "Unknown", + properties=dict(node.items()), + ) + + +@dataclass +class GraphEdge: + """An edge in a graph. + + This is a dataclass, not a Component. It's just data. + """ + id: str + source: GraphNode + label: str # Relationship type + target: GraphNode + properties: dict[str, Any] + + @classmethod + def from_neo4j_relationship( + cls, + rel: Any, + source: GraphNode, + target: GraphNode + ) -> "GraphEdge": + """Create from Neo4j relationship object.""" + return cls( + id=str(rel.id), + source=source, + label=rel.type, + target=target, + properties=dict(rel.items()), + ) +``` + +**Key Design Principle**: Separate data (GraphNode, GraphEdge) from Components (GraphQuery, GraphResult). Components wrap data and provide `format_for_llm()`. + +--- + +## Core Design: Components + + +### Mellea Component Pattern + +Before implementing graph components, understand Mellea's Component pattern: + +```python +# From mellea.stdlib.base +@runtime_checkable +class Component(Protocol): + """A Component is a composite data structure intended to be represented to an LLM.""" + + def parts(self) -> list[Component | CBlock]: + """The constituent parts of the Component.""" + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> TemplateRepresentation | str: + """Formats the Component into a TemplateRepresentation or string.""" + raise NotImplementedError("format_for_llm isn't implemented by default") +``` + +**Key Mellea Patterns**: +1. ✅ **Private fields**: Use `self._field` for all internal state +2. ✅ **Public properties**: Use `@property` for read access +3. ✅ **blockify()**: Convert strings to CBlocks using `blockify()` +4. ✅ **Implement parts()**: Even if it just raises NotImplementedError +5. ✅ **TemplateRepresentation**: Always include `tools=None, images=None` +6. ✅ **template_order**: Always a list starting with `["*", "ComponentName"]` +7. ✅ **Immutability**: Use `deepcopy(self)` for creating modified copies + +--- + +### 1. GraphQuery Component + +Base Component for all graph queries. + +```python +from mellea.stdlib.base import Component, TemplateRepresentation, CBlock, blockify +from copy import deepcopy +from typing import Any + +class GraphQuery(Component): + """Base Component for graph queries. + + Represents a graph query that can be executed against a GraphBackend + and formatted for LLM consumption. + + Following Mellea patterns: + - Private fields with _ prefix + - Public properties for access + - format_for_llm() returns TemplateRepresentation + - Immutable updates via deepcopy + """ + + def __init__( + self, + query_string: str | CBlock | None = None, + parameters: dict | None = None, + description: str | CBlock | None = None, + metadata: dict | None = None, + ): + """Initialize a graph query. + + Args: + query_string: The actual query (Cypher, SPARQL, etc.) + parameters: Query parameters for parameterized queries + description: Natural language description of what the query does + metadata: Additional metadata (schema hints, temporal constraints, etc.) + """ + # MELLEA PATTERN: Store as private fields with _ prefix + self._query_string = blockify(query_string) if query_string is not None else None + self._parameters = parameters or {} + self._description = blockify(description) if description is not None else None + self._metadata = metadata or {} + + # MELLEA PATTERN: Public properties for access + @property + def query_string(self) -> str | None: + """Get the query string.""" + return str(self._query_string) if self._query_string else None + + @property + def parameters(self) -> dict: + """Get query parameters.""" + return self._parameters + + @property + def description(self) -> str | None: + """Get description.""" + return str(self._description) if self._description else None + + # MELLEA PATTERN: Implement parts() even if raising + def parts(self) -> list[Component | CBlock]: + """The constituent parts of the query.""" + raise NotImplementedError("parts isn't implemented by default") + + # MELLEA PATTERN: format_for_llm returns TemplateRepresentation + def format_for_llm(self) -> TemplateRepresentation: + """Format query for LLM consumption. + + Returns a representation that shows: + - What the query is trying to find (description) + - The query structure (query_string) + - Any constraints or parameters + """ + return TemplateRepresentation( + obj=self, + args={ + "description": self.description or "Graph query", + "query": self.query_string, + "parameters": self._parameters, + "metadata": self._metadata, + }, + tools=None, # MELLEA PATTERN: Always include + images=None, # MELLEA PATTERN: Always include + template_order=["*", "GraphQuery"], # MELLEA PATTERN: List with "*" first + ) + + # MELLEA PATTERN: Immutable updates using deepcopy + def with_description(self, description: str | CBlock) -> "GraphQuery": + """Return new query with updated description (immutable). + + Following Mellea's copy_and_repair pattern from Instruction. + """ + result = deepcopy(self) + result._description = blockify(description) if description is not None else None + return result + + def with_parameters(self, **params) -> "GraphQuery": + """Return new query with updated parameters (immutable).""" + result = deepcopy(self) + result._parameters = {**self._parameters, **params} + return result + + def with_metadata(self, **metadata) -> "GraphQuery": + """Return new query with updated metadata (immutable).""" + result = deepcopy(self) + result._metadata = {**self._metadata, **metadata} + return result +``` + +--- + +### 2. CypherQuery Component + +Composable Cypher query builder with fluent interface. + +```python +class CypherQuery(GraphQuery): + """Component for building Cypher queries (Neo4j). + + Provides a fluent, composable interface for building Cypher queries + that follows Mellea's immutability patterns (like Instruction). + + Example: + query = ( + CypherQuery() + .match("(m:Movie)") + .where("m.year = $year") + .return_("m.title", "m.year") + .order_by("m.year DESC") + .limit(10) + .with_parameters(year=2020) + ) + """ + + def __init__( + self, + query_string: str | CBlock | None = None, + parameters: dict | None = None, + description: str | CBlock | None = None, + metadata: dict | None = None, + # Query clauses (for composable building) + match_clauses: list[str] | None = None, + where_clauses: list[str] | None = None, + return_clauses: list[str] | None = None, + order_by: list[str] | None = None, + limit: int | None = None, + ): + """Initialize Cypher query builder.""" + # MELLEA PATTERN: Store clauses as private fields + self._match_clauses = match_clauses or [] + self._where_clauses = where_clauses or [] + self._return_clauses = return_clauses or [] + self._order_by = order_by or [] + self._limit = limit + + # Build query string from clauses if not provided + if query_string is None and match_clauses: + query_string = self._build_query_string( + self._match_clauses, + self._where_clauses, + self._return_clauses, + self._order_by, + self._limit + ) + + # Call parent constructor + super().__init__(query_string, parameters, description, metadata) + + @staticmethod + def _build_query_string(match, where, return_, order, limit) -> str: + """Build Cypher query string from clauses.""" + parts = [] + if match: + parts.append("MATCH " + ", ".join(match)) + if where: + parts.append("WHERE " + " AND ".join(where)) + if return_: + parts.append("RETURN " + ", ".join(return_)) + if order: + parts.append("ORDER BY " + ", ".join(order)) + if limit: + parts.append(f"LIMIT {limit}") + return "\n".join(parts) + + # MELLEA PATTERN: Fluent builder methods using deepcopy for immutability + def match(self, pattern: str) -> "CypherQuery": + """Add a MATCH clause (immutable). + + Returns a new CypherQuery with the clause added. + """ + result = deepcopy(self) + result._match_clauses = [*self._match_clauses, pattern] + result._query_string = blockify(result._build_query_string( + result._match_clauses, + result._where_clauses, + result._return_clauses, + result._order_by, + result._limit, + )) + return result + + def where(self, condition: str) -> "CypherQuery": + """Add a WHERE clause (immutable).""" + result = deepcopy(self) + result._where_clauses = [*self._where_clauses, condition] + result._query_string = blockify(result._build_query_string( + result._match_clauses, + result._where_clauses, + result._return_clauses, + result._order_by, + result._limit, + )) + return result + + def return_(self, *items: str) -> "CypherQuery": + """Add a RETURN clause (immutable).""" + result = deepcopy(self) + result._return_clauses = [*self._return_clauses, *items] + result._query_string = blockify(result._build_query_string( + result._match_clauses, + result._where_clauses, + result._return_clauses, + result._order_by, + result._limit, + )) + return result + + def order_by(self, *fields: str) -> "CypherQuery": + """Add ORDER BY clause (immutable).""" + result = deepcopy(self) + result._order_by = [*self._order_by, *fields] + result._query_string = blockify(result._build_query_string( + result._match_clauses, + result._where_clauses, + result._return_clauses, + result._order_by, + result._limit, + )) + return result + + def limit(self, n: int) -> "CypherQuery": + """Add LIMIT clause (immutable).""" + result = deepcopy(self) + result._limit = n + result._query_string = blockify(result._build_query_string( + result._match_clauses, + result._where_clauses, + result._return_clauses, + result._order_by, + result._limit, + )) + return result + + # MELLEA PATTERN: format_for_llm with proper template_order + def format_for_llm(self) -> TemplateRepresentation: + """Format Cypher query for LLM.""" + return TemplateRepresentation( + obj=self, + args={ + "description": self.description or "Cypher graph query", + "query": self.query_string, + "parameters": self._parameters, + "query_type": "Cypher (Neo4j)", + }, + tools=None, + images=None, + template_order=["*", "CypherQuery", "GraphQuery"], # Inheritance chain + ) +``` + +--- + +### 3. GraphResult Component + +Represents query results formatted for LLM consumption. + +```python +class GraphResult(Component): + """Component for graph query results. + + Formats query results in LLM-friendly ways: + - "triplets": (subject, predicate, object) format + - "natural": Natural language descriptions + - "paths": Path narratives + - "structured": JSON/XML representations + + Following Mellea patterns: + - Private fields, public properties + - format_for_llm() with multiple style options + - Can be used in Instruction.grounding_context + """ + + def __init__( + self, + nodes: list[GraphNode] | None = None, + edges: list[GraphEdge] | None = None, + paths: list[list[GraphNode | GraphEdge]] | None = None, + raw_result: Any | None = None, + query: GraphQuery | None = None, + format_style: str = "triplets", + ): + """Initialize graph result. + + Args: + nodes: List of nodes in the result + edges: List of edges in the result + paths: List of paths (sequences of nodes/edges) + raw_result: Raw result from backend + query: The query that produced this result + format_style: "triplets", "natural", "paths", "structured" + """ + # MELLEA PATTERN: Private fields + self._nodes = nodes or [] + self._edges = edges or [] + self._paths = paths or [] + self._raw_result = raw_result + self._query = query + self._format_style = format_style + + # MELLEA PATTERN: Public properties + @property + def nodes(self) -> list[GraphNode]: + """Get result nodes.""" + return self._nodes + + @property + def edges(self) -> list[GraphEdge]: + """Get result edges.""" + return self._edges + + @property + def paths(self) -> list[list[GraphNode | GraphEdge]]: + """Get result paths.""" + return self._paths + + # MELLEA PATTERN: Implement parts() + def parts(self) -> list[Component | CBlock]: + """The constituent parts.""" + raise NotImplementedError("parts isn't implemented by default") + + # MELLEA PATTERN: format_for_llm returns TemplateRepresentation + def format_for_llm(self) -> TemplateRepresentation: + """Format result for LLM based on format_style.""" + formatted_content = self._format_based_on_style() + + return TemplateRepresentation( + obj=self, + args={ + "query_description": self._query.description if self._query else None, + "result_count": len(self._nodes) + len(self._edges) + len(self._paths), + "content": formatted_content, + "format_style": self._format_style, + }, + tools=None, + images=None, + template_order=["*", "GraphResult"], + ) + + def _format_based_on_style(self) -> str: + """Format results based on the selected style.""" + if self._format_style == "triplets": + return self._format_as_triplets() + elif self._format_style == "natural": + return self._format_as_natural_text() + elif self._format_style == "paths": + return self._format_as_paths() + elif self._format_style == "structured": + return self._format_as_structured() + else: + return str(self._raw_result) + + def _format_as_triplets(self) -> str: + """Format as (subject, predicate, object) triplets.""" + lines = [] + for i, edge in enumerate(self._edges): + source = edge.source.label + relation = edge.label + target = edge.target.label + lines.append(f" {i+1}. ({source})-[{relation}]->({target})") + return "\n".join(lines) if lines else "No edges found" + + def _format_as_natural_text(self) -> str: + """Format as natural language descriptions.""" + descriptions = [] + for edge in self._edges: + source = edge.source.label + relation = edge.label.replace("_", " ").lower() + target = edge.target.label + descriptions.append(f"{source} {relation} {target}") + return ". ".join(descriptions) + "." if descriptions else "No relationships found" + + def _format_as_paths(self) -> str: + """Format as path descriptions.""" + path_descriptions = [] + for i, path in enumerate(self._paths): + elements = [] + for item in path: + if isinstance(item, GraphNode): + elements.append(f"[{item.label}]") + elif isinstance(item, GraphEdge): + elements.append(f"-{item.label}->") + path_descriptions.append(f" Path {i+1}: " + "".join(elements)) + return "\n".join(path_descriptions) if path_descriptions else "No paths found" + + def _format_as_structured(self) -> str: + """Format as structured JSON.""" + import json + data = { + "nodes": [ + {"id": n.id, "label": n.label, "properties": n.properties} + for n in self._nodes + ], + "edges": [ + {"source": e.source.id, "label": e.label, "target": e.target.id} + for e in self._edges + ], + } + return json.dumps(data, indent=2) +``` + +--- + +### 4. GraphTraversal Component + +High-level traversal patterns (BFS, DFS, multi-hop, shortest path). + +```python +from typing import Callable + +class GraphTraversal(Component): + """High-level graph traversal patterns. + + Provides common traversal patterns that work across different backends: + - "bfs": Breadth-first search + - "dfs": Depth-first search + - "multi_hop": Variable-length path traversal + - "shortest_path": Shortest path between nodes + + Following Mellea patterns for Component implementation. + """ + + def __init__( + self, + start_nodes: list[str], + pattern: str = "multi_hop", + max_depth: int = 3, + edge_filter: Callable[[GraphEdge], bool] | None = None, + node_filter: Callable[[GraphNode], bool] | None = None, + description: str | CBlock | None = None, + ): + """Initialize a traversal pattern. + + Args: + start_nodes: Starting node IDs or labels + pattern: "bfs", "dfs", "multi_hop", "shortest_path" + max_depth: Maximum depth to traverse + edge_filter: Optional filter function for edges + node_filter: Optional filter function for nodes + description: Description of traversal intent + """ + # MELLEA PATTERN: Private fields + self._start_nodes = start_nodes + self._pattern = pattern + self._max_depth = max_depth + self._edge_filter = edge_filter + self._node_filter = node_filter + self._description = blockify(description) if description is not None else None + + # MELLEA PATTERN: Public properties + @property + def start_nodes(self) -> list[str]: + return self._start_nodes + + @property + def pattern(self) -> str: + return self._pattern + + @property + def max_depth(self) -> int: + return self._max_depth + + # MELLEA PATTERN: Implement parts() + def parts(self) -> list[Component | CBlock]: + raise NotImplementedError("parts isn't implemented by default") + + # MELLEA PATTERN: format_for_llm + def format_for_llm(self) -> TemplateRepresentation: + """Format traversal for LLM.""" + description = str(self._description) if self._description else "Graph traversal" + + return TemplateRepresentation( + obj=self, + args={ + "description": description, + "start_nodes": self._start_nodes, + "pattern": self._pattern, + "max_depth": self._max_depth, + }, + tools=None, + images=None, + template_order=["*", "GraphTraversal"], + ) + + def to_cypher(self) -> CypherQuery: + """Convert traversal to Cypher query. + + This allows high-level traversal patterns to be compiled + to backend-specific queries. + """ + if self._pattern == "multi_hop": + # Variable-length path pattern + match_pattern = f"(start)-[*1..{self._max_depth}]->(end)" + description = f"Multi-hop traversal from {self._start_nodes}" + + return ( + CypherQuery() + .match(match_pattern) + .where(f"start.id IN {self._start_nodes}") + .return_("start", "end") + .with_description(description) + ) + elif self._pattern == "shortest_path": + match_pattern = f"path = shortestPath((start)-[*1..{self._max_depth}]->(end))" + description = f"Shortest path from {self._start_nodes}" + + return ( + CypherQuery() + .match(match_pattern) + .where(f"start.id IN {self._start_nodes}") + .return_("path") + .with_description(description) + ) + else: + raise ValueError(f"Unknown pattern: {self._pattern}") +``` + +--- + +## Core Design: Backends + +### GraphBackend (Abstract Base) + +Similar to Mellea's `Backend` abstraction for LLMs, but for graph databases. + +```python +from abc import ABC, abstractmethod +from typing import Any + +class GraphBackend(ABC): + """Abstract backend for graph databases. + + Provides a unified interface for executing graph queries across + different graph database systems (Neo4j, Neptune, RDF stores, etc.). + + Following Mellea's Backend pattern: + - Takes backend_id (like model_id) + - Takes backend_options (like model_options) + - Abstract methods for core operations + """ + + def __init__( + self, + backend_id: str, + *, + connection_uri: str | None = None, + auth: tuple[str, str] | None = None, + database: str | None = None, + backend_options: dict | None = None, + ): + """Initialize graph backend. + + Following Mellea's Backend(model_id, model_options) pattern. + + Args: + backend_id: Identifier for backend type (e.g., "neo4j", "neptune") + connection_uri: URI for connecting to the database + auth: (username, password) tuple for authentication + database: Database name (if multi-database system) + backend_options: Backend-specific options + """ + # MELLEA PATTERN: Similar to Backend.__init__ + self.backend_id = backend_id + self.backend_options = backend_options if backend_options is not None else {} + + # Graph-specific fields + self.connection_uri = connection_uri + self.auth = auth + self.database = database + + @abstractmethod + async def execute_query( + self, + query: GraphQuery, + **execution_options, + ) -> GraphResult: + """Execute a graph query and return results. + + Similar to Backend.generate_from_context() for LLMs. + Takes a Component (GraphQuery), returns a Component (GraphResult). + + Args: + query: The GraphQuery Component to execute + execution_options: Backend-specific execution options + + Returns: + GraphResult Component containing formatted results + """ + ... + + @abstractmethod + async def get_schema(self) -> dict[str, Any]: + """Get the graph schema. + + Returns: + Dictionary with node_types, edge_types, properties, etc. + """ + ... + + @abstractmethod + async def validate_query(self, query: GraphQuery) -> tuple[bool, str | None]: + """Validate query syntax and semantics. + + Returns: + (is_valid, error_message) + """ + ... + + def supports_query_type(self, query_type: str) -> bool: + """Check if this backend supports a query type (Cypher, SPARQL, etc.). + + Default implementation returns False. Subclasses should override. + """ + return False + + async def execute_traversal( + self, + traversal: GraphTraversal, + **execution_options, + ) -> GraphResult: + """Execute a high-level traversal pattern. + + Default implementation converts to backend-specific query. + """ + if self.supports_query_type("cypher"): + query = traversal.to_cypher() + return await self.execute_query(query, **execution_options) + else: + raise NotImplementedError( + f"Traversal not implemented for {self.__class__.__name__}" + ) +``` + +--- + +### Neo4jBackend (Concrete Implementation) + +```python +import neo4j +from typing import Any + +class Neo4jBackend(GraphBackend): + """Neo4j implementation of GraphBackend. + + Implements the abstract GraphBackend interface for Neo4j databases. + """ + + def __init__( + self, + connection_uri: str = "bolt://localhost:7687", + auth: tuple[str, str] | None = None, + database: str | None = None, + backend_options: dict | None = None, + ): + """Initialize Neo4j backend. + + Args: + connection_uri: Neo4j connection URI + auth: (username, password) tuple + database: Database name (for multi-database) + backend_options: Neo4j-specific options + """ + # Call parent constructor following Mellea pattern + super().__init__( + backend_id="neo4j", + connection_uri=connection_uri, + auth=auth, + database=database, + backend_options=backend_options, + ) + + # Create Neo4j drivers + self._driver = neo4j.GraphDatabase.driver( + connection_uri, + auth=auth, + ) + self._async_driver = neo4j.AsyncGraphDatabase.driver( + connection_uri, + auth=auth, + ) + + async def execute_query( + self, + query: GraphQuery, + **execution_options, + ) -> GraphResult: + """Execute a query in Neo4j. + + Takes a GraphQuery Component, executes it, returns GraphResult Component. + """ + # Get query string and parameters + query_string = query.query_string + parameters = query.parameters + + if not query_string: + raise ValueError("Query string is empty") + + # Execute query + async with self._async_driver.session(database=self.database) as session: + result = await session.run(query_string, parameters) + records = [record async for record in result] + + # Parse results into nodes, edges, paths + nodes, edges, paths = self._parse_neo4j_result(records) + + # Return GraphResult Component + return GraphResult( + nodes=nodes, + edges=edges, + paths=paths, + raw_result=records, + query=query, + format_style=execution_options.get("format_style", "triplets"), + ) + + def _parse_neo4j_result( + self, + records + ) -> tuple[list[GraphNode], list[GraphEdge], list]: + """Parse Neo4j records into GraphNode and GraphEdge objects.""" + nodes = [] + edges = [] + paths = [] + + node_cache = {} # Cache nodes by ID for edge creation + + for record in records: + for key in record.keys(): + value = record[key] + + if isinstance(value, neo4j.graph.Node): + node = GraphNode.from_neo4j_node(value) + node_cache[node.id] = node + nodes.append(node) + + elif isinstance(value, neo4j.graph.Relationship): + # Get source and target nodes + source_id = str(value.start_node.id) + target_id = str(value.end_node.id) + + # Get from cache or create + if source_id not in node_cache: + node_cache[source_id] = GraphNode.from_neo4j_node(value.start_node) + if target_id not in node_cache: + node_cache[target_id] = GraphNode.from_neo4j_node(value.end_node) + + source = node_cache[source_id] + target = node_cache[target_id] + + edge = GraphEdge.from_neo4j_relationship(value, source, target) + edges.append(edge) + + elif isinstance(value, neo4j.graph.Path): + # Parse path into alternating nodes and edges + path_items = [] + for node in value.nodes: + path_items.append(GraphNode.from_neo4j_node(node)) + for rel in value.relationships: + src = GraphNode.from_neo4j_node(rel.start_node) + tgt = GraphNode.from_neo4j_node(rel.end_node) + path_items.append(GraphEdge.from_neo4j_relationship(rel, src, tgt)) + paths.append(path_items) + + return nodes, edges, paths + + async def get_schema(self) -> dict[str, Any]: + """Get Neo4j schema. + + Queries for node labels, relationship types, and property keys. + """ + # Get node labels + labels_query = "CALL db.labels() YIELD label RETURN collect(label) as labels" + labels_result = await self.execute_query( + CypherQuery(query_string=labels_query) + ) + + # Get relationship types + types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) as types" + types_result = await self.execute_query( + CypherQuery(query_string=types_query) + ) + + # This is simplified - real implementation would extract from results + return { + "node_types": [], # Would parse from labels_result + "edge_types": [], # Would parse from types_result + "properties": {}, # Would need additional queries + } + + async def validate_query(self, query: GraphQuery) -> tuple[bool, str | None]: + """Validate Cypher query syntax. + + Uses Neo4j's EXPLAIN to validate without executing. + """ + try: + explain_query = f"EXPLAIN {query.query_string}" + async with self._async_driver.session(database=self.database) as session: + await session.run(explain_query, query.parameters) + return True, None + except neo4j.exceptions.CypherSyntaxError as e: + return False, str(e) + except Exception as e: + return False, f"Validation error: {str(e)}" + + def supports_query_type(self, query_type: str) -> bool: + """Neo4j supports Cypher queries.""" + return query_type.lower() == "cypher" + + async def close(self): + """Close Neo4j connections.""" + await self._async_driver.close() +``` + +--- + +## LLM-Guided Query Construction + +### @generative Functions for Query Generation + +Following Mellea's @generative pattern for NL → Query generation. + +```python +from mellea.stdlib.genslot import generative +from pydantic import BaseModel +from typing import Any + +class GeneratedQuery(BaseModel): + """Pydantic model for generated query output.""" + query: str + explanation: str + parameters: dict[str, Any] | None = None + + +@generative +async def natural_language_to_cypher( + natural_language_query: str, + graph_schema: str, + examples: str, +) -> GeneratedQuery: + """Generate a Cypher query from natural language. + + Given a natural language question and the graph schema, generate a + valid Cypher query that answers the question. + + Graph Schema: + {graph_schema} + + Examples: + {examples} + + Question: {natural_language_query} + + Generate a Cypher query to answer this question. Return as JSON: + {{"query": "MATCH ...", "explanation": "This query...", "parameters": {{}}}} + + Query:""" + pass + + +@generative +async def explain_query_result( + query: str, + result: str, + original_question: str, +) -> str: + """Explain a graph query result in natural language. + + Original Question: {original_question} + + Query Executed: + {query} + + Results: + {result} + + Explain what these results mean in relation to the original question. + Write a clear, natural language answer. + + Answer:""" + pass + + +@generative +async def suggest_query_improvement( + query: str, + error_message: str, + schema: str, +) -> GeneratedQuery: + """Suggest an improved query based on an error. + + The following query failed: + {query} + + Error: {error_message} + + Graph Schema: + {schema} + + Suggest a corrected version of the query. Return as JSON: + {{"query": "...", "explanation": "The issue was...", "parameters": {{}}}} + + Corrected Query:""" + pass +``` + +--- + +## Sampling Strategies + +### QueryValidationStrategy + +Uses Instruct/Validate/Repair pattern for query generation. + +```python +from mellea.stdlib.sampling import BaseSamplingStrategy +from mellea.stdlib.base import Context, Component, ModelOutputThunk, CBlock +from mellea.stdlib.requirement import Requirement, ValidationResult + +class QueryValidationStrategy(BaseSamplingStrategy): + """Sampling strategy for generating and validating graph queries. + + Uses Instruct/Validate/Repair pattern: + 1. Generate query from NL + 2. Validate syntax and executability + 3. If invalid, repair using error feedback + + Following Mellea's BaseSamplingStrategy pattern. + """ + + def __init__( + self, + backend: GraphBackend, + loop_budget: int = 3, + requirements: list[Requirement] | None = None, + ): + """Initialize strategy. + + Args: + backend: Graph backend for validation + loop_budget: Max repair attempts + requirements: Query validation requirements + """ + super().__init__(loop_budget=loop_budget, requirements=requirements) + self._backend = backend + + @staticmethod + def repair( + old_ctx: Context, + new_ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> tuple[Component, Context]: + """Repair failed query using error feedback. + + Following Mellea's repair pattern. + """ + # Get the last validation failure + last_validation = past_val[-1] + + # Extract error messages + error_messages = [] + for req, result in last_validation: + if not result.result and result.reason: + error_messages.append(result.reason) + + # Get the failed query + failed_query = past_results[-1].value + + # Create repair instruction using CBlock + repair_instruction = CBlock( + f"The previous query failed validation:\n" + f"Query: {failed_query}\n" + f"Errors: {', '.join(error_messages)}\n" + f"Please generate a corrected query." + ) + + return repair_instruction, new_ctx + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + """Select best query when all attempts failed. + + Returns the query with the fewest validation errors. + """ + error_counts = [] + for validation in sampled_val: + error_count = sum(1 for _, result in validation if not result.result) + error_counts.append(error_count) + + return error_counts.index(min(error_counts)) +``` + +--- + +## Requirements + +### Graph-Specific Requirements + +```python +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.base import Context + +def is_valid_cypher(backend: GraphBackend) -> Requirement: + """Requirement: Query must be valid Cypher syntax.""" + + async def validate(ctx: Context) -> ValidationResult: + query_string = ctx.last_assistant_message.as_str() + query = CypherQuery(query_string=query_string) + + is_valid, error = await backend.validate_query(query) + + return ValidationResult( + result=is_valid, + reason=error if not is_valid else "Valid Cypher syntax", + ) + + return Requirement( + description="Query must be valid Cypher syntax", + validation_fn=validate, + ) + + +def returns_results(backend: GraphBackend) -> Requirement: + """Requirement: Query must return non-empty results.""" + + async def validate(ctx: Context) -> ValidationResult: + query_string = ctx.last_assistant_message.as_str() + query = CypherQuery(query_string=query_string) + + result = await backend.execute_query(query) + has_results = len(result.nodes) > 0 or len(result.edges) > 0 + + return ValidationResult( + result=has_results, + reason="Query returned results" if has_results else "Query returned no results", + ) + + return Requirement( + description="Query must return non-empty results", + validation_fn=validate, + ) + + +def respects_schema(backend: GraphBackend) -> Requirement: + """Requirement: Query must only reference valid node/edge types from schema.""" + + async def validate(ctx: Context) -> ValidationResult: + query_string = ctx.last_assistant_message.as_str() + schema = await backend.get_schema() + + # Would need actual Cypher parsing logic + # For now, simplified validation + + return ValidationResult( + result=True, + reason="Query respects schema", + ) + + return Requirement( + description="Query must only reference valid schema types", + validation_fn=validate, + ) +``` + +--- + +## End-to-End Flow Example + +This example shows how the different layers work together when processing a natural language query about a knowledge graph. + +### Flow Description + +**User Query**: "Who acted in The Matrix?" + +**Layer 1 - Application Layer**: +- **Input**: Natural language question +- **Action**: Initiates graph query workflow +- **Component**: Application code (could be KGRag or custom implementation) + +**Layer 2 - Query Construction**: +- **Input**: Query intent + graph schema +- **Action**: Build structured Cypher query +- **Output**: CypherQuery with query string, parameters, and description + ``` + query_string: "MATCH (m:Movie {title: $title})<-[:ACTED_IN]-(p:Person) RETURN p" + parameters: {"title": "The Matrix"} + ``` + +**Layer 3 - LLM-Guided Validation**: +- **Input**: CypherQuery +- **Action**: Validate query against schema and syntax rules +- **Checks**: + - Valid Cypher syntax? + - Uses valid node/edge types from schema? + - Likely to return results? +- **Output**: Validated query (or repaired query if validation failed) + +**Layer 4 - Backend Execution**: +- **Input**: Validated CypherQuery +- **Action**: Execute query against Neo4j database +- **Process**: + - Submit Cypher to Neo4j + - Receive raw Neo4j records + - Parse records into GraphNode and GraphEdge objects +- **Output**: Structured graph data (nodes and edges) + +**Layer 2 - Result Formatting**: +- **Input**: Raw graph nodes and edges +- **Action**: Create GraphResult and format for LLM consumption +- **Output**: LLM-readable representation + ``` + Format styles available: + - "triplets": (Person)-[ACTED_IN]->(Movie) + - "natural": "Keanu Reeves acted in The Matrix" + - "paths": Path descriptions through graph + - "structured": JSON representation + ``` + +**Layer 1 - Answer Generation**: +- **Input**: Original question + formatted graph data +- **Action**: LLM generates natural language answer using graph context +- **Output**: "Keanu Reeves and Carrie-Anne Moss acted in The Matrix." + +### Flow Diagram + + + + +### Key Design Points + +1. **Not Everything is a Component**: + - `GraphNode` and `GraphEdge` are pure dataclasses (just data) + - `CypherQuery` and `GraphResult` are Components (have `format_for_llm()`) + - Application layer code is not a Component + +2. **Layer 2 Components are Created BY Other Layers**: + - Layer 2 defines the Component classes (`CypherQuery`, `GraphResult`) + - Layer 3 **creates** CypherQuery instances via `@generative` functions + - Layer 4 **creates** GraphResult instances when returning query results + - Layer 2 is not a processing step - it's a library of Component definitions + +3. **Actual Processing Flow** (for "who act in The Matrix"): + - **Layer 1**: Receives NL query → passes to Layer 3 + - **Layer 3**: LLM converts NL → CypherQuery Component → validates → repairs if needed + - **Layer 4**: Executes CypherQuery → parses results → creates GraphResult Component + - **Layer 1**: Uses GraphResult.format_for_llm() → LLM generates final answer + +4. **Clear Layer Responsibilities**: + - **Layer 1** (Application): Orchestration and answer generation + - **Layer 2** (Components): Defines data structures with `format_for_llm()` + - **Layer 3** (LLM-Guided): Query generation, validation, and repair + - **Layer 4** (Backend): Database execution and result parsing + +5. **Data Transformations**: + - Natural language → CypherQuery Component → Validated query → Neo4j records → GraphNode/GraphEdge → GraphResult Component → Formatted text → Natural language answer + +--- + +## Usage Examples + +### Example 1: Simple Query Building + +```python +from mellea_contribs.kg.components import CypherQuery +from mellea_contribs.kg.graph_dbs import Neo4jBackend + +# Create backend +backend = Neo4jBackend( + connection_uri="bolt://localhost:7687", + auth=("neo4j", "password"), +) + +# Build query using fluent interface (immutable - each call returns new instance) +query = ( + CypherQuery() + .match("(m:Movie)") + .where("m.year = $year") + .return_("m.title", "m.year") + .order_by("m.year DESC") + .limit(10) + .with_parameters(year=2020) + .with_description("Find 10 movies from 2020") +) + +# Execute query +result = await backend.execute_query(query, format_style="natural") + +# Result is a Component - can be formatted for LLM +print(result.format_for_llm()) +``` + +### Example 2: LLM-Guided Query Construction + +```python +from mellea_contribs.kg.components import CypherQuery +from mellea_contribs.kg.components.llm_guided import natural_language_to_cypher +from mellea_contribs.kg.sampling import QueryValidationStrategy +from mellea_contribs.kg.requirements import is_valid_cypher, returns_results +from mellea_contribs.kg.graph_dbs import Neo4jBackend +from mellea import MelleaSession + +# Create Mellea session for LLM +session = MelleaSession(...) + +# Create graph backend +graph_backend = Neo4jBackend(...) + +# Get graph schema +schema = await graph_backend.get_schema() + +# Use @generative to convert NL → Cypher +query_result, _ = await natural_language_to_cypher( + session.ctx, + session.backend, + natural_language_query="Find all movies directed by Christopher Nolan", + graph_schema=format_schema(schema), + examples=get_few_shot_examples(), +) + +# Create query from generated output +query = CypherQuery(query_string=query_result.query) + +# Execute with validation strategy +strategy = QueryValidationStrategy( + backend=graph_backend, + loop_budget=3, +) + +# This will validate and auto-repair if needed +sampling_result = await strategy.sample( + action=query, + context=session.ctx, + backend=session.backend, + requirements=[ + is_valid_cypher(graph_backend), + returns_results(graph_backend), + ], +) + +# Get the final valid query +final_query = sampling_result.value +result = await graph_backend.execute_query(final_query) +``` + +### Example 3: Graph Results in LLM Reasoning + +```python +from mellea_contribs.kg.components import CypherQuery +from mellea.stdlib.instruction import Instruction + +# Query graph +query = ( + CypherQuery() + .match("(a:Actor)-[:ACTED_IN]->(m:Movie)") + .where("m.genre = 'Sci-Fi'") + .return_("a.name", "m.title") + .limit(20) + .with_description("Find actors in sci-fi movies") +) + +result = await backend.execute_query(query, format_style="triplets") + +# Use result in downstream LLM reasoning +instruction = Instruction( + description="Answer the question using the graph context", + grounding_context={ + "graph_knowledge": result, # Result is a Component! + }, +) + +# The formatter will call result.format_for_llm() automatically +answer, _ = await session.backend.generate_from_context( + action=instruction, + ctx=session.ctx, +) +``` + +### Example 4: Multi-Hop Traversal + +```python +from mellea_contribs.kg.components import GraphTraversal + +# Define traversal +traversal = GraphTraversal( + start_nodes=["Christopher Nolan"], + pattern="multi_hop", + max_depth=3, + description="Find entities connected to Christopher Nolan within 3 hops", +) + +# Execute traversal (converts to Cypher internally) +result = await backend.execute_traversal(traversal, format_style="paths") + +# Result shows paths through the graph +print(result.format_for_llm()) +``` + +--- + +## Benefits of This Design + +1. **Philosophical Alignment**: Everything is a Component with `format_for_llm()` +2. **Follows Mellea Patterns**: Private fields, properties, deepcopy, TemplateRepresentation +3. **Composability**: Fluent query building with immutable updates +4. **Backend Abstraction**: Swap Neo4j ↔ Neptune ↔ RDF seamlessly +5. **LLM Integration**: Natural language → Query with validation/repair +6. **Result Formatting**: Multiple styles for different LLM reasoning tasks +7. **Reusability**: Works standalone or as KGRag foundation + +--- + +## Implementation Plan + +### Phase 1: Core Data Structures +1. Implement `base.py` with GraphNode, GraphEdge dataclasses +2. Write tests for data structure creation and serialization + +### Phase 2: Components +1. Implement `components.py`: + - GraphQuery (base component) + - CypherQuery (with fluent builder) + - GraphResult (with format styles) + - GraphTraversal (high-level patterns) +2. Write tests for Component protocol compliance +3. Verify format_for_llm() output + +### Phase 3: Backends +1. Implement `backends/base.py` with GraphBackend ABC +2. Implement `backends/neo4j.py` with Neo4jBackend +3. Implement `backends/mock.py` for testing +4. Write integration tests with actual Neo4j instance + +### Phase 4: LLM-Guided Features +1. Implement `llm_guided.py` with @generative functions +2. Implement `sampling.py` with QueryValidationStrategy +3. Implement `requirements.py` with validation functions +4. Write end-to-end tests for NL → Query generation + +### Phase 5: Documentation & Examples +1. Write comprehensive README +2. Create usage examples +3. Document integration with KGRag +4. Write API reference + +--- + +## Open Questions + +1. **Query Optimization**: Should we include query optimization hints? +2. **Caching**: Should query results be cached at the Component level? +3. **Streaming**: Support for streaming large result sets? +4. **Graph Algorithms**: Include common algorithms (PageRank, community detection)? +5. **Vector Search**: Integration with vector indices for semantic search? +6. **Multi-Backend Queries**: Support queries across multiple graph backends? + +--- + +This design creates a **Graph Query library that truly embodies Mellea's philosophy** while being practical and extensible. It treats graph queries and results as first-class Components, follows Mellea's established patterns, and integrates naturally with LLM-based reasoning. diff --git a/docs/examples/kgrag/models/__init__.py b/docs/examples/kgrag/models/__init__.py new file mode 100644 index 0000000..cd6fad8 --- /dev/null +++ b/docs/examples/kgrag/models/__init__.py @@ -0,0 +1,32 @@ +"""Domain-specific entity models for KG-RAG examples. + +This package contains example implementations of domain-specific entity models. +Each module demonstrates how to extend the base Entity class for different domains. + +Example: + To use the movie domain models:: + + from docs.examples.kgrag.models import MovieEntity, PersonEntity, AwardEntity + + movie = MovieEntity( + type="Movie", + name="Oppenheimer", + description="2023 biographical film", + paragraph_start="Oppenheimer is", + paragraph_end="by Nolan.", + release_year=2023, + director="Christopher Nolan" + ) +""" + +from docs.examples.kgrag.models.movie_domain_models import ( + AwardEntity, + MovieEntity, + PersonEntity, +) + +__all__ = [ + "MovieEntity", + "PersonEntity", + "AwardEntity", +] diff --git a/docs/examples/kgrag/models/movie_domain_models.py b/docs/examples/kgrag/models/movie_domain_models.py new file mode 100644 index 0000000..3c5c19a --- /dev/null +++ b/docs/examples/kgrag/models/movie_domain_models.py @@ -0,0 +1,64 @@ +"""Domain-specific entity models for the movie domain example. + +This module demonstrates how to extend the base Entity class for a specific domain. +Users can follow this pattern to create domain-specific entities for their own domains. + +Note: + This example assumes mellea-contribs is installed. Install with: + pip install mellea-contribs[kg] +""" + +from typing import Optional + +from pydantic import Field + +from mellea_contribs.kg.models import Entity + + +class MovieEntity(Entity): + """Movie-specific entity. + + Extends the base Entity class with movie-domain properties. + All extraction and storage fields are inherited from Entity. + """ + + release_year: Optional[int] = Field(default=None, description="Year the movie was released") + director: Optional[str] = Field(default=None, description="Director(s) of the movie") + box_office: Optional[float] = Field( + default=None, description="Box office earnings in millions USD" + ) + language: Optional[str] = Field(default=None, description="Primary language of the movie") + rating: Optional[float] = Field(default=None, description="IMDb/review rating 0-10") + + +class PersonEntity(Entity): + """Person-specific entity (actor, director, producer, etc.). + + Extends the base Entity class with person-domain properties. + """ + + birth_year: Optional[int] = Field(default=None, description="Birth year") + nationality: Optional[str] = Field(default=None, description="Nationality/country") + profession: Optional[str] = Field( + default=None, description="Primary profession (actor, director, producer, etc.)" + ) + + +class AwardEntity(Entity): + """Award-specific entity (Academy Award, Golden Globe, etc.). + + Extends the base Entity class with award-domain properties. + """ + + ceremony_number: Optional[int] = Field( + default=None, description="Ceremony/edition number (e.g., 96th Academy Awards)" + ) + award_type: Optional[str] = Field(default=None, description="Type of award (e.g., Best Picture)") + award_year: Optional[int] = Field(default=None, description="Year of the award ceremony") + + +__all__ = [ + "MovieEntity", + "PersonEntity", + "AwardEntity", +] diff --git a/docs/examples/kgrag/preprocessor/__init__.py b/docs/examples/kgrag/preprocessor/__init__.py new file mode 100644 index 0000000..6caee8f --- /dev/null +++ b/docs/examples/kgrag/preprocessor/__init__.py @@ -0,0 +1,31 @@ +"""Domain-specific KG preprocessor examples. + +This package contains example implementations of domain-specific KG preprocessors. +Each module demonstrates how to extend the generic KGPreprocessor for different domains. + +Available Examples: + MovieKGPreprocessor: Example preprocessor for the movie domain + +Example Usage:: + + from movie_preprocessor import MovieKGPreprocessor + from mellea import start_session + from mellea_contribs.kg import MockGraphBackend + + async def process_movies(): + session = start_session(backend_name="litellm", model_id="gpt-4o-mini") + backend = MockGraphBackend() + processor = MovieKGPreprocessor(backend=backend, session=session) + + result = await processor.process_document( + doc_text="Avatar directed by James Cameron was released in 2009.", + doc_id="avatar_1" + ) + print(f"Extracted {len(result.entities)} entities and {len(result.relations)} relations") +""" + +from .movie_preprocessor import MovieKGPreprocessor + +__all__ = [ + "MovieKGPreprocessor", +] diff --git a/docs/examples/kgrag/preprocessor/movie_preprocessor.py b/docs/examples/kgrag/preprocessor/movie_preprocessor.py new file mode 100644 index 0000000..a85c862 --- /dev/null +++ b/docs/examples/kgrag/preprocessor/movie_preprocessor.py @@ -0,0 +1,574 @@ +"""Domain-specific KG Preprocessor for the movie domain. + +This module demonstrates how to extend the generic KGPreprocessor for a specific domain +by providing domain-specific hints and post-processing logic. + +Example:: + + import asyncio + from mellea import start_session + from mellea_contribs.kg import MockGraphBackend + from movie_preprocessor import MovieKGPreprocessor + + async def main(): + session = start_session(backend_name="litellm", model_id="gpt-4o-mini") + backend = MockGraphBackend() + processor = MovieKGPreprocessor(backend=backend, session=session) + + # Process a movie document + doc_text = '''Avatar is a 2009 science fiction film directed by James Cameron. + It stars Sam Worthington, Zoe Saldana, and Sigourney Weaver. + The film was nominated for multiple Academy Awards. + ''' + + result = await processor.process_document( + doc_text=doc_text, + doc_id="avatar_wiki" + ) + print(f"Extracted {len(result.entities)} entities and {len(result.relations)} relations") + await backend.close() + + asyncio.run(main()) +""" + +import json +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from mellea_contribs.kg.components.query import CypherQuery +from mellea_contribs.kg.graph_dbs.base import GraphBackend +from mellea_contribs.kg.models import Entity, ExtractionResult +from mellea_contribs.kg.preprocessor import KGPreprocessor +from mellea_contribs.kg.utils import log_progress + + +class MovieKGPreprocessor(KGPreprocessor): + """Domain-specific KG preprocessor for the movie domain. + + Extends the generic KGPreprocessor with movie-specific extraction hints and + post-processing logic. Demonstrates how to customize preprocessing for a domain. + """ + + def __init__(self, *args, **kwargs): + """Initialize the movie preprocessor.""" + # Set domain to "movies" if not specified + if "domain" not in kwargs: + kwargs["domain"] = "movies" + super().__init__(*args, **kwargs) + + def get_hints(self) -> str: + """Get movie-domain-specific hints for LLM extraction. + + Provides guidance on what entity and relation types to look for in movie texts. + + Returns: + String with movie domain hints + """ + return """ +Movie Domain Extraction Guide: + +ENTITY TYPES to extract: +- Movie: Film titles, release years, budgets, box office +- Person: Actors, directors, producers, writers, cinematographers +- Award: Academy Awards, Golden Globes, BAFTA, Cannes Film Festival awards +- Studio: Production studios, distributors +- Genre: Film genres (Action, Drama, Comedy, etc.) +- Character: Character names and roles + +RELATION TYPES to extract: +- directed_by: Movie → Director +- acted_in: Actor → Movie +- produced_by: Movie → Producer +- written_by: Movie → Writer +- distributed_by: Movie → Studio +- nominated_for: Movie → Award +- won_award: Movie → Award +- starred_as: Actor → Character (in specific movie) +- belongs_to_genre: Movie → Genre +- prequel_of: Movie → Movie +- sequel_of: Movie → Movie + +EXTRACTION PRIORITIES: +1. Movie title and release year (most important) +2. Director and main cast +3. Awards and nominations +4. Production company +5. Box office and budget (if mentioned) +6. Plot and characters (optional) + +FORMATTING NOTES: +- Use standard English names for entities +- Include full movie titles (e.g., "Avatar: The Way of Water") +- For actors, use their professional names +- Include award year and category if available +""" + + async def post_process_extraction( + self, result: ExtractionResult, doc_text: str + ) -> ExtractionResult: + """Post-process extraction results for the movie domain. + + Applies movie-specific cleaning and enrichment to extracted entities and relations. + + Args: + result: The raw extraction result from LLM + doc_text: The original document text + + Returns: + Enriched extraction result with movie-specific post-processing + """ + # Clean up entity names and types + for entity in result.entities: + # Standardize entity types + entity.type = self._standardize_entity_type(entity.type) + + # Clean up names (trim whitespace, fix common issues) + entity.name = entity.name.strip() + + # Add movie-specific properties if possible + if entity.type == "Movie": + entity = self._enrich_movie_entity(entity, doc_text) + elif entity.type == "Person": + entity = self._enrich_person_entity(entity, doc_text) + + # Clean up relation types + for relation in result.relations: + relation.relation_type = self._standardize_relation_type(relation.relation_type) + + return result + + def _standardize_entity_type(self, entity_type: str) -> str: + """Standardize entity type names to movie domain vocabulary. + + Args: + entity_type: Raw entity type from LLM + + Returns: + Standardized entity type + """ + type_map = { + "film": "Movie", + "movie": "Movie", + "cinema": "Movie", + "actor": "Person", + "actress": "Person", + "director": "Person", + "producer": "Person", + "writer": "Person", + "cinematographer": "Person", + "composer": "Person", + "performer": "Person", + "studio": "Studio", + "production_studio": "Studio", + "distributor": "Studio", + "award": "Award", + "oscar": "Award", + "golden_globe": "Award", + "award_nomination": "Award", + "genre": "Genre", + "character": "Character", + "role": "Character", + } + + # Case-insensitive lookup + normalized = entity_type.lower().replace(" ", "_") + return type_map.get(normalized, entity_type) + + def _standardize_relation_type(self, relation_type: str) -> str: + """Standardize relation type names to movie domain vocabulary. + + Args: + relation_type: Raw relation type from LLM + + Returns: + Standardized relation type + """ + type_map = { + "directed": "directed_by", + "direct": "directed_by", + "acted": "acted_in", + "acted_in": "acted_in", + "starred_in": "acted_in", + "starring": "acted_in", + "produced": "produced_by", + "written": "written_by", + "distributed": "distributed_by", + "nominated_for": "nominated_for", + "nominated": "nominated_for", + "won": "won_award", + "won_award": "won_award", + "prequel": "prequel_of", + "sequel": "sequel_of", + "spinoff": "spinoff_of", + "based_on": "based_on", + "remake_of": "remake_of", + } + + # Case-insensitive lookup + normalized = relation_type.lower().replace(" ", "_") + return type_map.get(normalized, relation_type) + + def _enrich_movie_entity(self, entity: Entity, doc_text: str) -> Entity: + """Enrich a movie entity with additional extracted information. + + Args: + entity: The movie entity to enrich + doc_text: The source document text + + Returns: + Enriched entity with additional properties + """ + # This would be implemented with more sophisticated extraction logic + # For now, just return as-is + return entity + + def _enrich_person_entity(self, entity: Entity, doc_text: str) -> Entity: + """Enrich a person entity with additional extracted information. + + Args: + entity: The person entity to enrich + doc_text: The source document text + + Returns: + Enriched entity with additional properties + """ + # This would be implemented with more sophisticated extraction logic + # For now, just return as-is + return entity + + +@dataclass +class PreprocessingStats: + """Statistics for a predefined-data preprocessing run.""" + + domain: str + start_time: datetime + end_time: datetime + duration_seconds: float + entities_loaded: int + entities_inserted: int + relations_loaded: int + relations_inserted: int + success: bool + error_message: str = "" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON output.""" + return { + "domain": self.domain, + "start_time": self.start_time.isoformat(), + "end_time": self.end_time.isoformat(), + "duration_seconds": self.duration_seconds, + "entities_loaded": self.entities_loaded, + "entities_inserted": self.entities_inserted, + "relations_loaded": self.relations_loaded, + "relations_inserted": self.relations_inserted, + "success": self.success, + "error_message": self.error_message, + } + + def __str__(self) -> str: + """Format statistics for display.""" + status = "✓ SUCCESS" if self.success else "✗ FAILED" + lines = [ + f"Domain: {self.domain}", + f"Status: {status}", + f"Duration: {self.duration_seconds:.2f}s", + f"Entities loaded: {self.entities_loaded}", + f"Entities inserted: {self.entities_inserted}", + f"Relations loaded: {self.relations_loaded}", + f"Relations inserted: {self.relations_inserted}", + ] + if self.error_message: + lines.append(f"Error: {self.error_message}") + return "\n".join(lines) + + +class PredefinedDataPreprocessor: + """Loads predefined movie/person JSON databases into the knowledge graph. + + Reads ``movie_db.json`` and ``person_db.json`` from *data_dir* and inserts + entities and relations via the graph backend's Cypher execution API. + + Args: + backend: Graph database backend. + data_dir: Directory containing ``movie_db.json`` and ``person_db.json``. + batch_size: Number of records per Cypher batch (default: 50). + """ + + def __init__( + self, + backend: GraphBackend, + data_dir: Path, + batch_size: int = 50, + ): + """Initialize preprocessor.""" + self.backend = backend + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.movie_db: Dict[str, Dict] = {} + self.person_db: Dict[str, Dict] = {} + + async def preprocess(self) -> PreprocessingStats: + """Run the full preprocessing pipeline. + + Returns: + PreprocessingStats with counts and timing. + """ + start_time = datetime.now() + try: + log_progress("Loading movie database...") + self.movie_db = self._load_json_file("movie_db.json") + log_progress(f"✓ Loaded {len(self.movie_db)} movies") + + log_progress("Loading person database...") + self.person_db = self._load_json_file("person_db.json") + log_progress(f"✓ Loaded {len(self.person_db)} persons") + + log_progress("\nInserting movie entities...") + movies_inserted = await self._insert_movies() + log_progress("Inserting person entities...") + persons_inserted = await self._insert_persons() + log_progress("\nInserting movie-person relations...") + relations_inserted = await self._insert_movie_relations() + + end_time = datetime.now() + return PreprocessingStats( + domain="movie", + start_time=start_time, + end_time=end_time, + duration_seconds=(end_time - start_time).total_seconds(), + entities_loaded=len(self.movie_db) + len(self.person_db), + entities_inserted=movies_inserted + persons_inserted, + relations_loaded=0, + relations_inserted=relations_inserted, + success=True, + ) + except Exception as e: + end_time = datetime.now() + log_progress(f"✗ Preprocessing failed: {e}") + return PreprocessingStats( + domain="movie", + start_time=start_time, + end_time=end_time, + duration_seconds=(end_time - start_time).total_seconds(), + entities_loaded=0, + entities_inserted=0, + relations_loaded=0, + relations_inserted=0, + success=False, + error_message=str(e), + ) + finally: + await self.backend.close() + + def _load_json_file(self, filename: str) -> Dict[str, Any]: + """Load a JSON file from the data directory.""" + file_path = self.data_dir / filename + if not file_path.exists(): + raise FileNotFoundError(f"Data file not found: {file_path}") + with open(file_path, "r") as f: + return json.load(f) + + async def _execute_cypher_batch(self, cypher_query: str, batch: List[Dict]) -> None: + """Execute a parameterised Cypher query with a batch of items.""" + if not batch: + return + try: + query = CypherQuery(cypher_query, parameters={"batch": batch}) + await self.backend.execute_query(query) + except Exception as e: + log_progress(f" Warning: Batch insert failed: {e}") + + async def _insert_movies(self) -> int: + """Insert movie entities. Returns count inserted.""" + count = 0 + batch: List[Dict] = [] + for movie_id, movie_data in self.movie_db.items(): + batch.append({ + "name": movie_data.get("title", f"Movie_{movie_id}").upper(), + "release_date": movie_data.get("release_date"), + "original_language": movie_data.get("original_language"), + "budget": str(movie_data.get("budget")) if movie_data.get("budget") else None, + "revenue": str(movie_data.get("revenue")) if movie_data.get("revenue") else None, + "rating": str(movie_data.get("rating")) if movie_data.get("rating") else None, + }) + count += 1 + if len(batch) >= self.batch_size: + await self._execute_cypher_batch( + """ + UNWIND $batch AS movie + MERGE (m:Movie {name: movie.name}) + SET m.release_date = movie.release_date, + m.original_language = movie.original_language, + m.budget = movie.budget, + m.revenue = movie.revenue, + m.rating = movie.rating + """, + batch, + ) + log_progress(f" Inserted {count} movies...") + batch = [] + if batch: + await self._execute_cypher_batch( + """ + UNWIND $batch AS movie + MERGE (m:Movie {name: movie.name}) + SET m.release_date = movie.release_date, + m.original_language = movie.original_language, + m.budget = movie.budget, + m.revenue = movie.revenue, + m.rating = movie.rating + """, + batch, + ) + log_progress(f"✓ Inserted {count} movie entities") + return count + + async def _insert_persons(self) -> int: + """Insert person entities. Returns count inserted.""" + count = 0 + batch: List[Dict] = [] + for person_id, person_data in self.person_db.items(): + batch.append({ + "name": person_data.get("name", f"Person_{person_id}").upper(), + "birthday": person_data.get("birthday"), + }) + count += 1 + if len(batch) >= self.batch_size: + await self._execute_cypher_batch( + """ + UNWIND $batch AS person + MERGE (p:Person {name: person.name}) + SET p.birthday = person.birthday + """, + batch, + ) + log_progress(f" Inserted {count} persons...") + batch = [] + if batch: + await self._execute_cypher_batch( + """ + UNWIND $batch AS person + MERGE (p:Person {name: person.name}) + SET p.birthday = person.birthday + """, + batch, + ) + log_progress(f"✓ Inserted {count} person entities") + return count + + async def _insert_movie_relations(self) -> int: + """Insert relations between movies and persons. Returns count inserted.""" + count = 0 + cast_batch: List[Dict] = [] + director_batch: List[Dict] = [] + genre_batch: List[Dict] = [] + + for movie_id, movie_data in self.movie_db.items(): + movie_name = movie_data.get("title", f"Movie_{movie_id}").upper() + + for cast_member in movie_data.get("cast") or []: + if not isinstance(cast_member, dict): + continue + person_name = cast_member.get("name", "").upper() + if not person_name: + continue + cast_batch.append({ + "person_name": person_name, + "movie_name": movie_name, + "character": cast_member.get("character", ""), + "order": cast_member.get("order", 0), + }) + count += 1 + if len(cast_batch) >= self.batch_size: + await self._execute_cypher_batch( + """ + UNWIND $batch AS item + MATCH (m:Movie {name: item.movie_name}) + MATCH (p:Person {name: item.person_name}) + MERGE (p)-[:ACTED_IN {character: item.character, order: item.order}]->(m) + """, + cast_batch, + ) + cast_batch = [] + + for crew_member in movie_data.get("crew") or []: + if not isinstance(crew_member, dict): + continue + person_name = crew_member.get("name", "").upper() + job = crew_member.get("job", "").lower() + if not person_name or not job: + continue + if "director" in job: + director_batch.append({"person_name": person_name, "movie_name": movie_name}) + if len(director_batch) >= self.batch_size: + await self._execute_cypher_batch( + """ + UNWIND $batch AS item + MATCH (m:Movie {name: item.movie_name}) + MATCH (p:Person {name: item.person_name}) + MERGE (p)-[:DIRECTED]->(m) + """, + director_batch, + ) + director_batch = [] + count += 1 + + for genre in movie_data.get("genres") or []: + genre_name = (genre.get("name", "") if isinstance(genre, dict) else str(genre)).upper() + if not genre_name: + continue + genre_batch.append({"movie_name": movie_name, "genre_name": genre_name}) + count += 1 + if len(genre_batch) >= self.batch_size: + await self._execute_cypher_batch( + """ + UNWIND $batch AS item + MATCH (m:Movie {name: item.movie_name}) + MERGE (g:Genre {name: item.genre_name}) + MERGE (m)-[:BELONGS_TO_GENRE]->(g) + """, + genre_batch, + ) + genre_batch = [] + + # Flush remaining batches + if cast_batch: + await self._execute_cypher_batch( + """ + UNWIND $batch AS item + MATCH (m:Movie {name: item.movie_name}) + MATCH (p:Person {name: item.person_name}) + MERGE (p)-[:ACTED_IN {character: item.character, order: item.order}]->(m) + """, + cast_batch, + ) + if director_batch: + await self._execute_cypher_batch( + """ + UNWIND $batch AS item + MATCH (m:Movie {name: item.movie_name}) + MATCH (p:Person {name: item.person_name}) + MERGE (p)-[:DIRECTED]->(m) + """, + director_batch, + ) + if genre_batch: + await self._execute_cypher_batch( + """ + UNWIND $batch AS item + MATCH (m:Movie {name: item.movie_name}) + MERGE (g:Genre {name: item.genre_name}) + MERGE (m)-[:BELONGS_TO_GENRE]->(g) + """, + genre_batch, + ) + + log_progress(f"✓ Inserted {count} relations") + return count + + +__all__ = ["MovieKGPreprocessor", "PredefinedDataPreprocessor", "PreprocessingStats"] diff --git a/docs/examples/kgrag/rep/__init__.py b/docs/examples/kgrag/rep/__init__.py new file mode 100644 index 0000000..08b97a1 --- /dev/null +++ b/docs/examples/kgrag/rep/__init__.py @@ -0,0 +1,36 @@ +"""Domain-specific KG representation examples. + +This package contains example implementations of domain-specific entity +and relation representation utilities for different domains. + +Example Usage:: + + from movie_rep import movie_entity_to_text, format_movie_context + from docs.examples.kgrag.models import MovieEntity + + movie = MovieEntity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer is", + paragraph_end="by Nolan.", + release_year=2023, + director="Christopher Nolan" + ) + + # Format for LLM prompts + text = movie_entity_to_text(movie, include_confidence=True) + print(text) +""" + +from movie_rep import ( + format_movie_context, + movie_entity_to_text, + movie_relation_to_text, +) + +__all__ = [ + "movie_entity_to_text", + "movie_relation_to_text", + "format_movie_context", +] diff --git a/docs/examples/kgrag/rep/movie_rep.py b/docs/examples/kgrag/rep/movie_rep.py new file mode 100644 index 0000000..0b450e6 --- /dev/null +++ b/docs/examples/kgrag/rep/movie_rep.py @@ -0,0 +1,155 @@ +"""Movie domain-specific entity and relation representation utilities. + +This module demonstrates how to extend the generic representation utilities +for a specific domain (movie) with domain-specific formatting and validation. +""" + +from typing import Optional + +from docs.examples.kgrag.models import MovieEntity, PersonEntity, AwardEntity +from mellea_contribs.kg.models import Entity, Relation +from mellea_contribs.kg.rep import entity_to_text as base_entity_to_text +from mellea_contribs.kg.rep import relation_to_text as base_relation_to_text + + +def movie_entity_to_text(entity: Entity, include_confidence: bool = False) -> str: + """Format movie entity with domain-specific details. + + Extends the generic formatting with movie-specific fields like release year, + director, box office, etc. + + Args: + entity: Entity to format (should be MovieEntity, PersonEntity, or AwardEntity) + include_confidence: Whether to include confidence score + + Returns: + Formatted text representation optimized for movie domain + """ + # Use base formatting + text = base_entity_to_text(entity, include_confidence) + + # Add domain-specific details if available + if isinstance(entity, MovieEntity): + details = [] + if entity.release_year: + details.append(f"Released: {entity.release_year}") + if entity.director: + details.append(f"Director: {entity.director}") + if entity.box_office: + details.append(f"Box Office: ${entity.box_office}M") + if entity.language: + details.append(f"Language: {entity.language}") + if entity.rating: + details.append(f"Rating: {entity.rating}/10") + + if details: + text += "\n" + "\n".join(details) + + elif isinstance(entity, PersonEntity): + details = [] + if entity.birth_year: + details.append(f"Born: {entity.birth_year}") + if entity.nationality: + details.append(f"Nationality: {entity.nationality}") + if entity.profession: + details.append(f"Profession: {entity.profession}") + + if details: + text += "\n" + "\n".join(details) + + elif isinstance(entity, AwardEntity): + details = [] + if entity.ceremony_number: + details.append(f"Ceremony: #{entity.ceremony_number}") + if entity.award_type: + details.append(f"Award: {entity.award_type}") + if entity.award_year: + details.append(f"Year: {entity.award_year}") + + if details: + text += "\n" + "\n".join(details) + + return text + + +def movie_relation_to_text( + relation: Relation, include_confidence: bool = False +) -> str: + """Format movie relation with domain-specific context. + + Extends generic formatting with movie-specific relation handling. + + Args: + relation: Relation to format + include_confidence: Whether to include confidence score + + Returns: + Formatted text representation optimized for movie domain + """ + # Use base formatting + return base_relation_to_text(relation, include_confidence) + + +def format_movie_context( + entities: list[Entity], + relations: list[Relation], + include_confidence: bool = False, + max_entities: Optional[int] = None, + max_relations: Optional[int] = None, +) -> str: + """Format movie KG context with domain-specific formatting. + + Uses movie-specific entity formatting for better readability. + + Args: + entities: List of entities from movie KG + relations: List of relations from movie KG + include_confidence: Whether to include confidence scores + max_entities: Maximum entities to display + max_relations: Maximum relations to display + + Returns: + Formatted movie KG context text + """ + sections = [] + + if entities: + sections.append("## Entities\n") + display_entities = entities[:max_entities] if max_entities else entities + + formatted = [] + for i, entity in enumerate(display_entities, 1): + formatted.append(f"{i}. {movie_entity_to_text(entity, include_confidence)}") + + if max_entities and len(entities) > max_entities: + formatted.append(f"\n... and {len(entities) - max_entities} more entities") + + sections.append("\n\n".join(formatted)) + + if relations: + sections.append("\n## Relations\n") + display_relations = ( + relations[:max_relations] if max_relations else relations + ) + + formatted = [] + for i, relation in enumerate(display_relations, 1): + formatted.append( + f"{i}. {movie_relation_to_text(relation, include_confidence)}" + ) + + if max_relations and len(relations) > max_relations: + formatted.append( + f"\n... and {len(relations) - max_relations} more relations" + ) + + sections.append("\n\n".join(formatted)) + + return "\n".join(sections) if sections else "(Empty movie knowledge graph)" + + +__all__ = [ + "movie_entity_to_text", + "movie_relation_to_text", + "format_movie_context", +] diff --git a/docs/examples/kgrag/scripts/__init__.py b/docs/examples/kgrag/scripts/__init__.py new file mode 100644 index 0000000..04fec17 --- /dev/null +++ b/docs/examples/kgrag/scripts/__init__.py @@ -0,0 +1 @@ +"""KG-RAG example scripts for preprocessing, embedding, QA, and evaluation.""" diff --git a/docs/examples/kgrag/scripts/create_tiny_dataset.py b/docs/examples/kgrag/scripts/create_tiny_dataset.py new file mode 100644 index 0000000..014a0e7 --- /dev/null +++ b/docs/examples/kgrag/scripts/create_tiny_dataset.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +"""Create a tiny dataset for quick testing (10-20 documents). + +This script creates an extremely small dataset for rapid testing and development. +It takes the first N documents from the full CRAG dataset. + +Usage: + python create_tiny_dataset.py --num-docs 10 + python create_tiny_dataset.py --num-docs 20 --output ../data/crag_movie_tiny.jsonl.bz2 +""" + +import argparse +import bz2 +import json +from pathlib import Path + + +def parse_arguments() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Create a tiny dataset for quick testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Create 10-document dataset + %(prog)s --num-docs 10 + + # Create 20-document dataset + %(prog)s --num-docs 20 + + # Create 5-document dataset for ultra-fast testing + %(prog)s --num-docs 5 --output ../data/crag_movie_micro.jsonl.bz2 + """ + ) + + parser.add_argument( + "--num-docs", + type=int, + default=10, + help="Number of documents to include (default: 10)" + ) + + parser.add_argument( + "--input", + type=str, + default="../dataset/crag_movie_dev.jsonl.bz2", + help="Input dataset file (default: ../dataset/crag_movie_dev.jsonl.bz2)" + ) + + parser.add_argument( + "--output", + type=str, + default="../data/crag_movie_tiny.jsonl.bz2", + help="Output dataset file (default: ../data/crag_movie_tiny.jsonl.bz2)" + ) + + return parser.parse_args() + + +def main(): + """Main entry point.""" + args = parse_arguments() + + print("=" * 60) + print("Creating Tiny Test Dataset") + print("=" * 60) + print(f"Input: {args.input}") + print(f"Output: {args.output}") + print(f"Number of documents: {args.num_docs}") + print("=" * 60) + + # Ensure output directory exists + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + # Read first N documents from input + documents = [] + try: + with bz2.open(args.input, 'rt', encoding='utf-8') as f: + for i, line in enumerate(f): + if i >= args.num_docs: + break + try: + doc = json.loads(line.strip()) + documents.append(doc) + except json.JSONDecodeError as e: + print(f"Warning: Failed to parse line {i+1}: {e}") + continue + + print(f"\nRead {len(documents)} documents from input") + + # Write to output + with bz2.open(args.output, 'wt', encoding='utf-8') as f: + for doc in documents: + f.write(json.dumps(doc, ensure_ascii=False) + '\n') + + print(f"Wrote {len(documents)} documents to {args.output}") + + # Print some statistics + print("\n" + "=" * 60) + print("Tiny Dataset Created Successfully!") + print("=" * 60) + print(f"Total documents: {len(documents)}") + print(f"Output file: {args.output}") + + # Show first document as example + if documents: + print("\nFirst document fields:") + for key in list(documents[0].keys())[:5]: # Show first 5 fields + print(f" - {key}") + + print("\nTo use this dataset, either:") + print(" 1. Update KG_BASE_DIRECTORY in your .env to point to the dataset directory") + print(f" 2. Or pass --dataset {args.output} to run_kg_update.py") + + return 0 + + except FileNotFoundError: + print(f"\nError: Input file not found: {args.input}") + print(f"\nMake sure you have the full dataset at:") + print(f" {Path(args.input).resolve()}") + print(f"\nCurrent working directory: {Path.cwd()}") + return 1 + except Exception as e: + print(f"\nError: {e}") + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/docs/examples/kgrag/scripts/run.sh b/docs/examples/kgrag/scripts/run.sh new file mode 100755 index 0000000..e58dbcb --- /dev/null +++ b/docs/examples/kgrag/scripts/run.sh @@ -0,0 +1,257 @@ +#!/bin/bash +# run.sh — Orchestrates the five-stage KG-RAG pipeline for the movie domain. +# +# Before running, configure your LLM and graph database: +# cp ../env_template ../.env && editor ../.env +# +# Requires: Python env with mellea-contribs[kg] installed, and a running +# graph database (or pass --mock for local testing without any database). +set -e # Exit the script (not the terminal) on any command failure + +# Change to the script's directory to ensure correct module paths +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +cd "$SCRIPT_DIR" + +# Get the parent directory (kgrag root) +KGRAG_ROOT="$(cd .. && pwd)" + +# --------------------------------------------------------------------------- +# Usage +# --------------------------------------------------------------------------- +usage() { + cat < "$KGRAG_ROOT/output/preprocess_stats.json" + echo "✓ Movie database loaded" +fi + +# --------------------------------------------------------------------------- +# Step 2: Compute entity embeddings +# --------------------------------------------------------------------------- +if run_step 2; then + echo "" + echo "Step 2: Running KG embedding on loaded entities..." + uv run --with mellea-contribs[kg] run_kg_embed.py \ + --db-uri "$NEO4J_URI" \ + --db-user "$NEO4J_USER" \ + --db-password "$NEO4J_PASSWORD" \ + --batch-size 100 > "$KGRAG_ROOT/output/embedding_stats.json" + echo "✓ Entity embeddings computed" +fi + +# --------------------------------------------------------------------------- +# Step 3: Update KG with documents +# --------------------------------------------------------------------------- +if run_step 3; then + echo "" + echo "Step 3: Updating Knowledge Graph with documents ($DATASET_LABEL)..." + uv run --with mellea-contribs[kg] run_kg_update.py \ + --dataset "$ACTIVE_DATASET" \ + --domain movie \ + --db-uri "$NEO4J_URI" \ + --db-user "$NEO4J_USER" \ + --db-password "$NEO4J_PASSWORD" \ + --num-workers 10 \ + --verbose > "$KGRAG_ROOT/output/update_stats.json" + echo "✓ Knowledge Graph updated with documents" +fi + +# --------------------------------------------------------------------------- +# Step 4: Run QA +# --------------------------------------------------------------------------- +if run_step 4; then + echo "" + echo "Step 4: Running QA ($DATASET_LABEL)..." + uv run --with mellea-contribs[kg] run_qa.py \ + --dataset "$ACTIVE_DATASET" \ + --output "$KGRAG_ROOT/output/qa_results.jsonl" \ + --progress "$KGRAG_ROOT/output/qa_progress.json" \ + --reset-progress \ + --domain movie \ + --routes 3 \ + --width 30 \ + --depth 3 \ + --db-uri "$NEO4J_URI" \ + --db-user "$NEO4J_USER" \ + --db-password "$NEO4J_PASSWORD" + echo "✓ QA completed" +fi + +# --------------------------------------------------------------------------- +# Step 5: Evaluate QA results +# --------------------------------------------------------------------------- +if run_step 5; then + echo "" + echo "Step 5: Evaluating QA results with LLM judge..." + uv run --with mellea-contribs[kg] run_eval.py \ + --input "$KGRAG_ROOT/output/qa_results.jsonl" \ + --output "$KGRAG_ROOT/output/eval_results.json" \ + --metrics "$KGRAG_ROOT/output/eval_metrics.json" + echo "✓ Evaluation completed" +fi + +# --------------------------------------------------------------------------- +# Summary +# --------------------------------------------------------------------------- +echo "" +echo "==================================================" +echo "✅ KG-RAG Pipeline Execution Completed!" +echo "==================================================" +echo "Steps run: ${STEPS[*]}" +echo "Dataset: $DATASET_LABEL ($ACTIVE_DATASET)" +echo "" +echo "Graph DB is running at: $NEO4J_URI" +echo "Logs saved to: $KGRAG_ROOT/output/" +echo " - preprocess_stats.json (step 1)" +echo " - embedding_stats.json (step 2)" +echo " - update_stats.json (step 3)" +echo " - qa_results.jsonl (step 4)" +echo " - qa_progress.json (step 4)" +echo " - eval_results.json (step 5)" +echo " - eval_metrics.json (step 5)" +echo "==================================================" diff --git a/docs/examples/kgrag/scripts/run_eval.py b/docs/examples/kgrag/scripts/run_eval.py new file mode 100644 index 0000000..3d125ef --- /dev/null +++ b/docs/examples/kgrag/scripts/run_eval.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +"""Evaluate QA results using LLM-based judgement. + +Reads a QA results JSONL file (produced by ``run_qa.py``), evaluates each +prediction against the ground-truth answer using a combination of exact +match, fuzzy match, and LLM judgement, then outputs CRAG-style metrics. + +Configuration is driven by the same environment variables as the other +scripts: + +.. code-block:: bash + + # Optional — separate eval model (defaults to main API_BASE / MODEL_NAME) + export EVAL_API_BASE=https://your-rits-endpoint/v1 + export EVAL_API_KEY=your-api-key + export EVAL_MODEL_NAME=meta-llama/llama-3-70b-instruct + +Use ``--mock`` to skip LLM calls and evaluate with fuzzy match only +(useful for local testing without a live LLM endpoint). + +Example:: + + python run_eval.py \\ + --input ../output/qa_results.jsonl \\ + --output ../output/eval_results.json \\ + --metrics ../output/eval_metrics.json + +Output JSON format (``--metrics``):: + + { + "total": 100, + "n_correct": 72, + "n_miss": 5, + "n_hallucination": 23, + "accuracy": 72.0, + "score": 49.0, + "hallucination": 23.0, + "missing": 5.0, + "eval_model": "meta-llama/llama-3-70b-instruct" + } +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict, List + +from dotenv import load_dotenv + +from mellea_contribs.kg.utils import ( + create_session_from_env, + evaluate_predictions, + load_jsonl, + log_progress, + setup_logging, +) + + +def compute_crag_metrics( + evaluated: List[Dict[str, Any]], + eval_model: str = "", +) -> Dict[str, Any]: + """Compute CRAG-style evaluation metrics. + + Args: + evaluated: List of evaluated result dicts (each has a ``"correct"`` field). + eval_model: Model name used for evaluation (for logging). + + Returns: + Dict with total, n_correct, n_miss, n_hallucination, accuracy, score, + hallucination, missing, and eval_model fields. + """ + n = len(evaluated) + if n == 0: + return {"total": 0, "eval_model": eval_model} + + n_correct = sum(1 for r in evaluated if r.get("correct")) + # "I don't know" responses count as missing (not hallucination) + n_miss = sum( + 1 for r in evaluated + if "i don't know" in str(r.get("predicted", "")).lower() + and not r.get("correct") + ) + n_hallucination = n - n_correct - n_miss + + accuracy = (n_correct / n) * 100.0 + # CRAG score formula: penalises hallucination more than missing answers + crag_score = ((2 * n_correct + n_miss) / n - 1) * 100.0 + + return { + "total": n, + "n_correct": n_correct, + "n_miss": n_miss, + "n_hallucination": n_hallucination, + "accuracy": round(accuracy, 2), + "score": round(crag_score, 2), + "hallucination": round((n_hallucination / n) * 100.0, 2), + "missing": round((n_miss / n) * 100.0, 2), + "eval_model": eval_model, + } + + +async def main() -> None: + """Entry point.""" + env_path = Path(__file__).parent.parent / ".env" + if env_path.exists(): + load_dotenv(env_path, override=False) + + parser = argparse.ArgumentParser( + description="Evaluate QA results with LLM-based judgement" + ) + parser.add_argument( + "--input", + type=str, + required=True, + help="Input JSONL file with QA results (from run_qa.py)", + ) + parser.add_argument( + "--output", + type=str, + default="", + help="Output JSONL file with per-item evaluation scores added", + ) + parser.add_argument( + "--metrics", + type=str, + default="", + help="Output JSON file for aggregate CRAG metrics", + ) + parser.add_argument( + "--mock", + action="store_true", + help="Skip LLM evaluation; use fuzzy match only", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + args = parser.parse_args() + + setup_logging(log_level=args.log_level) + + # ------------------------------------------------------------------ + # Load QA results + # ------------------------------------------------------------------ + input_path = Path(args.input) + if not input_path.exists(): + log_progress(f"Input file not found: {input_path}", level="ERROR") + sys.exit(1) + + results: List[Dict[str, Any]] = list(load_jsonl(input_path)) + log_progress(f"Loaded {len(results)} results from {input_path}") + + if not results: + log_progress("No results to evaluate.", level="WARNING") + sys.exit(0) + + # ------------------------------------------------------------------ + # LLM-based evaluation + # ------------------------------------------------------------------ + eval_model = "" + + if args.mock: + log_progress("--mock: skipping LLM evaluation; using fuzzy match only.") + # evaluate_predictions handles the None session path gracefully via + # exact/fuzzy match before the LLM branch — pass a dummy session. + session = None + evaluated = await evaluate_predictions( + session=session, + predictions=results, + query_key="query", + answer_key="predicted", + gold_key="answer_aliases", + ) + else: + # Use EVAL_* vars if set, otherwise fall back to main session vars. + if os.getenv("EVAL_API_BASE"): + session, eval_model = create_session_from_env(env_prefix="EVAL_") + else: + session, eval_model = create_session_from_env() + log_progress(f"Evaluating with model: {eval_model}") + + evaluated = await evaluate_predictions( + session=session, + predictions=results, + query_key="query", + answer_key="predicted", + gold_key="answer_aliases", + ) + + # ------------------------------------------------------------------ + # Compute and display metrics + # ------------------------------------------------------------------ + metrics = compute_crag_metrics(evaluated, eval_model=eval_model) + + log_progress("=" * 50) + log_progress("Evaluation Results") + log_progress("=" * 50) + log_progress(f"Total questions : {metrics['total']}") + log_progress(f"Correct : {metrics['n_correct']}") + log_progress(f"Hallucination : {metrics['n_hallucination']}") + log_progress(f"Missing : {metrics['n_miss']}") + log_progress(f"Accuracy : {metrics['accuracy']:.1f}%") + log_progress(f"CRAG Score : {metrics['score']:.1f}") + log_progress("=" * 50) + + # ------------------------------------------------------------------ + # Save annotated results + # ------------------------------------------------------------------ + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as fh: + for item in evaluated: + fh.write(json.dumps(item, ensure_ascii=False, default=str) + "\n") + log_progress(f"Annotated results saved to {out_path}") + + # ------------------------------------------------------------------ + # Save aggregate metrics + # ------------------------------------------------------------------ + if args.metrics: + metrics_path = Path(args.metrics) + metrics_path.parent.mkdir(parents=True, exist_ok=True) + with open(metrics_path, "w", encoding="utf-8") as fh: + json.dump(metrics, fh, indent=2) + log_progress(f"Metrics saved to {metrics_path}") + + # Always print metrics to stdout as JSON + print(json.dumps(metrics, indent=2)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/examples/kgrag/scripts/run_kg_embed.py b/docs/examples/kgrag/scripts/run_kg_embed.py new file mode 100644 index 0000000..3f28451 --- /dev/null +++ b/docs/examples/kgrag/scripts/run_kg_embed.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +"""Knowledge Graph Embedding Script. + +Generates and stores embeddings for all graph components: +- Entity nodes (Movie, Person, Genre) +- Relations (ACTED_IN, DIRECTED, BELONGS_TO_GENRE) +- Stores embeddings back to the graph database with vector indices + +Delegates to :class:`~mellea_contribs.kg.embedder.KGEmbedder` which handles +fetch / embed / store / create-index in a single +:meth:`~mellea_contribs.kg.embedder.KGEmbedder.embed_and_store_all` call. + +Usage: + python run_kg_embed.py --db-uri bolt://localhost:7687 + python run_kg_embed.py --mock # Mock backend (no actual embedding) + python run_kg_embed.py --batch-size 100 --model text-embedding-3-large +""" + +import argparse +import asyncio +import json +import os +import sys + +from dotenv import load_dotenv + +from mellea_contribs.kg.embedder import KGEmbedder +from mellea_contribs.kg.utils import ( + create_backend, + create_session_from_env, + log_progress, + output_json, + print_stats, +) + + +async def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Embed all KG entities/relations and store them in the graph", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s --db-uri bolt://localhost:7687 + %(prog)s --batch-size 500 --model text-embedding-3-large + %(prog)s --mock # Mock backend (no actual embedding) + """, + ) + + # Backend configuration + parser.add_argument( + "--db-uri", + type=str, + default=os.getenv("NEO4J_URI", "bolt://localhost:7687"), + help="Graph database connection URI (default: $NEO4J_URI or bolt://localhost:7687)", + ) + parser.add_argument( + "--db-user", + type=str, + default=os.getenv("NEO4J_USER", "neo4j"), + help="Graph database username (default: $NEO4J_USER or neo4j)", + ) + parser.add_argument( + "--db-password", + type=str, + default=os.getenv("NEO4J_PASSWORD", "password"), + help="Graph database password (default: $NEO4J_PASSWORD or password)", + ) + parser.add_argument( + "--mock", + action="store_true", + help="Use MockGraphBackend (no graph database needed)", + ) + + # Embedding configuration + parser.add_argument( + "--model", + type=str, + default="text-embedding-3-small", + help="Embedding model (default: text-embedding-3-small)", + ) + parser.add_argument( + "--dimension", + type=int, + default=1536, + help="Embedding dimension (default: 1536)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Batch size for progress logging (default: 100)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Load .env from the parent directory (docs/examples/kgrag/.env) + env_path = os.path.join(os.path.dirname(__file__), "..", ".env") + load_dotenv(env_path, override=False) + + backend = create_backend( + backend_type="neo4j" if not args.mock else "mock", + neo4j_uri=args.db_uri, + neo4j_user=args.db_user, + neo4j_password=args.db_password, + ) + session, _ = create_session_from_env() + + emb_api_base = os.getenv("EMB_API_BASE") + emb_api_key = os.getenv("EMB_API_KEY") + emb_model = os.getenv("EMB_MODEL_NAME", args.model) + emb_dimension = int(os.getenv("VECTOR_DIMENSIONS", str(args.dimension))) + # RITS authenticates via a custom header; fall back to the primary key. + rits_api_key = os.getenv("EMB_RITS_API_KEY") or os.getenv("RITS_API_KEY") + extra_headers = {"RITS_API_KEY": rits_api_key} if rits_api_key else {} + + embedder = KGEmbedder( + session=session, + model=emb_model, + dimension=emb_dimension, + api_base=emb_api_base, + api_key=emb_api_key, + extra_headers=extra_headers, + batch_size=args.batch_size, + backend=backend, + ) + + try: + log_progress("=" * 60) + log_progress("KG Embedding Pipeline") + log_progress("=" * 60) + + stats = await embedder.embed_and_store_all(batch_size=args.batch_size) + + log_progress("=" * 60) + log_progress("EMBEDDING SUMMARY") + log_progress("=" * 60) + print_stats(stats) + log_progress("=" * 60) + + output_json(stats) + sys.exit(0 if stats.success else 1) + + except KeyboardInterrupt: + log_progress("\n⚠️ Embedding interrupted by user") + sys.exit(130) + except Exception as e: + log_progress(f"❌ Embedding failed: {e}") + if args.verbose: + import traceback + traceback.print_exc() + sys.exit(1) + finally: + await backend.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/examples/kgrag/scripts/run_kg_preprocess.py b/docs/examples/kgrag/scripts/run_kg_preprocess.py new file mode 100644 index 0000000..f3acced --- /dev/null +++ b/docs/examples/kgrag/scripts/run_kg_preprocess.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +"""Knowledge Graph Preprocessing from Predefined Data. + +Loads movie and person databases and inserts them into the graph database using +:class:`~preprocessor.movie_preprocessor.PredefinedDataPreprocessor`. + +Usage: + python run_kg_preprocess.py --data-dir ./dataset/movie --db-uri bolt://localhost:7687 + python run_kg_preprocess.py --data-dir ./dataset/movie --mock + python run_kg_preprocess.py --data-dir ./dataset/movie --verbose +""" + +import argparse +import asyncio +import json +import os +import sys +from pathlib import Path + +from mellea_contribs.kg.utils import ( + create_backend, + create_session, + log_progress, +) + +from preprocessor.movie_preprocessor import PredefinedDataPreprocessor + + +async def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Preprocess and load predefined movie data into KG", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s --data-dir ./data/movie # Load from data directory + %(prog)s --data-dir ./data/movie --mock # Use mock backend + %(prog)s --data-dir ./data/movie --db-uri bolt://localhost:7687 # Custom graph DB URI + %(prog)s --data-dir ./data/movie --verbose # Verbose logging + """, + ) + + parser.add_argument( + "--data-dir", + type=str, + required=True, + help="Directory containing movie_db.json and person_db.json", + ) + parser.add_argument( + "--db-uri", + type=str, + default=os.getenv("NEO4J_URI", "bolt://localhost:7687"), + help="Graph database connection URI (default: $NEO4J_URI or bolt://localhost:7687)", + ) + parser.add_argument( + "--db-user", + type=str, + default=os.getenv("NEO4J_USER", "neo4j"), + help="Graph database username (default: $NEO4J_USER or neo4j)", + ) + parser.add_argument( + "--db-password", + type=str, + default=os.getenv("NEO4J_PASSWORD", "password"), + help="Graph database password (default: $NEO4J_PASSWORD or password)", + ) + parser.add_argument( + "--mock", + action="store_true", + help="Use MockGraphBackend instead of the graph database (no database needed)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Batch size for inserting entities (default: 50)", + ) + parser.add_argument( + "--model", + type=str, + default="gpt-4o-mini", + help="LLM model to use (default: gpt-4o-mini)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + data_dir = Path(args.data_dir) + if not data_dir.exists(): + log_progress(f"ERROR: Data directory not found: {data_dir}") + sys.exit(1) + + for filename in ["movie_db.json", "person_db.json"]: + if not (data_dir / filename).exists(): + log_progress(f"ERROR: Required file not found: {data_dir / filename}") + sys.exit(1) + + try: + backend = create_backend( + backend_type="neo4j" if not args.mock else "mock", + neo4j_uri=args.db_uri, + neo4j_user=args.db_user, + neo4j_password=args.db_password, + ) + _ = create_session(model_id=args.model) + + log_progress("=" * 60) + log_progress("KG Preprocessing from Predefined Data") + log_progress("=" * 60) + log_progress(f"Data directory: {data_dir}") + log_progress(f"Backend: {'Mock' if args.mock else 'Graph DB'}") + log_progress("") + + preprocessor = PredefinedDataPreprocessor( + backend=backend, + data_dir=data_dir, + batch_size=args.batch_size, + ) + + stats = await preprocessor.preprocess() + + log_progress("") + log_progress("=" * 60) + log_progress("PREPROCESSING SUMMARY") + log_progress("=" * 60) + log_progress(str(stats)) + log_progress("=" * 60) + log_progress("") + + print(json.dumps(stats.to_dict())) + sys.exit(0 if stats.success else 1) + + except KeyboardInterrupt: + log_progress("\n⚠️ Preprocessing interrupted by user") + sys.exit(130) + except Exception as e: + log_progress(f"❌ Preprocessing failed: {e}") + if args.verbose: + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/examples/kgrag/scripts/run_kg_update.py b/docs/examples/kgrag/scripts/run_kg_update.py new file mode 100644 index 0000000..90793f0 --- /dev/null +++ b/docs/examples/kgrag/scripts/run_kg_update.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +"""Knowledge Graph Update Script. + +Updates the knowledge graph by processing documents and extracting +entities and relations. + +Usage: + python run_kg_update.py --domain movie --progress-path results/progress.json + python run_kg_update.py --dataset data/corpus.jsonl.bz2 --num-workers 64 + python run_kg_update.py --mock --verbose +""" + +import argparse +import asyncio +import os +import sys +import time +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv + +from mellea_contribs.kg.kgrag import orchestrate_kg_update +from mellea_contribs.kg.updater_models import ( + KGUpdateRunConfig, + UpdateBatchResult, + UpdateResult, + UpdateStats, +) +from mellea_contribs.kg.utils import ( + BaseProgressLogger, + create_backend, + create_session_from_env, + load_jsonl, + log_progress, + output_json, + print_stats, + setup_logging, +) + + +def parse_arguments() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Update knowledge graph from documents", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s --dataset data/corpus.jsonl.bz2 --mock + %(prog)s --num-workers 32 --queue-size 32 + %(prog)s --domain movie --progress-path results/progress.json + %(prog)s --verbose --mock + """, + ) + + parser.add_argument( + "--dataset", type=str, default=None, + help="Path to dataset file (overrides env KG_BASE_DIRECTORY)", + ) + parser.add_argument( + "--domain", type=str, default="movie", + help="Knowledge domain (default: movie)", + ) + parser.add_argument( + "--num-workers", type=int, default=64, + help="Number of concurrent workers (default: 64)", + ) + parser.add_argument( + "--queue-size", type=int, default=64, + help="Queue size for data loading (default: 64)", + ) + parser.add_argument( + "--db-uri", type=str, + default=os.getenv("NEO4J_URI", "bolt://localhost:7687"), + help="Graph database connection URI (default: $NEO4J_URI or bolt://localhost:7687)", + ) + parser.add_argument( + "--db-user", type=str, + default=os.getenv("NEO4J_USER", "neo4j"), + help="Graph database username (default: $NEO4J_USER or neo4j)", + ) + parser.add_argument( + "--db-password", type=str, + default=os.getenv("NEO4J_PASSWORD", "password"), + help="Graph database password (default: $NEO4J_PASSWORD or password)", + ) + parser.add_argument( + "--mock", action="store_true", + help="Use MockGraphBackend instead of the graph database", + ) + parser.add_argument( + "--model", type=str, default="gpt-4o-mini", + help="LLM model to use (default: gpt-4o-mini)", + ) + parser.add_argument( + "--extraction-loop-budget", type=int, default=3, + help="Entity/relation extraction loop budget (default: 3)", + ) + parser.add_argument( + "--alignment-loop-budget", type=int, default=2, + help="Alignment refinement loop budget (default: 2)", + ) + parser.add_argument( + "--align-topk", type=int, default=10, + help="Number of top candidates for alignment (default: 10)", + ) + parser.add_argument( + "--progress-path", type=str, default="results/update_kg_progress.json", + help="Progress log file path", + ) + parser.add_argument( + "--verbose", "-v", action="store_true", + help="Enable verbose logging", + ) + + return parser.parse_args() + + +async def process_document( + doc_id: str, + text: str, + backend: Any, + session: Any, + domain: str, + model: str, + progress_tracker: BaseProgressLogger, +) -> UpdateResult: + """Process a single document and update the KG. + + Args: + doc_id: Document ID. + text: Document text. + backend: Graph backend. + session: Mellea session. + domain: Knowledge domain. + model: LLM model name. + progress_tracker: Progress tracker. + + Returns: + UpdateResult with processing details. + """ + start_time = time.perf_counter() + log_progress(f"[{doc_id[:12]}] Starting...") + + try: + update_result = await orchestrate_kg_update( + session=session, + backend=backend, + doc_text=text, + domain=domain, + hints="", + entity_types="", + relation_types="", + ) + + elapsed_time = time.perf_counter() - start_time + entities_found = len(update_result.get("extracted_entities", [])) + relations_found = len(update_result.get("extracted_relations", [])) + + result = UpdateResult( + document_id=doc_id, + success=True, + entities_found=entities_found, + relations_found=relations_found, + entities_added=entities_found, + relations_added=relations_found, + processing_time_ms=elapsed_time * 1000, + model_used=model, + ) + + log_progress( + f"[{doc_id[:12]}] Done — {entities_found} entities, " + f"{relations_found} relations ({elapsed_time * 1000:.0f}ms)" + ) + + progress_tracker.add_stat({ + "doc_id": doc_id, + "entities_extracted": entities_found, + "entities_new": entities_found, + "relations_extracted": relations_found, + "relations_new": relations_found, + "processing_time": round(elapsed_time, 2), + }) + progress_tracker.mark_processed(doc_id) + + return result + + except Exception as e: + import traceback + elapsed_time = time.perf_counter() - start_time + log_progress(f"[{doc_id[:12]}] ERROR: {type(e).__name__}: {e}", level="ERROR") + log_progress(traceback.format_exc(), level="DEBUG") + return UpdateResult( + document_id=doc_id, + success=False, + error=str(e), + processing_time_ms=elapsed_time * 1000, + model_used=model, + ) + + +async def process_dataset( + dataset_path: Path, + config: KGUpdateRunConfig, + progress_tracker: BaseProgressLogger, +) -> UpdateBatchResult: + """Process the entire dataset with parallel workers. + + Args: + dataset_path: Path to the JSONL dataset file. + config: Run configuration. + progress_tracker: Progress tracker. + + Returns: + Aggregated batch result. + """ + backend = create_backend( + backend_type="neo4j" if not config.mock else "mock", + neo4j_uri=config.db_uri, + neo4j_user=config.db_user, + neo4j_password=config.db_password, + ) + + session, model_id = create_session_from_env(default_model=config.model) + log_progress(f"Using model: {model_id}, API base: {os.getenv('API_BASE') or '(default)'}") + + batch_result = UpdateBatchResult() + tasks = [] + semaphore = asyncio.Semaphore(config.num_workers) + + async def process_with_semaphore(doc_id: str, text: str) -> UpdateResult: + async with semaphore: + return await process_document( + doc_id=doc_id, text=text, backend=backend, session=session, + domain=config.domain, model=model_id, progress_tracker=progress_tracker, + ) + + try: + doc_num = 0 + for doc_num, doc in enumerate(load_jsonl(dataset_path), 1): + doc_id = doc.get("id") or doc.get("interaction_id") or f"doc_{doc_num}" + text = doc.get("text") or doc.get("query") or doc.get("context") or "" + if not text: + log_progress(f"[{doc_num}] WARNING: Empty text for {doc_id}") + continue + tasks.append(process_with_semaphore(doc_id, text)) + + total_tasks = len(tasks) + completed_count = 0 + + async def _tracked(coro: Any) -> UpdateResult: + nonlocal completed_count + result = await coro + completed_count += 1 + status = "✓" if result.success else "✗" + log_progress(f"[{completed_count}/{total_tasks}] {status} {result.document_id[:12]}") + return result + + log_progress(f"Processing {total_tasks} documents with {config.num_workers} workers...") + results = list(await asyncio.gather(*[_tracked(t) for t in tasks])) + + for result in results: + if result.success: + batch_result.successful_documents += 1 + else: + batch_result.failed_documents += 1 + + finally: + await backend.close() + + batch_result.total_documents = len(results) + batch_result.results = results + + if results: + stats = UpdateStats() + stats.total_documents = len(results) + stats.successful_documents = batch_result.successful_documents + stats.failed_documents = batch_result.failed_documents + for result in results: + stats.entities_extracted += result.entities_found + stats.relations_extracted += result.relations_found + stats.entities_new += result.entities_added + stats.relations_new += result.relations_added + batch_result.total_time_ms = sum(r.processing_time_ms for r in results) + batch_result.avg_time_per_document_ms = ( + batch_result.total_time_ms / len(results) + ) + stats.total_processing_time_ms = batch_result.total_time_ms + stats.average_processing_time_per_doc_ms = batch_result.avg_time_per_document_ms + batch_result.stats = stats + + return batch_result + + +def load_env_file() -> None: + """Load environment variables from .env file in parent directory.""" + script_dir = Path(__file__).parent + env_path = script_dir.parent / ".env" + if env_path.exists(): + log_progress(f"Loading environment from: {env_path}") + load_dotenv(env_path, override=False) + else: + log_progress(f"⚠️ .env not found at {env_path} (optional)") + + +async def main() -> int: + """Main async entry point.""" + setup_logging(log_level="INFO") + load_env_file() + args = parse_arguments() + + config = KGUpdateRunConfig( + model=args.model, + num_workers=args.num_workers, + queue_size=args.queue_size, + extraction_loop_budget=args.extraction_loop_budget, + alignment_loop_budget=args.alignment_loop_budget, + align_topk=args.align_topk, + domain=args.domain, + progress_path=args.progress_path, + db_uri=args.db_uri, + db_user=args.db_user, + db_password=args.db_password, + mock=args.mock, + verbose=args.verbose, + ) + + # Resolve dataset path + if args.dataset: + config.dataset_path = args.dataset + else: + base_dir = os.getenv( + "KG_BASE_DIRECTORY", + os.path.join(os.path.dirname(__file__), "..", "dataset"), + ) + config.dataset_path = os.path.join(base_dir, "crag_movie_dev.jsonl.bz2") + + if not Path(config.dataset_path).exists(): + log_progress(f"ERROR: Dataset not found: {config.dataset_path}") + return 1 + + try: + progress_tracker = BaseProgressLogger(config.progress_path) + progress_tracker.load() + if progress_tracker.num_processed: + log_progress(f"Resuming: {progress_tracker.num_processed} documents already processed.") + + log_progress("=" * 60) + log_progress("KG Update Configuration:") + log_progress("=" * 60) + log_progress(f"Dataset: {config.dataset_path}") + log_progress(f"Domain: {config.domain}") + log_progress(f"Workers: {config.num_workers}") + log_progress(f"Queue size: {config.queue_size}") + log_progress(f"Extraction loop budget: {config.extraction_loop_budget}") + log_progress(f"Alignment loop budget: {config.alignment_loop_budget}") + log_progress(f"Top-K candidates: {config.align_topk}") + log_progress(f"Model: {config.model}") + log_progress(f"Backend: {'Mock' if config.mock else 'Graph DB'}") + log_progress(f"Progress: {config.progress_path}") + log_progress("=" * 60) + + Path("results").mkdir(exist_ok=True) + + log_progress("Starting KG update...") + batch_result = await process_dataset(Path(config.dataset_path), config, progress_tracker) + + progress_tracker.save() + + log_progress("=" * 60) + log_progress("✅ KG Update Completed Successfully!") + log_progress("=" * 60) + log_progress(f"Processed documents: {batch_result.total_documents}") + log_progress(f"Successful: {batch_result.successful_documents}") + log_progress(f"Failed: {batch_result.failed_documents}") + if batch_result.stats: + log_progress(f"Total entities: {batch_result.stats.entities_extracted}") + log_progress(f"Total relations: {batch_result.stats.relations_extracted}") + log_progress(f"Average time per doc: {batch_result.avg_time_per_document_ms:.2f}ms") + log_progress(f"Progress saved to: {config.progress_path}") + log_progress("=" * 60) + + if batch_result.stats: + print_stats(batch_result.stats) + output_json(batch_result) + + return 0 + + except KeyboardInterrupt: + log_progress("\n⚠️ KG update interrupted by user") + return 130 + except Exception as e: + log_progress(f"❌ KG update failed: {e}") + if args.verbose: + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/docs/examples/kgrag/scripts/run_qa.py b/docs/examples/kgrag/scripts/run_qa.py new file mode 100644 index 0000000..ee8211c --- /dev/null +++ b/docs/examples/kgrag/scripts/run_qa.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +"""Run QA on CRAG movie questions via the Think-on-Graph pipeline. + +Reads questions from a JSONL file, answers each one using +``orchestrate_qa_retrieval``, and writes per-question results to an output +JSONL file. Progress is persisted so interrupted runs can resume from where +they left off. + +Three independent sessions are created: + +* **main session** — question decomposition, entity alignment, relation / + triplet pruning. +* **eval session** — knowledge sufficiency evaluation, consensus validation, + direct-answer fallback. Defaults to the main session when not configured. +* **embedding client** — async OpenAI-compatible client for vector-based + entity alignment. Optional; falls back to fuzzy name search only. + +Configuration is driven by environment variables so the script works +transparently in any containerised environment. The variable names +mirror those used by ``run_kg_update.py``: + +.. code-block:: bash + + # Required — any OpenAI-compatible endpoint (OpenAI, vLLM, Ollama, Azure, etc.) + export API_BASE=https://your-llm-endpoint/v1 + export API_KEY=your-api-key + export MODEL_NAME=meta-llama/llama-3-70b-instruct # or gpt-4o-mini etc. + + # Optional — separate eval model (defaults to main session) + export EVAL_API_BASE=... + export EVAL_API_KEY=... + export EVAL_MODEL_NAME=... + + # Optional — embedding model for vector entity alignment + export EMB_API_BASE=... + export EMB_API_KEY=... + export EMB_MODEL_NAME=text-embedding-3-small + +When ``API_BASE`` is set the session uses ``OpenAIBackend`` directly, which +works for any OpenAI-compatible endpoint (OpenAI, vLLM, Azure, Ollama, etc.) +regardless of model ID. Use ``--mock`` to skip LLM calls entirely during +local testing. + + python run_qa.py \\ + --dataset ../dataset/crag_movie_tiny.jsonl.bz2 \\ + --output /tmp/qa_results.jsonl \\ + --progress /tmp/qa_progress.json \\ + --mock + +Output JSONL format (one JSON object per line):: + + { + "id": "q_0", + "query": "Who directed Inception?", + "query_time": "2024-03-05", + "predicted": "Christopher Nolan", + "answer": "Christopher Nolan", + "answer_aliases": ["Christopher Nolan", "Nolan"], + "correct": true, + "eval_method": "exact", + "elapsed_ms": 1234.5 + } +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import sys +import time +from pathlib import Path +from typing import Any, Dict, Optional + +from dotenv import load_dotenv + +# --------------------------------------------------------------------------- +# Project imports +# --------------------------------------------------------------------------- +from mellea_contribs.kg.kgrag import orchestrate_qa_retrieval +from mellea_contribs.kg.utils import ( + QAProgressLogger, + create_backend, + create_embedding_client, + create_session_from_env, + evaluate_predictions, + log_progress, + setup_logging, +) + +sys.path.insert(0, str(Path(__file__).parent.parent)) +from dataset.movie_dataset_loader import MovieDatasetLoader + +# --------------------------------------------------------------------------- +# Session / config helpers +# --------------------------------------------------------------------------- + +_HINTS = ( + "This is a movie-domain knowledge graph containing Movies, Persons, Awards, " + "and Year nodes. Common relation types: acted_in, directed_by, produced_by, " + "nominated_for, won, released_in." +) + +_DEFAULT_MODEL = "gpt-4o-mini" + + + + +def create_emb_client_optional(): + """Create an embedding client if ``EMB_API_BASE`` is set.""" + api_base = os.getenv("EMB_API_BASE") + if not api_base: + return None + return create_embedding_client( + api_base=api_base, + api_key=os.getenv("EMB_API_KEY", "dummy"), + model_name=os.getenv("EMB_MODEL_NAME", "text-embedding-3-small"), + ) + + +# --------------------------------------------------------------------------- +# Per-question processing +# --------------------------------------------------------------------------- + + +async def process_question( + item: Dict[str, Any], + *, + backend, + session, + eval_session, + emb_client, + domain: str, + num_routes: int, + width: int, + depth: int, +) -> Dict[str, Any]: + """Answer one QA item and return a result dict. + + Args: + item: Normalised QA item from ``MovieDatasetLoader``. + backend: Graph database backend. + session: Primary Mellea session. + eval_session: Eval Mellea session (may equal ``session``). + emb_client: Optional embedding client. + domain: Domain hint string. + num_routes: Number of solving routes. + width: ToG traversal width. + depth: ToG traversal depth. + + Returns: + Result dict with ``id``, ``query``, ``predicted``, timing, and + correctness fields. + """ + t0 = time.perf_counter() + error: Optional[str] = None + predicted = "" + + try: + predicted = await orchestrate_qa_retrieval( + session=session, + backend=backend, + query=item["query"], + query_time=item.get("query_time", ""), + domain=domain, + num_routes=num_routes, + hints=_HINTS, + eval_session=eval_session, + emb_client=emb_client, + width=width, + depth=depth, + ) + except Exception as exc: + error = str(exc) + log_progress(f" ERROR [{item['id']}]: {exc}", level="WARNING") + + elapsed_ms = (time.perf_counter() - t0) * 1000 + + return { + "id": item["id"], + "query": item["query"], + "query_time": item.get("query_time", ""), + "predicted": predicted, + "answer": item.get("answer", ""), + "answer_aliases": item.get("answer_aliases", []), + "elapsed_ms": round(elapsed_ms, 1), + "error": error, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main() -> None: + """Entry point.""" + env_path = Path(__file__).parent.parent / ".env" + if env_path.exists(): + load_dotenv(env_path, override=False) + + parser = argparse.ArgumentParser( + description="Run Think-on-Graph QA on CRAG movie questions" + ) + parser.add_argument( + "--dataset", + type=str, + default=str( + Path(__file__).parent.parent / "dataset" / "crag_movie_tiny.jsonl.bz2" + ), + help="Input dataset path (.jsonl or .jsonl.bz2)", + ) + parser.add_argument( + "--output", + type=str, + default="", + help="Output JSONL file for predictions (default: stdout only)", + ) + parser.add_argument( + "--progress", + type=str, + default="", + help="JSON file for progress tracking / resumption", + ) + parser.add_argument( + "--mock", + action="store_true", + help="Use MockGraphBackend (no graph database required)", + ) + parser.add_argument( + "--db-uri", + type=str, + default=os.getenv("NEO4J_URI", "bolt://localhost:7687"), + ) + parser.add_argument( + "--db-user", + type=str, + default=os.getenv("NEO4J_USER", "neo4j"), + ) + parser.add_argument( + "--db-password", + type=str, + default=os.getenv("NEO4J_PASSWORD", "password"), + ) + parser.add_argument( + "--domain", + type=str, + default="movie", + help="Knowledge domain hint (default: movie)", + ) + parser.add_argument( + "--routes", + type=int, + default=3, + help="Number of solving routes (default: 3)", + ) + parser.add_argument( + "--width", + type=int, + default=30, + help="ToG traversal width (default: 30)", + ) + parser.add_argument( + "--depth", + type=int, + default=3, + help="ToG traversal depth (default: 3)", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Parallel async workers (default: 1)", + ) + parser.add_argument( + "--prefix", + type=int, + default=0, + help="First dataset item index to process (default: 0)", + ) + parser.add_argument( + "--postfix", + type=int, + default=None, + help="Exclusive upper bound on dataset items (default: all)", + ) + parser.add_argument( + "--no-eval", + action="store_true", + help="Skip post-hoc correctness evaluation", + ) + parser.add_argument( + "--reset-progress", + action="store_true", + help="Delete any existing progress file and re-process all questions", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + args = parser.parse_args() + + setup_logging(log_level=args.log_level) + + if args.log_level == "DEBUG": + import litellm + litellm.set_verbose = True + + # ------------------------------------------------------------------ + # LLM configuration check + # ------------------------------------------------------------------ + # Any OpenAI-compatible endpoint is configured via API_BASE + API_KEY. + # Without API_BASE, the session falls back to direct OpenAI, which + # requires a valid OPENAI_API_KEY. + if not args.mock and not os.getenv("API_BASE"): + log_progress( + "WARNING: API_BASE is not set. The session will attempt to use the " + "OpenAI API directly. Set API_BASE (and API_KEY / MODEL_NAME) to " + "point to your LLM endpoint, or pass --mock for local testing.", + level="WARNING", + ) + + # ------------------------------------------------------------------ + # Dataset + # ------------------------------------------------------------------ + dataset_path = Path(args.dataset) + if not dataset_path.exists(): + log_progress(f"Dataset not found: {dataset_path}", level="ERROR") + sys.exit(1) + + # ------------------------------------------------------------------ + # Progress / resumption + # ------------------------------------------------------------------ + progress_path = args.progress or "" + if progress_path and args.reset_progress and Path(progress_path).exists(): + Path(progress_path).unlink() + log_progress("Progress file deleted; starting fresh.") + progress = QAProgressLogger(progress_path) if progress_path else None + if progress: + progress.load() + skipped = progress.num_processed + if skipped: + log_progress(f"Resuming: {skipped} questions already processed.") + + skip_ids = progress.processed_ids if progress else set() + + # ------------------------------------------------------------------ + # Backend & sessions + # ------------------------------------------------------------------ + backend = create_backend( + backend_type="mock" if args.mock else "neo4j", + neo4j_uri=args.db_uri, + neo4j_user=args.db_user, + neo4j_password=args.db_password, + ) + + # Main session — uses API_BASE / API_KEY / MODEL_NAME + session, model_id = create_session_from_env() + log_progress(f"Using model: {model_id}, API base: {os.getenv('API_BASE') or '(default)'}") + + # Eval session — uses EVAL_* env vars if set, otherwise reuses main session + eval_session, _ = ( + create_session_from_env(default_model=model_id, env_prefix="EVAL_") + if os.getenv("EVAL_API_BASE") + else (session, model_id) + ) + + emb_client = create_emb_client_optional() + + # ------------------------------------------------------------------ + # Output file + # ------------------------------------------------------------------ + output_fh = None + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + output_fh = open(out_path, "a", encoding="utf-8") + + results: list = [] + + async def _process(item: Dict[str, Any]): + result = await process_question( + item, + backend=backend, + session=session, + eval_session=eval_session, + emb_client=emb_client, + domain=args.domain, + num_routes=args.routes, + width=args.width, + depth=args.depth, + ) + # Emit immediately + line = json.dumps(result, ensure_ascii=False, default=str) + print(line) + if output_fh: + output_fh.write(line + "\n") + output_fh.flush() + # Update progress + if progress: + progress.add_result(item["id"], result) + progress.save() + return result + + # ------------------------------------------------------------------ + # Run via loader worker pool + # ------------------------------------------------------------------ + try: + loader = MovieDatasetLoader( + dataset_path=str(dataset_path), + num_workers=args.workers, + prefix=args.prefix, + postfix=args.postfix, + ) + results = await loader.run( + process_fn=_process, + id_key="id", + skip_ids=skip_ids, + ) + finally: + await backend.close() + if output_fh: + output_fh.close() + + log_progress(f"Done. {len(results)} questions answered.") + + # ------------------------------------------------------------------ + # Post-hoc evaluation + # ------------------------------------------------------------------ + if results and not args.no_eval: + log_progress("Running correctness evaluation...") + evaluated = await evaluate_predictions( + session=eval_session, + predictions=results, + query_key="query", + answer_key="predicted", + gold_key="answer_aliases", + ) + n_correct = sum(1 for r in evaluated if r.get("correct")) + accuracy = n_correct / len(evaluated) if evaluated else 0.0 + log_progress( + f"Accuracy: {n_correct}/{len(evaluated)} = {accuracy:.1%}" + ) + + # ------------------------------------------------------------------ + # Update progress metadata + # ------------------------------------------------------------------ + if progress: + progress.update_meta(total_answered=len(results)) + progress.save() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mellea_contribs/kg/README.md b/mellea_contribs/kg/README.md new file mode 100644 index 0000000..ae7c734 --- /dev/null +++ b/mellea_contribs/kg/README.md @@ -0,0 +1,656 @@ +# Knowledge Graph (KG) Library + +A complete Knowledge Graph-Enhanced Retrieval-Augmented Generation (KG-RAG) system built with the Mellea framework. Combines multi-hop reasoning, entity extraction, consensus validation, and graph database backends for sophisticated question answering and knowledge graph updates. + +## Overview + +The KG library provides: + +- **Multi-route QA Pipeline**: Break down complex questions into multiple solving routes and reach consensus +- **Document-based KG Updates**: Extract entities and relations from documents and merge with existing knowledge graphs +- **Backend-agnostic Design**: Works with Neo4j (production) or MockGraphBackend (testing) +- **LLM-Guided Operations**: All decisions powered by Mellea's @generative framework +- **Structured Data Models**: Pydantic models for all inputs and outputs + +## Installation + +```bash +# Basic installation (MockGraphBackend, no database) +pip install mellea-contribs + +# With Neo4j support (for production) +pip install mellea-contribs[kg] + +# With progress bars (tqdm) +pip install mellea-contribs[kg-utils] + +# Complete installation (everything) +pip install mellea-contribs[kg,kg-utils,dev] +``` + +**Optional Dependencies:** +- `tqdm`: Progress bars for batch processing +- `neo4j`: Neo4j driver for production backend +- `rapidfuzz`: Fuzzy string matching for evaluation + +## Quick Start + +### Knowledge Graph-Enhanced Question Answering (Multi-Route QA) + +```python +import asyncio +from mellea import start_session +from mellea_contribs.kg import ( + orchestrate_qa_retrieval, + MockGraphBackend, +) + +async def main(): + # Initialize Mellea session for LLM calls + session = start_session(backend_name="litellm", model_id="gpt-4o-mini") + + # Use mock backend for testing (or Neo4jBackend for production) + backend = MockGraphBackend() + + # Multi-route QA pipeline + answer = await orchestrate_qa_retrieval( + session=session, + backend=backend, + query="Who directed the highest-grossing film of 2024?", + query_time="2024-12-31", + domain="movies", + num_routes=3, # Explore 3 different reasoning paths + hints="Consider box office revenue data", + ) + + print(f"Answer: {answer}") + await backend.close() + +asyncio.run(main()) +``` + +**Pipeline Steps:** +1. Break down question into 3 solving routes +2. Extract topic entities from each route +3. Align entities with knowledge graph +4. Prune relevant relations +5. Evaluate if knowledge is sufficient +6. Validate consensus across routes +7. Return final answer with reasoning + +### Document-Based Knowledge Graph Updates + +```python +from mellea_contribs.kg import orchestrate_kg_update + +async def update_kg(): + # Extract entities and relations from document + result = await orchestrate_kg_update( + session=session, + backend=backend, + doc_text=""" + Oppenheimer is a 2023 biographical film directed by Christopher Nolan. + It stars Cillian Murphy and Emily Blunt. The film won Best Picture at + the 2024 Academy Awards. + """, + domain="movies", + entity_types="Person,Movie,Award", + relation_types="DIRECTED,STARRED_IN,WON", + ) + + print(f"Extracted {len(result['extracted_entities'])} entities") + print(f"Extracted {len(result['extracted_relations'])} relations") + # Output: Automatically aligns and merges with existing KG data + +asyncio.run(update_kg()) +``` + +**Pipeline Steps:** +1. Extract entities and relations from text +2. Align extracted entities with existing KG entities +3. Decide whether to merge or create new entities +4. Align extracted relations with existing KG relations +5. Decide whether to merge or create new relations +6. Update knowledge graph with merged data + +### Using Mock Backend (No Infrastructure) + +```python +from mellea_contribs.kg import MockGraphBackend, GraphNode, GraphEdge + +# Create mock nodes +alice = GraphNode(id="1", label="Person", properties={"name": "Alice"}) +bob = GraphNode(id="2", label="Person", properties={"name": "Bob"}) + +# Create mock edge +knows = GraphEdge( + id="e1", + source=alice, + label="KNOWS", + target=bob, + properties={} +) + +# Create backend and execute query +backend = MockGraphBackend( + mock_nodes=[alice, bob], + mock_edges=[knows] +) + +from mellea_contribs.kg import GraphQuery + +query = GraphQuery(query_string="MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a, b") +result = await backend.execute_query(query) + +print(f"Nodes: {len(result.nodes)}") +print(f"Edges: {len(result.edges)}") +``` + +### Using Neo4j Backend + +```python +from mellea_contribs.kg import Neo4jBackend, GraphQuery + +# Connect to Neo4j +backend = Neo4jBackend( + connection_uri="bolt://localhost:7687", + auth=("neo4j", "password") +) + +# Execute Cypher query +query = GraphQuery( + query_string="MATCH (p:Person)-[:ACTED_IN]->(m:Movie) RETURN p, m", + parameters={} +) + +result = await backend.execute_query(query) + +# Get schema +schema = await backend.get_schema() +print(f"Node types: {schema['node_types']}") +print(f"Edge types: {schema['edge_types']}") + +# Validate query before execution +is_valid, error = await backend.validate_query(query) +print(f"Query valid: {is_valid}") + +# Cleanup +await backend.close() +``` + +## Architecture + +The system follows a 4-layer architecture: + +### Layer 1: Application Orchestration + +Entry points for KG-RAG pipelines: +- **`orchestrate_qa_retrieval()`**: Multi-route question answering with consensus validation +- **`orchestrate_kg_update()`**: Document-based KG updates with entity/relation alignment + +### Layer 2: Components & Query Building + +Query construction and result formatting: +- **GraphQuery / CypherQuery / SparqlQuery**: Query type abstractions +- **GraphResult**: Result formatting with `format_for_llm()` method +- **natural_language_to_cypher()**: Convert questions to Cypher queries +- **explain_query_result()**: Format results for LLM consumption + +### Layer 3: LLM-Guided Logic (@generative functions) + +All decisions made by LLM through Mellea's @generative framework: + +**QA Functions (8):** +1. `break_down_question()` → Routes: Break complex questions into solving strategies +2. `extract_topic_entities()` → TopicEntities: Extract search entities from query +3. `align_topic_entities()` → RelevantEntities: Score entity relevance (0-1 scale) +4. `prune_relations()` → RelevantRelations: Filter relevant relations from entities +5. `prune_triplets()` → RelevantRelations: Score triplet relevance for answering +6. `evaluate_knowledge_sufficiency()` → EvaluationResult: Determine if KG knowledge suffices +7. `validate_consensus()` → ValidationResult: Validate consensus across routes +8. `generate_direct_answer()` → DirectAnswer: Generate answer without KG (fallback) + +**Update Functions (5):** +1. `extract_entities_and_relations()` → ExtractionResult: Extract from documents +2. `align_entity_with_kg()` → AlignmentResult: Find matching KG entities +3. `decide_entity_merge()` → MergeDecision: Decide entity merge strategy +4. `align_relation_with_kg()` → AlignmentResult: Find matching KG relations +5. `decide_relation_merge()` → MergeDecision: Decide relation merge strategy + +### Layer 4: Backend Abstraction + +Database operations: +- **GraphNode / GraphEdge / GraphPath**: Pure dataclasses representing graph data +- **GraphBackend**: Abstract interface for graph databases +- **Neo4jBackend**: Production-ready Neo4j implementation +- **MockGraphBackend**: In-memory testing backend (no infrastructure required) + +## Data Structures + +### GraphNode + +```python +@dataclass +class GraphNode: + id: str # Unique identifier + label: str # Node type/label + properties: dict[str, Any] # Node properties +``` + +### GraphEdge + +```python +@dataclass +class GraphEdge: + id: str # Unique identifier + source: GraphNode # Source node + label: str # Relationship type + target: GraphNode # Target node + properties: dict[str, Any] # Relationship properties +``` + +### GraphPath + +```python +@dataclass +class GraphPath: + nodes: list[GraphNode] # Sequence of nodes + edges: list[GraphEdge] # Sequence of edges +``` + +## Backend Interface + +All backends implement `GraphBackend` which provides: + +- `execute_query(query: GraphQuery) -> GraphResult`: Execute a query +- `get_schema() -> dict`: Get graph schema (node types, edge types, properties) +- `validate_query(query: GraphQuery) -> tuple[bool, str | None]`: Validate query +- `supports_query_type(query_type: str) -> bool`: Check if query type supported +- `execute_traversal(traversal: GraphTraversal) -> GraphResult`: Execute traversal pattern +- `close()`: Close backend connections + +## Testing + +```bash +# Run base data structure tests (no dependencies) +pytest test/kg/test_base.py -v + +# Run mock backend tests (no dependencies) +pytest test/kg/test_mock_backend.py -v + +# Run Neo4j tests (requires Neo4j running) +export NEO4J_URI=bolt://localhost:7687 +export NEO4J_USER=neo4j +export NEO4J_PASSWORD=password + +pytest test/kg/test_neo4j_backend.py -v +``` + +## Starting Neo4j for Testing + +```bash +# Docker +docker run -d --name neo4j-test -p 7687:7687 -p 7474:7474 \ + -e NEO4J_AUTH=neo4j/testpassword \ + neo4j:5.0 + +# Run tests +export NEO4J_URI=bolt://localhost:7687 +export NEO4J_USER=neo4j +export NEO4J_PASSWORD=testpassword + +pytest test/kg/ -v + +# Cleanup +docker stop neo4j-test && docker rm neo4j-test +``` + +## Implementation Status + +### Phase 1: Core KG Modules ✓ COMPLETE +- ✓ **Layer 1**: Application Orchestration + - `orchestrate_qa_retrieval()` - Multi-route QA entry point + - `orchestrate_kg_update()` - KG update entry point + +- ✓ **Layer 2**: Components + - GraphQuery, CypherQuery, SparqlQuery types + - GraphResult with format_for_llm() + - natural_language_to_cypher, explain_query_result + +- ✓ **Layer 3**: LLM-Guided Logic + - 8 QA @generative functions with full prompts + - 5 Update @generative functions with full prompts + - 12 Pydantic models for structured outputs + +- ✓ **Layer 4**: Backend Abstraction + - GraphNode, GraphEdge, GraphPath data structures + - GraphBackend abstract interface + - Neo4jBackend production implementation + - MockGraphBackend for testing + +### Phase 2: Run Scripts ✓ COMPLETE +- ✓ **8 Production-Ready CLI Scripts** (docs/examples/kgrag/scripts/) + - Dataset creation: create_demo_dataset.py, create_tiny_dataset.py, create_truncated_dataset.py + - Pipeline operations: run_kg_preprocess.py, run_kg_embed.py, run_kg_update.py, run_qa.py, run_eval.py + - All scripts support --mock flag for testing without database + - JSONL I/O for seamless pipeline chaining + +### Phase 3: Utility Modules ✓ COMPLETE (95 Tests Passing) +- ✓ **5 Reusable Utility Modules** (mellea_contribs/kg/utils/) + - `data_utils.py` - JSONL I/O, batching, schema validation (27 tests) + - `session_manager.py` - Session/backend factories, async resource management (19 tests) + - `progress.py` - Logging, progress tracking, JSON output (23 tests) + - `eval_utils.py` - Evaluation metrics, result aggregation (26 tests) + - All utilities tested with 95 comprehensive unit + integration tests + +### Phase 4: Configuration & Validation ✓ COMPLETE +- ✓ **.env_template** - Configuration template with all variables +- ✓ **pyproject.toml** - Updated with kg-utils optional dependency group +- ✓ **sun.sh** - Comprehensive end-to-end test suite validating all phases + +## Utility Modules (Phase 3) + +The `mellea_contribs.kg.utils` package provides reusable utilities extracted from the run scripts: + +### JSONL Data Utilities +```python +from mellea_contribs.kg.utils import ( + load_jsonl, save_jsonl, append_jsonl, + batch_iterator, truncate_jsonl, shuffle_jsonl, + validate_jsonl_schema +) + +# Load JSONL file +items = list(load_jsonl("data/questions.jsonl")) + +# Save and append +save_jsonl(items, "output/results.jsonl") +append_jsonl({"new": "item"}, "output/results.jsonl") + +# Batch processing +batches = list(batch_iterator(items, batch_size=10)) + +# Truncate and shuffle +truncate_jsonl("input.jsonl", "output.jsonl", max_items=100) +shuffle_jsonl("input.jsonl", "output_shuffled.jsonl") + +# Validate schema +valid, errors = validate_jsonl_schema("data.jsonl", required_fields=["id", "text"]) +``` + +### Session & Backend Management +```python +from mellea_contribs.kg.utils import ( + create_session, create_backend, MelleaResourceManager +) + +# Create session and backend +session = create_session(model_id="gpt-4o-mini") +backend = create_backend(backend_type="mock") + +# Or use async context manager for automatic cleanup +async with MelleaResourceManager(backend_type="mock") as manager: + # manager.session and manager.backend available + schema = await manager.backend.get_schema() +``` + +### Progress Tracking & Logging +```python +from mellea_contribs.kg.utils import ( + setup_logging, log_progress, output_json, + print_stats, ProgressTracker +) + +# Setup logging +setup_logging(log_level="INFO", log_file="pipeline.log") +log_progress("Processing started", level="INFO") + +# Output JSON +stats = compute_stats() +output_json(stats) # Prints to stdout + +# Print formatted stats +print_stats(stats, indent=2, to_stderr=False) + +# Progress tracking +tracker = ProgressTracker(total=1000, desc="Processing") +for item in items: + process(item) + tracker.update(1) +tracker.close() +``` + +### Evaluation Metrics +```python +from mellea_contribs.kg.utils import ( + exact_match, fuzzy_match, mean_reciprocal_rank, + precision, recall, f1_score, + aggregate_qa_results, aggregate_update_results +) + +# Matching +is_match = exact_match("Paris", "PARIS") # True (case-insensitive) +is_similar = fuzzy_match("Oppenheimer", "Oppenheimer", threshold=0.8) # True + +# Metrics +mrr = mean_reciprocal_rank(qa_results) +prec = precision(predicted_entities, expected_entities) +rec = recall(predicted_entities, expected_entities) +f1 = f1_score(prec, rec) + +# Aggregation +stats = aggregate_qa_results(qa_results_list) +stats = aggregate_update_results(update_results_list) +``` + +### Complete Workflow Example +```python +from mellea_contribs.kg.utils import ( + load_jsonl, batch_iterator, create_session, create_backend, + log_progress, output_json, aggregate_qa_results +) + +async def evaluate_qa_pipeline(): + # Setup + setup_logging(log_level="INFO") + session = create_session(model_id="gpt-4o-mini") + backend = create_backend(backend_type="mock") + + # Load questions + questions = list(load_jsonl("questions.jsonl")) + + # Process in batches + all_results = [] + for batch in batch_iterator(questions, batch_size=10): + results = await run_qa_batch(session, backend, batch) + all_results.extend(results) + + # Aggregate and output + stats = aggregate_qa_results(all_results) + output_json(stats) + + await backend.close() +``` + +## Testing & Validation (Phase 3 & 4) + +### Running Tests +```bash +# All KG tests +pytest test/kg/ -v + +# Unit tests only (Phase 1, 3) +pytest test/kg/ --ignore=test/kg/test_scripts/ -v + +# Utility module tests (95 tests) +pytest test/kg/utils/ -v + +# Neo4j tests (requires running Neo4j) +export NEO4J_URI=bolt://localhost:7687 +pytest test/kg/ -v -m neo4j +``` + +### Comprehensive Validation Suite (sun.sh) +```bash +# Run complete end-to-end validation +./sun.sh + +# Quick validation (skip some slower tests) +./sun.sh --quick + +# Unit tests only +./sun.sh --unit-only +``` + +The `sun.sh` script validates: +- Phase 0: Environment (Python, dependencies, imports) +- Phase 1: Core KG modules (95+ unit tests) +- Phase 2: Run scripts (all 8 scripts with mock backend) +- Phase 3: Utility modules (95 comprehensive tests) +- Phase 4: Configuration and dependencies + +### Test Coverage +- **Phase 1**: Core modules (entity models, preprocessor, embedder, orchestrators) +- **Phase 2**: Run scripts (dataset creation, preprocessing, embedding, QA, evaluation) +- **Phase 3**: Utility modules + - JSONL I/O: 27 tests (load, save, append, batch, truncate, shuffle, validate) + - Session management: 19 tests (backend creation, session creation, async resources) + - Progress/logging: 23 tests (logging levels, JSON output, progress tracking) + - Evaluation: 26 tests (exact/fuzzy matching, MRR, precision/recall/F1, aggregation) + +## Key Problems Solved + +### Multi-Hop Reasoning +Traditional LLMs struggle with questions requiring multiple steps through a knowledge graph. KG-RAG breaks questions into solving routes and explores them systematically. + +**Example:** +- Query: "Who won Best Picture at the Oscars and what other awards did they win?" +- Solved by: Entity extraction → Relation discovery → Multi-hop traversal → Consensus + +### Temporal Understanding +Time-sensitive queries require proper context. The system tracks query times and considers temporal aspects in both questions and graph properties. + +**Example:** +- Query: "Who was the highest-paid actor in 2023?" (different from 2024) +- Handled by: query_time parameter → temporal property filtering → time-aware alignment + +### Structured Relationship Comprehension +Complex relationships with properties need careful reasoning. The system scores and filters relations based on relevance. + +**Example:** +- Query: "Which movies did actor X star in that won awards?" +- Handled by: Extract ACTED_IN relations → Filter by WON properties → Score relevance + +### Explainable Reasoning +Get not just answers, but reasoning paths through the knowledge graph showing how the answer was derived. + +**Example:** Answer includes: +- Which solving route was used +- What entities were found +- Which relations were traversed +- Why the answer was sufficient or needed fallback + +### Document Integration +Automatically extract new information from documents and intelligently merge with existing knowledge graph without duplicates. + +**Example:** Merge "Leonardo DiCaprio" from document with existing entity, preserving both old and new properties. + +## Performance Considerations + +### Optimization Strategies + +1. **Multi-Route Exploration**: Configurable number of solving routes + - Fewer routes = faster but less certain + - More routes = slower but more confident + +2. **Relation Pruning Width**: Control how many relations to explore + - Default: 20 relations per entity + - Adjustable via `width` parameter + +3. **Consensus Validation**: Stop early when routes agree + - Fast path: 2 of 3 routes agree → return answer + - Slow path: All routes explored → reach consensus + +4. **Caching**: Neo4j vector index caching, schema caching + - Results cached per query_time + domain combination + - Entity similarity searches cached by backend + +### Scalability + +- **Async/await throughout** for non-blocking I/O +- **Configurable parameters** for tuning vs. quality tradeoff +- **Efficient Pydantic models** for structured validation +- **MockBackend** for parallel testing without infrastructure + +## Known Limitations + +- **Domain-Specific**: Currently optimized for movie/entertainment domain (easily adapted) +- **Requires Pre-built Graphs**: Expects Neo4j or data in MockBackend already populated +- **Computational Cost**: Multi-hop traversal can be expensive on large graphs +- **English-Only**: Currently designed for English-language queries (LLM-dependent) +- **Entity Disambiguation**: Relies on good entity naming conventions in KG + +## Design Notes + +- Pure dataclasses (GraphNode, GraphEdge, GraphPath) for data representation +- Components for queries and results (evolving in Layer 2) +- Async/await throughout for scalability +- Optional Neo4j dependency - graceful degradation if not installed +- MockBackend for unit testing without infrastructure +- All LLM decisions through Mellea's @generative framework (pluggable LLM backends) +- Structured outputs via Pydantic models (validated at LLM output) + +## Migration Notes + +This library was adapted from the KGRag system in mellea PR#3. Key differences: + +- **Open-ended**: Works with any domain (not movie-specific) +- **Mellea-integrated**: Uses Mellea's @generative decorators and MelleaSession +- **Backend-agnostic**: MockBackend for testing, Neo4jBackend for production +- **Structured API**: Clear Layer 1-4 separation with orchestration entry points +- **Full type hints**: Pydantic models throughout + +## Quick Reference: Running Everything + +### Minimal Setup (Mock Backend, No Database) +```bash +# Install +pip install -e .[kg,kg-utils] + +# Run tests +pytest test/kg/utils/ -v + +# Run full validation +./sun.sh + +# Try an example script +python docs/examples/kgrag/scripts/create_demo_dataset.py --output /tmp/demo.jsonl +python docs/examples/kgrag/scripts/run_qa.py --input /tmp/demo.jsonl --mock --output /tmp/qa.jsonl +``` + +### Production Setup (With Neo4j) +```bash +# Install with all features +pip install -e .[kg,kg-utils,dev] + +# Start Neo4j +docker run -d -p 7687:7687 -e NEO4J_AUTH=neo4j/password neo4j:5.0 + +# Configure +cp .env_template .env +# Edit .env with your Neo4j credentials + +# Run scripts +python docs/examples/kgrag/scripts/run_kg_preprocess.py --input data.jsonl +python docs/examples/kgrag/scripts/run_qa.py --input questions.jsonl --output results.jsonl +``` + +## See Also + +- [Main README](../../README.md) - KG-RAG overview and quick start +- [CLAUDE.md](../../CLAUDE.md) - Development guide and architecture +- [PHASE4_CONFIGURATION.md](../../PHASE4_CONFIGURATION.md) - Configuration templates and setup +- [missing_for_run_sh.txt](../../missing_for_run_sh.txt) - Implementation status and progress +- [Test README](utils/README.md) - Phase 3 utility module tests documentation +- [Mellea Framework](https://github.com/generative-computing/mellea) - Parent framework +- [Original PR#3](https://github.com/ydzhu98/mellea/pull/3) - Source of KG-RAG system diff --git a/mellea_contribs/kg/__init__.py b/mellea_contribs/kg/__init__.py new file mode 100644 index 0000000..dbbc24f --- /dev/null +++ b/mellea_contribs/kg/__init__.py @@ -0,0 +1,264 @@ +"""Knowledge Graph library for mellea-contribs. + +Backend-agnostic graph database abstraction for graph-based RAG applications. + +Optional Dependencies: + Neo4j support requires: pip install mellea-contribs[kg] + +Example: + Basic usage with MockGraphBackend:: + + from mellea_contribs.kg import MockGraphBackend, GraphNode, Entity + + backend = MockGraphBackend() + node = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + + # Create a generic entity + entity = Entity( + type="Person", + name="Alice", + description="Example person", + paragraph_start="Alice is", + paragraph_end="here.", + ) + + With domain-specific entities (movie example):: + + from docs.examples.kgrag.models import MovieEntity + + movie = MovieEntity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer is", + paragraph_end="by Nolan.", + release_year=2023, + director="Christopher Nolan" + ) + + With Neo4j:: + + from mellea_contribs.kg import Neo4jBackend + + backend = Neo4jBackend( + uri="bolt://localhost:7687", + auth=("neo4j", "password") + ) +""" + +from typing import Any + +from mellea_contribs.kg.base import GraphEdge, GraphNode, GraphPath + +# Optional imports from mellea components (requires mellea to be installed) +try: + from mellea_contribs.kg.components import ( + CypherQuery, + GeneratedQuery, + GraphResult, + GraphTraversal, + SparqlQuery, + align_entity_with_kg, + align_relation_with_kg, + align_topic_entities, + break_down_question, + decide_entity_merge, + decide_relation_merge, + evaluate_knowledge_sufficiency, + extract_entities_and_relations, + extract_topic_entities, + generate_direct_answer, + prune_relations, + prune_triplets, + validate_consensus, + ) +except ImportError: + # These functions are optional - mellea may not be installed + CypherQuery = None # type: ignore[assignment] + GeneratedQuery = None # type: ignore[assignment] + GraphResult = None # type: ignore[assignment] + GraphTraversal = None # type: ignore[assignment] + SparqlQuery = None # type: ignore[assignment] + align_entity_with_kg = None + align_relation_with_kg = None + align_topic_entities = None + break_down_question = None + decide_entity_merge = None + decide_relation_merge = None + evaluate_knowledge_sufficiency = None + extract_entities_and_relations = None + extract_topic_entities = None + generate_direct_answer = None + prune_relations = None + prune_triplets = None + validate_consensus = None +from mellea_contribs.kg.graph_dbs.base import GraphBackend +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend +from mellea_contribs.kg.embedder import KGEmbedder +from mellea_contribs.kg.kgrag import KGRag, format_schema +from mellea_contribs.kg.preprocessor import KGPreprocessor +from mellea_contribs.kg.rep import ( + camelcase_to_snake_case, + entity_to_text, + format_entity_list, + format_kg_context, + format_relation_list, + normalize_entity_name, + relation_to_text, + snake_case_to_camelcase, +) +from mellea_contribs.kg.requirements_models import ( + entity_confidence_threshold, + entity_has_description, + entity_has_name, + entity_type_valid, + relation_entities_exist, + relation_type_valid, +) +from mellea_contribs.kg.embed_models import ( + EmbeddingConfig, + EmbeddingResult, + EmbeddingSimilarity, + EmbeddingStats, +) +from mellea_contribs.kg.qa_models import ( + QAConfig, + QADatasetConfig, + QAResult, + QASessionConfig, + QAStats, +) +from mellea_contribs.kg.updater_models import ( + MergeConflict, + UpdateBatchResult, + UpdateConfig, + UpdateResult, + UpdateSessionConfig, + UpdateStats, +) +from mellea_contribs.kg.models import ( + DirectAnswer, + Entity, + EvaluationResult, + ExtractionResult, + QuestionRoutes, + Relation, + RelevantEntities, + RelevantRelations, + TopicEntities, + ValidationResult, +) + +_neo4j_import_error: Exception | None = None +_Neo4jBackend: Any = None + +try: + from mellea_contribs.kg.graph_dbs.neo4j import Neo4jBackend as _Neo4jBackend +except ImportError as e: + _neo4j_import_error = e + + def _Neo4jBackend(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc] + """Raise ImportError with helpful message when Neo4j not installed.""" + raise ImportError( + "Neo4j support requires additional dependencies. " + "Install with: pip install mellea-contribs[kg]" + ) from _neo4j_import_error + + +Neo4jBackend = _Neo4jBackend + +# Build __all__ list dynamically, only including successfully imported items +_all = [ + # Core data structures + "GraphBackend", + "GraphEdge", + "GraphNode", + "GraphPath", + # Models - QA/extraction outputs + "DirectAnswer", + "EvaluationResult", + "ExtractionResult", + "QuestionRoutes", + "RelevantEntities", + "RelevantRelations", + "TopicEntities", + "ValidationResult", + # Models - Stored entities/relations (base classes) + "Entity", + "Relation", + # Models - Embedding configuration and results + "EmbeddingConfig", + "EmbeddingResult", + "EmbeddingSimilarity", + "EmbeddingStats", + # Models - QA pipeline configuration and results + "QAConfig", + "QASessionConfig", + "QADatasetConfig", + "QAResult", + "QAStats", + # Models - KG update pipeline configuration and results + "UpdateConfig", + "UpdateSessionConfig", + "UpdateStats", + "MergeConflict", + "UpdateResult", + "UpdateBatchResult", + # Layer 1 Applications + "KGRag", + "KGPreprocessor", + "KGEmbedder", + # Backends (Layer 4) + "MockGraphBackend", + "Neo4jBackend", + # Utilities - Representation (Optional) + "normalize_entity_name", + "entity_to_text", + "relation_to_text", + "format_entity_list", + "format_relation_list", + "format_kg_context", + "camelcase_to_snake_case", + "snake_case_to_camelcase", + # Utilities - Requirements/Validation (Optional) + "entity_type_valid", + "entity_has_name", + "entity_has_description", + "relation_type_valid", + "relation_entities_exist", + "entity_confidence_threshold", + # Other Utilities + "format_schema", +] + +# Add component classes if they were successfully imported +if CypherQuery is not None: + _all.extend([ + "CypherQuery", + "GeneratedQuery", + "GraphResult", + "GraphTraversal", + "SparqlQuery", + ]) + +# Add generative functions only if they were successfully imported +if extract_entities_and_relations is not None: + _all.extend([ + # Generative functions - QA + "break_down_question", + "extract_topic_entities", + "align_topic_entities", + "prune_relations", + "prune_triplets", + "evaluate_knowledge_sufficiency", + "validate_consensus", + "generate_direct_answer", + # Generative functions - Update + "extract_entities_and_relations", + "align_entity_with_kg", + "decide_entity_merge", + "align_relation_with_kg", + "decide_relation_merge", + ]) + +__all__ = _all diff --git a/mellea_contribs/kg/base.py b/mellea_contribs/kg/base.py new file mode 100644 index 0000000..2aea382 --- /dev/null +++ b/mellea_contribs/kg/base.py @@ -0,0 +1,103 @@ +"""Core data structures for graph representation. + +These are pure dataclasses, not Components. They represent graph data. +""" + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class GraphNode: + """A node in a graph. + + This is a dataclass, not a Component. It's just data. + """ + + id: str + label: str # Node type/label + properties: dict[str, Any] + + @classmethod + def from_neo4j_node(cls, node: Any) -> "GraphNode": + """Create from Neo4j node object. + + Args: + node: Neo4j node object + + Returns: + GraphNode instance + """ + return cls( + id=str(node.element_id), + label=next(iter(node.labels)) if node.labels else "Unknown", + properties=dict(node.items()), + ) + + +@dataclass +class GraphEdge: + """An edge in a graph. + + This is a dataclass, not a Component. It's just data. + """ + + id: str + source: GraphNode + label: str # Relationship type + target: GraphNode + properties: dict[str, Any] + + @classmethod + def from_neo4j_relationship( + cls, rel: Any, source: GraphNode, target: GraphNode + ) -> "GraphEdge": + """Create from Neo4j relationship object. + + Args: + rel: Neo4j relationship object + source: Source GraphNode + target: Target GraphNode + + Returns: + GraphEdge instance + """ + return cls( + id=str(rel.element_id), + source=source, + label=rel.type, + target=target, + properties=dict(rel.items()), + ) + + +@dataclass +class GraphPath: + """A path through a graph (sequence of nodes and edges). + + This is a dataclass, not a Component. It's just data. + """ + + nodes: list[GraphNode] + edges: list[GraphEdge] + + @classmethod + def from_neo4j_path(cls, path: Any) -> "GraphPath": + """Create from Neo4j path object. + + Args: + path: Neo4j path object + + Returns: + GraphPath instance + """ + nodes = [GraphNode.from_neo4j_node(node) for node in path.nodes] + edges = [] + + # Build edges from relationships + for i, rel in enumerate(path.relationships): + source = nodes[i] + target = nodes[i + 1] + edges.append(GraphEdge.from_neo4j_relationship(rel, source, target)) + + return cls(nodes=nodes, edges=edges) diff --git a/mellea_contribs/kg/components/__init__.py b/mellea_contribs/kg/components/__init__.py new file mode 100644 index 0000000..afdb732 --- /dev/null +++ b/mellea_contribs/kg/components/__init__.py @@ -0,0 +1,53 @@ +"""Query components for graph database operations.""" + +from mellea_contribs.kg.components.generative import ( + align_entity_with_kg, + align_relation_with_kg, + align_topic_entities, + break_down_question, + decide_entity_merge, + decide_relation_merge, + evaluate_knowledge_sufficiency, + extract_entities_and_relations, + extract_topic_entities, + generate_direct_answer, + prune_relations, + prune_triplets, + validate_consensus, +) +from mellea_contribs.kg.components.llm_guided import ( + GeneratedQuery, + explain_query_result, + natural_language_to_cypher, + suggest_query_improvement, +) +from mellea_contribs.kg.components.query import CypherQuery, GraphQuery, SparqlQuery +from mellea_contribs.kg.components.result import GraphResult +from mellea_contribs.kg.components.traversal import GraphTraversal + +__all__ = [ + "CypherQuery", + "GeneratedQuery", + "GraphQuery", + "GraphResult", + "GraphTraversal", + "SparqlQuery", + "explain_query_result", + "natural_language_to_cypher", + "suggest_query_improvement", + # QA generative functions + "break_down_question", + "extract_topic_entities", + "align_topic_entities", + "prune_relations", + "prune_triplets", + "evaluate_knowledge_sufficiency", + "validate_consensus", + "generate_direct_answer", + # Update generative functions + "extract_entities_and_relations", + "align_entity_with_kg", + "decide_entity_merge", + "align_relation_with_kg", + "decide_relation_merge", +] diff --git a/mellea_contribs/kg/components/generative.py b/mellea_contribs/kg/components/generative.py new file mode 100644 index 0000000..7378692 --- /dev/null +++ b/mellea_contribs/kg/components/generative.py @@ -0,0 +1,549 @@ +"""Generative functions for KG-RAG using Mellea's @generative decorator. + +These are Layer 2-3 functions that combine LLM generative calls with orchestration. +Layer 2: Executor functions that orchestrate the pipeline +Layer 3: @generative functions that call the LLM +""" +from typing import List + +from mellea import generative + +from mellea_contribs.kg.models import ( + DirectAnswer, + EvaluationResult, + ExtractionResult, + QuestionRoutes, + RelevantEntities, + RelevantRelations, + TopicEntities, + ValidationResult, +) + + +# QA Generative Functions (Layer 3 LLM Functions) + + +@generative +async def break_down_question( + query: str, + query_time: str, + domain: str, + route: int, + hints: str +) -> QuestionRoutes: + """You are a helpful assistant who is good at answering questions in the {domain} domain by using knowledge from an external knowledge graph. Before answering the question, you need to break down the question + so that you may look for the information from the knowledge graph in a step-wise operation. Hence, please break down the process of answering the question into as few sub-objectives as possible based on semantic analysis. + A query time is also provided; please consider including the time information when applicable. + + There can be multiple possible route to break down the question, aim for generating {route} possible routes. Note that every route may have a different solving efficiency, order the route by their solving efficiency. + Return your reasoning and sub-objectives as multiple lists of strings in a flat JSON of format: {{"reason": "...", "routes": [[], [], ...]}}. (TIP: You will need to escape any double quotes in the string to make the JSON valid) + + Domain-specific Hints: + {hints} + + -Example- + Q: Which of the countries in the Caribbean has the smallest country calling code? + Query Time: 03/05/2024, 23:35:21 PT + Output: {{ + "reason": "The most efficient route involves directly identifying Caribbean countries and their respective calling codes, as this limits the scope of the search. In contrast, routes that involve broader searches, such as listing all country calling codes worldwide before filtering, are less efficient due to the larger dataset that needs to be processed. Therefore, routes are ordered based on the specificity of the initial search and the subsequent steps required to narrow down to the answer.", + "routes": [["List all Caribbean countries", "Determine the country calling code for each country", "Identify the country with the smallest calling code"], + ["Identify Caribbean countries", "Retrieve their country calling codes", "Compare to find the smallest"], + ["Identify the smallest country calling code globally", "Filter by Caribbean countries", "Select the smallest among them"], + ["List all country calling codes worldwide", "Filter the calling codes by Caribbean countries", "Find the smallest one"]] + }} + + Q: {query} + Query Time: {query_time} + Output Format (flat JSON): {{"reason": "...", "routes": [[], [], ...]}} + Output:""" + pass + + +@generative +async def extract_topic_entities( + query: str, + query_time: str, + route: List[str], + domain: str +) -> TopicEntities: + """-Goal- + You are presented with a question in the {domain} domain, its query time, and a potential route to solve it. + + 1) Determine the topic entities asked in the query and each step in the solving route. The topic entities will be used as source entities to search through a knowledge graph for answers. + It's preferrable to mention the entity type explictly to ensure a more precise search hit. + + 2) Extract those topic entities from the query into a string list in the format of ["entity1", "entity2", ...]. + Consider extracting the entities in an informative way, combining adjectives or surrounding information. + A query time is provided - please consider including the time information when applicable. + + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + ###################### + -Examples- + Question: Who wins the best actor award in 2020 Oscars? + Solving Route: ['List the nominees for the best actor award in the 2020 Oscars', 'Identify the winner among the nominees'] + Query Time: 03/05/2024, 23:35:21 PT + Output: ["2020 Oscars best actor award"] + Explanation (don't output this): This is an Award typed entity, extract an entity with the name "2020 Oscars best actor award" will best help search source entities in the knowledge graph. + + Question: Which movie wins the best visual effect award in this year's Oscars? + Query Time: 03/19/2024, 23:49:30 PT + Solving Route: ["Retrieve the list of nominees of this year's best visual effects Oscars award", 'Find the winner from the nominees'] + Output: ["2024 Oscars best visual effect award"] + Explanation (don't output this): This is an Award typed entity, and the query time for this year is "2024", extract an entity with the name "2024 Oscars best visual effect award" will best help search source entities in the knowledge graph. + + Question: Who is the lead actor for Titanic? + Query Time: 03/17/2024, 17:19:52 PT + Solving Route: ["List the main cast of Titanic", "Identify the lead actor among them"] + Output: ["Titanic Movie"] + Explanation (don't output this): This is a Movie typed entity, just simply extract an entity with the movie name "Titanic Movie" will best help search source entities in the knowledge graph. + + Question: How many countries were "Inception" filmed in? + Query Time: 03/19/2024, 22:59:20 PT + Solving Route: ["Retrieve information about the movie 'Inception'", "Extract filming locations", "Count the countries"] + Output: ["Inception Movie"] + Explanation (don't output this): This is a Movie typed entity, just simply extract an entity with the movie name "Inception Movie" will best help search source entities in the knowledge graph. + + Question: {query} + Query Time: {query_time} + Solving Route: {route} + + Output Format: ["entity1", "entity2", ...] + Output: + """ + pass + + +@generative +async def align_topic_entities( + query: str, + query_time: str, + route: List[str], + domain: str, + top_k_entities_str: str +) -> RelevantEntities: + """-Goal- + You are presented with a question in the {domain} domain, its query time, a potential route to solve it, and a list of entities extracted from a noisy knowledge graph. + The goal is to identify all possible relevant entities to answering the steps in the solving route and, therefore, answer the question. + You need to consider that the knowledge graph may be noisy and relations may split into similar entities, so it's essential to identify all relevant entities. + The entities' relevance would be scored on a scale from 0 to 1 (use at most 3 decimal places, and remove trailing zeros; the sum of the scores of all entities is 1). + + -Steps- + 1. You are provided a set of entities (type, name, description, and potential properties) globally searched from a knowledge graph that most similar to the question description, but may not directly relevant to the question itself. + Given in the format of "ent_i: (: , desc: "description", props: {{key: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}})" + where "i" is the index, the percentage is confidence score, "ctx" is an optional context under which the value is valid. Each property may have only a single value, or multiple valid values of vary confidence under different context. + + 2. Score *ALL POSSIBLE* entities that are relevant to answering the steps in the solving route and therefore answering the question, and provide a short reason for your scoring. + Return its index (ent_i) and score into a valid JSON of the format: {{"reason": "reason", "relevant_entities": {{"ent_i": 0.6, "ent_j": 0.3, ...}}}}. (TIP: You will need to escape any double quotes in the string to make the JSON valid) + + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + ###################### + -Examples- + Question: How many countries were "Inception" filmed in? + Solving Route: ["Retrieve information about the movie 'Inception'", "Extract filming locations", "Count the countries"] + Query Time: 03/05/2024, 23:35:21 PT + Entities: ent_0: (Movie: INCEPTION, desc: 2010 sci-fi action film, props: {{year: 2010, release_date: 2012-07-20, rating: 8.6}}) + ent_1: (Movie: INCEPTION: THE COBOL JOB, props: {{release_date: 2010-12-07, rating: 7.263, original_name: Inception: The Cobol Job}}) + ent_2: (Movie: INVASION, props: {{release_date: 2005-10-06, original_name: Invasion}}) + ent_3: (Movie: THE INVITATION, props: {{release_date: 2016-04-08, rating: 6.462, original_name: The Invitation}}) + Output: {{"reason": "The solving route asks about the movie 'Inception', and ent_0 is the entity that directly corresponds to the movie 'Inception'.", "relevant_entities": {{"ent_0": 1}}}} + + Question: In this year, which animated film was recognized with the best animated feature film Oscar? + Solving Route: ["Retrieve the list of nominees of this year's best animated feature film Oscars award", 'Find the winner from the nominees'] + Query Time: 03/19/2024, 23:49:30 PT + Entities: ent_0: (Award: ANIMATED FEATURE FILM, props: {{year: 2024, ceremony_number: 96, type: OSCAR AWARD}}) + ent_1: (Award: SHORT FILM (ANIMATED), props: {{year: 2004, ceremony_number: 76, type: OSCAR AWARD}}) + ent_2: (Award: ANIMATED FEATURE FILM, props: {{year: 2005, ceremony_number: 77, type: OSCAR AWARD}}) + ent_3: (Award: ANIMATED FEATURE FILM, props: {{year: 2002, ceremony_number: 74, type: OSCAR AWARD}}) + ent_4: (Award: ANIMATED FEATURE FILM, props: {{year: 2003, ceremony_number: 75, type: OSCAR AWARD}}) + Output: {{"reason": "The entity ent_0 is the award for the best animated feature film in the year of query time, 2024, asked in the solving route.", "relevant_entities": {{"ent_0": 1}}}} + + Question: Can you tell me the name of the actress who starred in the film that won the best picture oscar in 2018? + Solving Route: ["Find the Best Picture Oscar winner for 2018", "Retrieve the cast of the film", "Identify the lead actress"], + Query Time: 03/19/2024, 22:59:20 PT + Entities: ent_0: (Award: ACTRESS IN A LEADING ROLE, props:{{year: 2018, ceremony_number: 90, type: OSCAR AWARD}}) + ent_1: (Award: ACTOR IN A LEADING ROLE, props: {{year: 2018, ceremony_number: 90, type: OSCAR AWARD}}) + ent_2: (Award: BEST PICTURE, props: {{year: 2018, ceremony_number: 90, type: OSCAR AWARD}}) + ent_3: (Award: ACTRESS IN A SUPPORTING ROLE, props: {{year: 2018, ceremony_number: 90, type: OSCAR AWARD}}) + Output:{{"reason": "The solving route requests the 2018 best picture Oscar movies, and award ent_2 is for the best picture in 2018. The award ent_0 is for the actress in a leading role in 2018, which may also help answer the question.", "relevant_entities": {{"ent_2": 0.8, "ent_0": 0.1, "ent_3": 0.1}}}} + + Question: {query} + Query Time: {query_time} + Solving Route: {route} + Entities: {top_k_entities_str} + + Output Format (flat JSON): {{"reason": "reason", "relevant_entities": {{"ent_i": 0.6, "ent_j": 0.3, ...}}}} + Output: + """ + pass + + +@generative +async def prune_relations( + query: str, + query_time: str, + route: List[str], + domain: str, + entity_str: str, + relations_str: str, + width: int, + hints: str +) -> RelevantRelations: + """-Goal- + You are given a question in the {domain} domain, its query time, a potential route to solve it, an entity, and a list of relations starting from it. + The goal is to retrieve up to {width} relations that contribute to answering the steps in the solving route and, therefore, answer the question. Rate their relevance from 0 to 1 (use at most 3 decimal places, and remove trailing zeros; the sum of the scores of these relations is 1). + + -Steps- + 1. You are provided a list of directed relations between entities in the format of + rel_i: (entity_type: entity_name)-[relation_type, desc: "description", props: {{key: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}}]->(entity_type: entity_name). + where "i" is the index, arrow symbol ("->" or "<-") is the relation direction, the percentage is confidence score, "ctx" is an optional context under which the value is valid. Each property may have only a single value, or multiple valid values of vary confidence under different context. + + 2. Retrieve relations only from the given list that contribute to answering the question, and provide a short reason for your scoring. + Return its index (rel_i) and score into a json of the format: {{"reason": "reason", "relevant_relations": {{"rel_i": score_i, "rel_i": score_j, ...}}}}. + (TIP: You will need to escape any double quotes in the string to make the JSON valid) + + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + Domain-specific Hints: + {hints} + + ###################### + -Examples- + Question: Which movie wins the best visual effect award in 2006 Oscars? + Solving Route: ["Identify the 2006 Oscars best visual effects winner directly from the knowledge graph"] + + Entity: (Award: VISUAL EFFECTS, properties: ) + Relations: rel_0: (Award: VISUAL EFFECTS)-[HELD_IN]->(Year: None) + rel_1: (Award: VISUAL EFFECTS)-[NOMINATED_FOR, properties: ]->(Movie: None) + rel_2: (Award: VISUAL EFFECTS)-[WON, properties: ]->(Movie: None) + Output: {{"reason": "The question is asking for movies that won the award, relation rel_2 is the most relevant to award winning. rel_1 is relation that find movies released in 2006 and may help find the movie that wins the award. A movie that won the award should also got nominated for the award, so rel_1 also has slight relevance. ", + "relevant_relations": {{"rel_2": 0.7, "rel_0": 0.2, "rel_1": 0.1}} + }} + ##### + + Question: {query} + Query Time: {query_time} + Solving Route: {route} + + Entity: {entity_str} + Relations: {relations_str} + + Output Format (flat JSON): {{"reason": "reason", "relevant_relations": {{"rel_i": score_i, "rel_i": score_j, ...}}}}. + Output: + """ + pass + + +@generative +async def prune_triplets( + query: str, + query_time: str, + route: List[str], + domain: str, + entity_str: str, + relations_str: str, + hints: str +) -> RelevantRelations: + """-Goal- + You are presented with a question in the {domain} domain, its query time, a potential route to solve it. + You will then given a source entity (type, name, description, and potential properties) and a list of directed relations starting from / ended at the source entity in the format of (source entity)-[relation]->(target entity). + The goal is to score the relations' contribution to answering the steps in the solving route and, therefore, answer the question. Rate them on a scale from 0 to 1 (use at most 3 decimal places, and remove trailing zeros; the sum of the scores of all relations is 1). + + -Steps- + 1. You are provided the source entity in the format of "(source_entity_type: source_entity_name, desc: "description", props: {{key1: val, key2: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}})" + where the percentage is confidence score, "ctx" is an optional context under which the value is valid. Each property may have only a single value, or multiple valid values of vary confidence under different context. + + 2. You are then provided a list of directed relations in the format of + "rel_i: (source_entity_type: source_entity_name)-[relation_type, desc: "description", props: {{key1: val, key2: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}}]->(entity_type: entity_name, desc: "description", props: {{key: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}})" + where "i" is the index, arrow symbol ("->" or "<-") is the relation direction, the percentage is confidence score, "ctx" is an optional context under which the value is valid. Each property may have only a single value, or multiple valid values of vary confidence under different context. + You are going to assess the relevance of the relation type and its properties, along with the target entity name and its properties, to the given question. + + 3. Score the relations' relevance to answering the question, and provide a short reason for your scoring. + Return its index (ent_i) and score into a valid JSON of the format: {{"reason": "reason", "relevant_relations": {{"rel_i": score_i, "rel_i": score_j, ...}}}}. + (TIP: You will need to escape any double quotes in the string to make the JSON valid) + + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + Domain-specific Hints: + {hints} + + ##### Examples ##### + Question: The movie featured Miley Cyrus and was produced by Tobin Armbrust? + Query Time: 03/19/2024, 22:59:20 PT + Solving Route: ["List movies produced by Tobin Armbrust", "Filter by movies featuring Miley Cyrus", "Identify the movie"] + + Source Entity: (Person: Tobin Armbrust) + Relations: rel_0: (Person: Tobin Armbrust)-[PRODUCED]->(Movie: The Resident) + rel_1: (Person: Tobin Armbrust)-[PRODUCED]->(Movie: So Undercover, properties: ) + rel_2: (Person: Tobin Armbrust)-[PRODUCED]->(Movie: Let Me In, properties: ) + rel_3: (Person: Tobin Armbrust)-[PRODUCED]->(Movie: Begin Again, properties: ) + rel_4: (Person: Tobin Armbrust)-[PRODUCED]->(Movie: A Walk Among the Tombstones, properties: ) + Output: {{"reason": "The movie that matches the given criteria is 'So Undercover' with Miley Cyrus and produced by Tobin Armbrust. Therefore, the score for 'So Undercover' would be 1, and the scores for all other entities would be 0.", "relevant_relations": {{"rel_1": 1.0}}}} + #### + + Question: {query} + Query Time: {query_time} + Solving Route: {route} + + Source Entity: {entity_str} + Relations: {relations_str} + + Output Format (flat JSON): {{"reason": "reason", "relevant_relations": {{"rel_i": score_i, "rel_i": score_j, ...}}}} + Output: + """ + pass + + +@generative +async def evaluate_knowledge_sufficiency( + query: str, + query_time: str, + route: List[str], + domain: str, + entities: str, + triplets: str, + hints: str +) -> EvaluationResult: + """-Goal- + You are presented with a question in the {domain} domain, its query time, and a potential route to solve it. Given the retrieved related entities and triplets from a noisy knowledge graph, you are asked to determine whether these references and your knowledge are sufficient to answer the question (Yes or No). + - If yes, answer the question using fewer than 50 words. + - If no, respond with 'I don't know'. + + 1. The entities will be given in the format of + "ent_i: (: , desc: "description", props: {{key_1: val, key_2: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}})" + The triplets will be given in the format of + "rel_i: (: )-[, desc: "description", props: {{key_1: val, key_2: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}}]->(: )" + where "i" is the index, arrow symbol ("->" or "<-") is the relation direction, "props" are associated properties of the entity or relation. + Each property may have a single value, or multiple valid values of vary confidence under different context. The percentage is confidence score, and "ctx" is the optional context under which the value is valid. + If multiple conflicting candidates are found, use the one with stronger supporting evidence such as temporal-aligned triplets or consists of additional supporting properties. If a more strongly justified answer exists, prefer it. + + 2. Return your judgment in a JSON of the format {{"sufficient": "Yes/No", "reason": "...", "answer": "..."}} (TIP: You will need to escape any double quotes in the string to make the JSON valid) + + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + Domain-specific Hints: + {hints} + + #### Examples #### + Question: Find the person who said "Taste cannot be controlled by law", what did this person die from? + Knowledge Triplets: Taste cannot be controlled by law., media_common.quotation.author, Thomas Jefferson + Output: {{"sufficient": "No", + "reason": "Based on the given knowledge triplets, it's not sufficient to answer the entire question. The triplets only provide information about the person who said 'Taste cannot be controlled by law,' which is Thomas Jefferson. To answer the second part of the question, it's necessary to have additional knowledge about where Thomas Jefferson's dead.", + "answer": "I don't know."}} + + Question: The artist nominated for The Long Winter lived where? + Knowledge Triplets: The Long Winter, book.written_work.author, Laura Ingalls Wilder + Laura Ingalls Wilder, people.person.places_lived, Unknown-Entity + Unknown-Entity, people.place_lived.location, De Smet + Output: {{"sufficient": "Yes", + "reason": "Based on the given knowledge triplets, the author of The Long Winter, Laura Ingalls Wilder, lived in De Smet. Therefore, the answer to the question is De Smet.", + "answer": "De Smet."}} + + Question: Who is the coach of the team owned by Steve Bisciotti? + Knowledge Triplets: Steve Bisciotti, sports.professional_sports_team.owner_s, Baltimore Ravens + Steve Bisciotti, sports.sports_team_owner.teams_owned, Baltimore Ravens + Steve Bisciotti, organization.organization_founder.organizations_founded, Allegis Group + Output: {{"sufficient": "No", + "reason": "Based on the given knowledge triplets, the coach of the team owned by Steve Bisciotti is not explicitly mentioned. However, it can be inferred that the team owned by Steve Bisciotti is the Baltimore Ravens, a professional sports team. Therefore, additional knowledge about the current coach of the Baltimore Ravens can be used to answer the question.", + "answer": "I don't know."}} + + Question: Rift Valley Province is located in a nation that uses which form of currency? + Knowledge Triplets: Rift Valley Province, location.administrative_division.country, Kenya + Rift Valley Province, location.location.geolocation, UnName_Entity + Rift Valley Province, location.mailing_address.state_province_region, UnName_Entity + Kenya, location.country.currency_used, Kenyan shilling + Output: {{"sufficient": "Yes", + "reason": "Based on the given knowledge triplets, Rift Valley Province is located in Kenya, which uses the Kenyan shilling as its currency. Therefore, the answer to the question is Kenyan shilling.", + "answer": "Kenyan shilling."}} + + Question: The country with the National Anthem of Bolivia borders which nations? + Knowledge Triplets: National Anthem of Bolivia, government.national_anthem_of_a_country.anthem, UnName_Entity + National Anthem of Bolivia, music.composition.composer, Leopoldo Benedetto Vincenti + National Anthem of Bolivia, music.composition.lyricist, José Ignacio de Sanjinés + UnName_Entity, government.national_anthem_of_a_country.country, Bolivia + Bolivia, location.country.national_anthem, UnName_Entity + Output: {{"sufficient": "No", + "reason": "Based on the given knowledge triplets, we can infer that the National Anthem of Bolivia is the anthem of Bolivia. Therefore, the country with the National Anthem of Bolivia is Bolivia itself. However, the given knowledge triplets do not provide information about which nations border Bolivia. To answer this question, we need additional knowledge about the geography of Bolivia and its neighboring countries.", + "answer": "I don't know."}} + + Question: {query} + Query Time: {query_time} + Solving Route: {route} + Knowledge Entities: {entities} + Knowledge Triplets: {triplets} + + Output Format (flat JSON): {{"sufficient": "Yes/No", "reason": "...", "answer": "..."}} + Output: + """ + pass + + +@generative +async def validate_consensus( + query: str, + query_time: str, + domain: str, + attempt: str, + routes_info: str, + hints: str +) -> ValidationResult: + """-Goal- + You are presented with a question in the {domain} domain, and its query time. The goal is to answer the question *accurately* - you will be rewarded for correctly answering the question, *penalized* by providing a wrong answer. + + A confident but careless friend has provided us a tentative answer, denote as "attempt". We don't really trust it, so we have identified a list of potential routes to solve it. So far, we have followed a portion of the routes, retrieved a list of potential associated retrieved knowledge graph entities and triplets (entity, relation, entity), and provided tentative answers. + The entities will be given in the format of + "ent_i: (: , desc: "description", props: {{key: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}})" + The triplets will be given in the format of + "rel_i: (: )-[, desc: "description", props: {{key: [val_1 (70%, ctx:"context"), val_2 (30%, ctx:"context")], ...}}]->(: )" + where "i" is the index, arrow symbol ("->" or "<-") is the relation direction, the percentage is confidence score, "ctx" is an optional context under which the value is valid. Each property may have only a single value, or multiple valid values of vary confidence under different context. + + You will act as a rigorous judge to whether the answers reach a consensus or not before running out of solving routes. Consensus is defined by at least a half of the answers (including my friend's attempt) agree on a specific answer. + Please exactly follow these strategies to guarantee that your answer will perform at least better than my friend: + + 1. If there is a consensus, then respond with "Yes", and summarize them into a final answer following with a summarized explanation. + + 2. If there is not consensus, and there are still unexplored solving routes, then respond with "No", and don't provide a final answer. We will continue exploring the next solving route. + + 3. If there is not consensus, and we run out of unexplored solving route, you have to respond with "Yes", and summarize them into a final answer following with a summarized explanation. + If multiple conflicting answers are found, use the one with more votes (consensus), stronger supporting evidence such as temporal-aligned triplets or consists of additional supporting properties. If a more strongly justified answer exists, prefer it. + + 4. Lastly, if none of the solving routes give a resonable answer (all "I don't know"), then fall back to use my friend's attempt. + + If the references do not contain the necessary information to answer the question, respond with 'I don't know'. + Remember, you will be rewarded for correctly answering the question, penalized by providing a wrong answer. There is no reward or penalty if you answer "I don't know", which is more preferable than providing a wrong answer. + + Please return the output in a JSON of the format: {{"judgement": "Yes/No", "final_answer": ". "}} + + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + Domain-specific Hints: + {hints} + + Question: {query} + Query Time: {query_time} + Attempt: {attempt} + {routes_info} + """ + pass + + +@generative +async def generate_direct_answer( + query: str, + query_time: str, + domain: str +) -> DirectAnswer: + """-Goal- + You are provided with a question in the {domain} domain, and its query time. You are asked to determine whether your knowledge are sufficient to answer the question (Yes or No). + - If yes, answer the question succinctly, using the fewest words possible. + - If no, respond with 'I don't know'. + Please explain your reasoning and provide supporting evidence from your knowledge to support your answer. + + Return your judgment in a JSON of the format {{"sufficient": "Yes/No", "reason": "...", "answer": "..."}} (TIP: You will need to escape any double quotes in the string to make the JSON valid) + *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* + + #### Examples #### + Question: What state is home to the university that is represented in sports by George Washington Colonials men's basketball? + Output: {{"sufficient": "Yes", + "reason": "First, the education institution has a sports team named George Washington Colonials men's basketball in is George Washington University , Second, George Washington University is in Washington D.C. The answer is Washington, D.C.", + "answer": "Washington, D.C."}} + + Question: Who lists Pramatha Chaudhuri as an influence and wrote Jana Gana Mana? + Output: {{"sufficient": "Yes", + "reason": "First, Bharoto Bhagyo Bidhata wrote Jana Gana Mana. Second, Bharoto Bhagyo Bidhata lists Pramatha Chaudhuri as an influence. The answer is Bharoto Bhagyo Bidhata.", + "answer": "Bharoto Bhagyo Bidhata"}} + + + Question: Who was the artist nominated for an award for You Drive Me Crazy? + Output: {{"sufficient": "Yes", + "reason": "First, the song 'You Drive Me Crazy' was performed by Britney Spears. Second, Britney Spears was nominated for awards for this song. The answer is Britney Spears.", + "answer": "Britney Spears"}} + + + Question: What person born in Siegen influenced the work of Vincent Van Gogh? + Output: {{"sufficient": "Yes", + "reason": " First, Peter Paul Rubens, Claude Monet and etc. influenced the work of Vincent Van Gogh. Second, Peter Paul Rubens born in Siegen. The answer is Peter Paul Rubens.", + "answer": "Peter Paul Rubens"}} + + + Question: What is the country close to Russia where Mikheil Saakashvii holds a government position? + Output: {{"sufficient": "Yes", + "reason": "First, China, Norway, Finland, Estonia and Georgia is close to Russia. Second, Mikheil Saakashvii holds a government position at Georgia. The answer is Georgia.", + "answer": "Georgia"}} + + + Question: What drug did the actor who portrayed the character Urethane Wheels Guy overdosed on? + Output: {{"sufficient": "Yes", + "reason": "First, Mitchell Lee Hedberg portrayed character Urethane Wheels Guy. Second, Mitchell Lee Hedberg overdose Heroin. The answer is Heroin.", + "answer": "Heroin"}} + + Question: {query} + Query Time: {query_time} + + Output Format (flat JSON): {{"sufficient": "Yes/No", "reason": "...", "answer": "..."}} + Output: + """ + pass + + +# Update Generative Functions (will be implemented similarly) +@generative +async def extract_entities_and_relations( + doc_context: str, + domain: str, + hints: str, + reference: str, + entity_types: str = "", + relation_types: str = "" +) -> ExtractionResult: + """Extract entities and relations from a document context. + + See full docstring in source repository for complete extraction guidelines. + """ + pass + + +@generative +async def align_entity_with_kg( + extracted_entity_name: str, + extracted_entity_type: str, + extracted_entity_desc: str, + candidate_entities: str, + domain: str, + doc_text: str = "" +): + """Align extracted entity with knowledge graph candidates.""" + pass + + +@generative +async def decide_entity_merge( + entity_pair: str, + doc_text: str, + domain: str +): + """Decide whether to merge two entities.""" + pass + + +@generative +async def align_relation_with_kg( + extracted_relation: str, + candidate_relations: str, + synonym_relations: str, + domain: str, + doc_text: str = "" +): + """Align extracted relation with knowledge graph candidates.""" + pass + + +@generative +async def decide_relation_merge( + relation_pair: str, + doc_text: str, + domain: str +): + """Decide whether to merge two relations.""" + pass diff --git a/mellea_contribs/kg/components/llm_guided.py b/mellea_contribs/kg/components/llm_guided.py new file mode 100644 index 0000000..19d1b74 --- /dev/null +++ b/mellea_contribs/kg/components/llm_guided.py @@ -0,0 +1,98 @@ +"""LLM-guided query construction for knowledge graphs. + +Uses Mellea's @generative pattern to convert natural language into graph queries +and to explain/repair query results. + +.. note:: + + ``natural_language_to_cypher``, ``explain_query_result``, and + ``suggest_query_improvement`` are **planned future functionality**. They + are fully implemented ``@generative`` functions but are not yet wired into + the main orchestration pipeline (``kgrag.py``). They will be integrated in + a future release as an optional Layer 3 query-construction step. +""" + +from typing import Any + +from mellea import generative +from pydantic import BaseModel + + +class GeneratedQuery(BaseModel): + """Pydantic model for a generated graph query.""" + + query: str + explanation: str + parameters: dict[str, Any] | None = None + + +@generative +async def natural_language_to_cypher( + natural_language_query: str, + graph_schema: str, + examples: str, +) -> GeneratedQuery: + """Generate a Cypher query from a natural language question. + + Given a natural language question and the graph schema, generate a + valid Cypher query that answers the question. + + Graph Schema: + {graph_schema} + + Examples: + {examples} + + Question: {natural_language_query} + + Generate a Cypher query to answer this question. Return as JSON: + {{"query": "MATCH ...", "explanation": "This query...", "parameters": {{}}}} + + Query:""" + pass + + +@generative +async def explain_query_result( + query: str, + result: str, + original_question: str, +) -> str: + """Explain a graph query result in natural language. + + Original Question: {original_question} + + Query Executed: + {query} + + Results: + {result} + + Explain what these results mean in relation to the original question. + Write a clear, natural language answer. + + Answer:""" + pass + + +@generative +async def suggest_query_improvement( + query: str, + error_message: str, + schema: str, +) -> GeneratedQuery: + """Suggest a corrected query based on an error message. + + The following query failed: + {query} + + Error: {error_message} + + Graph Schema: + {schema} + + Suggest a corrected version of the query. Return as JSON: + {{"query": "...", "explanation": "The issue was...", "parameters": {{}}}} + + Corrected Query:""" + pass diff --git a/mellea_contribs/kg/components/query.py b/mellea_contribs/kg/components/query.py new file mode 100644 index 0000000..db05227 --- /dev/null +++ b/mellea_contribs/kg/components/query.py @@ -0,0 +1,325 @@ +"""Graph query components (Layer 2: full Mellea Component implementations).""" + +from copy import deepcopy +from typing import Any + +from mellea.stdlib.components import CBlock, Component, TemplateRepresentation, blockify + + +class GraphQuery(Component): + """Base Component for graph queries. + + Represents a graph query that can be executed against a GraphBackend + and formatted for LLM consumption. + + Follows Mellea patterns: + - Private fields with _ prefix + - Public properties for read access + - format_for_llm() returns TemplateRepresentation + - Immutable updates via deepcopy + """ + + def __init__( + self, + query_string: str | CBlock | None = None, + parameters: dict | None = None, + description: str | CBlock | None = None, + metadata: dict | None = None, + ): + """Initialize a graph query. + + Args: + query_string: The actual query (Cypher, SPARQL, etc.) + parameters: Query parameters for parameterized queries + description: Natural language description of what the query does + metadata: Additional metadata (schema hints, temporal constraints, etc.) + """ + self._query_string = blockify(query_string) if query_string is not None else None + self._parameters = parameters or {} + self._description = blockify(description) if description is not None else None + self._metadata = metadata or {} + + @property + def query_string(self) -> str | None: + """The query string.""" + return str(self._query_string) if self._query_string is not None else None + + @property + def parameters(self) -> dict[str, Any]: + """Query parameters.""" + return self._parameters + + @property + def description(self) -> str | None: + """Natural language description of the query.""" + return str(self._description) if self._description is not None else None + + @property + def metadata(self) -> dict[str, Any]: + """Additional query metadata.""" + return self._metadata + + def parts(self) -> list[Component | CBlock]: + """The constituent parts of this query.""" + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> TemplateRepresentation: + """Format query for LLM consumption.""" + return TemplateRepresentation( + obj=self, + args={ + "description": self.description or "Graph query", + "query": self.query_string, + "parameters": self._parameters, + "metadata": self._metadata, + }, + tools=None, + images=None, + template_order=["*", "GraphQuery"], + ) + + def with_description(self, description: str | CBlock) -> "GraphQuery": + """Return a new query with an updated description (immutable). + + Args: + description: New natural language description. + + Returns: + New GraphQuery with the updated description. + """ + result = deepcopy(self) + result._description = blockify(description) + return result + + def with_parameters(self, **params: Any) -> "GraphQuery": + """Return a new query with additional parameters merged in (immutable). + + Args: + **params: Parameters to merge into the current set. + + Returns: + New GraphQuery with the merged parameters. + """ + result = deepcopy(self) + result._parameters = {**self._parameters, **params} + return result + + def with_metadata(self, **metadata: Any) -> "GraphQuery": + """Return a new query with additional metadata merged in (immutable). + + Args: + **metadata: Metadata entries to merge. + + Returns: + New GraphQuery with the merged metadata. + """ + result = deepcopy(self) + result._metadata = {**self._metadata, **metadata} + return result + + +class CypherQuery(GraphQuery): + """Component for building Cypher queries (Neo4j). + + Provides a fluent, composable interface for building Cypher queries. + Each builder method returns a new instance (immutable). + + Example: + query = ( + CypherQuery() + .match("(m:Movie)") + .where("m.year = $year") + .return_("m.title", "m.year") + .order_by("m.year DESC") + .limit(10) + .with_parameters(year=2020) + ) + """ + + def __init__( + self, + query_string: str | CBlock | None = None, + parameters: dict | None = None, + description: str | CBlock | None = None, + metadata: dict | None = None, + match_clauses: list[str] | None = None, + where_clauses: list[str] | None = None, + return_clauses: list[str] | None = None, + order_by: list[str] | None = None, + limit: int | None = None, + ): + """Initialize a Cypher query builder. + + Args: + query_string: Explicit query string (bypasses clause building when provided). + parameters: Query parameters. + description: Natural language description. + metadata: Additional metadata. + match_clauses: MATCH clause patterns. + where_clauses: WHERE conditions. + return_clauses: RETURN expressions. + order_by: ORDER BY expressions. + limit: LIMIT value. + """ + self._match_clauses: list[str] = match_clauses or [] + self._where_clauses: list[str] = where_clauses or [] + self._return_clauses: list[str] = return_clauses or [] + self._order_by: list[str] = order_by or [] + self._limit: int | None = limit + + # Build query string from clauses if not provided + if query_string is None and self._match_clauses: + query_string = self._build_query_string( + self._match_clauses, + self._where_clauses, + self._return_clauses, + self._order_by, + self._limit, + ) + + super().__init__(query_string, parameters, description, metadata) + + @staticmethod + def _build_query_string( + match: list[str], + where: list[str], + return_: list[str], + order: list[str], + limit: int | None, + ) -> str: + """Build a Cypher query string from clause lists.""" + parts = [] + if match: + parts.append("MATCH " + ", ".join(match)) + if where: + parts.append("WHERE " + " AND ".join(where)) + if return_: + parts.append("RETURN " + ", ".join(return_)) + if order: + parts.append("ORDER BY " + ", ".join(order)) + if limit is not None: + parts.append(f"LIMIT {limit}") + return "\n".join(parts) + + def _rebuild(self) -> "CypherQuery": + """Rebuild the query string from the current clause lists.""" + if self._match_clauses or self._return_clauses: + self._query_string = blockify( + self._build_query_string( + self._match_clauses, + self._where_clauses, + self._return_clauses, + self._order_by, + self._limit, + ) + ) + return self + + def match(self, pattern: str) -> "CypherQuery": + """Add a MATCH clause (immutable). + + Args: + pattern: Cypher MATCH pattern, e.g. "(n:Person)". + + Returns: + New CypherQuery with the clause appended. + """ + result = deepcopy(self) + result._match_clauses = [*self._match_clauses, pattern] + return result._rebuild() + + def where(self, condition: str) -> "CypherQuery": + """Add a WHERE condition (immutable). + + Args: + condition: Cypher WHERE condition, e.g. "n.age > 30". + + Returns: + New CypherQuery with the condition appended. + """ + result = deepcopy(self) + result._where_clauses = [*self._where_clauses, condition] + return result._rebuild() + + def return_(self, *items: str) -> "CypherQuery": + """Add RETURN expressions (immutable). + + Args: + *items: One or more RETURN expressions. + + Returns: + New CypherQuery with the expressions appended. + """ + result = deepcopy(self) + result._return_clauses = [*self._return_clauses, *items] + return result._rebuild() + + def order_by(self, *fields: str) -> "CypherQuery": + """Add ORDER BY fields (immutable). + + Args: + *fields: One or more ORDER BY expressions. + + Returns: + New CypherQuery with the fields appended. + """ + result = deepcopy(self) + result._order_by = [*self._order_by, *fields] + return result._rebuild() + + def limit(self, n: int) -> "CypherQuery": + """Set the LIMIT clause (immutable). + + Args: + n: Maximum number of results. + + Returns: + New CypherQuery with the limit set. + """ + result = deepcopy(self) + result._limit = n + return result._rebuild() + + def format_for_llm(self) -> TemplateRepresentation: + """Format the Cypher query for LLM consumption.""" + return TemplateRepresentation( + obj=self, + args={ + "description": self.description or "Cypher graph query", + "query": self.query_string, + "parameters": self._parameters, + "query_type": "Cypher (Neo4j)", + }, + tools=None, + images=None, + template_order=["*", "CypherQuery", "GraphQuery"], + ) + + +class SparqlQuery(GraphQuery): + """Component for SPARQL queries (RDF/triple stores). + + Extends GraphQuery with SPARQL-specific formatting. + + .. note:: + + ``SparqlQuery`` is **planned future functionality** for RDF/triple-store + backends. The current production backend is Neo4j (``CypherQuery``). + ``SparqlQuery`` is exported for completeness but has no callers in the + current pipeline. + """ + + def format_for_llm(self) -> TemplateRepresentation: + """Format the SPARQL query for LLM consumption.""" + return TemplateRepresentation( + obj=self, + args={ + "description": self.description or "SPARQL graph query", + "query": self.query_string, + "parameters": self._parameters, + "query_type": "SPARQL", + }, + tools=None, + images=None, + template_order=["*", "SparqlQuery", "GraphQuery"], + ) diff --git a/mellea_contribs/kg/components/result.py b/mellea_contribs/kg/components/result.py new file mode 100644 index 0000000..59881fa --- /dev/null +++ b/mellea_contribs/kg/components/result.py @@ -0,0 +1,220 @@ +"""Graph result component (Layer 2: full Mellea Component implementation).""" + +import json +from typing import TYPE_CHECKING, Any + +from mellea.stdlib.components import CBlock, Component, TemplateRepresentation + +from mellea_contribs.kg.base import GraphEdge, GraphNode, GraphPath + +if TYPE_CHECKING: + from mellea_contribs.kg.components.query import GraphQuery + + +class GraphResult(Component): + """Component for graph query results. + + Formats query results for LLM consumption in one of several styles: + - "triplets": ``(Subject:Label)-[EDGE_TYPE]->(Object:Label)`` format + - "natural": Short natural language sentences per edge + - "paths": Narrative descriptions of graph paths + - "structured": JSON representation of nodes and edges + + Follows Mellea patterns: + - Private fields with _ prefix + - Public properties for read access + - format_for_llm() returns TemplateRepresentation + """ + + def __init__( + self, + nodes: list[GraphNode] | None = None, + edges: list[GraphEdge] | None = None, + paths: list[GraphPath] | None = None, + raw_result: Any | None = None, + query: "GraphQuery | None" = None, + format_style: str = "triplets", + ): + """Initialize a graph result. + + Args: + nodes: Nodes returned by the query. + edges: Edges returned by the query. + paths: Paths returned by the query. + raw_result: Raw backend result for debugging. + query: The query that produced this result. + format_style: One of "triplets", "natural", "paths", "structured". + """ + self._nodes = nodes or [] + self._edges = edges or [] + self._paths = paths or [] + self._raw_result = raw_result + self._query = query + self._format_style = format_style + + # --- public properties --- + + @property + def nodes(self) -> list[GraphNode]: + """Nodes in the result.""" + return self._nodes + + @property + def edges(self) -> list[GraphEdge]: + """Edges in the result.""" + return self._edges + + @property + def paths(self) -> list[GraphPath]: + """Paths in the result.""" + return self._paths + + @property + def raw_result(self) -> Any: + """Raw backend result.""" + return self._raw_result + + @property + def query(self) -> "GraphQuery | None": + """The query that produced this result.""" + return self._query + + @property + def format_style(self) -> str: + """Active format style.""" + return self._format_style + + # --- Component protocol --- + + def parts(self) -> list[Component | CBlock]: + """The constituent parts of this result.""" + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> TemplateRepresentation: + """Format the result for LLM consumption. + + Delegates to the appropriate formatter based on format_style. + """ + formatters = { + "triplets": self._format_triplets, + "natural": self._format_natural, + "paths": self._format_paths, + "structured": self._format_structured, + } + formatter = formatters.get(self._format_style, self._format_triplets) + formatted_text = formatter() + + return TemplateRepresentation( + obj=self, + args={ + "format_style": self._format_style, + "result": formatted_text, + "node_count": len(self._nodes), + "edge_count": len(self._edges), + }, + tools=None, + images=None, + template_order=["*", "GraphResult"], + ) + + # --- format helpers --- + + def _node_label(self, node: GraphNode) -> str: + """Return a human-readable label for a node.""" + name = ( + node.properties.get("name") + or node.properties.get("title") + or node.properties.get("id") + or node.id + ) + return f"{node.label}:{name}" + + def _format_triplets(self) -> str: + """Format as (Subject)-[PREDICATE]->(Object) triplets.""" + if not self._edges and not self._nodes: + return "(no results)" + + lines = [] + for edge in self._edges: + src = self._node_label(edge.source) + tgt = self._node_label(edge.target) + lines.append(f"({src})-[{edge.label}]->({tgt})") + + # Standalone nodes that appear in no edge + edge_node_ids = { + nid + for edge in self._edges + for nid in (edge.source.id, edge.target.id) + } + for node in self._nodes: + if node.id not in edge_node_ids: + lines.append(f"({self._node_label(node)})") + + return "\n".join(lines) + + def _format_natural(self) -> str: + """Format as natural language sentences.""" + if not self._edges and not self._nodes: + return "No results found." + + lines = [] + for edge in self._edges: + src = self._node_label(edge.source) + tgt = self._node_label(edge.target) + rel = edge.label.replace("_", " ").lower() + lines.append(f"{src} {rel} {tgt}.") + + edge_node_ids = { + nid + for edge in self._edges + for nid in (edge.source.id, edge.target.id) + } + for node in self._nodes: + if node.id not in edge_node_ids: + lines.append(f"{self._node_label(node)} is present in the graph.") + + return "\n".join(lines) + + def _format_paths(self) -> str: + """Format graph paths as narratives.""" + if not self._paths and not self._edges: + return "No paths found." + + lines = [] + + for path in self._paths: + if not path.nodes: + continue + segments = [f"({self._node_label(path.nodes[0])})"] + for i, edge in enumerate(path.edges): + next_node = path.nodes[i + 1] if i + 1 < len(path.nodes) else None + segments.append(f"-[{edge.label}]->") + if next_node: + segments.append(f"({self._node_label(next_node)})") + lines.append("".join(segments)) + + # Fall back to triplets if no explicit paths + if not lines: + return self._format_triplets() + + return "\n".join(lines) + + def _format_structured(self) -> str: + """Format as a JSON structure.""" + data: dict[str, Any] = { + "nodes": [ + {"id": n.id, "label": n.label, "properties": n.properties} + for n in self._nodes + ], + "edges": [ + { + "id": e.id, + "source": e.source.id, + "target": e.target.id, + "label": e.label, + "properties": e.properties, + } + for e in self._edges + ], + } + return json.dumps(data, indent=2, default=str) diff --git a/mellea_contribs/kg/components/traversal.py b/mellea_contribs/kg/components/traversal.py new file mode 100644 index 0000000..11abb19 --- /dev/null +++ b/mellea_contribs/kg/components/traversal.py @@ -0,0 +1,160 @@ +"""Graph traversal component (Layer 2: full Mellea Component implementation).""" + +from collections.abc import Callable + +from mellea.stdlib.components import CBlock, Component, TemplateRepresentation + +from mellea_contribs.kg.base import GraphEdge, GraphNode + + +class GraphTraversal(Component): + """Component for high-level graph traversal patterns. + + Represents a multi-hop traversal intent that can be converted to a + backend-specific query (e.g. Cypher) and formatted for LLM consumption. + + .. note:: + + ``GraphTraversal`` and its ``to_cypher()`` method are **planned future + functionality**. The current pipeline (``kgrag.py``) builds all Cypher + queries directly via ``@generative`` functions. ``GraphTraversal`` will + be integrated as a structured intermediate representation in a future + release. + + Supported patterns: + - "multi_hop": Follow all relationships up to max_depth hops + - "shortest_path": Find the shortest path between two node sets + - "bfs": Breadth-first traversal from start nodes + - "dfs": Depth-first traversal from start nodes + + Follows Mellea patterns: + - Private fields with _ prefix + - Public properties for read access + - format_for_llm() returns TemplateRepresentation + """ + + def __init__( + self, + start_nodes: list[str], + pattern: str = "multi_hop", + max_depth: int = 3, + edge_filter: Callable[[GraphEdge], bool] | None = None, + node_filter: Callable[[GraphNode], bool] | None = None, + description: str | None = None, + ): + """Initialize a traversal pattern. + + Args: + start_nodes: Starting node IDs or labels. + pattern: Traversal pattern ("multi_hop", "shortest_path", "bfs", "dfs"). + max_depth: Maximum depth to traverse. + edge_filter: Optional filter function for edges. + node_filter: Optional filter function for nodes. + description: Natural language description of the traversal intent. + """ + self._start_nodes = start_nodes + self._pattern = pattern + self._max_depth = max_depth + self._edge_filter = edge_filter + self._node_filter = node_filter + self._description = description + + # --- public properties --- + + @property + def start_nodes(self) -> list[str]: + """Starting node IDs or labels.""" + return self._start_nodes + + @property + def pattern(self) -> str: + """Traversal pattern name.""" + return self._pattern + + @property + def max_depth(self) -> int: + """Maximum traversal depth.""" + return self._max_depth + + @property + def edge_filter(self) -> Callable[[GraphEdge], bool] | None: + """Optional edge filter function.""" + return self._edge_filter + + @property + def node_filter(self) -> Callable[[GraphNode], bool] | None: + """Optional node filter function.""" + return self._node_filter + + @property + def description(self) -> str | None: + """Natural language description of the traversal.""" + return self._description + + # --- Component protocol --- + + def parts(self) -> list[Component | CBlock]: + """The constituent parts of this traversal.""" + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> TemplateRepresentation: + """Format the traversal for LLM consumption.""" + return TemplateRepresentation( + obj=self, + args={ + "description": self._description or f"{self._pattern} traversal", + "start_nodes": self._start_nodes, + "pattern": self._pattern, + "max_depth": self._max_depth, + "cypher": self.to_cypher().query_string, + }, + tools=None, + images=None, + template_order=["*", "GraphTraversal"], + ) + + # --- query conversion --- + + def to_cypher(self) -> "CypherQuery": # type: ignore[name-defined] + """Convert this traversal to an equivalent CypherQuery. + + Returns: + CypherQuery that implements the traversal pattern. + + Raises: + ValueError: If the pattern is not supported. + """ + from mellea_contribs.kg.components.query import CypherQuery + + if self._pattern in ("multi_hop", "bfs", "dfs"): + return ( + CypherQuery() + .match(f"(start)-[*1..{self._max_depth}]->(end)") + .where("start.id IN $start_nodes") + .return_("start", "end") + .with_parameters(start_nodes=self._start_nodes) + .with_description( + self._description + or f"{self._pattern} traversal from {self._start_nodes}" + ) + ) + + if self._pattern == "shortest_path": + return ( + CypherQuery() + .match( + f"path = shortestPath((start)-[*1..{self._max_depth}]->(end))" + ) + .where("start.id IN $start_nodes") + .return_("path") + .with_parameters(start_nodes=self._start_nodes) + .with_description( + self._description + or f"Shortest path from {self._start_nodes}" + ) + ) + + raise ValueError( + f"Unsupported traversal pattern: {self._pattern!r}. " + "Supported: 'multi_hop', 'bfs', 'dfs', 'shortest_path'." + ) diff --git a/mellea_contribs/kg/embed_models.py b/mellea_contribs/kg/embed_models.py new file mode 100644 index 0000000..0d15b27 --- /dev/null +++ b/mellea_contribs/kg/embed_models.py @@ -0,0 +1,139 @@ +"""Pydantic models for KG embedding configuration. + +This module provides configuration models for the embedding pipeline. +""" + +from typing import Optional + +from pydantic import BaseModel, Field + + +class EmbeddingConfig(BaseModel): + """Configuration for embedding model and API. + + Specifies which embedding model to use and how to connect to it. + """ + + model: str = Field( + default="text-embedding-3-small", + description="Name of embedding model (LiteLLM compatible). " + "Examples: 'text-embedding-3-small', 'text-embedding-3-large', " + "'all-MiniLM-L6-v2' (HuggingFace), 'nomic-embed-text' (Ollama)", + ) + + api_base: Optional[str] = Field( + default=None, + description="API base URL for custom embedding service. " + "If None, uses default LiteLLM routing", + ) + + api_key: Optional[str] = Field( + default=None, + description="API key for embedding service (if required). " + "Can be set via environment variable", + ) + + dimension: int = Field( + default=1536, + description="Dimension of embedding vectors. " + "Default 1536 for OpenAI text-embedding-3-small. " + "384 for all-MiniLM-L6-v2, 768 for nomic-embed-text", + ) + + batch_size: int = Field( + default=10, + description="Number of entities to embed in parallel per batch", + ) + + +class EmbeddingResult(BaseModel): + """Result of embedding a single entity or document. + + Contains the original text and its vector representation. + """ + + text: str = Field(description="Original text that was embedded") + + embedding: list[float] = Field(description="Vector embedding") + + model: str = Field(description="Model name used for embedding") + + dimension: int = Field(description="Dimension of the embedding vector") + + +class EmbeddingSimilarity(BaseModel): + """Result of similarity search. + + Represents an entity matched by embedding similarity. + """ + + entity_id: str = Field(description="ID of the matched entity") + + entity_name: str = Field(description="Name of the matched entity") + + similarity_score: float = Field( + description="Similarity score (0-1, where 1 is most similar)" + ) + + entity_type: Optional[str] = Field( + default=None, description="Type/label of the entity" + ) + + +class EmbeddingStats(BaseModel): + """Statistics about embedding operations. + + Tracks performance and results of batch embedding, including optional + storage and index creation counts when used in a full pipeline. + """ + + total_entities: int = Field(description="Total entities processed") + + successful_embeddings: int = Field(description="Number of successfully embedded entities") + + failed_embeddings: int = Field(description="Number of entities that failed to embed") + + skipped_embeddings: int = Field(description="Number of entities skipped (e.g., already embedded)") + + average_embedding_time: float = Field( + description="Average time (seconds) to embed one entity" + ) + + total_time: float = Field(description="Total time (seconds) for the batch") + + model_used: str = Field(description="Embedding model used") + + # Pipeline-level fields (populated by embed_and_store_all) + dimension: int = Field(default=0, description="Embedding vector dimension") + + total_relations: int = Field(default=0, description="Total relations processed") + + successful_relation_embeddings: int = Field( + default=0, description="Number of successfully embedded relations" + ) + + failed_relation_embeddings: int = Field( + default=0, description="Number of relations that failed to embed" + ) + + entities_stored: int = Field(default=0, description="Embeddings stored back to the graph DB") + + relations_stored: int = Field( + default=0, description="Relation embeddings stored back to the graph DB" + ) + + vector_indices_created: int = Field( + default=0, description="Vector indices created for similarity search" + ) + + success: bool = Field(default=True, description="Whether the pipeline completed successfully") + + error_message: str = Field(default="", description="Error message if pipeline failed") + + +__all__ = [ + "EmbeddingConfig", + "EmbeddingResult", + "EmbeddingSimilarity", + "EmbeddingStats", +] diff --git a/mellea_contribs/kg/embedder.py b/mellea_contribs/kg/embedder.py new file mode 100644 index 0000000..45d64c5 --- /dev/null +++ b/mellea_contribs/kg/embedder.py @@ -0,0 +1,802 @@ +"""KG Embedder: Layer 1 application for generating vector embeddings for KG entities. + +This module provides embedding infrastructure for converting entities and relations +into vector representations using LiteLLM's embedding API. + +The architecture follows Mellea's Layer 1 pattern: +- Layer 1: KGEmbedder (this module) orchestrates embedding operations +- Layer 3: Can integrate with LLM session for embedding generation +- Layer 4: Uses GraphBackend for storing/retrieving entities + +Example:: + + import asyncio + from mellea import start_session + from mellea_contribs.kg import MockGraphBackend, Entity + from mellea_contribs.kg.embedder import KGEmbedder + + async def main(): + session = start_session(backend_name="litellm", model_id="gpt-3.5-turbo") + backend = MockGraphBackend() + embedder = KGEmbedder( + session=session, + embedding_model="text-embedding-3-small", + embedding_dimension=1536 + ) + + # Create an entity + entity = Entity( + type="Movie", + name="Avatar", + description="A science fiction film directed by James Cameron", + paragraph_start="Avatar is", + paragraph_end="by Cameron." + ) + + # Generate embedding + entity_with_embedding = await embedder.embed_entity(entity) + assert entity_with_embedding.embedding is not None + assert len(entity_with_embedding.embedding) == 1536 + + # Find similar entities + similar = await embedder.get_similar_entities( + entity_with_embedding, + [entity_with_embedding], # Search against the same entity for demo + top_k=1 + ) + assert len(similar) == 1 + + await backend.close() + + asyncio.run(main()) +""" + +import logging +import asyncio +import math +import time +from datetime import datetime +from typing import Optional + +from mellea import MelleaSession + +from mellea_contribs.kg.embed_models import EmbeddingStats +from mellea_contribs.kg.graph_dbs.base import GraphBackend +from mellea_contribs.kg.models import Entity, Relation + +logger = logging.getLogger(__name__) + + +class KGEmbedder: + """Generates and manages vector embeddings for KG entities and relations. + + This is a Layer 1 application that orchestrates embedding operations. + It uses LiteLLM's embedding API for generating vector representations. + + The class supports: + - Embedding individual entities + - Batch embedding of multiple entities + - Finding similar entities by embedding distance (cosine similarity) + - Persistence through GraphBackend + """ + + def __init__( + self, + session: MelleaSession, + model: str = "text-embedding-3-small", + dimension: int = 1536, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + extra_headers: Optional[dict] = None, + batch_size: int = 10, + backend: Optional[GraphBackend] = None, + ): + """Initialize the KG embedder using individual parameters (Mellea Layer 1 pattern). + + Matches the pattern used by KGRag and KGPreprocessor with individual + configuration parameters rather than a config object. + + Args: + session: MelleaSession for LLM operations (required) + model: Name of embedding model (LiteLLM compatible). + Default: "text-embedding-3-small" (OpenAI model) + dimension: Dimension of embedding vectors. + Default: 1536 (OpenAI's embedding size) + api_base: Optional API base URL for custom embedding service. + If None, uses default LiteLLM routing + api_key: Optional API key for embedding service. When api_base is + set and no key is provided, defaults to "dummy" so that + OpenAI-compatible endpoints that authenticate via custom + headers (e.g. RITS) do not raise an auth error. + extra_headers: Optional dict of extra HTTP headers forwarded to + the embedding endpoint (e.g. {"RITS_API_KEY": "..."}). + batch_size: Number of entities to embed in parallel per batch. + Default: 10 + backend: Optional GraphBackend for persisting embeddings + """ + + self.session = session + # When routing through a custom OpenAI-compatible endpoint (api_base + # is set) and the model name doesn't already carry a LiteLLM provider + # prefix (e.g. "openai/", "huggingface/", "azure/"), prepend "openai/" + # so LiteLLM routes to the correct adapter instead of raising + # "LLM Provider NOT provided". + _LITELLM_PROVIDERS = { + "openai", "azure", "huggingface", "ollama", "cohere", + "anthropic", "replicate", "together_ai", "vertex_ai", + } + if api_base and model.split("/")[0] not in _LITELLM_PROVIDERS: + model = f"openai/{model}" + self.embedding_model = model + self.embedding_dimension = dimension + self.api_base = api_base + # Default to "dummy" when a custom endpoint is in use — services like + # RITS authenticate via extra_headers rather than a bearer token. + self.api_key = api_key or ("dummy" if api_base else None) + self.extra_headers = extra_headers or {} + self.batch_size = batch_size + self.backend = backend + + async def embed_entity( + self, + entity: Entity, + use_name: bool = True, + use_description: bool = True, + ) -> Entity: + """Generate embedding for a single entity. + + Args: + entity: The Entity to embed + use_name: Include entity name in embedding text (default True) + use_description: Include entity description in embedding text (default True) + + Returns: + Entity with embedding field populated + """ + # Build text to embed + text_parts = [] + if use_name: + text_parts.append(f"Name: {entity.name}") + if use_description: + text_parts.append(f"Description: {entity.description}") + + embed_text = " ".join(text_parts) + logger.debug(f"Embedding entity: {entity.name}") + + # Generate embedding using LiteLLM + try: + embedding = await self._get_embedding(embed_text) + entity.embedding = embedding + logger.debug(f"Generated embedding for {entity.name} ({len(embedding)} dimensions)") + return entity + except Exception as e: + logger.error(f"Failed to embed entity {entity.name}: {e}") + raise + + async def embed_batch( + self, + entities: list[Entity], + use_name: bool = True, + use_description: bool = True, + batch_size: int = 10, + ) -> list[Entity]: + """Generate embeddings for multiple entities in parallel batches. + + Args: + entities: List of entities to embed + use_name: Include entity name in embedding text + use_description: Include entity description in embedding text + batch_size: Number of entities to embed in parallel per batch + + Returns: + List of entities with embeddings populated + """ + logger.info(f"Embedding batch of {len(entities)} entities") + embedded_entities = [] + + # Process in batches + for i in range(0, len(entities), batch_size): + batch = entities[i : i + batch_size] + logger.debug(f"Processing batch {i // batch_size + 1} ({len(batch)} entities)") + + # Embed each entity in batch (could be parallelized further) + for entity in batch: + try: + embedded = await self.embed_entity( + entity, + use_name=use_name, + use_description=use_description, + ) + embedded_entities.append(embedded) + except Exception as e: + logger.warning(f"Skipping entity {entity.name} due to embedding error: {e}") + # Add entity without embedding + embedded_entities.append(entity) + + logger.info(f"Embedded {len(embedded_entities)} entities") + return embedded_entities + + async def get_similar_entities( + self, + query_entity: Entity, + candidate_entities: list[Entity], + top_k: int = 5, + similarity_threshold: float = 0.0, + ) -> list[tuple[Entity, float]]: + """Find similar entities by embedding distance. + + Uses cosine similarity to find entities most similar to the query entity. + + Args: + query_entity: Entity to query (must have embedding) + candidate_entities: List of entities to search (must all have embeddings) + top_k: Number of top similar entities to return + similarity_threshold: Minimum similarity score (0-1) to include results + + Returns: + List of (Entity, similarity_score) tuples sorted by similarity (highest first) + + Raises: + ValueError: If query_entity or candidates don't have embeddings + """ + if query_entity.embedding is None: + raise ValueError("Query entity must have embedding") + + if not candidate_entities: + return [] + + # Compute similarity scores + similarities = [] + for candidate in candidate_entities: + if candidate.embedding is None: + logger.warning(f"Skipping candidate {candidate.name} (no embedding)") + continue + + similarity = self._cosine_similarity(query_entity.embedding, candidate.embedding) + + if similarity >= similarity_threshold: + similarities.append((candidate, similarity)) + + # Sort by similarity (highest first) and return top_k + similarities.sort(key=lambda x: x[1], reverse=True) + result = similarities[:top_k] + + logger.debug( + f"Found {len(result)} similar entities (threshold={similarity_threshold}, top_k={top_k})" + ) + return result + + @staticmethod + def _cosine_similarity(vec1: list[float], vec2: list[float]) -> float: + """Compute cosine similarity between two vectors. + + Args: + vec1: First vector + vec2: Second vector + + Returns: + Cosine similarity score between -1 and 1 (typically 0 to 1 for embeddings) + """ + if len(vec1) != len(vec2): + raise ValueError("Vectors must have same dimension") + + if not vec1 or not vec2: + return 0.0 + + # Compute dot product + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + + # Compute magnitudes + mag1 = math.sqrt(sum(a * a for a in vec1)) + mag2 = math.sqrt(sum(b * b for b in vec2)) + + if mag1 == 0 or mag2 == 0: + return 0.0 + + return dot_product / (mag1 * mag2) + + async def _get_embedding(self, text: str) -> list[float]: + """Get embedding for text using LiteLLM API. + + Args: + text: Text to embed + + Returns: + List of floats representing the embedding + + Raises: + Exception: If embedding API call fails + """ + try: + # Use LiteLLM's embedding API through litellm.embedding() + import litellm + + kwargs: dict = { + "model": self.embedding_model, + "input": text, + "encoding_format": "float", + } + if self.api_base: + kwargs["api_base"] = self.api_base + if self.api_key: + kwargs["api_key"] = self.api_key + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + response = await litellm.aembedding(**kwargs) + + # LiteLLM returns an EmbeddingResponse object with a .data list. + # Items may be Embedding objects (.embedding) or plain dicts. + if hasattr(response, "data"): + item = response.data[0] + return item["embedding"] if isinstance(item, dict) else item.embedding + # Fallback for plain dict responses + if isinstance(response, dict) and "data" in response: + return response["data"][0]["embedding"] + # Fallback for plain list responses + if isinstance(response, list): + return response[0]["embedding"] + return response + except Exception as e: + logger.error(f"Embedding API error: {e}") + raise + + async def _get_embeddings_batch(self, texts: list[str]) -> list[list[float]]: + """Embed a list of texts in a single API call. + + OpenAI-compatible endpoints accept a list as the ``input`` field, + which is far more efficient than one call per text. + + Args: + texts: List of strings to embed. + + Returns: + List of embedding vectors, one per input text. + + Raises: + Exception: If the embedding API call fails. + """ + import litellm + + kwargs: dict = { + "model": self.embedding_model, + "input": texts, + "encoding_format": "float", + } + if self.api_base: + kwargs["api_base"] = self.api_base + if self.api_key: + kwargs["api_key"] = self.api_key + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + + response = await litellm.aembedding(**kwargs) + + if hasattr(response, "data"): + items = response.data + return [ + item["embedding"] if isinstance(item, dict) else item.embedding + for item in items + ] + if isinstance(response, dict) and "data" in response: + return [item["embedding"] for item in response["data"]] + raise ValueError(f"Unexpected embedding response format: {type(response)}") + + # ------------------------------------------------------------------ + # Neo4j pipeline: fetch / store / index + # ------------------------------------------------------------------ + + async def fetch_entities_from_neo4j(self) -> list[Entity]: + """Fetch all entities from Neo4j as Entity objects. + + Returns: + List of Entity objects from the graph, or empty list for non-Neo4j backends. + """ + if not self.backend or getattr(self.backend, "backend_id", None) != "neo4j": + return [] + + cypher = """ + MATCH (n) + RETURN n.name AS name, labels(n)[0] AS type + LIMIT 100000 + """ + try: + driver = getattr(self.backend, "_async_driver", None) + if driver is None: + return [] + async with driver.session() as session: + result = await session.run(cypher) + records = [r async for r in result] + return [ + Entity( + type=r.get("type", "Unknown"), + name=r.get("name", ""), + description=f"Node of type {r.get('type')}", + ) + for r in records + ] + except Exception as exc: + logger.warning(f"Failed to fetch entities from Neo4j: {exc}") + return [] + + async def fetch_relations_from_neo4j(self) -> list[Relation]: + """Fetch all relations from Neo4j as Relation objects. + + Returns: + List of Relation objects from the graph, or empty list for non-Neo4j backends. + """ + if not self.backend or getattr(self.backend, "backend_id", None) != "neo4j": + return [] + + cypher = """ + MATCH ()-[r]->() + RETURN type(r) AS relation_type, id(r) AS rel_id + LIMIT 100000 + """ + try: + driver = getattr(self.backend, "_async_driver", None) + if driver is None: + return [] + async with driver.session() as session: + result = await session.run(cypher) + records = [r async for r in result] + return [ + Relation( + source_entity=f"Source_{r.get('rel_id', i)}", + relation_type=r.get("relation_type", "UNKNOWN"), + target_entity="Target", + description=f"Relation of type {r.get('relation_type', 'UNKNOWN')}", + ) + for i, r in enumerate(records) + ] + except Exception as exc: + logger.warning(f"Failed to fetch relations from Neo4j: {exc}") + return [] + + async def store_entity_embeddings( + self, entities: list[Entity], store_batch_size: int = 1000 + ) -> int: + """Store entity embeddings back to Neo4j. + + Args: + entities: Entities with populated ``embedding`` fields. + store_batch_size: Rows per Cypher transaction (default: 1000). + All chunks run within a single session to avoid connection + overhead while still providing progress visibility. + + Returns: + Number of embeddings stored, or 0 for non-Neo4j backends. + """ + if not self.backend or getattr(self.backend, "backend_id", None) != "neo4j": + return 0 + + driver = getattr(self.backend, "_async_driver", None) + if driver is None: + return 0 + + cypher = """ + UNWIND $batch AS item + MATCH (n {name: item.name}) + SET n.embedding = item.embedding + RETURN count(n) AS updated + """ + rows = [ + {"name": e.name, "embedding": getattr(e, "embedding", []) or []} + for e in entities + if getattr(e, "embedding", None) + ] + + try: + from tqdm import tqdm as _tqdm + except ImportError: + _tqdm = None + + try: + from tqdm import tqdm as _tqdm + except ImportError: + _tqdm = None + + chunks = [rows[i : i + store_batch_size] for i in range(0, len(rows), store_batch_size)] + total_stored = 0 + pbar = _tqdm(total=len(rows), desc="Storing entity embeddings", unit="ent") if _tqdm else None + try: + async with driver.session() as s: + for chunk in chunks: + try: + result = await s.run(cypher, batch=chunk) + record = await result.single() + total_stored += record.get("updated", 0) if record else 0 + except Exception as exc: + logger.warning(f"Failed to store entity embedding chunk: {exc}") + if pbar: + pbar.update(len(chunk)) + else: + logger.info(f" Stored {total_stored}/{len(rows)} entity embeddings…") + finally: + if pbar: + pbar.close() + return total_stored + + async def store_relation_embeddings( + self, relations: list[Relation], store_batch_size: int = 1000 + ) -> int: + """Store relation embeddings back to Neo4j. + + Args: + relations: Relations with populated ``embedding`` fields. + store_batch_size: Rows per Cypher transaction (default: 1000). + + Returns: + Number of embeddings stored, or 0 for non-Neo4j backends. + """ + if not self.backend or getattr(self.backend, "backend_id", None) != "neo4j": + return 0 + + driver = getattr(self.backend, "_async_driver", None) + if driver is None: + return 0 + + cypher = """ + UNWIND $batch AS item + MATCH ()-[r {type: item.relation_type}]->() + SET r.embedding = item.embedding + RETURN count(r) AS updated + """ + rows = [ + { + "relation_type": r.relation_type, + "embedding": getattr(r, "embedding", []) or [], + } + for r in relations + if getattr(r, "embedding", None) + ] + + try: + from tqdm import tqdm as _tqdm + except ImportError: + _tqdm = None + + try: + from tqdm import tqdm as _tqdm + except ImportError: + _tqdm = None + + chunks = [rows[i : i + store_batch_size] for i in range(0, len(rows), store_batch_size)] + total_stored = 0 + pbar = _tqdm(total=len(rows), desc="Storing relation embeddings", unit="rel") if _tqdm else None + try: + async with driver.session() as s: + for chunk in chunks: + try: + result = await s.run(cypher, batch=chunk) + record = await result.single() + total_stored += record.get("updated", 0) if record else 0 + except Exception as exc: + logger.warning(f"Failed to store relation embedding chunk: {exc}") + if pbar: + pbar.update(len(chunk)) + else: + logger.info(f" Stored {total_stored}/{len(rows)} relation embeddings…") + finally: + if pbar: + pbar.close() + return total_stored + + async def create_vector_indices(self) -> int: + """Create Neo4j vector indices for embedding similarity search. + + Creates one index for entity nodes and one for relationship embeddings + using the configured ``embedding_dimension``. + + Returns: + Number of indices created, or 0 for non-Neo4j backends. + """ + if not self.backend or getattr(self.backend, "backend_id", None) != "neo4j": + return 0 + + driver = getattr(self.backend, "_async_driver", None) + if driver is None: + return 0 + + indices_created = 0 + dim = self.embedding_dimension + index_queries = [ + f""" + CREATE VECTOR INDEX IF NOT EXISTS entity_embedding_index + FOR (n) ON (n.embedding) + OPTIONS {{ + indexConfig: {{ + `vector.dimensions`: {dim}, + `vector.similarity_function`: 'cosine' + }} + }} + """, + f""" + CREATE VECTOR INDEX IF NOT EXISTS relation_embedding_index + FOR (r: RELATIONSHIP) ON (r.embedding) + OPTIONS {{ + indexConfig: {{ + `vector.dimensions`: {dim}, + `vector.similarity_function`: 'cosine' + }} + }} + """, + ] + try: + async with driver.session() as session: + for query in index_queries: + try: + await session.run(query) + indices_created += 1 + except Exception as exc: + logger.debug(f"Vector index creation note: {exc}") + except Exception as exc: + logger.warning(f"Failed to create vector indices: {exc}") + + return indices_created + + async def embed_and_store_all(self, batch_size: int = 100) -> EmbeddingStats: + """Run the full embedding pipeline: fetch → embed → store → index. + + Fetches entities and relations from the graph backend, generates + embeddings for each, stores them back, and creates vector indices + for similarity search. For non-Neo4j backends the fetch/store/index + steps are no-ops so the method can still be used in tests with the + mock backend. + + Args: + batch_size: Number of items to log progress at each interval. + + Returns: + :class:`~mellea_contribs.kg.embed_models.EmbeddingStats` populated + with counts for entities, relations, storage, and index creation. + """ + t_start = time.monotonic() + + entities_embedded = entities_failed = entities_stored = 0 + relations_embedded = relations_failed = relations_stored = 0 + vector_indices = 0 + + try: + try: + from tqdm import tqdm as _tqdm + except ImportError: + _tqdm = None + + # --- entities --------------------------------------------------- + logger.info("Embedding pipeline: fetching entities…") + entities = await self.fetch_entities_from_neo4j() + logger.info(f" Fetched {len(entities)} entities") + + semaphore = asyncio.Semaphore(16) # max concurrent batch requests + + async def _embed_entity_batch(batch: list) -> tuple[int, int]: + """Embed one batch via a single API call; return (ok, failed).""" + texts = [] + for ent in batch: + parts = [] + if ent.name: + parts.append(f"Name: {ent.name}") + if ent.description: + parts.append(f"Description: {ent.description}") + texts.append(" ".join(parts) or ent.name or "") + async with semaphore: + try: + embeddings = await self._get_embeddings_batch(texts) + for ent, emb in zip(batch, embeddings): + ent.embedding = emb + return len(batch), 0 + except Exception as exc: + logger.error(f"Batch embed error: {exc}") + return 0, len(batch) + + entity_batches = [ + entities[i : i + batch_size] + for i in range(0, len(entities), batch_size) + ] + pbar = _tqdm(total=len(entities), desc="Embedding entities", unit="ent") if _tqdm else None + tasks = [_embed_entity_batch(b) for b in entity_batches] + for coro in asyncio.as_completed(tasks): + ok, failed = await coro + entities_embedded += ok + entities_failed += failed + if pbar: + pbar.update(ok + failed) + else: + done = entities_embedded + entities_failed + if done % (batch_size * 10) == 0: + logger.info(f" Embedded {done}/{len(entities)} entities…") + if pbar: + pbar.close() + + if entities_embedded: + entities_stored = await self.store_entity_embeddings(entities) + logger.info(f" Stored {entities_stored} entity embeddings") + + # --- relations -------------------------------------------------- + logger.info("Embedding pipeline: fetching relations…") + relations = await self.fetch_relations_from_neo4j() + logger.info(f" Fetched {len(relations)} relations") + + async def _embed_relation_batch(batch: list) -> tuple[int, int]: + texts = [f"Relation: {r.relation_type}" for r in batch] + async with semaphore: + try: + embeddings = await self._get_embeddings_batch(texts) + for rel, emb in zip(batch, embeddings): + rel.embedding = emb # type: ignore[attr-defined] + return len(batch), 0 + except Exception as exc: + logger.error(f"Relation batch embed error: {exc}") + return 0, len(batch) + + relation_batches = [ + relations[i : i + batch_size] + for i in range(0, len(relations), batch_size) + ] + pbar = _tqdm(total=len(relations), desc="Embedding relations", unit="rel") if _tqdm else None + tasks = [_embed_relation_batch(b) for b in relation_batches] + for coro in asyncio.as_completed(tasks): + ok, failed = await coro + relations_embedded += ok + relations_failed += failed + if pbar: + pbar.update(ok + failed) + else: + done = relations_embedded + relations_failed + if done % (batch_size * 10) == 0: + logger.info(f" Embedded {done}/{len(relations)} relations…") + if pbar: + pbar.close() + + if relations_embedded: + relations_stored = await self.store_relation_embeddings(relations) + logger.info(f" Stored {relations_stored} relation embeddings") + + # --- vector indices --------------------------------------------- + logger.info("Embedding pipeline: creating vector indices…") + vector_indices = await self.create_vector_indices() + logger.info(f" Created {vector_indices} vector indices") + + total_time = time.monotonic() - t_start + n_total = max(len(entities), 1) + return EmbeddingStats( + total_entities=len(entities), + successful_embeddings=entities_embedded, + failed_embeddings=entities_failed, + skipped_embeddings=0, + average_embedding_time=total_time / n_total, + total_time=total_time, + model_used=self.embedding_model, + dimension=self.embedding_dimension, + total_relations=len(relations), + successful_relation_embeddings=relations_embedded, + failed_relation_embeddings=relations_failed, + entities_stored=entities_stored, + relations_stored=relations_stored, + vector_indices_created=vector_indices, + success=True, + ) + + except Exception as exc: + total_time = time.monotonic() - t_start + logger.error(f"Embedding pipeline failed: {exc}") + return EmbeddingStats( + total_entities=0, + successful_embeddings=0, + failed_embeddings=0, + skipped_embeddings=0, + average_embedding_time=0.0, + total_time=total_time, + model_used=self.embedding_model, + dimension=self.embedding_dimension, + success=False, + error_message=str(exc), + ) + + async def close(self): + """Close connections and cleanup resources.""" + if self.backend: + await self.backend.close() + + +__all__ = ["KGEmbedder"] diff --git a/mellea_contribs/kg/graph_dbs/__init__.py b/mellea_contribs/kg/graph_dbs/__init__.py new file mode 100644 index 0000000..6740b97 --- /dev/null +++ b/mellea_contribs/kg/graph_dbs/__init__.py @@ -0,0 +1,17 @@ +"""Graph database backend implementations. + +Provides backend abstraction for different graph database systems. +""" + +from typing import TYPE_CHECKING + +from mellea_contribs.kg.graph_dbs.base import GraphBackend +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend + +try: + from mellea_contribs.kg.graph_dbs.neo4j import Neo4jBackend +except ImportError: + if not TYPE_CHECKING: + Neo4jBackend = None # type: ignore[assignment] + +__all__ = ["GraphBackend", "MockGraphBackend", "Neo4jBackend"] diff --git a/mellea_contribs/kg/graph_dbs/base.py b/mellea_contribs/kg/graph_dbs/base.py new file mode 100644 index 0000000..fdcd30a --- /dev/null +++ b/mellea_contribs/kg/graph_dbs/base.py @@ -0,0 +1,132 @@ +"""Abstract backend for graph databases. + +Provides a unified interface for executing graph queries across +different graph database systems (Neo4j, Neptune, RDF stores, etc.). +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from mellea_contribs.kg.components.query import GraphQuery + from mellea_contribs.kg.components.result import GraphResult + from mellea_contribs.kg.components.traversal import GraphTraversal + + +class GraphBackend(ABC): + """Abstract backend for graph databases. + + Following Mellea's Backend pattern: + - Takes backend_id (like model_id) + - Takes backend_options (like model_options) + - Abstract methods for core operations + """ + + def __init__( + self, + backend_id: str, + *, + connection_uri: str | None = None, + auth: tuple[str, str] | None = None, + database: str | None = None, + backend_options: dict | None = None, + ): + """Initialize graph backend. + + Following Mellea's Backend(model_id, model_options) pattern. + + Args: + backend_id: Identifier for backend type (e.g., "neo4j", "neptune") + connection_uri: URI for connecting to the database + auth: (username, password) tuple for authentication + database: Database name (if multi-database system) + backend_options: Backend-specific options + """ + # MELLEA PATTERN: Similar to Backend.__init__ + self.backend_id = backend_id + self.backend_options = backend_options if backend_options is not None else {} + + # Graph-specific fields + self.connection_uri = connection_uri + self.auth = auth + self.database = database + + @abstractmethod + async def execute_query( + self, query: "GraphQuery", **execution_options + ) -> "GraphResult": + """Execute a graph query and return results. + + Similar to Backend.generate_from_context() for LLMs. + Takes a Component (GraphQuery), returns a Component (GraphResult). + + Args: + query: The GraphQuery Component to execute + execution_options: Backend-specific execution options + + Returns: + GraphResult Component containing formatted results + """ + ... + + @abstractmethod + async def get_schema(self) -> dict[str, Any]: + """Get the graph schema. + + Returns: + Dictionary with node_types, edge_types, properties, etc. + """ + ... + + @abstractmethod + async def validate_query(self, query: "GraphQuery") -> tuple[bool, str | None]: + """Validate query syntax and semantics. + + Args: + query: The GraphQuery to validate + + Returns: + Tuple of (is_valid, error_message) + """ + ... + + def supports_query_type(self, query_type: str) -> bool: + """Check if this backend supports a query type (Cypher, SPARQL, etc.). + + Default implementation returns False. Subclasses should override. + + Args: + query_type: Query language type (e.g., "cypher", "sparql") + + Returns: + True if supported, False otherwise + """ + return False + + async def execute_traversal( + self, traversal: "GraphTraversal", **execution_options + ) -> "GraphResult": + """Execute a high-level traversal pattern. + + Default implementation converts to backend-specific query. + + Args: + traversal: The GraphTraversal pattern to execute + execution_options: Backend-specific execution options + + Returns: + GraphResult Component containing formatted results + """ + if self.supports_query_type("cypher"): + query = traversal.to_cypher() + return await self.execute_query(query, **execution_options) + else: + raise NotImplementedError( + f"Traversal not implemented for {self.__class__.__name__}" + ) + + async def close(self): + """Close connections to the graph database. + + Default implementation does nothing. Subclasses should override if needed. + """ diff --git a/mellea_contribs/kg/graph_dbs/mock.py b/mellea_contribs/kg/graph_dbs/mock.py new file mode 100644 index 0000000..e8d82bf --- /dev/null +++ b/mellea_contribs/kg/graph_dbs/mock.py @@ -0,0 +1,115 @@ +"""Mock backend for testing without a real graph database.""" + +from typing import TYPE_CHECKING, Any + +from mellea_contribs.kg.base import GraphEdge, GraphNode +from mellea_contribs.kg.graph_dbs.base import GraphBackend + +if TYPE_CHECKING: + from mellea_contribs.kg.components.query import GraphQuery + from mellea_contribs.kg.components.result import GraphResult + + +class MockGraphBackend(GraphBackend): + """Mock graph backend for testing. + + Returns predefined results without connecting to a real database. + """ + + def __init__( + self, + mock_nodes: list[GraphNode] | None = None, + mock_edges: list[GraphEdge] | None = None, + mock_schema: dict[str, Any] | None = None, + backend_options: dict | None = None, + ): + """Initialize mock backend. + + Args: + mock_nodes: Predefined nodes to return + mock_edges: Predefined edges to return + mock_schema: Predefined schema to return + backend_options: Additional options + """ + super().__init__( + backend_id="mock", + connection_uri="mock://localhost", + auth=None, + database=None, + backend_options=backend_options, + ) + + self.mock_nodes = mock_nodes or [] + self.mock_edges = mock_edges or [] + self.mock_schema = mock_schema or { + "node_types": ["MockNode"], + "edge_types": ["MOCK_EDGE"], + "property_keys": ["name", "value"], + } + self.query_history: list[tuple[str, dict]] = [] + + async def execute_query( + self, query: "GraphQuery", **execution_options + ) -> "GraphResult": + """Execute a mock query. + + Records the query and returns mock results. + + Args: + query: GraphQuery to execute + execution_options: Additional options + + Returns: + GraphResult with mock data + """ + # Import here to avoid circular dependency + from mellea_contribs.kg.components.result import GraphResult + + # Record query for testing + self.query_history.append((query.query_string or "", query.parameters)) + + # Return mock result + return GraphResult( + nodes=self.mock_nodes, + edges=self.mock_edges, + paths=[], + raw_result=None, + query=query, + format_style=execution_options.get("format_style", "triplets"), + ) + + async def get_schema(self) -> dict[str, Any]: + """Get mock schema. + + Returns: + Mock schema dictionary + """ + return self.mock_schema + + async def validate_query(self, query: "GraphQuery") -> tuple[bool, str | None]: + """Validate mock query. + + Always returns True for mock queries. + + Args: + query: GraphQuery to validate + + Returns: + Tuple of (True, None) + """ + return True, None + + def supports_query_type(self, query_type: str) -> bool: + """Mock backend supports all query types. + + Args: + query_type: Query language type + + Returns: + True for all types + """ + return True + + def clear_history(self): + """Clear query history.""" + self.query_history = [] diff --git a/mellea_contribs/kg/graph_dbs/neo4j.py b/mellea_contribs/kg/graph_dbs/neo4j.py new file mode 100644 index 0000000..175e853 --- /dev/null +++ b/mellea_contribs/kg/graph_dbs/neo4j.py @@ -0,0 +1,270 @@ +"""Neo4j implementation of GraphBackend.""" + +from typing import TYPE_CHECKING, Any + +from mellea_contribs.kg.base import GraphEdge, GraphNode, GraphPath +from mellea_contribs.kg.graph_dbs.base import GraphBackend + +if TYPE_CHECKING: + from mellea_contribs.kg.components.query import GraphQuery + from mellea_contribs.kg.components.result import GraphResult + +NEO4J_AVAILABLE = False +neo4j: Any = None +GraphDatabase: Any = None +AsyncGraphDatabase: Any = None + +try: + import neo4j as _neo4j_module # type: ignore[import-not-found] + from neo4j import ( # type: ignore[import-not-found] + AsyncGraphDatabase as _AsyncGraphDatabase, + GraphDatabase as _GraphDatabase, + ) + neo4j = _neo4j_module + GraphDatabase = _GraphDatabase + AsyncGraphDatabase = _AsyncGraphDatabase + NEO4J_AVAILABLE = True +except ImportError: + pass + + +class Neo4jBackend(GraphBackend): + """Neo4j implementation of GraphBackend. + + Implements the abstract GraphBackend interface for Neo4j databases. + """ + + def __init__( + self, + connection_uri: str = "bolt://localhost:7687", + auth: tuple[str, str] | None = None, + database: str | None = None, + backend_options: dict | None = None, + ): + """Initialize Neo4j backend. + + Args: + connection_uri: Neo4j connection URI + auth: (username, password) tuple + database: Database name (for multi-database) + backend_options: Neo4j-specific options + + Raises: + ImportError: If Neo4j driver is not installed + """ + if not NEO4J_AVAILABLE: + raise ImportError( + "Neo4j support requires additional dependencies. " + "Install with: pip install mellea-contribs[kg]" + ) + + # Call parent constructor following Mellea pattern + super().__init__( + backend_id="neo4j", + connection_uri=connection_uri, + auth=auth, + database=database, + backend_options=backend_options, + ) + + # Create Neo4j drivers + self._driver = GraphDatabase.driver(connection_uri, auth=auth) + self._async_driver = AsyncGraphDatabase.driver(connection_uri, auth=auth) + + async def execute_query( + self, query: "GraphQuery", **execution_options + ) -> "GraphResult": + """Execute a query in Neo4j. + + Takes a GraphQuery Component, executes it, returns GraphResult Component. + + Args: + query: GraphQuery Component to execute + execution_options: Additional options (format_style, etc.) + + Returns: + GraphResult Component with parsed results + + Raises: + ValueError: If query string is empty + """ + # Import here to avoid circular dependency + from mellea_contribs.kg.components.result import GraphResult + + # Get query string and parameters + query_string = query.query_string + parameters = query.parameters + + if not query_string: + raise ValueError("Query string is empty") + + # Execute query + async with self._async_driver.session(database=self.database) as session: + result = await session.run(query_string, parameters) + records = [record async for record in result] + + # Parse results into nodes, edges, paths + nodes, edges, paths = self._parse_neo4j_result(records) + + # Return GraphResult Component + return GraphResult( + nodes=nodes, + edges=edges, + paths=paths, + raw_result=records, + query=query, + format_style=execution_options.get("format_style", "triplets"), + ) + + def _parse_neo4j_result( + self, records: list + ) -> tuple[list[GraphNode], list[GraphEdge], list[GraphPath]]: + """Parse Neo4j records into GraphNode and GraphEdge objects. + + Args: + records: List of Neo4j records + + Returns: + Tuple of (nodes, edges, paths) + """ + nodes = [] + edges = [] + paths = [] + + node_cache = {} # Cache nodes by ID for edge creation + seen_node_ids = set() + seen_edge_ids = set() + + for record in records: + for key in record.keys(): + value = record[key] + + if isinstance(value, neo4j.graph.Node): + node = GraphNode.from_neo4j_node(value) + if node.id not in seen_node_ids: + node_cache[node.id] = node + nodes.append(node) + seen_node_ids.add(node.id) + + elif isinstance(value, neo4j.graph.Relationship): + # Get source and target nodes + # Neo4j relationships always have start_node and end_node + assert value.start_node is not None + assert value.end_node is not None + source_id = str(value.start_node.element_id) + target_id = str(value.end_node.element_id) + + # Get from cache or create + if source_id not in node_cache: + source = GraphNode.from_neo4j_node(value.start_node) + node_cache[source_id] = source + if source_id not in seen_node_ids: + nodes.append(source) + seen_node_ids.add(source_id) + + if target_id not in node_cache: + target = GraphNode.from_neo4j_node(value.end_node) + node_cache[target_id] = target + if target_id not in seen_node_ids: + nodes.append(target) + seen_node_ids.add(target_id) + + source = node_cache[source_id] + target = node_cache[target_id] + + edge = GraphEdge.from_neo4j_relationship(value, source, target) + edge_id = str(value.element_id) + if edge_id not in seen_edge_ids: + edges.append(edge) + seen_edge_ids.add(edge_id) + + elif isinstance(value, neo4j.graph.Path): + # Parse path + path = GraphPath.from_neo4j_path(value) + paths.append(path) + + # Also add nodes and edges to main lists if not seen + for node in path.nodes: + if node.id not in seen_node_ids: + nodes.append(node) + node_cache[node.id] = node + seen_node_ids.add(node.id) + + for edge in path.edges: + if edge.id not in seen_edge_ids: + edges.append(edge) + seen_edge_ids.add(edge.id) + + return nodes, edges, paths + + async def get_schema(self) -> dict[str, Any]: + """Get Neo4j schema. + + Queries for node labels, relationship types, and property keys. + + Returns: + Dictionary with node_types, edge_types, and properties + """ + # Get node labels + labels_query = "CALL db.labels() YIELD label RETURN collect(label) as labels" + async with self._async_driver.session(database=self.database) as session: + labels_result = await session.run(labels_query) + labels_record = await labels_result.single() + node_types = labels_record["labels"] if labels_record else [] + + # Get relationship types + types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) as types" + async with self._async_driver.session(database=self.database) as session: + types_result = await session.run(types_query) + types_record = await types_result.single() + edge_types = types_record["types"] if types_record else [] + + # Get property keys + props_query = "CALL db.propertyKeys() YIELD propertyKey RETURN collect(propertyKey) as keys" + async with self._async_driver.session(database=self.database) as session: + props_result = await session.run(props_query) + props_record = await props_result.single() + property_keys = props_record["keys"] if props_record else [] + + return { + "node_types": node_types, + "edge_types": edge_types, + "property_keys": property_keys, + } + + async def validate_query(self, query: "GraphQuery") -> tuple[bool, str | None]: + """Validate Cypher query syntax. + + Uses Neo4j's EXPLAIN to validate without executing. + + Args: + query: GraphQuery to validate + + Returns: + Tuple of (is_valid, error_message) + """ + try: + explain_query = f"EXPLAIN {query.query_string}" + async with self._async_driver.session(database=self.database) as session: + await session.run(explain_query, query.parameters) + return True, None + except neo4j.exceptions.CypherSyntaxError as e: + return False, str(e) + except Exception as e: + return False, f"Validation error: {e!s}" + + def supports_query_type(self, query_type: str) -> bool: + """Neo4j supports Cypher queries. + + Args: + query_type: Query language type + + Returns: + True if "cypher", False otherwise + """ + return query_type.lower() == "cypher" + + async def close(self): + """Close Neo4j connections.""" + await self._async_driver.close() + self._driver.close() diff --git a/mellea_contribs/kg/kgrag.py b/mellea_contribs/kg/kgrag.py new file mode 100644 index 0000000..0d8edb8 --- /dev/null +++ b/mellea_contribs/kg/kgrag.py @@ -0,0 +1,990 @@ +"""KGRag: Knowledge Graph Retrieval-Augmented Generation. + +Layer 1 application that orchestrates the full KG RAG pipeline: + +**QA Pipeline**: + 1. Break down question into solving routes (Layer 3 @generative) + 2. Extract topic entities from routes (Layer 3 @generative) + 3. Align entities with KG candidates (Layer 3 @generative) + 4. Prune relevant relations (Layer 3 @generative) + 5. Evaluate knowledge sufficiency (Layer 3 @generative) + 6. Generate answer or validate consensus (Layer 3 @generative) + +**Update Pipeline**: + 1. Extract entities and relations from document (Layer 3 @generative) + 2. Align extracted entities with KG (Layer 3 @generative) + 3. Decide entity merges (Layer 3 @generative) + 4. Align extracted relations with KG (Layer 3 @generative) + 5. Decide relation merges (Layer 3 @generative) + +Example:: + + import asyncio + from mellea import start_session + from mellea_contribs.kg import Neo4jBackend, KGRag + + async def main(): + session = start_session(backend_name="litellm", model_id="gpt-4o-mini") + backend = Neo4jBackend( + connection_uri="bolt://localhost:7687", + auth=("neo4j", "password"), + ) + rag = KGRag(backend=backend, session=session) + answer = await rag.answer("Who acted in The Matrix?") + print(answer) + await backend.close() + + asyncio.run(main()) +""" + +from typing import Any, Optional + +try: + from mellea import MelleaSession +except ImportError: + MelleaSession = None # type: ignore[assignment,misc] + +# Optional imports from mellea components (requires mellea to be installed) +try: + from mellea_contribs.kg.components import ( + align_entity_with_kg, + align_relation_with_kg, + align_topic_entities, + break_down_question, + decide_entity_merge, + decide_relation_merge, + evaluate_knowledge_sufficiency, + extract_entities_and_relations, + extract_topic_entities, + generate_direct_answer, + prune_relations, + prune_triplets, + validate_consensus, + ) + from mellea_contribs.kg.components.llm_guided import ( + explain_query_result, + natural_language_to_cypher, + suggest_query_improvement, + ) + from mellea_contribs.kg.components.query import CypherQuery + from mellea_contribs.kg.components.result import GraphResult +except ImportError: + # These are optional - mellea may not be installed + align_entity_with_kg = None + align_relation_with_kg = None + align_topic_entities = None + break_down_question = None + decide_entity_merge = None + decide_relation_merge = None + evaluate_knowledge_sufficiency = None + extract_entities_and_relations = None + extract_topic_entities = None + generate_direct_answer = None + prune_relations = None + prune_triplets = None + validate_consensus = None + explain_query_result = None + natural_language_to_cypher = None + suggest_query_improvement = None + CypherQuery = None + GraphResult = None + +from mellea_contribs.kg.graph_dbs.base import GraphBackend + +# Maximum Cypher repair attempts before giving up +_MAX_REPAIR_ATTEMPTS = 2 + + +def format_schema(schema: dict) -> str: + """Format a graph schema dictionary into a readable string for LLM prompts. + + Args: + schema: Dictionary with "node_types", "edge_types", and "property_keys" + keys (as returned by ``GraphBackend.get_schema()``). + + Returns: + A human-readable schema description. + """ + node_types = schema.get("node_types", []) + edge_types = schema.get("edge_types", []) + property_keys = schema.get("property_keys", []) + + lines = ["Graph Schema:"] + if node_types: + lines.append(f" Node labels: {', '.join(node_types)}") + if edge_types: + lines.append(f" Relationship types: {', '.join(edge_types)}") + if property_keys: + lines.append(f" Property keys: {', '.join(property_keys)}") + + return "\n".join(lines) + + +class KGRag: + """Knowledge Graph Retrieval-Augmented Generation pipeline. + + Combines a Mellea session (for LLM calls) with a graph backend (for query + execution) to answer natural language questions about a knowledge graph. + + The pipeline for each question: + + 1. **Schema retrieval** — fetch the current graph schema so the LLM knows + what node labels and relationship types exist. + 2. **Query generation** — ``natural_language_to_cypher`` converts the + question into a Cypher query via a ``@generative`` LLM call. + 3. **Validation & repair** — the generated Cypher is validated against the + database. If invalid, ``suggest_query_improvement`` is called (up to + ``max_repair_attempts`` times) to produce a corrected query. + 4. **Execution** — the validated query is executed against the backend. + 5. **Answer generation** — ``explain_query_result`` produces a natural + language answer grounded in the query results. + + Args: + backend: Graph database backend (Layer 4). + session: Active Mellea session wrapping an LLM backend. + format_style: How query results are formatted for the LLM + ("triplets", "natural", "paths", or "structured"). + max_repair_attempts: Maximum number of Cypher repair attempts before + the pipeline gives up and returns whatever was last generated. + """ + + def __init__( + self, + backend: GraphBackend, + session: MelleaSession, + format_style: str = "natural", + max_repair_attempts: int = _MAX_REPAIR_ATTEMPTS, + ): + """Initialize a KGRag pipeline. + + Args: + backend: Graph database backend. + session: Mellea session for LLM calls. + format_style: Result format style passed to GraphResult. + max_repair_attempts: Max Cypher repair attempts. + """ + self._backend = backend + self._session = session + self._format_style = format_style + self._max_repair_attempts = max_repair_attempts + + async def answer(self, question: str, examples: str = "") -> str: + """Answer a natural language question using the knowledge graph. + + Args: + question: A natural language question about the graph data. + examples: Optional few-shot Cypher examples to guide generation. + + Returns: + A natural language answer grounded in graph query results. + """ + # Step 1: Get graph schema + schema = await self._backend.get_schema() + schema_text = format_schema(schema) + + # Step 2: Generate Cypher query from natural language + generated = await natural_language_to_cypher( + self._session, + natural_language_query=question, + graph_schema=schema_text, + examples=examples, + ) + cypher_string = generated.query + + # Step 3: Validate and repair loop + cypher_string = await self._validate_and_repair( + cypher_string, schema_text + ) + + # Step 4: Execute validated query + query = CypherQuery(query_string=cypher_string, description=question) + graph_result = await self._backend.execute_query( + query, format_style=self._format_style + ) + + # Step 5: Generate natural language answer + result_component = GraphResult( + nodes=graph_result.nodes, + edges=graph_result.edges, + paths=graph_result.paths, + query=query, + format_style=self._format_style, + ) + result_text = result_component.format_for_llm().args["result"] + + answer = await explain_query_result( + self._session, + query=cypher_string, + result=result_text, + original_question=question, + ) + return answer + + async def _validate_and_repair( + self, cypher_string: str, schema_text: str + ) -> str: + """Validate Cypher syntax; repair via LLM if invalid. + + Attempts up to ``_max_repair_attempts`` repairs. Returns the last + generated string whether or not it passed validation, so the caller + always gets a best-effort answer. + + Args: + cypher_string: Cypher query to validate. + schema_text: Formatted schema text used when requesting repairs. + + Returns: + The validated (or best-effort repaired) Cypher string. + """ + for attempt in range(self._max_repair_attempts + 1): + query = CypherQuery(query_string=cypher_string) + is_valid, error = await self._backend.validate_query(query) + + if is_valid: + return cypher_string + + if attempt < self._max_repair_attempts: + improved = await suggest_query_improvement( + self._session, + query=cypher_string, + error_message=error or "Unknown syntax error", + schema=schema_text, + ) + cypher_string = improved.query + + # Return last attempt regardless (best-effort) + return cypher_string + + +# ============================================================================ +# Layer 1 - KG traversal helpers +# ============================================================================ + + +def _node_to_text(node: Any) -> str: + """Format a GraphNode as entity text for LLM prompts. + + Produces ``(Label: NAME, desc: "...", props: {...})`` used by the + @generative alignment / pruning functions. + + Args: + node: GraphNode instance. + + Returns: + Formatted entity string. + """ + name = str(node.properties.get("name", node.id)).strip().upper() + desc = node.properties.get("description", "") + _SKIP = {"name", "description", "embedding"} + props = { + k: v + for k, v in node.properties.items() + if k not in _SKIP and not k.startswith("_") + } + parts = [f"({node.label}: {name}"] + if desc: + parts.append(f', desc: "{str(desc).replace(chr(34), chr(39))}"') + if props: + prop_items = [f"{k}: {v}" for k, v in list(props.items())[:8]] + parts.append(f", props: {{{', '.join(prop_items)}}}") + parts.append(")") + return "".join(parts) + + +def _edge_to_triplet_text(edge: Any) -> str: + """Format a GraphEdge as a triplet text for LLM prompts. + + Produces ``(Src)-[REL, props: {...}]->(Tgt)`` format. + + Args: + edge: GraphEdge instance. + + Returns: + Formatted triplet string. + """ + src = _node_to_text(edge.source) + tgt = _node_to_text(edge.target) + _SKIP = {"embedding"} + props = { + k: v + for k, v in edge.properties.items() + if k not in _SKIP and not k.startswith("_") + } + if props: + prop_items = [f"{k}: {v}" for k, v in list(props.items())[:5]] + return f"{src}-[{edge.label}, props: {{{', '.join(prop_items)}}}]->{tgt}" + return f"{src}-[{edge.label}]->{tgt}" + + +async def _search_entities_by_name( + backend: GraphBackend, name: str, k: int = 4 +) -> list: + """Search KG entities by case-insensitive name containment. + + Args: + backend: Graph database backend. + name: Entity name fragment to search for. + k: Maximum number of results. + + Returns: + List of matching GraphNode objects. + """ + q = CypherQuery( + query_string=( + "MATCH (n) WHERE toLower(n.name) CONTAINS toLower($name) " + "RETURN n LIMIT $k" + ), + parameters={"name": name, "k": k}, + ) + try: + result = await backend.execute_query(q) + return result.nodes + except Exception: + return [] + + +async def _search_entities_by_embedding( + backend: GraphBackend, + embedding: list, + k: int = 10, + exclude_ids: Optional[set] = None, +) -> list: + """Search KG entities using a Neo4j vector index. + + Falls back gracefully to an empty list when no vector index exists. + + Args: + backend: Graph database backend. + embedding: Query embedding vector. + k: Maximum number of results to return. + exclude_ids: Node IDs to exclude from the returned list. + + Returns: + List of GraphNode objects ordered by similarity. + """ + exclude_ids = exclude_ids or set() + fetch_k = k + len(exclude_ids) + q = CypherQuery( + query_string=( + "CALL db.index.vector.queryNodes('entity_embedding', $k, $emb) " + "YIELD node RETURN node" + ), + parameters={"k": fetch_k, "emb": embedding}, + ) + try: + result = await backend.execute_query(q) + return [n for n in result.nodes if n.id not in exclude_ids][:k] + except Exception: + return [] + + +async def _get_unique_relation_types( + backend: GraphBackend, node_id: str, width: int = 30 +) -> list: + """Retrieve distinct ``(relation_type, target_label)`` pairs from a node. + + Args: + backend: Graph database backend. + node_id: Element ID of the source node. + width: Maximum number of distinct relation types to return. + + Returns: + List of ``(relation_type, target_label)`` tuples. + """ + q = CypherQuery( + query_string=( + "MATCH (n)-[r]->(m) WHERE elementId(n) = $nid " + "RETURN DISTINCT type(r) AS rel_type, labels(m)[0] AS tgt_type " + "LIMIT $w" + ), + parameters={"nid": node_id, "w": width}, + ) + try: + result = await backend.execute_query(q) + pairs: list = [] + # raw_result holds Neo4j records for non-node/edge RETURN clauses + if result.raw_result: + for record in result.raw_result: + try: + data = record.data() if hasattr(record, "data") else dict(record) + rt = data.get("rel_type") + tt = data.get("tgt_type") or "Unknown" + if rt: + pairs.append((str(rt), str(tt))) + except Exception: + continue + # Fallback: deduplicate from edges when the backend already parsed them + if not pairs and result.edges: + seen: set = set() + for edge in result.edges: + key = (edge.label, edge.target.label) + if key not in seen: + seen.add(key) + pairs.append(key) + return pairs + except Exception: + return [] + + +async def _get_triplets( + backend: GraphBackend, + node_id: str, + rel_type: str, + target_type: str = "Unknown", + k: int = 30, +) -> list: + """Retrieve full ``(source)-[rel]->(target)`` triplets from the KG. + + Args: + backend: Graph database backend. + node_id: Element ID of the source node. + rel_type: Relationship type to traverse. + target_type: Label of target nodes; ignored when ``"Unknown"``/``"None"``. + k: Maximum number of triplets to return. + + Returns: + List of GraphEdge objects (each carries source and target GraphNodes). + """ + # Sanitise identifiers to prevent Cypher injection + safe_rel = "".join(c for c in rel_type if c.isalnum() or c == "_") + safe_tgt = "".join(c for c in target_type if c.isalnum() or c == "_") + + if safe_tgt and safe_tgt not in ("Unknown", "None"): + cypher = ( + f"MATCH (n)-[r:{safe_rel}]->(m:{safe_tgt}) " + "WHERE elementId(n) = $nid RETURN n, r, m LIMIT $k" + ) + else: + cypher = ( + f"MATCH (n)-[r:{safe_rel}]->(m) " + "WHERE elementId(n) = $nid RETURN n, r, m LIMIT $k" + ) + q = CypherQuery(query_string=cypher, parameters={"nid": node_id, "k": k}) + try: + result = await backend.execute_query(q) + return result.edges + except Exception: + return [] + + +# ============================================================================ +# Layer 1 - QA Orchestration (Multi-Route Question Answering) +# ============================================================================ + + +async def orchestrate_qa_retrieval( + session: Any, + backend: GraphBackend, + query: str, + query_time: str = "", + domain: str = "general", + num_routes: int = 3, + hints: str = "", + eval_session: Optional[Any] = None, + emb_client: Optional[Any] = None, + width: int = 30, + depth: int = 3, +) -> str: + """Orchestrate multi-route QA via Think-on-Graph (ToG) algorithm. + + Implements the full Think-on-Graph pipeline: + + 1. Break the question into ``num_routes`` solving routes. + 2. In parallel, compute a direct LLM answer (``attempt``) and explore the + first two routes. + 3. After each new explored route (starting from route 2), call + ``validate_consensus`` to check whether answers agree. Stop early on + consensus. + 4. If consensus is never reached, return the direct answer as fallback. + + Each route exploration performs up to ``depth`` hops of graph traversal: + extract topic entities → align with KG → prune relations → retrieve + triplets → prune triplets → evaluate knowledge sufficiency. + + Args: + session: Mellea session for main LLM calls (question-decomposition, + entity alignment, relation/triplet pruning). + backend: Graph database backend used for all Cypher queries. + query: Natural language question to answer. + query_time: Temporal context string (e.g. ``"2024-03-05"``). + domain: Knowledge domain hint (e.g. ``"movie"``). + num_routes: Number of solving routes to generate and explore. + hints: Domain-specific text hints appended to prompts. + eval_session: Separate session for evaluation calls (knowledge + sufficiency, consensus, direct answer). Defaults to ``session``. + emb_client: Optional async OpenAI-compatible embedding client. When + provided, entity alignment also uses vector-index search. + width: Maximum entities / relations considered at each step. + depth: Maximum graph-traversal hops per route. + + Returns: + Natural language answer string. + """ + import asyncio + + _eval = eval_session or session + + # ------------------------------------------------------------------ + # Inner helpers + # ------------------------------------------------------------------ + + async def _embed(texts: list) -> list: + """Return embeddings via emb_client, or None placeholders.""" + if emb_client is None or not texts: + return [None] * len(texts) + try: + model_name = getattr(emb_client, "_model_name", "text-embedding-3-small") + response = await emb_client.embeddings.create( + input=texts, model=model_name + ) + return [item.embedding for item in response.data] + except Exception: + return [None] * len(texts) + + async def _align_topic(route: list, topic_name: str, topic_emb: Any, top_k: int = 45) -> list: + """Align one topic entity name with KG candidates. + + Returns list of ``(GraphNode, float)`` scored pairs. + """ + norm_coeff = 1.0 / max(1, len(route)) + + fuzzy_nodes = await _search_entities_by_name(backend, topic_name, k=4) + fuzzy_ids = {n.id for n in fuzzy_nodes} + + emb_nodes: list = [] + if topic_emb is not None and len(fuzzy_nodes) < top_k: + emb_nodes = await _search_entities_by_embedding( + backend, topic_emb, k=top_k - len(fuzzy_nodes), exclude_ids=fuzzy_ids + ) + + all_nodes = fuzzy_nodes + emb_nodes + if not all_nodes: + return [] + + entities_dict = {f"ent_{i}": n for i, n in enumerate(all_nodes)} + entities_str = "\n".join( + f"{k}: {_node_to_text(v)}" for k, v in entities_dict.items() + ) + + align_result = await align_topic_entities( + session, + query=query, + query_time=query_time, + route=route, + domain=domain, + top_k_entities_str=entities_str, + ) + + scored = [] + for key, score_str in align_result.relevant_entities.items(): + try: + score = float(score_str) + except (TypeError, ValueError): + score = 0.0 + if score > 0 and key in entities_dict: + scored.append((entities_dict[key], norm_coeff * score)) + return scored + + async def _explore_one_route(route: list) -> dict: + """Run the ToG traversal for one solving route. + + Returns dict with keys ``ans``, ``context``, ``route``. + """ + # Step A: Extract topic entities + topic_result = await extract_topic_entities( + session, + query=query, + query_time=query_time, + route=route, + domain=domain, + ) + raw_topics = topic_result.entities or [] + topic_names = [str(t).strip().upper() for t in raw_topics if t] + + if not topic_names: + return {"ans": "I don't know.", "context": "", "route": route} + + # Step B: Align each topic entity with KG + topic_embeddings = await _embed(topic_names) + align_tasks = [ + _align_topic(route, name, emb, top_k=min(45, width)) + for name, emb in zip(topic_names, topic_embeddings) + ] + align_results = await asyncio.gather(*align_tasks) + + # Aggregate scores across topics (sum for same node) + score_map: dict = {} + for scored_list in align_results: + for node, score in scored_list: + prev = score_map.get(node.id, (node, 0.0))[1] + score_map[node.id] = (node, prev + score) + + topic_entities_scores = list(score_map.values()) + initial_entities = [n for n, _ in topic_entities_scores] + + # Step C: Initial knowledge sufficiency check (before any traversal) + ent_str = ( + "\n".join(f"ent_{i}: {_node_to_text(n)}" for i, n in enumerate(initial_entities)) + or "None" + ) + eval_result = await evaluate_knowledge_sufficiency( + _eval, + query=query, + query_time=query_time, + route=route, + domain=domain, + entities=ent_str, + triplets="None", + hints=hints, + ) + if eval_result.sufficient.lower().strip() == "yes": + return { + "ans": eval_result.answer, + "context": f"Knowledge Entities:\n{ent_str}\nKnowledge Triplets:\nNone", + "route": route, + } + + # Step D: Multi-hop traversal + cluster_chain: list = [] + visited_edges: set = set() + + for _hop in range(depth): + triplet_scored: list = [] + + for node, entity_score in topic_entities_scores: + rel_types = await _get_unique_relation_types(backend, node.id, width=width) + if not rel_types: + continue + + entity_str = _node_to_text(node) + src_name = node.properties.get("name", node.id).strip().upper() + rels_str = "\n".join( + f"rel_{i}: ({node.label}: {src_name})-[{rt}]->({tt}: None)" + for i, (rt, tt) in enumerate(rel_types) + ) + + rel_prune = await prune_relations( + session, + query=query, + query_time=query_time, + route=route, + domain=domain, + entity_str=entity_str, + relations_str=rels_str, + width=width, + hints=hints, + ) + + for key, score_str in rel_prune.relevant_relations.items(): + try: + score = float(score_str) + idx = int(key.split("_")[1]) + except (TypeError, ValueError, IndexError): + continue + if score <= 0 or idx >= len(rel_types): + continue + rt, tt = rel_types[idx] + + triplets = await _get_triplets(backend, node.id, rt, tt, k=width) + triplets = [e for e in triplets if e.id not in visited_edges] + if not triplets: + continue + + trips_str = "\n".join( + f"rel_{j}: {_edge_to_triplet_text(e)}" + for j, e in enumerate(triplets[:width]) + ) + trip_prune = await prune_triplets( + session, + query=query, + query_time=query_time, + route=route, + domain=domain, + entity_str=entity_str, + relations_str=trips_str, + hints=hints, + ) + + for tkey, tscore_str in trip_prune.relevant_relations.items(): + try: + tscore = float(tscore_str) + tidx = int(tkey.split("_")[1]) + except (TypeError, ValueError, IndexError): + continue + if tscore > 0 and tidx < len(triplets): + triplet_scored.append( + (triplets[tidx], entity_score * score * tscore) + ) + + # Keep top-width triplets by score + triplet_scored.sort(key=lambda x: x[1], reverse=True) + triplet_scored = triplet_scored[:width] + + if not triplet_scored: + break + + chain_texts = [_edge_to_triplet_text(e) for e, _ in triplet_scored] + cluster_chain.append(chain_texts) + for edge, _ in triplet_scored: + visited_edges.add(edge.id) + + # Advance topic entities to triplet targets for next hop + next_scores: dict = {} + norm_sum = sum(s for _, s in triplet_scored) + norm = 1.0 / norm_sum if norm_sum > 0 else 1.0 + for edge, score in triplet_scored: + tgt = edge.target + prev = next_scores.get(tgt.id, (tgt, 0.0))[1] + next_scores[tgt.id] = (tgt, prev + score * norm) + topic_entities_scores = list(next_scores.values()) + + # Evaluate sufficiency with accumulated knowledge + ent_str2 = ( + "\n".join(f"ent_{i}: {_node_to_text(n)}" for i, n in enumerate(initial_entities)) + or "None" + ) + idx = 0 + trip_parts: list = [] + for chain in cluster_chain: + for t in chain: + trip_parts.append(f"rel_{idx}: {t}") + idx += 1 + triplets_str = "\n".join(trip_parts) or "None" + + suf = await evaluate_knowledge_sufficiency( + _eval, + query=query, + query_time=query_time, + route=route, + domain=domain, + entities=ent_str2, + triplets=triplets_str, + hints=hints, + ) + if suf.sufficient.lower().strip() == "yes": + return { + "ans": suf.answer, + "context": ( + f"Knowledge Entities:\n{ent_str2}\n" + f"Knowledge Triplets:\n{triplets_str}" + ), + "route": route, + } + + # Depth exhausted — force a final answer from accumulated knowledge + ent_str_f = ( + "\n".join(f"ent_{i}: {_node_to_text(n)}" for i, n in enumerate(initial_entities)) + or "None" + ) + idx = 0 + trip_parts_f: list = [] + for chain in cluster_chain: + for t in chain: + trip_parts_f.append(f"rel_{idx}: {t}") + idx += 1 + trip_str_f = "\n".join(trip_parts_f) or "None" + + final_suf = await evaluate_knowledge_sufficiency( + _eval, + query=query, + query_time=query_time, + route=route, + domain=domain, + entities=ent_str_f, + triplets=trip_str_f, + hints=hints, + ) + return { + "ans": final_suf.answer, + "context": ( + f"Knowledge Entities:\n{ent_str_f}\n" + f"Knowledge Triplets:\n{trip_str_f}" + ), + "route": route, + } + + # ------------------------------------------------------------------ + # Main orchestration + # ------------------------------------------------------------------ + + # Break question into routes + routes_result = await break_down_question( + session, + query=query, + query_time=query_time, + domain=domain, + route=num_routes, + hints=hints, + ) + routes = routes_result.routes or [] + + if not routes: + direct = await generate_direct_answer( + _eval, query=query, query_time=query_time, domain=domain + ) + return direct.answer + + # Launch direct answer + first two routes in parallel + parallel_coros = [ + generate_direct_answer(_eval, query=query, query_time=query_time, domain=domain), + _explore_one_route(routes[0]), + ] + if len(routes) > 1: + parallel_coros.append(_explore_one_route(routes[1])) + + parallel_results = await asyncio.gather(*parallel_coros) + direct_result = parallel_results[0] + route_results = list(parallel_results[1:]) + attempt = f'"{direct_result.answer}". {direct_result.reason}' + + # Explore remaining routes, checking consensus after each + final = attempt + stop = False + + for i, route in enumerate(routes[2:], start=2): + route_results.append(await _explore_one_route(route)) + + n_total = len(routes) + n_explored = len(route_results) + n_remaining = n_total - n_explored + routes_info = ( + f"\nWe have identified {n_total} solving route(s) below, " + f"and have {n_remaining} unexplored solving route(s) left.:\n" + ) + for j, rr in enumerate(route_results): + route_label = routes[j] if j < len(routes) else [] + routes_info += ( + f"Route {j + 1}: {route_label}\n" + f"Reference: {rr['context']}\n" + f"Answer: {rr['ans']}\n\n" + ) + for j in range(n_explored, n_total): + routes_info += f"Route {j + 1}: {routes[j]}\n\n" + + val = await validate_consensus( + _eval, + query=query, + query_time=query_time, + domain=domain, + attempt=attempt, + routes_info=routes_info, + hints=hints, + ) + final = val.final_answer + stop = val.judgement.lower().strip().replace(" ", "") == "yes" + if stop: + break + + if not stop and len(route_results) >= 2: + # Final consensus check using all explored routes + n_total = len(routes) + routes_info = ( + f"\nWe have identified {n_total} solving route(s) below, " + "and have 0 unexplored solving route(s) left.:\n" + ) + for j, rr in enumerate(route_results): + route_label = routes[j] if j < len(routes) else [] + routes_info += ( + f"Route {j + 1}: {route_label}\n" + f"Reference: {rr['context']}\n" + f"Answer: {rr['ans']}\n\n" + ) + val = await validate_consensus( + _eval, + query=query, + query_time=query_time, + domain=domain, + attempt=attempt, + routes_info=routes_info, + hints=hints, + ) + final = val.final_answer + + return final + + +# ============================================================================ +# Layer 1 - Update Orchestration (Document-based KG Updating) +# ============================================================================ + + +async def orchestrate_kg_update( + session: MelleaSession, + backend: GraphBackend, + doc_text: str, + domain: str = "general", + hints: str = "", + entity_types: str = "", + relation_types: str = "", +) -> dict: + """Orchestrate KG update pipeline. + + This is the main Layer 1 entry point for updating a knowledge graph with + information extracted from documents. It extracts entities and relations, + aligns them with existing KG data, and decides on merges. + + Args: + session: Mellea session for LLM calls + backend: Graph database backend for queries and updates + doc_text: Document text to extract information from + domain: Domain-specific knowledge + hints: Domain-specific hints for the LLM + entity_types: Comma-separated list of valid entity types + relation_types: Comma-separated list of valid relation types + + Returns: + Dictionary with: + - extracted_entities: List of extracted entity objects + - extracted_relations: List of extracted relation objects + - aligned_entities: List of alignment results + - aligned_relations: List of alignment results + - update_summary: Summary of updates made to KG + """ + # Step 1: Extract entities and relations from document + extraction = await extract_entities_and_relations( + session, + doc_context=doc_text, + domain=domain, + hints=hints, + reference="", + entity_types=entity_types, + relation_types=relation_types, + ) + + # Step 2-3: Align entities with KG and decide merges + # (Simplified - full implementation would iterate through extracted entities) + + # Step 4-5: Align relations with KG and decide merges + # (Simplified - full implementation would iterate through extracted relations) + + return { + "extracted_entities": extraction.entities, + "extracted_relations": extraction.relations, + "aligned_entities": [], + "aligned_relations": [], + "update_summary": "Document processed and entities/relations extracted", + } + + +__all__ = [ + # Main Layer 1 orchestration functions + "KGRag", + "format_schema", + "orchestrate_qa_retrieval", + "orchestrate_kg_update", + # QA Generative functions (Layer 3) + "break_down_question", + "extract_topic_entities", + "align_topic_entities", + "prune_relations", + "prune_triplets", + "evaluate_knowledge_sufficiency", + "validate_consensus", + "generate_direct_answer", + # Update Generative functions (Layer 3) + "extract_entities_and_relations", + "align_entity_with_kg", + "decide_entity_merge", + "align_relation_with_kg", + "decide_relation_merge", +] diff --git a/mellea_contribs/kg/models.py b/mellea_contribs/kg/models.py new file mode 100644 index 0000000..b107c2c --- /dev/null +++ b/mellea_contribs/kg/models.py @@ -0,0 +1,200 @@ +"""Pydantic models for KG-RAG structured outputs.""" +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +# QA Models +class QuestionRoutes(BaseModel): + """Routes for breaking down a complex question into sub-objectives.""" + + reason: str = Field(description="Reasoning for the route ordering") + routes: List[List[str]] = Field( + description="List of solving routes, each containing sub-objectives" + ) + + +class TopicEntities(BaseModel): + """Extracted topic entities from a query.""" + + entities: List[str] = Field(description="List of extracted entity names") + + +class RelevantEntities(BaseModel): + """Relevant entities with their scores.""" + + reason: str = Field(description="Reasoning for entity relevance") + relevant_entities: Dict[str, float] = Field( + description="Mapping of entity index (e.g., 'ent_0') to relevance score" + ) + + +class RelevantRelations(BaseModel): + """Relevant relations with their scores.""" + + reason: str = Field(description="Reasoning for relation relevance") + relevant_relations: Dict[str, float] = Field( + description="Mapping of relation index (e.g., 'rel_0') to relevance score" + ) + + +class EvaluationResult(BaseModel): + """Evaluation result for whether knowledge is sufficient to answer.""" + + sufficient: str = Field(description="'Yes' or 'No' indicating if knowledge is sufficient") + reason: str = Field(description="Reasoning for the sufficiency judgment") + answer: str = Field(description="The answer if sufficient, 'I don't know' otherwise") + + +class ValidationResult(BaseModel): + """Validation result for consensus among multiple routes.""" + + judgement: str = Field( + description="'Yes' or 'No' for whether consensus is reached" + ) + final_answer: str = Field(description="The final answer with explanation") + + +class DirectAnswer(BaseModel): + """Direct answer without knowledge graph.""" + + sufficient: str = Field( + description="'Yes' or 'No' indicating if LLM knowledge is sufficient" + ) + reason: str = Field(description="Reasoning for the answer") + answer: str = Field(description="The answer or 'I don't know'") + + +# ── Unified Entity / Relation models ──────────────────────────────────────── +# Single class for both extracted and stored entities/relations. +# Storage fields (id, confidence, etc.) are optional: None until persisted. + + +class Entity(BaseModel): + """Entity in the KG (extracted from document or stored). + + Unified model for both extracted and stored entities. + Storage fields are None until the entity has been persisted. + + Fields: + Extraction context (always present): + type: Entity type (e.g., Person, Movie, Organization) + name: Entity name + description: Brief description + paragraph_start: First 5-30 chars of supporting paragraph + paragraph_end: Last 5-30 chars of supporting paragraph + properties: Additional domain-specific properties + + Storage fields (optional, None before persistence): + id: Stable KG node ID + confidence: Extraction confidence score 0-1 + embedding: Vector embedding for similarity search + """ + + # Extraction context + type: str = Field(description="Entity type (e.g., Person, Movie, Organization)") + name: str = Field(description="Entity name") + description: str = Field(description="Brief description of the entity") + paragraph_start: Optional[str] = Field( + default=None, description="First 5-30 chars of supporting paragraph" + ) + paragraph_end: Optional[str] = Field( + default=None, description="Last 5-30 chars of supporting paragraph" + ) + properties: Dict[str, Any] = Field( + default_factory=dict, description="Additional properties" + ) + + # Storage fields (optional, assigned on persistence) + id: Optional[str] = Field(default=None, description="Stable KG node ID") + confidence: float = Field(default=1.0, description="Extraction confidence 0-1") + embedding: Optional[List[float]] = Field( + default=None, description="Vector embedding for similarity search" + ) + + +class Relation(BaseModel): + """Relation in the KG (extracted from document or stored). + + Unified model for both extracted and stored relations. + Storage fields are None until the relation has been persisted and aligned. + + Fields: + Extraction context (always present): + source_entity: Source entity name + relation_type: Relation type (e.g., acted_in, directed) + target_entity: Target entity name + description: Description of the relation + paragraph_start: First 5-30 chars of supporting paragraph + paragraph_end: Last 5-30 chars of supporting paragraph + properties: Additional domain-specific properties + + Storage fields (optional, None before persistence): + id: Stable KG edge ID + source_entity_id: Resolved source node ID in KG + target_entity_id: Resolved target node ID in KG + valid_from: ISO date when relation became valid + valid_until: ISO date when relation ceased to be valid + """ + + # Extraction context + source_entity: str = Field(description="Source entity name") + relation_type: str = Field(description="Relation type (e.g., acted_in, directed)") + target_entity: str = Field(description="Target entity name") + description: str = Field(description="Description of the relation") + paragraph_start: Optional[str] = Field( + default=None, description="First 5-30 chars of supporting paragraph" + ) + paragraph_end: Optional[str] = Field( + default=None, description="Last 5-30 chars of supporting paragraph" + ) + properties: Dict[str, Any] = Field( + default_factory=dict, description="Additional properties" + ) + + # Storage fields (optional, assigned on persistence) + id: Optional[str] = Field(default=None, description="Stable KG edge ID") + source_entity_id: Optional[str] = Field( + default=None, description="Resolved source node ID in KG" + ) + target_entity_id: Optional[str] = Field( + default=None, description="Resolved target node ID in KG" + ) + valid_from: Optional[str] = Field( + default=None, description="ISO date when relation became valid" + ) + valid_until: Optional[str] = Field( + default=None, description="ISO date when relation ceased to be valid" + ) + embedding: Optional[List[float]] = Field( + default=None, description="Vector embedding for similarity search" + ) + + +class ExtractionResult(BaseModel): + """Result of entity and relation extraction.""" + + entities: List[Entity] = Field(description="List of extracted entities") + relations: List[Relation] = Field(description="List of extracted relations") + reasoning: str = Field(description="Reasoning for the extractions") + + +class AlignmentResult(BaseModel): + """Result of entity alignment with existing KG.""" + + aligned_entity_id: Optional[str] = Field( + description="ID of matched entity in KG, or None" + ) + confidence: float = Field(description="Confidence score 0-1 for the alignment") + reasoning: str = Field(description="Reasoning for the alignment decision") + + +class MergeDecision(BaseModel): + """Decision on whether to merge entities.""" + + should_merge: bool = Field(description="Whether entities should be merged") + reasoning: str = Field(description="Reasoning for the merge decision") + merged_properties: Dict[str, Any] = Field( + default_factory=dict, description="Properties of merged entity if merging" + ) diff --git a/mellea_contribs/kg/preprocessor.py b/mellea_contribs/kg/preprocessor.py new file mode 100644 index 0000000..3dda0af --- /dev/null +++ b/mellea_contribs/kg/preprocessor.py @@ -0,0 +1,229 @@ +"""KG Preprocessor: Layer 1 application for preprocessing raw data into KG entities/relations. + +This module provides generic preprocessing infrastructure for converting raw documents +into Knowledge Graph entities and relations using the Layer 2-3 extraction functions. + +The architecture follows Mellea's Layer 1 pattern: +- Layer 1: KGPreprocessor (this module) orchestrates the pipeline +- Layer 2-3: extract_entities_and_relations, align_entity_with_kg, etc. +- Layer 4: GraphBackend for persisting to Neo4j + +Example:: + + import asyncio + from mellea import start_session + from mellea_contribs.kg import MockGraphBackend + from mellea_contribs.kg.preprocessor import KGPreprocessor + + async def main(): + session = start_session(backend_name="litellm", model_id="gpt-4o-mini") + backend = MockGraphBackend() + processor = KGPreprocessor(backend=backend, session=session) + + # Process a document + doc = {"text": "Avatar was directed by James Cameron in 2009."} + result = await processor.process_document( + doc_text=doc["text"], + domain="movies", + doc_id="doc_1" + ) + print(f"Extracted {len(result.entities)} entities and {len(result.relations)} relations") + await backend.close() + + asyncio.run(main()) +""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional + +try: + from mellea import MelleaSession +except ImportError: + MelleaSession = None + +from mellea_contribs.kg.base import GraphEdge, GraphNode + +try: + from mellea_contribs.kg.components import ( + extract_entities_and_relations, + ) +except ImportError: + extract_entities_and_relations = None + +from mellea_contribs.kg.graph_dbs.base import GraphBackend +from mellea_contribs.kg.models import Entity, ExtractionResult, Relation + +logger = logging.getLogger(__name__) + + +class KGPreprocessor(ABC): + """Generic base class for preprocessing raw data into KG entities and relations. + + Orchestrates the Layer 2-3 extraction pipeline and handles entity/relation storage. + Subclasses should override get_hints() and optionally post_process_entities/relations(). + + This is a Layer 1 application that: + 1. Uses Layer 3 extract_entities_and_relations for LLM extraction + 2. Optionally calls Layer 3 alignment functions + 3. Uses Layer 4 GraphBackend for persistence + """ + + def __init__( + self, + backend: GraphBackend, + session: MelleaSession, + domain: str = "generic", + batch_size: int = 10, + ): + """Initialize the preprocessor. + + Args: + backend: GraphBackend instance (Layer 4) for storing entities/relations + session: MelleaSession for LLM operations + domain: Domain name (used in extraction hints) + batch_size: Number of documents to process in parallel + """ + self.backend = backend + self.session = session + self.domain = domain + self.batch_size = batch_size + + @abstractmethod + def get_hints(self) -> str: + """Get domain-specific hints for the LLM extraction. + + Should be overridden by subclasses to provide domain-specific guidance. + + Returns: + String with domain hints for LLM extraction + """ + pass + + async def process_document( + self, + doc_text: str, + doc_id: Optional[str] = None, + entity_types: str = "", + relation_types: str = "", + ) -> ExtractionResult: + """Process a single document to extract entities and relations. + + Uses Layer 3 extract_entities_and_relations function to call the LLM. + + Args: + doc_text: The document text to process + doc_id: Optional document ID for tracking + entity_types: Optional comma-separated list of entity types to extract + relation_types: Optional comma-separated list of relation types to extract + + Returns: + ExtractionResult with extracted entities and relations + """ + logger.info(f"Processing document {doc_id} with {len(doc_text)} chars") + + # Layer 3: Extract entities and relations using LLM + result = await extract_entities_and_relations( + doc_context=doc_text, + domain=self.domain, + hints=self.get_hints(), + reference=doc_id or "unknown", + entity_types=entity_types, + relation_types=relation_types, + ) + + # Post-process if needed (can be overridden by subclasses) + result = await self.post_process_extraction(result, doc_text) + + logger.info( + f"Extracted {len(result.entities)} entities and {len(result.relations)} relations" + ) + return result + + async def post_process_extraction( + self, result: ExtractionResult, doc_text: str + ) -> ExtractionResult: + """Post-process extracted entities and relations. + + Can be overridden by subclasses for domain-specific processing. + + Args: + result: The extraction result from LLM + doc_text: The original document text + + Returns: + Modified extraction result + """ + # Default: no post-processing + return result + + async def persist_extraction( + self, + result: ExtractionResult, + doc_id: str, + merge_strategy: str = "merge_if_similar", + ) -> dict[str, Any]: + """Persist extracted entities and relations to the KG. + + Converts Entity/Relation models to GraphNode/GraphEdge and stores them. + + Args: + result: ExtractionResult to persist + doc_id: Document ID for tracking provenance + merge_strategy: Strategy for handling existing entities ("merge_if_similar", "skip", "overwrite") + + Returns: + Dictionary with persisted node/edge IDs and statistics + """ + persisted = {"entities": {}, "relations": {}, "stats": {}} + + # Convert entities to GraphNodes and store + for i, entity in enumerate(result.entities): + # Create GraphNode from Entity + node = GraphNode( + id=f"{doc_id}_entity_{i}", + label=entity.type, + properties={ + "name": entity.name, + "description": entity.description, + "confidence": entity.confidence, + "source_doc": doc_id, + }, + ) + + # Add domain-specific properties if present + if entity.properties: + node.properties.update(entity.properties) + + # Store node ID for relation linking + persisted["entities"][i] = node.id + logger.debug(f"Persisted entity: {node.id}") + + # Convert relations to GraphEdges and store + for i, relation in enumerate(result.relations): + # For now, just store relation data + # In a full implementation, would link to actual entity IDs + edge_data = { + "source_entity": relation.source_entity, + "relation_type": relation.relation_type, + "target_entity": relation.target_entity, + "description": relation.description, + "source_doc": doc_id, + } + + if relation.properties: + edge_data.update(relation.properties) + + persisted["relations"][i] = edge_data + logger.debug(f"Persisted relation: {relation.relation_type}") + + persisted["stats"] = { + "entities_count": len(result.entities), + "relations_count": len(result.relations), + } + + return persisted + + async def close(self): + """Close connections and cleanup resources.""" + await self.backend.close() diff --git a/mellea_contribs/kg/qa_models.py b/mellea_contribs/kg/qa_models.py new file mode 100644 index 0000000..2a7ac98 --- /dev/null +++ b/mellea_contribs/kg/qa_models.py @@ -0,0 +1,311 @@ +"""Configuration and result models for KG-RAG QA pipeline. + +This module provides Pydantic models for configuring and tracking QA operations, +integrating with the generative functions in components/generative.py and the +KGRag orchestrator in kgrag.py. + +The models are designed to be reused across different QA scenarios and to track +metrics and configurations for reproducibility. +""" + +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from mellea_contribs.kg.models import ( + DirectAnswer, + EvaluationResult, + QuestionRoutes, + RelevantEntities, + RelevantRelations, + TopicEntities, + ValidationResult, +) + + +class QAConfig(BaseModel): + """Configuration for QA pipeline parameters. + + Controls how the QA pipeline breaks down questions and searches the graph. + These parameters directly map to generative function inputs. + """ + + route_count: int = Field( + default=3, + description="Number of different solving routes to generate per question. " + "Higher values explore more possibilities but increase LLM calls.", + ) + + depth: int = Field( + default=2, + description="Maximum depth for relation traversal in graph search. " + "How many hops deep to search from initial entities.", + ) + + width: int = Field( + default=5, + description="Maximum number of entities to consider at each level. " + "Limits breadth of graph search to avoid combinatorial explosion.", + ) + + domain: Optional[str] = Field( + default=None, + description="Domain for the QA task (e.g., 'movies', 'medicine'). " + "Used for domain-specific extraction hints.", + ) + + consensus_threshold: float = Field( + default=0.7, + description="Confidence threshold for consensus validation (0.0-1.0). " + "Only use consensus if agreement above this threshold.", + ) + + format_style: str = Field( + default="detailed", + description="How to format query results for the LLM. " + "Options: 'detailed', 'concise', 'structured'.", + ) + + max_repair_attempts: int = Field( + default=2, + description="Maximum attempts to repair invalid Cypher queries. " + "Used in query validation and repair loop.", + ) + + +class QASessionConfig(BaseModel): + """Configuration for LLM and evaluation settings in QA session. + + Manages session-level settings for QA operations like model parameters + and evaluation criteria. + """ + + llm_model: str = Field( + default="gpt-4o-mini", + description="LLM model to use for QA (e.g., 'gpt-4o-mini', 'claude-3-sonnet'). " + "Should be a LiteLLM-compatible model ID.", + ) + + temperature: float = Field( + default=0.1, + description="Temperature for LLM generation (0.0-2.0). " + "Lower = more deterministic, higher = more creative.", + ) + + max_tokens: int = Field( + default=2048, + description="Maximum tokens for LLM responses.", + ) + + few_shot_examples: Optional[list[dict[str, Any]]] = Field( + default=None, + description="Few-shot examples for Cypher generation.", + ) + + eval_model: str = Field( + default="gpt-4o-mini", + description="Model for evaluation/validation.", + ) + + evaluation_threshold: float = Field( + default=0.5, + description="Confidence threshold for evaluation (0.0-1.0). " + "Answers below this are marked as uncertain.", + ) + + +class QADatasetConfig(BaseModel): + """Configuration for QA dataset and batch processing. + + Specifies dataset paths and batch processing parameters for running + QA on multiple questions. + """ + + dataset_path: str = Field( + default="data/qa_dataset.jsonl", + description="Path to dataset file (JSONL format). " + "Each line: {\"question\": \"...\", \"expected_answer\": \"...\", ...}", + ) + + batch_size: int = Field( + default=32, + description="Number of questions to process in parallel.", + ) + + output_path: str = Field( + default="output/qa_results.jsonl", + description="Path to save QA results (JSONL format).", + ) + + num_workers: int = Field( + default=4, + description="Number of parallel workers for batch processing.", + ) + + shuffle: bool = Field( + default=True, + description="Whether to shuffle dataset before processing.", + ) + + max_examples: Optional[int] = Field( + default=None, + description="Maximum number of examples to process. " + "If None, process entire dataset.", + ) + + skip_errors: bool = Field( + default=True, + description="Whether to continue processing on errors. " + "If False, halt on first error.", + ) + + +class QAResult(BaseModel): + """Result of a single QA query. + + Combines the input question, the generated answer, and intermediate results + for tracing and evaluation. + """ + + question: str = Field(description="The input question") + + answer: str = Field(description="The final answer generated by the system") + + confidence: float = Field( + default=0.5, + description="Confidence score for the answer (0.0-1.0)", ge=0.0, le=1.0 + ) + + # Intermediate results from QA pipeline + question_routes: Optional[QuestionRoutes] = Field( + default=None, description="Breakdown routes generated in step 1" + ) + + topic_entities: Optional[TopicEntities] = Field( + default=None, description="Topic entities extracted in step 2" + ) + + relevant_entities: Optional[RelevantEntities] = Field( + default=None, description="Relevant entities found in graph" + ) + + relevant_relations: Optional[RelevantRelations] = Field( + default=None, description="Relevant relations found in graph" + ) + + evaluation_result: Optional[EvaluationResult] = Field( + default=None, description="Knowledge sufficiency evaluation from step 4" + ) + + validation_result: Optional[ValidationResult] = Field( + default=None, description="Consensus validation result (if multi-route)" + ) + + direct_answer: Optional[DirectAnswer] = Field( + default=None, description="Direct answer from LLM (if knowledge insufficient)" + ) + + # Evidence and reasoning + cypher_query: Optional[str] = Field( + default=None, description="Cypher query used to search the graph" + ) + + graph_evidence: Optional[list[str]] = Field( + default=None, description="Graph data that supported the answer" + ) + + reasoning: str = Field( + default="", description="Detailed reasoning/explanation for the answer" + ) + + # Metadata + processing_time_ms: float = Field( + default=0.0, + description="Time taken to process the question (milliseconds)" + ) + + model_used: str = Field( + default="gpt-4o-mini", + description="LLM model used for this result" + ) + + route_used: Optional[int] = Field( + default=None, description="Index of route that was used (if multi-route)" + ) + + error: Optional[str] = Field( + default=None, description="Error message if processing failed" + ) + + +class QAStats(BaseModel): + """Statistics from batch QA processing. + + Aggregates metrics across multiple QA queries for performance analysis. + """ + + total_questions: int = Field( + default=0, description="Total questions processed" + ) + + successful_answers: int = Field( + default=0, description="Number of successful answers" + ) + + failed_answers: int = Field( + default=0, description="Number of failed answers" + ) + + average_confidence: float = Field( + default=0.0, description="Average confidence score" + ) + + average_processing_time_ms: float = Field( + default=0.0, description="Average processing time" + ) + + min_processing_time_ms: float = Field( + default=0.0, description="Minimum processing time" + ) + + max_processing_time_ms: float = Field( + default=0.0, description="Maximum processing time" + ) + + models_used: list[str] = Field( + default_factory=list, description="List of models used in processing" + ) + + # Evaluation metrics (if ground truth available) + exact_match_count: int = Field( + default=0, description="Number of exact match answers" + ) + + partial_match_count: int = Field( + default=0, description="Number of partial match answers" + ) + + no_match_count: int = Field( + default=0, description="Number of non-matching answers" + ) + + mean_reciprocal_rank: float = Field( + default=0.0, description="MRR metric if ranking-based evaluation" + ) + + total_time_ms: float = Field( + default=0.0, description="Total processing time (milliseconds)" + ) + + avg_time_per_question_ms: float = Field( + default=0.0, description="Average processing time per question (milliseconds)" + ) + + +__all__ = [ + "QAConfig", + "QASessionConfig", + "QADatasetConfig", + "QAResult", + "QAStats", +] diff --git a/mellea_contribs/kg/rep.py b/mellea_contribs/kg/rep.py new file mode 100644 index 0000000..9d422df --- /dev/null +++ b/mellea_contribs/kg/rep.py @@ -0,0 +1,223 @@ +"""Entity and relation representation utilities for KG operations. + +Provides utility functions for formatting entities and relations into +human-readable text for LLM prompts and display. +""" + +import re +from typing import Optional + +from mellea_contribs.kg.models import Entity, Relation + + +def normalize_entity_name(name: str) -> str: + """Normalize entity name to a canonical form. + + Handles: + - Converting to title case + - Removing extra whitespace + - Standardizing quotes and punctuation + + Args: + name: Raw entity name + + Returns: + Normalized entity name + """ + # Remove extra whitespace + normalized = " ".join(name.split()) + + # Convert to title case + normalized = normalized.title() + + # Standardize quotes + normalized = normalized.replace("'", "'").replace(""", '"').replace(""", '"') + + return normalized + + +def entity_to_text(entity: Entity, include_confidence: bool = False) -> str: + """Format entity into human-readable text for LLM prompts. + + Args: + entity: Entity to format + include_confidence: Whether to include confidence score + + Returns: + Formatted text representation of entity + """ + parts = [f"**{entity.name}** ({entity.type})"] + + if entity.description: + parts.append(f"Description: {entity.description}") + + if entity.properties: + prop_str = ", ".join( + f"{k}: {v}" for k, v in entity.properties.items() if v is not None + ) + if prop_str: + parts.append(f"Properties: {prop_str}") + + if include_confidence: + parts.append(f"Confidence: {entity.confidence:.2%}") + + return "\n".join(parts) + + +def relation_to_text(relation: Relation, include_confidence: bool = False) -> str: + """Format relation into human-readable text for LLM prompts. + + Args: + relation: Relation to format + include_confidence: Whether to include confidence score + + Returns: + Formatted text representation of relation + """ + parts = [ + f"**{relation.source_entity}** --[{relation.relation_type}]--> **{relation.target_entity}**" + ] + + if relation.description: + parts.append(f"Description: {relation.description}") + + if relation.properties: + prop_str = ", ".join( + f"{k}: {v}" for k, v in relation.properties.items() if v is not None + ) + if prop_str: + parts.append(f"Properties: {prop_str}") + + if include_confidence and hasattr(relation, "confidence"): + parts.append(f"Confidence: {relation.confidence:.2%}") + + return "\n".join(parts) + + +def format_entity_list( + entities: list[Entity], include_confidence: bool = False, max_items: Optional[int] = None +) -> str: + """Format list of entities into readable text. + + Args: + entities: List of entities to format + include_confidence: Whether to include confidence scores + max_items: Maximum number of entities to display (None for all) + + Returns: + Formatted text with all entities + """ + display_entities = entities[:max_items] if max_items else entities + + formatted = [] + for i, entity in enumerate(display_entities, 1): + formatted.append(f"{i}. {entity_to_text(entity, include_confidence)}") + + if max_items and len(entities) > max_items: + formatted.append(f"\n... and {len(entities) - max_items} more entities") + + return "\n\n".join(formatted) + + +def format_relation_list( + relations: list[Relation], + include_confidence: bool = False, + max_items: Optional[int] = None, +) -> str: + """Format list of relations into readable text. + + Args: + relations: List of relations to format + include_confidence: Whether to include confidence scores + max_items: Maximum number of relations to display (None for all) + + Returns: + Formatted text with all relations + """ + display_relations = relations[:max_items] if max_items else relations + + formatted = [] + for i, relation in enumerate(display_relations, 1): + formatted.append(f"{i}. {relation_to_text(relation, include_confidence)}") + + if max_items and len(relations) > max_items: + formatted.append(f"\n... and {len(relations) - max_items} more relations") + + return "\n\n".join(formatted) + + +def format_kg_context( + entities: list[Entity], + relations: list[Relation], + include_confidence: bool = False, + max_entities: Optional[int] = None, + max_relations: Optional[int] = None, +) -> str: + """Format knowledge graph context for LLM prompts. + + Combines entities and relations into a structured text representation. + + Args: + entities: List of entities from KG + relations: List of relations from KG + include_confidence: Whether to include confidence scores + max_entities: Maximum entities to display + max_relations: Maximum relations to display + + Returns: + Formatted KG context text + """ + sections = [] + + if entities: + sections.append("## Entities\n") + sections.append(format_entity_list(entities, include_confidence, max_entities)) + + if relations: + sections.append("\n## Relations\n") + sections.append(format_relation_list(relations, include_confidence, max_relations)) + + return "\n".join(sections) if sections else "(Empty knowledge graph)" + + +def camelcase_to_snake_case(name: str) -> str: + """Convert camelCase to snake_case. + + Args: + name: Name in camelCase + + Returns: + Name converted to snake_case + """ + # Insert underscore before uppercase letters + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def snake_case_to_camelcase(name: str, upper_first: bool = False) -> str: + """Convert snake_case to camelCase. + + Args: + name: Name in snake_case + upper_first: Whether to capitalize first letter (PascalCase) + + Returns: + Name converted to camelCase or PascalCase + """ + components = name.split("_") + if upper_first: + return "".join(x.title() for x in components) + else: + return components[0].lower() + "".join(x.title() for x in components[1:]) + + +__all__ = [ + "normalize_entity_name", + "entity_to_text", + "relation_to_text", + "format_entity_list", + "format_relation_list", + "format_kg_context", + "camelcase_to_snake_case", + "snake_case_to_camelcase", +] diff --git a/mellea_contribs/kg/requirements/__init__.py b/mellea_contribs/kg/requirements/__init__.py new file mode 100644 index 0000000..3acaf46 --- /dev/null +++ b/mellea_contribs/kg/requirements/__init__.py @@ -0,0 +1,98 @@ +"""Graph-specific requirements for query validation. + +Provides Requirement factories for validating graph queries against +syntax rules, schema constraints, and executability. +""" + +from mellea.stdlib.context import Context +from mellea.stdlib.requirements import Requirement, ValidationResult + +from mellea_contribs.kg.components.query import GraphQuery +from mellea_contribs.kg.graph_dbs.base import GraphBackend + + +def is_valid_cypher(backend: GraphBackend) -> Requirement: + """Require that the query is valid Cypher syntax. + + Uses the backend's validate_query() to check syntax without executing. + + Args: + backend: Graph backend used to validate the query. + + Returns: + A Requirement that passes when the query is syntactically valid. + """ + + async def validate(ctx: Context) -> ValidationResult: + query_string = ctx.last_output().value + query = GraphQuery(query_string=str(query_string)) + is_valid, error = await backend.validate_query(query) + return ValidationResult( + is_valid, + reason=error if not is_valid else "Valid Cypher syntax", + ) + + return Requirement( + description="Query must be valid Cypher syntax", + validation_fn=validate, + ) + + +def returns_results(backend: GraphBackend) -> Requirement: + """Require that the query returns at least one node or edge. + + Executes the query and checks for non-empty results. + + Args: + backend: Graph backend used to execute the query. + + Returns: + A Requirement that passes when the query produces results. + """ + + async def validate(ctx: Context) -> ValidationResult: + query_string = ctx.last_output().value + query = GraphQuery(query_string=str(query_string)) + result = await backend.execute_query(query) + has_results = len(result.nodes) > 0 or len(result.edges) > 0 + return ValidationResult( + has_results, + reason=( + "Query returned results" + if has_results + else "Query returned no results" + ), + ) + + return Requirement( + description="Query must return non-empty results", + validation_fn=validate, + ) + + +def respects_schema(backend: GraphBackend) -> Requirement: + """Require that the query only references node and edge types in the schema. + + Args: + backend: Graph backend used to retrieve the schema. + + Returns: + A Requirement that passes when the query respects the schema. + """ + + async def validate(ctx: Context) -> ValidationResult: + await backend.get_schema() + # Full Cypher AST parsing would be needed for strict enforcement. + # This is a placeholder that always passes once the schema is fetched. + return ValidationResult( + True, + reason="Query respects schema", + ) + + return Requirement( + description="Query must only reference valid schema types", + validation_fn=validate, + ) + + +__all__ = ["is_valid_cypher", "respects_schema", "returns_results"] diff --git a/mellea_contribs/kg/requirements_models.py b/mellea_contribs/kg/requirements_models.py new file mode 100644 index 0000000..6a1b3f9 --- /dev/null +++ b/mellea_contribs/kg/requirements_models.py @@ -0,0 +1,211 @@ +"""Entity and relation validation requirements for KG operations. + +Provides Requirement factories for validating entities and relations +against type constraints, schema rules, and data quality standards. + +Requires: Mellea >= 0.3.0 (for mellea.stdlib.context and mellea.stdlib.requirements) +""" + +from typing import Optional + +from mellea.stdlib.context import Context +from mellea.stdlib.requirements import Requirement, ValidationResult + +from mellea_contribs.kg.models import Entity, Relation + + +def entity_type_valid(allowed_types: list[str]) -> Requirement: + """Require that an entity type is in the allowed types list. + + Args: + allowed_types: List of valid entity type names (e.g., ['Movie', 'Person']) + + Returns: + A Requirement that passes when the entity's type is allowed. + """ + + async def validate(ctx: Context) -> ValidationResult: + try: + entity = ctx.last_output() + if isinstance(entity, Entity): + is_valid = entity.type in allowed_types + return ValidationResult( + is_valid, + reason=( + f"Entity type '{entity.type}' is valid" + if is_valid + else f"Entity type '{entity.type}' not in allowed types: {allowed_types}" + ), + ) + return ValidationResult(False, reason="Output is not an Entity") + except Exception as e: + return ValidationResult(False, reason=f"Validation error: {str(e)}") + + return Requirement( + description=f"Entity type must be one of: {allowed_types}", + validation_fn=validate, + ) + + +def entity_has_name(ctx: Context) -> ValidationResult: + """Require that an entity has a non-empty name. + + Returns a ValidationResult indicating if entity name is valid. + """ + try: + entity = ctx.last_output() + if isinstance(entity, Entity): + is_valid = bool(entity.name and entity.name.strip()) + return ValidationResult( + is_valid, + reason="Entity has valid name" if is_valid else "Entity name is empty", + ) + return ValidationResult(False, reason="Output is not an Entity") + except Exception as e: + return ValidationResult(False, reason=f"Validation error: {str(e)}") + + +def entity_has_description(ctx: Context) -> ValidationResult: + """Require that an entity has a non-empty description. + + Returns a ValidationResult indicating if entity description is valid. + """ + try: + entity = ctx.last_output() + if isinstance(entity, Entity): + is_valid = bool(entity.description and entity.description.strip()) + return ValidationResult( + is_valid, + reason=( + "Entity has valid description" + if is_valid + else "Entity description is empty" + ), + ) + return ValidationResult(False, reason="Output is not an Entity") + except Exception as e: + return ValidationResult(False, reason=f"Validation error: {str(e)}") + + +def relation_type_valid(allowed_types: list[str]) -> Requirement: + """Require that a relation type is in the allowed types list. + + Args: + allowed_types: List of valid relation type names (e.g., ['directed_by', 'acted_in']) + + Returns: + A Requirement that passes when the relation's type is allowed. + """ + + async def validate(ctx: Context) -> ValidationResult: + try: + relation = ctx.last_output() + if isinstance(relation, Relation): + is_valid = relation.relation_type in allowed_types + return ValidationResult( + is_valid, + reason=( + f"Relation type '{relation.relation_type}' is valid" + if is_valid + else f"Relation type '{relation.relation_type}' not in allowed types: {allowed_types}" + ), + ) + return ValidationResult(False, reason="Output is not a Relation") + except Exception as e: + return ValidationResult(False, reason=f"Validation error: {str(e)}") + + return Requirement( + description=f"Relation type must be one of: {allowed_types}", + validation_fn=validate, + ) + + +def relation_entities_exist(entities: list[str]) -> Requirement: + """Require that relation source and target entities exist in the provided list. + + Args: + entities: List of valid entity names that can be relation endpoints + + Returns: + A Requirement that passes when both entities are in the allowed list. + """ + + async def validate(ctx: Context) -> ValidationResult: + try: + relation = ctx.last_output() + if isinstance(relation, Relation): + source_valid = relation.source_entity in entities + target_valid = relation.target_entity in entities + is_valid = source_valid and target_valid + + reasons = [] + if not source_valid: + reasons.append( + f"Source entity '{relation.source_entity}' not found" + ) + if not target_valid: + reasons.append( + f"Target entity '{relation.target_entity}' not found" + ) + + return ValidationResult( + is_valid, + reason=( + "Both relation entities are valid" + if is_valid + else "; ".join(reasons) + ), + ) + return ValidationResult(False, reason="Output is not a Relation") + except Exception as e: + return ValidationResult(False, reason=f"Validation error: {str(e)}") + + return Requirement( + description=f"Relation entities must be in: {entities}", + validation_fn=validate, + ) + + +def entity_confidence_threshold( + min_confidence: float = 0.5, +) -> Requirement: + """Require that an entity meets a minimum confidence threshold. + + Args: + min_confidence: Minimum confidence score (0-1) + + Returns: + A Requirement that passes when entity confidence meets threshold. + """ + + async def validate(ctx: Context) -> ValidationResult: + try: + entity = ctx.last_output() + if isinstance(entity, Entity): + is_valid = entity.confidence >= min_confidence + return ValidationResult( + is_valid, + reason=( + f"Entity confidence {entity.confidence:.2f} meets threshold {min_confidence}" + if is_valid + else f"Entity confidence {entity.confidence:.2f} below threshold {min_confidence}" + ), + ) + return ValidationResult(False, reason="Output is not an Entity") + except Exception as e: + return ValidationResult(False, reason=f"Validation error: {str(e)}") + + return Requirement( + description=f"Entity confidence must be >= {min_confidence}", + validation_fn=validate, + ) + + +__all__ = [ + "entity_type_valid", + "entity_has_name", + "entity_has_description", + "relation_type_valid", + "relation_entities_exist", + "entity_confidence_threshold", +] diff --git a/mellea_contribs/kg/sampling/__init__.py b/mellea_contribs/kg/sampling/__init__.py new file mode 100644 index 0000000..ce73bcb --- /dev/null +++ b/mellea_contribs/kg/sampling/__init__.py @@ -0,0 +1,6 @@ +"""Query sampling strategies for knowledge graph operations.""" + +from mellea_contribs.kg.sampling.validation import QueryValidationStrategy + +__all__ = ["QueryValidationStrategy"] + diff --git a/mellea_contribs/kg/sampling/validation.py b/mellea_contribs/kg/sampling/validation.py new file mode 100644 index 0000000..f0c29b8 --- /dev/null +++ b/mellea_contribs/kg/sampling/validation.py @@ -0,0 +1,108 @@ +"""Query validation sampling strategy for knowledge graph queries. + +Implements the Instruct/Validate/Repair loop for LLM-guided query generation. +""" + +from mellea.stdlib.components import CBlock, Component, ModelOutputThunk +from mellea.stdlib.context import Context +from mellea.stdlib.requirements import Requirement, ValidationResult +from mellea.stdlib.sampling import BaseSamplingStrategy + +from mellea_contribs.kg.graph_dbs.base import GraphBackend + + +class QueryValidationStrategy(BaseSamplingStrategy): + """Sampling strategy for generating and validating graph queries. + + Uses the Instruct/Validate/Repair pattern: + 1. Generate a query from natural language. + 2. Validate syntax and executability. + 3. If invalid, repair using error feedback. + + Args: + backend: Graph backend used to validate queries. + loop_budget: Maximum number of repair attempts. + requirements: Query validation requirements to check. + """ + + def __init__( + self, + backend: GraphBackend, + loop_budget: int = 3, + requirements: list[Requirement] | None = None, + ): + """Initialize QueryValidationStrategy. + + Args: + backend: Graph backend for validation. + loop_budget: Max repair attempts. + requirements: Query validation requirements. + """ + super().__init__(loop_budget=loop_budget, requirements=requirements) + self._backend = backend + + @staticmethod + def repair( + old_ctx: Context, + new_ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> tuple[Component, Context]: + """Repair a failed query using validation error feedback. + + Constructs a new repair instruction from the last validation failures + and appends it to the context for the next generation attempt. + + Args: + old_ctx: Context before the last action and output. + new_ctx: Context including the last action and output. + past_actions: Actions executed so far without success. + past_results: Generation results for those actions. + past_val: Validation results for each attempt. + + Returns: + A tuple of (repair instruction component, updated context). + """ + last_validation = past_val[-1] + + error_messages = [ + result.reason + for _, result in last_validation + if not bool(result) and result.reason + ] + + failed_query = str(past_results[-1].value) if past_results else "" + + repair_text = ( + f"The previous query failed validation:\n" + f"Query: {failed_query}\n" + f"Errors: {', '.join(error_messages)}\n" + f"Please generate a corrected query." + ) + + return CBlock(repair_text), new_ctx + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + """Select the best query when all attempts have failed. + + Picks the attempt with the fewest validation errors. + + Args: + sampled_actions: All actions attempted without success. + sampled_results: Generation results for those actions. + sampled_val: Validation results for each attempt. + + Returns: + Index of the attempt with the fewest validation errors. + """ + error_counts = [ + sum(1 for _, result in validation if not bool(result)) + for validation in sampled_val + ] + return error_counts.index(min(error_counts)) diff --git a/mellea_contribs/kg/updater_models.py b/mellea_contribs/kg/updater_models.py new file mode 100644 index 0000000..cf95ce4 --- /dev/null +++ b/mellea_contribs/kg/updater_models.py @@ -0,0 +1,408 @@ +"""Configuration and result models for KG update pipeline. + +This module provides Pydantic models for configuring KG update operations, +integrating with the extraction and alignment functions in components/generative.py +and the KGPreprocessor orchestrator. + +The models track configuration for extracting entities/relations from documents, +aligning them with existing KG entities, and updating the graph accordingly. +""" + +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class UpdateConfig(BaseModel): + """Configuration for KG update process parameters. + + Controls how entities and relations are extracted, aligned, and merged. + """ + + batch_size: int = Field( + default=32, + description="Number of documents to process in parallel.", + ) + + merge_strategy: str = Field( + default="merge_if_similar", + description="How to handle entity/relation conflicts: " + "'merge_if_similar' (default), 'skip', 'overwrite', 'create_variant'.", + ) + + similarity_threshold: float = Field( + default=0.8, + ge=0.0, + le=1.0, + description="Confidence threshold for merging entities/relations (0.0-1.0). " + "Only merge if alignment confidence > threshold.", + ) + + max_entities_per_doc: Optional[int] = Field( + default=None, + description="Maximum entities to extract per document. " + "If None, extract all found.", + ) + + max_relations_per_doc: Optional[int] = Field( + default=None, + description="Maximum relations to extract per document. " + "If None, extract all found.", + ) + + domain: str = Field( + default="generic", + description="Domain for entity/relation extraction (e.g., 'movies'). " + "Used for domain-specific extraction hints.", + ) + + entity_types: Optional[list[str]] = Field( + default=None, + description="List of entity types to extract. " + "If None, extract all types found.", + ) + + relation_types: Optional[list[str]] = Field( + default=None, + description="List of relation types to extract. " + "If None, extract all types found.", + ) + + skip_validation: bool = Field( + default=False, + description="Whether to skip KG schema validation. " + "If False, validate extracted entities/relations against schema.", + ) + + +class UpdateSessionConfig(BaseModel): + """Configuration for LLM and alignment settings in update session. + + Manages session-level settings for entity/relation extraction and alignment. + """ + + extraction_model: str = Field( + default="gpt-4o-mini", + description="LLM model to use for entity/relation extraction.", + ) + + extraction_temperature: float = Field( + default=0.1, + ge=0.0, + le=2.0, + description="Temperature for extraction LLM (0.0-2.0).", + ) + + alignment_model: str = Field( + default="gpt-4o-mini", + description="Model for entity alignment.", + ) + + alignment_temperature: float = Field( + default=0.1, + ge=0.0, + le=2.0, + description="Temperature for alignment LLM.", + ) + + merge_decision_model: str = Field( + default="gpt-4o-mini", + description="Model for merge decisions.", + ) + + merge_decision_temperature: float = Field( + default=0.1, + ge=0.0, + le=2.0, + description="Temperature for merge decision LLM.", + ) + + use_few_shot_examples: bool = Field( + default=True, + description="Whether to include few-shot examples in prompts.", + ) + + num_alignment_candidates: int = Field( + default=5, + description="Number of candidate entities to consider for alignment.", + ) + + +class UpdateStats(BaseModel): + """Statistics tracking for KG update process. + + Tracks metrics about entities/relations extracted, merged, and added. + """ + + total_documents: int = Field( + default=0, description="Total documents processed" + ) + + successful_documents: int = Field( + default=0, description="Documents processed successfully" + ) + + failed_documents: int = Field( + default=0, description="Documents that failed processing" + ) + + # Extraction statistics + entities_extracted: int = Field( + default=0, description="Total entities extracted from documents" + ) + + relations_extracted: int = Field( + default=0, description="Total relations extracted from documents" + ) + + # Alignment statistics + entities_aligned: int = Field( + default=0, description="Entities aligned with existing KG entities" + ) + + entities_new: int = Field( + default=0, description="New entities added to KG" + ) + + relations_aligned: int = Field( + default=0, description="Relations aligned with existing KG relations" + ) + + relations_new: int = Field( + default=0, description="New relations added to KG" + ) + + # Merge statistics + entities_merged: int = Field( + default=0, description="Entities merged with existing entities" + ) + + relations_merged: int = Field( + default=0, description="Relations merged with existing relations" + ) + + entities_skipped: int = Field( + default=0, description="Entities skipped due to merge conflicts" + ) + + relations_skipped: int = Field( + default=0, description="Relations skipped due to merge conflicts" + ) + + # Performance + average_processing_time_per_doc_ms: float = Field( + default=0.0, description="Average processing time per document (milliseconds)" + ) + + total_processing_time_ms: float = Field( + default=0.0, description="Total processing time" + ) + + +class MergeConflict(BaseModel): + """Record of a merge conflict during KG update. + + Tracks conflicts that occurred when trying to align/merge entities or relations. + """ + + source_id: str = Field(description="ID of source entity/relation") + + target_id: str = Field(description="ID of target entity/relation") + + conflict_type: str = Field( + description="Type of conflict: 'entity_merge', 'relation_merge', 'property_conflict'" + ) + + similarity_score: float = Field( + ge=0.0, le=1.0, description="Similarity score between entities/relations" + ) + + decision: str = Field( + description="How conflict was resolved: 'merged', 'skipped', 'variant_created'" + ) + + reason: str = Field(description="Reasoning for the decision") + + timestamp: Optional[str] = Field( + default=None, description="When conflict occurred (ISO format)" + ) + + +class UpdateResult(BaseModel): + """Result of updating KG with entities and relations from a document. + + Tracks what was added, merged, and skipped during update for a single document. + """ + + document_id: str = Field(description="ID of document being processed") + + success: bool = Field( + default=True, description="Whether update completed successfully" + ) + + # Extraction results + entities_found: int = Field( + default=0, description="Number of entities found in document" + ) + + relations_found: int = Field( + default=0, description="Number of relations found in document" + ) + + # What was added/merged + entities_added: int = Field( + default=0, description="Number of new entities added to KG" + ) + + entities_merged: int = Field( + default=0, description="Number of entities merged with existing KG entities" + ) + + entities_skipped: int = Field( + default=0, description="Number of entities skipped due to conflicts" + ) + + relations_added: int = Field( + default=0, description="Number of new relations added to KG" + ) + + relations_merged: int = Field( + default=0, description="Number of relations merged with existing KG relations" + ) + + relations_skipped: int = Field( + default=0, description="Number of relations skipped due to conflicts" + ) + + # Conflicts + conflicts: list[MergeConflict] = Field( + default_factory=list, description="List of merge conflicts encountered" + ) + + # Metadata + processing_time_ms: float = Field( + default=0.0, description="Time to process document (milliseconds)" + ) + + model_used: str = Field( + default="gpt-4o-mini", description="Extraction model used" + ) + + error: Optional[str] = Field(default=None, description="Error message if failed") + + warnings: list[str] = Field( + default_factory=list, description="Non-fatal warnings during processing" + ) + + +class UpdateBatchResult(BaseModel): + """Aggregated results from batch KG update. + + Combines results from updating multiple documents with statistics. + """ + + total_documents: int = Field( + default=0, description="Total documents processed" + ) + + successful_documents: int = Field( + default=0, description="Number of successful documents" + ) + + failed_documents: int = Field( + default=0, description="Number of failed documents" + ) + + results: list[UpdateResult] = Field( + default_factory=list, description="Per-document results" + ) + + stats: UpdateStats = Field( + default_factory=UpdateStats, description="Aggregated statistics" + ) + + total_time_ms: float = Field( + default=0.0, description="Total time for batch (milliseconds)" + ) + + avg_time_per_document_ms: float = Field( + default=0.0, description="Average time per document (milliseconds)" + ) + + start_time: Optional[str] = Field( + default=None, description="When batch processing started (ISO format)" + ) + + end_time: Optional[str] = Field( + default=None, description="When batch processing ended (ISO format)" + ) + + +class KGUpdateRunConfig(BaseModel): + """Aggregated run-time configuration for the KG update pipeline. + + Collects all settings needed by a KG update run script: session (LLM model), + updater (concurrency, loop budgets), dataset (path, domain, progress), and + graph backend (connection, mock flag). + + Typical usage:: + + config = KGUpdateRunConfig( + model="gpt-4o-mini", + num_workers=32, + domain="movie", + dataset_path="data/corpus.jsonl.bz2", + ) + """ + + # Session + model: str = Field(default="gpt-4o-mini", description="LLM model for extraction and alignment.") + + # Updater concurrency + num_workers: int = Field(default=64, description="Max concurrent async workers.") + queue_size: int = Field(default=64, description="Async queue capacity for document loading.") + + # Loop budgets (passed to orchestrate_kg_update) + extraction_loop_budget: int = Field( + default=3, description="Retry budget for entity/relation extraction." + ) + alignment_loop_budget: int = Field( + default=2, description="Retry budget for entity alignment." + ) + align_topk: int = Field( + default=10, description="Number of top candidates considered during alignment." + ) + align_entity: bool = Field(default=True, description="Whether to align entities.") + merge_entity: bool = Field(default=True, description="Whether to merge aligned entities.") + align_relation: bool = Field(default=True, description="Whether to align relations.") + merge_relation: bool = Field(default=True, description="Whether to merge aligned relations.") + + # Dataset + dataset_path: Optional[str] = Field(default=None, description="Path to input JSONL dataset.") + domain: str = Field(default="movie", description="Knowledge domain label.") + progress_path: str = Field( + default="results/update_kg_progress.json", + description="Path to the progress checkpoint file.", + ) + + # Graph backend + db_uri: str = Field( + default="bolt://localhost:7687", description="Graph database connection URI." + ) + db_user: str = Field(default="neo4j", description="Graph database username.") + db_password: str = Field(default="password", description="Graph database password.") + mock: bool = Field(default=False, description="Use in-memory mock backend instead of real DB.") + + # Misc + verbose: bool = Field(default=False, description="Enable verbose logging.") + + +__all__ = [ + "UpdateConfig", + "UpdateSessionConfig", + "UpdateStats", + "MergeConflict", + "UpdateResult", + "UpdateBatchResult", + "KGUpdateRunConfig", +] diff --git a/mellea_contribs/kg/utils/__init__.py b/mellea_contribs/kg/utils/__init__.py new file mode 100644 index 0000000..dc40c60 --- /dev/null +++ b/mellea_contribs/kg/utils/__init__.py @@ -0,0 +1,88 @@ +"""KG utility modules for JSONL I/O, session management, progress tracking, and evaluation. + +This package provides reusable utilities extracted from run scripts: +- data_utils: JSONL reading/writing, batch processing +- session_manager: Mellea session and backend creation +- progress: Logging, progress tracking, structured output +- eval_utils: Evaluation metrics and result aggregation +""" + +from .data_utils import ( + BaseDatasetLoader, + append_jsonl, + batch_iterator, + load_jsonl, + save_jsonl, + shuffle_jsonl, + stream_batch_process, + truncate_jsonl, + validate_jsonl_schema, +) +from .eval_utils import ( + aggregate_qa_results, + aggregate_update_results, + evaluate_predictions, + exact_match, + f1_score, + fuzzy_match, + mean_reciprocal_rank, + precision, + recall, +) +from .progress import ( + BaseProgressLogger, + ProgressTracker, + QAProgressLogger, + log_progress, + output_json, + print_stats, + setup_logging, +) +from .session_manager import ( + MelleaResourceManager, + create_backend, + create_embedding_client, + create_openai_session, + create_session, + create_session_from_env, + generate_embeddings, +) + +__all__ = [ + # data_utils + "load_jsonl", + "save_jsonl", + "append_jsonl", + "batch_iterator", + "stream_batch_process", + "truncate_jsonl", + "shuffle_jsonl", + "validate_jsonl_schema", + "BaseDatasetLoader", + # session_manager + "create_session", + "create_openai_session", + "create_session_from_env", + "create_backend", + "create_embedding_client", + "generate_embeddings", + "MelleaResourceManager", + # progress + "setup_logging", + "log_progress", + "output_json", + "print_stats", + "ProgressTracker", + "BaseProgressLogger", + "QAProgressLogger", + # eval_utils + "exact_match", + "fuzzy_match", + "mean_reciprocal_rank", + "precision", + "recall", + "f1_score", + "aggregate_qa_results", + "aggregate_update_results", + "evaluate_predictions", +] diff --git a/mellea_contribs/kg/utils/data_utils.py b/mellea_contribs/kg/utils/data_utils.py new file mode 100644 index 0000000..78e210d --- /dev/null +++ b/mellea_contribs/kg/utils/data_utils.py @@ -0,0 +1,338 @@ +"""JSONL and data processing utilities. + +Provides reusable functions for reading/writing JSONL files, batch processing, +and dataset manipulation. +""" + +import asyncio +import bz2 +import json +import random +import sys +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set + + +def load_jsonl(path: Path) -> Iterator[Dict[str, Any]]: + """Load JSONL file and yield each line as a dictionary. + + Supports both plain text and bz2-compressed JSONL files. + + Args: + path: Path to JSONL file (plain or .bz2). + + Yields: + Dictionary from each JSON line. + + Raises: + FileNotFoundError: If file does not exist. + json.JSONDecodeError: If line is not valid JSON. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if str(path).endswith('.bz2'): + f = bz2.open(path, "rt", encoding="utf-8") + else: + f = open(path, "r") + + try: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + yield json.loads(line) + except json.JSONDecodeError as e: + print(f"[Line {line_num}] JSON decode error: {e}", file=sys.stderr) + raise + finally: + f.close() + + +def save_jsonl(data: List[Dict[str, Any]], path: Path) -> None: + """Save list of dictionaries as JSONL file. + + Args: + data: List of dictionaries to save. + path: Path to output JSONL file. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "w") as f: + for item in data: + f.write(json.dumps(item) + "\n") + + +def append_jsonl(item: Dict[str, Any], path: Path) -> None: + """Append a single dictionary to JSONL file. + + Args: + item: Dictionary to append. + path: Path to JSONL file. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "a") as f: + f.write(json.dumps(item) + "\n") + + +def batch_iterator(items: List[Any], batch_size: int) -> Iterator[List[Any]]: + """Iterate through items in batches. + + Args: + items: List of items to batch. + batch_size: Size of each batch. + + Yields: + Lists of items, each of size batch_size (last batch may be smaller). + """ + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + +def stream_batch_process( + input_path: Path, + output_path: Path, + process_fn: Callable, + batch_size: int = 1, +) -> int: + """Process JSONL file in batches and write results. + + Args: + input_path: Path to input JSONL file. + output_path: Path to output JSONL file. + process_fn: Function that takes a list of items and returns processed list. + batch_size: Number of items to process at once (default: 1). + + Returns: + Number of items processed. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + count = 0 + batch = [] + + with open(output_path, "w") as out_f: + try: + for item in load_jsonl(input_path): + batch.append(item) + count += 1 + + if len(batch) >= batch_size: + # Process batch + processed = process_fn(batch) + for result in processed: + out_f.write(json.dumps(result) + "\n") + batch = [] + + # Process remaining items + if batch: + processed = process_fn(batch) + for result in processed: + out_f.write(json.dumps(result) + "\n") + + except Exception as e: + print(f"Error during batch processing: {e}", file=sys.stderr) + raise + + return count + + +def truncate_jsonl( + input_path: Path, output_path: Path, max_lines: int +) -> int: + """Truncate JSONL file to specified number of lines. + + Args: + input_path: Path to input JSONL file. + output_path: Path to output truncated JSONL file. + max_lines: Maximum number of lines to keep. + + Returns: + Number of lines written. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + count = 0 + with open(output_path, "w") as out_f: + for item in load_jsonl(input_path): + if count >= max_lines: + break + out_f.write(json.dumps(item) + "\n") + count += 1 + + return count + + +def shuffle_jsonl(input_path: Path, output_path: Path) -> int: + """Shuffle JSONL file randomly. + + Args: + input_path: Path to input JSONL file. + output_path: Path to output shuffled JSONL file. + + Returns: + Number of lines written. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Load all items + items = list(load_jsonl(input_path)) + + # Shuffle + random.shuffle(items) + + # Write + with open(output_path, "w") as f: + for item in items: + f.write(json.dumps(item) + "\n") + + return len(items) + + +def validate_jsonl_schema( + path: Path, required_fields: List[str] +) -> tuple[bool, List[str]]: + """Validate that all items in JSONL have required fields. + + Args: + path: Path to JSONL file. + required_fields: List of field names that must be present. + + Returns: + Tuple of (is_valid, error_messages). + """ + errors = [] + + try: + for line_num, item in enumerate(load_jsonl(path), 1): + for field in required_fields: + if field not in item: + errors.append(f"Line {line_num}: Missing field '{field}'") + except Exception as e: + errors.append(f"Error validating file: {e}") + + return len(errors) == 0, errors + + +class BaseDatasetLoader(ABC): + """Abstract base class for async dataset loaders with worker-pool support. + + Subclasses implement :meth:`iter_items` to yield raw dataset records. + :meth:`run` feeds those records through a configurable number of async + workers, skipping IDs that appear in ``skip_ids``. + + Usage:: + + class MyLoader(BaseDatasetLoader): + def iter_items(self): + for item in load_jsonl(self.dataset_path): + yield item + + loader = MyLoader(dataset_path="data.jsonl", num_workers=4) + results = await loader.run( + process_fn=my_async_fn, + id_key="id", + skip_ids=already_done, + ) + """ + + def __init__( + self, + dataset_path: str, + num_workers: int = 1, + queue_size: int = 100, + ) -> None: + """Initialise the loader. + + Args: + dataset_path: Path to the dataset file. + num_workers: Number of parallel async workers (default: ``1``). + queue_size: Internal asyncio queue capacity (default: ``100``). + """ + self.dataset_path = dataset_path + self.num_workers = num_workers + self.queue_size = queue_size + + @abstractmethod + def iter_items(self) -> Generator[Dict[str, Any], None, None]: + """Yield raw dataset items one by one. + + Subclasses must implement this method. It is called synchronously + from the producer coroutine inside :meth:`run`. + + Yields: + Dict representing a single dataset record. + """ + + async def run( + self, + process_fn: Callable, + id_key: str = "id", + skip_ids: Optional[Set[str]] = None, + on_result: Optional[Callable] = None, + ) -> List[Any]: + """Process all items through an async worker pool. + + Args: + process_fn: ``async (item) -> result`` coroutine called for each + item. Should return *None* to discard results. + id_key: Key in each item dict used as the unique ID for + ``skip_ids`` matching (default: ``"id"``). + skip_ids: Set of item IDs to skip (for resumption). + on_result: Optional async or sync callback ``(item_id, result)`` + invoked after each item is processed successfully. + + Returns: + List of non-None results collected from ``process_fn``. + """ + skip_ids = skip_ids or set() + queue: asyncio.Queue = asyncio.Queue(maxsize=self.queue_size) + results: List[Any] = [] + _sentinel = object() + + async def _producer() -> None: + for item in self.iter_items(): + item_id = str(item.get(id_key, "")) + if item_id and item_id in skip_ids: + continue + await queue.put(item) + # Send one sentinel per worker to signal end-of-stream + for _ in range(self.num_workers): + await queue.put(_sentinel) + + async def _worker() -> None: + while True: + item = await queue.get() + if item is _sentinel: + queue.task_done() + break + try: + result = await process_fn(item) + if result is not None: + results.append(result) + if on_result is not None: + item_id = str(item.get(id_key, "")) + if asyncio.iscoroutinefunction(on_result): + await on_result(item_id, result) + else: + on_result(item_id, result) + except Exception as exc: + print( + f"Worker error on item {item.get(id_key, '?')}: {exc}", + file=sys.stderr, + ) + finally: + queue.task_done() + + workers = [asyncio.create_task(_worker()) for _ in range(self.num_workers)] + await asyncio.gather(_producer(), *workers) + return results diff --git a/mellea_contribs/kg/utils/eval_utils.py b/mellea_contribs/kg/utils/eval_utils.py new file mode 100644 index 0000000..29076c0 --- /dev/null +++ b/mellea_contribs/kg/utils/eval_utils.py @@ -0,0 +1,334 @@ +"""Evaluation and metrics utilities. + +Provides functions for computing evaluation metrics and aggregating results. +""" + +from typing import Any, Dict, List, Optional + +try: + from rapidfuzz import fuzz +except ImportError: + fuzz = None + +from mellea_contribs.kg.qa_models import QAResult, QAStats +from mellea_contribs.kg.updater_models import UpdateResult, UpdateStats + + +def exact_match(predicted: str, expected: str) -> bool: + """Check if predicted answer exactly matches expected answer (case-insensitive). + + Args: + predicted: Predicted answer string. + expected: Expected answer string. + + Returns: + True if answers match exactly (case-insensitive), False otherwise. + """ + return predicted.lower().strip() == expected.lower().strip() + + +def fuzzy_match(predicted: str, expected: str, threshold: float = 0.8) -> bool: + """Check if predicted answer fuzzy-matches expected answer. + + Uses rapidfuzz token_set_ratio if available, otherwise falls back to exact match. + + Args: + predicted: Predicted answer string. + expected: Expected answer string. + threshold: Similarity threshold (0-1, default: 0.8). + + Returns: + True if similarity score >= threshold, False otherwise. + """ + if fuzz is None: + return exact_match(predicted, expected) + + score = fuzz.token_set_ratio(predicted.lower(), expected.lower()) / 100.0 + return score >= threshold + + +def mean_reciprocal_rank(results: List[Dict[str, Any]]) -> float: + """Compute Mean Reciprocal Rank (MRR) for ranking results. + + Args: + results: List of result dictionaries with 'answer', 'expected', and optional 'confidence'. + + Returns: + MRR score (0-1). + """ + if not results: + return 0.0 + + reciprocal_ranks = [] + + for result in results: + # Check for exact match first + if exact_match(result.get("answer", ""), result.get("expected", "")): + reciprocal_ranks.append(1.0) + else: + # For non-exact matches, use confidence as proxy + confidence = result.get("confidence", 0.0) + if confidence >= 0.9: + reciprocal_ranks.append(1.0 / (1.0 + (1.0 - confidence))) + else: + reciprocal_ranks.append(0.0) + + if not reciprocal_ranks: + return 0.0 + + return sum(reciprocal_ranks) / len(reciprocal_ranks) + + +def precision(predicted: List[str], expected: List[str]) -> float: + """Compute precision metric (TP / (TP + FP)). + + Args: + predicted: List of predicted items. + expected: List of expected items. + + Returns: + Precision score (0-1). + """ + if not predicted: + return 0.0 + + tp = len(set(predicted) & set(expected)) + return tp / len(predicted) + + +def recall(predicted: List[str], expected: List[str]) -> float: + """Compute recall metric (TP / (TP + FN)). + + Args: + predicted: List of predicted items. + expected: List of expected items. + + Returns: + Recall score (0-1). + """ + if not expected: + return 0.0 + + tp = len(set(predicted) & set(expected)) + return tp / len(expected) + + +def f1_score(prec: float, rec: float) -> float: + """Compute F1 score (harmonic mean of precision and recall). + + Args: + prec: Precision score. + rec: Recall score. + + Returns: + F1 score (0-1). + """ + if prec + rec == 0: + return 0.0 + + return 2 * (prec * rec) / (prec + rec) + + +def aggregate_qa_results(qa_results: List[QAResult]) -> QAStats: + """Aggregate QA results into statistics. + + Args: + qa_results: List of QAResult objects. + + Returns: + QAStats object with aggregated statistics. + """ + stats = QAStats() + + if not qa_results: + return stats + + stats.total_questions = len(qa_results) + + # Count successful and failed + successful = 0 + failed = 0 + times = [] + confidences = [] + + for result in qa_results: + if result.error: + failed += 1 + else: + successful += 1 + + if result.processing_time_ms: + times.append(result.processing_time_ms) + + if result.confidence: + confidences.append(result.confidence) + + stats.successful_answers = successful + stats.failed_answers = failed + + # Compute timing stats + if times: + stats.average_processing_time_ms = sum(times) / len(times) + stats.min_processing_time_ms = min(times) + stats.max_processing_time_ms = max(times) + stats.total_time_ms = sum(times) + + # Compute confidence stats + if confidences: + stats.average_confidence = sum(confidences) / len(confidences) + + # Collect models used + models = set(r.model_used for r in qa_results if r.model_used) + stats.models_used = list(models) + + return stats + + +async def evaluate_predictions( + session: Any, + predictions: List[Dict[str, Any]], + query_key: str = "query", + answer_key: str = "answer", + gold_key: str = "answer_aliases", +) -> List[Dict[str, Any]]: + """Evaluate a list of QA predictions with LLM-based judgement. + + For each prediction the function checks whether the predicted answer + matches the gold answer using a combination of fast heuristics (exact + match, fuzzy match) and, for borderline cases, LLM judgement via a + ``@generative`` function. + + Args: + session: Mellea session used for LLM-based evaluation calls. + predictions: List of dicts, each containing at least ``query_key`` + and ``answer_key`` fields plus an optional ``gold_key`` list of + acceptable answers. + query_key: Dict key for the question text (default: ``"query"``). + answer_key: Dict key for the predicted answer (default: ``"answer"``). + gold_key: Dict key for the list of gold answers + (default: ``"answer_aliases"``). + + Returns: + Same list with an added ``"correct"`` (bool) and + ``"eval_method"`` (str) field on every item. + """ + try: + from mellea import generative + from pydantic import BaseModel + + class _EvalResult(BaseModel): + correct: bool + reason: str + + @generative + async def _llm_judge( + query: str, + predicted: str, + gold_answers: str, + ) -> _EvalResult: + """Judge whether a predicted answer is correct. + + Question: {query} + Predicted answer: {predicted} + Acceptable answers: {gold_answers} + + Respond with a JSON object: + {{"correct": true/false, "reason": "brief explanation"}} + """ + pass + + _generative_available = True + except Exception: + _generative_available = False + + results = [] + for item in predictions: + pred = str(item.get(answer_key, "")).strip() + golds = item.get(gold_key, []) + if isinstance(golds, str): + golds = [golds] + + correct = False + method = "none" + + # 1. Exact match + for gold in golds: + if exact_match(pred, str(gold)): + correct = True + method = "exact" + break + + # 2. Fuzzy match + if not correct: + for gold in golds: + if fuzzy_match(pred, str(gold)): + correct = True + method = "fuzzy" + break + + # 3. LLM judgement for uncertain cases + if not correct and _generative_available and golds: + try: + gold_str = " | ".join(str(g) for g in golds) + judge = await _llm_judge( + session, + query=str(item.get(query_key, "")), + predicted=pred, + gold_answers=gold_str, + ) + correct = judge.correct + method = "llm" + except Exception: + pass + + result = dict(item) + result["correct"] = correct + result["eval_method"] = method + results.append(result) + + return results + + +def aggregate_update_results(update_results: List[UpdateResult]) -> UpdateStats: + """Aggregate update results into statistics. + + Args: + update_results: List of UpdateResult objects. + + Returns: + UpdateStats object with aggregated statistics. + """ + stats = UpdateStats() + + if not update_results: + return stats + + stats.total_documents = len(update_results) + + # Count successful and failed + successful = 0 + failed = 0 + times = [] + + for result in update_results: + if result.success: + successful += 1 + stats.entities_extracted += result.entities_found + stats.relations_extracted += result.relations_found + stats.entities_new += result.entities_added + stats.relations_new += result.relations_added + else: + failed += 1 + + if result.processing_time_ms: + times.append(result.processing_time_ms) + + stats.successful_documents = successful + stats.failed_documents = failed + + # Compute timing stats + if times: + stats.total_processing_time_ms = sum(times) + stats.average_processing_time_per_doc_ms = sum(times) / len(update_results) + + return stats diff --git a/mellea_contribs/kg/utils/progress.py b/mellea_contribs/kg/utils/progress.py new file mode 100644 index 0000000..403ab9c --- /dev/null +++ b/mellea_contribs/kg/utils/progress.py @@ -0,0 +1,319 @@ +"""Progress tracking and logging utilities. + +Provides functions for logging, progress tracking, and structured output. +""" + +import json +import logging +import os +import sys +import time +from typing import Any, Dict, Optional, Set, Union + +from pydantic import BaseModel + +try: + from tqdm import tqdm +except ImportError: + tqdm = None # type: ignore + + +def setup_logging(log_level: str = "INFO", log_file: Optional[str] = None) -> None: + """Configure logging for the application. + + Args: + log_level: Logging level ("DEBUG", "INFO", "WARNING", "ERROR", default: "INFO"). + log_file: Optional file path to write logs to. + """ + level = getattr(logging, log_level.upper(), logging.INFO) + + # Create logger + logger = logging.getLogger("mellea_contribs.kg") + logger.setLevel(level) + + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Add console handler + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setLevel(level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # Add file handler if requested + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + +def log_progress(msg: str, level: str = "INFO") -> None: + """Log a progress message to stderr. + + Args: + msg: Message to log. + level: Logging level ("DEBUG", "INFO", "WARNING", "ERROR", default: "INFO"). + """ + logger = logging.getLogger("mellea_contribs.kg") + level_func = getattr(logger, level.lower(), logger.info) + level_func(msg) + + +def output_json(obj: BaseModel) -> None: + """Output a Pydantic model as JSON to stdout. + + Args: + obj: Pydantic model instance to output. + """ + print(json.dumps(obj.model_dump())) + + +def print_stats( + stats: BaseModel, indent: int = 0, to_stderr: bool = True +) -> None: + """Pretty-print statistics to stderr or stdout. + + Args: + stats: Statistics object (QAStats, UpdateStats, EmbeddingStats, etc.). + indent: Number of spaces to indent (default: 0). + to_stderr: Print to stderr if True, stdout if False (default: True). + """ + output = sys.stderr if to_stderr else sys.stdout + prefix = " " * indent + + # Get all fields from stats object + data = stats.model_dump() + + for key, value in data.items(): + # Format key (snake_case to Title Case) + display_key = key.replace("_", " ").title() + + # Format value + if isinstance(value, float): + display_value = f"{value:.2f}" + elif isinstance(value, list): + display_value = ", ".join(str(v) for v in value) + else: + display_value = str(value) + + print(f"{prefix}{display_key}: {display_value}", file=output) + + +class ProgressTracker: + """Progress tracker with optional tqdm integration. + + If tqdm is available, uses progress bar; otherwise prints text updates. + """ + + def __init__(self, total: int, desc: str = "Processing", use_tqdm: bool = True): + """Initialize progress tracker. + + Args: + total: Total number of items to process. + desc: Description of progress (default: "Processing"). + use_tqdm: Use tqdm if available (default: True). + """ + self.total = total + self.desc = desc + self.current = 0 + self.use_tqdm = use_tqdm and tqdm is not None + + if self.use_tqdm: + self.pbar = tqdm(total=total, desc=desc) + else: + self.pbar = None + + def update(self, n: int = 1) -> None: + """Update progress by n items. + + Args: + n: Number of items to add to progress (default: 1). + """ + self.current += n + + if self.use_tqdm and self.pbar: + self.pbar.update(n) + else: + # Print text update + percent = (self.current / self.total) * 100 + print( + f"{self.desc}: {self.current}/{self.total} ({percent:.1f}%)", + file=sys.stderr, + ) + + def close(self) -> None: + """Close the progress tracker.""" + if self.use_tqdm and self.pbar: + self.pbar.close() + + +class BaseProgressLogger: + """JSON-file-backed progress logger with resumption support. + + Persists a set of processed item IDs and arbitrary key-value metadata + to a JSON file so that long-running pipelines can resume after + interruption. + + Usage:: + + logger = BaseProgressLogger("progress.json") + logger.load() + for item in items: + if logger.is_processed(item["id"]): + continue + result = process(item) + logger.mark_processed(item["id"]) + logger.add_stat(result) + logger.save() + """ + + def __init__(self, progress_path: str) -> None: + """Initialise the logger. + + Args: + progress_path: Path to the JSON file used for persistence. + """ + self._path = progress_path + self._processed: Set[str] = set() + self._stats: list = [] + self._meta: Dict[str, Any] = {} + + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def load(self) -> None: + """Load existing progress from disk (no-op when file is absent).""" + if not os.path.exists(self._path): + return + try: + with open(self._path, "r", encoding="utf-8") as fh: + data = json.load(fh) + self._processed = set(data.get("processed", [])) + self._stats = data.get("stats", []) + self._meta = data.get("meta", {}) + except (json.JSONDecodeError, OSError): + pass + + def save(self, retries: int = 3) -> None: + """Save progress to disk. + + Args: + retries: Number of write attempts before giving up. + """ + os.makedirs(os.path.dirname(os.path.abspath(self._path)), exist_ok=True) + data = { + "processed": list(self._processed), + "stats": self._stats, + "meta": self._meta, + } + for attempt in range(retries): + try: + tmp = self._path + ".tmp" + with open(tmp, "w", encoding="utf-8") as fh: + json.dump(data, fh, ensure_ascii=False, default=str) + os.replace(tmp, self._path) + return + except OSError: + if attempt < retries - 1: + time.sleep(0.1) + + # ------------------------------------------------------------------ + # Tracking + # ------------------------------------------------------------------ + + def is_processed(self, item_id: str) -> bool: + """Check whether an item has already been processed. + + Args: + item_id: Unique identifier for the item. + + Returns: + True if the item is in the processed set. + """ + return item_id in self._processed + + def mark_processed(self, item_id: str) -> None: + """Mark an item as processed. + + Args: + item_id: Unique identifier for the item. + """ + self._processed.add(item_id) + + def add_stat(self, stat: Any) -> None: + """Append a result/stat entry. + + Args: + stat: Any JSON-serialisable value. + """ + if hasattr(stat, "model_dump"): + self._stats.append(stat.model_dump()) + else: + self._stats.append(stat) + + def update_meta(self, **kwargs: Any) -> None: + """Update key-value metadata entries. + + Args: + **kwargs: Key-value pairs to store in the metadata dict. + """ + self._meta.update(kwargs) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def processed_ids(self) -> Set[str]: + """Set of all processed item IDs.""" + return set(self._processed) + + @property + def stats(self) -> list: + """List of collected stat entries.""" + return list(self._stats) + + @property + def meta(self) -> Dict[str, Any]: + """Metadata dictionary.""" + return dict(self._meta) + + @property + def num_processed(self) -> int: + """Number of items marked as processed.""" + return len(self._processed) + + +class QAProgressLogger(BaseProgressLogger): + """Progress logger specialised for QA pipeline runs. + + Stores per-question results alongside a processed-question ID set so + that the run can resume mid-dataset without repeating work. + + The progress file uses the format:: + + { + "processed": ["q_0", "q_3", ...], + "stats": [{"query": "...", "answer": "...", ...}, ...], + "meta": {"total": 100, "last_updated": "..."} + } + """ + + def add_result(self, query_id: str, result: Any) -> None: + """Record a QA result and mark the query as processed. + + Args: + query_id: Unique identifier for the query (used for resumption). + result: QAResult, dict, or any JSON-serialisable value. + """ + self.add_stat(result) + self.mark_processed(query_id) + + @property + def processed_queries(self) -> Set[str]: + """Set of query IDs that have been answered.""" + return self.processed_ids diff --git a/mellea_contribs/kg/utils/session_manager.py b/mellea_contribs/kg/utils/session_manager.py new file mode 100644 index 0000000..4f9c500 --- /dev/null +++ b/mellea_contribs/kg/utils/session_manager.py @@ -0,0 +1,400 @@ +"""Session and backend management utilities. + +Provides factory functions for creating Mellea sessions and graph backends. +""" + +import sys +from typing import Optional + +try: + from mellea import start_session, MelleaSession +except ImportError: + MelleaSession = None # type: ignore + +from mellea_contribs.kg.graph_dbs.base import GraphBackend +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend + +try: + from mellea_contribs.kg.graph_dbs.neo4j import Neo4jBackend +except ImportError: + Neo4jBackend = None + + +def create_session( + backend_name: str = "litellm", + model_id: str = "gpt-4o-mini", + temperature: float = 0.7, + api_base: Optional[str] = None, + api_key: Optional[str] = None, +) -> "MelleaSession": + """Create a Mellea session. + + Args: + backend_name: Backend name (default: "litellm"). + model_id: Model ID to use (default: "gpt-4o-mini"). + temperature: Temperature for generation (default: 0.7). + api_base: Optional API base URL. + api_key: Optional API key. + + Returns: + MelleaSession object. + + Raises: + ImportError: If mellea is not installed. + """ + if MelleaSession is None: + print("Error: mellea not installed. Run: pip install mellea[litellm]") + sys.exit(1) + + return start_session(backend_name=backend_name, model_id=model_id) + + +def create_backend( + backend_type: str = "mock", + neo4j_uri: Optional[str] = None, + neo4j_user: Optional[str] = None, + neo4j_password: Optional[str] = None, +) -> GraphBackend: + """Create a graph backend. + + Args: + backend_type: Type of backend ("mock" or "neo4j", default: "mock"). + neo4j_uri: Neo4j connection URI (default: "bolt://localhost:7687"). + neo4j_user: Neo4j username (default: "neo4j"). + neo4j_password: Neo4j password (default: "password"). + + Returns: + GraphBackend instance. + + Raises: + SystemExit: If Neo4j backend requested but not available. + """ + if backend_type == "mock": + return MockGraphBackend() + + if backend_type == "neo4j": + if Neo4jBackend is None: + print( + "Error: Neo4j backend not available. " + "Install: pip install mellea-contribs[kg]" + ) + sys.exit(1) + + neo4j_uri = neo4j_uri or "bolt://localhost:7687" + neo4j_user = neo4j_user or "neo4j" + neo4j_password = neo4j_password or "password" + + return Neo4jBackend( + connection_uri=neo4j_uri, + auth=(neo4j_user, neo4j_password), + ) + + raise ValueError(f"Unknown backend type: {backend_type}") + + +class MelleaResourceManager: + """Async context manager for managing Mellea session and backend resources. + + Usage: + async with MelleaResourceManager(backend_type="mock") as manager: + session = manager.session + backend = manager.backend + # Use session and backend + """ + + def __init__( + self, + backend_type: str = "mock", + model_id: str = "gpt-4o-mini", + neo4j_uri: Optional[str] = None, + neo4j_user: Optional[str] = None, + neo4j_password: Optional[str] = None, + ): + """Initialize resource manager. + + Args: + backend_type: Type of backend ("mock" or "neo4j", default: "mock"). + model_id: Model ID for session (default: "gpt-4o-mini"). + neo4j_uri: Neo4j connection URI. + neo4j_user: Neo4j username. + neo4j_password: Neo4j password. + """ + self.backend_type = backend_type + self.model_id = model_id + self.neo4j_uri = neo4j_uri + self.neo4j_user = neo4j_user + self.neo4j_password = neo4j_password + self.session: Optional[MelleaSession] = None + self.backend: Optional[GraphBackend] = None + + async def __aenter__(self) -> "MelleaResourceManager": + """Enter async context and create resources.""" + self.session = create_session(model_id=self.model_id) + self.backend = create_backend( + backend_type=self.backend_type, + neo4j_uri=self.neo4j_uri, + neo4j_user=self.neo4j_user, + neo4j_password=self.neo4j_password, + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit async context and cleanup resources.""" + if self.backend: + await self.backend.close() + + +def create_openai_session( + model_id: str = "gpt-4o-mini", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 1800, + extra_headers: Optional[dict] = None, + force_openai_schema: bool = True, +) -> "MelleaSession": + """Create a Mellea session backed by an OpenAI-compatible endpoint. + + Unlike :func:`create_session` (which uses the generic LiteLLM backend), + this function wires up ``OpenAIBackend`` directly, allowing fine-grained + control over base URL, API key, timeout, and custom headers. Suitable + for Azure OpenAI, vLLM, and IBM RITS endpoints. + + Args: + model_id: Model identifier recognised by the endpoint + (e.g. ``"gpt-4o-mini"``). + api_base: Base URL for the OpenAI-compatible API. Falls back to the + ``OPENAI_API_BASE`` environment variable when *None*. + api_key: API key. Falls back to the ``OPENAI_API_KEY`` environment + variable when *None*. + timeout: Request timeout in seconds (default: ``1800``). + extra_headers: Additional HTTP headers forwarded with every request. + force_openai_schema: When *True* (default), override Mellea's + server-type detection so it always uses the strict OpenAI JSON + schema format (``additionalProperties=False``, ``strict=True``). + Mellea's auto-detection classifies non-``api.openai.com`` URLs as + ``UNKNOWN`` and sends schemas without ``additionalProperties``, + which many OpenAI-compatible endpoints (including IBM RITS) reject. + + Returns: + A configured :class:`MelleaSession`. + + Raises: + SystemExit: If ``mellea`` or ``mellea.backends.openai`` is not + installed. + """ + import os + + if MelleaSession is None: + print("Error: mellea not installed. Run: pip install mellea[litellm]") + sys.exit(1) + + try: + from mellea import MelleaSession as _MS + from mellea.backends.openai import OpenAIBackend, TemplateFormatter + except ImportError: + print("Error: mellea.backends.openai not available.") + sys.exit(1) + + resolved_base = api_base or os.environ.get("OPENAI_API_BASE") + resolved_key = api_key or os.environ.get("OPENAI_API_KEY", "dummy") + + backend = OpenAIBackend( + model_id=model_id, + formatter=TemplateFormatter(model_id=model_id), + base_url=resolved_base, + api_key=resolved_key, + timeout=timeout, + default_headers=extra_headers or {}, + ) + + if force_openai_schema and resolved_base is not None: + try: + from mellea.helpers.server_type import _ServerType + backend._server_type = _ServerType.OPENAI + except ImportError: + pass + + if resolved_base is not None: + _patch_openai_backend_error_logging() + + return _MS(backend=backend) + + +def _patch_openai_backend_error_logging() -> None: + """Patch ``send_to_queue`` in the OpenAI backend module so API exceptions + are logged before being silently swallowed by Mellea's async queue. + + Mellea catches all exceptions from the LLM call inside ``send_to_queue`` + and puts them on a queue. The backend's ``processing()`` function ignores + non-response objects, so the original error is permanently lost and + ``post_processing()`` raises a confusing ``KeyError: 'oai_chat_response'`` + instead. This patch logs the real exception at ERROR level first. + """ + import logging + import traceback as _tb + from collections.abc import AsyncIterator, Coroutine + + try: + import mellea.backends.openai as _mo + import mellea.helpers.async_helpers as _ah + except ImportError: + return + + _logger = logging.getLogger("mellea_contribs.kg") + + async def _logged_send_to_queue(co, aqueue) -> None: # type: ignore[type-arg] + try: + aresponse = await co if isinstance(co, Coroutine) else co + if isinstance(aresponse, AsyncIterator): + async for item in aresponse: + await aqueue.put(item) + else: + await aqueue.put(aresponse) + await aqueue.put(None) + except Exception as exc: + _logger.error( + f"[API] Call failed: {type(exc).__name__}: {exc}" + ) + _logger.debug(_tb.format_exc()) + await aqueue.put(exc) + + _mo.send_to_queue = _logged_send_to_queue + _ah.send_to_queue = _logged_send_to_queue + + +def create_session_from_env( + default_model: str = "gpt-4o-mini", + timeout: int = 1800, + env_prefix: str = "", +) -> tuple: + """Create a Mellea session from standard environment variables. + + Reads ``{prefix}API_BASE``, ``{prefix}API_KEY``, ``{prefix}MODEL_NAME``, + and ``{prefix}RITS_API_KEY`` from the environment and delegates to + :func:`create_openai_session`. Suitable for any OpenAI-compatible + endpoint including IBM RITS. + + Args: + default_model: Model to use when ``MODEL_NAME`` is not set. + timeout: Request timeout in seconds (default: ``1800``). + env_prefix: Optional prefix for all env var names (e.g. ``"EVAL_"`` + reads ``EVAL_API_BASE``, ``EVAL_API_KEY``, etc.). + + Returns: + ``(session, model_id)`` tuple — the configured + :class:`MelleaSession` and the resolved model name string. + """ + import os + + import logging + _log = logging.getLogger("mellea_contribs.kg") + + p = env_prefix + api_base = os.getenv(f"{p}API_BASE") + api_key = os.getenv(f"{p}API_KEY", "dummy") + model_id = os.getenv(f"{p}MODEL_NAME", default_model) + # Fall back to the unprefixed RITS_API_KEY when the prefixed one isn't set, + # since the same RITS credentials are typically shared across all sessions. + rits_api_key = os.getenv(f"{p}RITS_API_KEY") or ( + os.getenv("RITS_API_KEY") if p else None + ) + + _log.info( + f"create_session_from_env(prefix={repr(env_prefix)}): " + f"api_base={'set' if api_base else 'MISSING'}, " + f"api_key={'set' if api_key != 'dummy' else 'dummy/unset'}, " + f"rits_api_key={'set' if rits_api_key else 'MISSING'}" + ) + + extra_headers: dict = {} + if rits_api_key: + extra_headers["RITS_API_KEY"] = rits_api_key + + session = create_openai_session( + model_id=model_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + extra_headers=extra_headers or None, + ) + return session, model_id + + +def create_embedding_client( + api_base: Optional[str] = None, + api_key: Optional[str] = None, + model_name: str = "text-embedding-3-small", + timeout: int = 1800, +): + """Create an async OpenAI-compatible embedding client. + + The returned client exposes ``client.embeddings.create(input, model)`` + just like the official ``openai.AsyncOpenAI`` client, making it usable + with any OpenAI-compatible embedding endpoint (Azure, vLLM, IBM RITS, + etc.). + + Args: + api_base: Base URL of the embedding endpoint. Falls back to + ``OPENAI_API_BASE`` when *None*. + api_key: API key. Falls back to ``OPENAI_API_KEY`` when *None*. + model_name: Default model name attached to the client as + ``client._model_name`` for convenience. + timeout: HTTP timeout in seconds (default: ``1800``). + + Returns: + An ``openai.AsyncOpenAI`` instance (or ``None`` when *openai* is not + installed). + """ + import os + + resolved_base = api_base or os.environ.get("OPENAI_API_BASE") + resolved_key = api_key or os.environ.get("OPENAI_API_KEY", "dummy") + + try: + from openai import AsyncOpenAI + + client = AsyncOpenAI( + base_url=resolved_base, + api_key=resolved_key, + timeout=timeout, + ) + # Attach model name so callers can read it without a separate arg + client._model_name = model_name # type: ignore[attr-defined] + return client + except ImportError: + print("Warning: openai package not installed; embedding client unavailable.") + return None + + +async def generate_embeddings( + client, + texts: list, + model_name: Optional[str] = None, +) -> list: + """Generate embeddings for a list of texts using an async OpenAI client. + + Args: + client: Async OpenAI-compatible client (from + :func:`create_embedding_client`). + texts: List of strings to embed. + model_name: Override the model name. Uses ``client._model_name`` + when *None*, falling back to ``"text-embedding-3-small"``. + + Returns: + List of embedding vectors (one per input text), or a list of + ``None`` values when the client is unavailable or the call fails. + """ + if client is None or not texts: + return [None] * len(texts) + + resolved_model = ( + model_name + or getattr(client, "_model_name", None) + or "text-embedding-3-small" + ) + try: + response = await client.embeddings.create(input=texts, model=resolved_model) + return [item.embedding for item in response.data] + except Exception as exc: + print(f"Warning: embedding call failed — {exc}", file=sys.stderr) + return [None] * len(texts) diff --git a/pyproject.toml b/pyproject.toml index 17f2829..c834763 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,13 @@ classifiers = [ ] dependencies = [ - "mellea[litellm]", + "mellea[litellm]>=0.3.0", "rapidfuzz>=3.14.3", "eyecite", "citeurl", "playwright", "markdown", + "neo4j>=6.1.0", ] [project.optional-dependencies] @@ -59,6 +60,15 @@ docs = [ "sphinx_mdinclude", ] +kg = [ + "neo4j>=5.0.0", +] + +# Phase 4: Optional dependencies for KG-RAG pipeline utilities +kg-utils = [ + "tqdm>=4.65.0", # Progress bars for batch processing +] + [tool.ruff] target-version = "py310" respect-gitignore = true @@ -128,7 +138,8 @@ python_version = "3.10" [tool.pytest.ini_options] markers = [ - "qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD" + "qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD", + "neo4j: Marks tests requiring a running Neo4j instance (skipped in CI unless NEO4J_URI is set)", ] asyncio_mode = "auto" # Don't require explicitly marking async tests. diff --git a/test/conftest.py b/test/conftest.py index c670852..7f78593 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,8 +11,17 @@ def gh_run() -> int: def pytest_runtest_setup(item): - """Skip qualitative tests when running in CI environment.""" - # Allow tests *not* marked with `@pytest.mark.qualitative` to run normally. + """Skip qualitative and neo4j tests when appropriate.""" + # Handle neo4j marker - skip if NEO4J_URI not set + if item.get_closest_marker("neo4j"): + if not os.environ.get("NEO4J_URI"): + pytest.skip( + reason="Skipping neo4j test: NEO4J_URI environment variable not set. " + "Set NEO4J_URI to enable Neo4j integration tests." + ) + return + + # Handle qualitative marker - skip if in CI if not item.get_closest_marker("qualitative"): return diff --git a/test/kg/__init__.py b/test/kg/__init__.py new file mode 100644 index 0000000..5644f52 --- /dev/null +++ b/test/kg/__init__.py @@ -0,0 +1 @@ +"""Tests for Knowledge Graph library.""" diff --git a/test/kg/conftest.py b/test/kg/conftest.py new file mode 100644 index 0000000..12b1309 --- /dev/null +++ b/test/kg/conftest.py @@ -0,0 +1,98 @@ +"""Pytest configuration for KG tests. + +Reads Neo4j credentials from environment variables for testing. +""" + +import os + +import pytest + +try: + from mellea_contribs.kg.components.query import GraphQuery +except ImportError: + # Allow tests to run even if mellea is not fully installed + GraphQuery = None + +try: + from mellea_contribs.kg.graph_dbs.neo4j import Neo4jBackend + + NEO4J_AVAILABLE = True +except ImportError: + NEO4J_AVAILABLE = False + + +@pytest.fixture +def neo4j_credentials(): + """Get Neo4j credentials from environment variables. + + Set these environment variables before running tests: + NEO4J_URI: Connection URI (default: bolt://localhost:7687) + NEO4J_USER: Username (default: neo4j) + NEO4J_PASSWORD: Password (default: testpassword) + + Example: + export NEO4J_PASSWORD="your_password" + uv run pytest test/contribs/kg/test_neo4j_backend.py -v + """ + return { + "uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"), + "user": os.getenv("NEO4J_USER", "neo4j"), + "password": os.getenv("NEO4J_PASSWORD", "testpassword"), + } + + +@pytest.fixture +async def neo4j_backend(neo4j_credentials): + """Create a Neo4j backend for testing. + + Uses credentials from environment variables via neo4j_credentials fixture. + Tests will be skipped if Neo4j is not available or connection fails. + """ + if not NEO4J_AVAILABLE: + pytest.skip("Neo4j driver not installed") + + backend = Neo4jBackend( + connection_uri=neo4j_credentials["uri"], + auth=(neo4j_credentials["user"], neo4j_credentials["password"]), + ) + + # Test connection + try: + await backend.get_schema() + yield backend + except Exception as e: + pytest.skip(f"Could not connect to Neo4j: {e}") + finally: + await backend.close() + + +@pytest.fixture +async def populated_neo4j_backend(neo4j_backend): + """Create a Neo4j backend with test data. + + Clears existing data, populates with test data, and cleans up after tests. + """ + if GraphQuery is None: + pytest.skip("GraphQuery not available (mellea not fully installed)") + + # Clear any existing data + clear_query = GraphQuery(query_string="MATCH (n) DETACH DELETE n") + await neo4j_backend.execute_query(clear_query) + + # Create test data + create_query = GraphQuery( + query_string=""" + CREATE (alice:Person {name: 'Alice', age: 30}) + CREATE (bob:Person {name: 'Bob', age: 35}) + CREATE (matrix:Movie {title: 'The Matrix', year: 1999}) + CREATE (alice)-[:ACTED_IN {role: 'Trinity'}]->(matrix) + CREATE (bob)-[:ACTED_IN {role: 'Morpheus'}]->(matrix) + RETURN alice, bob, matrix + """ + ) + await neo4j_backend.execute_query(create_query) + + yield neo4j_backend + + # Cleanup + await neo4j_backend.execute_query(clear_query) diff --git a/test/kg/test_base.py b/test/kg/test_base.py new file mode 100644 index 0000000..26bba08 --- /dev/null +++ b/test/kg/test_base.py @@ -0,0 +1,132 @@ +"""Tests for base graph data structures.""" + +import pytest + +from mellea_contribs.kg.base import GraphEdge, GraphNode, GraphPath + + +class TestGraphNode: + """Tests for GraphNode dataclass.""" + + def test_create_graph_node(self): + """Test creating a GraphNode.""" + node = GraphNode( + id="1", label="Person", properties={"name": "Alice", "age": 30} + ) + + assert node.id == "1" + assert node.label == "Person" + assert node.properties["name"] == "Alice" + assert node.properties["age"] == 30 + + def test_graph_node_empty_properties(self): + """Test GraphNode with empty properties.""" + node = GraphNode(id="2", label="Movie", properties={}) + + assert node.id == "2" + assert node.label == "Movie" + assert node.properties == {} + + def test_graph_node_equality(self): + """Test GraphNode equality.""" + node1 = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + node2 = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + node3 = GraphNode(id="2", label="Person", properties={"name": "Bob"}) + + assert node1 == node2 + assert node1 != node3 + + +class TestGraphEdge: + """Tests for GraphEdge dataclass.""" + + def test_create_graph_edge(self): + """Test creating a GraphEdge.""" + source = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + target = GraphNode(id="2", label="Movie", properties={"title": "The Matrix"}) + + edge = GraphEdge( + id="e1", + source=source, + label="ACTED_IN", + target=target, + properties={"role": "Neo"}, + ) + + assert edge.id == "e1" + assert edge.source == source + assert edge.label == "ACTED_IN" + assert edge.target == target + assert edge.properties["role"] == "Neo" + + def test_graph_edge_empty_properties(self): + """Test GraphEdge with empty properties.""" + source = GraphNode(id="1", label="Person", properties={}) + target = GraphNode(id="2", label="Movie", properties={}) + + edge = GraphEdge( + id="e1", source=source, label="DIRECTED", target=target, properties={} + ) + + assert edge.properties == {} + + def test_graph_edge_equality(self): + """Test GraphEdge equality.""" + source = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + target = GraphNode(id="2", label="Movie", properties={"title": "The Matrix"}) + + edge1 = GraphEdge( + id="e1", source=source, label="ACTED_IN", target=target, properties={} + ) + edge2 = GraphEdge( + id="e1", source=source, label="ACTED_IN", target=target, properties={} + ) + edge3 = GraphEdge( + id="e2", source=source, label="DIRECTED", target=target, properties={} + ) + + assert edge1 == edge2 + assert edge1 != edge3 + + +class TestGraphPath: + """Tests for GraphPath dataclass.""" + + def test_create_graph_path(self): + """Test creating a GraphPath.""" + node1 = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + node2 = GraphNode(id="2", label="Movie", properties={"title": "The Matrix"}) + node3 = GraphNode(id="3", label="Person", properties={"name": "Bob"}) + + edge1 = GraphEdge( + id="e1", source=node1, label="ACTED_IN", target=node2, properties={} + ) + edge2 = GraphEdge( + id="e2", source=node3, label="ACTED_IN", target=node2, properties={} + ) + + path = GraphPath(nodes=[node1, node2, node3], edges=[edge1, edge2]) + + assert len(path.nodes) == 3 + assert len(path.edges) == 2 + assert path.nodes[0] == node1 + assert path.nodes[1] == node2 + assert path.nodes[2] == node3 + assert path.edges[0] == edge1 + assert path.edges[1] == edge2 + + def test_empty_graph_path(self): + """Test creating an empty GraphPath.""" + path = GraphPath(nodes=[], edges=[]) + + assert len(path.nodes) == 0 + assert len(path.edges) == 0 + + def test_graph_path_single_node(self): + """Test GraphPath with single node.""" + node = GraphNode(id="1", label="Person", properties={"name": "Alice"}) + path = GraphPath(nodes=[node], edges=[]) + + assert len(path.nodes) == 1 + assert len(path.edges) == 0 + assert path.nodes[0] == node diff --git a/test/kg/test_kgrag.py b/test/kg/test_kgrag.py new file mode 100644 index 0000000..4b6866e --- /dev/null +++ b/test/kg/test_kgrag.py @@ -0,0 +1,173 @@ +"""Tests for Layer 1: KGRag pipeline. + +Structural tests run without an LLM. Qualitative (end-to-end) tests require +a real Mellea session and are marked @pytest.mark.qualitative so they are +skipped in CI. +""" + +import pytest + +from mellea_contribs.kg.base import GraphEdge, GraphNode +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend +from mellea_contribs.kg.kgrag import KGRag, format_schema + + +# --------------------------------------------------------------------------- +# format_schema +# --------------------------------------------------------------------------- + + +class TestFormatSchema: + """Tests for the format_schema() helper.""" + + def test_full_schema(self): + """format_schema includes node types, edge types, and property keys.""" + schema = { + "node_types": ["Person", "Movie"], + "edge_types": ["ACTED_IN", "DIRECTED"], + "property_keys": ["name", "title", "year"], + } + result = format_schema(schema) + assert "Person" in result + assert "Movie" in result + assert "ACTED_IN" in result + assert "name" in result + + def test_empty_schema(self): + """format_schema handles an empty schema without errors.""" + result = format_schema({}) + assert "Graph Schema" in result + + def test_partial_schema(self): + """format_schema handles a schema with only node types.""" + result = format_schema({"node_types": ["Person"]}) + assert "Person" in result + + def test_returns_string(self): + """format_schema always returns a string.""" + assert isinstance(format_schema({}), str) + assert isinstance(format_schema({"node_types": ["A"]}), str) + + +# --------------------------------------------------------------------------- +# KGRag structural tests (no LLM required) +# --------------------------------------------------------------------------- + + +class TestKGRagStructural: + """Structural tests for KGRag that do not call an LLM.""" + + @pytest.fixture + def backend(self): + """Mock backend with sample data.""" + nodes = [ + GraphNode(id="1", label="Person", properties={"name": "Alice"}), + GraphNode(id="2", label="Movie", properties={"title": "The Matrix"}), + ] + edges = [ + GraphEdge( + id="e1", + source=nodes[0], + label="ACTED_IN", + target=nodes[1], + properties={}, + ) + ] + return MockGraphBackend(mock_nodes=nodes, mock_edges=edges) + + def test_init_stores_backend_and_session(self, backend): + """KGRag stores backend, session, and config options.""" + rag = KGRag(backend=backend, session=None) + assert rag._backend is backend + assert rag._session is None + + def test_default_format_style(self, backend): + """KGRag defaults to 'natural' format style.""" + rag = KGRag(backend=backend, session=None) + assert rag._format_style == "natural" + + def test_custom_format_style(self, backend): + """KGRag accepts a custom format style.""" + rag = KGRag(backend=backend, session=None, format_style="triplets") + assert rag._format_style == "triplets" + + def test_default_max_repair_attempts(self, backend): + """KGRag defaults to 2 max repair attempts.""" + rag = KGRag(backend=backend, session=None) + assert rag._max_repair_attempts == 2 + + def test_custom_max_repair_attempts(self, backend): + """KGRag accepts a custom max_repair_attempts.""" + rag = KGRag(backend=backend, session=None, max_repair_attempts=5) + assert rag._max_repair_attempts == 5 + + def test_answer_is_coroutine(self, backend): + """KGRag.answer() is a coroutine function.""" + import asyncio + + rag = KGRag(backend=backend, session=None) + assert asyncio.iscoroutinefunction(rag.answer) + + async def test_validate_and_repair_returns_valid_query(self, backend): + """_validate_and_repair returns the query unchanged when already valid.""" + rag = KGRag(backend=backend, session=None) + schema_text = format_schema(await backend.get_schema()) + # MockBackend.validate_query always returns True + result = await rag._validate_and_repair( + "MATCH (n) RETURN n", schema_text + ) + assert result == "MATCH (n) RETURN n" + + +# --------------------------------------------------------------------------- +# Qualitative (end-to-end) tests +# --------------------------------------------------------------------------- + + +@pytest.mark.qualitative +class TestKGRagQualitative: + """End-to-end tests for KGRag that require a real LLM session. + + These tests are skipped in CI unless MELLEA_SKIP_QUALITATIVE is unset. + They require a running Neo4j database (NEO4J_URI env var) and a + configured Mellea LLM backend. + """ + + @pytest.fixture + def neo4j_backend(self): + """Neo4j backend for qualitative tests.""" + import os + + uri = os.environ.get("NEO4J_URI") + if not uri: + pytest.skip("NEO4J_URI not set") + + from mellea_contribs.kg.graph_dbs.neo4j import Neo4jBackend + + user = os.environ.get("NEO4J_USER", "neo4j") + password = os.environ.get("NEO4J_PASSWORD", "password") + return Neo4jBackend(connection_uri=uri, auth=(user, password)) + + @pytest.fixture + def mellea_session(self): + """Mellea session for qualitative tests.""" + from mellea import start_session + + return start_session(backend_name="litellm", model_id="gpt-4o-mini") + + async def test_answer_returns_string(self, neo4j_backend, mellea_session): + """KGRag.answer() returns a non-empty string for a simple question.""" + rag = KGRag(backend=neo4j_backend, session=mellea_session) + answer = await rag.answer("What nodes exist in the graph?") + assert isinstance(answer, str) + assert len(answer) > 0 + await neo4j_backend.close() + + async def test_answer_is_grounded(self, neo4j_backend, mellea_session): + """KGRag.answer() produces an answer grounded in real graph data.""" + rag = KGRag(backend=neo4j_backend, session=mellea_session) + answer = await rag.answer("What node labels are in the graph?") + assert isinstance(answer, str) + # Answer should contain some domain-relevant content + assert len(answer.split()) > 3 + await neo4j_backend.close() diff --git a/test/kg/test_layer2.py b/test/kg/test_layer2.py new file mode 100644 index 0000000..93f83f8 --- /dev/null +++ b/test/kg/test_layer2.py @@ -0,0 +1,402 @@ +"""Tests for Layer 2: Graph Query Components. + +Covers: +- components/query.py: GraphQuery, CypherQuery, SparqlQuery +- components/result.py: GraphResult (all format styles) +- components/traversal.py: GraphTraversal (all patterns + to_cypher) +""" + +import json + +import pytest +from mellea.stdlib.components import Component, TemplateRepresentation + +from mellea_contribs.kg.base import GraphEdge, GraphNode, GraphPath +from mellea_contribs.kg.components.query import CypherQuery, GraphQuery, SparqlQuery +from mellea_contribs.kg.components.result import GraphResult +from mellea_contribs.kg.components.traversal import GraphTraversal + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def alice(): + return GraphNode(id="1", label="Person", properties={"name": "Alice"}) + + +@pytest.fixture +def matrix(): + return GraphNode(id="2", label="Movie", properties={"title": "The Matrix"}) + + +@pytest.fixture +def acted_in(alice, matrix): + return GraphEdge( + id="e1", source=alice, label="ACTED_IN", target=matrix, properties={} + ) + + +@pytest.fixture +def graph_path(alice, acted_in, matrix): + return GraphPath(nodes=[alice, matrix], edges=[acted_in]) + + +# --------------------------------------------------------------------------- +# GraphQuery +# --------------------------------------------------------------------------- + + +class TestGraphQuery: + """Tests for the full GraphQuery Component.""" + + def test_implements_component_protocol(self): + """GraphQuery satisfies the Mellea Component protocol.""" + q = GraphQuery(query_string="MATCH (n) RETURN n") + assert isinstance(q, Component) + + def test_properties_accessible(self): + """GraphQuery exposes query_string, parameters, description, metadata.""" + q = GraphQuery( + query_string="MATCH (n) RETURN n", + parameters={"limit": 5}, + description="All nodes", + metadata={"schema": "test"}, + ) + assert q.query_string == "MATCH (n) RETURN n" + assert q.parameters == {"limit": 5} + assert q.description == "All nodes" + assert q.metadata == {"schema": "test"} + + def test_defaults(self): + """GraphQuery defaults to empty parameters / metadata and None strings.""" + q = GraphQuery() + assert q.query_string is None + assert q.parameters == {} + assert q.description is None + assert q.metadata == {} + + def test_format_for_llm_returns_template_representation(self): + """format_for_llm() returns a TemplateRepresentation.""" + q = GraphQuery(query_string="MATCH (n) RETURN n", description="All nodes") + rep = q.format_for_llm() + assert isinstance(rep, TemplateRepresentation) + assert rep.args["query"] == "MATCH (n) RETURN n" + assert rep.args["description"] == "All nodes" + + def test_with_description_is_immutable(self): + """with_description() returns a new instance without modifying the original.""" + q = GraphQuery(query_string="MATCH (n) RETURN n") + q2 = q.with_description("Updated") + assert q.description is None + assert q2.description == "Updated" + assert q is not q2 + + def test_with_parameters_merges(self): + """with_parameters() merges new params into a new instance.""" + q = GraphQuery(parameters={"a": 1}) + q2 = q.with_parameters(b=2) + assert q.parameters == {"a": 1} + assert q2.parameters == {"a": 1, "b": 2} + + def test_with_metadata_merges(self): + """with_metadata() merges new metadata into a new instance.""" + q = GraphQuery(metadata={"x": 1}) + q2 = q.with_metadata(y=2) + assert q.metadata == {"x": 1} + assert q2.metadata == {"x": 1, "y": 2} + + def test_parts_raises(self): + """parts() raises NotImplementedError.""" + q = GraphQuery() + with pytest.raises(NotImplementedError): + q.parts() + + +# --------------------------------------------------------------------------- +# CypherQuery +# --------------------------------------------------------------------------- + + +class TestCypherQuery: + """Tests for the fluent CypherQuery builder.""" + + def test_is_graph_query(self): + """CypherQuery is a subclass of GraphQuery.""" + assert issubclass(CypherQuery, GraphQuery) + + def test_empty_query(self): + """CypherQuery with no clauses has None query_string.""" + q = CypherQuery() + assert q.query_string is None + + def test_match_clause(self): + """match() adds a MATCH clause.""" + q = CypherQuery().match("(n:Person)") + assert q.query_string == "MATCH (n:Person)" + + def test_where_clause(self): + """where() adds a WHERE condition.""" + q = CypherQuery().match("(n:Person)").where("n.age > 18") + assert "WHERE n.age > 18" in q.query_string + + def test_return_clause(self): + """return_() adds RETURN expressions.""" + q = CypherQuery().match("(n)").return_("n.name", "n.age") + assert "RETURN n.name, n.age" in q.query_string + + def test_order_by_clause(self): + """order_by() adds an ORDER BY clause.""" + q = CypherQuery().match("(n)").return_("n").order_by("n.name ASC") + assert "ORDER BY n.name ASC" in q.query_string + + def test_limit_clause(self): + """limit() adds a LIMIT clause.""" + q = CypherQuery().match("(n)").return_("n").limit(10) + assert "LIMIT 10" in q.query_string + + def test_full_fluent_chain(self): + """All builder methods compose correctly.""" + q = ( + CypherQuery() + .match("(m:Movie)") + .where("m.year = $year") + .return_("m.title", "m.year") + .order_by("m.year DESC") + .limit(5) + .with_parameters(year=2020) + .with_description("Movies from 2020") + ) + qs = q.query_string + assert "MATCH (m:Movie)" in qs + assert "WHERE m.year = $year" in qs + assert "RETURN m.title, m.year" in qs + assert "ORDER BY m.year DESC" in qs + assert "LIMIT 5" in qs + assert q.parameters == {"year": 2020} + assert q.description == "Movies from 2020" + + def test_each_step_immutable(self): + """Each builder step returns a new CypherQuery instance.""" + base = CypherQuery() + with_match = base.match("(n)") + assert base is not with_match + assert base.query_string is None + + def test_multiple_where_conditions(self): + """Multiple where() calls are ANDed together.""" + q = CypherQuery().match("(n)").where("n.age > 18").where("n.active = true") + assert "n.age > 18 AND n.active = true" in q.query_string + + def test_explicit_query_string_bypasses_clauses(self): + """Providing query_string directly bypasses clause building.""" + raw = "MATCH (n:Custom) RETURN n" + q = CypherQuery(query_string=raw) + assert q.query_string == raw + + def test_format_for_llm_includes_query_type(self): + """CypherQuery format_for_llm includes query_type field.""" + q = CypherQuery(query_string="MATCH (n) RETURN n") + rep = q.format_for_llm() + assert rep.args["query_type"] == "Cypher (Neo4j)" + + +# --------------------------------------------------------------------------- +# SparqlQuery +# --------------------------------------------------------------------------- + + +class TestSparqlQuery: + """Tests for SparqlQuery.""" + + def test_is_graph_query(self): + """SparqlQuery is a subclass of GraphQuery.""" + assert issubclass(SparqlQuery, GraphQuery) + + def test_format_for_llm_includes_query_type(self): + """SparqlQuery format_for_llm includes SPARQL query_type.""" + q = SparqlQuery(query_string="SELECT ?s WHERE { ?s a :Person }") + rep = q.format_for_llm() + assert rep.args["query_type"] == "SPARQL" + + +# --------------------------------------------------------------------------- +# GraphResult +# --------------------------------------------------------------------------- + + +class TestGraphResult: + """Tests for GraphResult Component and its format styles.""" + + def test_implements_component_protocol(self, alice, matrix, acted_in): + """GraphResult satisfies the Mellea Component protocol.""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in]) + assert isinstance(r, Component) + + def test_properties(self, alice, matrix, acted_in): + """GraphResult exposes nodes, edges, paths, format_style.""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in], format_style="natural") + assert r.nodes == [alice, matrix] + assert r.edges == [acted_in] + assert r.format_style == "natural" + + def test_empty_result(self): + """GraphResult with no data has empty lists.""" + r = GraphResult() + assert r.nodes == [] + assert r.edges == [] + assert r.paths == [] + + def test_format_triplets(self, alice, matrix, acted_in): + """'triplets' style formats edges as (Src)-[REL]->(Tgt).""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in], format_style="triplets") + rep = r.format_for_llm() + result_text = rep.args["result"] + assert "(Person:Alice)-[ACTED_IN]->(Movie:The Matrix)" in result_text + + def test_format_triplets_empty(self): + """'triplets' on empty result returns a placeholder.""" + r = GraphResult(format_style="triplets") + rep = r.format_for_llm() + assert rep.args["result"] == "(no results)" + + def test_format_natural(self, alice, matrix, acted_in): + """'natural' style formats edges as natural language sentences.""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in], format_style="natural") + rep = r.format_for_llm() + result_text = rep.args["result"] + assert "acted in" in result_text.lower() + + def test_format_natural_empty(self): + """'natural' on empty result returns a descriptive message.""" + r = GraphResult(format_style="natural") + rep = r.format_for_llm() + assert "no results" in rep.args["result"].lower() + + def test_format_paths(self, alice, matrix, acted_in, graph_path): + """'paths' style renders paths as node-edge-node chains.""" + r = GraphResult(paths=[graph_path], format_style="paths") + rep = r.format_for_llm() + result_text = rep.args["result"] + assert "ACTED_IN" in result_text + assert "Alice" in result_text + + def test_format_paths_falls_back_to_triplets(self, alice, matrix, acted_in): + """'paths' style falls back to triplets when there are no explicit paths.""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in], format_style="paths") + rep = r.format_for_llm() + assert "ACTED_IN" in rep.args["result"] + + def test_format_structured(self, alice, matrix, acted_in): + """'structured' style returns valid JSON.""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in], format_style="structured") + rep = r.format_for_llm() + data = json.loads(rep.args["result"]) + assert "nodes" in data + assert "edges" in data + assert len(data["nodes"]) == 2 + assert len(data["edges"]) == 1 + + def test_format_for_llm_includes_counts(self, alice, matrix, acted_in): + """format_for_llm args include node_count and edge_count.""" + r = GraphResult(nodes=[alice, matrix], edges=[acted_in]) + rep = r.format_for_llm() + assert rep.args["node_count"] == 2 + assert rep.args["edge_count"] == 1 + + def test_standalone_node_appears_in_triplets(self, alice): + """Nodes without edges appear in triplet output.""" + r = GraphResult(nodes=[alice], format_style="triplets") + rep = r.format_for_llm() + assert "Alice" in rep.args["result"] + + def test_parts_raises(self): + """parts() raises NotImplementedError.""" + r = GraphResult() + with pytest.raises(NotImplementedError): + r.parts() + + +# --------------------------------------------------------------------------- +# GraphTraversal +# --------------------------------------------------------------------------- + + +class TestGraphTraversal: + """Tests for GraphTraversal Component.""" + + def test_implements_component_protocol(self): + """GraphTraversal satisfies the Mellea Component protocol.""" + t = GraphTraversal(start_nodes=["Alice"]) + assert isinstance(t, Component) + + def test_properties(self): + """GraphTraversal exposes start_nodes, pattern, max_depth, description.""" + t = GraphTraversal( + start_nodes=["1", "2"], + pattern="bfs", + max_depth=4, + description="BFS from roots", + ) + assert t.start_nodes == ["1", "2"] + assert t.pattern == "bfs" + assert t.max_depth == 4 + assert t.description == "BFS from roots" + + def test_format_for_llm_returns_template_representation(self): + """format_for_llm() returns TemplateRepresentation with a cypher field.""" + t = GraphTraversal(start_nodes=["Alice"], description="Find connections") + rep = t.format_for_llm() + assert isinstance(rep, TemplateRepresentation) + assert "cypher" in rep.args + assert rep.args["cypher"] is not None + + def test_to_cypher_multi_hop(self): + """to_cypher() for multi_hop returns a CypherQuery.""" + t = GraphTraversal(start_nodes=["Alice"], pattern="multi_hop", max_depth=2) + q = t.to_cypher() + assert isinstance(q, CypherQuery) + assert "*1..2" in q.query_string + assert q.parameters == {"start_nodes": ["Alice"]} + + def test_to_cypher_bfs(self): + """to_cypher() for bfs returns a CypherQuery (same pattern as multi_hop).""" + t = GraphTraversal(start_nodes=["1"], pattern="bfs", max_depth=3) + q = t.to_cypher() + assert isinstance(q, CypherQuery) + assert "*1..3" in q.query_string + + def test_to_cypher_dfs(self): + """to_cypher() for dfs returns a CypherQuery.""" + t = GraphTraversal(start_nodes=["1"], pattern="dfs", max_depth=3) + q = t.to_cypher() + assert isinstance(q, CypherQuery) + + def test_to_cypher_shortest_path(self): + """to_cypher() for shortest_path uses shortestPath().""" + t = GraphTraversal(start_nodes=["A"], pattern="shortest_path", max_depth=5) + q = t.to_cypher() + assert isinstance(q, CypherQuery) + assert "shortestPath" in q.query_string + + def test_to_cypher_unsupported_pattern(self): + """to_cypher() raises ValueError for unknown patterns.""" + t = GraphTraversal(start_nodes=["A"], pattern="unknown_pattern") + with pytest.raises(ValueError, match="Unsupported traversal pattern"): + t.to_cypher() + + def test_to_cypher_preserves_description(self): + """to_cypher() carries the traversal description into the CypherQuery.""" + t = GraphTraversal( + start_nodes=["Alice"], pattern="multi_hop", description="My traversal" + ) + q = t.to_cypher() + assert q.description == "My traversal" + + def test_parts_raises(self): + """parts() raises NotImplementedError.""" + t = GraphTraversal(start_nodes=[]) + with pytest.raises(NotImplementedError): + t.parts() diff --git a/test/kg/test_layer3.py b/test/kg/test_layer3.py new file mode 100644 index 0000000..22b6b0e --- /dev/null +++ b/test/kg/test_layer3.py @@ -0,0 +1,291 @@ +"""Tests for Layer 3: LLM-guided query construction. + +Covers: +- components/llm_guided.py: @generative functions and GeneratedQuery model +- sampling/validation.py: QueryValidationStrategy +- requirements/__init__.py: is_valid_cypher, returns_results, respects_schema +""" + +import pytest +from mellea.stdlib.components import CBlock, ModelOutputThunk +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.requirements import Requirement, ValidationResult + +from mellea_contribs.kg.base import GraphEdge, GraphNode +from mellea_contribs.kg.components.llm_guided import ( + GeneratedQuery, + explain_query_result, + natural_language_to_cypher, + suggest_query_improvement, +) +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend +from mellea_contribs.kg.requirements import ( + is_valid_cypher, + respects_schema, + returns_results, +) +from mellea_contribs.kg.sampling import QueryValidationStrategy + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_backend(): + """Create a mock backend with sample data.""" + nodes = [ + GraphNode(id="1", label="Person", properties={"name": "Alice"}), + GraphNode(id="2", label="Movie", properties={"title": "The Matrix"}), + ] + edges = [ + GraphEdge( + id="e1", + source=nodes[0], + label="ACTED_IN", + target=nodes[1], + properties={}, + ) + ] + return MockGraphBackend(mock_nodes=nodes, mock_edges=edges) + + +@pytest.fixture +def empty_backend(): + """Create a mock backend with no data.""" + return MockGraphBackend() + + +@pytest.fixture +def ctx_with_query(): + """Create a SimpleContext whose last output is a Cypher query string.""" + ctx = SimpleContext() + thunk = ModelOutputThunk(value="MATCH (n:Person) RETURN n") + return SimpleContext.from_previous(ctx, thunk) + + +# --------------------------------------------------------------------------- +# GeneratedQuery model +# --------------------------------------------------------------------------- + + +class TestGeneratedQuery: + """Tests for the GeneratedQuery Pydantic model.""" + + def test_create_with_all_fields(self): + """GeneratedQuery can be created with all fields.""" + gq = GeneratedQuery( + query="MATCH (n) RETURN n", + explanation="Returns all nodes", + parameters={"limit": 10}, + ) + assert gq.query == "MATCH (n) RETURN n" + assert gq.explanation == "Returns all nodes" + assert gq.parameters == {"limit": 10} + + def test_parameters_optional(self): + """GeneratedQuery parameters field defaults to None.""" + gq = GeneratedQuery(query="MATCH (n) RETURN n", explanation="All nodes") + assert gq.parameters is None + + def test_is_pydantic_model(self): + """GeneratedQuery is a Pydantic BaseModel.""" + from pydantic import BaseModel + + assert issubclass(GeneratedQuery, BaseModel) + + +# --------------------------------------------------------------------------- +# @generative functions +# --------------------------------------------------------------------------- + + +class TestGenerativeFunctions: + """Tests for the @generative LLM-guided functions.""" + + def test_natural_language_to_cypher_is_callable(self): + """natural_language_to_cypher is a callable generative slot.""" + assert callable(natural_language_to_cypher) + + def test_explain_query_result_is_callable(self): + """explain_query_result is a callable generative slot.""" + assert callable(explain_query_result) + + def test_suggest_query_improvement_is_callable(self): + """suggest_query_improvement is a callable generative slot.""" + assert callable(suggest_query_improvement) + + def test_functions_have_correct_names(self): + """@generative functions preserve their original names.""" + assert natural_language_to_cypher.__name__ == "natural_language_to_cypher" + assert explain_query_result.__name__ == "explain_query_result" + assert suggest_query_improvement.__name__ == "suggest_query_improvement" + + +# --------------------------------------------------------------------------- +# QueryValidationStrategy +# --------------------------------------------------------------------------- + + +class TestQueryValidationStrategy: + """Tests for QueryValidationStrategy.""" + + def test_create_strategy(self, mock_backend): + """QueryValidationStrategy can be created with a backend.""" + strategy = QueryValidationStrategy(backend=mock_backend) + assert strategy._backend is mock_backend + assert strategy.loop_budget == 3 + + def test_create_strategy_custom_budget(self, mock_backend): + """QueryValidationStrategy respects a custom loop_budget.""" + strategy = QueryValidationStrategy(backend=mock_backend, loop_budget=5) + assert strategy.loop_budget == 5 + + def test_create_strategy_with_requirements(self, mock_backend): + """QueryValidationStrategy stores requirements.""" + req = is_valid_cypher(mock_backend) + strategy = QueryValidationStrategy( + backend=mock_backend, requirements=[req] + ) + assert strategy.requirements == [req] + + def test_repair_returns_cblock_and_context(self, ctx_with_query): + """repair() returns a (CBlock, Context) tuple.""" + failed_thunk = ModelOutputThunk(value="MATCH (n) RETURN n WRONG") + val_result = ValidationResult(False, reason="Syntax error near WRONG") + from mellea.stdlib.requirements import Requirement + + req = Requirement(description="valid cypher") + + action, new_ctx = QueryValidationStrategy.repair( + old_ctx=ctx_with_query, + new_ctx=ctx_with_query, + past_actions=[], + past_results=[failed_thunk], + past_val=[[(req, val_result)]], + ) + + assert isinstance(action, CBlock) + assert "Syntax error near WRONG" in str(action) + assert "MATCH (n) RETURN n WRONG" in str(action) + + def test_repair_handles_multiple_errors(self, ctx_with_query): + """repair() collects all error messages from the last validation.""" + thunk = ModelOutputThunk(value="bad query") + req1 = Requirement(description="req1") + req2 = Requirement(description="req2") + val = [ + (req1, ValidationResult(False, reason="Error A")), + (req2, ValidationResult(False, reason="Error B")), + ] + + action, _ = QueryValidationStrategy.repair( + old_ctx=ctx_with_query, + new_ctx=ctx_with_query, + past_actions=[], + past_results=[thunk], + past_val=[val], + ) + + assert "Error A" in str(action) + assert "Error B" in str(action) + + def test_select_from_failure_picks_fewest_errors(self): + """select_from_failure returns the index with the fewest failures.""" + req = Requirement(description="req") + val_0 = [(req, ValidationResult(False)), (req, ValidationResult(False))] + val_1 = [(req, ValidationResult(False))] + val_2 = [(req, ValidationResult(True))] + + idx = QueryValidationStrategy.select_from_failure( + sampled_actions=[], + sampled_results=[], + sampled_val=[val_0, val_1, val_2], + ) + + assert idx == 2 + + def test_select_from_failure_tie_picks_first(self): + """select_from_failure picks the first index when error counts tie.""" + req = Requirement(description="req") + val = [(req, ValidationResult(False))] + + idx = QueryValidationStrategy.select_from_failure( + sampled_actions=[], + sampled_results=[], + sampled_val=[val, val], + ) + + assert idx == 0 + + +# --------------------------------------------------------------------------- +# Requirements +# --------------------------------------------------------------------------- + + +class TestRequirements: + """Tests for graph-specific Requirement factories.""" + + def test_is_valid_cypher_returns_requirement(self, mock_backend): + """is_valid_cypher() returns a Requirement.""" + req = is_valid_cypher(mock_backend) + assert isinstance(req, Requirement) + assert "Cypher" in req.description + + def test_returns_results_returns_requirement(self, mock_backend): + """returns_results() returns a Requirement.""" + req = returns_results(mock_backend) + assert isinstance(req, Requirement) + assert "results" in req.description.lower() + + def test_respects_schema_returns_requirement(self, mock_backend): + """respects_schema() returns a Requirement.""" + req = respects_schema(mock_backend) + assert isinstance(req, Requirement) + assert "schema" in req.description.lower() + + def test_all_requirements_have_validation_fn(self, mock_backend): + """All requirement factories set a validation_fn.""" + for req in [ + is_valid_cypher(mock_backend), + returns_results(mock_backend), + respects_schema(mock_backend), + ]: + assert req.validation_fn is not None + + async def test_is_valid_cypher_passes_for_valid_query( + self, mock_backend, ctx_with_query + ): + """is_valid_cypher validation passes when backend validates OK.""" + req = is_valid_cypher(mock_backend) + result = await req.validation_fn(ctx_with_query) + assert isinstance(result, ValidationResult) + assert bool(result) is True + + async def test_returns_results_passes_when_data_present( + self, mock_backend, ctx_with_query + ): + """returns_results validation passes when backend has nodes/edges.""" + req = returns_results(mock_backend) + result = await req.validation_fn(ctx_with_query) + assert isinstance(result, ValidationResult) + assert bool(result) is True + + async def test_returns_results_fails_when_no_data( + self, empty_backend, ctx_with_query + ): + """returns_results validation fails when backend has no data.""" + req = returns_results(empty_backend) + result = await req.validation_fn(ctx_with_query) + assert isinstance(result, ValidationResult) + assert bool(result) is False + assert "no results" in result.reason.lower() + + async def test_respects_schema_passes(self, mock_backend, ctx_with_query): + """respects_schema validation passes (placeholder implementation).""" + req = respects_schema(mock_backend) + result = await req.validation_fn(ctx_with_query) + assert isinstance(result, ValidationResult) + assert bool(result) is True diff --git a/test/kg/test_mock_backend.py b/test/kg/test_mock_backend.py new file mode 100644 index 0000000..bd5be40 --- /dev/null +++ b/test/kg/test_mock_backend.py @@ -0,0 +1,108 @@ +"""Tests for MockGraphBackend.""" + +import pytest + +from mellea_contribs.kg.base import GraphEdge, GraphNode +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend + + +@pytest.fixture +def mock_nodes(): + """Create mock nodes for testing.""" + return [ + GraphNode(id="1", label="Person", properties={"name": "Alice"}), + GraphNode(id="2", label="Person", properties={"name": "Bob"}), + GraphNode(id="3", label="Movie", properties={"title": "The Matrix"}), + ] + + +@pytest.fixture +def mock_edges(mock_nodes): + """Create mock edges for testing.""" + return [ + GraphEdge( + id="e1", + source=mock_nodes[0], + label="ACTED_IN", + target=mock_nodes[2], + properties={"role": "Neo"}, + ), + GraphEdge( + id="e2", + source=mock_nodes[1], + label="ACTED_IN", + target=mock_nodes[2], + properties={"role": "Morpheus"}, + ), + ] + + +@pytest.fixture +def mock_backend(mock_nodes, mock_edges): + """Create a mock backend for testing.""" + return MockGraphBackend(mock_nodes=mock_nodes, mock_edges=mock_edges) + + +class TestMockGraphBackend: + """Tests for MockGraphBackend.""" + + def test_create_mock_backend(self): + """Test creating a MockGraphBackend.""" + backend = MockGraphBackend() + + assert backend.backend_id == "mock" + assert backend.connection_uri == "mock://localhost" + assert backend.mock_nodes == [] + assert backend.mock_edges == [] + + def test_mock_backend_with_data(self, mock_backend, mock_nodes, mock_edges): + """Test MockGraphBackend with predefined data.""" + assert len(mock_backend.mock_nodes) == 3 + assert len(mock_backend.mock_edges) == 2 + assert mock_backend.mock_nodes == mock_nodes + assert mock_backend.mock_edges == mock_edges + + @pytest.mark.asyncio + async def test_get_schema(self, mock_backend): + """Test getting mock schema.""" + schema = await mock_backend.get_schema() + + assert "node_types" in schema + assert "edge_types" in schema + assert "property_keys" in schema + assert "MockNode" in schema["node_types"] + assert "MOCK_EDGE" in schema["edge_types"] + + @pytest.mark.asyncio + async def test_validate_query_always_valid(self, mock_backend): + """Test that mock backend always validates queries as valid.""" + # Need to create a minimal query object + from mellea_contribs.kg.components.query import GraphQuery + + query = GraphQuery(query_string="MATCH (n) RETURN n") + is_valid, error = await mock_backend.validate_query(query) + + assert is_valid is True + assert error is None + + def test_supports_all_query_types(self, mock_backend): + """Test that mock backend supports all query types.""" + assert mock_backend.supports_query_type("cypher") is True + assert mock_backend.supports_query_type("sparql") is True + assert mock_backend.supports_query_type("gremlin") is True + assert mock_backend.supports_query_type("unknown") is True + + def test_query_history_tracking(self, mock_backend): + """Test that mock backend tracks query history.""" + assert len(mock_backend.query_history) == 0 + + mock_backend.clear_history() + assert len(mock_backend.query_history) == 0 + + def test_clear_history(self, mock_backend): + """Test clearing query history.""" + mock_backend.query_history.append(("MATCH (n) RETURN n", {})) + assert len(mock_backend.query_history) == 1 + + mock_backend.clear_history() + assert len(mock_backend.query_history) == 0 diff --git a/test/kg/test_neo4j_backend.py b/test/kg/test_neo4j_backend.py new file mode 100644 index 0000000..4e385a2 --- /dev/null +++ b/test/kg/test_neo4j_backend.py @@ -0,0 +1,227 @@ +"""Tests for Neo4jBackend. + +These tests require a running Neo4j instance. They will be skipped if: +- Neo4j is not installed +- No Neo4j instance is available at the connection URI +- Authentication fails + +To run these tests: + +1. Set environment variables (recommended): + export NEO4J_PASSWORD="your_password" + uv run pytest test/contribs/kg/test_neo4j_backend.py -v + +2. Or start Neo4j with Docker: + docker run --rm -p 7687:7687 -p 7474:7474 \ + -e NEO4J_AUTH=neo4j/testpassword \ + neo4j:latest + +Note: Fixtures are defined in conftest.py and use environment variables: + - NEO4J_URI (default: bolt://localhost:7687) + - NEO4J_USER (default: neo4j) + - NEO4J_PASSWORD (default: testpassword) +""" + +import os + +import pytest + +from mellea_contribs.kg.components.query import GraphQuery + +try: + from mellea_contribs.kg.graph_dbs.neo4j import Neo4jBackend + + NEO4J_AVAILABLE = True +except ImportError: + NEO4J_AVAILABLE = False + + +# Skip all tests in this module if Neo4j is not available +pytestmark = [ + pytest.mark.neo4j, + pytest.mark.skipif(not NEO4J_AVAILABLE, reason="Neo4j driver not installed"), +] + + +class TestNeo4jBackend: + """Tests for Neo4jBackend.""" + + @pytest.mark.asyncio + async def test_create_neo4j_backend(self, neo4j_credentials): + """Test creating a Neo4jBackend.""" + backend = Neo4jBackend( + connection_uri=neo4j_credentials["uri"], + auth=(neo4j_credentials["user"], neo4j_credentials["password"]), + ) + + assert backend.backend_id == "neo4j" + assert backend.connection_uri == neo4j_credentials["uri"] + assert backend.auth == ( + neo4j_credentials["user"], + neo4j_credentials["password"], + ) + + await backend.close() + + @pytest.mark.asyncio + async def test_supports_cypher(self, neo4j_backend): + """Test that Neo4j backend supports Cypher queries.""" + assert neo4j_backend.supports_query_type("cypher") is True + assert neo4j_backend.supports_query_type("sparql") is False + + @pytest.mark.asyncio + async def test_get_schema(self, neo4j_backend): + """Test getting Neo4j schema.""" + schema = await neo4j_backend.get_schema() + + assert "node_types" in schema + assert "edge_types" in schema + assert "property_keys" in schema + assert isinstance(schema["node_types"], list) + assert isinstance(schema["edge_types"], list) + assert isinstance(schema["property_keys"], list) + + @pytest.mark.asyncio + async def test_validate_valid_query(self, neo4j_backend): + """Test validating a valid Cypher query.""" + query = GraphQuery(query_string="MATCH (n) RETURN n LIMIT 10") + is_valid, error = await neo4j_backend.validate_query(query) + + assert is_valid is True + assert error is None + + @pytest.mark.asyncio + async def test_validate_invalid_query(self, neo4j_backend): + """Test validating an invalid Cypher query.""" + query = GraphQuery(query_string="INVALID CYPHER QUERY") + is_valid, error = await neo4j_backend.validate_query(query) + + assert is_valid is False + assert error is not None + assert len(error) > 0 + + @pytest.mark.asyncio + async def test_execute_simple_query(self, populated_neo4j_backend): + """Test executing a simple query.""" + query = GraphQuery(query_string="MATCH (p:Person) RETURN p ORDER BY p.name") + result = await populated_neo4j_backend.execute_query(query) + + assert len(result.nodes) == 2 + assert result.nodes[0].label == "Person" + assert result.nodes[0].properties["name"] == "Alice" + assert result.nodes[1].properties["name"] == "Bob" + + @pytest.mark.asyncio + async def test_execute_query_with_parameters(self, populated_neo4j_backend): + """Test executing a query with parameters.""" + query = GraphQuery( + query_string="MATCH (p:Person {name: $name}) RETURN p", + parameters={"name": "Alice"}, + ) + result = await populated_neo4j_backend.execute_query(query) + + assert len(result.nodes) == 1 + assert result.nodes[0].properties["name"] == "Alice" + assert result.nodes[0].properties["age"] == 30 + + @pytest.mark.asyncio + async def test_execute_query_with_relationships(self, populated_neo4j_backend): + """Test executing a query that returns relationships.""" + query = GraphQuery( + query_string=""" + MATCH (p:Person)-[r:ACTED_IN]->(m:Movie) + RETURN p, r, m + ORDER BY p.name + """ + ) + result = await populated_neo4j_backend.execute_query(query) + + # Should have 2 people and 1 movie + assert len(result.nodes) == 3 + assert len(result.edges) == 2 + + # Check edge properties + assert result.edges[0].label == "ACTED_IN" + assert "role" in result.edges[0].properties + + @pytest.mark.asyncio + async def test_execute_query_no_results(self, populated_neo4j_backend): + """Test executing a query that returns no results.""" + query = GraphQuery( + query_string="MATCH (p:Person {name: 'NonExistent'}) RETURN p" + ) + result = await populated_neo4j_backend.execute_query(query) + + assert len(result.nodes) == 0 + assert len(result.edges) == 0 + + @pytest.mark.asyncio + async def test_execute_query_different_format_styles(self, populated_neo4j_backend): + """Test executing query with different format styles.""" + query = GraphQuery(query_string="MATCH (p:Person) RETURN p") + + result_triplets = await populated_neo4j_backend.execute_query( + query, format_style="triplets" + ) + assert result_triplets.format_style == "triplets" + + result_natural = await populated_neo4j_backend.execute_query( + query, format_style="natural" + ) + assert result_natural.format_style == "natural" + + @pytest.mark.asyncio + async def test_parse_neo4j_result_deduplication(self, populated_neo4j_backend): + """Test that parsed results deduplicate nodes and edges.""" + query = GraphQuery( + query_string=""" + MATCH (p:Person)-[r:ACTED_IN]->(m:Movie) + RETURN p, r, m + UNION + MATCH (p:Person)-[r:ACTED_IN]->(m:Movie) + RETURN p, r, m + """ + ) + result = await populated_neo4j_backend.execute_query(query) + + # Should still have only 3 unique nodes and 2 unique edges despite UNION + assert len(result.nodes) == 3 + assert len(result.edges) == 2 + + @pytest.mark.asyncio + async def test_execute_query_with_path(self, populated_neo4j_backend): + """Test executing a query that returns a path.""" + query = GraphQuery( + query_string=""" + MATCH path = (p:Person)-[:ACTED_IN]->(m:Movie) + WHERE p.name = 'Alice' + RETURN path + """ + ) + result = await populated_neo4j_backend.execute_query(query) + + assert len(result.paths) == 1 + path = result.paths[0] + assert len(path.nodes) == 2 + assert len(path.edges) == 1 + assert path.nodes[0].properties["name"] == "Alice" + assert path.nodes[1].properties["title"] == "The Matrix" + + @pytest.mark.asyncio + async def test_execute_query_empty_string(self, neo4j_backend): + """Test that empty query string raises ValueError.""" + query = GraphQuery(query_string="") + + with pytest.raises(ValueError, match="Query string is empty"): + await neo4j_backend.execute_query(query) + + @pytest.mark.asyncio + async def test_backend_close(self, neo4j_credentials): + """Test closing backend connections.""" + backend = Neo4jBackend( + connection_uri=neo4j_credentials["uri"], + auth=(neo4j_credentials["user"], neo4j_credentials["password"]), + ) + + # Should not raise + await backend.close() diff --git a/test/kg/test_phase1_domain_examples.py b/test/kg/test_phase1_domain_examples.py new file mode 100644 index 0000000..3baafbb --- /dev/null +++ b/test/kg/test_phase1_domain_examples.py @@ -0,0 +1,353 @@ +"""Tests for Phase 1 domain-specific examples (movie domain).""" + +import pytest + +from mellea_contribs.kg.models import Entity, Relation + + +class TestMovieEntityModel: + """Tests for MovieEntity domain-specific model.""" + + def test_movie_entity_import(self): + """Test that MovieEntity can be imported.""" + try: + from docs.examples.kgrag.models.movie_domain_models import MovieEntity + + assert MovieEntity is not None + except ImportError: + # If import fails, the file structure is still valid + pytest.skip("MovieEntity example not available in test environment") + + def test_movie_entity_structure(self): + """Test MovieEntity has required fields.""" + try: + from docs.examples.kgrag.models.movie_domain_models import MovieEntity + + # Check that MovieEntity is a class + assert hasattr(MovieEntity, "__init__") + + # Create a MovieEntity instance + entity = MovieEntity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer is", + paragraph_end="by Nolan.", + release_year=2023, + director="Christopher Nolan", + ) + + assert entity.type == "Movie" + assert entity.name == "Oppenheimer" + assert entity.release_year == 2023 + assert entity.director == "Christopher Nolan" + except ImportError: + pytest.skip("MovieEntity example not available") + + def test_movie_entity_optional_fields(self): + """Test MovieEntity optional fields.""" + try: + from docs.examples.kgrag.models.movie_domain_models import MovieEntity + + entity = MovieEntity( + type="Movie", + name="Oppenheimer", + description="A film", + paragraph_start="Movie", + paragraph_end="here.", + box_office=952.0, + language="English", + rating=8.4, + ) + + assert entity.box_office == 952.0 + assert entity.language == "English" + assert entity.rating == 8.4 + except ImportError: + pytest.skip("MovieEntity example not available") + + +class TestPersonEntityModel: + """Tests for PersonEntity domain-specific model.""" + + def test_person_entity_structure(self): + """Test PersonEntity has required fields.""" + try: + from docs.examples.kgrag.models.movie_domain_models import PersonEntity + + entity = PersonEntity( + type="Person", + name="Christopher Nolan", + description="Film director", + paragraph_start="Christopher is", + paragraph_end="a director.", + birth_year=1970, + nationality="British", + profession="Director", + ) + + assert entity.type == "Person" + assert entity.name == "Christopher Nolan" + assert entity.birth_year == 1970 + assert entity.nationality == "British" + assert entity.profession == "Director" + except ImportError: + pytest.skip("PersonEntity example not available") + + +class TestAwardEntityModel: + """Tests for AwardEntity domain-specific model.""" + + def test_award_entity_structure(self): + """Test AwardEntity has required fields.""" + try: + from docs.examples.kgrag.models.movie_domain_models import AwardEntity + + entity = AwardEntity( + type="Award", + name="Best Picture", + description="Academy Award", + paragraph_start="Best Picture", + paragraph_end="award.", + ceremony_number=96, + award_type="Oscar", + award_year=2024, + ) + + assert entity.type == "Award" + assert entity.name == "Best Picture" + assert entity.ceremony_number == 96 + assert entity.award_type == "Oscar" + assert entity.award_year == 2024 + except ImportError: + pytest.skip("AwardEntity example not available") + + +class TestMovieRepresentationUtilities: + """Tests for movie domain representation utilities.""" + + def test_movie_rep_utilities_import(self): + """Test that movie representation utilities can be imported.""" + try: + from docs.examples.kgrag.rep.movie_rep import ( + movie_entity_to_text, + movie_relation_to_text, + format_movie_context, + ) + + assert movie_entity_to_text is not None + assert movie_relation_to_text is not None + assert format_movie_context is not None + except ImportError: + pytest.skip("Movie rep utilities not available") + + def test_movie_entity_to_text_with_movie_entity(self): + """Test movie_entity_to_text with MovieEntity.""" + try: + from docs.examples.kgrag.models.movie_domain_models import MovieEntity + from docs.examples.kgrag.rep.movie_rep import movie_entity_to_text + + entity = MovieEntity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer is", + paragraph_end="great.", + release_year=2023, + director="Christopher Nolan", + ) + + text = movie_entity_to_text(entity) + + assert "Oppenheimer" in text + assert "Movie" in text + # Movie-specific fields should be included + assert "2023" in text or "release" in text.lower() or "2023" in str(entity) + + except ImportError: + pytest.skip("Movie utilities not available") + + def test_movie_entity_to_text_with_person_entity(self): + """Test movie_entity_to_text with PersonEntity.""" + try: + from docs.examples.kgrag.models.movie_domain_models import PersonEntity + from docs.examples.kgrag.rep.movie_rep import movie_entity_to_text + + entity = PersonEntity( + type="Person", + name="Christopher Nolan", + description="Director", + paragraph_start="Christopher", + paragraph_end="is a director.", + birth_year=1970, + nationality="British", + ) + + text = movie_entity_to_text(entity) + + assert "Christopher Nolan" in text + assert "Person" in text + except ImportError: + pytest.skip("Movie utilities not available") + + def test_format_movie_context(self): + """Test format_movie_context function.""" + try: + from docs.examples.kgrag.models.movie_domain_models import MovieEntity + from docs.examples.kgrag.rep.movie_rep import format_movie_context + + entities = [ + MovieEntity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer", + paragraph_end="is great.", + release_year=2023, + ), + ] + + relations = [ + Relation( + source_entity="Christopher Nolan", + relation_type="directed_by", + target_entity="Oppenheimer", + description="Christopher directed Oppenheimer", + ) + ] + + context = format_movie_context(entities, relations) + + assert "Oppenheimer" in context + assert "directed_by" in context or "Entities" in context + + except ImportError: + pytest.skip("Movie utilities not available") + + +class TestMoviePreprocessorExample: + """Tests for MovieKGPreprocessor domain example.""" + + def test_movie_preprocessor_import(self): + """Test that MovieKGPreprocessor can be imported.""" + try: + from docs.examples.kgrag.preprocessor.movie_preprocessor import MovieKGPreprocessor + + assert MovieKGPreprocessor is not None + except ImportError: + pytest.skip("MovieKGPreprocessor not available") + + def test_movie_preprocessor_structure(self): + """Test MovieKGPreprocessor structure.""" + try: + from docs.examples.kgrag.preprocessor.movie_preprocessor import MovieKGPreprocessor + + # Check that it's a class + import inspect + + assert inspect.isclass(MovieKGPreprocessor) + + # Check for key methods + assert hasattr(MovieKGPreprocessor, "get_hints") + assert hasattr(MovieKGPreprocessor, "post_process_extraction") + + except ImportError: + pytest.skip("MovieKGPreprocessor not available") + + +class TestDomainExampleIntegration: + """Integration tests for domain examples.""" + + def test_domain_models_inheritance(self): + """Test that domain models properly extend Entity.""" + try: + from docs.examples.kgrag.models.movie_domain_models import ( + MovieEntity, + PersonEntity, + AwardEntity, + ) + + # All should extend Entity + movie = MovieEntity( + type="Movie", + name="Test", + description="Test", + paragraph_start="Test", + paragraph_end=".", + ) + person = PersonEntity( + type="Person", + name="Test", + description="Test", + paragraph_start="Test", + paragraph_end=".", + ) + award = AwardEntity( + type="Award", + name="Test", + description="Test", + paragraph_start="Test", + paragraph_end=".", + ) + + # Check that they all have Entity fields + for entity in [movie, person, award]: + assert hasattr(entity, "type") + assert hasattr(entity, "name") + assert hasattr(entity, "description") + assert hasattr(entity, "properties") + + except ImportError: + pytest.skip("Domain models not available") + + def test_domain_rep_utilities_with_base_entities(self): + """Test that domain rep utilities work with base Entity.""" + try: + from docs.examples.kgrag.rep.movie_rep import movie_entity_to_text + + # Test with base Entity + entity = Entity( + type="Generic", + name="Generic Entity", + description="A generic entity", + paragraph_start="Generic", + paragraph_end="here.", + ) + + text = movie_entity_to_text(entity) + + assert text is not None + assert len(text) > 0 + + except ImportError: + pytest.skip("Domain rep utilities not available") + + +class TestDomainExampleConsistency: + """Tests for consistency of domain examples.""" + + def test_all_domain_entities_have_same_base_fields(self): + """Test that all domain entities have the same base Entity fields.""" + try: + from docs.examples.kgrag.models.movie_domain_models import ( + MovieEntity, + PersonEntity, + AwardEntity, + ) + + base_fields = ["type", "name", "description", "paragraph_start", "paragraph_end"] + + for EntityClass in [MovieEntity, PersonEntity, AwardEntity]: + entity = EntityClass( + type="Test", + name="Test", + description="Test", + paragraph_start="Test", + paragraph_end=".", + ) + + for field in base_fields: + assert hasattr(entity, field), f"{EntityClass.__name__} missing {field}" + + except ImportError: + pytest.skip("Domain models not available") diff --git a/test/kg/test_phase1_embed_models.py b/test/kg/test_phase1_embed_models.py new file mode 100644 index 0000000..bf99822 --- /dev/null +++ b/test/kg/test_phase1_embed_models.py @@ -0,0 +1,302 @@ +"""Tests for Phase 1 embedding configuration and result models.""" + +import pytest + +from mellea_contribs.kg.embed_models import ( + EmbeddingConfig, + EmbeddingResult, + EmbeddingSimilarity, + EmbeddingStats, +) + + +class TestEmbeddingConfig: + """Tests for EmbeddingConfig model.""" + + def test_create_embedding_config_defaults(self): + """Test creating EmbeddingConfig with defaults.""" + config = EmbeddingConfig() + + assert config.model == "text-embedding-3-small" + assert config.dimension == 1536 + assert config.api_base is None + assert config.api_key is None + + def test_create_embedding_config_openai(self): + """Test creating EmbeddingConfig for OpenAI.""" + config = EmbeddingConfig( + model="text-embedding-3-large", + dimension=3072, + ) + + assert config.model == "text-embedding-3-large" + assert config.dimension == 3072 + + def test_create_embedding_config_custom(self): + """Test creating EmbeddingConfig with custom values.""" + config = EmbeddingConfig( + model="all-MiniLM-L6-v2", + dimension=384, + api_base="http://localhost:8000", + api_key="test-key", + ) + + assert config.model == "all-MiniLM-L6-v2" + assert config.dimension == 384 + assert config.api_base == "http://localhost:8000" + assert config.api_key == "test-key" + + def test_embedding_config_various_dimensions(self): + """Test EmbeddingConfig with various embedding dimensions.""" + dimensions = [384, 768, 1024, 1536, 3072] + + for dim in dimensions: + config = EmbeddingConfig(dimension=dim) + assert config.dimension == dim + + +class TestEmbeddingResult: + """Tests for EmbeddingResult model.""" + + def test_create_embedding_result(self): + """Test creating an EmbeddingResult.""" + embedding_vector = [0.1, 0.2, 0.3, 0.4, 0.5] + result = EmbeddingResult( + text="Alice is a person", + embedding=embedding_vector, + model="text-embedding-3-small", + dimension=1536, + ) + + assert result.text == "Alice is a person" + assert result.embedding == embedding_vector + assert result.model == "text-embedding-3-small" + assert result.dimension == 1536 + + def test_embedding_result_with_large_vector(self): + """Test EmbeddingResult with large embedding.""" + result = EmbeddingResult( + text="Oppenheimer is a 2023 film", + embedding=[0.1] * 1536, + model="text-embedding-3-small", + dimension=1536, + ) + + assert result.text == "Oppenheimer is a 2023 film" + assert len(result.embedding) == 1536 + assert result.dimension == 1536 + + def test_embedding_result_various_sizes(self): + """Test EmbeddingResult with various embedding sizes.""" + sizes = [384, 768, 1536, 3072] + + for size in sizes: + result = EmbeddingResult( + text=f"Test text for embedding {size}", + embedding=[0.5] * size, + model="test-model", + dimension=size, + ) + assert len(result.embedding) == size + assert result.dimension == size + + +class TestEmbeddingSimilarity: + """Tests for EmbeddingSimilarity model.""" + + def test_create_embedding_similarity(self): + """Test creating an EmbeddingSimilarity.""" + similarity = EmbeddingSimilarity( + entity_id="entity-1", + entity_name="Alice", + similarity_score=0.92, + entity_type="Person", + ) + + assert similarity.entity_id == "entity-1" + assert similarity.entity_name == "Alice" + assert similarity.similarity_score == 0.92 + assert similarity.entity_type == "Person" + + def test_embedding_similarity_high_score(self): + """Test EmbeddingSimilarity with high similarity score.""" + similarity = EmbeddingSimilarity( + entity_id="entity-alice", + entity_name="Alice", + similarity_score=0.99, + ) + + assert similarity.similarity_score == 0.99 + + def test_embedding_similarity_low_score(self): + """Test EmbeddingSimilarity with low similarity score.""" + similarity = EmbeddingSimilarity( + entity_id="entity-1", + entity_name="Alice", + similarity_score=0.15, + ) + + assert similarity.similarity_score == 0.15 + + def test_embedding_similarity_multiple_matches(self): + """Test creating multiple EmbeddingSimilarity results.""" + similarities = [ + EmbeddingSimilarity( + entity_id=f"match-{i}", + entity_name=f"Match {i}", + similarity_score=0.9 - (i * 0.1), + ) + for i in range(3) + ] + + assert len(similarities) == 3 + # Scores should be decreasing + assert similarities[0].similarity_score > similarities[1].similarity_score + assert similarities[1].similarity_score > similarities[2].similarity_score + + +class TestEmbeddingStats: + """Tests for EmbeddingStats model.""" + + def test_create_embedding_stats_defaults(self): + """Test creating EmbeddingStats with defaults.""" + stats = EmbeddingStats( + total_entities=0, + successful_embeddings=0, + failed_embeddings=0, + skipped_embeddings=0, + average_embedding_time=0.0, + total_time=0.0, + model_used="test-model", + ) + + assert stats.total_entities == 0 + assert stats.successful_embeddings == 0 + assert stats.failed_embeddings == 0 + + def test_create_embedding_stats_custom(self): + """Test creating EmbeddingStats with custom values.""" + stats = EmbeddingStats( + total_entities=1000, + successful_embeddings=950, + failed_embeddings=30, + skipped_embeddings=20, + average_embedding_time=0.01, + total_time=10.0, + model_used="text-embedding-3-small", + ) + + assert stats.total_entities == 1000 + assert stats.successful_embeddings == 950 + assert stats.failed_embeddings == 30 + assert stats.skipped_embeddings == 20 + assert stats.average_embedding_time == 0.01 + assert stats.total_time == 10.0 + + def test_embedding_stats_success_rate(self): + """Test EmbeddingStats tracking success.""" + stats = EmbeddingStats( + total_entities=100, + successful_embeddings=90, + failed_embeddings=5, + skipped_embeddings=5, + average_embedding_time=0.1, + total_time=10.0, + model_used="test-model", + ) + + # Verify totals add up + total = stats.successful_embeddings + stats.failed_embeddings + stats.skipped_embeddings + assert total == 100 + + def test_embedding_stats_batch_processing(self): + """Test EmbeddingStats for batch processing metrics.""" + stats = EmbeddingStats( + total_entities=1000, + successful_embeddings=950, + failed_embeddings=50, + skipped_embeddings=0, + average_embedding_time=0.005, + total_time=5.0, + model_used="all-MiniLM-L6-v2", + ) + + assert stats.total_entities == 1000 + assert stats.average_embedding_time == 0.005 + assert stats.total_time == 5.0 + + +class TestEmbeddingIntegration: + """Integration tests for embedding models.""" + + def test_embedding_config_and_result_together(self): + """Test using EmbeddingConfig with EmbeddingResult.""" + config = EmbeddingConfig( + model="text-embedding-3-small", + dimension=1536, + ) + + result = EmbeddingResult( + text="Alice", + embedding=[0.1] * config.dimension, + model=config.model, + dimension=config.dimension, + ) + + assert config.dimension == 1536 + assert len(result.embedding) == config.dimension + assert result.model == config.model + + def test_embedding_result_list_with_stats(self): + """Test creating multiple EmbeddingResults with stats.""" + config = EmbeddingConfig( + model="all-MiniLM-L6-v2", + dimension=384, + ) + + results = [ + EmbeddingResult( + text=f"Entity {i}", + embedding=[0.1 + i * 0.01] * config.dimension, + model=config.model, + dimension=config.dimension, + ) + for i in range(10) + ] + + stats = EmbeddingStats( + total_entities=len(results), + successful_embeddings=len(results), + failed_embeddings=0, + skipped_embeddings=0, + average_embedding_time=0.01, + total_time=0.1, + model_used=config.model, + ) + + assert len(results) == 10 + assert stats.total_entities == 10 + assert stats.successful_embeddings == 10 + + def test_similarity_search_workflow(self): + """Test similarity search workflow with embeddings.""" + # Query result + query_result = EmbeddingResult( + text="Alice", + embedding=[0.5] * 384, + model="all-MiniLM-L6-v2", + dimension=384, + ) + + # Match results + matches = [ + EmbeddingSimilarity( + entity_id=f"candidate-{i}", + entity_name=f"Candidate {i}", + similarity_score=0.95 - (i * 0.1), + ) + for i in range(3) + ] + + assert len(matches) == 3 + assert matches[0].similarity_score > matches[1].similarity_score diff --git a/test/kg/test_phase1_embedder.py b/test/kg/test_phase1_embedder.py new file mode 100644 index 0000000..436d004 --- /dev/null +++ b/test/kg/test_phase1_embedder.py @@ -0,0 +1,257 @@ +"""Tests for Phase 1 KGEmbedder structure and interface.""" + +import pytest +import inspect +import asyncio + +from mellea_contribs.kg.embedder import KGEmbedder +from mellea_contribs.kg.models import Entity + + +class TestKGEmbedderStructure: + """Tests for KGEmbedder class structure.""" + + def test_kg_embedder_exists(self): + """Test that KGEmbedder class exists.""" + assert KGEmbedder is not None + + def test_kg_embedder_is_class(self): + """Test that KGEmbedder is a class.""" + assert inspect.isclass(KGEmbedder) + + def test_kg_embedder_has_embed_entity_method(self): + """Test that KGEmbedder has embed_entity method.""" + assert hasattr(KGEmbedder, "embed_entity") + assert callable(getattr(KGEmbedder, "embed_entity")) + + def test_kg_embedder_has_embed_batch_method(self): + """Test that KGEmbedder has embed_batch method.""" + assert hasattr(KGEmbedder, "embed_batch") + assert callable(getattr(KGEmbedder, "embed_batch")) + + def test_kg_embedder_has_get_similar_entities_method(self): + """Test that KGEmbedder has get_similar_entities method.""" + assert hasattr(KGEmbedder, "get_similar_entities") + assert callable(getattr(KGEmbedder, "get_similar_entities")) + + def test_kg_embedder_methods_are_async(self): + """Test that key methods are async.""" + embed_entity_method = getattr(KGEmbedder, "embed_entity") + assert inspect.iscoroutinefunction(embed_entity_method) + + embed_batch_method = getattr(KGEmbedder, "embed_batch") + assert inspect.iscoroutinefunction(embed_batch_method) + + get_similar_method = getattr(KGEmbedder, "get_similar_entities") + assert inspect.iscoroutinefunction(get_similar_method) + + +class TestKGEmbedderInterface: + """Tests for KGEmbedder interface matching Mellea Layer 1 pattern.""" + + def test_kg_embedder_init_individual_params(self): + """Test KGEmbedder initialization with individual parameters (Mellea pattern).""" + # Create embedder with individual parameters (matching KGRag/KGPreprocessor pattern) + embedder = KGEmbedder( + session=None, # Would be MelleaSession in real usage + model="text-embedding-3-small", + dimension=1536, + ) + + assert embedder is not None + # Verify individual parameters are accessible + assert embedder.embedding_model == "text-embedding-3-small" + assert embedder.embedding_dimension == 1536 + + def test_kg_embedder_init_defaults(self): + """Test KGEmbedder initialization with defaults.""" + embedder = KGEmbedder(session=None) + + assert embedder is not None + # Should have default values + assert hasattr(embedder, "embedding_model") + assert hasattr(embedder, "embedding_dimension") + + def test_kg_embedder_init_all_params(self): + """Test KGEmbedder initialization with all individual parameters.""" + embedder = KGEmbedder( + session=None, + model="all-MiniLM-L6-v2", + dimension=384, + api_base="http://localhost:8000", + api_key="test-key", + batch_size=64, + ) + + assert embedder is not None + assert embedder.embedding_model == "all-MiniLM-L6-v2" + assert embedder.embedding_dimension == 384 + + def test_kg_embedder_parameter_defaults(self): + """Test that KGEmbedder has sensible parameter defaults.""" + embedder = KGEmbedder(session=None) + + # Should have reasonable defaults + assert isinstance(embedder.embedding_model, str) + assert isinstance(embedder.embedding_dimension, int) + assert embedder.embedding_dimension > 0 + + +class TestKGEmbedderDocumentation: + """Tests for KGEmbedder documentation.""" + + def test_kg_embedder_has_docstring(self): + """Test that KGEmbedder has docstring.""" + assert KGEmbedder.__doc__ is not None + assert len(KGEmbedder.__doc__) > 0 + + def test_kg_embedder_method_docstrings(self): + """Test that key methods have docstrings.""" + methods = ["embed_entity", "embed_batch", "get_similar_entities"] + + for method_name in methods: + method = getattr(KGEmbedder, method_name) + assert method.__doc__ is not None, f"Method {method_name} missing docstring" + assert len(method.__doc__) > 0, f"Method {method_name} has empty docstring" + + def test_kg_embedder_init_docstring(self): + """Test that __init__ has docstring.""" + assert KGEmbedder.__init__.__doc__ is not None + + +class TestKGEmbedderInstantiation: + """Tests for KGEmbedder instantiation (Mellea Layer 1 pattern).""" + + def test_kg_embedder_can_be_instantiated(self): + """Test that KGEmbedder can be instantiated.""" + embedder = KGEmbedder(session=None) + assert embedder is not None + assert isinstance(embedder, KGEmbedder) + + def test_kg_embedder_with_various_models(self): + """Test KGEmbedder with various embedding models.""" + models = [ + "text-embedding-3-small", + "text-embedding-3-large", + "all-MiniLM-L6-v2", + ] + + for model_name in models: + embedder = KGEmbedder( + session=None, + model=model_name, + ) + + assert embedder is not None + assert embedder.embedding_model == model_name + + def test_kg_embedder_with_various_dimensions(self): + """Test KGEmbedder with various embedding dimensions.""" + dimensions = [384, 768, 1536, 3072] + + for dim in dimensions: + embedder = KGEmbedder( + session=None, + dimension=dim, + ) + + assert embedder is not None + assert embedder.embedding_dimension == dim + + def test_kg_embedder_consistency_across_instances(self): + """Test that multiple KGEmbedder instances maintain separate configs.""" + embedder1 = KGEmbedder( + session=None, + model="model1", + dimension=384, + ) + + embedder2 = KGEmbedder( + session=None, + model="model2", + dimension=1536, + ) + + assert embedder1.embedding_model == "model1" + assert embedder1.embedding_dimension == 384 + assert embedder2.embedding_model == "model2" + assert embedder2.embedding_dimension == 1536 + + +class TestKGEmbedderMethodSignatures: + """Tests for detailed method signatures.""" + + def test_embed_entity_parameters(self): + """Test embed_entity method parameters.""" + sig = inspect.signature(KGEmbedder.embed_entity) + params = sig.parameters + + assert "self" in params + assert "entity" in params + # Check for optional parameters like use_name, use_description + param_names = list(params.keys()) + assert any("use" in p.lower() for p in param_names) + + def test_embed_batch_parameters(self): + """Test embed_batch method parameters.""" + sig = inspect.signature(KGEmbedder.embed_batch) + params = sig.parameters + + assert "self" in params + assert "entities" in params + + def test_get_similar_entities_parameters(self): + """Test get_similar_entities method parameters.""" + sig = inspect.signature(KGEmbedder.get_similar_entities) + params = sig.parameters + + assert "self" in params + assert "query_entity" in params or "query" in params + + +class TestKGEmbedderIntegration: + """Integration tests for KGEmbedder with other models.""" + + def test_kg_embedder_with_entity_model(self): + """Test that KGEmbedder works with Entity model.""" + entity = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice", + paragraph_end="here.", + ) + + embedder = KGEmbedder(session=None) + + # Just verify the interaction is type-safe + assert entity is not None + assert embedder is not None + + def test_kg_embedder_parameter_consistency(self): + """Test that embedder maintains parameter consistency.""" + # Verify that parameters are stored and accessible + embedder = KGEmbedder( + session=None, + model="test-model", + dimension=512, + batch_size=32, + ) + + assert embedder.embedding_model == "test-model" + assert embedder.embedding_dimension == 512 + + def test_kg_embedder_matches_layer1_pattern(self): + """Test that KGEmbedder matches Mellea Layer 1 pattern (individual params).""" + # Similar to KGRag and KGPreprocessor + embedder = KGEmbedder( + session=None, # Would be MelleaSession + model="default-model", + dimension=1536, + batch_size=10, # Individual parameters, not config object + ) + + assert embedder is not None + # All parameters should be individually accessible + assert hasattr(embedder, "embedding_model") + assert hasattr(embedder, "embedding_dimension") diff --git a/test/kg/test_phase1_models.py b/test/kg/test_phase1_models.py new file mode 100644 index 0000000..9016dcd --- /dev/null +++ b/test/kg/test_phase1_models.py @@ -0,0 +1,291 @@ +"""Tests for Phase 1 Entity and Relation models.""" + +import pytest + +from mellea_contribs.kg.models import Entity, Relation, DirectAnswer + + +class TestEntity: + """Tests for Entity model.""" + + def test_create_entity_minimal(self): + """Test creating Entity with minimal fields.""" + entity = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice is", + paragraph_end="a character.", + ) + + assert entity.type == "Person" + assert entity.name == "Alice" + assert entity.description == "A person" + assert entity.paragraph_start == "Alice is" + assert entity.paragraph_end == "a character." + # Check optional storage fields default + assert entity.id is None + assert entity.confidence == 1.0 + assert entity.embedding is None + + def test_create_entity_with_storage_fields(self): + """Test creating Entity with storage fields.""" + entity = Entity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer is", + paragraph_end="by Nolan.", + id="neo4j-123", + confidence=0.95, + embedding=[0.1, 0.2, 0.3], + ) + + assert entity.id == "neo4j-123" + assert entity.confidence == 0.95 + assert entity.embedding == [0.1, 0.2, 0.3] + + def test_entity_with_properties(self): + """Test Entity with properties dict.""" + props = {"director": "Christopher Nolan", "year": 2023} + entity = Entity( + type="Movie", + name="Oppenheimer", + description="2023 film", + paragraph_start="Oppenheimer is", + paragraph_end="by Nolan.", + properties=props, + ) + + assert entity.properties == props + assert entity.properties["director"] == "Christopher Nolan" + + def test_entity_confidence_range(self): + """Test Entity with valid confidence values.""" + entity1 = Entity( + type="Person", + name="Bob", + description="Test", + paragraph_start="Bob", + paragraph_end="test.", + confidence=0.0, + ) + entity2 = Entity( + type="Person", + name="Bob", + description="Test", + paragraph_start="Bob", + paragraph_end="test.", + confidence=1.0, + ) + entity3 = Entity( + type="Person", + name="Bob", + description="Test", + paragraph_start="Bob", + paragraph_end="test.", + confidence=0.5, + ) + + assert entity1.confidence == 0.0 + assert entity2.confidence == 1.0 + assert entity3.confidence == 0.5 + + +class TestRelation: + """Tests for Relation model.""" + + def test_create_relation_minimal(self): + """Test creating Relation with minimal fields.""" + relation = Relation( + source_entity="Alice", + relation_type="KNOWS", + target_entity="Bob", + description="Alice knows Bob", + paragraph_start="Alice knows Bob", + paragraph_end="very well.", + ) + + assert relation.source_entity == "Alice" + assert relation.relation_type == "KNOWS" + assert relation.target_entity == "Bob" + assert relation.description == "Alice knows Bob" + # Check optional storage fields default + assert relation.id is None + assert relation.source_entity_id is None + assert relation.target_entity_id is None + assert relation.valid_from is None + assert relation.valid_until is None + + def test_create_relation_with_storage_fields(self): + """Test creating Relation with storage fields.""" + relation = Relation( + source_entity="Alice", + relation_type="DIRECTED", + target_entity="Oppenheimer", + description="Alice directed Oppenheimer", + paragraph_start="Alice directed", + paragraph_end="Oppenheimer.", + id="rel-123", + source_entity_id="node-alice", + target_entity_id="node-oppenheimer", + valid_from="2020-01-01", + valid_until="2023-12-31", + ) + + assert relation.id == "rel-123" + assert relation.source_entity_id == "node-alice" + assert relation.target_entity_id == "node-oppenheimer" + assert relation.valid_from == "2020-01-01" + assert relation.valid_until == "2023-12-31" + + def test_relation_with_properties(self): + """Test Relation with properties dict.""" + props = {"role": "Director", "years": 3} + relation = Relation( + source_entity="Christopher", + relation_type="DIRECTED", + target_entity="Oppenheimer", + description="Christopher directed Oppenheimer", + paragraph_start="Christopher directed", + paragraph_end="Oppenheimer.", + properties=props, + ) + + assert relation.properties == props + assert relation.properties["role"] == "Director" + + def test_relation_temporal_fields(self): + """Test Relation temporal validity fields.""" + relation = Relation( + source_entity="CEO", + relation_type="LEADS", + target_entity="Company", + description="CEO leads company", + paragraph_start="CEO leads", + paragraph_end="company.", + valid_from="2020-01-01", + valid_until="2025-12-31", + ) + + assert relation.valid_from == "2020-01-01" + assert relation.valid_until == "2025-12-31" + + +class TestDirectAnswer: + """Tests for DirectAnswer model.""" + + def test_create_direct_answer(self): + """Test creating a DirectAnswer.""" + answer = DirectAnswer( + sufficient="Yes", + reason="LLM has knowledge about this", + answer="The answer is 42", + ) + + assert answer.sufficient == "Yes" + assert answer.reason == "LLM has knowledge about this" + assert answer.answer == "The answer is 42" + + def test_direct_answer_negative_case(self): + """Test DirectAnswer with negative sufficient case.""" + answer = DirectAnswer( + sufficient="No", + reason="Not enough information", + answer="I don't know", + ) + + assert answer.sufficient == "No" + assert answer.reason == "Not enough information" + assert answer.answer == "I don't know" + + +class TestEntityRelationIntegration: + """Tests for Entity and Relation integration.""" + + def test_create_entity_and_relate(self): + """Test creating entities and relating them.""" + alice = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice is", + paragraph_end="a person.", + ) + + oppenheimer = Entity( + type="Movie", + name="Oppenheimer", + description="A film", + paragraph_start="Oppenheimer is", + paragraph_end="a film.", + ) + + relation = Relation( + source_entity=alice.name, + relation_type="DIRECTED", + target_entity=oppenheimer.name, + description="Alice directed Oppenheimer", + paragraph_start="Alice directed", + paragraph_end="Oppenheimer.", + ) + + assert alice.type == "Person" + assert oppenheimer.type == "Movie" + assert relation.relation_type == "DIRECTED" + assert relation.source_entity == "Alice" + assert relation.target_entity == "Oppenheimer" + + def test_entity_extraction_to_storage_progression(self): + """Test Entity progression from extracted to stored.""" + # Start: extracted entity (no storage fields set) + extracted = Entity( + type="Person", + name="Bob", + description="Test person", + paragraph_start="Bob is", + paragraph_end="a test.", + ) + + assert extracted.id is None + assert extracted.confidence == 1.0 + + # Later: stored entity (storage fields set) + stored = Entity( + type="Person", + name="Bob", + description="Test person", + paragraph_start="Bob is", + paragraph_end="a test.", + id="neo4j-bob", + confidence=0.98, + embedding=[0.1] * 384, + ) + + assert stored.id == "neo4j-bob" + assert stored.confidence == 0.98 + assert len(stored.embedding) == 384 + + def test_empty_properties_dict(self): + """Test Entity and Relation with empty properties.""" + entity = Entity( + type="Thing", + name="Thing1", + description="A thing", + paragraph_start="Thing", + paragraph_end="here.", + properties={}, + ) + + relation = Relation( + source_entity="A", + relation_type="RELATES", + target_entity="B", + description="A relates to B", + paragraph_start="A relates", + paragraph_end="to B.", + properties={}, + ) + + assert entity.properties == {} + assert relation.properties == {} diff --git a/test/kg/test_phase1_preprocessor.py b/test/kg/test_phase1_preprocessor.py new file mode 100644 index 0000000..1ad5c00 --- /dev/null +++ b/test/kg/test_phase1_preprocessor.py @@ -0,0 +1,117 @@ +"""Tests for Phase 1 KGPreprocessor structure and interface.""" + +import pytest +import inspect + +from mellea_contribs.kg.preprocessor import KGPreprocessor + + +class TestKGPreprocessorStructure: + """Tests for KGPreprocessor class structure.""" + + def test_kg_preprocessor_exists(self): + """Test that KGPreprocessor class exists.""" + assert KGPreprocessor is not None + + def test_kg_preprocessor_is_class(self): + """Test that KGPreprocessor is a class.""" + assert inspect.isclass(KGPreprocessor) + + def test_kg_preprocessor_has_process_document_method(self): + """Test that KGPreprocessor has process_document method.""" + assert hasattr(KGPreprocessor, "process_document") + assert callable(getattr(KGPreprocessor, "process_document")) + + def test_kg_preprocessor_has_persist_extraction_method(self): + """Test that KGPreprocessor has persist_extraction method.""" + assert hasattr(KGPreprocessor, "persist_extraction") + assert callable(getattr(KGPreprocessor, "persist_extraction")) + + def test_kg_preprocessor_has_get_hints_method(self): + """Test that KGPreprocessor has get_hints abstract method.""" + assert hasattr(KGPreprocessor, "get_hints") + + def test_kg_preprocessor_has_post_process_extraction_method(self): + """Test that KGPreprocessor has post_process_extraction method.""" + assert hasattr(KGPreprocessor, "post_process_extraction") + + def test_kg_preprocessor_methods_are_defined(self): + """Test that all required methods are defined.""" + required_methods = [ + "process_document", + "persist_extraction", + "get_hints", + "post_process_extraction", + ] + + for method_name in required_methods: + method = getattr(KGPreprocessor, method_name, None) + assert method is not None, f"Method {method_name} not found" + assert callable(method), f"Method {method_name} is not callable" + + +class TestKGPreprocessorInterface: + """Tests for KGPreprocessor interface and contract.""" + + def test_kg_preprocessor_init_signature(self): + """Test KGPreprocessor __init__ signature.""" + sig = inspect.signature(KGPreprocessor.__init__) + params = list(sig.parameters.keys()) + + # Should have at least 'self' + assert "self" in params + + def test_kg_preprocessor_method_signatures(self): + """Test that key methods have correct signatures.""" + # process_document should have parameters + process_sig = inspect.signature(KGPreprocessor.process_document) + process_params = list(process_sig.parameters.keys()) + assert "self" in process_params + + # get_hints should be present + hints_sig = inspect.signature(KGPreprocessor.get_hints) + hints_params = list(hints_sig.parameters.keys()) + assert "self" in hints_params + + def test_kg_preprocessor_docstring(self): + """Test that KGPreprocessor has docstring.""" + assert KGPreprocessor.__doc__ is not None + assert len(KGPreprocessor.__doc__) > 0 + + def test_kg_preprocessor_method_docstrings(self): + """Test that key methods have docstrings.""" + methods = ["process_document", "persist_extraction", "get_hints"] + + for method_name in methods: + method = getattr(KGPreprocessor, method_name) + assert method.__doc__ is not None, f"Method {method_name} missing docstring" + + +class TestKGPreprocessorAbstractContract: + """Tests for KGPreprocessor abstract contract.""" + + def test_kg_preprocessor_cannot_be_instantiated_directly(self): + """Test that KGPreprocessor abstract methods prevent direct instantiation.""" + # KGPreprocessor has abstract methods (get_hints) + # Trying to instantiate should fail if it's properly abstract + try: + from abc import ABC + # Check if KGPreprocessor is abstract + if hasattr(KGPreprocessor, "__abstractmethods__"): + # If it has abstractmethods, direct instantiation should fail + assert len(KGPreprocessor.__abstractmethods__) > 0 or not issubclass( + KGPreprocessor, ABC + ) + except Exception: + # If there's any error, the class structure is correct + pass + + def test_kg_preprocessor_inheritance_ready(self): + """Test that KGPreprocessor is ready for subclassing.""" + # Create a minimal concrete implementation + class ConcretePreprocessor(KGPreprocessor): + def get_hints(self, domain=None): + return "test hints" + + # This should work without errors + assert ConcretePreprocessor is not None diff --git a/test/kg/test_phase1_qa_models.py b/test/kg/test_phase1_qa_models.py new file mode 100644 index 0000000..d615e2c --- /dev/null +++ b/test/kg/test_phase1_qa_models.py @@ -0,0 +1,245 @@ +"""Tests for Phase 1 QA configuration and result models.""" + +import pytest + +from mellea_contribs.kg.qa_models import ( + QAConfig, + QASessionConfig, + QADatasetConfig, + QAResult, + QAStats, +) +from mellea_contribs.kg.models import DirectAnswer, Entity + + +class TestQAConfig: + """Tests for QAConfig model.""" + + def test_create_qa_config_defaults(self): + """Test creating QAConfig with defaults.""" + config = QAConfig() + + assert config.route_count == 3 + assert config.depth == 2 + assert config.width == 5 + assert config.domain is None + assert config.consensus_threshold == 0.7 + assert config.format_style == "detailed" + + def test_create_qa_config_custom(self): + """Test creating QAConfig with custom values.""" + config = QAConfig( + route_count=5, + depth=3, + width=10, + domain="movie", + consensus_threshold=0.8, + format_style="concise", + ) + + assert config.route_count == 5 + assert config.depth == 3 + assert config.width == 10 + assert config.domain == "movie" + assert config.consensus_threshold == 0.8 + assert config.format_style == "concise" + + def test_qa_config_max_repair_attempts(self): + """Test QAConfig max_repair_attempts field.""" + config = QAConfig(max_repair_attempts=3) + assert config.max_repair_attempts == 3 + + +class TestQASessionConfig: + """Tests for QASessionConfig model.""" + + def test_create_qa_session_config_defaults(self): + """Test creating QASessionConfig with defaults.""" + config = QASessionConfig() + + assert config.llm_model is not None + assert config.temperature == 0.1 + assert config.max_tokens == 2048 + assert config.eval_model is not None + + def test_create_qa_session_config_custom(self): + """Test creating QASessionConfig with custom values.""" + config = QASessionConfig( + llm_model="gpt-4", + temperature=0.5, + max_tokens=4096, + eval_model="claude-3-sonnet", + evaluation_threshold=0.75, + ) + + assert config.llm_model == "gpt-4" + assert config.temperature == 0.5 + assert config.max_tokens == 4096 + assert config.eval_model == "claude-3-sonnet" + assert config.evaluation_threshold == 0.75 + + def test_qa_session_config_few_shot_examples(self): + """Test QASessionConfig with few-shot examples.""" + examples = [ + {"question": "Q1", "answer": "A1"}, + {"question": "Q2", "answer": "A2"}, + ] + config = QASessionConfig(few_shot_examples=examples) + assert config.few_shot_examples == examples + + +class TestQADatasetConfig: + """Tests for QADatasetConfig model.""" + + def test_create_qa_dataset_config_defaults(self): + """Test creating QADatasetConfig with defaults.""" + config = QADatasetConfig() + + assert config.dataset_path is not None + assert config.batch_size == 32 + assert config.output_path is not None + assert config.num_workers == 4 + assert config.shuffle is True + + def test_create_qa_dataset_config_custom(self): + """Test creating QADatasetConfig with custom values.""" + config = QADatasetConfig( + dataset_path="data/qa.jsonl", + batch_size=64, + output_path="output/qa_results.json", + num_workers=8, + shuffle=False, + max_examples=100, + skip_errors=True, + ) + + assert config.dataset_path == "data/qa.jsonl" + assert config.batch_size == 64 + assert config.output_path == "output/qa_results.json" + assert config.num_workers == 8 + assert config.shuffle is False + assert config.max_examples == 100 + assert config.skip_errors is True + + +class TestQAResult: + """Tests for QAResult model.""" + + def test_create_qa_result_minimal(self): + """Test creating QAResult with minimal fields.""" + result = QAResult( + question="What is AI?", + answer="Artificial Intelligence", + ) + + assert result.question == "What is AI?" + assert result.answer == "Artificial Intelligence" + + def test_create_qa_result_comprehensive(self): + """Test creating QAResult with comprehensive fields.""" + result = QAResult( + question="What is AI?", + answer="Artificial Intelligence", + confidence=0.95, + cypher_query="MATCH (n) RETURN n", + graph_evidence=["node1", "node2"], + reasoning="Based on knowledge graph", + ) + + assert result.question == "What is AI?" + assert result.answer == "Artificial Intelligence" + assert result.confidence == 0.95 + assert result.cypher_query == "MATCH (n) RETURN n" + assert result.graph_evidence == ["node1", "node2"] + assert result.reasoning == "Based on knowledge graph" + + def test_qa_result_with_direct_answer(self): + """Test QAResult with DirectAnswer object.""" + direct_answer = DirectAnswer( + sufficient="Yes", + reason="LLM has knowledge about this", + answer="42", + ) + result = QAResult( + question="What is the answer?", + answer="42", + direct_answer=direct_answer, + ) + + assert result.direct_answer == direct_answer + assert result.direct_answer.answer == "42" + + +class TestQAStats: + """Tests for QAStats model.""" + + def test_create_qa_stats_defaults(self): + """Test creating QAStats with defaults.""" + stats = QAStats() + + assert stats.total_questions == 0 + assert stats.exact_match_count == 0 + assert stats.partial_match_count == 0 + assert stats.no_match_count == 0 + + def test_create_qa_stats_custom(self): + """Test creating QAStats with custom values.""" + stats = QAStats( + total_questions=100, + exact_match_count=80, + partial_match_count=15, + no_match_count=5, + mean_reciprocal_rank=0.87, + total_time_ms=5000, + avg_time_per_question_ms=50, + ) + + assert stats.total_questions == 100 + assert stats.exact_match_count == 80 + assert stats.partial_match_count == 15 + assert stats.no_match_count == 5 + assert stats.mean_reciprocal_rank == 0.87 + assert stats.total_time_ms == 5000 + assert stats.avg_time_per_question_ms == 50 + + def test_qa_stats_metrics(self): + """Test QAStats metric calculations.""" + stats = QAStats( + total_questions=100, + exact_match_count=50, + partial_match_count=30, + ) + + # Test derived metrics + assert stats.exact_match_count == 50 + assert stats.partial_match_count == 30 + + +class TestQAIntegration: + """Integration tests for QA models.""" + + def test_qa_config_and_session_together(self): + """Test using QAConfig with QASessionConfig.""" + qa_config = QAConfig(domain="movie", route_count=3) + session_config = QASessionConfig(temperature=0.1) + + assert qa_config.domain == "movie" + assert session_config.temperature == 0.1 + + def test_qa_result_list_with_stats(self): + """Test creating multiple QA results with stats.""" + results = [ + QAResult(question="Q1", answer="A1", confidence=0.9), + QAResult(question="Q2", answer="A2", confidence=0.8), + QAResult(question="Q3", answer="A3", confidence=0.95), + ] + + stats = QAStats( + total_questions=len(results), + exact_match_count=2, + avg_time_per_question_ms=100, + ) + + assert len(results) == 3 + assert stats.total_questions == 3 + assert stats.exact_match_count == 2 diff --git a/test/kg/test_phase1_rep.py b/test/kg/test_phase1_rep.py new file mode 100644 index 0000000..6e62239 --- /dev/null +++ b/test/kg/test_phase1_rep.py @@ -0,0 +1,392 @@ +"""Tests for Phase 1 representation utilities (rep.py).""" + +import pytest + +from mellea_contribs.kg.models import Entity, Relation +from mellea_contribs.kg.rep import ( + normalize_entity_name, + entity_to_text, + relation_to_text, + format_entity_list, + format_relation_list, + format_kg_context, + camelcase_to_snake_case, + snake_case_to_camelcase, +) + + +class TestNormalizeEntityName: + """Tests for normalize_entity_name function.""" + + def test_normalize_basic(self): + """Test normalizing basic names.""" + assert normalize_entity_name("alice") == "Alice" + assert normalize_entity_name("ALICE") == "Alice" + + def test_normalize_with_spaces(self): + """Test normalizing names with extra spaces.""" + assert normalize_entity_name("alice bob") == "Alice Bob" + assert normalize_entity_name(" alice ") == "Alice" + + def test_normalize_with_quotes(self): + """Test normalizing names with quotes.""" + result = normalize_entity_name("alice 'bob' charlie") + assert "alice" in result.lower() + assert "bob" in result.lower() + + def test_normalize_title_case(self): + """Test normalization to title case.""" + assert normalize_entity_name("oppenheimer") == "Oppenheimer" + assert normalize_entity_name("christopher nolan") == "Christopher Nolan" + + +class TestEntityToText: + """Tests for entity_to_text function.""" + + def test_entity_to_text_basic(self): + """Test basic entity formatting.""" + entity = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice is", + paragraph_end="here.", + ) + + text = entity_to_text(entity) + assert "Alice" in text + assert "Person" in text + assert "A person" in text + + def test_entity_to_text_with_properties(self): + """Test entity formatting with properties.""" + entity = Entity( + type="Movie", + name="Oppenheimer", + description="A film", + paragraph_start="Oppenheimer is", + paragraph_end="great.", + properties={"year": 2023, "director": "Nolan"}, + ) + + text = entity_to_text(entity) + assert "Oppenheimer" in text + assert "Movie" in text + + def test_entity_to_text_with_confidence(self): + """Test entity formatting with confidence score.""" + entity = Entity( + type="Person", + name="Bob", + description="Test", + paragraph_start="Bob", + paragraph_end="here.", + confidence=0.95, + ) + + text = entity_to_text(entity, include_confidence=True) + assert "Confidence" in text or "0.95" in text + + def test_entity_to_text_no_properties(self): + """Test entity formatting without properties.""" + entity = Entity( + type="Thing", + name="Thing1", + description="A thing", + paragraph_start="Thing", + paragraph_end="here.", + properties={}, + ) + + text = entity_to_text(entity) + assert "Thing1" in text + assert "Thing" in text + + +class TestRelationToText: + """Tests for relation_to_text function.""" + + def test_relation_to_text_basic(self): + """Test basic relation formatting.""" + relation = Relation( + source_entity="Alice", + relation_type="KNOWS", + target_entity="Bob", + description="Alice knows Bob", + paragraph_start="Alice knows", + paragraph_end="Bob.", + ) + + text = relation_to_text(relation) + assert "Alice" in text + assert "KNOWS" in text + assert "Bob" in text + assert "Alice knows Bob" in text + + def test_relation_to_text_with_properties(self): + """Test relation formatting with properties.""" + relation = Relation( + source_entity="Alice", + relation_type="DIRECTED", + target_entity="Oppenheimer", + description="Alice directed Oppenheimer", + paragraph_start="Alice directed", + paragraph_end="Oppenheimer.", + properties={"year": 2023, "budget": "100M"}, + ) + + text = relation_to_text(relation) + assert "Alice" in text + assert "DIRECTED" in text + assert "Oppenheimer" in text + + def test_relation_to_text_with_confidence(self): + """Test relation formatting with confidence.""" + relation = Relation( + source_entity="X", + relation_type="RELATED", + target_entity="Y", + description="X related to Y", + paragraph_start="X related", + paragraph_end="Y.", + ) + + text = relation_to_text(relation, include_confidence=False) + assert "X" in text + assert "RELATED" in text + + +class TestFormatEntityList: + """Tests for format_entity_list function.""" + + def test_format_single_entity(self): + """Test formatting single entity list.""" + entity = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice", + paragraph_end="here.", + ) + + text = format_entity_list([entity]) + assert "1." in text + assert "Alice" in text + + def test_format_multiple_entities(self): + """Test formatting multiple entities.""" + entities = [ + Entity( + type="Person", + name=f"Person{i}", + description=f"Person {i}", + paragraph_start=f"Person{i}", + paragraph_end="here.", + ) + for i in range(3) + ] + + text = format_entity_list(entities) + assert "1." in text + assert "2." in text + assert "3." in text + + def test_format_entity_list_with_max_items(self): + """Test formatting entity list with max limit.""" + entities = [ + Entity( + type="Person", + name=f"Person{i}", + description=f"Person {i}", + paragraph_start=f"Person{i}", + paragraph_end="here.", + ) + for i in range(5) + ] + + text = format_entity_list(entities, max_items=3) + assert "1." in text + assert "2." in text + assert "3." in text + assert "more entities" in text or "..." in text + + def test_format_empty_entity_list(self): + """Test formatting empty entity list.""" + text = format_entity_list([]) + assert text == "" + + +class TestFormatRelationList: + """Tests for format_relation_list function.""" + + def test_format_single_relation(self): + """Test formatting single relation list.""" + relation = Relation( + source_entity="Alice", + relation_type="KNOWS", + target_entity="Bob", + description="Alice knows Bob", + paragraph_start="Alice knows", + paragraph_end="Bob.", + ) + + text = format_relation_list([relation]) + assert "1." in text + assert "Alice" in text + + def test_format_multiple_relations(self): + """Test formatting multiple relations.""" + relations = [ + Relation( + source_entity=f"Entity{i}", + relation_type="RELATES", + target_entity=f"Entity{i+1}", + description=f"Entity {i} relates to Entity {i+1}", + paragraph_start=f"Entity{i}", + paragraph_end=f"Entity{i+1}.", + ) + for i in range(3) + ] + + text = format_relation_list(relations) + assert "1." in text + assert "2." in text + assert "3." in text + + def test_format_relation_list_with_max_items(self): + """Test formatting relation list with max limit.""" + relations = [ + Relation( + source_entity="A", + relation_type="RELATES", + target_entity=f"Entity{i}", + description=f"Relation {i}", + paragraph_start="A", + paragraph_end=f"Entity{i}.", + ) + for i in range(5) + ] + + text = format_relation_list(relations, max_items=2) + assert "1." in text + assert "2." in text + assert "more relations" in text or "..." in text + + +class TestFormatKgContext: + """Tests for format_kg_context function.""" + + def test_format_kg_context_entities_only(self): + """Test formatting KG context with only entities.""" + entities = [ + Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice", + paragraph_end="here.", + ), + Entity( + type="Movie", + name="Oppenheimer", + description="A film", + paragraph_start="Oppenheimer", + paragraph_end="great.", + ), + ] + + text = format_kg_context(entities, []) + assert "Entities" in text or "entities" in text + assert "Alice" in text + assert "Oppenheimer" in text + + def test_format_kg_context_relations_only(self): + """Test formatting KG context with only relations.""" + relations = [ + Relation( + source_entity="Alice", + relation_type="DIRECTED", + target_entity="Oppenheimer", + description="Alice directed Oppenheimer", + paragraph_start="Alice directed", + paragraph_end="Oppenheimer.", + ), + ] + + text = format_kg_context([], relations) + assert "Relations" in text or "relations" in text + assert "Alice" in text + + def test_format_kg_context_both(self): + """Test formatting KG context with entities and relations.""" + entities = [ + Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice", + paragraph_end="here.", + ), + ] + relations = [ + Relation( + source_entity="Alice", + relation_type="KNOWS", + target_entity="Bob", + description="Alice knows Bob", + paragraph_start="Alice knows", + paragraph_end="Bob.", + ), + ] + + text = format_kg_context(entities, relations) + assert "Alice" in text + assert "Bob" in text + assert "KNOWS" in text + + def test_format_kg_context_empty(self): + """Test formatting empty KG context.""" + text = format_kg_context([], []) + assert "Empty" in text or "empty" in text + + +class TestCamelcaseToSnakecase: + """Tests for camelcase_to_snake_case function.""" + + def test_simple_camelcase(self): + """Test converting simple camelCase.""" + assert camelcase_to_snake_case("camelCase") == "camel_case" + assert camelcase_to_snake_case("myVariableName") == "my_variable_name" + + def test_pascalcase(self): + """Test converting PascalCase.""" + assert camelcase_to_snake_case("MyClass") == "my_class" + + def test_already_snake_case(self): + """Test with already snake_case input.""" + assert camelcase_to_snake_case("snake_case") == "snake_case" + assert camelcase_to_snake_case("my_var") == "my_var" + + def test_single_word(self): + """Test with single word.""" + assert camelcase_to_snake_case("word") == "word" + assert camelcase_to_snake_case("Word") == "word" + + +class TestSnakecaseToCamelcase: + """Tests for snake_case_to_camelcase function.""" + + def test_simple_snake_case(self): + """Test converting simple snake_case.""" + assert snake_case_to_camelcase("snake_case") == "snakeCase" + assert snake_case_to_camelcase("my_variable") == "myVariable" + + def test_snake_case_to_pascalcase(self): + """Test converting snake_case to PascalCase.""" + assert snake_case_to_camelcase("my_class", upper_first=True) == "MyClass" + assert snake_case_to_camelcase("my_variable_name", upper_first=True) == "MyVariableName" + + def test_single_word(self): + """Test with single word.""" + assert snake_case_to_camelcase("word") == "word" + assert snake_case_to_camelcase("word", upper_first=True) == "Word" diff --git a/test/kg/test_phase1_requirements.py b/test/kg/test_phase1_requirements.py new file mode 100644 index 0000000..164f976 --- /dev/null +++ b/test/kg/test_phase1_requirements.py @@ -0,0 +1,242 @@ +"""Tests for Phase 1 requirement factory functions.""" + +import pytest + +from mellea_contribs.kg.requirements_models import ( + entity_type_valid, + entity_has_name, + entity_has_description, + relation_type_valid, + relation_entities_exist, + entity_confidence_threshold, +) +from mellea_contribs.kg.models import Entity, Relation + + +class TestRequirementFactories: + """Tests for requirement factory functions.""" + + def test_entity_type_valid_returns_requirement(self): + """Test that entity_type_valid returns a Requirement.""" + req = entity_type_valid(["Person", "Movie"]) + + assert req is not None + assert hasattr(req, "description") + assert hasattr(req, "validation_fn") + assert "Entity type" in req.description or "type" in req.description + + def test_entity_type_valid_factory_multiple_calls(self): + """Test entity_type_valid can be called multiple times.""" + req1 = entity_type_valid(["Person", "Movie"]) + req2 = entity_type_valid(["Company", "Location"]) + + assert req1 is not None + assert req2 is not None + assert req1 != req2 + + def test_relation_type_valid_returns_requirement(self): + """Test that relation_type_valid returns a Requirement.""" + req = relation_type_valid(["directed_by", "acted_in"]) + + assert req is not None + assert hasattr(req, "description") + assert hasattr(req, "validation_fn") + + def test_entity_has_name_returns_function(self): + """Test that entity_has_name returns a callable.""" + assert callable(entity_has_name) + + def test_entity_has_description_returns_function(self): + """Test that entity_has_description returns a callable.""" + assert callable(entity_has_description) + + def test_relation_entities_exist_returns_requirement(self): + """Test that relation_entities_exist returns a Requirement.""" + req = relation_entities_exist(["Alice", "Bob", "Charlie"]) + + assert req is not None + assert hasattr(req, "description") + assert hasattr(req, "validation_fn") + + def test_entity_confidence_threshold_returns_requirement(self): + """Test that entity_confidence_threshold returns a Requirement.""" + req = entity_confidence_threshold(min_confidence=0.8) + + assert req is not None + assert hasattr(req, "description") + assert hasattr(req, "validation_fn") + + +class TestEntityTypeValidRequirement: + """Tests for entity_type_valid requirement.""" + + def test_entity_type_valid_description(self): + """Test entity_type_valid requirement description.""" + allowed_types = ["Person", "Movie", "Award"] + req = entity_type_valid(allowed_types) + + assert req.description is not None + assert "Person" in req.description or "type" in req.description.lower() + + def test_entity_type_valid_with_single_type(self): + """Test entity_type_valid with single allowed type.""" + req = entity_type_valid(["Person"]) + + assert req is not None + assert "Person" in req.description + + def test_entity_type_valid_with_multiple_types(self): + """Test entity_type_valid with multiple allowed types.""" + types = ["Person", "Movie", "Award", "Company"] + req = entity_type_valid(types) + + assert req is not None + for type_name in types: + assert type_name in req.description or len(req.description) > 10 + + +class TestRelationTypeValidRequirement: + """Tests for relation_type_valid requirement.""" + + def test_relation_type_valid_description(self): + """Test relation_type_valid requirement description.""" + allowed_types = ["directed_by", "acted_in", "won"] + req = relation_type_valid(allowed_types) + + assert req.description is not None + assert "Relation type" in req.description or "type" in req.description.lower() + + def test_relation_type_valid_with_various_types(self): + """Test relation_type_valid with various relation types.""" + types = ["KNOWS", "DIRECTED", "ACTED_IN", "WON"] + + for type_list in [types, types[:2], types[:1]]: + req = relation_type_valid(type_list) + assert req is not None + + +class TestEntityNameDescriptionValidation: + """Tests for entity name and description validation functions.""" + + def test_entity_has_name_callable(self): + """Test that entity_has_name is callable.""" + assert callable(entity_has_name) + + def test_entity_has_description_callable(self): + """Test that entity_has_description is callable.""" + assert callable(entity_has_description) + + def test_entity_has_name_with_entity(self): + """Test entity_has_name with an Entity.""" + entity = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice", + paragraph_end="here.", + ) + + # The function should be callable with entity + assert callable(entity_has_name) + + def test_entity_has_description_with_entity(self): + """Test entity_has_description with an Entity.""" + entity = Entity( + type="Person", + name="Alice", + description="A person", + paragraph_start="Alice", + paragraph_end="here.", + ) + + # The function should be callable with entity + assert callable(entity_has_description) + + +class TestRelationEntitiesExistRequirement: + """Tests for relation_entities_exist requirement.""" + + def test_relation_entities_exist_description(self): + """Test relation_entities_exist requirement description.""" + entities = ["Alice", "Bob", "Charlie"] + req = relation_entities_exist(entities) + + assert req.description is not None + assert "entities" in req.description.lower() + + def test_relation_entities_exist_with_various_entity_lists(self): + """Test relation_entities_exist with various entity lists.""" + test_cases = [ + ["Alice", "Bob"], + ["Person1", "Movie1", "Award1"], + ["node-1", "node-2"], + ] + + for entities in test_cases: + req = relation_entities_exist(entities) + assert req is not None + assert req.description is not None + + +class TestEntityConfidenceThresholdRequirement: + """Tests for entity_confidence_threshold requirement.""" + + def test_entity_confidence_threshold_description(self): + """Test entity_confidence_threshold requirement description.""" + req = entity_confidence_threshold(min_confidence=0.8) + + assert req.description is not None + assert "0.8" in req.description or "confidence" in req.description.lower() + + def test_entity_confidence_threshold_with_various_values(self): + """Test entity_confidence_threshold with various threshold values.""" + thresholds = [0.0, 0.5, 0.7, 0.9, 1.0] + + for threshold in thresholds: + req = entity_confidence_threshold(min_confidence=threshold) + assert req is not None + assert req.description is not None + + def test_entity_confidence_threshold_low(self): + """Test entity_confidence_threshold with low threshold.""" + req = entity_confidence_threshold(min_confidence=0.3) + + assert req is not None + assert "0.3" in req.description or "confidence" in req.description.lower() + + def test_entity_confidence_threshold_high(self): + """Test entity_confidence_threshold with high threshold.""" + req = entity_confidence_threshold(min_confidence=0.99) + + assert req is not None + assert "0.99" in req.description or "confidence" in req.description.lower() + + +class TestRequirementFactoriesConsistency: + """Tests for consistency of requirement factories.""" + + def test_all_requirement_factories_return_non_none(self): + """Test that all requirement factories return non-None values.""" + factories = [ + (entity_type_valid, [["Person"]]), + (relation_type_valid, [["KNOWS"]]), + (relation_entities_exist, [["Alice", "Bob"]]), + (entity_confidence_threshold, [0.8]), + ] + + for factory, args in factories: + result = factory(*args) + assert result is not None + + def test_all_requirement_factories_have_descriptions(self): + """Test that all requirements have descriptions.""" + factories_and_args = [ + (entity_type_valid, [["Person"]]), + (relation_type_valid, [["KNOWS"]]), + (relation_entities_exist, [["Alice", "Bob"]]), + (entity_confidence_threshold, [0.8]), + ] + + for factory, args in factories_and_args: + result = factory(*args) + assert hasattr(result, "description") or callable(result) diff --git a/test/kg/test_phase1_updater_models.py b/test/kg/test_phase1_updater_models.py new file mode 100644 index 0000000..a135946 --- /dev/null +++ b/test/kg/test_phase1_updater_models.py @@ -0,0 +1,388 @@ +"""Tests for Phase 1 KG update configuration and result models.""" + +import pytest +from datetime import datetime + +from mellea_contribs.kg.updater_models import ( + UpdateConfig, + UpdateSessionConfig, + UpdateStats, + MergeConflict, + UpdateResult, + UpdateBatchResult, +) + + +class TestUpdateConfig: + """Tests for UpdateConfig model.""" + + def test_create_update_config_defaults(self): + """Test creating UpdateConfig with defaults.""" + config = UpdateConfig() + + assert config.batch_size == 32 + assert config.merge_strategy == "merge_if_similar" + assert config.similarity_threshold == 0.8 + + def test_create_update_config_custom(self): + """Test creating UpdateConfig with custom values.""" + config = UpdateConfig( + batch_size=64, + merge_strategy="skip", + similarity_threshold=0.75, + domain="movie", + ) + + assert config.batch_size == 64 + assert config.merge_strategy == "skip" + assert config.similarity_threshold == 0.75 + assert config.domain == "movie" + + def test_update_config_merge_strategies(self): + """Test UpdateConfig with different merge strategies.""" + strategies = ["merge_if_similar", "skip", "overwrite", "create_variant"] + + for strategy in strategies: + config = UpdateConfig(merge_strategy=strategy) + assert config.merge_strategy == strategy + + def test_update_config_entity_types(self): + """Test UpdateConfig with entity and relation types.""" + config = UpdateConfig( + entity_types=["Person", "Movie", "Award"], + relation_types=["directed_by", "acted_in", "won"], + ) + + assert config.entity_types == ["Person", "Movie", "Award"] + assert config.relation_types == ["directed_by", "acted_in", "won"] + + +class TestUpdateSessionConfig: + """Tests for UpdateSessionConfig model.""" + + def test_create_update_session_config_defaults(self): + """Test creating UpdateSessionConfig with defaults.""" + config = UpdateSessionConfig() + + assert config.extraction_model is not None + assert config.alignment_model is not None + assert config.merge_decision_model is not None + + def test_create_update_session_config_custom(self): + """Test creating UpdateSessionConfig with custom values.""" + config = UpdateSessionConfig( + extraction_model="gpt-4", + extraction_temperature=0.1, + alignment_model="claude-3-sonnet", + alignment_temperature=0.2, + merge_decision_model="gpt-4-turbo", + merge_decision_temperature=0.1, + ) + + assert config.extraction_model == "gpt-4" + assert config.extraction_temperature == 0.1 + assert config.alignment_model == "claude-3-sonnet" + assert config.alignment_temperature == 0.2 + assert config.merge_decision_model == "gpt-4-turbo" + assert config.merge_decision_temperature == 0.1 + + +class TestUpdateStats: + """Tests for UpdateStats model.""" + + def test_create_update_stats_defaults(self): + """Test creating UpdateStats with defaults.""" + stats = UpdateStats() + + assert stats.entities_extracted == 0 + assert stats.entities_aligned == 0 + assert stats.entities_merged == 0 + assert stats.relations_extracted == 0 + + def test_create_update_stats_custom(self): + """Test creating UpdateStats with custom values.""" + stats = UpdateStats( + entities_extracted=100, + entities_new=30, + entities_aligned=70, + entities_merged=50, + entities_skipped=20, + relations_extracted=80, + relations_new=20, + relations_aligned=60, + relations_merged=45, + relations_skipped=15, + processing_time_ms=5000, + ) + + assert stats.entities_extracted == 100 + assert stats.entities_new == 30 + assert stats.entities_aligned == 70 + assert stats.entities_merged == 50 + assert stats.relations_extracted == 80 + assert stats.relations_new == 20 + + +class TestMergeConflict: + """Tests for MergeConflict model.""" + + def test_create_merge_conflict(self): + """Test creating a MergeConflict.""" + conflict = MergeConflict( + source_id="entity-1", + target_id="entity-2", + conflict_type="entity_merge", + similarity_score=0.85, + decision="merged", + reason="High similarity score", + ) + + assert conflict.source_id == "entity-1" + assert conflict.target_id == "entity-2" + assert conflict.conflict_type == "entity_merge" + assert conflict.similarity_score == 0.85 + assert conflict.decision == "merged" + assert conflict.reason == "High similarity score" + + def test_merge_conflict_with_timestamp(self): + """Test MergeConflict with timestamp.""" + conflict = MergeConflict( + source_id="entity-1", + target_id="entity-2", + conflict_type="relation_merge", + similarity_score=0.9, + decision="skipped", + reason="User review needed", + timestamp="2026-03-08T10:30:00Z", + ) + + assert conflict.timestamp == "2026-03-08T10:30:00Z" + assert conflict.decision == "skipped" + + def test_merge_conflict_types(self): + """Test different merge conflict types.""" + types = ["entity_merge", "relation_merge", "property_conflict", "temporal_conflict"] + + for conflict_type in types: + conflict = MergeConflict( + source_id="a", + target_id="b", + conflict_type=conflict_type, + similarity_score=0.8, + decision="merged", + reason="Test", + ) + assert conflict.conflict_type == conflict_type + + +class TestUpdateResult: + """Tests for UpdateResult model.""" + + def test_create_update_result_minimal(self): + """Test creating UpdateResult with minimal fields.""" + result = UpdateResult( + document_id="doc-1", + entities_added=10, + entities_merged=5, + ) + + assert result.document_id == "doc-1" + assert result.entities_added == 10 + assert result.entities_merged == 5 + + def test_create_update_result_comprehensive(self): + """Test creating UpdateResult with all fields.""" + conflicts = [ + MergeConflict( + source_id="e1", + target_id="e2", + conflict_type="entity_merge", + similarity_score=0.85, + decision="merged", + reason="High similarity", + ) + ] + + result = UpdateResult( + document_id="doc-1", + entities_added=10, + entities_merged=5, + entities_skipped=2, + relations_added=8, + relations_merged=3, + relations_skipped=1, + conflicts=conflicts, + processing_time_ms=1000, + ) + + assert result.document_id == "doc-1" + assert result.entities_added == 10 + assert result.entities_merged == 5 + assert result.entities_skipped == 2 + assert result.relations_added == 8 + assert len(result.conflicts) == 1 + assert result.processing_time_ms == 1000 + + def test_update_result_tracking(self): + """Test UpdateResult for tracking changes.""" + result = UpdateResult( + document_id="doc-movie", + entities_added=15, + entities_merged=8, + relations_added=12, + relations_merged=6, + ) + + # Verify all tracking fields are present + assert result.entities_added > 0 + assert result.entities_merged > 0 + assert result.relations_added > 0 + assert result.relations_merged > 0 + + +class TestUpdateBatchResult: + """Tests for UpdateBatchResult model.""" + + def test_create_update_batch_result_defaults(self): + """Test creating UpdateBatchResult with defaults.""" + batch_result = UpdateBatchResult() + + assert batch_result.total_documents == 0 + assert batch_result.successful_documents == 0 + assert batch_result.failed_documents == 0 + assert batch_result.stats is not None + + def test_create_update_batch_result_custom(self): + """Test creating UpdateBatchResult with custom values.""" + stats = UpdateStats( + entities_extracted=100, + entities_merged=50, + relations_extracted=80, + relations_merged=40, + ) + + results = [ + UpdateResult( + document_id=f"doc-{i}", + entities_added=10, + entities_merged=5, + ) + for i in range(3) + ] + + batch_result = UpdateBatchResult( + total_documents=3, + successful_documents=3, + failed_documents=0, + stats=stats, + results=results, + total_time_ms=3000, + avg_time_per_document_ms=1000, + ) + + assert batch_result.total_documents == 3 + assert batch_result.successful_documents == 3 + assert batch_result.failed_documents == 0 + assert len(batch_result.results) == 3 + assert batch_result.total_time_ms == 3000 + + def test_update_batch_result_aggregation(self): + """Test UpdateBatchResult for aggregating multiple documents.""" + results = [ + UpdateResult( + document_id=f"doc-{i}", + entities_added=5, + entities_merged=2, + relations_added=3, + relations_merged=1, + ) + for i in range(5) + ] + + batch_result = UpdateBatchResult( + total_documents=5, + successful_documents=5, + results=results, + ) + + assert batch_result.total_documents == 5 + assert len(batch_result.results) == 5 + + +class TestUpdateIntegration: + """Integration tests for Update models.""" + + def test_update_config_with_session_config(self): + """Test using UpdateConfig with UpdateSessionConfig.""" + config = UpdateConfig( + batch_size=64, + merge_strategy="merge_if_similar", + domain="movie", + ) + + session = UpdateSessionConfig( + extraction_model="gpt-4", + merge_decision_model="claude-3", + ) + + assert config.domain == "movie" + assert config.batch_size == 64 + assert session.extraction_model == "gpt-4" + + def test_merge_conflict_in_update_result(self): + """Test MergeConflict within UpdateResult.""" + conflicts = [ + MergeConflict( + source_id="entity-alice", + target_id="entity-alice-2", + conflict_type="entity_merge", + similarity_score=0.92, + decision="merged", + reason="Duplicate person entry", + ), + MergeConflict( + source_id="relation-directed", + target_id="relation-directed-2", + conflict_type="relation_merge", + similarity_score=0.88, + decision="merged", + reason="Same relation type", + ), + ] + + result = UpdateResult( + document_id="doc-movie", + entities_added=5, + entities_merged=3, + relations_added=4, + relations_merged=2, + conflicts=conflicts, + ) + + assert len(result.conflicts) == 2 + assert result.conflicts[0].conflict_type == "entity_merge" + assert result.conflicts[1].conflict_type == "relation_merge" + + def test_update_result_in_batch_result(self): + """Test UpdateResult within UpdateBatchResult.""" + results = [ + UpdateResult( + document_id=f"doc-{i}", + entities_added=10 + i, + entities_merged=5 + i, + relations_added=8 + i, + relations_merged=3 + i, + ) + for i in range(3) + ] + + batch = UpdateBatchResult( + total_documents=3, + successful_documents=3, + results=results, + ) + + assert len(batch.results) == 3 + for i, result in enumerate(batch.results): + assert result.entities_added == 10 + i + assert result.entities_merged == 5 + i diff --git a/test/kg/utils/README.md b/test/kg/utils/README.md new file mode 100644 index 0000000..00e6837 --- /dev/null +++ b/test/kg/utils/README.md @@ -0,0 +1,158 @@ +# Phase 3 Utility Tests + +## Overview + +Comprehensive test suite for the Phase 3 utility modules in `mellea_contribs/kg/utils/`. This test suite ensures that all utility functions work correctly and integrates with the rest of the KG system. + +## Test Files + +### 1. test_data_utils.py (27 tests) +Tests for JSONL I/O and batch processing utilities. + +**Test Classes:** +- `TestLoadAndSaveJsonl` - Reading/writing JSONL files +- `TestAppendJsonl` - Appending items to JSONL files +- `TestBatchIterator` - Batch iteration functionality +- `TestTruncateJsonl` - Truncating JSONL files +- `TestShuffleJsonl` - Shuffling JSONL data +- `TestValidateJsonlSchema` - Schema validation +- `TestIntegration` - Integration workflows + +**Coverage:** +- ✓ Empty files, single items, multiple items +- ✓ Error handling (nonexistent files, invalid JSON) +- ✓ Directory creation +- ✓ Batch sizes (exact, uneven, edge cases) +- ✓ Truncation at various limits +- ✓ Schema validation with missing fields +- ✓ Integration workflows + +### 2. test_eval_utils.py (26 tests) +Tests for evaluation metrics and result aggregation. + +**Test Classes:** +- `TestExactMatch` - Exact string matching +- `TestFuzzyMatch` - Fuzzy string matching +- `TestMeanReciprocalRank` - MRR computation +- `TestPrecisionRecall` - Precision and recall metrics +- `TestF1Score` - F1 score computation +- `TestAggregateQaResults` - QA result aggregation +- `TestAggregateUpdateResults` - Update result aggregation +- `TestIntegration` - Integration workflows + +**Coverage:** +- ✓ Case-insensitive matching +- ✓ Whitespace handling +- ✓ Threshold sensitivity +- ✓ Empty results handling +- ✓ Perfect/partial/no matches +- ✓ Confidence-based ranking +- ✓ Classification metrics (precision, recall, F1) +- ✓ Result aggregation with errors + +### 3. test_progress.py (23 tests) +Tests for logging, progress tracking, and structured output. + +**Test Classes:** +- `TestSetupLogging` - Logging configuration +- `TestLogProgress` - Progress message logging +- `TestOutputJson` - JSON output +- `TestPrintStats` - Statistics formatting +- `TestProgressTracker` - Progress tracking class +- `TestIntegration` - Integration workflows + +**Coverage:** +- ✓ Logging levels (DEBUG, INFO, WARNING, ERROR) +- ✓ File logging +- ✓ JSON serialization of Pydantic models +- ✓ Statistics pretty-printing +- ✓ Indentation and formatting +- ✓ ProgressTracker initialization and updates +- ✓ Integration workflows + +### 4. test_session_manager.py (19 tests) +Tests for session and backend creation. + +**Test Classes:** +- `TestCreateBackend` - Backend factory function +- `TestCreateSession` - Session factory function +- `TestMelleaResourceManager` - Async resource manager +- `TestIntegration` - Integration workflows + +**Coverage:** +- ✓ Mock backend creation +- ✓ Default parameters +- ✓ Invalid backend types +- ✓ Neo4j backend (when available) +- ✓ Custom parameters +- ✓ Async context manager cleanup +- ✓ Integration workflows + +## Running Tests + +### Run all Phase 3 tests: +```bash +pytest test/kg/utils/ -v +``` + +### Run specific test file: +```bash +pytest test/kg/utils/test_data_utils.py -v +``` + +### Run specific test class: +```bash +pytest test/kg/utils/test_data_utils.py::TestLoadAndSaveJsonl -v +``` + +### Run specific test: +```bash +pytest test/kg/utils/test_data_utils.py::TestLoadAndSaveJsonl::test_save_and_load_single_item -v +``` + +## Test Statistics + +- **Total Tests:** 95 +- **Passed:** 95 ✓ +- **Failed:** 0 +- **Coverage:** Comprehensive coverage of all Phase 3 utilities + +## Test Strategy + +### Data Utilities (data_utils.py) +- **Boundary testing:** Empty files, single items, multiple items +- **Error handling:** Invalid JSON, missing files +- **Integration:** Workflows combining multiple utilities +- **Edge cases:** Batch sizes, truncation limits, shuffling order preservation + +### Evaluation Utilities (eval_utils.py) +- **Matching:** Exact, fuzzy, case-insensitive, whitespace handling +- **Metrics:** Precision, recall, F1, MRR computation +- **Aggregation:** Batch result aggregation with errors +- **Edge cases:** Empty results, perfect/no matches, confidence scoring + +### Progress Utilities (progress.py) +- **Logging:** Multiple levels, file output +- **Output:** JSON serialization, pretty-printing +- **Tracking:** Progress bar updates +- **Formatting:** Indentation, stats display + +### Session Manager (session_manager.py) +- **Backend creation:** Mock, Neo4j (conditional) +- **Session creation:** Default/custom parameters +- **Resource management:** Async context managers +- **Integration:** Workflow testing + +## Dependencies + +- pytest +- pytest-asyncio +- mellea_contribs.kg utilities +- Pydantic models (QAStats, UpdateStats, etc.) + +## Notes + +- All tests use temporary directories for file operations +- Async tests properly use pytest-asyncio +- Tests gracefully handle optional dependencies (Neo4j) +- Integration tests verify workflows combining multiple utilities diff --git a/test/kg/utils/__init__.py b/test/kg/utils/__init__.py new file mode 100644 index 0000000..92a74e0 --- /dev/null +++ b/test/kg/utils/__init__.py @@ -0,0 +1 @@ +"""Tests for mellea_contribs.kg.utils module.""" diff --git a/test/kg/utils/test_data_utils.py b/test/kg/utils/test_data_utils.py new file mode 100644 index 0000000..1b06799 --- /dev/null +++ b/test/kg/utils/test_data_utils.py @@ -0,0 +1,321 @@ +"""Tests for mellea_contribs.kg.utils.data_utils module.""" + +import json +import random +from pathlib import Path + +import pytest + +from mellea_contribs.kg.utils import ( + append_jsonl, + batch_iterator, + load_jsonl, + save_jsonl, + shuffle_jsonl, + truncate_jsonl, + validate_jsonl_schema, +) + + +@pytest.fixture +def temp_dir(tmp_path): + """Create temporary directory for tests.""" + return tmp_path + + +class TestLoadAndSaveJsonl: + """Tests for load_jsonl and save_jsonl functions.""" + + def test_save_and_load_empty_list(self, temp_dir): + """Test saving and loading empty JSONL.""" + path = temp_dir / "empty.jsonl" + data = [] + save_jsonl(data, path) + loaded = list(load_jsonl(path)) + assert loaded == [] + + def test_save_and_load_single_item(self, temp_dir): + """Test saving and loading single item.""" + path = temp_dir / "single.jsonl" + data = [{"id": 1, "name": "test"}] + save_jsonl(data, path) + loaded = list(load_jsonl(path)) + assert loaded == data + + def test_save_and_load_multiple_items(self, temp_dir): + """Test saving and loading multiple items.""" + path = temp_dir / "multiple.jsonl" + data = [ + {"id": 1, "value": "a"}, + {"id": 2, "value": "b"}, + {"id": 3, "value": "c"}, + ] + save_jsonl(data, path) + loaded = list(load_jsonl(path)) + assert loaded == data + + def test_load_nonexistent_file(self, temp_dir): + """Test loading nonexistent file raises error.""" + path = temp_dir / "nonexistent.jsonl" + with pytest.raises(FileNotFoundError): + list(load_jsonl(path)) + + def test_load_invalid_json(self, temp_dir): + """Test loading invalid JSON raises error.""" + path = temp_dir / "invalid.jsonl" + path.write_text("not valid json\n") + with pytest.raises(json.JSONDecodeError): + list(load_jsonl(path)) + + def test_save_creates_parent_directory(self, temp_dir): + """Test that save_jsonl creates parent directories.""" + path = temp_dir / "subdir" / "nested" / "file.jsonl" + data = [{"test": 1}] + save_jsonl(data, path) + assert path.exists() + loaded = list(load_jsonl(path)) + assert loaded == data + + +class TestAppendJsonl: + """Tests for append_jsonl function.""" + + def test_append_to_empty_file(self, temp_dir): + """Test appending to empty file.""" + path = temp_dir / "append.jsonl" + append_jsonl({"id": 1}, path) + loaded = list(load_jsonl(path)) + assert loaded == [{"id": 1}] + + def test_append_to_existing_file(self, temp_dir): + """Test appending to existing file.""" + path = temp_dir / "append.jsonl" + save_jsonl([{"id": 1}], path) + append_jsonl({"id": 2}, path) + loaded = list(load_jsonl(path)) + assert loaded == [{"id": 1}, {"id": 2}] + + def test_append_multiple_items(self, temp_dir): + """Test appending multiple items.""" + path = temp_dir / "append.jsonl" + for i in range(5): + append_jsonl({"id": i}, path) + loaded = list(load_jsonl(path)) + assert len(loaded) == 5 + assert loaded[0]["id"] == 0 + assert loaded[4]["id"] == 4 + + +class TestBatchIterator: + """Tests for batch_iterator function.""" + + def test_batch_exact_division(self): + """Test batch iterator with exact division.""" + items = list(range(9)) + batches = list(batch_iterator(items, 3)) + assert len(batches) == 3 + assert batches[0] == [0, 1, 2] + assert batches[1] == [3, 4, 5] + assert batches[2] == [6, 7, 8] + + def test_batch_uneven_division(self): + """Test batch iterator with uneven division.""" + items = list(range(10)) + batches = list(batch_iterator(items, 3)) + assert len(batches) == 4 + assert batches[0] == [0, 1, 2] + assert batches[3] == [9] + + def test_batch_single_item(self): + """Test batch iterator with batch size 1.""" + items = [1, 2, 3] + batches = list(batch_iterator(items, 1)) + assert len(batches) == 3 + assert all(len(b) == 1 for b in batches) + + def test_batch_empty_list(self): + """Test batch iterator with empty list.""" + batches = list(batch_iterator([], 3)) + assert batches == [] + + def test_batch_size_larger_than_list(self): + """Test batch iterator with batch size larger than list.""" + items = [1, 2, 3] + batches = list(batch_iterator(items, 10)) + assert len(batches) == 1 + assert batches[0] == [1, 2, 3] + + +class TestTruncateJsonl: + """Tests for truncate_jsonl function.""" + + def test_truncate_larger_than_file(self, temp_dir): + """Test truncating with limit larger than file.""" + input_path = temp_dir / "input.jsonl" + output_path = temp_dir / "output.jsonl" + data = [{"id": i} for i in range(5)] + save_jsonl(data, input_path) + + truncate_jsonl(input_path, output_path, 10) + loaded = list(load_jsonl(output_path)) + assert loaded == data + + def test_truncate_smaller_than_file(self, temp_dir): + """Test truncating with limit smaller than file.""" + input_path = temp_dir / "input.jsonl" + output_path = temp_dir / "output.jsonl" + data = [{"id": i} for i in range(10)] + save_jsonl(data, input_path) + + truncate_jsonl(input_path, output_path, 3) + loaded = list(load_jsonl(output_path)) + assert len(loaded) == 3 + assert loaded == data[:3] + + def test_truncate_zero(self, temp_dir): + """Test truncating with limit 0.""" + input_path = temp_dir / "input.jsonl" + output_path = temp_dir / "output.jsonl" + data = [{"id": i} for i in range(5)] + save_jsonl(data, input_path) + + truncate_jsonl(input_path, output_path, 0) + loaded = list(load_jsonl(output_path)) + assert loaded == [] + + +class TestShuffleJsonl: + """Tests for shuffle_jsonl function.""" + + def test_shuffle_preserves_all_items(self, temp_dir): + """Test that shuffle preserves all items.""" + input_path = temp_dir / "input.jsonl" + output_path = temp_dir / "output.jsonl" + data = [{"id": i} for i in range(20)] + save_jsonl(data, input_path) + + shuffle_jsonl(input_path, output_path) + loaded = list(load_jsonl(output_path)) + assert len(loaded) == len(data) + # Check all items are present (sorted to compare) + assert sorted(loaded, key=lambda x: x["id"]) == data + + def test_shuffle_changes_order(self, temp_dir): + """Test that shuffle changes order (probabilistically).""" + # Note: This test could theoretically fail if shuffle produces same order + # but probability is extremely low with 20 items + input_path = temp_dir / "input.jsonl" + output_path = temp_dir / "output.jsonl" + data = [{"id": i} for i in range(20)] + save_jsonl(data, input_path) + + shuffle_jsonl(input_path, output_path) + loaded = list(load_jsonl(output_path)) + # Extract just the order of IDs + loaded_ids = [item["id"] for item in loaded] + original_ids = list(range(20)) + # Very unlikely to get same order + assert loaded_ids != original_ids + + +class TestValidateJsonlSchema: + """Tests for validate_jsonl_schema function.""" + + def test_valid_schema(self, temp_dir): + """Test validation with valid schema.""" + path = temp_dir / "valid.jsonl" + data = [ + {"id": 1, "name": "a"}, + {"id": 2, "name": "b"}, + ] + save_jsonl(data, path) + + valid, errors = validate_jsonl_schema(path, ["id", "name"]) + assert valid is True + assert errors == [] + + def test_missing_required_field(self, temp_dir): + """Test validation with missing required field.""" + path = temp_dir / "invalid.jsonl" + data = [ + {"id": 1, "name": "a"}, + {"id": 2}, # Missing 'name' + ] + save_jsonl(data, path) + + valid, errors = validate_jsonl_schema(path, ["id", "name"]) + assert valid is False + assert len(errors) > 0 + assert "Line 2" in errors[0] + + def test_empty_file_valid(self, temp_dir): + """Test validation with empty file.""" + path = temp_dir / "empty.jsonl" + save_jsonl([], path) + + valid, errors = validate_jsonl_schema(path, ["id", "name"]) + assert valid is True + assert errors == [] + + def test_partial_schema_validation(self, temp_dir): + """Test validation with partial schema match.""" + path = temp_dir / "partial.jsonl" + data = [ + {"id": 1, "name": "a", "extra": "field"}, + {"id": 2, "name": "b", "extra": "field"}, + ] + save_jsonl(data, path) + + valid, errors = validate_jsonl_schema(path, ["id"]) + assert valid is True + assert errors == [] + + +class TestIntegration: + """Integration tests for data_utils functions.""" + + def test_workflow_create_truncate_load(self, temp_dir): + """Test workflow: create → truncate → load.""" + input_path = temp_dir / "input.jsonl" + output_path = temp_dir / "output.jsonl" + + # Create + data = [{"id": i, "value": f"item_{i}"} for i in range(100)] + save_jsonl(data, input_path) + + # Truncate + truncate_jsonl(input_path, output_path, 10) + + # Load and verify + loaded = list(load_jsonl(output_path)) + assert len(loaded) == 10 + assert loaded[0]["id"] == 0 + assert loaded[9]["id"] == 9 + + def test_workflow_batch_processing(self, temp_dir): + """Test workflow: create → batch processing.""" + path = temp_dir / "data.jsonl" + data = [{"id": i} for i in range(25)] + save_jsonl(data, path) + + loaded = list(load_jsonl(path)) + batches = list(batch_iterator(loaded, 5)) + assert len(batches) == 5 + assert all(len(b) == 5 for b in batches) + + def test_workflow_append_shuffle_validate(self, temp_dir): + """Test workflow: append → shuffle → validate.""" + path = temp_dir / "data.jsonl" + + # Append items + for i in range(10): + append_jsonl({"id": i, "status": "active"}, path) + + # Shuffle + output_path = temp_dir / "shuffled.jsonl" + shuffle_jsonl(path, output_path) + + # Validate + valid, errors = validate_jsonl_schema(output_path, ["id", "status"]) + assert valid is True + assert errors == [] diff --git a/test/kg/utils/test_eval_utils.py b/test/kg/utils/test_eval_utils.py new file mode 100644 index 0000000..04b8658 --- /dev/null +++ b/test/kg/utils/test_eval_utils.py @@ -0,0 +1,505 @@ +"""Tests for mellea_contribs.kg.utils.eval_utils module.""" + +import pytest + +from mellea_contribs.kg.qa_models import QAResult, QAStats +from mellea_contribs.kg.updater_models import UpdateResult, UpdateStats +from mellea_contribs.kg.utils import ( + aggregate_qa_results, + aggregate_update_results, + exact_match, + f1_score, + fuzzy_match, + mean_reciprocal_rank, + precision, + recall, +) + + +class TestExactMatch: + """Tests for exact_match function.""" + + def test_exact_match_identical(self): + """Test exact match with identical strings.""" + assert exact_match("hello", "hello") is True + + def test_exact_match_case_insensitive(self): + """Test exact match is case insensitive.""" + assert exact_match("Hello", "hello") is True + assert exact_match("WORLD", "world") is True + + def test_exact_match_with_whitespace(self): + """Test exact match handles whitespace.""" + assert exact_match(" hello ", "hello") is True + assert exact_match("hello", " hello ") is True + + def test_exact_match_different(self): + """Test exact match with different strings.""" + assert exact_match("hello", "world") is False + + def test_exact_match_partial(self): + """Test exact match with partial strings.""" + assert exact_match("hello", "hello world") is False + + def test_exact_match_empty_strings(self): + """Test exact match with empty strings.""" + assert exact_match("", "") is True + assert exact_match("", "test") is False + + +class TestFuzzyMatch: + """Tests for fuzzy_match function.""" + + def test_fuzzy_match_identical(self): + """Test fuzzy match with identical strings.""" + assert fuzzy_match("hello", "hello") is True + + def test_fuzzy_match_similar(self): + """Test fuzzy match with similar strings.""" + # Should match with high similarity + assert fuzzy_match("hello", "helo", threshold=0.7) is True + + def test_fuzzy_match_different(self): + """Test fuzzy match with different strings.""" + assert fuzzy_match("hello", "world", threshold=0.8) is False + + def test_fuzzy_match_threshold_sensitivity(self): + """Test fuzzy match threshold sensitivity.""" + # Same pair, different thresholds + pair1, pair2 = "testing", "test" + assert fuzzy_match(pair1, pair2, threshold=0.5) is True + assert fuzzy_match(pair1, pair2, threshold=0.99) is False + + def test_fuzzy_match_case_insensitive(self): + """Test fuzzy match is case insensitive.""" + assert fuzzy_match("Hello", "hello") is True + + +class TestMeanReciprocalRank: + """Tests for mean_reciprocal_rank function.""" + + def test_mrr_empty_results(self): + """Test MRR with empty results.""" + assert mean_reciprocal_rank([]) == 0.0 + + def test_mrr_all_exact_matches(self): + """Test MRR with all exact matches.""" + results = [ + {"answer": "correct", "expected": "correct"}, + {"answer": "right", "expected": "right"}, + {"answer": "yes", "expected": "yes"}, + ] + mrr = mean_reciprocal_rank(results) + assert mrr == 1.0 + + def test_mrr_no_matches(self): + """Test MRR with no matches.""" + results = [ + {"answer": "wrong", "expected": "correct"}, + {"answer": "bad", "expected": "good"}, + ] + mrr = mean_reciprocal_rank(results) + assert mrr == 0.0 + + def test_mrr_partial_matches(self): + """Test MRR with partial matches.""" + results = [ + {"answer": "correct", "expected": "correct"}, # Match + {"answer": "wrong", "expected": "right"}, # No match + {"answer": "okay", "expected": "okay"}, # Match + ] + mrr = mean_reciprocal_rank(results) + # Average of [1.0, 0.0, 1.0] = 0.667 + assert 0.6 < mrr < 0.7 + + def test_mrr_with_confidence(self): + """Test MRR considering confidence.""" + results = [ + {"answer": "maybe", "expected": "correct", "confidence": 0.95}, + {"answer": "no", "expected": "yes", "confidence": 0.1}, + ] + mrr = mean_reciprocal_rank(results) + # First has high confidence despite mismatch + assert 0.4 < mrr < 0.6 + + +class TestPrecisionRecall: + """Tests for precision and recall functions.""" + + def test_precision_perfect(self): + """Test precision with perfect predictions.""" + predicted = ["a", "b", "c"] + expected = ["a", "b", "c"] + assert precision(predicted, expected) == 1.0 + + def test_precision_partial(self): + """Test precision with partial overlap.""" + predicted = ["a", "b", "c"] + expected = ["a", "b"] + # TP = 2, TP+FP = 3, precision = 2/3 + assert abs(precision(predicted, expected) - 2/3) < 0.01 + + def test_precision_empty_predicted(self): + """Test precision with empty predictions.""" + predicted = [] + expected = ["a", "b"] + assert precision(predicted, expected) == 0.0 + + def test_recall_perfect(self): + """Test recall with perfect predictions.""" + predicted = ["a", "b", "c"] + expected = ["a", "b", "c"] + assert recall(predicted, expected) == 1.0 + + def test_recall_partial(self): + """Test recall with partial overlap.""" + predicted = ["a", "b"] + expected = ["a", "b", "c"] + # TP = 2, TP+FN = 3, recall = 2/3 + assert abs(recall(predicted, expected) - 2/3) < 0.01 + + def test_recall_empty_expected(self): + """Test recall with empty expected.""" + predicted = ["a", "b"] + expected = [] + assert recall(predicted, expected) == 0.0 + + +class TestF1Score: + """Tests for f1_score function.""" + + def test_f1_perfect(self): + """Test F1 with perfect precision and recall.""" + f1 = f1_score(1.0, 1.0) + assert f1 == 1.0 + + def test_f1_zero(self): + """Test F1 with zero precision and recall.""" + f1 = f1_score(0.0, 0.0) + assert f1 == 0.0 + + def test_f1_balanced(self): + """Test F1 with balanced precision and recall.""" + f1 = f1_score(0.8, 0.8) + assert abs(f1 - 0.8) < 0.01 + + def test_f1_imbalanced(self): + """Test F1 with imbalanced precision and recall.""" + prec = 0.9 + rec = 0.7 + expected = 2 * (prec * rec) / (prec + rec) + f1 = f1_score(prec, rec) + assert abs(f1 - expected) < 0.01 + + def test_f1_partial(self): + """Test F1 with partial scores.""" + f1 = f1_score(0.5, 0.5) + assert f1 == 0.5 + + +class TestAggregateQaResults: + """Tests for aggregate_qa_results function.""" + + def test_aggregate_empty_results(self): + """Test aggregating empty results.""" + results = [] + stats = aggregate_qa_results(results) + assert stats.total_questions == 0 + assert stats.successful_answers == 0 + + def test_aggregate_all_successful(self): + """Test aggregating all successful results.""" + results = [ + QAResult( + question="q1", + answer="a1", + confidence=0.9, + processing_time_ms=100.0, + model_used="gpt-4o-mini", + ), + QAResult( + question="q2", + answer="a2", + confidence=0.8, + processing_time_ms=120.0, + model_used="gpt-4o-mini", + ), + ] + stats = aggregate_qa_results(results) + assert stats.total_questions == 2 + assert stats.successful_answers == 2 + assert stats.failed_answers == 0 + + def test_aggregate_with_errors(self): + """Test aggregating results with errors.""" + results = [ + QAResult( + question="q1", + answer="a1", + confidence=0.9, + processing_time_ms=100.0, + ), + QAResult( + question="q2", + answer="", + error="Timeout", + confidence=0.0, + ), + ] + stats = aggregate_qa_results(results) + assert stats.total_questions == 2 + assert stats.successful_answers == 1 + assert stats.failed_answers == 1 + + def test_aggregate_confidence_stats(self): + """Test confidence aggregation.""" + results = [ + QAResult( + question="q1", + answer="a1", + confidence=0.9, + processing_time_ms=100.0, + ), + QAResult( + question="q2", + answer="a2", + confidence=0.7, + processing_time_ms=120.0, + ), + ] + stats = aggregate_qa_results(results) + # Average confidence = (0.9 + 0.7) / 2 = 0.8 + assert abs(stats.average_confidence - 0.8) < 0.01 + + +class TestAggregateUpdateResults: + """Tests for aggregate_update_results function.""" + + def test_aggregate_empty_results(self): + """Test aggregating empty update results.""" + results = [] + stats = aggregate_update_results(results) + assert stats.total_documents == 0 + assert stats.successful_documents == 0 + + def test_aggregate_all_successful(self): + """Test aggregating all successful updates.""" + results = [ + UpdateResult( + document_id="doc1", + success=True, + entities_found=5, + relations_found=3, + entities_added=5, + relations_added=3, + processing_time_ms=100.0, + ), + UpdateResult( + document_id="doc2", + success=True, + entities_found=7, + relations_found=4, + entities_added=7, + relations_added=4, + processing_time_ms=150.0, + ), + ] + stats = aggregate_update_results(results) + assert stats.total_documents == 2 + assert stats.successful_documents == 2 + assert stats.failed_documents == 0 + assert stats.entities_extracted == 12 + assert stats.relations_extracted == 7 + + def test_aggregate_with_failures(self): + """Test aggregating results with failures.""" + results = [ + UpdateResult( + document_id="doc1", + success=True, + entities_found=5, + relations_found=3, + entities_added=5, + relations_added=3, + ), + UpdateResult( + document_id="doc2", + success=False, + error="Parse error", + ), + ] + stats = aggregate_update_results(results) + assert stats.total_documents == 2 + assert stats.successful_documents == 1 + assert stats.failed_documents == 1 + + +class TestEvaluatePredictions: + """Tests for evaluate_predictions function.""" + + @pytest.mark.asyncio + async def test_exact_match(self): + """Test evaluate_predictions identifies exact matches.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + {"query": "What is 2+2?", "answer": "4", "answer_aliases": ["4", "four"]}, + ] + result = await evaluate_predictions(None, predictions) + assert result[0]["correct"] is True + assert result[0]["eval_method"] == "exact" + + @pytest.mark.asyncio + async def test_fuzzy_match(self): + """Test evaluate_predictions falls back to fuzzy match.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + {"query": "Capital of France?", "answer": "Pari", "answer_aliases": ["Paris"]}, + ] + result = await evaluate_predictions(None, predictions) + assert result[0]["correct"] is True + assert result[0]["eval_method"] == "fuzzy" + + @pytest.mark.asyncio + async def test_no_match(self): + """Test evaluate_predictions handles wrong answers.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + {"query": "What color is the sky?", "answer": "red", "answer_aliases": ["blue"]}, + ] + result = await evaluate_predictions(None, predictions) + assert result[0]["correct"] is False + + @pytest.mark.asyncio + async def test_missing_gold_answers(self): + """Test evaluate_predictions when no gold answers are provided.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [{"query": "What is X?", "answer": "something"}] + result = await evaluate_predictions(None, predictions) + assert result[0]["correct"] is False + assert result[0]["eval_method"] == "none" + + @pytest.mark.asyncio + async def test_preserves_original_fields(self): + """Test that evaluate_predictions keeps original dict fields.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + { + "query": "Q?", + "answer": "A", + "answer_aliases": ["A"], + "custom_field": "preserved", + } + ] + result = await evaluate_predictions(None, predictions) + assert result[0]["custom_field"] == "preserved" + assert "correct" in result[0] + assert "eval_method" in result[0] + + @pytest.mark.asyncio + async def test_custom_key_names(self): + """Test evaluate_predictions with non-default key names.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + {"question": "Q?", "prediction": "answer", "expected": ["answer"]}, + ] + result = await evaluate_predictions( + None, + predictions, + query_key="question", + answer_key="prediction", + gold_key="expected", + ) + assert result[0]["correct"] is True + + @pytest.mark.asyncio + async def test_string_gold_coerced_to_list(self): + """Test evaluate_predictions accepts a string for gold answers.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + {"query": "Q?", "answer": "Paris", "answer_aliases": "Paris"}, + ] + result = await evaluate_predictions(None, predictions) + assert result[0]["correct"] is True + + @pytest.mark.asyncio + async def test_multiple_predictions(self): + """Test evaluate_predictions processes all items in the list.""" + from mellea_contribs.kg.utils import evaluate_predictions + + predictions = [ + {"query": "Q1", "answer": "correct", "answer_aliases": ["correct"]}, + {"query": "Q2", "answer": "wrong", "answer_aliases": ["right"]}, + ] + result = await evaluate_predictions(None, predictions) + assert len(result) == 2 + assert result[0]["correct"] is True + assert result[1]["correct"] is False + + +class TestIntegration: + """Integration tests for eval_utils functions.""" + + def test_workflow_qa_evaluation(self): + """Test QA evaluation workflow.""" + # Create results + results = [ + QAResult( + question="What is 2+2?", + answer="4", + confidence=0.95, + processing_time_ms=50.0, + ), + QAResult( + question="What is the capital of France?", + answer="Paris", + confidence=0.9, + processing_time_ms=60.0, + ), + ] + + # Aggregate + stats = aggregate_qa_results(results) + + # Verify + assert stats.total_questions == 2 + assert stats.successful_answers == 2 + assert 0.9 < stats.average_confidence < 0.95 + + def test_workflow_update_evaluation(self): + """Test update operation evaluation workflow.""" + # Create results + results = [ + UpdateResult( + document_id="doc1", + success=True, + entities_found=10, + relations_found=8, + entities_added=10, + relations_added=8, + processing_time_ms=200.0, + ), + UpdateResult( + document_id="doc2", + success=True, + entities_found=12, + relations_found=9, + entities_added=12, + relations_added=9, + processing_time_ms=250.0, + ), + ] + + # Aggregate + stats = aggregate_update_results(results) + + # Verify + assert stats.total_documents == 2 + assert stats.entities_extracted == 22 + assert stats.relations_extracted == 17 diff --git a/test/kg/utils/test_progress.py b/test/kg/utils/test_progress.py new file mode 100644 index 0000000..8d39626 --- /dev/null +++ b/test/kg/utils/test_progress.py @@ -0,0 +1,248 @@ +"""Tests for mellea_contribs.kg.utils.progress module.""" + +import io +import json +import logging +import sys +from unittest.mock import patch + +import pytest + +from mellea_contribs.kg.qa_models import QAStats +from mellea_contribs.kg.updater_models import UpdateStats +from mellea_contribs.kg.utils import ( + log_progress, + output_json, + print_stats, + setup_logging, +) + + +class TestSetupLogging: + """Tests for setup_logging function.""" + + def test_setup_logging_default(self): + """Test setting up logging with default parameters.""" + setup_logging() + logger = logging.getLogger("mellea_contribs.kg") + assert logger.level == logging.INFO + + def test_setup_logging_debug(self): + """Test setting up logging with DEBUG level.""" + setup_logging(log_level="DEBUG") + logger = logging.getLogger("mellea_contribs.kg") + assert logger.level == logging.DEBUG + + def test_setup_logging_warning(self): + """Test setting up logging with WARNING level.""" + setup_logging(log_level="WARNING") + logger = logging.getLogger("mellea_contribs.kg") + assert logger.level == logging.WARNING + + def test_setup_logging_with_file(self, tmp_path): + """Test setting up logging with file output.""" + log_file = tmp_path / "test.log" + setup_logging(log_level="INFO", log_file=str(log_file)) + logger = logging.getLogger("mellea_contribs.kg") + # File handler should be added + assert any(isinstance(h, logging.FileHandler) for h in logger.handlers) + + +class TestLogProgress: + """Tests for log_progress function.""" + + def test_log_progress_default(self, caplog): + """Test logging progress with default level.""" + with caplog.at_level(logging.INFO): + log_progress("Test message") + assert "Test message" in caplog.text + + def test_log_progress_debug(self, caplog): + """Test logging progress with DEBUG level.""" + with caplog.at_level(logging.DEBUG, logger="mellea_contribs.kg"): + log_progress("Debug message", level="DEBUG") + assert "Debug message" in caplog.text + + def test_log_progress_error(self, caplog): + """Test logging progress with ERROR level.""" + with caplog.at_level(logging.ERROR): + log_progress("Error message", level="ERROR") + assert "Error message" in caplog.text + + +class TestOutputJson: + """Tests for output_json function.""" + + def test_output_json_qa_stats(self, capsys): + """Test outputting QAStats as JSON.""" + stats = QAStats( + total_questions=10, + successful_answers=8, + failed_answers=2, + average_confidence=0.85, + ) + output_json(stats) + captured = capsys.readouterr() + output = json.loads(captured.out) + assert output["total_questions"] == 10 + assert output["successful_answers"] == 8 + + def test_output_json_update_stats(self, capsys): + """Test outputting UpdateStats as JSON.""" + stats = UpdateStats( + total_documents=5, + successful_documents=4, + failed_documents=1, + entities_extracted=20, + relations_extracted=15, + ) + output_json(stats) + captured = capsys.readouterr() + output = json.loads(captured.out) + assert output["total_documents"] == 5 + assert output["entities_extracted"] == 20 + + def test_output_json_preserves_structure(self, capsys): + """Test that JSON output preserves structure.""" + stats = QAStats( + total_questions=5, + models_used=["gpt-4o-mini", "claude-opus"], + ) + output_json(stats) + captured = capsys.readouterr() + output = json.loads(captured.out) + assert isinstance(output["models_used"], list) + assert len(output["models_used"]) == 2 + + +class TestPrintStats: + """Tests for print_stats function.""" + + def test_print_stats_qa_stats(self, capsys): + """Test printing QAStats.""" + stats = QAStats( + total_questions=10, + successful_answers=8, + failed_answers=2, + average_confidence=0.85, + ) + print_stats(stats, to_stderr=False) + captured = capsys.readouterr() + assert "Total Questions" in captured.out or "total_questions" in captured.out.lower() + assert "10" in captured.out + + def test_print_stats_update_stats(self, capsys): + """Test printing UpdateStats.""" + stats = UpdateStats( + total_documents=5, + successful_documents=4, + failed_documents=1, + ) + print_stats(stats, to_stderr=False) + captured = capsys.readouterr() + assert "5" in captured.out or "Total Documents" in captured.out.lower() + + def test_print_stats_to_stderr(self, capsys): + """Test that print_stats can output to stderr.""" + stats = QAStats(total_questions=10) + print_stats(stats, to_stderr=True) + captured = capsys.readouterr() + # Should be in stderr when to_stderr=True + assert "10" in captured.err or "10" in captured.out + + def test_print_stats_formatting(self, capsys): + """Test formatting of stats output.""" + stats = QAStats( + total_questions=100, + average_confidence=0.7531, + ) + print_stats(stats, to_stderr=False) + captured = capsys.readouterr() + # Should format float nicely + assert "0.75" in captured.out or "75" in captured.out + + def test_print_stats_with_indentation(self, capsys): + """Test print_stats with indentation.""" + stats = QAStats(total_questions=5) + print_stats(stats, indent=2, to_stderr=False) + captured = capsys.readouterr() + # Should have some indentation + output_lines = captured.out.split("\n") + # At least some lines should be indented + assert any(line.startswith(" ") for line in output_lines if line.strip()) + + +class TestProgressTracker: + """Tests for ProgressTracker class.""" + + def test_progress_tracker_initialization(self): + """Test ProgressTracker initialization.""" + from mellea_contribs.kg.utils import ProgressTracker + + tracker = ProgressTracker(total=100, desc="Processing") + assert tracker.total == 100 + assert tracker.current == 0 + + def test_progress_tracker_update(self): + """Test ProgressTracker update.""" + from mellea_contribs.kg.utils import ProgressTracker + + tracker = ProgressTracker(total=100, desc="Processing") + tracker.update(10) + assert tracker.current == 10 + tracker.update(5) + assert tracker.current == 15 + + def test_progress_tracker_close(self): + """Test ProgressTracker close.""" + from mellea_contribs.kg.utils import ProgressTracker + + tracker = ProgressTracker(total=100, desc="Processing") + tracker.update(50) + tracker.close() + # Should not raise error + + def test_progress_tracker_without_tqdm(self): + """Test ProgressTracker without tqdm.""" + from mellea_contribs.kg.utils import ProgressTracker + + tracker = ProgressTracker(total=100, use_tqdm=False) + tracker.update(50) + assert tracker.current == 50 + + +class TestIntegration: + """Integration tests for progress module.""" + + def test_workflow_setup_log_output(self, capsys): + """Test workflow: setup → log → output.""" + setup_logging(log_level="INFO") + log_progress("Starting processing") + + stats = QAStats(total_questions=20, successful_answers=18) + output_json(stats) + + captured = capsys.readouterr() + assert "Starting processing" in captured.err or len(captured.out) > 0 + assert "20" in captured.out + + def test_workflow_print_and_output_stats(self, capsys): + """Test printing and outputting same stats.""" + stats = QAStats( + total_questions=50, + successful_answers=45, + average_confidence=0.88, + ) + + # Print to inspect + print_stats(stats, to_stderr=False) + captured1 = capsys.readouterr() + + # Output as JSON + output_json(stats) + captured2 = capsys.readouterr() + + # Both should reference the stats + assert "50" in captured1.out + assert "50" in captured2.out + assert json.loads(captured2.out)["total_questions"] == 50 diff --git a/test/kg/utils/test_session_manager.py b/test/kg/utils/test_session_manager.py new file mode 100644 index 0000000..5e73069 --- /dev/null +++ b/test/kg/utils/test_session_manager.py @@ -0,0 +1,276 @@ +"""Tests for mellea_contribs.kg.utils.session_manager module.""" + +import pytest + +from mellea_contribs.kg.graph_dbs.mock import MockGraphBackend +from mellea_contribs.kg.utils import create_backend, create_session + + +class TestCreateBackend: + """Tests for create_backend function.""" + + def test_create_mock_backend(self): + """Test creating mock backend.""" + backend = create_backend(backend_type="mock") + assert isinstance(backend, MockGraphBackend) + + def test_create_backend_default_type(self): + """Test creating backend with default type.""" + backend = create_backend() + assert isinstance(backend, MockGraphBackend) + + def test_create_backend_invalid_type(self): + """Test creating backend with invalid type.""" + with pytest.raises(ValueError): + create_backend(backend_type="invalid") + + def test_create_neo4j_backend_not_available(self): + """Test Neo4j backend creation when not available.""" + # Neo4jBackend might not be installed + try: + backend = create_backend(backend_type="neo4j", neo4j_uri="bolt://localhost:7687") + # If we got here, Neo4j backend is available + assert backend is not None + except (SystemExit, ImportError): + # Expected if Neo4j backend is not available + pass + + +class TestCreateSession: + """Tests for create_session function.""" + + def test_create_session_default_params(self): + """Test creating session with default parameters.""" + session = create_session() + assert session is not None + # Session should be a MelleaSession instance + assert hasattr(session, "instruct") + + def test_create_session_custom_model(self): + """Test creating session with custom model.""" + session = create_session(model_id="gpt-4o-mini") + assert session is not None + + def test_create_session_custom_temperature(self): + """Test creating session with custom temperature.""" + session = create_session(temperature=0.5) + assert session is not None + + def test_create_session_litellm_backend(self): + """Test creating session with litellm backend.""" + session = create_session(backend_name="litellm", model_id="gpt-4o-mini") + assert session is not None + + +class TestMelleaResourceManager: + """Tests for MelleaResourceManager async context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_cleanup(self): + """Test that context manager cleans up resources.""" + from mellea_contribs.kg.utils import MelleaResourceManager + + manager = MelleaResourceManager(backend_type="mock") + async with manager as mgr: + assert mgr.session is not None + assert mgr.backend is not None + + # After exiting context, backend should be closed + # (MockGraphBackend.close() is async, so this just verifies it ran) + + @pytest.mark.asyncio + async def test_context_manager_with_neo4j_params(self): + """Test context manager with Neo4j parameters.""" + from mellea_contribs.kg.utils import MelleaResourceManager + + manager = MelleaResourceManager( + backend_type="mock", + model_id="gpt-4o-mini", + ) + async with manager as mgr: + assert mgr.session is not None + assert isinstance(mgr.backend, MockGraphBackend) + + +class TestCreateOpenAISession: + """Tests for create_openai_session function.""" + + def test_create_openai_session_returns_session(self): + """Test that create_openai_session returns a usable session.""" + from mellea_contribs.kg.utils.session_manager import create_openai_session + + session = create_openai_session() + assert session is not None + assert hasattr(session, "instruct") + + def test_create_openai_session_custom_model(self): + """Test creating OpenAI session with a different model ID.""" + from mellea_contribs.kg.utils.session_manager import create_openai_session + + session = create_openai_session(model_id="gpt-4o") + assert session is not None + + def test_create_openai_session_with_api_base(self): + """Test creating OpenAI session with explicit API base and key.""" + from mellea_contribs.kg.utils.session_manager import create_openai_session + + session = create_openai_session( + model_id="gpt-4o-mini", + api_base="http://localhost:8080/v1", + api_key="test-key", + ) + assert session is not None + + def test_create_openai_session_no_force_schema(self): + """Test creating session with force_openai_schema disabled.""" + from mellea_contribs.kg.utils.session_manager import create_openai_session + + session = create_openai_session(force_openai_schema=False) + assert session is not None + + +class TestCreateSessionFromEnv: + """Tests for create_session_from_env function.""" + + def test_returns_session_model_tuple(self): + """Test that create_session_from_env returns (session, model_id) tuple.""" + from mellea_contribs.kg.utils.session_manager import create_session_from_env + + result = create_session_from_env() + assert isinstance(result, tuple) + assert len(result) == 2 + + def test_default_model_id(self): + """Test that default model ID is used when env var not set.""" + from mellea_contribs.kg.utils.session_manager import create_session_from_env + + _, model_id = create_session_from_env(default_model="gpt-4o-mini") + assert model_id == "gpt-4o-mini" + + def test_model_from_env_var(self, monkeypatch): + """Test that MODEL_NAME env var overrides default model.""" + from mellea_contribs.kg.utils.session_manager import create_session_from_env + + monkeypatch.setenv("MODEL_NAME", "gpt-4o") + _, model_id = create_session_from_env() + assert model_id == "gpt-4o" + + def test_prefixed_env_vars(self, monkeypatch): + """Test that env_prefix correctly scopes environment variables.""" + from mellea_contribs.kg.utils.session_manager import create_session_from_env + + monkeypatch.setenv("EVAL_MODEL_NAME", "gpt-4-turbo") + _, model_id = create_session_from_env(default_model="gpt-4o-mini", env_prefix="EVAL_") + assert model_id == "gpt-4-turbo" + + def test_session_is_usable(self): + """Test that returned session has expected attributes.""" + from mellea_contribs.kg.utils.session_manager import create_session_from_env + + session, _ = create_session_from_env() + assert hasattr(session, "instruct") + + +class TestCreateEmbeddingClient: + """Tests for create_embedding_client function.""" + + def test_returns_client_or_none(self): + """Test that create_embedding_client does not raise exceptions.""" + from mellea_contribs.kg.utils.session_manager import create_embedding_client + + # Should return a client (openai installed) or None (not installed) + client = create_embedding_client() + # Both outcomes are acceptable; the key is no exception is raised + + def test_model_name_attached_to_client(self): + """Test that model name is stored as client._model_name.""" + from mellea_contribs.kg.utils.session_manager import create_embedding_client + + client = create_embedding_client(model_name="text-embedding-3-large") + if client is not None: + assert client._model_name == "text-embedding-3-large" + + def test_default_model_name(self): + """Test default embedding model name.""" + from mellea_contribs.kg.utils.session_manager import create_embedding_client + + client = create_embedding_client() + if client is not None: + assert client._model_name == "text-embedding-3-small" + + def test_custom_api_base(self): + """Test creating embedding client with custom endpoint.""" + from mellea_contribs.kg.utils.session_manager import create_embedding_client + + client = create_embedding_client( + api_base="http://localhost:8080/v1", + api_key="test-key", + model_name="my-embed-model", + ) + if client is not None: + assert client._model_name == "my-embed-model" + + +class TestGenerateEmbeddings: + """Tests for generate_embeddings function.""" + + @pytest.mark.asyncio + async def test_none_client_returns_nones(self): + """Test that None client returns a list of None values.""" + from mellea_contribs.kg.utils.session_manager import generate_embeddings + + result = await generate_embeddings(None, ["text1", "text2"]) + assert result == [None, None] + + @pytest.mark.asyncio + async def test_empty_texts_returns_empty(self): + """Test that empty input list returns empty list.""" + from mellea_contribs.kg.utils.session_manager import generate_embeddings + + result = await generate_embeddings(None, []) + assert result == [] + + @pytest.mark.asyncio + async def test_correct_length_with_none_client(self): + """Test that result length matches input length when client is None.""" + from mellea_contribs.kg.utils.session_manager import generate_embeddings + + texts = ["a", "b", "c", "d"] + result = await generate_embeddings(None, texts) + assert len(result) == len(texts) + assert all(v is None for v in result) + + @pytest.mark.asyncio + async def test_model_name_override(self): + """Test that explicit model_name overrides client._model_name.""" + from mellea_contribs.kg.utils.session_manager import generate_embeddings + + # With None client this returns nones regardless of model_name, + # but the call should not raise + result = await generate_embeddings(None, ["text"], model_name="custom-model") + assert result == [None] + + +class TestIntegration: + """Integration tests for session_manager functions.""" + + def test_workflow_create_session_and_backend(self): + """Test creating both session and backend.""" + backend = create_backend(backend_type="mock") + session = create_session(model_id="gpt-4o-mini") + assert backend is not None + assert session is not None + + @pytest.mark.asyncio + async def test_workflow_with_resource_manager(self): + """Test workflow using resource manager.""" + from mellea_contribs.kg.utils import MelleaResourceManager + + async with MelleaResourceManager(backend_type="mock") as manager: + # Should be able to access both + assert manager.session is not None + assert manager.backend is not None + + # Backend should be functional + schema = await manager.backend.get_schema() + assert isinstance(schema, dict) diff --git a/uv.lock b/uv.lock index 4daa455..d801d86 100644 --- a/uv.lock +++ b/uv.lock @@ -508,6 +508,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/96/d32b941a501ab566a16358d68b6eb4e4acc373fab3c3c4d7d9e649f7b4bb/catalogue-2.0.10-py3-none-any.whl", hash = "sha256:58c2de0020aa90f4a2da7dfad161bf7b3b054c86a5f09fcedc0b2b740c109a9f", size = 17325, upload-time = "2023-09-25T06:29:23.337Z" }, ] +[[package]] +name = "boto3" +version = "1.42.63" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7f/2a/33d5d4b16fd97dfd629421ebed2456392eae1553cc401d9f86010c18065e/boto3-1.42.63.tar.gz", hash = "sha256:cd008cfd0d7ea30f1c5e22daf0998c55b7c6c68cb68eea05110e33fe641173d5", size = 112778, upload-time = "2026-03-06T22:47:55.96Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/19/f1d8d2b24871d3d0ccb2cbd0b0cb64a3396d439384bd9643d2c25c641b84/boto3-1.42.63-py3-none-any.whl", hash = "sha256:d502a89a0acc701692ae020d15981f2a82e9eb3485acc651cfd0cf1a3afe79ee", size = 140554, upload-time = "2026-03-06T22:47:53.463Z" }, +] + +[[package]] +name = "botocore" +version = "1.42.63" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/af/eb/a1c042f6638ada552399a9977335a6de2668a85bf80bece193c953531236/botocore-1.42.63.tar.gz", hash = "sha256:1fdfc33cff58d21e8622cf620ba2bba3cff324557932aaf935b5374e4610f059", size = 14965362, upload-time = "2026-03-06T22:47:44.158Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9a/60/17a2d3b94658bb999c6aee7bba6c76b271905debf0c8c8e6ac63ca8491bc/botocore-1.42.63-py3-none-any.whl", hash = "sha256:83f39d04f2b316bdfc59a3cac2d12238bde7126ac99d9a57d910dbd86d58c528", size = 14639889, upload-time = "2026-03-06T22:47:39.347Z" }, +] + [[package]] name = "certifi" version = "2025.11.12" @@ -1653,15 +1681,15 @@ wheels = [ [[package]] name = "granite-common" -version = "0.3.5" +version = "0.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonschema" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4d/b8/cba7a2399079838f793cc138c5b341df965c82e7cdaaf4c37deeffa0a14c/granite_common-0.3.5.tar.gz", hash = "sha256:80d4251b9294b6ec234d5aa4e273801b66f7cc4c5bc77151e5c22e7d7f5a19cd", size = 273710, upload-time = "2025-11-15T01:32:38.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/10/ca8f59c644a3574a443bb85ff807f1ebbe726a6ad75bd471e092ab002f37/granite_common-0.4.1.tar.gz", hash = "sha256:5290e03d43e2962218aaf13c9c43877af6fb7869332a4ea35983c4f6a206d801", size = 714066, upload-time = "2026-02-25T01:10:45.253Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/f4/79d121e4192cf7871122995f060f891c6dbcb91b9b4753e4dd27852704c4/granite_common-0.3.5-py3-none-any.whl", hash = "sha256:fca8fdb7caff7f5714bfda9c81438f4a6df974e0b01631421cd6ca1a19bbb07e", size = 77551, upload-time = "2025-11-15T01:32:37.186Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ee/c52f5ddb073c111c19f15889646fd65e0be2b4e0b01764237d1ecbcd5bfe/granite_common-0.4.1-py3-none-any.whl", hash = "sha256:e82df48f69a98b46dbff8a36c10a64a13d4400fed2425f2ad9a6981031544062", size = 86633, upload-time = "2026-02-25T01:10:43.845Z" }, ] [[package]] @@ -2267,6 +2295,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2f/9c/6753e6522b8d0ef07d3a3d239426669e984fb0eba15a315cdbc1253904e4/jiter-0.12.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24e864cb30ab82311c6425655b0cdab0a98c5d973b065c66a3f020740c2324c", size = 346110, upload-time = "2025-11-09T20:49:21.817Z" }, ] +[[package]] +name = "jmespath" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, +] + [[package]] name = "joblib" version = "1.5.2" @@ -3089,7 +3126,7 @@ wheels = [ [[package]] name = "mellea" -version = "0.2.0" +version = "0.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ansicolors" }, @@ -3113,13 +3150,14 @@ dependencies = [ { name = "types-tqdm" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d3/8f/816ecf6417fc2b2ba5ac44f5057d18337f05e8b487cbf82f779c74a63d91/mellea-0.2.0.tar.gz", hash = "sha256:d33d7c0faa33183cb35dec38073d2aef91e279baa20400bbdf51f026363c7ad0", size = 195187, upload-time = "2025-11-19T17:49:47.778Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/e9/25d87d92064b9781ef3e15d032f6f50deb2ff70b35d3c27f21ff595666ca/mellea-0.3.2.tar.gz", hash = "sha256:b73c5c1da473891e85005042a8e1ac26eae4026f448306884de3c8ba56df1965", size = 3583161, upload-time = "2026-02-26T13:43:30.979Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/4a/15e44c80a5c4caa5848c53e918aac1e0fde7be00883221a3bdb993fae31b/mellea-0.2.0-py3-none-any.whl", hash = "sha256:9f7409e7bda660afe80297120379619c4633eaf617c3e2dd78ea341291373a0e", size = 292958, upload-time = "2025-11-19T17:49:46.141Z" }, + { url = "https://files.pythonhosted.org/packages/55/d8/4bcf1174810518d7f9fb27ae8d77a9001bd975ef26d24f734f89df051ddc/mellea-0.3.2-py3-none-any.whl", hash = "sha256:ef9beb4b5b3f8c2099d17df592067646ecac2bc47b8f7f8b19951c3acbe37174", size = 3864719, upload-time = "2026-02-26T13:43:29.118Z" }, ] [package.optional-dependencies] litellm = [ + { name = "boto3" }, { name = "litellm" }, ] @@ -3132,6 +3170,7 @@ dependencies = [ { name = "eyecite" }, { name = "markdown" }, { name = "mellea", extra = ["litellm"] }, + { name = "neo4j" }, { name = "playwright" }, { name = "rapidfuzz" }, ] @@ -3159,6 +3198,12 @@ docs = [ { name = "sphinx-mdinclude" }, { name = "sphinx-rtd-theme" }, ] +kg = [ + { name = "neo4j" }, +] +kg-utils = [ + { name = "tqdm" }, +] notebook = [ { name = "ipykernel" }, { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -3172,7 +3217,8 @@ requires-dist = [ { name = "citeurl" }, { name = "eyecite" }, { name = "markdown" }, - { name = "mellea", extras = ["litellm"] }, + { name = "mellea", extras = ["litellm"], specifier = ">=0.3.0" }, + { name = "neo4j", specifier = ">=6.1.0" }, { name = "playwright" }, { name = "rapidfuzz", specifier = ">=3.14.3" }, ] @@ -3195,6 +3241,8 @@ docs = [ { name = "sphinx-mdinclude" }, { name = "sphinx-rtd-theme" }, ] +kg = [{ name = "neo4j", specifier = ">=5.0.0" }] +kg-utils = [{ name = "tqdm", specifier = ">=4.65.0" }] notebook = [ { name = "ipykernel", specifier = ">=6.29.5" }, { name = "ipython", specifier = ">=8.36.0" }, @@ -3612,6 +3660,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454, upload-time = "2024-04-04T11:20:34.895Z" }, ] +[[package]] +name = "neo4j" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1b/01/d6ce65e4647f6cb2b9cca3b813978f7329b54b4e36660aaec1ddf0ccce7a/neo4j-6.1.0.tar.gz", hash = "sha256:b5dde8c0d8481e7b6ae3733569d990dd3e5befdc5d452f531ad1884ed3500b84", size = 239629, upload-time = "2026-01-12T11:27:34.777Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/5c/ee71e2dd955045425ef44283f40ba1da67673cf06404916ca2950ac0cd39/neo4j-6.1.0-py3-none-any.whl", hash = "sha256:3bd93941f3a3559af197031157220af9fd71f4f93a311db687bd69ffa417b67d", size = 325326, upload-time = "2026-01-12T11:27:33.196Z" }, +] + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -6032,6 +6092,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" }, ] +[[package]] +name = "s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, +] + [[package]] name = "secretstorage" version = "3.5.0"