diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..218f7468 --- /dev/null +++ b/.env.example @@ -0,0 +1,22 @@ +# Core OpenAI configuration +OPENAI_API_KEY=your-openai-key +OPENAI_BASE_URL=https://api.openai.com/v1 + +# Demo defaults (OpenAI-only) +CMM_VECTORSTORE_BACKEND=chroma +CMM_EMBEDDING_MODEL=openai:text-embedding-3-large +CMM_EMBEDDING_DIMENSIONS=3072 +CMM_USE_RERANKER=false +CMM_RERANKER_PROVIDER=none +CMM_HYBRID_ALPHA=0.7 +CMM_VECTORSTORE_COLLECTION=cmm_chunks + +# Optional Weaviate backend (enabled only if provided) +CMM_WEAVIATE_URL= +CMM_WEAVIATE_API_KEY= + +# Optional Cohere reranker (enabled only if provider is cohere) +COHERE_API_KEY= + +# Legacy compatibility toggle +URSA_RAG_LEGACY_MODE=false diff --git a/.gitignore b/.gitignore index 44016bfc..d33ba26f 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ arxiv_generated_summaries arxiv_papers *.sqfs ursa_workspace/ +cmm_demo_workspace/ +cmm_demo_outputs/ .vscode/settings.json scratch/ - diff --git a/README.md b/README.md index 31ef997b..09e50ac2 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,40 @@ Documentation for each URSA agent: Documentation for combining agents: - [ArXiv -> Execution for Materials](docs/combining_arxiv_and_execution.md) - [ArXiv -> Execution for Neutron Star Properties](docs/combining_arxiv_and_execution_neutronStar.md) +- [Critical Minerals Workflow Architecture](docs/critical_minerals_workflow.md) +- [CMM Setup and Usage Guide for Domain Scientists](docs/domain_scientist_setup_and_usage.md) + +## Critical Minerals Demo (RAG + Optimization) + +URSA includes a CMM demo path with adaptive RAG retrieval and deterministic +tool-calling optimization. + +1. Copy env defaults: +```bash +cp .env.example .env +``` +2. Configure at minimum: +```bash +OPENAI_API_KEY=... +OPENAI_BASE_URL=https://api.openai.com/v1 +CMM_VECTORSTORE_BACKEND=chroma +CMM_EMBEDDING_MODEL=openai:text-embedding-3-large +CMM_EMBEDDING_DIMENSIONS=3072 +CMM_USE_RERANKER=false +CMM_RERANKER_PROVIDER=none +``` +3. Reindex local corpus: +```bash +uv run python scripts/reindex.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path cmm_vectorstore \ + --backend chroma \ + --embedding-model openai:text-embedding-3-large \ + --embedding-dimensions 3072 \ + --reset +``` +4. Run workflow with `local_corpus_path`, `rag_context`, and optional +`optimization_input`. ## Command Line Usage diff --git a/configs/cmm_demo_scenarios.json b/configs/cmm_demo_scenarios.json new file mode 100644 index 00000000..a929ca64 --- /dev/null +++ b/configs/cmm_demo_scenarios.json @@ -0,0 +1,215 @@ +{ + "ndfeb_la_y_5pct_baseline": { + "task": "Assess Nd2Fe14B magnet supply and sourcing resilience when the alloy specification is fixed at 5% lanthanum and 5% yttrium impurity substitution in the rare-earth fraction.", + "rag_context": "Focus on metallurgical performance implications, coercivity/remanence tradeoffs, processing constraints, and procurement implications of 5% La + 5% Y impurity levels in Nd2Fe14B.", + "execution_instruction": "Produce a source-grounded decision brief for a U.S. magnet manufacturing audience, including assumptions, uncertainty notes, and near-term sourcing recommendations under the 5% La and 5% Y impurity target.", + "source_queries": {}, + "optimization_input": { + "commodity": "ND2FE14B_LA5_Y5", + "demand": { + "US_DEFENSE": 120, + "US_EV": 180, + "EU_OEM": 100 + }, + "suppliers": [ + { + "name": "domestic_recycled_blend", + "capacity": 140, + "unit_cost": 92.0, + "risk_score": 0.2, + "composition_profile": { + "LA": 0.072, + "Y": 0.028 + } + }, + { + "name": "allied_separated_oxide", + "capacity": 190, + "unit_cost": 98.0, + "risk_score": 0.13, + "composition_profile": { + "LA": 0.041, + "Y": 0.059 + } + }, + { + "name": "integrated_allied_metal", + "capacity": 230, + "unit_cost": 106.0, + "risk_score": 0.09, + "composition_profile": { + "LA": 0.05, + "Y": 0.05 + } + } + ], + "shipping_cost": { + "domestic_recycled_blend": { + "US_DEFENSE": 1.2, + "US_EV": 1.1, + "EU_OEM": 2.8 + }, + "allied_separated_oxide": { + "US_DEFENSE": 2.1, + "US_EV": 2.2, + "EU_OEM": 1.6 + }, + "integrated_allied_metal": { + "US_DEFENSE": 2.8, + "US_EV": 2.7, + "EU_OEM": 1.5 + } + }, + "risk_weight": 3.2, + "unmet_demand_penalty": 18000, + "max_supplier_share": 0.72, + "composition_targets": { + "LA": 0.05, + "Y": 0.05 + }, + "composition_tolerance": 0.005 + } + }, + "ndfeb_la_y_5pct_quality_tightening": { + "task": "Evaluate procurement and production risk if Nd2Fe14B with 5% La and 5% Y impurities must meet tighter quality assurance windows for critical motor applications.", + "rag_context": "Focus on NdFeB quality control evidence for La/Y-substituted magnets, including coercivity and remanence variability, grain boundary diffusion processing, process capability (Cp/Cpk), lot acceptance testing, and supplier QA qualification under a fixed 5% La + 5% Y composition target.", + "execution_instruction": "Return an executive memo plus technical appendix bullets focused on quality assurance risk controls, supplier qualification sequencing, and contingency actions.", + "source_queries": {}, + "optimization_input": { + "commodity": "ND2FE14B_LA5_Y5_QA", + "demand": { + "US_DEFENSE": 130, + "US_EV": 170, + "EU_OEM": 95 + }, + "suppliers": [ + { + "name": "domestic_recycled_blend", + "capacity": 120, + "unit_cost": 95.0, + "risk_score": 0.22, + "composition_profile": { + "LA": 0.068, + "Y": 0.032 + } + }, + { + "name": "allied_separated_oxide", + "capacity": 175, + "unit_cost": 100.0, + "risk_score": 0.14, + "composition_profile": { + "LA": 0.043, + "Y": 0.057 + } + }, + { + "name": "integrated_allied_metal", + "capacity": 210, + "unit_cost": 109.0, + "risk_score": 0.1, + "composition_profile": { + "LA": 0.05, + "Y": 0.05 + } + } + ], + "shipping_cost": { + "domestic_recycled_blend": { + "US_DEFENSE": 1.1, + "US_EV": 1.0, + "EU_OEM": 2.9 + }, + "allied_separated_oxide": { + "US_DEFENSE": 2.0, + "US_EV": 2.1, + "EU_OEM": 1.7 + }, + "integrated_allied_metal": { + "US_DEFENSE": 2.7, + "US_EV": 2.6, + "EU_OEM": 1.6 + } + }, + "risk_weight": 3.7, + "unmet_demand_penalty": 22000, + "max_supplier_share": 0.65, + "composition_targets": { + "LA": 0.05, + "Y": 0.05 + }, + "composition_tolerance": 0.003 + } + }, + "ndfeb_la_y_5pct_supply_shock": { + "task": "Stress-test Nd2Fe14B supply posture under a logistics and refining disruption while holding composition at 5% La and 5% Y impurities.", + "rag_context": "Emphasize disruption pathways, substitution limits under fixed impurity composition, and practical mitigation levers for near-term delivery reliability.", + "execution_instruction": "Provide a red-team style risk narrative and rank-order mitigations by expected impact and implementation complexity for the 5% La/5% Y Nd2Fe14B specification.", + "source_queries": {}, + "optimization_input": { + "commodity": "ND2FE14B_LA5_Y5_SHOCK", + "demand": { + "US_DEFENSE": 145, + "US_EV": 210, + "EU_OEM": 110 + }, + "suppliers": [ + { + "name": "domestic_recycled_blend", + "capacity": 125, + "unit_cost": 96.0, + "risk_score": 0.24, + "composition_profile": { + "LA": 0.082, + "Y": 0.018 + } + }, + { + "name": "allied_separated_oxide", + "capacity": 145, + "unit_cost": 103.0, + "risk_score": 0.2, + "composition_profile": { + "LA": 0.037, + "Y": 0.063 + } + }, + { + "name": "integrated_allied_metal", + "capacity": 185, + "unit_cost": 112.0, + "risk_score": 0.16, + "composition_profile": { + "LA": 0.049, + "Y": 0.051 + } + } + ], + "shipping_cost": { + "domestic_recycled_blend": { + "US_DEFENSE": 1.4, + "US_EV": 1.3, + "EU_OEM": 3.2 + }, + "allied_separated_oxide": { + "US_DEFENSE": 3.0, + "US_EV": 3.1, + "EU_OEM": 2.2 + }, + "integrated_allied_metal": { + "US_DEFENSE": 3.4, + "US_EV": 3.3, + "EU_OEM": 2.0 + } + }, + "risk_weight": 4.2, + "unmet_demand_penalty": 26000, + "max_supplier_share": 0.62, + "composition_targets": { + "LA": 0.05, + "Y": 0.05 + }, + "composition_tolerance": 0.004 + } + } +} diff --git a/configs/nd_china_2025_scenarios.json b/configs/nd_china_2025_scenarios.json new file mode 100644 index 00000000..28c39987 --- /dev/null +++ b/configs/nd_china_2025_scenarios.json @@ -0,0 +1,284 @@ +{ + "nd_preshock_baseline": { + "task": "Assess NdPr oxide supply chain posture before China MOFCOM 2025 export controls take effect.", + "rag_context": "Focus on global NdPr oxide production capacity, existing trade flows, and baseline geopolitical risk for US/EU/JP/KR magnet manufacturing supply chains. Key suppliers: China consolidated (dominant), Lynas (Australia), MP Materials (USA), Neo Performance (EU), domestic recycled (USA).", + "execution_instruction": "Produce a baseline supply-chain assessment for NdPr oxide procurement showing allocation feasibility, cost structure, and supplier concentration risk before any export restrictions.", + "source_queries": {}, + "optimization_input": { + "commodity": "NDPR_OXIDE", + "demand": { + "US_DEFENSE": 800, + "US_COMMERCIAL": 3700, + "EU_AUTOMOTIVE": 3000, + "EU_INDUSTRIAL": 3500, + "JP_AUTOMOTIVE": 2000, + "KR_ELECTRONICS": 1500 + }, + "suppliers": [ + { + "name": "China_consolidated", + "capacity": 12000, + "unit_cost": 85.0, + "risk_score": 0.15 + }, + { + "name": "Lynas_Australia", + "capacity": 12000, + "unit_cost": 120.0, + "risk_score": 0.10 + }, + { + "name": "MP_Materials_USA", + "capacity": 1300, + "unit_cost": 110.0, + "risk_score": 0.05 + }, + { + "name": "Neo_Performance_EU", + "capacity": 2000, + "unit_cost": 120.0, + "risk_score": 0.20 + }, + { + "name": "Recycled_domestic_USA", + "capacity": 600, + "unit_cost": 70.0, + "risk_score": 0.08 + } + ], + "shipping_cost": { + "China_consolidated": { + "US_DEFENSE": 0.725, + "US_COMMERCIAL": 0.725, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 0.4, + "KR_ELECTRONICS": 0.35 + }, + "Lynas_Australia": { + "US_DEFENSE": 1.2, + "US_COMMERCIAL": 1.2, + "EU_AUTOMOTIVE": 1.725, + "EU_INDUSTRIAL": 1.725, + "JP_AUTOMOTIVE": 0.9, + "KR_ELECTRONICS": 0.95 + }, + "MP_Materials_USA": { + "US_DEFENSE": 0.1, + "US_COMMERCIAL": 0.1, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 1.3, + "KR_ELECTRONICS": 1.3 + }, + "Neo_Performance_EU": { + "US_DEFENSE": 1.0, + "US_COMMERCIAL": 1.0, + "EU_AUTOMOTIVE": 0.1, + "EU_INDUSTRIAL": 0.1, + "JP_AUTOMOTIVE": 1.5, + "KR_ELECTRONICS": 1.5 + }, + "Recycled_domestic_USA": { + "US_DEFENSE": 0.1, + "US_COMMERCIAL": 0.1, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 1.3, + "KR_ELECTRONICS": 1.3 + } + }, + "risk_weight": 5.0, + "unmet_demand_penalty": 50000, + "max_supplier_share": 0.50 + } + }, + "nd_post_april_2025": { + "task": "Model NdPr oxide supply chain impact of China MOFCOM April 2025 initial HREE export controls (Notice 2025 No. 61).", + "rag_context": "April 2025: China implements initial rare earth export licensing under MOFCOM Notice 2025 No. 61. Chinese NdPr oxide export capacity drops from 12,000 to 7,000 t/yr as license approvals slow. Prices rise across all suppliers due to market tightening. Risk weight increases to reflect supply uncertainty.", + "execution_instruction": "Produce a post-April 2025 impact assessment showing how initial Chinese export controls redistribute NdPr oxide flows, increase costs, and create capacity pressure on alternative suppliers.", + "source_queries": {}, + "optimization_input": { + "commodity": "NDPR_OXIDE", + "demand": { + "US_DEFENSE": 800, + "US_COMMERCIAL": 3700, + "EU_AUTOMOTIVE": 3000, + "EU_INDUSTRIAL": 3500, + "JP_AUTOMOTIVE": 2000, + "KR_ELECTRONICS": 1500 + }, + "suppliers": [ + { + "name": "China_consolidated", + "capacity": 7000, + "unit_cost": 95.0, + "risk_score": 0.50 + }, + { + "name": "Lynas_Australia", + "capacity": 12000, + "unit_cost": 126.0, + "risk_score": 0.10 + }, + { + "name": "MP_Materials_USA", + "capacity": 1300, + "unit_cost": 115.0, + "risk_score": 0.05 + }, + { + "name": "Neo_Performance_EU", + "capacity": 2000, + "unit_cost": 126.0, + "risk_score": 0.20 + }, + { + "name": "Recycled_domestic_USA", + "capacity": 600, + "unit_cost": 75.0, + "risk_score": 0.08 + } + ], + "shipping_cost": { + "China_consolidated": { + "US_DEFENSE": 0.725, + "US_COMMERCIAL": 0.725, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 0.4, + "KR_ELECTRONICS": 0.35 + }, + "Lynas_Australia": { + "US_DEFENSE": 1.2, + "US_COMMERCIAL": 1.2, + "EU_AUTOMOTIVE": 1.725, + "EU_INDUSTRIAL": 1.725, + "JP_AUTOMOTIVE": 0.9, + "KR_ELECTRONICS": 0.95 + }, + "MP_Materials_USA": { + "US_DEFENSE": 0.1, + "US_COMMERCIAL": 0.1, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 1.3, + "KR_ELECTRONICS": 1.3 + }, + "Neo_Performance_EU": { + "US_DEFENSE": 1.0, + "US_COMMERCIAL": 1.0, + "EU_AUTOMOTIVE": 0.1, + "EU_INDUSTRIAL": 0.1, + "JP_AUTOMOTIVE": 1.5, + "KR_ELECTRONICS": 1.5 + }, + "Recycled_domestic_USA": { + "US_DEFENSE": 0.1, + "US_COMMERCIAL": 0.1, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 1.3, + "KR_ELECTRONICS": 1.3 + } + }, + "risk_weight": 8.0, + "unmet_demand_penalty": 75000, + "max_supplier_share": 0.45 + } + }, + "nd_post_december_2025": { + "task": "Stress-test NdPr oxide supply chain under full China December 2025 extraterritorial export control regime.", + "rag_context": "December 2025: China implements full extraterritorial rare earth export rules. Chinese NdPr oxide export capacity drops to 3,000 t/yr. All supplier costs increase significantly. Total non-Chinese capacity (Lynas 12,000 + MP 1,300 + Neo 2,000 + Recycled 600 = 15,900) is insufficient to meet 14,500 t/yr demand, especially with the 40% max supplier share cap limiting Lynas to 5,800 t/yr.", + "execution_instruction": "Produce a crisis-scenario analysis showing demand shortfalls, market prioritization under scarcity, and shadow prices indicating the marginal value of additional capacity from each source.", + "source_queries": {}, + "optimization_input": { + "commodity": "NDPR_OXIDE", + "demand": { + "US_DEFENSE": 850, + "US_COMMERCIAL": 3700, + "EU_AUTOMOTIVE": 3000, + "EU_INDUSTRIAL": 3500, + "JP_AUTOMOTIVE": 2000, + "KR_ELECTRONICS": 1500 + }, + "suppliers": [ + { + "name": "China_consolidated", + "capacity": 3000, + "unit_cost": 115.0, + "risk_score": 0.95 + }, + { + "name": "Lynas_Australia", + "capacity": 12000, + "unit_cost": 145.0, + "risk_score": 0.10 + }, + { + "name": "MP_Materials_USA", + "capacity": 1300, + "unit_cost": 135.0, + "risk_score": 0.05 + }, + { + "name": "Neo_Performance_EU", + "capacity": 2000, + "unit_cost": 140.0, + "risk_score": 0.20 + }, + { + "name": "Recycled_domestic_USA", + "capacity": 600, + "unit_cost": 85.0, + "risk_score": 0.08 + } + ], + "shipping_cost": { + "China_consolidated": { + "US_DEFENSE": 0.725, + "US_COMMERCIAL": 0.725, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 0.4, + "KR_ELECTRONICS": 0.35 + }, + "Lynas_Australia": { + "US_DEFENSE": 1.2, + "US_COMMERCIAL": 1.2, + "EU_AUTOMOTIVE": 1.725, + "EU_INDUSTRIAL": 1.725, + "JP_AUTOMOTIVE": 0.9, + "KR_ELECTRONICS": 0.95 + }, + "MP_Materials_USA": { + "US_DEFENSE": 0.1, + "US_COMMERCIAL": 0.1, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 1.3, + "KR_ELECTRONICS": 1.3 + }, + "Neo_Performance_EU": { + "US_DEFENSE": 1.0, + "US_COMMERCIAL": 1.0, + "EU_AUTOMOTIVE": 0.1, + "EU_INDUSTRIAL": 0.1, + "JP_AUTOMOTIVE": 1.5, + "KR_ELECTRONICS": 1.5 + }, + "Recycled_domestic_USA": { + "US_DEFENSE": 0.1, + "US_COMMERCIAL": 0.1, + "EU_AUTOMOTIVE": 1.0, + "EU_INDUSTRIAL": 1.0, + "JP_AUTOMOTIVE": 1.3, + "KR_ELECTRONICS": 1.3 + } + }, + "risk_weight": 12.0, + "unmet_demand_penalty": 100000, + "max_supplier_share": 0.40 + } + } +} diff --git a/docs/arxiv_agent.md b/docs/arxiv_agent.md index 671ce955..a7a20793 100644 --- a/docs/arxiv_agent.md +++ b/docs/arxiv_agent.md @@ -1,91 +1,47 @@ # ArxivAgent Documentation -`ArxivAgent` is a class that helps fetch, process, and summarize scientific papers from arXiv. It uses LLMs to generate summaries of papers relevant to a given query and context. +`ArxivAgent` (current implementation) is an acquisition agent that fetches ArXiv papers, extracts content, and returns context-aware summaries. ## Basic Usage ```python +from langchain.chat_models import init_chat_model from ursa.agents import ArxivAgent -# Initialize the agent -agent = ArxivAgent() +llm = init_chat_model("openai:gpt-5.2") +agent = ArxivAgent(llm=llm, max_results=3) -# Run a query result = agent.invoke( - arxiv_search_query="Experimental Constraints on neutron star radius", - context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?" + query="Experimental Constraints on neutron star radius", + context="What are the constraints on neutron star radius and what uncertainties are reported?", ) -# Print the summary -print(result) +print(result["final_summary"]) ``` ## Parameters -When initializing `ArxivAgent`, you can customize its behavior with these parameters: - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `llm` | `BaseChatModel` | `init_chat_model("openai:gpt-5-mini")` | The LLM model to use for summarization | -| `summarize` | bool | True | Whether to summarize the papers or just fetch them | -| `process_images` | bool | True | Whether to extract and describe images from papers | -| `max_results` | int | 3 | Maximum number of papers to fetch from arXiv | -| `database_path` | str | 'arxiv_papers' | Directory to store downloaded PDFs | -| `summaries_path` | str | 'arxiv_generated_summaries' | Directory to store paper summaries | -| `vectorstore_path` | str | 'arxiv_vectorstores' | Directory to store vector embeddings | -| `download` | bool | True | Whether to download papers or use existing ones | +- `llm`: required chat model +- `max_results`: number of papers to fetch +- `summarize`: summarize fetched items (`True` default) +- `process_images`: attempt image extraction + vision description +- `download`: if `False`, use local cache in `database_path` +- `database_path`, `summaries_path`, `vectorstore_path`: workspace-relative folders ## Advanced Usage ### Customizing the Agent ```python -from langchain.chat_models import init_chat_model -from ursa.agents import ArxivAgent - agent = ArxivAgent( - llm=init_chat_model("openai:gpt-5-mini"), # Use a more powerful model - max_results=5, # Fetch more papers - process_images=False, # Skip image processing to save time - download=False # Use only papers already in database_path + llm=llm, + max_results=5, + process_images=False, + download=False, ) ``` -### Running Multiple Queries - -```python -# First query -result1 = agent.invoke( - arxiv_search_query="quantum computing error correction", - context="Summarize recent advances in quantum error correction techniques" -) - -# Second query (will reuse downloaded papers if applicable) -result2 = agent.invoke( - arxiv_search_query="quantum computing algorithms", - context="What are the most promising quantum algorithms for near-term devices?" -) -``` - -## How It Works - -1. **Fetching Papers**: The agent searches arXiv for papers matching your query and downloads them as PDFs. - -2. **Processing**: If `summarize=True`, each paper is: - - Converted to text - - Split into chunks - - Embedded into a vector database - - If `process_images=True`, images are extracted and described using GPT-4 Vision - -3. **Summarization**: The agent: - - Retrieves the most relevant chunks based on your context - - Generates a summary for each paper - - Creates a final summary addressing your specific context - -4. **Output**: Returns a comprehensive summary that synthesizes information from all relevant papers. - ## Notes -- Summaries and vector stores are cached, making subsequent queries faster. -- The agent uses a ThreadPoolExecutor to process papers in parallel. -- You can find the combined summaries in 'summaries_combined.txt' and the final summary in 'final_summary.txt'. +- Returned state includes `items`, optional per-item `summaries`, and `final_summary`. +- Legacy `ArxivAgentLegacy` has been retired. Use `ursa.agents.ArxivAgent`. diff --git a/docs/cmm_demo_runbook.md b/docs/cmm_demo_runbook.md new file mode 100644 index 00000000..009ec982 --- /dev/null +++ b/docs/cmm_demo_runbook.md @@ -0,0 +1,265 @@ +# CMM Demo Runbook: Full Workflow Build and Demo Readiness + +This runbook provides a full, reproducible sequence to prepare and run the +Critical Minerals and Materials (CMM) demo on your local corpus. + +It assumes this repository is at: + +- `/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa` + +And your corpus is at: + +- `/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus` + +## 1. Prerequisites + +- Python 3.10+ and `uv` installed. +- Valid `OPENAI_API_KEY`. +- Your OpenAI-compatible endpoint URL (custom `OPENAI_BASE_URL`). +- Network access from this machine to your model endpoint. + +## 2. Open Repo and Install Dependencies + +```bash +cd /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa + +# Base + dev dependencies +uv sync --group dev + +# Optional CMM extras (needed only for local/cohere/weaviate optional paths) +uv sync --extra cmm +``` + +## 3. Configure Environment + +```bash +cp .env.example .env +``` + +Edit `.env` and set at least: + +```bash +OPENAI_API_KEY= +OPENAI_BASE_URL= + +# Demo defaults (OpenAI-only) +CMM_VECTORSTORE_BACKEND=chroma +CMM_EMBEDDING_MODEL=openai:text-embedding-3-large +CMM_EMBEDDING_DIMENSIONS=3072 +CMM_USE_RERANKER=false +CMM_RERANKER_PROVIDER=none +CMM_HYBRID_ALPHA=0.7 +CMM_VECTORSTORE_COLLECTION=cmm_chunks + +URSA_RAG_LEGACY_MODE=false +``` + +## 4. Sanity Check Model Connectivity (Recommended) + +```bash +uv run python - <<'PY' +import os +from langchain.chat_models import init_chat_model + +base_url = os.getenv("OPENAI_BASE_URL") +api_key = os.getenv("OPENAI_API_KEY") +model = init_chat_model( + model="openai:gpt-5-nano", + base_url=base_url, + api_key=api_key, + temperature=0, +) +print(model.invoke("Reply with exactly: connectivity_ok").content) +PY +``` + +Expected: `connectivity_ok`. + +## 5. Preprocess + Ingest Corpus (Smoke Pass) + +Start small to validate parsing and indexing behavior before full ingest. + +```bash +uv run python scripts/reindex.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --backend chroma \ + --embedding-model openai:text-embedding-3-large \ + --embedding-dimensions 3072 \ + --include-extension pdf \ + --include-extension txt \ + --include-extension md \ + --exclude-extension py \ + --max-docs 250 \ + --reset +``` + +Review outputs for: + +- `Docs indexed` +- `Chunks indexed` +- `Vectorstore count` +- `Commodity tag counts` +- `Subdomain tag counts` + +## 6. Full Corpus Ingest + +After smoke pass succeeds, run full ingest. + +```bash +uv run python scripts/reindex.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --backend chroma \ + --embedding-model openai:text-embedding-3-large \ + --embedding-dimensions 3072 \ + --include-extension pdf \ + --include-extension txt \ + --include-extension md +``` + +Notes: + +- Reindex writes a manifest at `cmm_vectorstore/_ingested_ids.txt`. +- This manifest prevents `RAGAgent` from re-ingesting already indexed docs. +- If you need a full rebuild, rerun with `--reset`. + +## 7. Run Healthcheck and Demo Runner + +You now have turnkey scripts: + +- `scripts/demo_healthcheck.py` +- `scripts/run_cmm_demo.py` +- `configs/cmm_demo_scenarios.json` + +### Healthcheck + +```bash +uv run python scripts/demo_healthcheck.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore +``` + +### Run a scenario + +```bash +uv run python scripts/run_cmm_demo.py \ + --scenario ndfeb_la_y_5pct_baseline \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --output-dir /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_demo_outputs +``` + +Available scenarios: + +- `ndfeb_la_y_5pct_baseline` +- `ndfeb_la_y_5pct_quality_tightening` +- `ndfeb_la_y_5pct_supply_shock` + +Artifacts are written under: + +- `cmm_demo_outputs///` +- `cmm_demo_outputs//latest/` + +## 8. One-Command Demo Prep + +Use the `just` target to run healthcheck + smoke demo in one command: + +```bash +just demo-prep +``` + +If `just` is not installed: + +```bash +uv tool install rust-just +``` + +Optional environment overrides: + +```bash +export CMM_CORPUS_PATH=/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus +export CMM_VECTORSTORE_PATH=/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore +export CMM_DEMO_SCENARIO=ndfeb_la_y_5pct_baseline +export CMM_OUTPUT_DIR=/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_demo_outputs +just demo-prep +``` + +## 9. Validate Demo Acceptance Criteria + +Confirm the run output contains all of the following: + +- Source-grounded RAG narrative in final summary. +- Retrieved context behavior that varies by query type. +- Optimization output with deterministic fields: + - `objective_value` + - `allocations` + - `constraint_residuals` + - `feasible` / `status` + - `sensitivity_summary` + +## 10. Regression / Confidence Tests + +Run targeted tests used for this demo path: + +```bash +uv run pytest -q \ + tests/agents/test_rag_agent/test_rag_agent.py \ + tests/agents/test_rag_agent/test_cmm_components.py \ + tests/tools/test_cmm_supply_chain_optimization_tool.py \ + tests/workflows/test_critical_minerals_workflow.py \ + tests/agents/test_execution_agent/test_execution_agent.py +``` + +## 11. Optional Architecture Toggles (Not Required for Demo) + +### Weaviate backend + +Set: + +```bash +CMM_VECTORSTORE_BACKEND=weaviate +CMM_WEAVIATE_URL= +CMM_WEAVIATE_API_KEY= +``` + +Re-run `scripts/reindex.py` with `--backend weaviate`. + +### Cohere reranker + +Set: + +```bash +CMM_USE_RERANKER=true +CMM_RERANKER_PROVIDER=cohere +COHERE_API_KEY= +``` + +## 12. Demo-Day Checklist + +- `.env` points to correct endpoint and key. +- `cmm_vectorstore` already built (avoid live full ingest on stage). +- `scripts/demo_healthcheck.py` reports no FAIL items. +- `scripts/run_cmm_demo.py` runs cleanly once pre-demo. +- Have 2-3 prepared prompts of increasing complexity. +- Keep one deterministic optimization scenario ready to replay. + +## 13. Troubleshooting + +- Empty RAG output: + - Verify `cmm_vectorstore` path and `_ingested_ids.txt`. + - Ensure corpus parseable files exist for selected extensions. +- Slow ingest: + - Run a staged ingest with `--max-docs` batches. + - Limit extensions to high-value types first (`pdf`, `txt`, `md`). +- API errors: + - Re-check `OPENAI_BASE_URL` and endpoint auth requirements. +- Rebuild clean index: + - Re-run `scripts/reindex.py ... --reset`. + +## 14. Suggested Artifact Outputs for Demo Package + +- Final `result` JSON from workflow run. +- `RAG_summary.txt` from summaries directory. +- Screenshot/log snippet showing indexed counts and tag distributions. +- A short one-page summary of insights and recommended actions. diff --git a/docs/critical_minerals_workflow.md b/docs/critical_minerals_workflow.md new file mode 100644 index 00000000..5667afae --- /dev/null +++ b/docs/critical_minerals_workflow.md @@ -0,0 +1,118 @@ +# Critical Minerals Workflow + +This workflow provides an end-to-end path for Critical Minerals and Materials +(CMM) analysis with: + +- local-corpus RAG, +- adaptive retrieval, +- deterministic CMM supply-chain optimization, +- final synthesis via the execution model. + +## Architecture + +```mermaid +flowchart TD + U["User Task"] --> WF["CriticalMineralsWorkflow"] + + WF --> PL["Planning Agent"] + WF --> ACQ["Acquisition Agents"] + WF --> RAG["RAGAgent"] + WF --> OPT["CMM Optimization Tool"] + WF --> EX["Executor"] + + ACQ --> OSTI["OSTI"] + ACQ --> ARX["arXiv"] + ACQ --> WEB["Web"] + + RAG --> CLASS["Query Classifier"] + RAG --> CHUNK["CMM Chunker"] + RAG --> VS["VectorStore Backend"] + VS --> CHROMA["Chroma + BM25 + RRF"] + VS --> WEAVIATE["Weaviate (optional)"] + + EX --> TOOLS["Execution tools"] + TOOLS --> OPT +``` + +## RAG Components + +`RAGAgent` now uses CMM-specific modules: + +- `src/ursa/agents/cmm_taxonomy.py` + - commodity/subdomain tags and temporal hints. +- `src/ursa/agents/cmm_embeddings.py` + - embedding provider abstraction (`openai`, `local`). +- `src/ursa/agents/cmm_chunker.py` + - markdown-aware, table-preserving chunking + metadata enrichment. +- `src/ursa/agents/cmm_vectorstore.py` + - backend abstraction: + - `chroma`: dense + BM25 hybrid retrieval with RRF fusion, + - `weaviate`: optional backend hook. +- `src/ursa/agents/cmm_query_classifier.py` + - rule-based query profile for adaptive retrieval. +- `src/ursa/agents/cmm_reranker.py` + - reranker abstraction (`none` default; `cohere`/`local` optional). + +## Demo Defaults (OpenAI-only) + +Set in `.env` (or start from `.env.example`): + +```bash +OPENAI_API_KEY=... +OPENAI_BASE_URL=https://api.openai.com/v1 + +CMM_VECTORSTORE_BACKEND=chroma +CMM_EMBEDDING_MODEL=openai:text-embedding-3-large +CMM_EMBEDDING_DIMENSIONS=3072 +CMM_USE_RERANKER=false +CMM_RERANKER_PROVIDER=none +CMM_HYBRID_ALPHA=0.7 +``` + +## Workflow Input Schema + +`CriticalMineralsWorkflow.invoke(...)` supports: + +- `task` (required) +- `local_corpus_path` +- `rag_context` +- `source_queries` +- `optimization_input` +- `execution_instruction` + +### Optimization Input + +`optimization_input` expected fields: + +- `commodity`: string +- `demand`: mapping `market -> quantity` +- `suppliers`: list of `{name, capacity, unit_cost, risk_score}` +- `shipping_cost` (optional): mapping `supplier -> market -> unit shipping cost` +- `risk_weight` (optional) +- `unmet_demand_penalty` (optional) +- `max_supplier_share` (optional, 0 to 1) + +Output is deterministic JSON with: + +- `objective_value` +- `allocations` +- `constraint_residuals` +- `feasible` / `status` +- `sensitivity_summary` + +## Reindex Script + +Use `scripts/reindex.py` to ingest corpus into configured backend: + +```bash +uv run python scripts/reindex.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path cmm_vectorstore \ + --backend chroma \ + --embedding-model openai:text-embedding-3-large \ + --embedding-dimensions 3072 \ + --reset +``` + +The script reports indexed document/chunk counts and commodity/subdomain tag +counts. diff --git a/docs/domain_scientist_setup_and_usage.md b/docs/domain_scientist_setup_and_usage.md new file mode 100644 index 00000000..d76b185c --- /dev/null +++ b/docs/domain_scientist_setup_and_usage.md @@ -0,0 +1,364 @@ +# URSA CMM Guide for Domain Scientists + +This guide is a practical, end-to-end manual for domain scientists who want to +use the URSA Critical Minerals and Materials (CMM) workflow without digging +through the full codebase first. + +It covers: + +1. Environment setup. +2. API/model configuration. +3. Corpus preprocessing and indexing. +4. Running CMM scenarios and custom analyses. +5. Reading outputs (RAG + optimization). +6. Troubleshooting common failures. +7. Demo preparation workflow. + +## 1. What This Workflow Does + +At a high level, the CMM workflow combines: + +- Planning (`PlanningAgent`): decomposes the task and response strategy. +- Retrieval (`RAGAgent`): pulls relevant evidence from your local corpus. +- Deterministic optimization (`run_cmm_supply_chain_optimization`): computes + supply allocations, unmet demand, and composition-constraint feasibility. +- Synthesis (`ExecutionAgent`): writes the final decision narrative. + +Primary workflow path: + +- `src/ursa/workflows/critical_minerals_workflow.py` +- `scripts/run_cmm_demo.py` + +## 2. Prerequisites + +- macOS/Linux terminal access. +- Python 3.10+. +- [`uv`](https://docs.astral.sh/uv/) installed. +- A valid `OPENAI_API_KEY`. +- An OpenAI-compatible endpoint (`OPENAI_BASE_URL`) if you are not using + `https://api.openai.com/v1`. + +Repository path used below: + +- `/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa` + +Corpus path used below: + +- `/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus` + +## 3. Initial Setup + +### 3.1 Enter the repo + +```bash +cd /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa +``` + +### 3.2 Install dependencies + +```bash +uv sync --group dev +uv sync --extra cmm +``` + +### 3.3 Create `.env` + +```bash +cp .env.example .env +``` + +Set at least: + +```bash +OPENAI_API_KEY= +OPENAI_BASE_URL= +``` + +## 4. Recommended Runtime Profiles + +The most important rule is: **embedding settings at runtime must match how the +vectorstore was indexed**. + +### Profile A: Use existing local index (common in this workspace) + +Use if your index was built with `text-embedding-3-small-project`: + +```bash +CMM_EMBEDDING_MODEL=openai:text-embedding-3-small-project +CMM_EMBEDDING_DIMENSIONS=1536 +``` + +### Profile B: Build a fresh index with large embeddings + +If you reindex with `text-embedding-3-large`, use: + +```bash +CMM_EMBEDDING_MODEL=openai:text-embedding-3-large +CMM_EMBEDDING_DIMENSIONS=3072 +``` + +### Common CMM defaults + +```bash +CMM_VECTORSTORE_BACKEND=chroma +CMM_USE_RERANKER=false +CMM_RERANKER_PROVIDER=none +CMM_HYBRID_ALPHA=0.7 +CMM_VECTORSTORE_COLLECTION=cmm_chunks +URSA_RAG_LEGACY_MODE=false +``` + +### Model access note + +If your endpoint does not allow `openai:gpt-5`, set explicit project models: + +```bash +CMM_PLANNER_MODEL=openai:gpt-5.2-project +CMM_EXECUTOR_MODEL=openai:gpt-5.2-project +CMM_RAG_MODEL=openai:gpt-5.2-project +``` + +## 5. Healthcheck Before Indexing/Running + +```bash +uv run python scripts/demo_healthcheck.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --model openai:gpt-5.2-project +``` + +A good run has no `FAIL` lines. + +## 6. Corpus Indexing (Preprocessing + Ingestion) + +### 6.1 Smoke test index (small batch) + +```bash +uv run python scripts/reindex.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --backend chroma \ + --embedding-model openai:text-embedding-3-small-project \ + --embedding-dimensions 1536 \ + --include-extension pdf \ + --include-extension txt \ + --include-extension md \ + --exclude-extension py \ + --max-docs 250 \ + --reset +``` + +### 6.2 Full indexing + +```bash +uv run python scripts/reindex.py \ + --corpus-path /Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --backend chroma \ + --embedding-model openai:text-embedding-3-small-project \ + --embedding-dimensions 1536 \ + --include-extension pdf \ + --include-extension txt \ + --include-extension md +``` + +Expected summary includes: + +- `Docs indexed` +- `Chunks indexed` +- `Vectorstore count` +- Tag counts by commodity/subdomain + +### 6.3 Manifest behavior + +The script writes: + +- `cmm_vectorstore/_ingested_ids.txt` + +This prevents re-ingesting documents already indexed. + +## 7. Running the CMM Workflow + +### 7.1 Available built-in scenarios + +Defined in: + +- `configs/cmm_demo_scenarios.json` + +Current scenario IDs: + +- `ndfeb_la_y_5pct_baseline` +- `ndfeb_la_y_5pct_quality_tightening` +- `ndfeb_la_y_5pct_supply_shock` + +### 7.2 Recommended run command + +Use an already-indexed vectorstore and an empty corpus path to avoid +reparsing all raw files during each demo run: + +```bash +CMM_EMBEDDING_MODEL=openai:text-embedding-3-small-project \ +CMM_EMBEDDING_DIMENSIONS=1536 \ +CMM_USE_RERANKER=false \ +CMM_RERANKER_PROVIDER=none \ +CMM_VECTORSTORE_BACKEND=chroma \ +uv run python scripts/run_cmm_demo.py \ + --scenario ndfeb_la_y_5pct_baseline \ + --planner-model openai:gpt-5.2-project \ + --executor-model openai:gpt-5.2-project \ + --rag-model openai:gpt-5.2-project \ + --corpus-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/empty_corpus \ + --vectorstore-path /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_vectorstore \ + --output-dir /Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_demo_outputs +``` + +### 7.3 Run all three scenario variants + +Repeat the same command with: + +- `--scenario ndfeb_la_y_5pct_baseline` +- `--scenario ndfeb_la_y_5pct_quality_tightening` +- `--scenario ndfeb_la_y_5pct_supply_shock` + +## 8. Understanding Output Artifacts + +Each run writes: + +- `cmm_demo_outputs///input_payload.json` +- `cmm_demo_outputs///workflow_result.json` +- `cmm_demo_outputs///optimization_output.json` +- `cmm_demo_outputs///rag_metadata.json` +- `cmm_demo_outputs///final_summary.md` + +And copies current output to: + +- `cmm_demo_outputs//latest/` + +### 8.1 Key optimization fields + +In `optimization_output.json` (or `workflow_result.json -> optimization`): + +- `status` +- `feasible` +- `objective_value` +- `allocations` +- `unmet_demand` +- `constraint_residuals` +- `composition` (targets/actual/residuals/tolerance) +- `sensitivity_summary` + +### 8.2 Status interpretation + +- `optimal_greedy`: demand and active constraints satisfied. +- `infeasible_unmet_demand`: demand could not be fully satisfied. +- `infeasible_composition_constraints`: demand may be met, but composition + targets violate tolerance. +- `infeasible_unmet_and_composition`: both demand and composition fail. + +### 8.3 RAG metadata interpretation + +In `rag_metadata`: + +- `num_results`: retrieved chunk count. +- `relevance_scores`: retrieval scores. +- `query_type`: classifier type (`general`, `multi_hop`, etc.). +- `filter_fallback_used`: retrieval had to retry without metadata filters. + +## 9. Creating a New Domain Scenario + +1. Open `configs/cmm_demo_scenarios.json`. +2. Add a new top-level scenario object with fields: + - `task` + - `rag_context` + - `execution_instruction` + - `source_queries` (optional) + - `optimization_input` +3. In `optimization_input`, include: + - demand/suppliers/shipping/risk parameters + - `composition_targets` (e.g., `LA`, `Y`) + - `composition_tolerance` + - per-supplier `composition_profile` +4. Run with `scripts/run_cmm_demo.py --scenario `. + +## 10. Testing and Regression Checks + +Run targeted tests after edits: + +```bash +uv run pytest -q \ + tests/agents/test_rag_agent/test_rag_agent.py \ + tests/agents/test_rag_agent/test_cmm_components.py \ + tests/tools/test_cmm_supply_chain_optimization_tool.py \ + tests/workflows/test_critical_minerals_workflow.py +``` + +## 11. Common Failure Modes and Fixes + +### 11.1 `team_model_access_denied` / 401 model errors + +Cause: model name is not allowed on your endpoint. + +Fix: set model env/flags to endpoint-allowed models, for example: + +```bash +CMM_PLANNER_MODEL=openai:gpt-5.2-project +CMM_EXECUTOR_MODEL=openai:gpt-5.2-project +CMM_RAG_MODEL=openai:gpt-5.2-project +``` + +### 11.2 RAG returns zero results + +Check: + +- `vectorstore-path` is correct. +- embedding model/dimensions match index build settings. +- corpus was actually indexed (`_ingested_ids.txt` non-empty). +- `rag_context` contains domain-specific terminology. + +### 11.3 Runs are very slow + +Cause: workflow reparses entire corpus each run. + +Fix: use an indexed vectorstore and point `--corpus-path` to `empty_corpus` for +repeat demo runs. + +### 11.4 Indexing appears stalled + +Use smaller batches and frequent flushes: + +```bash +--max-docs 500 --flush-docs 20 +``` + +Then iterate by tranche. + +## 12. Demo-Day Checklist + +1. `.env` loaded and correct. +2. Model connectivity healthcheck passes. +3. Vectorstore is already built. +4. Scenario commands tested once before recording. +5. Artifact folders verified under `cmm_demo_outputs/.../latest/`. +6. Keep one feasible scenario and one infeasible scenario ready to show + contrast. + +## 13. Fast Command Reference + +### Reindex + +```bash +uv run python scripts/reindex.py --corpus-path --vectorstore-path --backend chroma --embedding-model openai:text-embedding-3-small-project --embedding-dimensions 1536 +``` + +### Healthcheck + +```bash +uv run python scripts/demo_healthcheck.py --corpus-path --vectorstore-path --model openai:gpt-5.2-project +``` + +### Run scenario + +```bash +uv run python scripts/run_cmm_demo.py --scenario --planner-model openai:gpt-5.2-project --executor-model openai:gpt-5.2-project --rag-model openai:gpt-5.2-project --corpus-path --vectorstore-path --output-dir +``` + diff --git a/docs/execution_agent.md b/docs/execution_agent.md index 97ea7d80..08095e36 100644 --- a/docs/execution_agent.md +++ b/docs/execution_agent.md @@ -1,118 +1,47 @@ # ExecutionAgent Documentation -`ExecutionAgent` is a class that enables AI-powered code execution, writing, and editing. It uses a state machine architecture to safely execute commands, write code files, and search for information. +`ExecutionAgent` runs iterative plan/act loops with tools for reading files, writing/editing code, running commands, and doing acquisition-style web/literature lookups. ## Basic Usage ```python +from langchain.chat_models import init_chat_model from ursa.agents import ExecutionAgent -# Initialize the agent -agent = ExecutionAgent() +llm = init_chat_model("openai:gpt-5.2") +agent = ExecutionAgent(llm=llm, workspace="ursa_workspace") -# Run a prompt -result = agent("Write and execute a python script to print the first 10 integers.") - -# Access the final response +result = agent.invoke("Write and run a Python script that prints the first 10 integers.") print(result["messages"][-1].text) ``` -## Parameters - -When initializing `ExecutionAgent`, you can customize its behavior with these parameters: - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `llm` | `BaseChatModel` | `init_chat_model("openai:gpt-5-mini")` | The LLM model to use | -| `extra_tools` | `Optional[list[Callable[..., Any]]]` | `None` | Additional tools for the execution agent | - -## Features - -### Code Execution - -The agent can safely execute shell commands in a controlled environment: - -```python -result = agent("Install numpy and create a script that uses it to calculate the mean of [1, 2, 3, 4, 5]") -``` - -### Code Writing - -The agent can write code files to a workspace directory: - -```python -result = agent("Create a Flask web application that displays 'Hello World'") -``` - -## Advanced Usage - -### Customizing the Workspace - -The agent creates a workspace folder with a randomly generated name for each -run. You can access this workspace path from the agent: - -```python -result = agent("Create a Python script")) -workspace_path = agent.workspace -print(f"Files were created in: {workspace_path}") -``` - -### Setting a Recursion Limit - -For complex tasks, you may need to adjust the recursion limit: - -```python -result = agent.invoke( - "Create a complex project with multiple files and tests", - recursion_limit=2000 -) -``` - -### Safety Features - -The agent includes built-in safety checks for shell commands: - -1. Commands are evaluated for safety before execution -2. Unsafe commands are blocked with explanations -3. The agent suggests safer alternatives when appropriate - -## How It Works +## Key Parameters -1. **State Machine**: The agent uses a directed graph to manage its workflow: - - `agent` node: Processes user requests and generates responses - - `action` node: Executes tools (`run_cmd`, `write_code`, `edit_code`, `search`) - - extra tools can be provided to the agent as follows: - ```py - from langchain.tools import tool +- `llm`: required `BaseChatModel` +- `workspace`: directory for files and command execution (default `ursa_workspace`) +- `extra_tools`: optional additional tools +- `safe_codes`: trusted language/tool hints for command safety checks +- `tokens_before_summarize`, `messages_to_keep`: context compaction controls - @tool - def do_magic(a: int, b: int) -> float: - """Do magic with integers a and b. +## Graph Shape - Args: - a: first integer - b: second integer - """ - return sqrt(a**2 + b**2) +- `agent`: LLM decides next action +- `action`: executes tool calls +- `recap`: summarizes final result - agent = ExecutionAgent(extra_tools=[do_magic]) - ``` - - `summarize` node: Creates a final summary when complete +The graph loops `agent -> action -> agent` until no tool calls remain, then goes to `recap`. -2. **Tools**: - - `run_cmd`: Executes shell commands in the workspace directory - - `write_code`: Creates new code files with syntax highlighting - - `edit_code`: Modifies existing code files with diff preview - - `search_tool`: Performs web searches via DuckDuckGo +## Built-in Tools -3. **Visualization**: - - Code changes are displayed with syntax highlighting - - File edits show detailed diffs - - Command execution shows stdout and stderr +- `run_command` +- `write_code` +- `edit_code` +- `read_file` +- `run_web_search` +- `run_osti_search` +- `run_arxiv_search` -## Notes +## Safety Notes -- The agent creates a new workspace directory for each run -- Files are written to and executed from this workspace -- Shell commands have a 60000-second timeout by default -- The agent can handle keyboard interrupts during command execution +- `run_command` performs an LLM safety check before execution. +- Unsafe commands are blocked and returned with a reason. diff --git a/docs/ndfeb_la_y_5pct_demo_comparison.md b/docs/ndfeb_la_y_5pct_demo_comparison.md new file mode 100644 index 00000000..96464e34 --- /dev/null +++ b/docs/ndfeb_la_y_5pct_demo_comparison.md @@ -0,0 +1,23 @@ +# Nd2Fe14B 5% La / 5% Y Scenario Comparison + +Generated: 2026-02-22 17:32:12Z + +## Run Summary + +| Scenario | Latest Run | Status | Feasible | Objective | Unmet | Composition Feasible | LA Actual | Y Actual | RAG Results | Filter Fallback | Score Range | +|---|---:|---|---:|---:|---:|---:|---:|---:|---:|---:|---| +| `ndfeb_la_y_5pct_baseline` | `20260222T172757Z` | `optimal_greedy` | `True` | `40035.8` | `0.0` | `True` | `0.053425` | `0.046575` | `5` | `False` | 0.008861-0.009859 | +| `ndfeb_la_y_5pct_quality_tightening` | `20260222T171520Z` | `optimal_greedy` | `True` | `40945.33` | `0.0` | `True` | `0.052367` | `0.047633` | `5` | `False` | 0.004167-0.008974 | +| `ndfeb_la_y_5pct_supply_shock` | `20260222T173144Z` | `infeasible_unmet_and_composition` | `False` | `309447.12` | `10.0` | `False` | `0.054242` | `0.045758` | `2` | `False` | 0.004839-0.004918 | + +## Demo Readout + +- `baseline`: feasible demand and composition posture under nominal assumptions. +- `quality_tightening`: feasible but closer composition/tolerance management burden; use for QA governance narrative. +- `supply_shock`: infeasible from both demand and composition constraints; use for mitigation and contingency discussion. + +## Artifact Paths + +- `/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_demo_outputs/ndfeb_la_y_5pct_baseline/20260222T172757Z` +- `/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_demo_outputs/ndfeb_la_y_5pct_quality_tightening/20260222T171520Z` +- `/Users/wash198/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/Science_Projects/MPII_CMM/ursa/cmm_demo_outputs/ndfeb_la_y_5pct_supply_shock/20260222T173144Z` diff --git a/docs/planning_agent.md b/docs/planning_agent.md index 99ee1dba..5e9f34bb 100644 --- a/docs/planning_agent.md +++ b/docs/planning_agent.md @@ -1,104 +1,39 @@ # PlanningAgent Documentation -`PlanningAgent` is a class that implements a multi-step planning approach for complex problem solving. It uses a state machine architecture to generate plans, reflect on them, and formalize the final solution. +`PlanningAgent` generates structured multi-step plans for a task and optionally performs reflection loops. ## Basic Usage ```python +from langchain.chat_models import init_chat_model from ursa.agents import PlanningAgent -# Initialize the agent -agent = PlanningAgent() +llm = init_chat_model("openai:gpt-5.2") +agent = PlanningAgent(llm=llm) -# Run a planning task result = agent.invoke("Find a city with at least 10 vowels in its name.") - -# Access the final plan -plan_steps = result["plan_steps"] +plan = result["plan"] +print(plan.steps[0].name) ``` -## Parameters - -When initializing `PlanningAgent`, you can customize its behavior with these parameters: - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `llm` | BaseChatModel | `init_chat_model("openai:gpt-5-mini")` | The LLM model to use for planning | -| `**kwargs` | `dict` | `{}` | Additional parameters passed to the base agent | - -## Features - -### Multi-step Planning - -The agent follows a three-stage planning process: +## Output Schema -1. **Generation**: Creates an initial plan to solve the problem -2. **Reflection**: Critically evaluates and improves the plan -3. **Formalization**: Structures the final plan as a JSON object - -### Structured Output - -The final output includes: - -- `messages`: The conversation history -- `plan_steps`: A structured list of steps to solve the problem +- `plan`: structured `Plan` object with `steps` +- `messages`: message history for generation/reflection ## Advanced Usage ### Customizing Reflection Steps -You can adjust how many reflection iterations the agent performs: - ```python -# Initialize with custom reflection steps -initial_state = { - "messages": [HumanMessage(content="Your complex problem here")], - "reflection_steps": 5 # Default is 3 -} - -result = agent.invoke(initial_state, {"configurable": {"thread_id": agent.thread_id}}) +agent = PlanningAgent(llm=llm, max_reflection_steps=3) ``` -### Streaming Results - -You can stream the agent's thinking process: - -```python -for event in agent.stream( - {"messages": [HumanMessage(content="Your problem here")]}, - {"configurable": {"thread_id": agent.thread_id}} -): - print(event[list(event.keys())[0]]["messages"][-1].text) -``` - -### Setting a Recursion Limit - -For complex planning tasks, you may need to adjust the recursion limit: - -```python -result = agent.invoke( - "Solve this complex problem...", - recursion_limit=200 # Default is 100 -) -``` - -## How It Works - -1. **State Machine**: The agent uses a directed graph to manage its workflow: - - `generate` node: Creates or improves the plan - - `reflect` node: Evaluates the plan for improvements - - `formalize` node: Structures the final plan as JSON - -2. **Termination Conditions**: The planning process ends when either: - - The agent has completed the specified number of reflection steps - - The agent explicitly marks the plan as "[APPROVED]" +`max_reflection_steps` defaults to `1`. -3. **JSON Output**: The final plan is structured as a JSON array of steps, each containing: - - A description of the step - - Any relevant details for executing that step +## Graph Shape -## Notes +- `generate`: creates a structured plan (`Plan`) +- `reflect`: critiques and requests regeneration if needed -- The agent continues to refine its plan through multiple reflection cycles -- The final output is a structured JSON representation of the solution steps -- You can access the complete conversation history in the `messages` field of the result \ No newline at end of file +Routing ends when reflection budget is exhausted or reflection marks the plan with `[APPROVED]`. diff --git a/docs/web_search_agent.md b/docs/web_search_agent.md index 1cfb6432..b74276dc 100644 --- a/docs/web_search_agent.md +++ b/docs/web_search_agent.md @@ -1,84 +1,45 @@ # WebSearchAgent Documentation -`WebSearchAgent` is a powerful tool for conducting internet-based research on any topic. It leverages language models and web search capabilities to gather, process, and summarize information from online sources. +`WebSearchAgent` (current implementation) is an acquisition agent that uses `ddgs` search, downloads/web-scrapes sources, and returns a synthesized summary in context. ## Basic Usage ```python +from langchain.chat_models import init_chat_model from ursa.agents import WebSearchAgent -from langchain_openai import ChatOpenAI -# Initialize with default model (gpt-5-mini) -websearcher = WebSearchAgent() +llm = init_chat_model("openai:gpt-5.2") +websearcher = WebSearchAgent(llm=llm, max_results=3) -# Or initialize with a custom model -model = ChatOpenAI(model="gpt-5-mini", max_completion_tokens=10000) -websearcher = WebSearchAgent(llm=model) - -# Run a web search query -result = websearcher.invoke("Who are the 2025 Detroit Tigers top 10 prospects and what year were they born?") - -# Access the web search results -sources = result["urls_visited"] - -print("Web Search Summary:") +result = websearcher.invoke({ + "query": "Detroit Tigers top prospects 2025 birth year", + "context": "Who are the top prospects and what year were they born?", +}) print(result["final_summary"]) -print("Sources:", sources) ``` ## Parameters -### Initialization - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `llm` | `BaseChatModel` | init_chat_model("openai:gpt-5-mini") | The language model to use for web search | -| `**kwargs` | `dict` | `{}` | Additional parameters passed to the base agent | - -### Run Method - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `prompt` | str | Required | The web search question or topic | -| `recursion_limit` | int | 100 | Maximum recursion depth for the web search process | +- `llm`: required chat model +- `max_results`: max search hits to materialize +- `summarize`: summarize fetched content (`True` by default) +- `download`: if `False`, uses cached files from `database_path` +- `database_path`, `summaries_path`, `vectorstore_path`: storage folders under workspace ## Features -- **Automated Web Search**: Uses DuckDuckGo to find relevant information -- **Content Processing**: Extracts and summarizes content from web pages -- **Iterative Web Search**: Continues researching until sufficient information is gathered -- **Source Tracking**: Records all URLs visited during research -- **Internet Connectivity Check**: Verifies internet access before attempting research +- DuckDuckGo discovery via `ddgs` +- HTML/PDF materialization to local cache +- boilerplate-stripped text extraction +- per-source summaries + final aggregate summary ## Output -The agent returns a dictionary containing: - -- `messages`: A list of message objects, with the final message containing the comprehensive web search summary -- `urls_visited`: A list of all sources consulted during the web search process - -## Advanced Usage - -```python -from langchain.chat_model import init_chat_model -from ursa.agents import WebSearchAgent - -# Initialize with custom parameters -websearcher = WebSearchAgent( - llm=init_chat_model("openai:gpt-5-mini"), - url="https://www.example.com" # Custom URL for internet connectivity check -) - -# Run with higher recursion limit for complex topics -result = websearcher.invoke( - "What are the latest developments in quantum computing? Summarize in markdown format.", - recursion_limit=200 -) -``` +- `final_summary`: synthesized answer +- `items`: fetched source items (metadata/content) +- `summaries`: per-item summaries ## Notes -- The agent requires internet connectivity to function properly -- Rate limiting is implemented to avoid overwhelming search services -- For networks with SSL inspection, you may need to set the `CERT_FILE` environment variable -- The websearch process includes multiple steps: search, content processing, review, and final summarization +- This documentation covers the current exported `ursa.agents.WebSearchAgent`. +- Legacy `WebSearchAgentLegacy` has been retired. Use `ursa.agents.WebSearchAgent`. diff --git a/examples/single_agent_examples/acquisition_examples/acquistion_agents.py b/examples/single_agent_examples/acquisition_examples/acquistion_agents.py index 99bbea3a..43886363 100644 --- a/examples/single_agent_examples/acquisition_examples/acquistion_agents.py +++ b/examples/single_agent_examples/acquisition_examples/acquistion_agents.py @@ -4,7 +4,7 @@ from rich import print as rprint from rich.panel import Panel -from ursa.agents import ArxivAgent, ArxivAgentLegacy, OSTIAgent, WebSearchAgent +from ursa.agents import ArxivAgent, OSTIAgent, WebSearchAgent def print_summary(summary: str, title: str): @@ -40,20 +40,6 @@ async def main(): }) print_summary(summary, title="OSTI Agent Summary") - # ArXiv agent (legacy version) - arxiv_agent_legacy = ArxivAgentLegacy( - llm=init_chat_model("openai:gpt-5-mini"), - max_results=3, - database_path="arxiv_papers", - summaries_path="arxiv_generated_summaries", - enable_metrics=True, - ) - summary = await arxiv_agent_legacy.ainvoke({ - "query": "graph neural networks for PDEs", - "context": "Summarize methods & benchmarks and potential for shock hydrodynamics", - }) - print_summary(summary, title="Arxiv Agent (Legacy) Summary") - # ArXiv agent arxiv_agent = ArxivAgent( llm=init_chat_model("openai:gpt-5-mini"), diff --git a/justfile b/justfile index 1b96a36b..039837f7 100644 --- a/justfile +++ b/justfile @@ -61,6 +61,15 @@ clean: clean-workspaces test-cli: uv run ursa run +demo-healthcheck: + bash -lc 'uv run python scripts/demo_healthcheck.py --corpus-path "${CMM_CORPUS_PATH:-/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus}" --vectorstore-path "${CMM_VECTORSTORE_PATH:-cmm_vectorstore}"' + +demo-smoke: + bash -lc 'uv run python scripts/run_cmm_demo.py --scenario "${CMM_DEMO_SCENARIO:-ndfeb_la_y_5pct_baseline}" --corpus-path "${CMM_CORPUS_PATH:-/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus}" --vectorstore-path "${CMM_VECTORSTORE_PATH:-cmm_vectorstore}" --output-dir "${CMM_OUTPUT_DIR:-cmm_demo_outputs}"' + +demo-prep: demo-healthcheck demo-smoke + @echo "CMM demo prep completed." + docker-build: docker buildx \ build \ diff --git a/mkdocs.yml b/mkdocs.yml old mode 100644 new mode 100755 index 51dec005..b9287458 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -34,9 +34,10 @@ nav: - Combining arXiv agent and execution agent: combining_arxiv_and_execution.md - Human in the Loop: humanInTheLoop_example.md - Neutron Star: combining_arxiv_and_execution_neutronStar.md + - Critical Minerals Workflow: critical_minerals_workflow.md + - CMM Demo Runbook: cmm_demo_runbook.md - API Reference: - agents: api_reference/agents.md - prompt_library: api_reference/prompt_library.md - tools: api_reference/tools.md - util: api_reference/util.md - diff --git a/pyproject.toml b/pyproject.toml old mode 100644 new mode 100755 index 000c6e2f..1057f846 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ "typer>=0.16.1", "trafilatura>=1.6.1,<1.7", "selectolax>=0.4.0,<0.5", + "openai>=1.0.0,<3.0", + "rank-bm25>=0.2.2,<0.3", # "langchain-google-genai>=2.1.9,<3.0", # "langchain-anthropic>=0.3.19,<0.4", "langchain>=1.0.3", @@ -84,6 +86,16 @@ otel = [ "opentelemetry-exporter-otlp>=1.39.0", "opentelemetry-sdk>=1.38.0", ] +cmm = [ + "cohere>=5.11.0", + "sentence-transformers>=3.2.0", + "weaviate-client>=4.9.0", +] +dashboard = [ + "streamlit>=1.40,<2.0", + "plotly>=5.18,<6.0", + "scipy>=1.11", +] [build-system] requires = ["setuptools>=74.1,<80", "setuptools-git-versioning>=2.0,<3"] diff --git a/scripts/cmm_dashboard.py b/scripts/cmm_dashboard.py new file mode 100644 index 00000000..d87dccff --- /dev/null +++ b/scripts/cmm_dashboard.py @@ -0,0 +1,1229 @@ +"""Streamlit dashboard for CMM supply-chain optimization scenarios. + +Launch with:: + + uv run streamlit run scripts/cmm_dashboard.py +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import streamlit as st +from plotly.subplots import make_subplots + +# --------------------------------------------------------------------------- +# Ensure project src is importable +# --------------------------------------------------------------------------- +_PROJECT_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_PROJECT_ROOT / "src")) + +from ursa.tools.cmm_supply_chain_optimization_tool import ( # noqa: E402 + solve_cmm_supply_chain_optimization, +) + +# --------------------------------------------------------------------------- +# Display configuration per scenario file +# --------------------------------------------------------------------------- + +_CMM_COLORS: dict[str, str] = { + "domestic_recycled_blend": "#00d4aa", + "allied_separated_oxide": "#636efa", + "integrated_allied_metal": "#ffa15a", + "US_DEFENSE": "#ef553b", + "US_EV": "#ab63fa", + "EU_OEM": "#19d3f3", +} + +_CMM_SUPPLIER_SHORT: dict[str, str] = { + "domestic_recycled_blend": "Domestic Recycled", + "allied_separated_oxide": "Allied Oxide", + "integrated_allied_metal": "Allied Metal", +} + +_CMM_SCENARIO_LABELS: dict[str, str] = { + "ndfeb_la_y_5pct_baseline": "Baseline", + "ndfeb_la_y_5pct_quality_tightening": "Quality Tightening", + "ndfeb_la_y_5pct_supply_shock": "Supply Shock", +} + +_ND_COLORS: dict[str, str] = { + "China_consolidated": "#ef553b", + "Lynas_Australia": "#636efa", + "MP_Materials_USA": "#00d4aa", + "Neo_Performance_EU": "#ffa15a", + "Recycled_domestic_USA": "#ab63fa", + "US_DEFENSE": "#ff6692", + "US_COMMERCIAL": "#19d3f3", + "EU_AUTOMOTIVE": "#b6e880", + "EU_INDUSTRIAL": "#ff97ff", + "JP_AUTOMOTIVE": "#fecb52", + "KR_ELECTRONICS": "#00cc96", +} + +_ND_SUPPLIER_SHORT: dict[str, str] = { + "China_consolidated": "China", + "Lynas_Australia": "Lynas (AU)", + "MP_Materials_USA": "MP Materials (US)", + "Neo_Performance_EU": "Neo Perf. (EU)", + "Recycled_domestic_USA": "Recycled (US)", +} + +_ND_SCENARIO_LABELS: dict[str, str] = { + "nd_preshock_baseline": "Pre-Shock Baseline", + "nd_post_april_2025": "Post-April 2025", + "nd_post_december_2025": "Post-December 2025", +} + +LAYOUT_DEFAULTS: dict[str, object] = dict( + template="plotly_dark", + font=dict(size=14), + margin=dict(t=60, b=40, l=60, r=40), +) + +_CONFIGS_DIR = _PROJECT_ROOT / "configs" + + +# --------------------------------------------------------------------------- +# Display config selection helpers +# --------------------------------------------------------------------------- + + +def _detect_display_config( + file_stem: str, +) -> tuple[ + dict[str, str], + dict[str, str], + dict[str, str], +]: + """Return (colors, supplier_short, scenario_labels) for a file. + + Falls back to auto-generated mappings when no preset + matches. + """ + if "nd_china" in file_stem: + return _ND_COLORS, _ND_SUPPLIER_SHORT, _ND_SCENARIO_LABELS + if "cmm_demo" in file_stem: + return _CMM_COLORS, _CMM_SUPPLIER_SHORT, _CMM_SCENARIO_LABELS + return {}, {}, {} + + +def _auto_display_config( + scenarios: dict[str, object], + colors: dict[str, str], + supplier_short: dict[str, str], + scenario_labels: dict[str, str], +) -> tuple[ + dict[str, str], + dict[str, str], + dict[str, str], + list[str], + list[str], +]: + """Augment display dicts with auto-generated entries. + + Returns ``(colors, supplier_short, scenario_labels, + suppliers_list, markets_list)``. + """ + palette = [ + "#636efa", + "#ef553b", + "#00d4aa", + "#ffa15a", + "#ab63fa", + "#19d3f3", + "#ff6692", + "#b6e880", + "#ff97ff", + "#fecb52", + "#00cc96", + ] + + # Auto-generate scenario labels for unknown keys + labels = dict(scenario_labels) + for key in scenarios: + if key not in labels: + labels[key] = key.replace("_", " ").title() + + # Collect all suppliers and markets from all scenarios + all_suppliers: list[str] = [] + all_markets: list[str] = [] + seen_sup: set[str] = set() + seen_mkt: set[str] = set() + for cfg in scenarios.values(): + inp = cfg["optimization_input"] + for s in inp["suppliers"]: + name = s["name"] + if name not in seen_sup: + all_suppliers.append(name) + seen_sup.add(name) + for m in inp["demand"]: + if m not in seen_mkt: + all_markets.append(m) + seen_mkt.add(m) + + all_suppliers.sort() + all_markets.sort() + + # Auto-fill colors and short names + c = dict(colors) + s_short = dict(supplier_short) + pi = 0 + for name in all_suppliers + all_markets: + if name not in c: + c[name] = palette[pi % len(palette)] + pi += 1 + if name not in s_short and name in seen_sup: + s_short[name] = name.replace("_", " ") + + return c, s_short, labels, all_suppliers, all_markets + + +# --------------------------------------------------------------------------- +# Data loading (cached) +# --------------------------------------------------------------------------- + + +@st.cache_data(show_spinner="Loading scenario definitions...") +def load_scenarios(path_str: str) -> dict[str, object]: + """Load a scenario configuration JSON.""" + with open(path_str) as fh: + return json.load(fh) + + +@st.cache_data(show_spinner="Solving all scenarios...") +def run_all_scenarios( + scenarios_json: str, +) -> dict[str, dict[str, object]]: + """Solve the optimisation for every scenario. + + Parameters + ---------- + scenarios_json + JSON-serialised scenarios dict (used as cache key). + + Returns + ------- + dict + ``{scenario_key: result_dict}`` + """ + scenarios: dict[str, object] = json.loads(scenarios_json) + results: dict[str, dict[str, object]] = {} + for key, cfg in scenarios.items(): + results[key] = solve_cmm_supply_chain_optimization( + cfg["optimization_input"], + ) + return results + + +# --------------------------------------------------------------------------- +# Chart builders — each returns a go.Figure (or pd.DataFrame) +# --------------------------------------------------------------------------- + + +def build_summary_table( + scenarios: dict[str, object], + results: dict[str, dict[str, object]], + scenario_labels: dict[str, str], + supplier_short: dict[str, str], +) -> pd.DataFrame: + """Summary table across all non-errored scenarios.""" + rows: list[dict[str, object]] = [] + for key in scenarios: + label = scenario_labels.get(key, key) + r = results[key] + if r.get("status") == "validation_error": + continue + inp = scenarios[key]["optimization_input"] + sens = r["sensitivity_summary"] + rows.append({ + "Scenario": label, + "Status": r["status"], + "Feasible": ("\u2705" if r["feasible"] else "\u274c"), + "Objective ($)": (f"${r['objective_value']:,.0f}"), + "Avg Unit Cost": (f"${sens['average_unit_cost']:.1f}"), + "Unmet (t)": sens["unmet_demand_total"], + "Comp. Feasible": ( + "\u2705" if sens["composition_feasible"] else "\u274c" + ), + "Tolerance": inp.get("composition_tolerance", "\u2014"), + "Active Capacity": ", ".join( + supplier_short.get(s, s) + for s in sens["active_capacity_constraints"] + ) + or "None", + }) + return pd.DataFrame(rows) + + +def build_sankey( + result: dict[str, object], + suppliers_list: list[str], + markets_list: list[str], + colors: dict[str, str], + supplier_short: dict[str, str], +) -> go.Figure: + """Supplier -> Market allocation Sankey.""" + labels = [supplier_short.get(s, s) for s in suppliers_list] + markets_list + node_colors = [colors.get(s, "#888") for s in suppliers_list] + [ + colors.get(m, "#888") for m in markets_list + ] + + sources: list[int] = [] + targets: list[int] = [] + values: list[float] = [] + link_colors: list[str] = [] + for alloc in result["allocations"]: + sup_name = alloc["supplier"] + mkt_name = alloc["market"] + if sup_name not in suppliers_list: + continue + if mkt_name not in markets_list: + continue + s_idx = suppliers_list.index(sup_name) + t_idx = len(suppliers_list) + markets_list.index(mkt_name) + sources.append(s_idx) + targets.append(t_idx) + values.append(alloc["amount"]) + c = colors.get(sup_name, "#888888") + r, g, b = ( + int(c[1:3], 16), + int(c[3:5], 16), + int(c[5:7], 16), + ) + link_colors.append(f"rgba({r},{g},{b},0.45)") + + fig = go.Figure( + go.Sankey( + arrangement="snap", + node=dict( + pad=20, + thickness=30, + label=labels, + color=node_colors, + ), + link=dict( + source=sources, + target=targets, + value=values, + color=link_colors, + ), + ) + ) + fig.update_layout( + title="Supplier \u2192 Market Flows (tonnes)", + **LAYOUT_DEFAULTS, + height=450, + ) + return fig + + +def build_cost_waterfall( + result: dict[str, object], +) -> go.Figure: + """Objective-function cost waterfall.""" + bd = result["objective_breakdown"] + cost_items = [ + ("Procurement", bd["procurement"]), + ("Shipping", bd["shipping"]), + ("Risk Penalty", bd["risk_penalty"]), + ("Unmet Penalty", bd["unmet_penalty"]), + ] + + fig = go.Figure( + go.Waterfall( + x=[item[0] for item in cost_items] + ["Total"], + y=[item[1] for item in cost_items] + [0], + measure=["relative"] * len(cost_items) + ["total"], + text=[f"${v:,.0f}" for _, v in cost_items] + + [f"${result['objective_value']:,.0f}"], + textposition="outside", + connector=dict(line=dict(color="rgba(63,63,63,0.6)")), + increasing=dict(marker=dict(color="#636efa")), + decreasing=dict(marker=dict(color="#00d4aa")), + totals=dict(marker=dict(color="#ffa15a")), + ) + ) + fig.update_layout( + title="Objective Function Breakdown", + yaxis_title="Cost ($)", + **LAYOUT_DEFAULTS, + ) + return fig + + +def build_capacity_gauges( + result: dict[str, object], + scenario_input: dict[str, object], + colors: dict[str, str], + supplier_short: dict[str, str], +) -> go.Figure: + """Supplier capacity utilisation gauge charts.""" + n_suppliers = len(scenario_input["suppliers"]) + fig = make_subplots( + rows=1, + cols=n_suppliers, + specs=[[{"type": "indicator"}] * n_suppliers], + subplot_titles=[ + supplier_short.get(s["name"], s["name"]) + for s in scenario_input["suppliers"] + ], + ) + + for i, supplier_cfg in enumerate(scenario_input["suppliers"]): + name = supplier_cfg["name"] + capacity = supplier_cfg["capacity"] + used = sum( + a["amount"] for a in result["allocations"] if a["supplier"] == name + ) + + fig.add_trace( + go.Indicator( + mode="gauge+number+delta", + value=used, + number=dict(suffix=" t"), + delta=dict( + reference=capacity, + relative=False, + suffix=" t", + ), + gauge=dict( + axis=dict(range=[0, capacity]), + bar=dict(color=colors.get(name, "#888")), + bgcolor="rgba(50,50,50,0.3)", + steps=[ + dict( + range=[0, capacity * 0.7], + color="rgba(50,50,50,0.15)", + ), + dict( + range=[ + capacity * 0.7, + capacity * 0.9, + ], + color=("rgba(255,161,90,0.15)"), + ), + dict( + range=[ + capacity * 0.9, + capacity, + ], + color="rgba(239,85,59,0.15)", + ), + ], + threshold=dict( + line=dict(color="white", width=2), + value=capacity, + ), + ), + ), + row=1, + col=i + 1, + ) + + fig.update_layout( + title="Supplier Capacity Utilization", + **LAYOUT_DEFAULTS, + height=350, + ) + return fig + + +def build_composition_bullets( + result: dict[str, object], +) -> go.Figure | None: + """Composition feasibility bullet chart. + + Returns ``None`` when no composition data is present. + """ + comp = result.get("composition") + if not comp: + return None + + components = list(comp["targets"].keys()) + fig = make_subplots( + rows=1, + cols=len(components), + specs=[[{"type": "indicator"}] * len(components)], + subplot_titles=[f"{c} Fraction" for c in components], + ) + + for i, component in enumerate(components): + target = comp["targets"][component] + actual = comp["actual"][component] + tol = comp["tolerance"] + + fig.add_trace( + go.Indicator( + mode="number+gauge+delta", + value=actual, + number=dict(valueformat=".4f"), + delta=dict(reference=target, valueformat=".4f"), + gauge=dict( + shape="bullet", + axis=dict( + range=[ + target - 4 * tol, + target + 4 * tol, + ] + ), + bar=dict( + color=( + "#00d4aa" + if abs(actual - target) <= tol + else "#ef553b" + ) + ), + steps=[ + dict( + range=[ + target - tol, + target + tol, + ], + color="rgba(0,212,170,0.2)", + ), + ], + threshold=dict( + line=dict(color="white", width=3), + thickness=0.8, + value=target, + ), + ), + ), + row=1, + col=i + 1, + ) + + feas_label = "FEASIBLE" if comp["feasible"] else "INFEASIBLE" + fig.update_layout( + title=( + f"Composition Constraints \u2014 {feas_label}" + f" (tolerance \u00b1{comp['tolerance']})" + ), + **LAYOUT_DEFAULTS, + height=250, + ) + return fig + + +def build_multi_scenario_cost_bar( + results: dict[str, dict[str, object]], + scenario_labels: dict[str, str], +) -> go.Figure: + """Stacked cost-breakdown bar across scenarios.""" + rows: list[dict[str, object]] = [] + for key in results: + label = scenario_labels.get(key, key) + r = results[key] + if r.get("status") == "validation_error": + continue + bd = r["objective_breakdown"] + for cost_type, value in bd.items(): + rows.append({ + "Scenario": label, + "Cost Component": cost_type.replace("_", " ").title(), + "Value": value, + }) + + df = pd.DataFrame(rows) + fig = px.bar( + df, + x="Scenario", + y="Value", + color="Cost Component", + barmode="stack", + title=("Objective Cost Breakdown \u2014 All Scenarios"), + color_discrete_sequence=[ + "#636efa", + "#00d4aa", + "#ffa15a", + "#ef553b", + ], + template="plotly_dark", + ) + fig.update_layout(**LAYOUT_DEFAULTS, yaxis_title="Cost ($)") + return fig + + +def build_allocation_comparison( + results: dict[str, dict[str, object]], + scenario_labels: dict[str, str], + supplier_short: dict[str, str], + colors: dict[str, str], +) -> go.Figure: + """Faceted stacked allocation bar.""" + rows: list[dict[str, object]] = [] + for key in results: + label = scenario_labels.get(key, key) + r = results[key] + if r.get("status") == "validation_error": + continue + for alloc in r["allocations"]: + rows.append({ + "Scenario": label, + "Supplier": supplier_short.get( + alloc["supplier"], + alloc["supplier"], + ), + "Market": alloc["market"], + "Amount": alloc["amount"], + }) + + df = pd.DataFrame(rows) + color_map = { + supplier_short.get(k, k): v + for k, v in colors.items() + if k in supplier_short + } + fig = px.bar( + df, + x="Market", + y="Amount", + color="Supplier", + facet_col="Scenario", + barmode="stack", + title=("Allocation by Market & Supplier \u2014 All Scenarios"), + color_discrete_map=color_map, + template="plotly_dark", + ) + fig.update_layout( + **LAYOUT_DEFAULTS, + height=450, + yaxis_title="Tonnes", + ) + return fig + + +def build_risk_cost_scatter( + scenarios: dict[str, object], + results: dict[str, dict[str, object]], + scenario_labels: dict[str, str], + supplier_short: dict[str, str], + colors: dict[str, str], +) -> go.Figure: + """Bubble scatter: risk vs cost by supplier.""" + rows: list[dict[str, object]] = [] + for key in results: + label = scenario_labels.get(key, key) + r = results[key] + if r.get("status") == "validation_error": + continue + inp = scenarios[key]["optimization_input"] + for s in inp["suppliers"]: + used = sum( + a["amount"] + for a in r["allocations"] + if a["supplier"] == s["name"] + ) + rows.append({ + "Scenario": label, + "Supplier": supplier_short.get(s["name"], s["name"]), + "Unit Cost": s["unit_cost"], + "Risk Score": s["risk_score"], + "Allocated (t)": used, + "Capacity": s["capacity"], + }) + + df = pd.DataFrame(rows) + color_map = { + supplier_short.get(k, k): v + for k, v in colors.items() + if k in supplier_short + } + fig = px.scatter( + df, + x="Unit Cost", + y="Risk Score", + size="Allocated (t)", + color="Supplier", + symbol="Scenario", + hover_data=["Capacity", "Allocated (t)"], + title=("Supplier Risk vs. Cost \u2014 Bubble Size = Allocation"), + size_max=50, + color_discrete_map=color_map, + template="plotly_dark", + ) + fig.update_layout(**LAYOUT_DEFAULTS, height=500) + return fig + + +def build_shadow_price_bars( + result: dict[str, object], + supplier_short: dict[str, str], +) -> tuple[go.Figure, go.Figure, go.Figure | None] | None: + """Build demand + capacity + composition shadow price bar charts. + + Returns ``None`` when no shadow prices are present. + Otherwise returns ``(demand_fig, capacity_fig, composition_fig)``, + where ``composition_fig`` may be ``None`` if no composition duals + exist. + """ + sp = result.get("shadow_prices") + if not sp: + return None + + # Demand shadow prices + demand_sp = sp.get("demand_balance", {}) + if demand_sp: + df_d = pd.DataFrame([ + {"Market": k, "Shadow Price ($/t)": v} for k, v in demand_sp.items() + ]) + fig_d = px.bar( + df_d, + x="Market", + y="Shadow Price ($/t)", + title="Demand Shadow Prices (Marginal Cost)", + template="plotly_dark", + color_discrete_sequence=["#636efa"], + ) + fig_d.update_layout(**LAYOUT_DEFAULTS) + else: + fig_d = go.Figure() + fig_d.update_layout( + title="No demand shadow prices", + **LAYOUT_DEFAULTS, + ) + + # Capacity + share shadow prices (stacked by type) + cap_sp = sp.get("supplier_capacity", {}) + share_sp = sp.get("supplier_share_cap", {}) + all_names = sorted(set(list(cap_sp.keys()) + list(share_sp.keys()))) + + if all_names: + cap_rows: list[dict[str, object]] = [] + for name in all_names: + short = supplier_short.get(name, name) + cap_val = cap_sp.get(name, 0.0) + share_val = share_sp.get(name, 0.0) + if abs(cap_val) > 1e-9: + cap_rows.append({ + "Supplier": short, + "Constraint": "Capacity", + "Shadow Price ($/t)": cap_val, + }) + if abs(share_val) > 1e-9: + cap_rows.append({ + "Supplier": short, + "Constraint": "Share Cap", + "Shadow Price ($/t)": share_val, + }) + if abs(cap_val) <= 1e-9 and abs(share_val) <= 1e-9: + cap_rows.append({ + "Supplier": short, + "Constraint": "Capacity", + "Shadow Price ($/t)": 0.0, + }) + + df_c = pd.DataFrame(cap_rows) + fig_c = px.bar( + df_c, + x="Supplier", + y="Shadow Price ($/t)", + color="Constraint", + barmode="relative", + title="Capacity Shadow Prices (Marginal Value of +1t)", + template="plotly_dark", + color_discrete_map={ + "Capacity": "#ef553b", + "Share Cap": "#ffa15a", + }, + ) + fig_c.update_layout(**LAYOUT_DEFAULTS) + else: + fig_c = go.Figure() + fig_c.update_layout( + title="No capacity shadow prices", + **LAYOUT_DEFAULTS, + ) + + # Composition shadow prices (if present) + comp_sp = sp.get("composition", {}) + fig_comp: go.Figure | None = None + if any(abs(v) > 1e-9 for v in comp_sp.values()): + df_comp = pd.DataFrame([ + {"Component": k, "Shadow Price ($/unit)": v} + for k, v in comp_sp.items() + if abs(v) > 1e-9 + ]) + fig_comp = px.bar( + df_comp, + x="Component", + y="Shadow Price ($/unit)", + title="Composition Constraint Shadow Prices", + template="plotly_dark", + color_discrete_sequence=["#00d4aa"], + ) + fig_comp.update_layout(**LAYOUT_DEFAULTS) + + return fig_d, fig_c, fig_comp + + +def build_multi_scenario_shadow_comparison( + results: dict[str, dict[str, object]], + scenario_labels: dict[str, str], + supplier_short: dict[str, str], +) -> dict[str, go.Figure]: + """Build shadow price comparison charts across scenarios. + + Returns a dict with keys ``"demand"``, ``"supply"``, and + optionally ``"composition"``. Returns empty dict when no + shadow prices are present in any result. + """ + demand_rows: list[dict[str, object]] = [] + supply_rows: list[dict[str, object]] = [] + comp_rows: list[dict[str, object]] = [] + any_shadow = False + + for key in results: + label = scenario_labels.get(key, key) + r = results[key] + sp = r.get("shadow_prices") + if not sp: + continue + any_shadow = True + + # Demand shadow prices per market + for mkt, val in sp.get("demand_balance", {}).items(): + demand_rows.append({ + "Scenario": label, + "Market": mkt, + "Shadow Price ($/t)": val, + }) + + # Supply shadow prices (capacity + share) per supplier + cap_sp = sp.get("supplier_capacity", {}) + share_sp = sp.get("supplier_share_cap", {}) + all_names = sorted(set(list(cap_sp.keys()) + list(share_sp.keys()))) + for name in all_names: + combined = cap_sp.get(name, 0.0) + share_sp.get(name, 0.0) + supply_rows.append({ + "Scenario": label, + "Supplier": supplier_short.get(name, name), + "Shadow Price ($/t)": abs(combined), + }) + + # Composition shadow prices + for comp_name, val in sp.get("composition", {}).items(): + if abs(val) > 1e-9: + comp_rows.append({ + "Scenario": label, + "Component": comp_name, + "Shadow Price ($/unit)": abs(val), + }) + + if not any_shadow: + return {} + + figs: dict[str, go.Figure] = {} + + # --- Demand shadow prices comparison --- + if demand_rows: + df_d = pd.DataFrame(demand_rows) + fig_d = px.bar( + df_d, + x="Market", + y="Shadow Price ($/t)", + color="Scenario", + barmode="group", + title=( + "Demand Shadow Prices Across Scenarios" + " (Marginal Cost per Market)" + ), + template="plotly_dark", + ) + fig_d.update_layout(**LAYOUT_DEFAULTS, height=400) + + # Use log scale if range exceeds 100x + vals = [ + r["Shadow Price ($/t)"] + for r in demand_rows + if r["Shadow Price ($/t)"] > 0 + ] + if vals and max(vals) / max(min(vals), 0.01) > 100: + fig_d.update_yaxes(type="log") + + figs["demand"] = fig_d + + # --- Supply constraint shadow prices comparison --- + if supply_rows: + df_s = pd.DataFrame(supply_rows) + fig_s = px.bar( + df_s, + x="Supplier", + y="Shadow Price ($/t)", + color="Scenario", + barmode="group", + title=( + "Supply Constraint Shadow Prices Across Scenarios" + " (|marginal value| of +1t capacity)" + ), + template="plotly_dark", + ) + fig_s.update_layout(**LAYOUT_DEFAULTS, height=400) + + # Use log scale if range exceeds 100x to prevent + # infeasible-scenario penalties from compressing + # feasible-scenario bars to invisibility. + vals = [ + r["Shadow Price ($/t)"] + for r in supply_rows + if r["Shadow Price ($/t)"] > 0 + ] + if vals and max(vals) / max(min(vals), 0.01) > 100: + fig_s.update_yaxes(type="log") + + figs["supply"] = fig_s + + # --- Composition constraint shadow prices comparison --- + if comp_rows: + df_c = pd.DataFrame(comp_rows) + fig_c = px.bar( + df_c, + x="Component", + y="Shadow Price ($/unit)", + color="Scenario", + barmode="group", + title=( + "Composition Constraint Shadow Prices" + " (Cost of Tightening Tolerance by 1 Unit)" + ), + template="plotly_dark", + ) + fig_c.update_layout(**LAYOUT_DEFAULTS, height=400) + + vals = [ + r["Shadow Price ($/unit)"] + for r in comp_rows + if r["Shadow Price ($/unit)"] > 0 + ] + if vals and max(vals) / max(min(vals), 0.01) > 100: + fig_c.update_yaxes(type="log") + + figs["composition"] = fig_c + + return figs + + +# --------------------------------------------------------------------------- +# Main app +# --------------------------------------------------------------------------- + + +def main() -> None: + """Entry-point for the Streamlit dashboard.""" + st.set_page_config( + page_title="CMM Optimization Dashboard", + page_icon="\U0001f9f2", + layout="wide", + ) + + # --- Scenario file selector --- + scenario_files = sorted(_CONFIGS_DIR.glob("*_scenarios.json")) + if not scenario_files: + st.error( + "No scenario files found in configs/. " + "Expected *_scenarios.json files." + ) + return + + file_labels = {f.stem: f for f in scenario_files} + + st.sidebar.title("\U0001f9f2 CMM Dashboard") + selected_file_stem = st.sidebar.selectbox( + "Scenario File", + options=list(file_labels.keys()), + format_func=lambda s: s.replace("_", " ").title(), + ) + selected_path = file_labels[selected_file_stem] + + # --- Load & solve --- + scenarios = load_scenarios(str(selected_path)) + scenarios_json = json.dumps(scenarios) + results = run_all_scenarios(scenarios_json) + + # --- Resolve display config --- + ( + base_colors, + base_supplier_short, + base_scenario_labels, + ) = _detect_display_config(selected_file_stem) + + ( + colors, + supplier_short, + scenario_labels, + suppliers_list, + markets_list, + ) = _auto_display_config( + scenarios, + base_colors, + base_supplier_short, + base_scenario_labels, + ) + + # --- Sidebar scenario selector --- + scenario_key = st.sidebar.radio( + "Select Scenario", + options=list(scenarios.keys()), + format_func=lambda k: scenario_labels.get(k, k), + ) + selected_label = scenario_labels.get(scenario_key, scenario_key) + + with st.sidebar.expander("Scenario Parameters", expanded=False): + st.json(scenarios[scenario_key]["optimization_input"]) + + st.sidebar.markdown("---") + st.sidebar.caption("URSA \u2014 LANL / PNNL") + + # --- Title --- + st.title("Critical Minerals & Materials \u2014 Optimization Dashboard") + + # --- Current result / validation error check --- + result = results[scenario_key] + is_error = result.get("status") == "validation_error" + + # --- KPI metrics row --- + if is_error: + st.error( + f"**{selected_label}**: Validation error" + f" \u2014 {result.get('errors', [])}" + ) + else: + sens = result["sensitivity_summary"] + c1, c2, c3, c4, c5 = st.columns(5) + c1.metric( + "Status", + result["status"].replace("_", " ").title(), + ) + c2.metric( + "Objective", + f"${result['objective_value']:,.0f}", + ) + c3.metric( + "Avg Unit Cost", + f"${sens['average_unit_cost']:.1f}", + ) + c4.metric( + "Unmet Demand", + f"{sens['unmet_demand_total']:.1f} t", + ) + c5.metric( + "Feasible", + "\u2705 Yes" if result["feasible"] else "\u274c No", + ) + + # --- Tabs --- + tab_single, tab_multi = st.tabs([ + "\U0001f50d Single Scenario", + "\U0001f4ca Multi-Scenario Compare", + ]) + + # ===== Single-scenario tab ==================================== + with tab_single: + if is_error: + st.warning( + "Charts unavailable for this scenario due to validation errors." + ) + else: + scenario_input = scenarios[scenario_key]["optimization_input"] + + # Row 1: Sankey + Waterfall + col_left, col_right = st.columns(2) + with col_left: + st.plotly_chart( + build_sankey( + result, + suppliers_list, + markets_list, + colors, + supplier_short, + ), + use_container_width=True, + ) + with col_right: + st.plotly_chart( + build_cost_waterfall(result), + use_container_width=True, + ) + + # Row 2: Capacity gauges (full width) + st.plotly_chart( + build_capacity_gauges( + result, + scenario_input, + colors, + supplier_short, + ), + use_container_width=True, + ) + + # Row 3: Composition bullets (full width) + comp_fig = build_composition_bullets(result) + if comp_fig is not None: + st.plotly_chart( + comp_fig, + use_container_width=True, + ) + + # Row 4: Shadow prices (LP only) + sp_figs = build_shadow_price_bars( + result, + supplier_short, + ) + if sp_figs is not None: + fig_d, fig_c, fig_comp = sp_figs + st.subheader("Shadow Prices (LP Duals)") + sp_left, sp_right = st.columns(2) + with sp_left: + st.plotly_chart( + fig_d, + use_container_width=True, + ) + with sp_right: + st.plotly_chart( + fig_c, + use_container_width=True, + ) + if fig_comp is not None: + st.plotly_chart( + fig_comp, + use_container_width=True, + ) + st.caption( + "**Demand shadow prices** show the" + " marginal cost of supplying one" + " additional tonne to each market." + " **Capacity shadow prices** show" + " the cost reduction from adding" + " one tonne of supplier capacity" + " (split by capacity vs. share cap" + " constraint). A supplier with zero" + " capacity shadow price may be" + " bottlenecked by **composition**" + " constraints instead." + ) + + # ===== Multi-scenario tab ===================================== + with tab_multi: + valid_keys = [ + k + for k in scenarios + if results[k].get("status") != "validation_error" + ] + if not valid_keys: + st.warning("No valid scenario results to compare.") + else: + # Summary table + st.subheader("Scenario Summary") + summary_df = build_summary_table( + scenarios, + results, + scenario_labels, + supplier_short, + ) + st.dataframe( + summary_df, + use_container_width=True, + hide_index=True, + ) + + # Row 1: Cost bar + Risk scatter + col_left, col_right = st.columns(2) + with col_left: + st.plotly_chart( + build_multi_scenario_cost_bar( + results, + scenario_labels, + ), + use_container_width=True, + ) + with col_right: + st.plotly_chart( + build_risk_cost_scatter( + scenarios, + results, + scenario_labels, + supplier_short, + colors, + ), + use_container_width=True, + ) + + # Row 2: Allocation comparison (full width) + st.plotly_chart( + build_allocation_comparison( + results, + scenario_labels, + supplier_short, + colors, + ), + use_container_width=True, + ) + + # Row 3: Shadow price comparison + shadow_figs = build_multi_scenario_shadow_comparison( + results, + scenario_labels, + supplier_short, + ) + if shadow_figs: + st.subheader("Shadow Price Comparison Across Scenarios") + + # Demand shadow prices — best indicator of system tightening + if "demand" in shadow_figs: + st.plotly_chart( + shadow_figs["demand"], + use_container_width=True, + ) + st.caption( + "**Demand shadow prices** are the" + " most direct indicator of market" + " tightness. Higher values mean it" + " costs more to serve one additional" + " tonne in that market. These should" + " increase monotonically as supply" + " tightens." + ) + + # Supply constraint shadow prices + if "supply" in shadow_figs: + st.plotly_chart( + shadow_figs["supply"], + use_container_width=True, + ) + st.caption( + "**Supply constraint shadow prices**" + " show the value of adding one tonne" + " of capacity from each supplier." + " A supplier with a low value under" + " tight supply may be bottlenecked" + " by composition or share constraints" + " rather than capacity." + " Log scale is used when infeasible" + " scenarios create large penalties." + ) + + # Composition constraint shadow prices + if "composition" in shadow_figs: + st.plotly_chart( + shadow_figs["composition"], + use_container_width=True, + ) + st.caption( + "**Composition constraint shadow" + " prices** show the cost of" + " tightening specification tolerances." + " Large values indicate that blend" + " quality requirements are the" + " dominant bottleneck, not raw" + " capacity." + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo_healthcheck.py b/scripts/demo_healthcheck.py new file mode 100644 index 00000000..c5f5632d --- /dev/null +++ b/scripts/demo_healthcheck.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +from __future__ import annotations + +import argparse +import json +import os +from dataclasses import dataclass +from pathlib import Path + +from langchain.chat_models import init_chat_model + +try: + from dotenv import load_dotenv +except Exception: # pragma: no cover - optional import + load_dotenv = None # type: ignore[assignment] + + +DEFAULT_CORPUS = ( + "/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus" +) + + +@dataclass +class CheckResult: + name: str + status: str + detail: str + + +def _check_env(required: list[str]) -> list[CheckResult]: + results = [] + for key in required: + value = os.getenv(key, "").strip() + if value: + results.append(CheckResult(key, "PASS", "set")) + else: + results.append(CheckResult(key, "FAIL", "missing")) + return results + + +def _count_manifest_ids(path: Path) -> int: + if not path.exists(): + return 0 + return len([line for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]) + + +def _corpus_scan(corpus_path: Path, sample_limit: int = 5000) -> tuple[int, int]: + visited = 0 + matched = 0 + valid_ext = {".pdf", ".txt", ".md", ".csv", ".json", ".xml"} + for path in corpus_path.rglob("*"): + if not path.is_file(): + continue + visited += 1 + if path.suffix.lower() in valid_ext: + matched += 1 + if visited >= sample_limit: + break + return visited, matched + + +def _model_connectivity_check(model_name: str) -> CheckResult: + kwargs = { + "model": model_name, + "temperature": 0, + } + base_url = os.getenv("OPENAI_BASE_URL") + api_key = os.getenv("OPENAI_API_KEY") + if base_url: + kwargs["base_url"] = base_url + if api_key: + kwargs["api_key"] = api_key + + try: + model = init_chat_model(**kwargs) + response = model.invoke("Reply exactly with: connectivity_ok") + text = getattr(response, "content", "") + if isinstance(text, list): + text = str(text) + text = str(text) + if "connectivity_ok" in text: + return CheckResult("model_connectivity", "PASS", text) + return CheckResult("model_connectivity", "WARN", text) + except Exception as exc: + return CheckResult("model_connectivity", "FAIL", str(exc)) + + +def _print_results(results: list[CheckResult]) -> tuple[int, int, int]: + passed = sum(1 for result in results if result.status == "PASS") + warned = sum(1 for result in results if result.status == "WARN") + failed = sum(1 for result in results if result.status == "FAIL") + + for result in results: + print(f"[{result.status}] {result.name}: {result.detail}") + + print( + f"\nSummary: PASS={passed} WARN={warned} FAIL={failed}" + ) + return passed, warned, failed + + +def main() -> int: + parser = argparse.ArgumentParser( + description="CMM demo readiness healthcheck." + ) + parser.add_argument( + "--corpus-path", + default=os.getenv("CMM_CORPUS_PATH", DEFAULT_CORPUS), + ) + parser.add_argument( + "--vectorstore-path", + default=os.getenv("CMM_VECTORSTORE_PATH", "cmm_vectorstore"), + ) + parser.add_argument( + "--scenarios-path", + default="configs/cmm_demo_scenarios.json", + ) + parser.add_argument( + "--model", + default=os.getenv("CMM_RAG_MODEL", "openai:gpt-5-nano"), + ) + parser.add_argument("--skip-model-check", action="store_true") + args = parser.parse_args() + + if load_dotenv is not None: + load_dotenv() + + corpus_path = Path(args.corpus_path).expanduser().resolve() + vectorstore_path = Path(args.vectorstore_path).expanduser().resolve() + scenarios_path = Path(args.scenarios_path).expanduser().resolve() + + results: list[CheckResult] = [] + results.extend(_check_env(["OPENAI_API_KEY", "OPENAI_BASE_URL"])) + + if corpus_path.exists() and corpus_path.is_dir(): + visited, matched = _corpus_scan(corpus_path) + results.append( + CheckResult( + "corpus_path", + "PASS" if matched > 0 else "WARN", + f"exists, sampled_files={visited}, sampled_cmm_ext_matches={matched}", + ) + ) + else: + results.append( + CheckResult("corpus_path", "FAIL", f"not found: {corpus_path}") + ) + + if scenarios_path.exists(): + try: + payload = json.loads(scenarios_path.read_text(encoding="utf-8")) + scenario_count = len(payload.keys()) if isinstance(payload, dict) else 0 + status = "PASS" if scenario_count > 0 else "WARN" + results.append( + CheckResult( + "scenarios_config", + status, + f"path={scenarios_path}, scenarios={scenario_count}", + ) + ) + except Exception as exc: + results.append( + CheckResult("scenarios_config", "FAIL", f"invalid json: {exc}") + ) + else: + results.append( + CheckResult( + "scenarios_config", + "FAIL", + f"missing file: {scenarios_path}", + ) + ) + + if vectorstore_path.exists() and vectorstore_path.is_dir(): + manifest_count = _count_manifest_ids(vectorstore_path / "_ingested_ids.txt") + has_index_files = any(vectorstore_path.iterdir()) + status = "PASS" if has_index_files else "WARN" + results.append( + CheckResult( + "vectorstore_path", + status, + f"exists={has_index_files}, manifest_docs={manifest_count}, path={vectorstore_path}", + ) + ) + if manifest_count == 0: + results.append( + CheckResult( + "vectorstore_manifest", + "WARN", + "manifest missing or empty; run scripts/reindex.py", + ) + ) + else: + results.append( + CheckResult( + "vectorstore_manifest", + "PASS", + f"manifest doc ids: {manifest_count}", + ) + ) + else: + results.append( + CheckResult( + "vectorstore_path", + "FAIL", + f"not found: {vectorstore_path}", + ) + ) + + if args.skip_model_check: + results.append( + CheckResult("model_connectivity", "WARN", "skipped by flag") + ) + else: + results.append(_model_connectivity_check(args.model)) + + _, _, failed = _print_results(results) + return 1 if failed > 0 else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/reindex.py b/scripts/reindex.py new file mode 100644 index 00000000..68827552 --- /dev/null +++ b/scripts/reindex.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +from __future__ import annotations + +import argparse +import os +from collections import Counter +from pathlib import Path +from typing import Iterable + +from tqdm import tqdm + +from ursa.agents.cmm_chunker import CMMChunker +from ursa.agents.cmm_embeddings import init_embeddings +from ursa.agents.cmm_vectorstore import init_vectorstore +from ursa.util.parse import ( + OFFICE_EXTENSIONS, + SPECIAL_TEXT_FILENAMES, + TEXT_EXTENSIONS, + read_text_from_file, +) + +EXCLUDED_DIR_NAMES = { + ".git", + ".hg", + ".svn", + ".venv", + "venv", + "env", + "__pycache__", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".tox", + ".cache", + "node_modules", + "site-packages", +} + + +def _normalize_extension(value: str) -> str: + value = value.strip().lower() + if not value: + return "" + return value if value.startswith(".") else f".{value}" + + +def _normalize_extensions(values: Iterable[str] | None) -> set[str]: + if not values: + return set() + return {ext for ext in (_normalize_extension(v) for v in values) if ext} + + +def _iter_ingestible_files( + corpus_path: Path, + include_extensions: set[str] | None, + exclude_extensions: set[str], +) -> list[Path]: + files: list[Path] = [] + for root, dirnames, filenames in os.walk(corpus_path): + dirnames[:] = [ + d + for d in dirnames + if d.lower() not in EXCLUDED_DIR_NAMES + ] + root_path = Path(root) + for filename_raw in filenames: + filename = filename_raw.lower() + path = root_path / filename_raw + ext = path.suffix.lower() + ingestible = ( + ext == ".pdf" + or ext in TEXT_EXTENSIONS + or filename in SPECIAL_TEXT_FILENAMES + or ext in OFFICE_EXTENSIONS + ) + if not ingestible: + continue + if ext in exclude_extensions: + continue + if include_extensions is not None and ext not in include_extensions: + if filename not in SPECIAL_TEXT_FILENAMES: + continue + files.append(path) + return files + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Backend-agnostic CMM corpus reindex utility." + ) + parser.add_argument("--corpus-path", required=True) + parser.add_argument("--vectorstore-path", default="cmm_vectorstore") + parser.add_argument("--backend", default="chroma") + parser.add_argument( + "--embedding-model", + default="openai:text-embedding-3-large", + ) + parser.add_argument("--embedding-dimensions", type=int, default=3072) + parser.add_argument( + "--embedding-batch-size", + type=int, + default=20, + ) + parser.add_argument("--collection-name", default="cmm_chunks") + parser.add_argument("--chunk-size", type=int, default=1000) + parser.add_argument("--chunk-overlap", type=int, default=200) + parser.add_argument("--min-chars", type=int, default=30) + parser.add_argument( + "--include-extension", + action="append", + dest="include_extensions", + default=None, + ) + parser.add_argument( + "--exclude-extension", + action="append", + dest="exclude_extensions", + default=[], + ) + parser.add_argument("--max-docs", type=int, default=0) + parser.add_argument( + "--max-chunks-per-doc", + type=int, + default=0, + help="Cap chunks ingested from each source document (0 = no cap).", + ) + parser.add_argument( + "--skip-existing", + action=argparse.BooleanOptionalAction, + default=True, + help="Skip source files already present in _ingested_ids.txt.", + ) + parser.add_argument( + "--flush-docs", + type=int, + default=50, + help="Number of source documents to batch before vectorstore insert.", + ) + parser.add_argument("--reset", action="store_true") + args = parser.parse_args() + + corpus_path = Path(args.corpus_path).expanduser().resolve() + vectorstore_path = Path(args.vectorstore_path).expanduser().resolve() + vectorstore_path.mkdir(parents=True, exist_ok=True) + + include_extensions = ( + _normalize_extensions(args.include_extensions) + if args.include_extensions + else None + ) + exclude_extensions = _normalize_extensions(args.exclude_extensions) + + embedding = init_embeddings( + args.embedding_model, + dimensions=args.embedding_dimensions, + batch_size=max(1, args.embedding_batch_size), + ) + vectorstore = init_vectorstore( + backend=args.backend, + persist_directory=vectorstore_path, + embedding_model=embedding, + collection_name=args.collection_name, + ) + manifest_path = vectorstore_path / "_ingested_ids.txt" + if args.reset: + vectorstore.delete_collection() + if manifest_path.exists(): + manifest_path.unlink() + + chunker = CMMChunker( + max_tokens=max(64, args.chunk_size // 2), + overlap_tokens=max(0, args.chunk_overlap // 4), + min_tokens=max(20, min(args.chunk_size // 6, 120)), + ) + + files = _iter_ingestible_files( + corpus_path=corpus_path, + include_extensions=include_extensions, + exclude_extensions=exclude_extensions, + ) + existing_ids: set[str] = set() + if args.skip_existing and manifest_path.exists(): + existing_ids = { + line.strip() + for line in manifest_path.read_text(encoding="utf-8").splitlines() + if line.strip() + } + files = [path for path in files if str(path) not in existing_ids] + if args.max_docs > 0: + files = files[: args.max_docs] + + commodity_counts: Counter[str] = Counter() + subdomain_counts: Counter[str] = Counter() + docs_indexed = 0 + chunks_indexed = 0 + ingested_doc_ids: set[str] = set() + batched_docs = [] + batched_source_ids: list[str] = [] + + def flush_batch() -> None: + nonlocal batched_docs + nonlocal batched_source_ids + if not batched_docs: + return + vectorstore.add_documents(batched_docs) + batched_docs = [] + batched_source_ids = [] + + for path in tqdm(files, desc="Reindex corpus"): + text = read_text_from_file(path) + if len(text) < args.min_chars: + continue + docs = chunker.chunk_document( + text, + metadata={ + "source_doc_id": str(path), + "source_doc_title": path.name, + }, + ) + if not docs: + continue + if args.max_chunks_per_doc > 0: + docs = docs[: args.max_chunks_per_doc] + batched_docs.extend(docs) + batched_source_ids.append(str(path)) + if len(batched_source_ids) >= max(1, args.flush_docs): + flush_batch() + docs_indexed += 1 + chunks_indexed += len(docs) + ingested_doc_ids.add(str(path)) + for doc in docs: + commodity_counts.update(doc.metadata.get("commodity_tags", [])) + subdomain_counts.update(doc.metadata.get("subdomain_tags", [])) + + flush_batch() + + if manifest_path.exists(): + existing_ids = { + line.strip() + for line in manifest_path.read_text(encoding="utf-8").splitlines() + if line.strip() + } + merged_ids = sorted(existing_ids.union(ingested_doc_ids)) + manifest_path.write_text( + "\n".join(merged_ids) + ("\n" if merged_ids else ""), + encoding="utf-8", + ) + + print(f"Docs indexed: {docs_indexed}") + print(f"Chunks indexed: {chunks_indexed}") + print(f"Vectorstore count: {vectorstore.count()}") + print("Commodity tag counts:") + for tag, count in sorted(commodity_counts.items()): + print(f" {tag}: {count}") + print("Subdomain tag counts:") + for tag, count in sorted(subdomain_counts.items()): + print(f" {tag}: {count}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/run_cmm_demo.py b/scripts/run_cmm_demo.py new file mode 100644 index 00000000..98adafbb --- /dev/null +++ b/scripts/run_cmm_demo.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python +from __future__ import annotations + +import argparse +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from langchain.chat_models import init_chat_model + +from ursa.agents import ExecutionAgent, PlanningAgent, RAGAgent +from ursa.workflows import CriticalMineralsWorkflow + +try: + from dotenv import load_dotenv +except Exception: # pragma: no cover - optional import + load_dotenv = None # type: ignore[assignment] + + +DEFAULT_CORPUS = ( + "/Users/wash198/Documents/Projects/Science_Projects/MPII_CMM/Corpus" +) + + +def _parse_bool_env(name: str, default: bool = False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _json_safe(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, Path): + return str(value) + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set)): + return [_json_safe(item) for item in value] + if hasattr(value, "model_dump"): + return _json_safe(value.model_dump()) + if hasattr(value, "dict"): + return _json_safe(value.dict()) + return str(value) + + +def _load_scenario(path: Path, scenario_name: str) -> dict[str, Any]: + raw = json.loads(path.read_text(encoding="utf-8")) + if scenario_name not in raw: + available = ", ".join(sorted(raw.keys())) + raise KeyError( + f"Scenario '{scenario_name}' not found. Available: {available}" + ) + scenario = raw[scenario_name] + if not isinstance(scenario, dict): + raise ValueError("Scenario payload must be a mapping.") + return scenario + + +def _build_model(model_name: str): + kwargs: dict[str, Any] = { + "model": model_name, + "temperature": 0, + } + if base_url := os.getenv("OPENAI_BASE_URL"): + kwargs["base_url"] = base_url + if api_key := os.getenv("OPENAI_API_KEY"): + kwargs["api_key"] = api_key + return init_chat_model(**kwargs) + + +def _write_json(path: Path, payload: Any) -> None: + path.write_text( + json.dumps(_json_safe(payload), indent=2, sort_keys=True), + encoding="utf-8", + ) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Run end-to-end CMM demo workflow and emit artifacts." + ) + parser.add_argument("--scenario", default="ndfeb_la_y_5pct_baseline") + parser.add_argument( + "--scenarios-path", + default="configs/cmm_demo_scenarios.json", + ) + parser.add_argument( + "--corpus-path", + default=os.getenv("CMM_CORPUS_PATH", DEFAULT_CORPUS), + ) + parser.add_argument( + "--vectorstore-path", + default=os.getenv("CMM_VECTORSTORE_PATH", "cmm_vectorstore"), + ) + parser.add_argument( + "--summaries-path", + default=os.getenv("CMM_SUMMARIES_PATH", "cmm_summaries"), + ) + parser.add_argument( + "--workspace", + default=os.getenv("CMM_WORKSPACE_PATH", "cmm_demo_workspace"), + ) + parser.add_argument( + "--output-dir", + default=os.getenv("CMM_OUTPUT_DIR", "cmm_demo_outputs"), + ) + parser.add_argument( + "--planner-model", + default=os.getenv("CMM_PLANNER_MODEL", "openai:gpt-5"), + ) + parser.add_argument( + "--executor-model", + default=os.getenv("CMM_EXECUTOR_MODEL", "openai:gpt-5"), + ) + parser.add_argument( + "--rag-model", + default=os.getenv("CMM_RAG_MODEL", "openai:gpt-5-nano"), + ) + parser.add_argument("--print-result-json", action="store_true") + args = parser.parse_args() + + if load_dotenv is not None: + load_dotenv() + + scenario_path = Path(args.scenarios_path).expanduser().resolve() + scenario = _load_scenario(scenario_path, args.scenario) + + workspace = Path(args.workspace).expanduser().resolve() + output_root = Path(args.output_dir).expanduser().resolve() + corpus_path = Path(args.corpus_path).expanduser().resolve() + vectorstore_path = Path(args.vectorstore_path).expanduser().resolve() + summaries_path = Path(args.summaries_path).expanduser().resolve() + + workspace.mkdir(parents=True, exist_ok=True) + + planner = PlanningAgent(llm=_build_model(args.planner_model), workspace=workspace) + executor = ExecutionAgent( + llm=_build_model(args.executor_model), + workspace=workspace, + ) + + rag_agent = RAGAgent( + llm=_build_model(args.rag_model), + workspace=workspace, + database_path=corpus_path, + vectorstore_path=vectorstore_path, + summaries_path=summaries_path, + vectorstore_backend=os.getenv("CMM_VECTORSTORE_BACKEND", "chroma"), + retrieval_k=int(os.getenv("CMM_RETRIEVAL_K", "20")), + return_k=int(os.getenv("CMM_RETURN_K", "5")), + use_reranker=_parse_bool_env("CMM_USE_RERANKER", default=False), + reranker_provider=os.getenv("CMM_RERANKER_PROVIDER", "none"), + ) + + workflow = CriticalMineralsWorkflow( + planner=planner, + executor=executor, + rag_agent=rag_agent, + workspace=workspace, + ) + + payload = { + "task": scenario["task"], + "local_corpus_path": str(corpus_path), + "rag_context": scenario.get("rag_context", scenario["task"]), + "source_queries": scenario.get("source_queries", {}), + "optimization_input": scenario.get("optimization_input"), + "execution_instruction": scenario.get( + "execution_instruction", + "Produce a source-grounded synthesis with uncertainty notes.", + ), + } + + result = workflow.invoke(payload) + + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + run_dir = output_root / args.scenario / timestamp + run_dir.mkdir(parents=True, exist_ok=True) + + _write_json(run_dir / "input_payload.json", payload) + _write_json(run_dir / "workflow_result.json", result) + _write_json(run_dir / "optimization_output.json", result.get("optimization")) + _write_json(run_dir / "rag_metadata.json", result.get("rag", {})) + + final_summary = str(result.get("final_summary", "")).strip() + (run_dir / "final_summary.md").write_text( + final_summary + "\n", + encoding="utf-8", + ) + + latest_dir = output_root / args.scenario / "latest" + latest_dir.mkdir(parents=True, exist_ok=True) + (latest_dir / "final_summary.md").write_text( + final_summary + "\n", encoding="utf-8" + ) + _write_json(latest_dir / "workflow_result.json", result) + + print(f"Scenario: {args.scenario}") + print(f"Run artifacts: {run_dir}") + print("Final summary preview:") + print(final_summary[:1200]) + + if args.print_result_json: + print("\nFull result JSON:") + print(json.dumps(_json_safe(result), indent=2, sort_keys=True)) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/ursa/agents/__init__.py b/src/ursa/agents/__init__.py index 3b44c8b9..8e8916be 100644 --- a/src/ursa/agents/__init__.py +++ b/src/ursa/agents/__init__.py @@ -6,7 +6,6 @@ "ArxivAgent": (".acquisition_agents", "ArxivAgent"), "OSTIAgent": (".acquisition_agents", "OSTIAgent"), "WebSearchAgent": (".acquisition_agents", "WebSearchAgent"), - "ArxivAgentLegacy": (".arxiv_agent", "ArxivAgentLegacy"), "BaseAgent": (".base", "BaseAgent"), "BaseChatModel": (".base", "BaseChatModel"), "ChatAgent": (".chat_agent", "ChatAgent"), @@ -18,7 +17,11 @@ "PlanningAgent": (".planning_agent", "PlanningAgent"), "RAGAgent": (".rag_agent", "RAGAgent"), "RecallAgent": (".recall_agent", "RecallAgent"), - "WebSearchAgentLegacy": (".websearch_agent", "WebSearchAgentLegacy"), +} + +_retired_attrs: dict[str, str] = { + "ArxivAgentLegacy": "Use ArxivAgent from ursa.agents instead.", + "WebSearchAgentLegacy": "Use WebSearchAgent from ursa.agents instead.", } __all__ = list(_lazy_attrs.keys()) @@ -30,6 +33,11 @@ def __getattr__(name: str) -> Any: This avoids importing all agent modules at package import time, so a failure in one agent does not prevent using others. """ + if name in _retired_attrs: + raise AttributeError( + f"{name} has been retired. {_retired_attrs[name]}" + ) from None + try: module_name, attr_name = _lazy_attrs[name] except KeyError: diff --git a/src/ursa/agents/acquisition_agents.py b/src/ursa/agents/acquisition_agents.py index 8b5cd21f..cc6e6f4d 100644 --- a/src/ursa/agents/acquisition_agents.py +++ b/src/ursa/agents/acquisition_agents.py @@ -43,6 +43,8 @@ except Exception: OpenAI = None +VISION_MODEL = os.environ.get("URSA_VISION_MODEL", "gpt-4o-mini") + # ---------- Shared State / Types ---------- @@ -120,7 +122,7 @@ def describe_image(image: Image.Image) -> str: img_b64 = base64.b64encode(buf.getvalue()).decode() resp = client.chat.completions.create( - model="gpt-4-vision-preview", + model=VISION_MODEL, messages=[ { "role": "system", diff --git a/src/ursa/agents/arxiv_agent.py b/src/ursa/agents/arxiv_agent.py deleted file mode 100644 index 5055bbd9..00000000 --- a/src/ursa/agents/arxiv_agent.py +++ /dev/null @@ -1,384 +0,0 @@ -import base64 -import os -import re -from concurrent.futures import ThreadPoolExecutor, as_completed -from io import BytesIO -from typing import TypedDict -from urllib.parse import quote - -import feedparser -import pymupdf -import requests -from langchain.chat_models import BaseChatModel -from langchain_community.document_loaders import PyPDFLoader -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import ChatPromptTemplate -from PIL import Image -from tqdm import tqdm - -from ursa.agents.base import BaseAgent -from ursa.agents.rag_agent import RAGAgent - -try: - from openai import OpenAI -except Exception: - pass - - -class PaperMetadata(TypedDict): - arxiv_id: str - full_text: str - - -class PaperState(TypedDict, total=False): - query: str - context: str - papers: list[PaperMetadata] - summaries: list[str] - final_summary: str - - -def describe_image(image: Image.Image) -> str: - if "OpenAI" not in globals(): - print( - "Vision transformer for summarizing images currently only implemented for OpenAI API." - ) - return "" - client = OpenAI() - - buffered = BytesIO() - image.save(buffered, format="PNG") - img_base64 = base64.b64encode(buffered.getvalue()).decode() - - response = client.chat.completions.create( - model="gpt-4-vision-preview", - messages=[ - { - "role": "system", - "content": "You are a scientific assistant who explains plots and scientific diagrams.", - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Describe this scientific image or plot in detail.", - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{img_base64}" - }, - }, - ], - }, - ], - max_tokens=500, - ) - return response.choices[0].message.content.strip() - - -def extract_and_describe_images( - pdf_path: str, max_images: int = 5 -) -> list[str]: - doc = pymupdf.open(pdf_path) - descriptions = [] - image_count = 0 - - for page_index in range(len(doc)): - if image_count >= max_images: - break - page = doc[page_index] - images = page.get_images(full=True) - - for img_index, img in enumerate(images): - if image_count >= max_images: - break - xref = img[0] - base_image = doc.extract_image(xref) - image_bytes = base_image["image"] - image = Image.open(BytesIO(image_bytes)) - - try: - desc = describe_image(image) - descriptions.append( - f"Page {page_index + 1}, Image {img_index + 1}: {desc}" - ) - except Exception as e: - descriptions.append( - f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]" - ) - image_count += 1 - - return descriptions - - -def remove_surrogates(text: str) -> str: - return re.sub(r"[\ud800-\udfff]", "", text) - - -class ArxivAgentLegacy(BaseAgent): - def __init__( - self, - llm: BaseChatModel, - summarize: bool = True, - process_images=True, - max_results: int = 3, - download_papers: bool = True, - rag_embedding=None, - database_path="arxiv_papers", - summaries_path="arxiv_generated_summaries", - vectorstore_path="arxiv_vectorstores", - **kwargs, - ): - super().__init__(llm, **kwargs) - self.summarize = summarize - self.process_images = process_images - self.max_results = max_results - self.database_path = self.workspace / database_path - self.summaries_path = self.workspace / summaries_path - self.vectorstore_path = self.workspace / vectorstore_path - self.download_papers = download_papers - self.rag_embedding = rag_embedding - - self.database_path.mkdir(exist_ok=True, parents=True) - self.summaries_path.mkdir(exist_ok=True, parents=True) - - def _fetch_papers(self, query: str) -> list[PaperMetadata]: - if self.download_papers: - encoded_query = quote(query) - url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}" - # print(f"URL is {url}") # if verbose - entries = [] - try: - response = requests.get(url, timeout=10) - response.raise_for_status() - - feed = feedparser.parse(response.content) - # print(f"parsed response status is {feed.status}") # if verbose - entries = feed.entries - if feed.bozo: - raise Exception("Feed from arXiv looks like garbage =(") - except requests.exceptions.Timeout: - print("Request timed out while fetching papers.") - except requests.exceptions.RequestException as e: - print(f"Request error encountered while fetching papers: {e}") - except ValueError as ve: - print(f"Value error occurred while fetching papers: {ve}") - except Exception as e: - print( - f"An unexpected error occurred while fetching papers: {e}" - ) - - for i, entry in enumerate(entries): - full_id = entry.id.split("/abs/")[-1] - arxiv_id = full_id.split("/")[-1] - title = entry.title.strip() - # authors = ", ".join(author.name for author in entry.authors) - pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf" - pdf_filename = os.path.join( - self.database_path, f"{arxiv_id}.pdf" - ) - - if os.path.exists(pdf_filename): - print( - f"Paper # {i + 1}, Title: {title}, already exists in database" - ) - else: - print(f"Downloading paper # {i + 1}, Title: {title}") - response = requests.get(pdf_url) - with open(pdf_filename, "wb") as f: - f.write(response.content) - - papers = [] - - pdf_files = [ - f - for f in os.listdir(self.database_path) - if f.lower().endswith(".pdf") - ] - - for i, pdf_filename in enumerate(pdf_files): - full_text = "" - arxiv_id = pdf_filename.split(".pdf")[0] - vec_save_loc = self.vectorstore_path / arxiv_id - - if self.summarize and not vec_save_loc.exists(): - try: - loader = PyPDFLoader(self.database_path / pdf_filename) - pages = loader.load() - full_text = "\n".join([p.page_content for p in pages]) - - if self.process_images: - image_descriptions = extract_and_describe_images( - self.database_path / pdf_filename - ) - full_text += ( - "\n\n[Image Interpretations]\n" - + "\n".join(image_descriptions) - ) - - except Exception as e: - full_text = f"Error loading paper: {e}" - - papers.append({ - "arxiv_id": arxiv_id, - "full_text": full_text, - }) - - return papers - - def _fetch_node(self, state: PaperState) -> PaperState: - papers = self._fetch_papers(state["query"]) - return {**state, "papers": papers} - - def _summarize_node(self, state: PaperState) -> PaperState: - prompt = ChatPromptTemplate.from_template(""" - You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context} - - Summarize the retrieved scientific content below. - - {retrieved_content} - """) - - chain = prompt | self.llm | StrOutputParser() - - summaries = [None] * len(state["papers"]) - relevancy_scores = [0.0] * len(state["papers"]) - - def process_paper(i, paper): - arxiv_id = paper["arxiv_id"] - summary_filename = os.path.join( - self.summaries_path, f"{arxiv_id}_summary.txt" - ) - - try: - cleaned_text = remove_surrogates(paper["full_text"]) - summary = chain.invoke( - { - "retrieved_content": cleaned_text, - "context": state["context"], - }, - config=self.build_config(tags=["arxiv", "summarize_each"]), - ) - - except Exception as e: - summary = f"Error summarizing paper: {e}" - relevancy_scores[i] = 0.0 - - with open(summary_filename, "w") as f: - f.write(summary) - - return i, summary - - if "papers" not in state or len(state["papers"]) == 0: - print( - "No papers retrieved - bad query or network connection to ArXiv?" - ) - return {**state, "summaries": None} - - with ThreadPoolExecutor( - max_workers=min(32, len(state["papers"])) - ) as executor: - futures = [ - executor.submit(process_paper, i, paper) - for i, paper in enumerate(state["papers"]) - ] - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc="Summarizing Papers", - ): - i, result = future.result() - summaries[i] = result - - return {**state, "summaries": summaries} - - def _rag_node(self, state: PaperState) -> PaperState: - new_state = state.copy() - rag_agent = RAGAgent( - llm=self.llm, - embedding=self.rag_embedding, - database_path=self.database_path, - ) - new_state["final_summary"] = rag_agent.invoke(context=state["context"])[ - "summary" - ] - return new_state - - def _aggregate_node(self, state: PaperState) -> PaperState: - summaries = state["summaries"] - papers = state["papers"] - formatted = [] - - if ( - "summaries" not in state - or state["summaries"] is None - or "papers" not in state - or state["papers"] is None - ): - return {**state, "final_summary": None} - - for i, (paper, summary) in enumerate(zip(papers, summaries)): - citation = f"[{i + 1}] Arxiv ID: {paper['arxiv_id']}" - formatted.append(f"{citation}\n\nSummary:\n{summary}") - - combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted) - - with open(self.summaries_path + "/summaries_combined.txt", "w") as f: - f.write(combined) - - prompt = ChatPromptTemplate.from_template(""" - You are a scientific assistant helping extract insights from summaries of research papers. - - Here are the summaries of a large number of extracts from scientific papers: - - {Summaries} - - Your task is to read all the summaries and provide a response to this task: {context} - """) - - chain = prompt | self.llm | StrOutputParser() - - final_summary = chain.invoke( - { - "Summaries": combined, - "context": state["context"], - }, - config=self.build_config(tags=["arxiv", "aggregate"]), - ) - - with open(self.summaries_path + "/final_summary.txt", "w") as f: - f.write(final_summary) - - return {**state, "final_summary": final_summary} - - def _build_graph(self): - self.add_node(self._fetch_node) - if self.summarize: - if self.rag_embedding: - self.add_node(self._rag_node) - self.graph.set_entry_point("_fetch_node") - self.graph.add_edge("_fetch_node", "_rag_node") - self.graph.set_finish_point("_rag_node") - else: - self.add_node(self._summarize_node) - self.add_node(self._aggregate_node) - - self.graph.set_entry_point("_fetch_node") - self.graph.add_edge("_fetch_node", "_summarize_node") - self.graph.add_edge("_summarize_node", "_aggregate_node") - self.graph.set_finish_point("_aggregate_node") - else: - self.graph.set_entry_point("_fetch_node") - self.graph.set_finish_point("_fetch_node") - - -# NOTE: Run test in `tests/agents/test_arxiv_agent/test_arxiv_agent.py` via: -# -# pytest -s tests/agents/test_arxiv_agent -# -# OR -# -# uv run pytest -s tests/agents/test_arxiv_agent diff --git a/src/ursa/agents/cmm_chunker.py b/src/ursa/agents/cmm_chunker.py new file mode 100644 index 00000000..87637a66 --- /dev/null +++ b/src/ursa/agents/cmm_chunker.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from langchain_core.documents import Document + +from ursa.agents.cmm_taxonomy import ( + detect_commodity_tags, + detect_subdomain_tags, + first_temporal_indicator, + has_numerical_data, +) + +_HEADING_RE = re.compile(r"^(#{1,6})\s+(.*)$") +_SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?])\s+") +_TABLE_SEPARATOR_RE = re.compile(r"^\s*\|?\s*[:\-]+\s*(\|\s*[:\-]+\s*)+\|?\s*$") + + +@dataclass +class _Section: + section_path: str + text: str + + +class CMMChunker: + def __init__( + self, + max_tokens: int = 512, + overlap_tokens: int = 50, + min_tokens: int = 50, + ): + self.max_tokens = max_tokens + self.overlap_tokens = overlap_tokens + self.min_tokens = min_tokens + + def chunk_document(self, text: str, metadata: dict[str, Any]) -> list[Document]: + if not text.strip(): + return [] + + source_doc_id = str( + metadata.get("source_doc_id") + or metadata.get("id") + or metadata.get("doc_id") + or metadata.get("source") + or "unknown_doc" + ) + source_doc_title = str( + metadata.get("source_doc_title") + or metadata.get("title") + or metadata.get("filename") + or source_doc_id + ) + sensitivity = str(metadata.get("sensitivity_level", "public")) + doc_level_commodity = detect_commodity_tags(text) + doc_level_subdomain = detect_subdomain_tags(text) + doc_temporal = str( + metadata.get("temporal_indicator") or first_temporal_indicator(text) + ) + data_vintage = str(metadata.get("data_vintage") or doc_temporal) + + sections = ( + self._chunk_markdown(text) + if "#" in text + else [_Section(section_path="", text=text)] + ) + docs: list[Document] = [] + chunk_index = 0 + for section in sections: + blocks = self._extract_tables(section.text) + for block_text, block_type in blocks: + guarded_blocks = self._apply_size_guard([block_text]) + for chunk in guarded_blocks: + if not chunk.strip(): + continue + commodity_tags = detect_commodity_tags( + chunk, fallback=doc_level_commodity + ) + subdomain_tags = detect_subdomain_tags( + chunk, fallback=doc_level_subdomain + ) + temporal_indicator = first_temporal_indicator(chunk) or doc_temporal + doc = Document( + page_content=chunk, + metadata={ + "source_doc_id": source_doc_id, + "source_doc_title": source_doc_title, + "section_path": section.section_path, + "chunk_index": chunk_index, + "chunk_type": block_type, + "commodity_tags": commodity_tags, + "subdomain_tags": subdomain_tags, + "temporal_indicator": temporal_indicator, + "data_vintage": data_vintage, + "sensitivity_level": sensitivity, + "char_count": len(chunk), + "has_numerical_data": has_numerical_data(chunk), + }, + ) + docs.append(doc) + chunk_index += 1 + return docs + + def _chunk_markdown(self, text: str) -> list[_Section]: + sections: list[_Section] = [] + heading_stack: list[str] = [] + current_lines: list[str] = [] + current_path = "" + + def flush(): + nonlocal current_lines, current_path + body = "\n".join(current_lines).strip() + if body: + sections.append(_Section(section_path=current_path, text=body)) + current_lines = [] + + for raw_line in text.splitlines(): + m = _HEADING_RE.match(raw_line.strip()) + if m: + flush() + level = len(m.group(1)) + title = m.group(2).strip() + while len(heading_stack) >= level: + heading_stack.pop() + heading_stack.append(title) + current_path = " > ".join(heading_stack) + else: + current_lines.append(raw_line) + flush() + if not sections: + return [_Section(section_path="", text=text)] + return sections + + def _chunk_plain_text(self, text: str) -> list[str]: + paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] + if not paras: + return [text] + return paras + + def _extract_tables(self, text: str) -> list[tuple[str, str]]: + lines = text.splitlines() + blocks: list[tuple[str, str]] = [] + prose_buf: list[str] = [] + i = 0 + + def flush_prose(): + nonlocal prose_buf + prose = "\n".join(prose_buf).strip() + if prose: + for para in self._chunk_plain_text(prose): + blocks.append((para, "prose")) + prose_buf = [] + + while i < len(lines): + line = lines[i] + if "|" in line and line.count("|") >= 2: + next_line = lines[i + 1] if i + 1 < len(lines) else "" + if _TABLE_SEPARATOR_RE.match(next_line.strip()): + flush_prose() + table_lines = [line, next_line] + i += 2 + while i < len(lines): + if "|" in lines[i] and lines[i].count("|") >= 2: + table_lines.append(lines[i]) + i += 1 + else: + break + blocks.append(("\n".join(table_lines).strip(), "table")) + continue + prose_buf.append(line) + i += 1 + + flush_prose() + return blocks or [(text, "prose")] + + def _apply_size_guard(self, chunks: list[str]) -> list[str]: + out: list[str] = [] + for chunk in chunks: + words = chunk.split() + if len(words) <= self.max_tokens: + out.append(chunk) + continue + + sentences = _SENTENCE_SPLIT_RE.split(chunk) + cur: list[str] = [] + cur_tokens = 0 + for sentence in sentences: + s_tokens = len(sentence.split()) + if s_tokens > self.max_tokens: + if cur: + out.append(" ".join(cur).strip()) + cur = [] + cur_tokens = 0 + long_words = sentence.split() + step = max(1, self.max_tokens - self.overlap_tokens) + start = 0 + while start < len(long_words): + end = min(len(long_words), start + self.max_tokens) + out.append(" ".join(long_words[start:end]).strip()) + if end >= len(long_words): + break + start += step + continue + if cur and cur_tokens + s_tokens > self.max_tokens: + out.append(" ".join(cur).strip()) + overlap = " ".join(cur).split()[-self.overlap_tokens :] + cur = [" ".join(overlap), sentence] + cur_tokens = len(overlap) + s_tokens + else: + cur.append(sentence) + cur_tokens += s_tokens + if cur: + out.append(" ".join(cur).strip()) + + if not out: + return [] + + merged: list[str] = [] + for chunk in out: + tokens = len(chunk.split()) + if merged and tokens < self.min_tokens: + merged[-1] = merged[-1].rstrip() + "\n\n" + chunk.lstrip() + else: + merged.append(chunk) + return merged diff --git a/src/ursa/agents/cmm_embeddings.py b/src/ursa/agents/cmm_embeddings.py new file mode 100644 index 00000000..e49aa734 --- /dev/null +++ b/src/ursa/agents/cmm_embeddings.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from pathlib import Path + +import httpx +from langchain.embeddings import Embeddings + +try: + from openai import OpenAI +except Exception: # pragma: no cover - optional import guard + OpenAI = None # type: ignore[assignment] + + +def _default_dim_for_model(model: str) -> int: + if model == "text-embedding-3-large": + return 3072 + if model == "text-embedding-3-small": + return 1536 + return 1024 + + +class CMMEmbeddingsBase(Embeddings, ABC): + @property + @abstractmethod + def embedding_dim(self) -> int: + ... + + +class LangChainEmbeddingsAdapter(CMMEmbeddingsBase): + def __init__(self, inner: Embeddings, embedding_dim: int | None = None): + self.inner = inner + self._embedding_dim = embedding_dim + + @property + def embedding_dim(self) -> int: + if self._embedding_dim: + return self._embedding_dim + probe = self.inner.embed_query("dimension probe") + return len(probe) + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self.inner.embed_documents(texts) + + def embed_query(self, text: str) -> list[float]: + return self.inner.embed_query(text) + + +class OpenAIEmbeddings(CMMEmbeddingsBase): + def __init__( + self, + model: str = "text-embedding-3-large", + dimensions: int | None = None, + batch_size: int = 100, + api_key: str | None = None, + base_url: str | None = None, + ): + if OpenAI is None: + raise ImportError( + "openai package is required for OpenAIEmbeddings" + ) + self.model = model + self.dimensions = dimensions + self.batch_size = batch_size + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.base_url = base_url or os.getenv("OPENAI_BASE_URL") + if not self.api_key: + raise ValueError("OPENAI_API_KEY is required for OpenAIEmbeddings") + kwargs = {"api_key": self.api_key} + if self.base_url: + kwargs["base_url"] = self.base_url + timeout_seconds = float(os.getenv("CMM_EMBEDDING_TIMEOUT_SECONDS", "90")) + kwargs["timeout"] = timeout_seconds + kwargs["http_client"] = httpx.Client( + timeout=httpx.Timeout( + timeout_seconds, + connect=timeout_seconds, + read=timeout_seconds, + write=timeout_seconds, + ) + ) + self.client = OpenAI(**kwargs) + # Defensive bound for OpenAI-compatible gateways with stricter context + # limits on embedding endpoints. + self.max_input_chars = int( + os.getenv("CMM_EMBEDDING_MAX_INPUT_CHARS", "6000") + ) + + @property + def embedding_dim(self) -> int: + return self.dimensions or _default_dim_for_model(self.model) + + def _embed_batch(self, texts: list[str]) -> list[list[float]]: + bounded = [self._bounded_text(text) for text in texts] + kwargs = {"model": self.model, "input": bounded} + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions + response = self.client.embeddings.create(**kwargs) + return [list(item.embedding) for item in response.data] + + def _bounded_text(self, text: str) -> str: + if len(text) <= self.max_input_chars: + return text + return text[: self.max_input_chars] + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + vectors: list[list[float]] = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + vectors.extend(self._embed_batch(batch)) + return vectors + + def embed_query(self, text: str) -> list[float]: + return self._embed_batch([text])[0] + + +class LocalEmbeddings(CMMEmbeddingsBase): + def __init__( + self, + model_name_or_path: str = "BAAI/bge-large-en-v1.5", + device: str | None = None, + batch_size: int = 32, + ): + try: + from sentence_transformers import SentenceTransformer + except Exception as exc: # pragma: no cover - optional dependency + raise ImportError( + "sentence-transformers is required for LocalEmbeddings" + ) from exc + self.model_name_or_path = model_name_or_path + self.batch_size = batch_size + kwargs = {} + if device: + kwargs["device"] = device + self.model = SentenceTransformer(model_name_or_path, **kwargs) + + @property + def embedding_dim(self) -> int: + return int(self.model.get_sentence_embedding_dimension()) + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + vectors = self.model.encode( + texts, + batch_size=self.batch_size, + normalize_embeddings=True, + show_progress_bar=False, + ) + return [list(v) for v in vectors] + + def embed_query(self, text: str) -> list[float]: + prefixed = ( + "Represent this sentence for searching relevant passages: " + text + ) + vector = self.model.encode( + [prefixed], normalize_embeddings=True, show_progress_bar=False + )[0] + return list(vector) + + +def parse_embedding_model_spec(spec: str) -> tuple[str, str, int | None]: + if spec.startswith("openai:"): + tail = spec.split("openai:", 1)[1] + if ":" in tail: + maybe_model, maybe_dim = tail.rsplit(":", 1) + if maybe_dim.isdigit(): + return ("openai", maybe_model, int(maybe_dim)) + return ("openai", tail, None) + + if spec.startswith("local:"): + model_path = spec.split("local:", 1)[1] + return ("local", model_path, None) + + if Path(spec).exists() or "/" in spec: + return ("local", spec, None) + + if ":" in spec: + provider, model = spec.split(":", 1) + if provider in {"openai", "local"}: + return (provider, model, None) + + return ("openai", spec, None) + + +def init_embeddings( + model_spec: str, + *, + dimensions: int | None = None, + batch_size: int = 100, + api_key: str | None = None, + base_url: str | None = None, +) -> CMMEmbeddingsBase: + provider, model, parsed_dim = parse_embedding_model_spec(model_spec) + dim = dimensions if dimensions is not None else parsed_dim + + if provider == "openai": + return OpenAIEmbeddings( + model=model, + dimensions=dim, + batch_size=batch_size, + api_key=api_key, + base_url=base_url, + ) + if provider == "local": + return LocalEmbeddings(model_name_or_path=model, batch_size=32) + raise ValueError(f"Unsupported embedding provider '{provider}'") diff --git a/src/ursa/agents/cmm_query_classifier.py b/src/ursa/agents/cmm_query_classifier.py new file mode 100644 index 00000000..3df140d4 --- /dev/null +++ b/src/ursa/agents/cmm_query_classifier.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + +from ursa.agents.cmm_taxonomy import ( + detect_commodity_tags, + detect_subdomain_tags, + extract_temporal_indicators, +) + +_FACTOID_PREFIXES = { + "what", + "which", + "who", + "when", + "where", + "how", + "is", + "are", + "does", + "do", + "can", +} +_COMPARE_MARKERS = {"compare", "versus", "vs", "difference between", "relative to"} +_CAUSAL_MARKERS = {"affect", "impact", "consequence", "lead to", "because"} + + +@dataclass +class QueryProfile: + query_type: str + commodity_hints: list[str] + subdomain_hints: list[str] + temporal_hints: list[str] + retrieval_k: int + return_k: int + alpha: float + filters: dict + + +class CMMQueryClassifier: + def classify(self, query: str) -> QueryProfile: + q = query.strip() + lower_q = q.lower() + tokens = q.split() + + commodity_hints = detect_commodity_tags(q) + subdomain_hints = detect_subdomain_tags(q) + temporal_hints = extract_temporal_indicators(q) + + is_factoid = ( + len(tokens) < 15 + and bool(tokens) + and tokens[0].lower().strip("?:!,.") in _FACTOID_PREFIXES + ) + is_comparative = any(marker in lower_q for marker in _COMPARE_MARKERS) + is_multi_hop = ( + any(marker in lower_q for marker in _CAUSAL_MARKERS) + or len(subdomain_hints) >= 2 + ) + is_temporal = bool( + temporal_hints + or re.search(r"\b(recent|current|trend|historical)\b", lower_q) + ) + + query_type = "general" + retrieval_k = 20 + return_k = 5 + alpha = 0.7 + + if is_comparative: + query_type = "comparative" + retrieval_k = 30 + return_k = 8 + alpha = 0.7 + elif is_multi_hop: + query_type = "multi_hop" + retrieval_k = 30 + return_k = 10 + alpha = 0.8 + elif is_temporal: + query_type = "temporal" + retrieval_k = 24 + return_k = 6 + alpha = 0.8 + elif is_factoid: + query_type = "factoid" + retrieval_k = 10 + return_k = 3 + alpha = 0.5 + + filters: dict = {} + if commodity_hints: + filters["commodity_tags"] = commodity_hints + if subdomain_hints: + filters["subdomain_tags"] = subdomain_hints + if temporal_hints: + years = sorted( + {hint[:4] if "-Q" in hint else hint for hint in temporal_hints} + ) + filters["temporal_indicator_gte"] = years[0] + filters["temporal_indicator_lte"] = years[-1] + + return QueryProfile( + query_type=query_type, + commodity_hints=commodity_hints, + subdomain_hints=subdomain_hints, + temporal_hints=temporal_hints, + retrieval_k=retrieval_k, + return_k=return_k, + alpha=alpha, + filters=filters, + ) diff --git a/src/ursa/agents/cmm_reranker.py b/src/ursa/agents/cmm_reranker.py new file mode 100644 index 00000000..c3709af0 --- /dev/null +++ b/src/ursa/agents/cmm_reranker.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import os +from abc import ABC, abstractmethod + +from langchain_core.documents import Document + + +class CMMRerankerBase(ABC): + @abstractmethod + def rerank( + self, + query: str, + documents: list[tuple[Document, float]], + top_k: int, + ) -> list[tuple[Document, float]]: + ... + + +class NoOpReranker(CMMRerankerBase): + def rerank( + self, + query: str, + documents: list[tuple[Document, float]], + top_k: int, + ) -> list[tuple[Document, float]]: + del query + return documents[:top_k] + + +class CohereReranker(CMMRerankerBase): + def __init__(self, model: str = "rerank-english-v3.0"): + try: + import cohere + except Exception as exc: # pragma: no cover - optional dependency + raise ImportError("cohere package is required for CohereReranker") from exc + self.model = model + api_key = os.getenv("COHERE_API_KEY") + if not api_key: + raise ValueError("COHERE_API_KEY is required for CohereReranker") + self.client = cohere.Client(api_key) + + def rerank( + self, + query: str, + documents: list[tuple[Document, float]], + top_k: int, + ) -> list[tuple[Document, float]]: + if not documents: + return [] + texts = [doc.page_content for doc, _ in documents] + response = self.client.rerank( + model=self.model, + query=query, + documents=texts, + top_n=min(top_k, len(documents)), + ) + reranked: list[tuple[Document, float]] = [] + for item in response.results: + doc, retrieval_score = documents[item.index] + doc.metadata["retrieval_score"] = retrieval_score + reranked.append((doc, float(item.relevance_score))) + return reranked + + +class LocalCrossEncoderReranker(CMMRerankerBase): + def __init__( + self, + model_name: str = "BAAI/bge-reranker-v2-m3", + device: str | None = None, + batch_size: int = 16, + ): + try: + from sentence_transformers import CrossEncoder + except Exception as exc: # pragma: no cover - optional dependency + raise ImportError( + "sentence-transformers is required for LocalCrossEncoderReranker" + ) from exc + self.batch_size = batch_size + kwargs = {} + if device: + kwargs["device"] = device + self.model = CrossEncoder(model_name, **kwargs) + + def rerank( + self, + query: str, + documents: list[tuple[Document, float]], + top_k: int, + ) -> list[tuple[Document, float]]: + if not documents: + return [] + pairs = [[query, doc.page_content] for doc, _ in documents] + scores = self.model.predict(pairs, batch_size=self.batch_size) + scored = [] + for (doc, retrieval_score), rank_score in zip(documents, scores): + doc.metadata["retrieval_score"] = retrieval_score + scored.append((doc, float(rank_score))) + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:top_k] + + +def init_reranker(provider: str | None = None) -> CMMRerankerBase: + selected = (provider or os.getenv("CMM_RERANKER_PROVIDER", "none")).lower() + if selected in {"none", "false", "off"}: + return NoOpReranker() + if selected == "cohere": + return CohereReranker() + if selected == "local": + model = os.getenv("CMM_RERANKER_MODEL_PATH", "BAAI/bge-reranker-v2-m3") + return LocalCrossEncoderReranker(model_name=model) + raise ValueError(f"Unsupported reranker provider '{selected}'") diff --git a/src/ursa/agents/cmm_taxonomy.py b/src/ursa/agents/cmm_taxonomy.py new file mode 100644 index 00000000..cccc18b1 --- /dev/null +++ b/src/ursa/agents/cmm_taxonomy.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import re +from typing import Iterable + +COMMODITY_KEYWORDS: dict[str, set[str]] = { + "HREE": { + "dysprosium", + "terbium", + "heavy rare earth", + "hree", + "dy", + "tb", + "yttrium", + "y", + }, + "LREE": { + "neodymium", + "praseodymium", + "lanthanum", + "cerium", + "light rare earth", + "lree", + "nd", + "pr", + "la", + "ce", + "bastnasite", + "monazite", + }, + "CO": {"cobalt", "co", "li-ion cathode", "nmc", "lco"}, + "LI": { + "lithium", + "li", + "spodumene", + "lithium carbonate", + "li2co3", + "lithium hydroxide", + "brine", + }, + "GA": {"gallium", "ga", "gaas", "gallium arsenide", "ga2o3"}, + "GR": {"graphite", "gr", "anode", "natural graphite", "synthetic graphite"}, + "NI": {"nickel", "ni", "class i nickel", "sulfide nickel", "laterite"}, + "CU": {"copper", "cu", "porphyry", "cathode copper"}, + "GE": {"germanium", "ge", "optical fiber", "infrared optics"}, + "OTH": {"critical mineral", "critical materials", "cmm", "strategic mineral"}, +} + +SUBDOMAIN_KEYWORDS: dict[str, set[str]] = { + "T-EC": { + "extraction", + "leaching", + "solvent extraction", + "sx", + "ion exchange", + "separation", + "acid digestion", + "flotation", + "beneficiation", + }, + "T-PM": { + "processing", + "refining", + "purification", + "calcination", + "roasting", + "hydrometallurgy", + "pyrometallurgy", + }, + "T-GO": { + "grade", + "ore", + "resource estimate", + "reserve", + "geometallurgy", + "deposit", + }, + "Q-PS": { + "price", + "spot price", + "index", + "market price", + "premium", + "discount", + }, + "Q-TF": { + "import", + "export", + "trade flow", + "comtrade", + "hs code", + "tariff", + "trade balance", + "shipment", + }, + "Q-EP": { + "elasticity", + "supply shock", + "demand shock", + "substitution", + "cross-price", + }, + "G-PR": { + "policy", + "regulation", + "executive order", + "ira", + "inflation reduction act", + "chips", + "export control", + "dpa", + "sanction", + }, + "G-BM": { + "benchmark", + "best practice", + "governance", + "standards", + "traceability", + }, + "S-CC": { + "carbon", + "co2", + "emissions", + "scope 1", + "scope 2", + "decarbonization", + }, + "S-ST": { + "sustainability", + "esg", + "water usage", + "tailings", + "waste", + "environmental impact", + }, +} + +_YEAR_RE = re.compile(r"\b(19|20)\d{2}\b") +_QUARTER_RE = re.compile(r"\b((19|20)\d{2})[-\s]?Q([1-4])\b", re.IGNORECASE) +_NUMERIC_RE = re.compile(r"\d") + + +def _normalize(text: str) -> str: + return text.lower() + + +def _matches_keyword(text: str, keyword: str) -> bool: + k = keyword.strip().lower() + if not k: + return False + if " " in k or "-" in k: + return k in text + return re.search(rf"\b{re.escape(k)}\b", text) is not None + + +def detect_tags( + text: str, taxonomy: dict[str, set[str]], fallback: Iterable[str] | None = None +) -> list[str]: + normalized = _normalize(text) + tags: list[str] = [] + for code, words in taxonomy.items(): + if any(_matches_keyword(normalized, word) for word in words): + tags.append(code) + if tags: + return sorted(set(tags)) + if fallback: + return sorted(set(fallback)) + return [] + + +def detect_commodity_tags( + text: str, fallback: Iterable[str] | None = None +) -> list[str]: + return detect_tags(text, COMMODITY_KEYWORDS, fallback=fallback) + + +def detect_subdomain_tags( + text: str, fallback: Iterable[str] | None = None +) -> list[str]: + return detect_tags(text, SUBDOMAIN_KEYWORDS, fallback=fallback) + + +def extract_temporal_indicators(text: str) -> list[str]: + years = {match.group(0) for match in _YEAR_RE.finditer(text)} + quarters = { + f"{match.group(1)}-Q{match.group(3)}" for match in _QUARTER_RE.finditer(text) + } + return sorted(quarters) + sorted(years) + + +def first_temporal_indicator(text: str) -> str: + indicators = extract_temporal_indicators(text) + return indicators[0] if indicators else "" + + +def has_numerical_data(text: str) -> bool: + return _NUMERIC_RE.search(text) is not None diff --git a/src/ursa/agents/cmm_vectorstore.py b/src/ursa/agents/cmm_vectorstore.py new file mode 100644 index 00000000..2b644969 --- /dev/null +++ b/src/ursa/agents/cmm_vectorstore.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +import os +import re +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Any + +from langchain_chroma import Chroma +from langchain_core.documents import Document + +from ursa.agents.cmm_embeddings import CMMEmbeddingsBase + +try: + from rank_bm25 import BM25Okapi +except Exception: # pragma: no cover - optional dependency + BM25Okapi = None # type: ignore[assignment] + + +class CMMVectorStoreBase(ABC): + @abstractmethod + def add_documents(self, documents: list[Document]) -> None: + ... + + @abstractmethod + def hybrid_search( + self, + query: str, + k: int, + alpha: float, + filters: dict[str, Any] | None, + ) -> list[tuple[Document, float]]: + ... + + @abstractmethod + def delete_collection(self) -> None: + ... + + @abstractmethod + def count(self) -> int: + ... + + +def _tokenize(text: str) -> list[str]: + return re.findall(r"\b\w+\b", text.lower()) + + +_LIST_METADATA_FIELDS = {"commodity_tags", "subdomain_tags"} + + +def _sanitize_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + cleaned: dict[str, Any] = {} + for key, value in metadata.items(): + if value is None: + cleaned[key] = None + elif isinstance(value, (str, int, float, bool)): + cleaned[key] = value + elif isinstance(value, list): + cleaned[key] = "|".join(str(item) for item in value) + else: + cleaned[key] = str(value) + return cleaned + + +def _restore_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + restored = dict(metadata) + for key in _LIST_METADATA_FIELDS: + value = restored.get(key) + if isinstance(value, str): + if not value.strip(): + restored[key] = [] + elif "|" in value: + restored[key] = [item for item in value.split("|") if item] + else: + restored[key] = [value] + return restored + + +def _match_filter(metadata: dict[str, Any], filters: dict[str, Any]) -> bool: + if not filters: + return True + + def intersects(meta_val: Any, expected: list[str]) -> bool: + if meta_val is None: + return False + if isinstance(meta_val, list): + return bool(set(map(str, meta_val)).intersection(set(expected))) + return str(meta_val) in expected + + for field in ("commodity_tags", "subdomain_tags", "sensitivity_level"): + expected = filters.get(field) + if expected and not intersects(metadata.get(field), expected): + return False + + gte = filters.get("temporal_indicator_gte") + lte = filters.get("temporal_indicator_lte") + if gte or lte: + temporal = str(metadata.get("temporal_indicator", "")) + year = temporal[:4] if "-Q" in temporal else temporal + if gte and (not year or year < str(gte)): + return False + if lte and (not year or year > str(lte)): + return False + + return True + + +class ChromaBM25VectorStore(CMMVectorStoreBase): + def __init__( + self, + *, + persist_directory: str | Path, + embedding_model: CMMEmbeddingsBase, + collection_name: str = "cmm_chunks", + ): + self.persist_directory = Path(persist_directory) + self.persist_directory.mkdir(parents=True, exist_ok=True) + self.embedding_model = embedding_model + self.collection_name = collection_name + self._chroma = Chroma( + collection_name=collection_name, + persist_directory=str(self.persist_directory), + embedding_function=embedding_model, + collection_metadata={"hnsw:space": "cosine"}, + ) + self._docs_by_id: dict[str, Document] = {} + self._bm25 = None + self._tokenized_docs: list[list[str]] = [] + self._bm25_ids: list[str] = [] + self._rebuild_bm25_index() + + def _collection_ids(self) -> list[str]: + res = self._chroma._collection.get(include=[]) + ids = res.get("ids", []) + if ids and isinstance(ids[0], list): + return [i for sub in ids for i in sub] + return list(ids) + + def _rebuild_bm25_index(self) -> None: + self._docs_by_id = {} + self._tokenized_docs = [] + self._bm25_ids = [] + + res = self._chroma._collection.get(include=["documents", "metadatas"]) + ids = res.get("ids", []) + docs = res.get("documents", []) + metas = res.get("metadatas", []) + if ids and isinstance(ids[0], list): + ids = [i for sub in ids for i in sub] + for cid, text, meta in zip(ids, docs, metas): + metadata = _restore_metadata(dict(meta or {})) + metadata.setdefault("chunk_id", cid) + doc = Document(page_content=text or "", metadata=metadata) + self._docs_by_id[cid] = doc + self._bm25_ids.append(cid) + self._tokenized_docs.append(_tokenize(doc.page_content)) + + if BM25Okapi is not None and self._tokenized_docs: + self._bm25 = BM25Okapi(self._tokenized_docs) + else: + self._bm25 = None + + def add_documents(self, documents: list[Document]) -> None: + chroma_docs: list[Document] = [] + ids = [] + for i, doc in enumerate(documents): + chunk_id = ( + doc.metadata.get("chunk_id") + or f"{doc.metadata.get('source_doc_id', 'doc')}::" + f"{doc.metadata.get('chunk_index', i)}::{i}" + ) + metadata = _sanitize_metadata(dict(doc.metadata)) + metadata["chunk_id"] = str(chunk_id) + chroma_docs.append( + Document(page_content=doc.page_content, metadata=metadata) + ) + ids.append(str(chunk_id)) + try: + self._chroma.add_documents(chroma_docs, ids=ids) + except Exception as exc: + print( + "[ChromaBM25VectorStore] Batch insert failed; falling back to" + f" per-document insert. error={exc}" + ) + skipped = 0 + for doc, chunk_id in zip(chroma_docs, ids): + try: + self._chroma.add_documents([doc], ids=[chunk_id]) + except Exception: + skipped += 1 + if skipped: + print( + "[ChromaBM25VectorStore] Skipped" + f" {skipped} chunk(s) during fallback insert." + ) + self._rebuild_bm25_index() + + def _dense_search(self, query: str, k: int) -> list[tuple[Document, float]]: + dense = self._chroma.similarity_search_with_relevance_scores(query, k=k) + out: list[tuple[Document, float]] = [] + for i, (doc, score) in enumerate(dense): + metadata = _restore_metadata(dict(doc.metadata or {})) + metadata.setdefault( + "chunk_id", + f"{metadata.get('source_doc_id', 'dense')}::{metadata.get('chunk_index', i)}::{i}", + ) + out.append( + ( + Document(page_content=doc.page_content, metadata=metadata), + float(score), + ) + ) + return out + + def _bm25_search(self, query: str, k: int) -> list[tuple[Document, float]]: + if not self._bm25_ids: + return [] + + query_tokens = _tokenize(query) + if self._bm25 is not None: + scores = list(self._bm25.get_scores(query_tokens)) + else: + qset = set(query_tokens) + scores = [] + for toks in self._tokenized_docs: + scores.append(float(len(qset.intersection(set(toks))))) + + ranked = sorted( + enumerate(scores), key=lambda item: item[1], reverse=True + )[:k] + results: list[tuple[Document, float]] = [] + for idx, score in ranked: + cid = self._bm25_ids[idx] + doc = self._docs_by_id[cid] + results.append((doc, float(score))) + return results + + def hybrid_search( + self, + query: str, + k: int, + alpha: float = 0.7, + filters: dict[str, Any] | None = None, + ) -> list[tuple[Document, float]]: + retrieval_k = max(k, 20) + dense = self._dense_search(query, retrieval_k) + sparse = self._bm25_search(query, retrieval_k) + k_rrf = 60.0 + scores: dict[str, float] = defaultdict(float) + docs: dict[str, Document] = {} + + for rank, (doc, _) in enumerate(dense, start=1): + cid = str(doc.metadata.get("chunk_id")) + docs[cid] = doc + scores[cid] += alpha * (1.0 / (k_rrf + rank)) + + for rank, (doc, _) in enumerate(sparse, start=1): + cid = str(doc.metadata.get("chunk_id")) + docs[cid] = doc + scores[cid] += (1.0 - alpha) * (1.0 / (k_rrf + rank)) + + fused = sorted(scores.items(), key=lambda x: x[1], reverse=True) + results: list[tuple[Document, float]] = [] + for cid, score in fused: + doc = docs[cid] + if _match_filter(doc.metadata, filters or {}): + results.append((doc, score)) + if len(results) >= k: + break + return results + + def delete_collection(self) -> None: + ids = self._collection_ids() + if ids: + self._chroma._collection.delete(ids=ids) + self._rebuild_bm25_index() + + def count(self) -> int: + return int(self._chroma._collection.count()) + + +class WeaviateVectorStore(CMMVectorStoreBase): + def __init__( + self, + *, + embedding_model: CMMEmbeddingsBase, + collection_name: str = "CMMChunk", + weaviate_url: str | None = None, + weaviate_api_key: str | None = None, + ): + self.embedding_model = embedding_model + self.collection_name = collection_name + self.weaviate_url = weaviate_url or os.getenv("CMM_WEAVIATE_URL") + self.weaviate_api_key = weaviate_api_key or os.getenv( + "CMM_WEAVIATE_API_KEY" + ) + if not self.weaviate_url or not self.weaviate_api_key: + raise ValueError( + "CMM_WEAVIATE_URL and CMM_WEAVIATE_API_KEY are required" + ) + try: + import weaviate + except Exception as exc: # pragma: no cover - optional dependency + raise ImportError( + "weaviate-client package is required for WeaviateVectorStore" + ) from exc + + self._weaviate = weaviate + self.client = weaviate.connect_to_weaviate_cloud( + cluster_url=self.weaviate_url, + auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key), + ) + + def _collection(self): + return self.client.collections.get(self.collection_name) + + def add_documents(self, documents: list[Document]) -> None: + coll = self._collection() + texts = [doc.page_content for doc in documents] + vectors = self.embedding_model.embed_documents(texts) + with coll.batch.dynamic() as batch: + for doc, vector in zip(documents, vectors): + data = {"text": doc.page_content} + data.update(doc.metadata) + batch.add_object(properties=data, vector=vector) + + def hybrid_search( + self, + query: str, + k: int, + alpha: float = 0.7, + filters: dict[str, Any] | None = None, + ) -> list[tuple[Document, float]]: + coll = self._collection() + response = coll.query.hybrid(query=query, alpha=alpha, limit=k) + out = [] + for obj in response.objects: + props = dict(obj.properties) + text = str(props.pop("text", "")) + score = float(getattr(obj.metadata, "score", 0.0)) + if _match_filter(props, filters or {}): + out.append((Document(page_content=text, metadata=props), score)) + return out + + def delete_collection(self) -> None: + try: + self.client.collections.delete(self.collection_name) + except Exception: + # Collection may not exist yet. + return + + def count(self) -> int: + try: + coll = self._collection() + return int(coll.aggregate.over_all(total_count=True).total_count) + except Exception: + return 0 + + +def init_vectorstore( + *, + backend: str | None = None, + persist_directory: str | Path, + embedding_model: CMMEmbeddingsBase, + collection_name: str = "cmm_chunks", +) -> CMMVectorStoreBase: + selected = ( + backend or os.getenv("CMM_VECTORSTORE_BACKEND", "chroma") + ).lower() + if selected == "chroma": + return ChromaBM25VectorStore( + persist_directory=persist_directory, + embedding_model=embedding_model, + collection_name=collection_name, + ) + if selected == "weaviate": + return WeaviateVectorStore( + embedding_model=embedding_model, + collection_name=collection_name, + ) + raise ValueError(f"Unsupported vectorstore backend '{selected}'") diff --git a/src/ursa/agents/code_review_agent.py b/src/ursa/agents/code_review_agent.py index c6fc085e..1fffd5b8 100644 --- a/src/ursa/agents/code_review_agent.py +++ b/src/ursa/agents/code_review_agent.py @@ -155,7 +155,7 @@ def safety_check(self, state: CodeReviewState) -> CodeReviewState: } print(f"{GREEN}[PASSED] the safety check: {RESET}" + query) - elif state["messages"][-1].tool_calls[0]["name"] == "write_code": + elif state["messages"][-1].tool_calls[0]["name"] == "write_file": fn = ( state["messages"][-1] .tool_calls[0]["args"] @@ -222,8 +222,9 @@ def run(self, prompt, workspace): "iteration": 0, "workspace": workspace, } - return self.action.invoke( - initial_state, {"configurable": {"thread_id": self.thread_id}} + return self.invoke( + initial_state, + config={"configurable": {"thread_id": self.thread_id}}, ) diff --git a/src/ursa/agents/execution_agent.py b/src/ursa/agents/execution_agent.py index 9402be53..977bede4 100644 --- a/src/ursa/agents/execution_agent.py +++ b/src/ursa/agents/execution_agent.py @@ -7,19 +7,16 @@ - Workspace management with optional symlinking for external sources. - Safety-checked shell execution via run_command with output size budgeting. - Code authoring and edits through write_code and edit_code with rich previews. -- Web search capability through DuckDuckGoSearchResults. +- Web/literature search tools through acquisition helpers. - Summarization of the session and optional memory logging. - Configurable graph with nodes for agent, action, and summarize. Implementation notes: - LLM prompts are sourced from prompt_library.execution_prompts. -- Outputs from subprocess are trimmed under MAX_TOOL_MSG_CHARS to fit tool messages. +- Outputs from subprocess are trimmed to the tool-character budget. - The agent uses ToolNode and LangGraph StateGraph to loop until no tool calls remain. - Safety gates block unsafe shell commands and surface the rationale to the user. -Environment: -- MAX_TOOL_MSG_CHARS caps combined stdout/stderr in tool responses. - Entry points: - ExecutionAgent._invoke(...) runs the compiled graph. - main() shows a minimal demo that writes and runs a script. @@ -60,7 +57,13 @@ executor_prompt, recap_prompt, ) -from ursa.tools import edit_code, read_file, run_command, write_code +from ursa.tools import ( + edit_code, + read_file, + run_cmm_supply_chain_optimization, + run_command, + write_code, +) from ursa.tools.search_tools import ( run_arxiv_search, run_osti_search, @@ -204,6 +207,7 @@ def __init__( write_code, edit_code, read_file, + run_cmm_supply_chain_optimization, run_web_search, run_osti_search, run_arxiv_search, diff --git a/src/ursa/agents/mp_agent.py b/src/ursa/agents/mp_agent.py index 28a2b0eb..08bfa526 100644 --- a/src/ursa/agents/mp_agent.py +++ b/src/ursa/agents/mp_agent.py @@ -160,9 +160,15 @@ def _build_graph(self): if __name__ == "__main__": - agent = MaterialsProjectAgent() + from langchain.chat_models import init_chat_model + + agent = MaterialsProjectAgent(llm=init_chat_model("openai:gpt-5.2")) resp = agent.invoke( - mp_query="LiFePO4", + query={ + "elements": ["Li", "Fe", "P", "O"], + "band_gap_min": 0.0, + "band_gap_max": 6.0, + }, context="What is its band gap and stability, and any synthesis challenges?", ) print(resp) diff --git a/src/ursa/agents/rag_agent.py b/src/ursa/agents/rag_agent.py index 82b737cd..39106c6b 100644 --- a/src/ursa/agents/rag_agent.py +++ b/src/ursa/agents/rag_agent.py @@ -1,20 +1,34 @@ +from __future__ import annotations + import os import re import statistics -from functools import cached_property from pathlib import Path from threading import Lock -from typing import TypedDict +from typing import Any, Iterable, TypedDict from langchain.chat_models import BaseChatModel -from langchain.embeddings import Embeddings, init_embeddings +from langchain.embeddings import Embeddings +from langchain.embeddings import init_embeddings as init_lc_embeddings from langchain_chroma import Chroma +from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_text_splitters import RecursiveCharacterTextSplitter from tqdm import tqdm from ursa.agents.base import BaseAgent +from ursa.agents.cmm_chunker import CMMChunker +from ursa.agents.cmm_embeddings import ( + CMMEmbeddingsBase, + LangChainEmbeddingsAdapter, +) +from ursa.agents.cmm_embeddings import ( + init_embeddings as init_cmm_embeddings, +) +from ursa.agents.cmm_query_classifier import CMMQueryClassifier +from ursa.agents.cmm_reranker import CMMRerankerBase, init_reranker +from ursa.agents.cmm_vectorstore import CMMVectorStoreBase, init_vectorstore from ursa.util.parse import ( OFFICE_EXTENSIONS, SPECIAL_TEXT_FILENAMES, @@ -22,10 +36,6 @@ read_text_from_file, ) -# Set a minimum number of characters in a file to -# to ingest it. Avoids files with minimal content -# that would be unlikely to give meaningful -# information to perform RAG on. MIN_CHARS = 30 @@ -33,6 +43,10 @@ class RAGMetadata(TypedDict): k: int num_results: int relevance_scores: list[float] + query_type: str + retrieval_k: int + backend: str + filter_fallback_used: bool class RAGState(TypedDict, total=False): @@ -47,48 +61,93 @@ def remove_surrogates(text: str) -> str: return re.sub(r"[\ud800-\udfff]", "", text) -def _is_meaningful(text: str) -> bool: - return len(text) >= MIN_CHARS - - class RAGAgent(BaseAgent[RAGState]): - agent_state = RAGState + state_type = RAGState def __init__( self, llm: BaseChatModel, - embedding: Embeddings | None = None, - return_k: int = 10, + embedding: Embeddings | str | None = None, + embedding_dimensions: int | None = None, + retrieval_k: int = 20, + return_k: int = 5, + vectorstore_backend: str | None = None, + hybrid_alpha: float | None = None, + use_reranker: bool = False, + reranker_provider: str = "none", chunk_size: int = 1000, chunk_overlap: int = 200, - database_path: str = "database", - summaries_path: str = "database", - vectorstore_path: str = "vectorstore", - **kwargs, + database_path: str | Path = "database", + summaries_path: str | Path = "database", + vectorstore_path: str | Path = "vectorstore", + include_extensions: set[str] | None = None, + exclude_extensions: set[str] | None = None, + max_docs_per_ingest: int | None = None, + min_chars: int = MIN_CHARS, + **kwargs: Any, ): super().__init__(llm, **kwargs) - self.retriever = None self._vs_lock = Lock() - self.return_k = return_k - self.embedding = embedding or init_embeddings( - "openai:text-embedding-3-small" + + self.retrieval_k = max(1, int(retrieval_k)) + self.return_k = max(1, int(return_k)) + self._adaptive_retrieval_k = self.retrieval_k == 20 + self._adaptive_return_k = self.return_k == 5 + self.hybrid_alpha = float(hybrid_alpha or os.getenv("CMM_HYBRID_ALPHA", "0.7")) + self.use_reranker = bool(use_reranker) + self.reranker_provider = reranker_provider + self.vectorstore_backend = ( + vectorstore_backend + or os.getenv("CMM_VECTORSTORE_BACKEND", "chroma") + ).lower() + self.legacy_mode = ( + os.getenv("URSA_RAG_LEGACY_MODE", "false").strip().lower() + in {"1", "true", "yes", "on"} ) + self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - self.database_path = self.workspace / database_path - self.summaries_path = self.workspace / summaries_path - self.vectorstore_path = self.workspace / vectorstore_path + self.database_path = self._resolve_path(database_path) + self.summaries_path = self._resolve_path(summaries_path) + self.vectorstore_path = self._resolve_path(vectorstore_path) + self.include_extensions = ( + self._normalize_extensions(include_extensions) + if include_extensions is not None + else None + ) + self.exclude_extensions = self._normalize_extensions(exclude_extensions) + self.max_docs_per_ingest = max_docs_per_ingest + self.min_chars = max(1, int(min_chars)) + + self.embedding = self._init_embedding(embedding, embedding_dimensions) + self.chunker = CMMChunker( + max_tokens=max(64, chunk_size // 2), + overlap_tokens=max(0, chunk_overlap // 4), + min_tokens=max(20, min(chunk_size // 6, 120)), + ) + self.classifier = CMMQueryClassifier() - self.vectorstore_path.mkdir(exist_ok=True, parents=True) - self.vectorstore = self._open_global_vectorstore() + provider = reranker_provider or os.getenv( + "CMM_RERANKER_PROVIDER", "none" + ) + self.reranker: CMMRerankerBase = ( + init_reranker(provider) if self.use_reranker else init_reranker("none") + ) - @cached_property - def graph(self): - return self._build_graph() + self.vectorstore_path.mkdir(exist_ok=True, parents=True) + self._ingested_manifest = self._load_manifest_ids() - @property - def _action(self): - return self.graph + if self.legacy_mode: + self.vectorstore = self._open_legacy_vectorstore() + else: + self.vectorstore = init_vectorstore( + backend=self.vectorstore_backend, + persist_directory=self.vectorstore_path, + embedding_model=self.embedding, + collection_name=os.getenv( + "CMM_VECTORSTORE_COLLECTION", "cmm_chunks" + ), + ) @property def manifest_path(self) -> str: @@ -98,170 +157,320 @@ def manifest_path(self) -> str: def manifest_exists(self) -> bool: return os.path.exists(self.manifest_path) - def _open_global_vectorstore(self) -> Chroma: + def _resolve_path(self, value: str | Path) -> Path: + p = Path(value) + if p.is_absolute(): + return p + return self.workspace / p + + def _normalize_extensions(self, values: Iterable[str] | None) -> set[str]: + if not values: + return set() + normalized = set() + for value in values: + ext = self._normalize_extension(value) + if ext: + normalized.add(ext) + return normalized + + def _normalize_extension(self, value: str) -> str: + v = value.strip().lower() + if not v: + return "" + return v if v.startswith(".") else f".{v}" + + def _init_embedding( + self, + embedding: Embeddings | str | None, + embedding_dimensions: int | None, + ) -> CMMEmbeddingsBase: + if isinstance(embedding, str): + return init_cmm_embeddings( + embedding, + dimensions=embedding_dimensions, + base_url=os.getenv("OPENAI_BASE_URL"), + ) + if isinstance(embedding, Embeddings): + return LangChainEmbeddingsAdapter( + embedding, + embedding_dim=embedding_dimensions, + ) + + model = os.getenv("CMM_EMBEDDING_MODEL", "openai:text-embedding-3-large") + if self.legacy_mode: + legacy = init_lc_embeddings("openai:text-embedding-3-small") + return LangChainEmbeddingsAdapter(legacy, embedding_dim=1536) + + return init_cmm_embeddings( + model, + dimensions=embedding_dimensions + or _safe_int(os.getenv("CMM_EMBEDDING_DIMENSIONS")), + base_url=os.getenv("OPENAI_BASE_URL"), + ) + + def _open_legacy_vectorstore(self) -> Chroma: return Chroma( - persist_directory=self.vectorstore_path, + persist_directory=str(self.vectorstore_path), embedding_function=self.embedding, collection_metadata={"hnsw:space": "cosine"}, ) + def _load_manifest_ids(self) -> set[str]: + if not self.manifest_exists: + return set() + with open(self.manifest_path, "r", encoding="utf-8") as handle: + return {line.strip() for line in handle if line.strip()} + def _paper_exists_in_vectorstore(self, doc_id: str) -> bool: - try: - col = self.vectorstore._collection - res = col.get(where={"id": doc_id}, limit=1) - return len(res.get("ids", [])) > 0 - except Exception: - if not self.manifest_exists: - return False - with open(self.manifest_path, "r") as f: - return any(line.strip() == doc_id for line in f) - - def _mark_paper_ingested(self, arxiv_id: str) -> None: - with open(self.manifest_path, "a") as f: - f.write(f"{arxiv_id}\n") - - def _ensure_doc_in_vectorstore(self, paper_text: str, doc_id: str) -> None: - splitter = RecursiveCharacterTextSplitter( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap - ) - docs = splitter.create_documents( - [paper_text], metadatas=[{"id": doc_id}] - ) - with self._vs_lock: - if not self._paper_exists_in_vectorstore(doc_id): - ids = [f"{doc_id}::{i}" for i, _ in enumerate(docs)] - self.vectorstore.add_documents(docs, ids=ids) - self._mark_paper_ingested(doc_id) + return doc_id in self._ingested_manifest - def _get_global_retriever(self, k: int = 5): - return self.vectorstore, self.vectorstore.as_retriever( - search_kwargs={"k": k} - ) + def _mark_paper_ingested(self, doc_id: str) -> None: + if doc_id in self._ingested_manifest: + return + self._ingested_manifest.add(doc_id) + with open(self.manifest_path, "a", encoding="utf-8") as handle: + handle.write(f"{doc_id}\n") def _read_docs_node(self, state: RAGState) -> RAGState: print("[RAG Agent] Reading Documents....") new_state = state.copy() - custom_extensions = [ - item.strip() + custom_extensions = { + self._normalize_extension(item) for item in os.environ.get("URSA_TEXT_EXTENSIONS", "").split(",") - ] - custom_readable_files = [ - item.strip() - for item in os.environ.get("URSA_SPECIAL_TEXT_FILENAMES", "").split( - "," - ) - ] + if item.strip() + } + custom_readable_files = { + item.strip().lower() + for item in os.environ.get( + "URSA_SPECIAL_TEXT_FILENAMES", "" + ).split(",") + if item.strip() + } base_dir = Path(self.database_path) ingestible_paths: list[Path] = [] - for p in base_dir.rglob("*"): - if not p.is_file(): + for path in base_dir.rglob("*"): + if not path.is_file(): continue - ext = p.suffix.lower() + ext = path.suffix.lower() + file_name = path.name.lower() - if ( + base_ingestible = ( ext == ".pdf" or ext in TEXT_EXTENSIONS or ext in custom_extensions - or p.name.lower() in SPECIAL_TEXT_FILENAMES - or p.name.lower() in custom_readable_files + or file_name in SPECIAL_TEXT_FILENAMES + or file_name in custom_readable_files or ext in OFFICE_EXTENSIONS - ): - ingestible_paths.append(p) + ) + if not base_ingestible: + continue + + if ext in self.exclude_extensions: + continue + + if self.include_extensions is not None: + if ( + ext not in self.include_extensions + and file_name not in SPECIAL_TEXT_FILENAMES + and file_name not in custom_readable_files + ): + continue + + ingestible_paths.append(path) candidates: list[tuple[Path, str]] = [] - for p in ingestible_paths: - doc_id = str(p) + for path in ingestible_paths: + doc_id = str(path) if not self._paper_exists_in_vectorstore(doc_id): - candidates.append((p, doc_id)) + candidates.append((path, doc_id)) + + if self.max_docs_per_ingest is not None and self.max_docs_per_ingest > 0: + candidates = candidates[: self.max_docs_per_ingest] papers: list[str] = [] doc_ids: list[str] = [] for path, doc_id in tqdm(candidates, desc="RAG parsing text"): full_text = read_text_from_file(path) - # skip files with very few characters to - # avoid parsing/rag ingestion problems - if not _is_meaningful(full_text): + if len(full_text) < self.min_chars: continue papers.append(full_text) doc_ids.append(doc_id) new_state["doc_texts"] = papers new_state["doc_ids"] = doc_ids - return new_state - def _ingest_docs_node(self, state: RAGState) -> RAGState: - splitter = RecursiveCharacterTextSplitter( - chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap + def _build_docs_for_ingest(self, paper: str, doc_id: str) -> list[Document]: + cleaned_text = remove_surrogates(paper) + + if self.legacy_mode: + splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + ) + docs = splitter.create_documents( + [cleaned_text], metadatas=[{"id": doc_id}] + ) + for i, doc in enumerate(docs): + doc.metadata.setdefault("source_doc_id", doc_id) + doc.metadata.setdefault("chunk_index", i) + return docs + + title = Path(doc_id).name + docs = self.chunker.chunk_document( + cleaned_text, + metadata={"source_doc_id": doc_id, "source_doc_title": title}, ) + for i, doc in enumerate(docs): + if "chunk_id" not in doc.metadata: + chunk_id = f"{doc_id}::{doc.metadata.get('chunk_index', i)}" + doc.metadata["chunk_id"] = chunk_id + doc.metadata.setdefault("id", doc_id) + return docs + def _ingest_docs_node(self, state: RAGState) -> RAGState: if "doc_texts" not in state: - raise RuntimeError("Unexpected error: doc_ids not in state!") - - if "doc_ids" not in state: raise RuntimeError("Unexpected error: doc_texts not in state!") + if "doc_ids" not in state: + raise RuntimeError("Unexpected error: doc_ids not in state!") - batch_docs, batch_ids = [], [] - - for paper, id in tqdm( + batch_docs: list[Document] = [] + ingest_ids: list[str] = [] + for paper, doc_id in tqdm( zip(state["doc_texts"], state["doc_ids"]), total=len(state["doc_texts"]), desc="RAG Ingesting", ): - cleaned_text = remove_surrogates(paper) - docs = splitter.create_documents( - [cleaned_text], metadatas=[{"id": id}] - ) - ids = [f"{id}::{i}" for i, _ in enumerate(docs)] - batch_docs.extend(docs) - batch_ids.extend(ids) + docs = self._build_docs_for_ingest(paper, doc_id) + if docs: + batch_docs.extend(docs) + ingest_ids.append(doc_id) - if state["doc_texts"]: + if batch_docs: print("[RAG Agent] Ingesting Documents Into RAG Database....") with self._vs_lock: - self.vectorstore.add_documents(batch_docs, ids=batch_ids) - for id in batch_ids: - self._mark_paper_ingested(id) + if self.legacy_mode: + assert isinstance(self.vectorstore, Chroma) + ids = [ + str(doc.metadata.get("chunk_id") or f"chunk::{i}") + for i, doc in enumerate(batch_docs) + ] + self.vectorstore.add_documents(batch_docs, ids=ids) + else: + assert isinstance(self.vectorstore, CMMVectorStoreBase) + self.vectorstore.add_documents(batch_docs) + + for doc_id in ingest_ids: + self._mark_paper_ingested(doc_id) return state - def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState: - print( - "[RAG Agent] Retrieving Contextually Relevant Information From Database..." + def _retrieve(self, query: str) -> tuple[list[tuple[Document, float]], dict[str, Any]]: + if self.legacy_mode: + assert isinstance(self.vectorstore, Chroma) + retrieved = self.vectorstore.similarity_search_with_relevance_scores( + query, + k=self.return_k, + ) + params = { + "query_type": "legacy", + "retrieval_k": self.return_k, + "return_k": self.return_k, + "alpha": self.hybrid_alpha, + "backend": "legacy-chroma", + "filter_fallback_used": False, + } + return retrieved, params + + assert isinstance(self.vectorstore, CMMVectorStoreBase) + profile = self.classifier.classify(query) + retrieval_k = max(1, int(profile.retrieval_k)) + return_k = max(1, int(profile.return_k)) + alpha = float(profile.alpha) + effective_retrieval_k = ( + retrieval_k if self._adaptive_retrieval_k else self.retrieval_k + ) + effective_return_k = ( + return_k if self._adaptive_return_k else self.return_k ) - prompt = ChatPromptTemplate.from_template(""" - You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context} - Summarize the retrieved scientific content below. - Cite sources by ID when relevant: {source_ids} + dense_sparse = self.vectorstore.hybrid_search( + query=query, + k=effective_retrieval_k, + alpha=self.hybrid_alpha if self.hybrid_alpha is not None else alpha, + filters=profile.filters, + ) + filter_fallback_used = False + if not dense_sparse and profile.filters: + # Retry without metadata filters when strict tagging yields no hits. + dense_sparse = self.vectorstore.hybrid_search( + query=query, + k=effective_retrieval_k, + alpha=( + self.hybrid_alpha + if self.hybrid_alpha is not None + else alpha + ), + filters=None, + ) + filter_fallback_used = bool(dense_sparse) + top_docs = dense_sparse[:effective_return_k] + + if self.use_reranker: + top_docs = self.reranker.rerank( + query=query, + documents=top_docs, + top_k=effective_return_k, + ) - {retrieved_content} - """) - chain = prompt | self.llm | StrOutputParser() + params = { + "query_type": profile.query_type, + "retrieval_k": effective_retrieval_k, + "return_k": effective_return_k, + "alpha": self.hybrid_alpha if self.hybrid_alpha is not None else alpha, + "backend": self.vectorstore_backend, + "filter_fallback_used": filter_fallback_used, + } + return top_docs, params - # 2) One retrieval over the global DB with the task context - try: - if "context" not in state: - raise RuntimeError("Unexpected error: context not in state!") + def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState: + print("[RAG Agent] Retrieving Contextually Relevant Information...") + if "context" not in state: + raise RuntimeError("Unexpected error: context not in state!") - results = self.vectorstore.similarity_search_with_relevance_scores( - state["context"], k=self.return_k - ) + prompt = ChatPromptTemplate.from_template( + """ +You are a scientific assistant responsible for summarizing extracts from +research papers in the context of: {context} - relevance_scores = [score for _, score in results] - except Exception as e: - print(f"RAG failed due to: {e}") +Summarize the retrieved scientific content below. +Cite source IDs when relevant: {source_ids} + +{retrieved_content} +""" + ) + chain = prompt | self.llm | StrOutputParser() + + try: + results, params = self._retrieve(state["context"]) + relevance_scores = [float(score) for _, score in results] + except Exception as exc: + print(f"RAG failed due to: {exc}") return {**state, "summary": ""} - source_ids_list = [] + source_ids_list: list[str] = [] for doc, _ in results: - aid = doc.metadata.get("id") - if aid and aid not in source_ids_list: - source_ids_list.append(aid) + source_id = ( + doc.metadata.get("source_doc_id") + or doc.metadata.get("id") + or doc.metadata.get("source") + ) + if source_id and source_id not in source_ids_list: + source_ids_list.append(str(source_id)) source_ids = ", ".join(source_ids_list) retrieved_content = ( @@ -270,59 +479,66 @@ def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState: else "" ) - print("[RAG Agent] Summarizing Retrieved Information From Database...") - # 3) One summary based on retrieved chunks - rag_summary = chain.invoke({ - "retrieved_content": retrieved_content, - "context": state["context"], - "source_ids": source_ids, - }) + print("[RAG Agent] Summarizing Retrieved Information...") + rag_summary = chain.invoke( + { + "retrieved_content": retrieved_content, + "context": state["context"], + "source_ids": source_ids, + } + ) - # Persist a single file for the batch (optional) - batch_name = "RAG_summary.txt" os.makedirs(self.summaries_path, exist_ok=True) - with open(os.path.join(self.summaries_path, batch_name), "w") as f: - f.write(rag_summary) + with open( + os.path.join(self.summaries_path, "RAG_summary.txt"), + "w", + encoding="utf-8", + ) as handle: + handle.write(rag_summary) - # Diagnostics if relevance_scores: print(f"\nMax Relevance Score: {max(relevance_scores):.4f}") print(f"Min Relevance Score: {min(relevance_scores):.4f}") - print( - f"Median Relevance Score: {statistics.median(relevance_scores):.4f}\n" - ) + median = statistics.median(relevance_scores) + print(f"Median Relevance Score: {median:.4f}\n") else: print("\nNo RAG results retrieved (score list empty).\n") - # Return a single-element list by default (preferred) return { **state, "summary": rag_summary, "rag_metadata": { - "k": self.return_k, + "k": params["return_k"], "num_results": len(results), "relevance_scores": relevance_scores, + "query_type": params["query_type"], + "retrieval_k": params["retrieval_k"], + "backend": params["backend"], + "filter_fallback_used": params["filter_fallback_used"], }, } - def _build_graph(self): + def _build_graph(self) -> None: self.add_node(self._read_docs_node) self.add_node(self._ingest_docs_node) self.add_node(self._retrieve_and_summarize_node) self.graph.add_edge("_read_docs_node", "_ingest_docs_node") - self.graph.add_edge("_ingest_docs_node", "_retrieve_and_summarize_node") + self.graph.add_edge( + "_ingest_docs_node", "_retrieve_and_summarize_node" + ) self.graph.set_entry_point("_read_docs_node") self.graph.set_finish_point("_retrieve_and_summarize_node") -# NOTE: Run test in `tests/agents/test_rag_agent/test_rag_agent.py` via: -# -# pytest -s tests/agents/test_rag_agent -# -# OR -# -# uv run pytest -s tests/agents/test_rag_agent -# -# NOTE: You may need to `rm -rf workspace/rag-agent` to remove the vectorstore. +def _safe_int(value: str | None) -> int | None: + if value is None: + return None + value = value.strip() + if not value: + return None + try: + return int(value) + except ValueError: + return None diff --git a/src/ursa/agents/websearch_agent.py b/src/ursa/agents/websearch_agent.py deleted file mode 100644 index cc0bfe75..00000000 --- a/src/ursa/agents/websearch_agent.py +++ /dev/null @@ -1,182 +0,0 @@ -# from langchain_community.tools import TavilySearchResults -# from langchain_core.runnables.graph import MermaidDrawMethod -from typing import Annotated, Any, TypedDict - -import requests -from bs4 import BeautifulSoup -from langchain.agents import create_agent -from langchain.chat_models import BaseChatModel -from langchain.messages import HumanMessage, SystemMessage -from langchain_community.tools import DuckDuckGoSearchResults -from langchain_core.output_parsers import StrOutputParser -from langgraph.graph.message import add_messages -from langgraph.prebuilt import InjectedState -from pydantic import Field - -from ursa.prompt_library.websearch_prompts import ( - reflection_prompt, - summarize_prompt, - websearch_prompt, -) - -from .base import BaseAgent - - -class WebSearchState(TypedDict): - websearch_query: str - messages: Annotated[list, add_messages] - urls_visited: list[str] - max_websearch_steps: Annotated[ - int, Field(default=100, description="Maximum number of websearch steps") - ] - remaining_steps: int - is_last_step: bool - model: Any - thread_id: Any - - -# Adding the model to the state clumsily so that all "read" sources arent in the -# context window. That eats a ton of tokens because each `llm.invoke` passes -# all the tokens of all the sources. - - -class WebSearchAgentLegacy(BaseAgent): - def __init__(self, llm: BaseChatModel, **kwargs): - super().__init__(llm, **kwargs) - self.websearch_prompt = websearch_prompt - self.reflection_prompt = reflection_prompt - self.tools = [search_tool, process_content] # + cb_tools - self.has_internet = self._check_for_internet( - kwargs.get("url", "http://www.lanl.gov") - ) - - def _review_node(self, state: WebSearchState) -> WebSearchState: - if not self.has_internet: - return { - "messages": [ - HumanMessage( - content="No internet for WebSearch Agent so no research to review." - ) - ], - "urls_visited": [], - } - - translated = [SystemMessage(content=reflection_prompt)] + state[ - "messages" - ] - res = StrOutputParser().invoke( - self.llm.invoke( - translated, {"configurable": {"thread_id": self.thread_id}} - ) - ) - return {"messages": [HumanMessage(content=res)]} - - def _response_node(self, state: WebSearchState) -> WebSearchState: - if not self.has_internet: - return { - "messages": [ - HumanMessage( - content="No internet for WebSearch Agent. No research carried out." - ) - ], - "urls_visited": [], - } - - messages = state["messages"] + [SystemMessage(content=summarize_prompt)] - response = StrOutputParser().invoke( - self.llm.invoke( - messages, {"configurable": {"thread_id": self.thread_id}} - ) - ) - - urls_visited = [] - for message in messages: - if message.model_dump().get("tool_calls", []): - if "url" in message.tool_calls[0]["args"]: - urls_visited.append(message.tool_calls[0]["args"]["url"]) - return {"messages": [response], "urls_visited": urls_visited} - - def _check_for_internet(self, url, timeout=2): - """ - Checks for internet connectivity by attempting an HTTP GET request. - """ - try: - requests.get(url, timeout=timeout) - return True - except (requests.ConnectionError, requests.Timeout): - return False - - def _state_store_node(self, state: WebSearchState) -> WebSearchState: - state["thread_id"] = self.thread_id - return state - # return dict(**state, thread_id=self.thread_id) - - def _create_react(self, state: WebSearchState) -> WebSearchState: - react_agent = create_agent( - self.llm, - self.tools, - state_schema=WebSearchState, - system_prompt=self.websearch_prompt, - ) - return react_agent.invoke(state) - - def _build_graph(self): - self.add_node(self._state_store_node) - self.add_node(self._create_react) - self.add_node(self._review_node) - self.add_node(self._response_node) - - self.graph.set_entry_point("_state_store_node") - self.graph.add_edge("_state_store_node", "_create_react") - self.graph.add_edge("_create_react", "_review_node") - self.graph.set_finish_point("_response_node") - - self.graph.add_conditional_edges( - "_review_node", - should_continue, - { - "_create_react": "_create_react", - "_response_node": "_response_node", - }, - ) - - -def process_content( - url: str, context: str, state: Annotated[dict, InjectedState] -) -> str: - """ - Processes content from a given webpage. - - Args: - url: string with the url to obtain text content from. - context: string summary of the information the agent wants from the url for summarizing salient information. - """ - print("Parsing information from ", url) - response = requests.get(url) - soup = BeautifulSoup(response.content, "html.parser") - - content_prompt = f""" - Here is the full content: - {soup.get_text()} - - Carefully summarize the content in full detail, given the following context: - {context} - """ - summarized_information = StrOutputParser().invoke( - state["model"].invoke( - content_prompt, {"configurable": {"thread_id": state["thread_id"]}} - ) - ) - return summarized_information - - -search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10) -# search_tool = TavilySearchResults(max_results=10, search_depth="advanced",include_answer=True) - - -def should_continue(state: WebSearchState): - if len(state["messages"]) > (state.get("max_websearch_steps", 100) + 3): - return "_response_node" - if "[APPROVED]" in state["messages"][-1].text: - return "_response_node" - return "_create_react" diff --git a/src/ursa/tools/__init__.py b/src/ursa/tools/__init__.py index 73d20f08..743ac604 100644 --- a/src/ursa/tools/__init__.py +++ b/src/ursa/tools/__init__.py @@ -1,3 +1,9 @@ +from .cmm_supply_chain_optimization_tool import ( + run_cmm_supply_chain_optimization as run_cmm_supply_chain_optimization, +) +from .cmm_supply_chain_optimization_tool import ( + solve_cmm_supply_chain_optimization as solve_cmm_supply_chain_optimization, +) from .read_file_tool import read_file as read_file from .run_command_tool import run_command as run_command from .write_code_tool import edit_code as edit_code diff --git a/src/ursa/tools/_lp_solver.py b/src/ursa/tools/_lp_solver.py new file mode 100644 index 00000000..4e0deb5b --- /dev/null +++ b/src/ursa/tools/_lp_solver.py @@ -0,0 +1,294 @@ +"""LP solver backend for CMM supply-chain optimization. + +Uses ``scipy.optimize.linprog`` with the HiGHS method to produce +provably optimal allocations and dual (shadow) prices. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +try: + from scipy.optimize import linprog + + _HAS_SCIPY = True +except ImportError: # pragma: no cover + _HAS_SCIPY = False + + +def scipy_available() -> bool: + """Return ``True`` when scipy is importable.""" + return _HAS_SCIPY + + +@dataclass(frozen=True) +class LPResult: + """Result container returned by :func:`solve_lp`.""" + + success: bool + status: str + objective: float = 0.0 + allocation: dict[tuple[str, str], float] = field( + default_factory=dict, + ) + unmet: dict[str, float] = field(default_factory=dict) + shadow_prices_demand: dict[str, float] = field( + default_factory=dict, + ) + shadow_prices_capacity: dict[str, float] = field( + default_factory=dict, + ) + shadow_prices_share: dict[str, float] = field( + default_factory=dict, + ) + shadow_prices_composition: dict[str, float] = field( + default_factory=dict, + ) + reduced_costs: dict[str, float] = field( + default_factory=dict, + ) + + +def solve_lp( + *, + suppliers: list[Any], + markets: list[str], + demand: dict[str, float], + shipping_cost: dict[str, dict[str, float]], + risk_weight: float, + unmet_penalty: float, + max_supplier_share: float, + composition_targets: dict[str, float] | None = None, + composition_tolerance: float = 0.0, + composition_profiles: dict[str, dict[str, float]] | None = None, +) -> LPResult: + """Solve the supply-chain allocation as a linear programme. + + Parameters + ---------- + suppliers + List of supplier objects (must have ``.name``, ``.capacity``, + ``.unit_cost``, ``.risk_score`` attributes). + markets + Sorted list of market names. + demand + ``{market: required_quantity}``. + shipping_cost + ``{supplier_name: {market: cost_per_unit}}``. + risk_weight + Multiplier for supplier risk in the objective. + unmet_penalty + Per-unit penalty for unmet demand. + max_supplier_share + Maximum fraction of total demand from any single supplier. + composition_targets + Optional ``{component: target_fraction}``. + composition_tolerance + Allowed deviation from each composition target. + composition_profiles + ``{supplier_name: {component: fraction}}``. + + Returns + ------- + LPResult + Solver output including allocations, unmet demand, and + shadow prices for all constraint groups. + """ + if not _HAS_SCIPY: + return LPResult( + success=False, + status="scipy_unavailable", + ) + + S = len(suppliers) + M = len(markets) + n_x = S * M # allocation variables + n_u = M # unmet-demand slack variables + n_vars = n_x + n_u + + total_demand = sum(demand.values()) + + # --- Objective: c^T x ------------------------------------------- + c = [0.0] * n_vars + for s_idx, sup in enumerate(suppliers): + for m_idx, mkt in enumerate(markets): + cost = ( + sup.unit_cost + + shipping_cost.get(sup.name, {}).get(mkt, 0.0) + + risk_weight * sup.risk_score + ) + c[s_idx * M + m_idx] = cost + for m_idx in range(M): + c[n_x + m_idx] = unmet_penalty + + # --- Equality: demand balance ----------------------------------- + # Σ_s x[s,m] + u[m] = demand[m] ∀ m + A_eq: list[list[float]] = [] + b_eq: list[float] = [] + for m_idx, mkt in enumerate(markets): + row = [0.0] * n_vars + for s_idx in range(S): + row[s_idx * M + m_idx] = 1.0 + row[n_x + m_idx] = 1.0 + A_eq.append(row) + b_eq.append(demand[mkt]) + + # --- Inequality: capacity + share + composition ----------------- + A_ub: list[list[float]] = [] + b_ub: list[float] = [] + + # (2) Capacity: Σ_m x[s,m] ≤ capacity[s] ∀ s + for s_idx, sup in enumerate(suppliers): + row = [0.0] * n_vars + for m_idx in range(M): + row[s_idx * M + m_idx] = 1.0 + A_ub.append(row) + b_ub.append(sup.capacity) + + # (3) Share: Σ_m x[s,m] ≤ max_share × total_demand ∀ s + share_cap = max_supplier_share * total_demand + for s_idx in range(S): + row = [0.0] * n_vars + for m_idx in range(M): + row[s_idx * M + m_idx] = 1.0 + A_ub.append(row) + b_ub.append(share_cap) + + # (4-5) Composition constraints (optional) + comp_targets = composition_targets or {} + comp_profiles = composition_profiles or {} + comp_components = sorted(comp_targets.keys()) + + for component in comp_components: + target = comp_targets[component] + tol = composition_tolerance + + # Upper: Σ_{s,m} (profile[s,c] - target - tol) × x[s,m] ≤ 0 + row_upper = [0.0] * n_vars + for s_idx, sup in enumerate(suppliers): + coeff = ( + comp_profiles.get(sup.name, {}).get(component, 0.0) + - target + - tol + ) + for m_idx in range(M): + row_upper[s_idx * M + m_idx] = coeff + A_ub.append(row_upper) + b_ub.append(0.0) + + # Lower: Σ_{s,m} (target - tol - profile[s,c]) × x[s,m] ≤ 0 + row_lower = [0.0] * n_vars + for s_idx, sup in enumerate(suppliers): + coeff = ( + target + - tol + - comp_profiles.get(sup.name, {}).get(component, 0.0) + ) + for m_idx in range(M): + row_lower[s_idx * M + m_idx] = coeff + A_ub.append(row_lower) + b_ub.append(0.0) + + # --- Variable bounds: x ≥ 0, u ≥ 0 ----------------------------- + bounds = [(0.0, None)] * n_vars + + # --- Solve with HiGHS ------------------------------------------- + result = linprog( + c, + A_ub=A_ub if A_ub else None, + b_ub=b_ub if b_ub else None, + A_eq=A_eq, + b_eq=b_eq, + bounds=bounds, + method="highs", + ) + + if not result.success: + return LPResult( + success=False, + status=f"lp_{result.status}", + objective=0.0, + ) + + # --- Extract solution ------------------------------------------- + x = result.x + + allocation: dict[tuple[str, str], float] = {} + for s_idx, sup in enumerate(suppliers): + for m_idx, mkt in enumerate(markets): + val = x[s_idx * M + m_idx] + if val > 1e-9: + allocation[(sup.name, mkt)] = round(val, 6) + + unmet: dict[str, float] = {} + for m_idx, mkt in enumerate(markets): + val = x[n_x + m_idx] + unmet[mkt] = round(val, 6) if val > 1e-9 else 0.0 + + # --- Extract shadow prices (duals) ------------------------------ + shadow_demand: dict[str, float] = {} + shadow_capacity: dict[str, float] = {} + shadow_share: dict[str, float] = {} + shadow_composition: dict[str, float] = {} + + # Equality constraint duals (demand balance) + if hasattr(result, "eqlin") and result.eqlin is not None: + eq_duals = result.eqlin.marginals + for m_idx, mkt in enumerate(markets): + shadow_demand[mkt] = round(float(eq_duals[m_idx]), 6) + + # Inequality constraint duals + if hasattr(result, "ineqlin") and result.ineqlin is not None: + ineq_duals = result.ineqlin.marginals + offset = 0 + + # Capacity duals (S constraints) + for s_idx, sup in enumerate(suppliers): + shadow_capacity[sup.name] = round( + float(ineq_duals[offset + s_idx]), + 6, + ) + offset += S + + # Share duals (S constraints) + for s_idx, sup in enumerate(suppliers): + shadow_share[sup.name] = round( + float(ineq_duals[offset + s_idx]), + 6, + ) + offset += S + + # Composition duals (2 per component) + for component in comp_components: + upper_dual = float(ineq_duals[offset]) + lower_dual = float(ineq_duals[offset + 1]) + shadow_composition[component] = round( + upper_dual - lower_dual, + 6, + ) + offset += 2 + + # --- Reduced costs ---------------------------------------------- + reduced: dict[str, float] = {} + if hasattr(result, "x"): + for s_idx, sup in enumerate(suppliers): + for m_idx, mkt in enumerate(markets): + idx = s_idx * M + m_idx + # reduced cost from fun gradient minus dual + rc = c[idx] - shadow_demand.get(mkt, 0.0) + if abs(rc) > 1e-9: + reduced[f"{sup.name}->{mkt}"] = round(rc, 6) + + return LPResult( + success=True, + status="optimal", + objective=round(float(result.fun), 6), + allocation=allocation, + unmet=unmet, + shadow_prices_demand=shadow_demand, + shadow_prices_capacity=shadow_capacity, + shadow_prices_share=shadow_share, + shadow_prices_composition=shadow_composition, + reduced_costs=reduced, + ) diff --git a/src/ursa/tools/cmm_supply_chain_optimization_tool.py b/src/ursa/tools/cmm_supply_chain_optimization_tool.py new file mode 100755 index 00000000..9b39ea4e --- /dev/null +++ b/src/ursa/tools/cmm_supply_chain_optimization_tool.py @@ -0,0 +1,983 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Annotated, Any + +from langchain_core.tools import tool +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) + +from ursa.tools._lp_solver import LPResult, scipy_available, solve_lp + +_log = logging.getLogger(__name__) + +NonNegativeFloat = Annotated[float, Field(ge=0.0)] +UnitFraction = Annotated[float, Field(ge=0.0, le=1.0)] + +_EPS = 1e-9 + + +# --------------------------------------------------------------------------- +# Pydantic input models +# --------------------------------------------------------------------------- + + +class SupplierInput(BaseModel): + """Validated input for a single supplier.""" + + model_config = ConfigDict(extra="forbid") + + name: str | None = None + capacity: NonNegativeFloat + unit_cost: NonNegativeFloat + risk_score: NonNegativeFloat = 0.0 + composition_profile: dict[str, UnitFraction] | None = None + + @field_validator("composition_profile", mode="before") + @classmethod + def _normalize_composition_keys( + cls, + v: dict[str, Any] | None, + ) -> dict[str, Any] | None: + if v is None: + return None + return {_normalize_component_name(k): val for k, val in v.items()} + + +class OptimizationInput(BaseModel): + """Validated top-level optimization input.""" + + model_config = ConfigDict(extra="forbid") + + commodity: str = "CMM" + demand: dict[str, float] + suppliers: list[SupplierInput] + shipping_cost: dict[str, dict[str, float]] | None = None + risk_weight: NonNegativeFloat = 0.0 + unmet_demand_penalty: Annotated[float, Field(ge=1.0)] = 10000.0 + max_supplier_share: UnitFraction = 1.0 + composition_targets: dict[str, UnitFraction] | None = None + composition_tolerance: UnitFraction = 0.0 + solver_backend: str | None = None + + @field_validator("demand", mode="after") + @classmethod + def _demand_non_empty(cls, v: dict[str, float]) -> dict[str, float]: + if not v: + msg = "demand must be a non-empty mapping" + raise ValueError(msg) + return v + + @field_validator("suppliers", mode="after") + @classmethod + def _suppliers_non_empty( + cls, v: list[SupplierInput] + ) -> list[SupplierInput]: + if not v: + msg = "suppliers must be a non-empty list" + raise ValueError(msg) + return v + + @field_validator("composition_targets", mode="before") + @classmethod + def _normalize_target_keys( + cls, + v: dict[str, Any] | None, + ) -> dict[str, Any] | None: + if v is None: + return None + return {_normalize_component_name(k): val for k, val in v.items()} + + @model_validator(mode="after") + def _auto_name_suppliers(self) -> OptimizationInput: + for idx, supplier in enumerate(self.suppliers): + if supplier.name is None: + supplier.name = f"supplier_{idx + 1}" + return self + + +# --------------------------------------------------------------------------- +# Pydantic output models +# --------------------------------------------------------------------------- + + +class AllocationItem(BaseModel): + """A single supplier-to-market allocation.""" + + model_config = ConfigDict(extra="forbid") + + supplier: str + market: str + amount: float + unit_total_cost: float + + +class ObjectiveBreakdown(BaseModel): + """Cost breakdown of the objective function.""" + + model_config = ConfigDict(extra="forbid") + + procurement: float + shipping: float + risk_penalty: float + unmet_penalty: float + + +class CompositionResult(BaseModel): + """Composition constraint evaluation results.""" + + model_config = ConfigDict(extra="forbid") + + targets: dict[str, float] + actual: dict[str, float] + residuals: dict[str, float] + tolerance: float + feasible: bool + + +class ConstraintResiduals(BaseModel): + """Residuals for all constraint groups.""" + + model_config = ConfigDict(extra="forbid") + + demand_balance: dict[str, float] + supplier_capacity: dict[str, float] + supplier_share_cap: dict[str, float] + composition: dict[str, float] + + +class SensitivitySummary(BaseModel): + """Summary of binding constraints and bottlenecks.""" + + model_config = ConfigDict(extra="forbid") + + active_capacity_constraints: list[str] + bottleneck_markets: list[str] + average_unit_cost: float + unmet_demand_total: float + composition_binding_components: list[str] + composition_feasible: bool + + +class ShadowPrices(BaseModel): + """Dual values from LP solver indicating marginal costs.""" + + model_config = ConfigDict(extra="forbid") + + demand_balance: dict[str, float] + supplier_capacity: dict[str, float] + supplier_share_cap: dict[str, float] + composition: dict[str, float] + + +class OptimizationOutput(BaseModel): + """Validated optimization result.""" + + model_config = ConfigDict(extra="forbid") + + commodity: str + status: str + feasible: bool + objective_value: float + objective_breakdown: ObjectiveBreakdown + allocations: list[AllocationItem] + unmet_demand: dict[str, float] + constraint_residuals: ConstraintResiduals + composition: CompositionResult | None + sensitivity_summary: SensitivitySummary + shadow_prices: ShadowPrices | None = None + + +# --------------------------------------------------------------------------- +# Validation error response +# --------------------------------------------------------------------------- + + +class ValidationErrorDetail(BaseModel): + """A single validation error.""" + + loc: list[str | int] + msg: str + type: str + + +class OptimizationErrorResponse(BaseModel): + """Structured error response matching output schema conventions.""" + + commodity: str = "CMM" + status: str = "validation_error" + feasible: bool = False + errors: list[ValidationErrorDetail] + + +@dataclass(frozen=True) +class _Supplier: + name: str + capacity: float + unit_cost: float + risk_score: float + + +def _normalize_component_name(name: Any) -> str: + return str(name).strip().upper() + + +def _sum_allocated_by_supplier( + allocation: dict[tuple[str, str], float], + suppliers: list[_Supplier], +) -> dict[str, float]: + totals = {supplier.name: 0.0 for supplier in suppliers} + for (supplier_name, _market), amount in allocation.items(): + totals[supplier_name] = totals.get(supplier_name, 0.0) + amount + return totals + + +def _compute_costs( + *, + allocation: dict[tuple[str, str], float], + suppliers: list[_Supplier], + shipping_cost: dict[str, dict[str, float]], + risk_weight: float, + unmet: dict[str, float], + unmet_penalty: float, +) -> dict[str, float]: + suppliers_by_name = {supplier.name: supplier for supplier in suppliers} + procurement_cost = 0.0 + shipping_total = 0.0 + risk_total = 0.0 + + for (supplier_name, market), amount in allocation.items(): + supplier = suppliers_by_name[supplier_name] + procurement_cost += amount * supplier.unit_cost + shipping_total += amount * shipping_cost.get(supplier_name, {}).get( + market, 0.0 + ) + risk_total += amount * (risk_weight * supplier.risk_score) + + unmet_cost = sum(unmet.values()) * unmet_penalty + return { + "procurement": procurement_cost, + "shipping": shipping_total, + "risk_penalty": risk_total, + "unmet_penalty": unmet_cost, + } + + +def _compute_composition_metrics( + *, + totals_by_supplier: dict[str, float], + composition_targets: dict[str, float], + composition_profiles: dict[str, dict[str, float]], + composition_tolerance: float, +) -> dict[str, Any] | None: + if not composition_targets: + return None + + total_allocated = sum(totals_by_supplier.values()) + actual: dict[str, float] = {} + residuals: dict[str, float] = {} + + for component, target in composition_targets.items(): + if total_allocated <= _EPS: + actual_value = 0.0 + else: + weighted = 0.0 + for supplier_name, amount in totals_by_supplier.items(): + profile = composition_profiles.get(supplier_name, {}) + weighted += amount * profile.get(component, 0.0) + actual_value = weighted / total_allocated + actual[component] = actual_value + residuals[component] = actual_value - target + + feasible = all( + abs(residual) <= composition_tolerance + _EPS + for residual in residuals.values() + ) + + return { + "targets": composition_targets, + "actual": actual, + "residuals": residuals, + "tolerance": composition_tolerance, + "feasible": feasible, + } + + +def _shift_between_suppliers( + *, + allocation: dict[tuple[str, str], float], + supplier_remaining: dict[str, float], + donor: str, + receiver: str, + markets: list[str], + max_shift: float, +) -> float: + if max_shift <= _EPS: + return 0.0 + + shifted_total = 0.0 + for market in markets: + if shifted_total >= max_shift - _EPS: + break + + donor_key = (donor, market) + receiver_key = (receiver, market) + donor_amount = allocation.get(donor_key, 0.0) + receiver_cap = supplier_remaining.get(receiver, 0.0) + + shift = min( + max_shift - shifted_total, + donor_amount, + receiver_cap, + ) + if shift <= _EPS: + continue + + new_donor_amount = donor_amount - shift + if new_donor_amount <= _EPS: + allocation.pop(donor_key, None) + else: + allocation[donor_key] = new_donor_amount + + allocation[receiver_key] = allocation.get(receiver_key, 0.0) + shift + supplier_remaining[receiver] = receiver_cap - shift + supplier_remaining[donor] = supplier_remaining.get(donor, 0.0) + shift + shifted_total += shift + + return shifted_total + + +def _rebalance_for_composition( + *, + allocation: dict[tuple[str, str], float], + supplier_remaining: dict[str, float], + suppliers: list[_Supplier], + markets: list[str], + composition_targets: dict[str, float], + composition_profiles: dict[str, dict[str, float]], + composition_tolerance: float, + max_iterations: int = 200, +) -> None: + if not composition_targets: + return + + for _ in range(max_iterations): + totals = _sum_allocated_by_supplier(allocation, suppliers) + metrics = _compute_composition_metrics( + totals_by_supplier=totals, + composition_targets=composition_targets, + composition_profiles=composition_profiles, + composition_tolerance=composition_tolerance, + ) + if metrics is None or metrics["feasible"]: + return + + total_allocated = sum(totals.values()) + if total_allocated <= _EPS: + return + + residuals = metrics["residuals"] + pending_components = [ + (component, float(residual)) + for component, residual in residuals.items() + if abs(float(residual)) > composition_tolerance + _EPS + ] + pending_components.sort(key=lambda item: abs(item[1]), reverse=True) + + progressed = False + for component, residual in pending_components: + too_high = residual > 0.0 + if too_high: + donor_candidates = sorted( + suppliers, + key=lambda supplier: ( + -composition_profiles[supplier.name].get( + component, 0.0 + ), + supplier.name, + ), + ) + receiver_candidates = sorted( + suppliers, + key=lambda supplier: ( + composition_profiles[supplier.name].get(component, 0.0), + supplier.name, + ), + ) + else: + donor_candidates = sorted( + suppliers, + key=lambda supplier: ( + composition_profiles[supplier.name].get(component, 0.0), + supplier.name, + ), + ) + receiver_candidates = sorted( + suppliers, + key=lambda supplier: ( + -composition_profiles[supplier.name].get( + component, 0.0 + ), + supplier.name, + ), + ) + + for donor in donor_candidates: + donor_name = donor.name + donor_amount = totals.get(donor_name, 0.0) + if donor_amount <= _EPS: + continue + + donor_profile = composition_profiles[donor_name].get( + component, 0.0 + ) + for receiver in receiver_candidates: + receiver_name = receiver.name + if receiver_name == donor_name: + continue + if supplier_remaining.get(receiver_name, 0.0) <= _EPS: + continue + + receiver_profile = composition_profiles[receiver_name].get( + component, 0.0 + ) + profile_delta = ( + donor_profile - receiver_profile + if too_high + else receiver_profile - donor_profile + ) + if profile_delta <= _EPS: + continue + + needed_shift = ( + (abs(residual) - composition_tolerance) + * total_allocated + / profile_delta + ) + if needed_shift <= _EPS: + continue + + shifted = _shift_between_suppliers( + allocation=allocation, + supplier_remaining=supplier_remaining, + donor=donor_name, + receiver=receiver_name, + markets=markets, + max_shift=needed_shift, + ) + if shifted > _EPS: + progressed = True + break + if progressed: + break + if progressed: + break + + if not progressed: + return + + +def _prepare_solver_input( + inp: OptimizationInput, +) -> tuple[ + str, + list[str], + list[_Supplier], + dict[str, float], + dict[str, dict[str, float]], + float, + float, + float, + dict[str, float], + float, + dict[str, dict[str, float]], +]: + """Reshape a validated *OptimizationInput* into the tuple the solver expects.""" + commodity = inp.commodity + markets = sorted(inp.demand.keys()) + demand = dict(inp.demand) + + suppliers: list[_Supplier] = [] + composition_profiles_raw: dict[str, dict[str, float]] = {} + for supplier_inp in inp.suppliers: + assert supplier_inp.name is not None # guaranteed by model_validator + suppliers.append( + _Supplier( + name=supplier_inp.name, + capacity=supplier_inp.capacity, + unit_cost=supplier_inp.unit_cost, + risk_score=supplier_inp.risk_score, + ) + ) + composition_profiles_raw[supplier_inp.name] = ( + dict(supplier_inp.composition_profile) + if supplier_inp.composition_profile + else {} + ) + + suppliers = sorted(suppliers, key=lambda supplier: supplier.name) + + composition_targets: dict[str, float] = ( + dict(inp.composition_targets) if inp.composition_targets else {} + ) + + components: set[str] = set(composition_targets.keys()) + for profile in composition_profiles_raw.values(): + components.update(profile.keys()) + + normalized_profiles: dict[str, dict[str, float]] = {} + for supplier in suppliers: + profile = composition_profiles_raw.get(supplier.name, {}) + normalized_profiles[supplier.name] = { + component: profile.get(component, 0.0) + for component in sorted(components) + } + + shipping_cost: dict[str, dict[str, float]] = ( + { + supplier_name: dict(market_costs) + for supplier_name, market_costs in inp.shipping_cost.items() + } + if inp.shipping_cost + else {} + ) + + return ( + commodity, + markets, + suppliers, + demand, + shipping_cost, + inp.risk_weight, + inp.unmet_demand_penalty, + inp.max_supplier_share, + composition_targets, + inp.composition_tolerance, + normalized_profiles, + ) + + +def _validation_error_response( + payload: dict[str, Any], + exc: ValidationError, +) -> dict[str, Any]: + """Convert a *ValidationError* into a dict matching the output schema conventions.""" + commodity = str(payload.get("commodity", "CMM")) + return OptimizationErrorResponse( + commodity=commodity, + errors=[ + ValidationErrorDetail( + loc=[str(x) for x in e["loc"]], + msg=e["msg"], + type=e["type"], + ) + for e in exc.errors() + ], + ).model_dump() + + +def _greedy_fallback( + *, + markets: list[str], + suppliers: list[_Supplier], + demand: dict[str, float], + shipping_cost: dict[str, dict[str, float]], + risk_weight: float, + max_supplier_share: float, +) -> tuple[ + dict[tuple[str, str], float], + dict[str, float], + dict[str, float], + dict[str, float], +]: + total_demand = sum(demand.values()) + supplier_max: dict[str, float] = {} + for supplier in suppliers: + share_cap = max_supplier_share * total_demand + supplier_max[supplier.name] = min(supplier.capacity, share_cap) + + supplier_remaining = dict(supplier_max) + demand_remaining = {market: float(demand[market]) for market in markets} + allocation: dict[tuple[str, str], float] = {} + + for market in markets: + candidates: list[tuple[float, _Supplier]] = [] + for supplier in suppliers: + landed_cost = ( + supplier.unit_cost + + shipping_cost.get(supplier.name, {}).get(market, 0.0) + + risk_weight * supplier.risk_score + ) + candidates.append((landed_cost, supplier)) + candidates.sort(key=lambda item: (item[0], item[1].name)) + + for _landed_cost, supplier in candidates: + if demand_remaining[market] <= _EPS: + break + available = supplier_remaining[supplier.name] + if available <= _EPS: + continue + flow = min(available, demand_remaining[market]) + if flow <= _EPS: + continue + allocation[(supplier.name, market)] = flow + supplier_remaining[supplier.name] -= flow + demand_remaining[market] -= flow + + unmet = {market: max(0.0, demand_remaining[market]) for market in markets} + return allocation, unmet, supplier_max, supplier_remaining + + +def _build_output( + *, + commodity: str, + allocation: dict[tuple[str, str], float], + unmet: dict[str, float], + suppliers: list[_Supplier], + markets: list[str], + demand: dict[str, float], + shipping_cost: dict[str, dict[str, float]], + risk_weight: float, + unmet_penalty: float, + supplier_max: dict[str, float], + composition_targets: dict[str, float], + composition_profiles: dict[str, dict[str, float]], + composition_tolerance: float, + shadow_prices: dict[str, Any] | None = None, + status_override: str | None = None, +) -> dict[str, Any]: + """Assemble the result dict and validate it through *OptimizationOutput*.""" + totals_by_supplier = _sum_allocated_by_supplier(allocation, suppliers) + costs = _compute_costs( + allocation=allocation, + suppliers=suppliers, + shipping_cost=shipping_cost, + risk_weight=risk_weight, + unmet=unmet, + unmet_penalty=unmet_penalty, + ) + + allocation_items: list[dict[str, Any]] = [] + suppliers_by_name = {supplier.name: supplier for supplier in suppliers} + for (supplier_name, market), amount in sorted( + allocation.items(), key=lambda item: (item[0][0], item[0][1]) + ): + supplier = suppliers_by_name[supplier_name] + unit_total = ( + supplier.unit_cost + + shipping_cost.get(supplier_name, {}).get(market, 0.0) + + risk_weight * supplier.risk_score + ) + allocation_items.append({ + "supplier": supplier_name, + "market": market, + "amount": round(amount, 6), + "unit_total_cost": round(unit_total, 6), + }) + + demand_residual: dict[str, float] = {} + for market in markets: + allocated = sum( + amount + for (supplier_name, mkt), amount in allocation.items() + if mkt == market and supplier_name + ) + demand_residual[market] = round( + allocated + unmet[market] - demand[market], 9 + ) + + supplier_capacity_residual: dict[str, float] = {} + supplier_share_residual: dict[str, float] = {} + for supplier in suppliers: + used = totals_by_supplier[supplier.name] + supplier_capacity_residual[supplier.name] = round( + supplier.capacity - used, + 9, + ) + supplier_share_residual[supplier.name] = round( + supplier_max[supplier.name] - used, + 9, + ) + + unmet_total = sum(unmet.values()) + objective_value = sum(costs.values()) + + composition = _compute_composition_metrics( + totals_by_supplier=totals_by_supplier, + composition_targets=composition_targets, + composition_profiles=composition_profiles, + composition_tolerance=composition_tolerance, + ) + composition_feasible = True + composition_residuals: dict[str, float] = {} + composition_binding: list[str] = [] + if composition is not None: + composition_feasible = bool(composition["feasible"]) + composition_residuals = { + component: round(float(residual), 9) + for component, residual in composition["residuals"].items() + } + composition_binding = [ + component + for component, residual in composition["residuals"].items() + if abs(float(residual)) >= composition_tolerance - _EPS + ] + + feasible = unmet_total <= _EPS and composition_feasible + if status_override is not None: + status = status_override + elif unmet_total > _EPS and not composition_feasible: + status = "infeasible_unmet_and_composition" + elif unmet_total > _EPS: + status = "infeasible_unmet_demand" + elif not composition_feasible: + status = "infeasible_composition_constraints" + else: + status = "optimal_greedy" + + active_capacity = [ + supplier.name + for supplier in suppliers + if abs(supplier_capacity_residual[supplier.name]) <= 1e-6 + ] + bottleneck_markets = [market for market in markets if unmet[market] > _EPS] + allocated_total = sum(totals_by_supplier.values()) + avg_unit = objective_value / allocated_total if allocated_total else 0.0 + + composition_output = None + if composition is not None: + composition_output = CompositionResult( + targets={ + component: round(float(target), 9) + for component, target in composition["targets"].items() + }, + actual={ + component: round(float(actual), 9) + for component, actual in composition["actual"].items() + }, + residuals=composition_residuals, + tolerance=round(float(composition["tolerance"]), 9), + feasible=composition_feasible, + ) + + shadow_prices_model = None + if shadow_prices is not None: + shadow_prices_model = ShadowPrices( + demand_balance=shadow_prices.get( + "demand_balance", + {}, + ), + supplier_capacity=shadow_prices.get( + "supplier_capacity", + {}, + ), + supplier_share_cap=shadow_prices.get( + "supplier_share_cap", + {}, + ), + composition=shadow_prices.get("composition", {}), + ) + + return OptimizationOutput( + commodity=commodity, + status=status, + feasible=feasible, + objective_value=round(objective_value, 6), + objective_breakdown=ObjectiveBreakdown( + **{k: round(v, 6) for k, v in costs.items()}, + ), + allocations=[AllocationItem(**item) for item in allocation_items], + unmet_demand={k: round(v, 6) for k, v in unmet.items()}, + constraint_residuals=ConstraintResiduals( + demand_balance=demand_residual, + supplier_capacity=supplier_capacity_residual, + supplier_share_cap=supplier_share_residual, + composition=composition_residuals, + ), + composition=composition_output, + sensitivity_summary=SensitivitySummary( + active_capacity_constraints=active_capacity, + bottleneck_markets=bottleneck_markets, + average_unit_cost=round(avg_unit, 6), + unmet_demand_total=round(unmet_total, 6), + composition_binding_components=sorted(composition_binding), + composition_feasible=composition_feasible, + ), + shadow_prices=shadow_prices_model, + ).model_dump() + + +def solve_cmm_supply_chain_optimization( + optimization_input: dict[str, Any], +) -> dict[str, Any]: + try: + inp = OptimizationInput.model_validate(optimization_input) + except ValidationError as exc: + return _validation_error_response(optimization_input, exc) + + ( + commodity, + markets, + suppliers, + demand, + shipping_cost, + risk_weight, + unmet_penalty, + max_supplier_share, + composition_targets, + composition_tolerance, + composition_profiles, + ) = _prepare_solver_input(inp) + + backend = inp.solver_backend + total_demand = sum(demand.values()) + + # --- Try LP solver when appropriate --- + if backend != "greedy": + lp_ok = scipy_available() + if not lp_ok and backend == "lp": + return OptimizationErrorResponse( + commodity=commodity, + errors=[ + ValidationErrorDetail( + loc=["solver_backend"], + msg=("LP backend requested but scipy is not installed"), + type="import_error", + ), + ], + ).model_dump() + + if lp_ok: + lp_result: LPResult = solve_lp( + suppliers=suppliers, + markets=markets, + demand=demand, + shipping_cost=shipping_cost, + risk_weight=risk_weight, + unmet_penalty=unmet_penalty, + max_supplier_share=max_supplier_share, + composition_targets=( + composition_targets if composition_targets else None + ), + composition_tolerance=composition_tolerance, + composition_profiles=( + composition_profiles if composition_profiles else None + ), + ) + + if lp_result.success: + # Compute supplier_max for _build_output + share_cap = max_supplier_share * total_demand + lp_supplier_max = { + sup.name: min(sup.capacity, share_cap) for sup in suppliers + } + + lp_shadow = { + "demand_balance": (lp_result.shadow_prices_demand), + "supplier_capacity": (lp_result.shadow_prices_capacity), + "supplier_share_cap": (lp_result.shadow_prices_share), + "composition": (lp_result.shadow_prices_composition), + } + + lp_status = "optimal_lp" + if sum(lp_result.unmet.values()) > _EPS: + lp_status = "infeasible_unmet_demand" + + return _build_output( + commodity=commodity, + allocation=lp_result.allocation, + unmet=lp_result.unmet, + suppliers=suppliers, + markets=markets, + demand=demand, + shipping_cost=shipping_cost, + risk_weight=risk_weight, + unmet_penalty=unmet_penalty, + supplier_max=lp_supplier_max, + composition_targets=composition_targets, + composition_profiles=composition_profiles, + composition_tolerance=composition_tolerance, + shadow_prices=lp_shadow, + status_override=lp_status, + ) + + if backend == "lp": + return OptimizationErrorResponse( + commodity=commodity, + errors=[ + ValidationErrorDetail( + loc=["solver_backend"], + msg=(f"LP solver failed: {lp_result.status}"), + type="solver_error", + ), + ], + ).model_dump() + + _log.info( + "LP solver failed (%s), falling back to greedy", + lp_result.status, + ) + + # --- Greedy fallback --- + allocation, unmet, supplier_max, supplier_remaining = _greedy_fallback( + markets=markets, + suppliers=suppliers, + demand=demand, + shipping_cost=shipping_cost, + risk_weight=risk_weight, + max_supplier_share=max_supplier_share, + ) + + _rebalance_for_composition( + allocation=allocation, + supplier_remaining=supplier_remaining, + suppliers=suppliers, + markets=markets, + composition_targets=composition_targets, + composition_profiles=composition_profiles, + composition_tolerance=composition_tolerance, + ) + + return _build_output( + commodity=commodity, + allocation=allocation, + unmet=unmet, + suppliers=suppliers, + markets=markets, + demand=demand, + shipping_cost=shipping_cost, + risk_weight=risk_weight, + unmet_penalty=unmet_penalty, + supplier_max=supplier_max, + composition_targets=composition_targets, + composition_profiles=composition_profiles, + composition_tolerance=composition_tolerance, + ) + + +@tool +def run_cmm_supply_chain_optimization( + optimization_input: dict[str, Any], +) -> dict[str, Any]: + """Run a deterministic CMM supply allocation optimization. + + Expected optimization_input schema: + - commodity: str + - demand: mapping market -> required quantity + - suppliers: list of {name, capacity, unit_cost, risk_score} + - optional per-supplier composition_profile: mapping component -> fraction + - shipping_cost: optional mapping supplier -> market -> cost + - risk_weight: optional float + - unmet_demand_penalty: optional float + - max_supplier_share: optional float in [0, 1] + - composition_targets: optional mapping component -> target fraction + - composition_tolerance: optional fraction tolerance in [0, 1] + """ + return solve_cmm_supply_chain_optimization(optimization_input) diff --git a/src/ursa/workflows/__init__.py b/src/ursa/workflows/__init__.py index c7908bbd..501adb61 100644 --- a/src/ursa/workflows/__init__.py +++ b/src/ursa/workflows/__init__.py @@ -1,4 +1,7 @@ from .base_workflow import BaseWorkflow as BaseWorkflow +from .critical_minerals_workflow import ( + CriticalMineralsWorkflow as CriticalMineralsWorkflow, +) from .planning_execution_workflow import ( PlanningExecutorWorkflow as PlanningExecutorWorkflow, ) diff --git a/src/ursa/workflows/critical_minerals_workflow.py b/src/ursa/workflows/critical_minerals_workflow.py new file mode 100644 index 00000000..2ff18bcf --- /dev/null +++ b/src/ursa/workflows/critical_minerals_workflow.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Mapping + +from ursa.tools import solve_cmm_supply_chain_optimization +from ursa.workflows.base_workflow import BaseWorkflow + + +def _coerce_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, Mapping): + for key in ("final_summary", "summary", "result", "output"): + text = value.get(key) + if isinstance(text, str) and text.strip(): + return text + messages = value.get("messages") + if isinstance(messages, list) and messages: + last = messages[-1] + text = getattr(last, "text", None) + if isinstance(text, str) and text.strip(): + return text + content = getattr(last, "content", None) + if isinstance(content, str) and content.strip(): + return content + return str(value) + + +class CriticalMineralsWorkflow(BaseWorkflow): + """Compose planning, retrieval, domain tools, and execution for minerals tasks.""" + + def __init__( + self, + planner: Any, + executor: Any, + *, + acquisition_agents: Mapping[str, Any] | None = None, + rag_agent: Any | None = None, + materials_agent: Any | None = None, + simulation_agent: Any | None = None, + workspace: str | Path = "critical_minerals_workspace", + **kwargs: Any, + ): + super().__init__(**kwargs) + self.planner = planner + self.executor = executor + self.acquisition_agents = dict(acquisition_agents or {}) + self.rag_agent = rag_agent + self.materials_agent = materials_agent + self.simulation_agent = simulation_agent + self.workspace = Path(workspace) + self.workspace.mkdir(parents=True, exist_ok=True) + + def _invoke(self, inputs: Mapping[str, Any], **kw: Any) -> dict[str, Any]: + del kw + task = str(inputs["task"]) + domain_context = str( + inputs.get( + "domain_context", + "critical minerals and materials supply chains", + ) + ) + acquisition_context = str(inputs.get("acquisition_context", task)) + rag_context = str(inputs.get("rag_context", task)) + materials_context = str(inputs.get("materials_context", task)) + local_corpus_path = inputs.get("local_corpus_path") + source_queries = inputs.get("source_queries", {}) + optimization_input = inputs.get("optimization_input") + + if not isinstance(source_queries, Mapping): + raise TypeError("source_queries must be a mapping of source_name -> query") + + planning_prompt = ( + "Generate an implementation-ready plan for this technical task.\n" + f"Domain context: {domain_context}\n" + f"Task: {task}\n" + "Prioritize data provenance, uncertainty handling, and reproducibility." + ) + planning_output = self.planner.invoke(planning_prompt) + + acquisition_outputs: dict[str, Any] = {} + for source_name, agent in self.acquisition_agents.items(): + query = str(source_queries.get(source_name, task)) + payload = {"query": query, "context": acquisition_context} + try: + acquisition_outputs[source_name] = agent.invoke(payload) + except TypeError: + acquisition_outputs[source_name] = agent.invoke(**payload) + + rag_output: Any | None = None + if self.rag_agent is not None: + if local_corpus_path and hasattr(self.rag_agent, "database_path"): + self.rag_agent.database_path = Path(str(local_corpus_path)) + rag_output = self.rag_agent.invoke({"context": rag_context}) + + materials_output: Any | None = None + if self.materials_agent is not None and "materials_query" in inputs: + materials_output = self.materials_agent.invoke( + { + "query": inputs["materials_query"], + "context": materials_context, + } + ) + + simulation_output: Any | None = None + if self.simulation_agent is not None and "simulation_input" in inputs: + simulation_output = self.simulation_agent.invoke( + inputs["simulation_input"] + ) + + sections: list[str] = [ + f"Task:\n{task}", + f"Domain context:\n{domain_context}", + f"Planning output:\n{_coerce_text(planning_output)}", + ] + + if acquisition_outputs: + for source_name, output in acquisition_outputs.items(): + sections.append( + f"Acquisition ({source_name}) summary:\n{_coerce_text(output)}" + ) + + if rag_output is not None: + sections.append(f"RAG summary:\n{_coerce_text(rag_output)}") + + if materials_output is not None: + sections.append( + f"Materials intelligence summary:\n{_coerce_text(materials_output)}" + ) + + if simulation_output is not None: + sections.append( + f"Simulation summary:\n{_coerce_text(simulation_output)}" + ) + + optimization_output: dict[str, Any] | None = None + if optimization_input is not None: + if not isinstance(optimization_input, Mapping): + raise TypeError("optimization_input must be a mapping") + optimization_output = solve_cmm_supply_chain_optimization( + dict(optimization_input) + ) + sections.append( + "Optimization output (deterministic JSON):\n" + + json.dumps(optimization_output, indent=2, sort_keys=True) + ) + + execution_instruction = str( + inputs.get( + "execution_instruction", + "Produce a technically rigorous synthesis with explicit source" + " grounding, assumptions, uncertainty notes, and actionable next" + " steps for critical minerals decisions.", + ) + ) + sections.append(f"Execution instruction:\n{execution_instruction}") + + executor_output = self.executor.invoke("\n\n".join(sections)) + + return { + "task": task, + "plan": planning_output, + "acquisition": acquisition_outputs, + "rag": rag_output, + "materials": materials_output, + "simulation": simulation_output, + "optimization": optimization_output, + "executor_output": executor_output, + "final_summary": _coerce_text(executor_output), + } diff --git a/src/ursa/workflows/planning_execution_workflow.py b/src/ursa/workflows/planning_execution_workflow.py index 32748bc6..bdf7a71c 100644 --- a/src/ursa/workflows/planning_execution_workflow.py +++ b/src/ursa/workflows/planning_execution_workflow.py @@ -1,5 +1,6 @@ import sqlite3 from pathlib import Path +from typing import Any, Mapping from langgraph.checkpoint.sqlite import SqliteSaver from rich import get_console @@ -33,7 +34,8 @@ def __init__(self, planner, executor, workspace, **kwargs): self.planner.checkpointer = checkpointer self.executor.checkpointer = checkpointer - def _invoke(self, task: str, **kw): + def _invoke(self, inputs: Mapping[str, Any], **kw): + task = str(inputs["task"]) with console.status( "[bold deep_pink1]Planning overarching steps . . .", spinner="point", diff --git a/src/ursa/workflows/simulation_use_workflow.py b/src/ursa/workflows/simulation_use_workflow.py index 10a1e2d3..169849da 100644 --- a/src/ursa/workflows/simulation_use_workflow.py +++ b/src/ursa/workflows/simulation_use_workflow.py @@ -1,4 +1,6 @@ # planning_executor.py +from typing import Any, Mapping + from rich import get_console from rich.panel import Panel @@ -75,16 +77,23 @@ class SimulationUseWorkflow(BaseWorkflow): def __init__( - self, planner, executor, workspace, tool_description, **kwargs + self, + planner, + executor, + workspace, + tool_description, + tool_schema=code_schema_prompt, + **kwargs, ): super().__init__(**kwargs) self.planner = planner self.executor = executor self.workspace = workspace - self.tool_schema = code_schema_prompt + self.tool_schema = tool_schema self.tool_description = tool_description - def _invoke(self, task: str, **kw): + def _invoke(self, inputs: Mapping[str, Any], **kw): + task = str(inputs["task"]) with console.status( "[bold deep_pink1]Planning overarching steps . . .", spinner="point", diff --git a/tests/agents/test_arxiv_agent_legacy/test_arxiv_agent_legacy.py b/tests/agents/test_arxiv_agent_legacy/test_arxiv_agent_legacy.py deleted file mode 100644 index 975ef31b..00000000 --- a/tests/agents/test_arxiv_agent_legacy/test_arxiv_agent_legacy.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest - -from ursa.agents.arxiv_agent import ArxivAgentLegacy - - -@pytest.mark.asyncio -async def test_arxiv_agent_legacy_fetches_local_papers_without_network( - chat_model, tmpdir, monkeypatch: pytest.MonkeyPatch -): - monkeypatch.setattr( - "ursa.agents.arxiv_agent.requests.get", - lambda *args, **kwargs: pytest.fail( - "requests.get should not be called" - ), - ) - - agent = ArxivAgentLegacy( - llm=chat_model, - summarize=False, - process_images=False, - download_papers=False, - workspace=tmpdir, - database_path="papers", - summaries_path="summaries", - vectorstore_path="vectors", - ) - - local_pdf = agent.database_path / "2401.01234.pdf" - local_pdf.parent.mkdir(parents=True, exist_ok=True) - local_pdf.write_bytes(b"") - - query = "quantum error correction codes" - context = "Identify recent progress in near-term experiments." - result = await agent.ainvoke({"query": query, "context": context}) - - assert result["query"] == query - assert result["context"] == context - assert isinstance(result["papers"], list) - assert any( - paper["arxiv_id"] == "2401.01234" and paper["full_text"] == "" - for paper in result["papers"] - ) diff --git a/tests/agents/test_code_review_agent/test_code_review_agent.py b/tests/agents/test_code_review_agent/test_code_review_agent.py new file mode 100644 index 00000000..ae42a67f --- /dev/null +++ b/tests/agents/test_code_review_agent/test_code_review_agent.py @@ -0,0 +1,41 @@ +from typing import Iterator + +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage + + +class _ToolReadyFakeChatModel(GenericFakeChatModel): + def bind_tools(self, tools, **kwargs): + return self + + +def _message_stream(content: str) -> Iterator[AIMessage]: + while True: + yield AIMessage(content=content) + + +def test_code_review_run_delegates_to_invoke(tmp_path, monkeypatch): + from ursa.agents.code_review_agent import CodeReviewAgent + + (tmp_path / "main.py").write_text("print('hello')\n", encoding="utf-8") + + chat_model = _ToolReadyFakeChatModel(messages=_message_stream("ok")) + agent = CodeReviewAgent(llm=chat_model, workspace=tmp_path) + captured = {} + + def fake_invoke(inputs=None, **kwargs): + captured["inputs"] = inputs + captured["kwargs"] = kwargs + return {"status": "ok"} + + monkeypatch.setattr(agent, "invoke", fake_invoke) + + result = agent.run("Review the code", tmp_path) + + assert result == {"status": "ok"} + assert captured["inputs"]["project_prompt"] == "Review the code" + assert captured["inputs"]["code_files"] == ["main.py"] + assert ( + captured["kwargs"]["config"]["configurable"]["thread_id"] + == agent.thread_id + ) diff --git a/tests/agents/test_rag_agent/test_cmm_components.py b/tests/agents/test_rag_agent/test_cmm_components.py new file mode 100644 index 00000000..6847334c --- /dev/null +++ b/tests/agents/test_rag_agent/test_cmm_components.py @@ -0,0 +1,162 @@ +from langchain_core.documents import Document + +from ursa.agents.cmm_chunker import CMMChunker +from ursa.agents.cmm_embeddings import parse_embedding_model_spec +from ursa.agents.cmm_query_classifier import CMMQueryClassifier, QueryProfile +from ursa.agents.cmm_taxonomy import ( + detect_commodity_tags, + detect_subdomain_tags, + extract_temporal_indicators, +) +from ursa.agents.cmm_vectorstore import ( + ChromaBM25VectorStore, + CMMVectorStoreBase, +) +from ursa.agents.rag_agent import RAGAgent + + +def test_parse_embedding_model_spec_variants(): + assert parse_embedding_model_spec("openai:text-embedding-3-large:1024") == ( + "openai", + "text-embedding-3-large", + 1024, + ) + assert parse_embedding_model_spec("openai:text-embedding-3-small") == ( + "openai", + "text-embedding-3-small", + None, + ) + assert parse_embedding_model_spec("local:BAAI/bge-large-en-v1.5") == ( + "local", + "BAAI/bge-large-en-v1.5", + None, + ) + + +def test_chunker_preserves_markdown_tables_and_metadata(): + text = """ +# Lithium Supply +Q4 2024 update for lithium carbonate market. + +| region | demand | +| --- | --- | +| NA | 120 | +| EU | 95 | +""" + chunker = CMMChunker(max_tokens=120, overlap_tokens=20, min_tokens=5) + docs = chunker.chunk_document( + text, + metadata={"source_doc_id": "doc-1", "source_doc_title": "demo"}, + ) + + assert docs + assert any(doc.metadata["chunk_type"] == "table" for doc in docs) + assert any("LI" in doc.metadata["commodity_tags"] for doc in docs) + assert any(doc.metadata["temporal_indicator"] for doc in docs) + + +def test_taxonomy_taggers_detect_domain_hints(): + query = "Compare lithium and cobalt policy impacts on 2025 trade flows." + commodities = detect_commodity_tags(query) + subdomains = detect_subdomain_tags(query) + temporal = extract_temporal_indicators(query) + + assert "LI" in commodities + assert "CO" in commodities + assert "G-PR" in subdomains + assert "Q-TF" in subdomains + assert "2025" in temporal + + +def test_query_classifier_adapts_retrieval(): + classifier = CMMQueryClassifier() + profile = classifier.classify( + "Compare lithium and cobalt supply shocks in 2025." + ) + + assert profile.query_type in {"comparative", "multi_hop"} + assert profile.retrieval_k >= 20 + assert profile.return_k >= 5 + assert "commodity_tags" in profile.filters + + +def test_hybrid_rrf_fusion_prefers_cross_signal_docs(): + store = ChromaBM25VectorStore.__new__(ChromaBM25VectorStore) + + doc_a = Document(page_content="A", metadata={"chunk_id": "A"}) + doc_b = Document(page_content="B", metadata={"chunk_id": "B"}) + doc_c = Document(page_content="C", metadata={"chunk_id": "C"}) + + store._dense_search = lambda query, k: [(doc_a, 0.9), (doc_b, 0.8)] + store._bm25_search = lambda query, k: [(doc_b, 2.0), (doc_c, 1.0)] + + results = ChromaBM25VectorStore.hybrid_search( + store, + query="demo", + k=3, + alpha=0.7, + filters=None, + ) + + ordered_ids = [doc.metadata["chunk_id"] for doc, _ in results] + assert ordered_ids[0] == "B" + assert set(ordered_ids) == {"A", "B", "C"} + + +class _StubVectorStore(CMMVectorStoreBase): + def __init__(self): + self.calls = [] + + def add_documents(self, documents: list[Document]) -> None: + del documents + + def hybrid_search(self, query: str, k: int, alpha: float, filters: dict | None): + del query, k, alpha + self.calls.append(filters) + if filters: + return [] + return [ + (Document(page_content="fallback-hit", metadata={"chunk_id": "doc-1"}), 0.9) + ] + + def delete_collection(self) -> None: + return None + + def count(self) -> int: + return 0 + + +class _StubClassifier: + def classify(self, query: str) -> QueryProfile: + del query + return QueryProfile( + query_type="general", + commodity_hints=["LREE"], + subdomain_hints=[], + temporal_hints=[], + retrieval_k=20, + return_k=5, + alpha=0.7, + filters={"commodity_tags": ["LREE"]}, + ) + + +def test_rag_retrieval_falls_back_when_filters_return_empty(): + agent = RAGAgent.__new__(RAGAgent) + agent.legacy_mode = False + agent.vectorstore = _StubVectorStore() + agent.classifier = _StubClassifier() + agent._adaptive_retrieval_k = True + agent._adaptive_return_k = True + agent.retrieval_k = 20 + agent.return_k = 5 + agent.hybrid_alpha = 0.7 + agent.use_reranker = False + agent.vectorstore_backend = "chroma" + + results, params = RAGAgent._retrieve(agent, "lanthanum yttrium ndfeb") + + assert len(results) == 1 + assert params["filter_fallback_used"] is True + assert agent.vectorstore.calls[0] == {"commodity_tags": ["LREE"]} + assert agent.vectorstore.calls[1] is None diff --git a/tests/agents/test_rag_agent/test_rag_agent.py b/tests/agents/test_rag_agent/test_rag_agent.py index cbbbc7e5..5d8a5418 100644 --- a/tests/agents/test_rag_agent/test_rag_agent.py +++ b/tests/agents/test_rag_agent/test_rag_agent.py @@ -58,3 +58,36 @@ def fakePDFLoader(path_name): manifest_path = vectors_dir / "_ingested_ids.txt" assert manifest_path.exists() + + +async def test_rag_agent_extension_filtering(chat_model, embedding_model, tmpdir): + workspace = Path(tmpdir) + database_dir = workspace / "database" + summaries_dir = workspace / "summaries" + vectors_dir = workspace / "vectors" + for path in (database_dir, summaries_dir, vectors_dir): + path.mkdir(parents=True, exist_ok=True) + + (database_dir / "report.txt").write_text( + "Critical mineral supply chain report content." + ) + (database_dir / "scratch.py").write_text( + "print('dev script that should be excluded')" + ) + + agent = RAGAgent( + llm=chat_model, + embedding=embedding_model, + workspace=tmpdir, + database_path="database", + summaries_path="summaries", + vectorstore_path="vectors", + include_extensions={".txt"}, + max_docs_per_ingest=10, + min_chars=5, + ) + + state = agent._read_docs_node({"context": "critical minerals"}) + doc_ids = state.get("doc_ids") or [] + assert any(doc_id.endswith("report.txt") for doc_id in doc_ids) + assert not any(doc_id.endswith("scratch.py") for doc_id in doc_ids) diff --git a/tests/agents/test_websearch_agent/test_websearch_agent.py b/tests/agents/test_websearch_agent/test_websearch_agent.py index d17decc4..2a277933 100644 --- a/tests/agents/test_websearch_agent/test_websearch_agent.py +++ b/tests/agents/test_websearch_agent/test_websearch_agent.py @@ -1,115 +1,70 @@ -from langchain.messages import HumanMessage -from langchain_core.messages import AIMessage, ToolMessage +from typing import Iterator -from ursa.agents import WebSearchAgentLegacy +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage +from ursa.agents.acquisition_agents import WebSearchAgent -async def test_websearch_agent_legacy_websearch_flow( - chat_model, monkeypatch, tmpdir + +class _FakeChatModel(GenericFakeChatModel): + pass + + +def _message_stream(content: str) -> Iterator[AIMessage]: + while True: + yield AIMessage(content=content) + + +async def test_websearch_agent_fetches_items_without_network_or_llm( + monkeypatch, tmpdir ): - """Ensure the legacy websearch agent wires the search tool into the graph.""" - query = "Who won the 2025 International Chopin Competition?" - search_url = "https://example.com/chopin-2025" + query = "test scientific query" + search_url = "https://example.com/paper" - # Prevent real network access during the test. - monkeypatch.setattr( - WebSearchAgentLegacy, - "_check_for_internet", - lambda self, url="http://www.lanl.gov", timeout=2: True, - ) + class FakeDDGS: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def text(self, query, max_results=10, backend="auto"): + return [ + { + "title": "Mock Result", + "href": search_url, + "body": "Mock snippet", + } + ] class FakeResponse: - def __init__(self, content: bytes): - self.content = content - self.text = str(content) + def __init__(self, text: str): + self.text = text - monkeypatch.setattr( - "ursa.agents.websearch_agent.requests.get", - lambda url, timeout=2: FakeResponse( - b"

Mock content for testing.

" - ), - ) + def raise_for_status(self): + return None + monkeypatch.setattr("ursa.agents.acquisition_agents.DDGS", FakeDDGS) monkeypatch.setattr( - "langchain_community.utilities.duckduckgo_search.DuckDuckGoSearchAPIWrapper.results", - lambda self, q, max_results, source=None: [ - { - "title": "Mock Chopin Coverage", - "href": search_url, - "body": "Summary of the winner and teachers.", - } - ], + "ursa.agents.acquisition_agents.requests.get", + lambda *args, **kwargs: FakeResponse( + "

mock extracted content with enough length to survive dedupe threshold in extraction

" + ), ) - class FakeReactAgent: - def __init__(self): - self.invocations = 0 - - def invoke(self, state): - self.invocations += 1 - tool_call_id = "tool_call_1" - tool_request = AIMessage( - content="Searching for official announcement.", - tool_calls=[ - { - "id": tool_call_id, - "name": "process_content", - "args": { - "url": search_url, - "context": "competition winner details", - }, - "type": "tool_call", - } - ], - ) - tool_result = ToolMessage( - content="Winner: Jane Doe. Teachers: John Smith and Alice Brown.", - tool_call_id=tool_call_id, - ) - researcher_summary = AIMessage( - content="Collected the winner and teacher information from the announcement." - ) - return { - "messages": [tool_request, tool_result, researcher_summary], - "urls_visited": [search_url], - } - - fake_react_agent = FakeReactAgent() - monkeypatch.setattr( - "ursa.agents.websearch_agent.create_agent", - lambda *args, **kwargs: fake_react_agent, + chat_model = _FakeChatModel(messages=_message_stream("summary")) + agent = WebSearchAgent( + llm=chat_model, + summarize=False, + max_results=1, + workspace=tmpdir, ) - agent = WebSearchAgentLegacy(llm=chat_model, workspace=tmpdir) - inputs = { - "messages": [HumanMessage(content=query)], - "model": chat_model, - "websearch_query": query, - "urls_visited": [], - "max_websearch_steps": 0, - } - - # Run once via ainvoke to satisfy the async API contract. - await agent.ainvoke(inputs) - assert fake_react_agent.invocations >= 1 - - # Collect a second run with astream so we can inspect intermediate node outputs. - create_react_state = None - response_state = None - async for step in agent.compiled_graph.astream(inputs): - create_react_state = step.get("_create_react") or create_react_state - response_state = step.get("_response_node") or response_state - - assert create_react_state is not None - assert search_url in create_react_state["urls_visited"] - - tool_messages = create_react_state["messages"] - assert any( - getattr(msg, "tool_calls", []) - and msg.tool_calls[0]["args"].get("url") == search_url - for msg in tool_messages + result = await agent.ainvoke( + {"query": query, "context": "summarize this query"} ) - assert response_state is not None - final_messages = response_state["messages"] - assert final_messages and isinstance(final_messages[0], str) + assert "items" in result + assert len(result["items"]) == 1 + assert result["items"][0]["url"] == search_url + assert "mock extracted content" in result["items"][0]["full_text"].lower() diff --git a/tests/conftest.py b/tests/conftest.py index 083a551c..f0b0fb9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import os + import pytest from dotenv import load_dotenv from langchain.chat_models import init_chat_model @@ -7,6 +9,8 @@ @pytest.fixture(scope="session", autouse=True) def _load_dotenv(): load_dotenv() + if not os.path.exists(".env") and os.path.exists("env.txt"): + load_dotenv("env.txt") def bind_kwargs(func, **kwargs): @@ -18,17 +22,46 @@ def bind_kwargs(func, **kwargs): @pytest.fixture(scope="function") def chat_model(): + base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_API_BASE") + model_name = os.getenv("URSA_TEST_CHAT_MODEL", "openai:gpt-5-nano") + if base_url and ":" not in model_name: + # OpenAI-compatible endpoints require the OpenAI provider in init_chat_model. + model_name = f"openai:{model_name}" + + kwargs = { + "model": model_name, + "max_tokens": int(os.getenv("URSA_TEST_MAX_TOKENS", "3000")), + "temperature": float(os.getenv("URSA_TEST_TEMPERATURE", "0.0")), + } + if base_url: + kwargs["base_url"] = base_url + if api_key := os.getenv("OPENAI_API_KEY"): + kwargs["api_key"] = api_key + return bind_kwargs( init_chat_model, - model="openai:gpt-5-nano", - max_tokens=3000, - temperature=0.0, + **kwargs, ) @pytest.fixture(scope="function") def embedding_model(): + base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_API_BASE") + model_name = os.getenv( + "URSA_TEST_EMBED_MODEL", "openai:text-embedding-3-small" + ) + if base_url and ":" not in model_name: + model_name = f"openai:{model_name}" + + kwargs = { + "model": model_name, + } + if base_url: + kwargs["base_url"] = base_url + if api_key := os.getenv("OPENAI_API_KEY"): + kwargs["api_key"] = api_key + return bind_kwargs( init_embeddings, - model="openai:text-embedding-3-small", + **kwargs, ) diff --git a/tests/test_base_interface.py b/tests/test_base_interface.py index 7044cbdf..c2c74f73 100644 --- a/tests/test_base_interface.py +++ b/tests/test_base_interface.py @@ -36,7 +36,6 @@ def load_class(path: str): "ursa.agents.acquisition_agents.ArxivAgent", "ursa.agents.acquisition_agents.WebSearchAgent", "ursa.agents.acquisition_agents.OSTIAgent", - "ursa.agents.arxiv_agent.ArxivAgentLegacy", "ursa.agents.chat_agent.ChatAgent", "ursa.agents.code_review_agent.CodeReviewAgent", "ursa.agents.execution_agent.ExecutionAgent", @@ -45,7 +44,6 @@ def load_class(path: str): "ursa.agents.planning_agent.PlanningAgent", "ursa.agents.rag_agent.RAGAgent", "ursa.agents.recall_agent.RecallAgent", - "ursa.agents.websearch_agent.WebSearchAgentLegacy", ], ids=lambda agent_import: agent_import.rsplit(".", 1)[-1], ) diff --git a/tests/tools/test_cmm_supply_chain_optimization_tool.py b/tests/tools/test_cmm_supply_chain_optimization_tool.py new file mode 100644 index 00000000..a4368f9e --- /dev/null +++ b/tests/tools/test_cmm_supply_chain_optimization_tool.py @@ -0,0 +1,307 @@ +from ursa.tools.cmm_supply_chain_optimization_tool import ( + OptimizationOutput, + solve_cmm_supply_chain_optimization, +) + + +def _base_input(): + return { + "commodity": "CO", + "demand": {"NA": 100, "EU": 80}, + "suppliers": [ + { + "name": "US_mine", + "capacity": 120, + "unit_cost": 8.0, + "risk_score": 0.2, + }, + { + "name": "Allied_import", + "capacity": 90, + "unit_cost": 9.2, + "risk_score": 0.1, + }, + ], + "shipping_cost": { + "US_mine": {"NA": 1.0, "EU": 2.0}, + "Allied_import": {"NA": 1.4, "EU": 1.2}, + }, + "risk_weight": 2.0, + "max_supplier_share": 0.8, + } + + +def test_cmm_optimization_output_schema_and_determinism(): + payload = _base_input() + result_a = solve_cmm_supply_chain_optimization(payload) + result_b = solve_cmm_supply_chain_optimization(payload) + + assert result_a == result_b + assert "objective_value" in result_a + assert "allocations" in result_a + assert "constraint_residuals" in result_a + assert "sensitivity_summary" in result_a + assert isinstance(result_a["feasible"], bool) + + +def test_cmm_optimization_reports_infeasible_with_unmet_demand(): + payload = _base_input() + payload["suppliers"] = [ + { + "name": "low_capacity_1", + "capacity": 50, + "unit_cost": 8.0, + "risk_score": 0.2, + }, + { + "name": "low_capacity_2", + "capacity": 30, + "unit_cost": 9.2, + "risk_score": 0.1, + }, + ] + + result = solve_cmm_supply_chain_optimization(payload) + + assert result["feasible"] is False + assert result["status"] == "infeasible_unmet_demand" + assert sum(result["unmet_demand"].values()) > 0 + + +def test_cmm_optimization_enforces_composition_targets_when_possible(): + payload = { + "commodity": "ND2FE14B_LA5_Y5", + "demand": {"US": 100}, + "suppliers": [ + { + "name": "la_rich", + "capacity": 100, + "unit_cost": 1.0, + "risk_score": 0.1, + "composition_profile": {"LA": 0.10, "Y": 0.00}, + }, + { + "name": "y_rich", + "capacity": 100, + "unit_cost": 5.0, + "risk_score": 0.1, + "composition_profile": {"LA": 0.00, "Y": 0.10}, + }, + ], + "shipping_cost": { + "la_rich": {"US": 0.0}, + "y_rich": {"US": 0.0}, + }, + "risk_weight": 0.0, + "max_supplier_share": 1.0, + "composition_targets": {"LA": 0.05, "Y": 0.05}, + "composition_tolerance": 0.001, + } + + result = solve_cmm_supply_chain_optimization(payload) + + assert result["feasible"] is True + assert result["status"] in ("optimal_greedy", "optimal_lp") + composition = result["composition"] + assert composition is not None + assert composition["feasible"] is True + assert abs(composition["actual"]["LA"] - 0.05) <= 0.002 + assert abs(composition["actual"]["Y"] - 0.05) <= 0.002 + + +def test_cmm_optimization_reports_infeasible_composition_constraints(): + payload = { + "commodity": "ND2FE14B_LA5_Y5", + "demand": {"US": 100}, + "suppliers": [ + { + "name": "supplier_a", + "capacity": 100, + "unit_cost": 1.0, + "risk_score": 0.1, + "composition_profile": {"LA": 0.02, "Y": 0.01}, + }, + { + "name": "supplier_b", + "capacity": 100, + "unit_cost": 1.1, + "risk_score": 0.1, + "composition_profile": {"LA": 0.03, "Y": 0.02}, + }, + ], + "shipping_cost": { + "supplier_a": {"US": 0.0}, + "supplier_b": {"US": 0.0}, + }, + "risk_weight": 0.0, + "max_supplier_share": 1.0, + "composition_targets": {"LA": 0.05, "Y": 0.05}, + "composition_tolerance": 0.001, + } + + result = solve_cmm_supply_chain_optimization(payload) + + # LP may report infeasibility as unmet_demand (composition + # constraints force zero allocation), greedy reports it as + # infeasible_composition_constraints. + assert result["status"] in ( + "infeasible_composition_constraints", + "infeasible_unmet_demand", + ) + assert result["feasible"] is False + + +# --------------------------------------------------------------------------- +# Validation error tests +# --------------------------------------------------------------------------- + + +def test_validation_error_negative_capacity(): + payload = _base_input() + payload["suppliers"][0]["capacity"] = -50 + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + assert len(result["errors"]) >= 1 + + +def test_validation_error_fraction_above_one(): + payload = _base_input() + payload["max_supplier_share"] = 1.5 + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_non_numeric_cost(): + payload = _base_input() + payload["suppliers"][0]["unit_cost"] = "not_a_number" + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_empty_demand(): + payload = _base_input() + payload["demand"] = {} + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_empty_suppliers(): + payload = _base_input() + payload["suppliers"] = [] + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_unknown_field(): + payload = _base_input() + payload["bogus_field"] = 42 + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_composition_fraction_out_of_range(): + payload = _base_input() + payload["composition_targets"] = {"LA": 1.5} + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_unmet_demand_penalty_below_min(): + payload = _base_input() + payload["unmet_demand_penalty"] = 0.5 + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "validation_error" + assert result["feasible"] is False + + +def test_validation_error_preserves_commodity_in_response(): + payload = _base_input() + payload["suppliers"][0]["capacity"] = -10 + result = solve_cmm_supply_chain_optimization(payload) + + assert result["commodity"] == "CO" + + +# --------------------------------------------------------------------------- +# Backward-compatibility tests +# --------------------------------------------------------------------------- + + +def test_int_demand_values_coerced_to_float(): + payload = _base_input() + # _base_input already uses int demand values (100, 80) + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] in ("optimal_greedy", "optimal_lp") + assert "objective_value" in result + + +def test_supplier_name_auto_generated(): + payload = { + "commodity": "CMM", + "demand": {"NA": 50}, + "suppliers": [ + {"capacity": 100, "unit_cost": 5.0}, + {"capacity": 80, "unit_cost": 6.0}, + ], + } + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] in ("optimal_greedy", "optimal_lp") + supplier_names = {a["supplier"] for a in result["allocations"]} + assert "supplier_1" in supplier_names or "supplier_2" in supplier_names + + +def test_component_names_normalized_to_uppercase(): + payload = { + "commodity": "CMM", + "demand": {"US": 100}, + "suppliers": [ + { + "name": "s1", + "capacity": 100, + "unit_cost": 1.0, + "composition_profile": {"la": 0.05, "y": 0.05}, + }, + ], + "composition_targets": {"la": 0.05, "y": 0.05}, + "composition_tolerance": 0.01, + } + result = solve_cmm_supply_chain_optimization(payload) + + composition = result["composition"] + assert composition is not None + assert "LA" in composition["actual"] + assert "Y" in composition["actual"] + + +# --------------------------------------------------------------------------- +# Output model validation test +# --------------------------------------------------------------------------- + + +def test_output_has_correct_types(): + payload = _base_input() + result = solve_cmm_supply_chain_optimization(payload) + + # Re-validate through the output model to confirm structural correctness + validated = OptimizationOutput.model_validate(result) + assert validated.commodity == "CO" + assert isinstance(validated.feasible, bool) + assert isinstance(validated.objective_value, float) + assert len(validated.allocations) > 0 diff --git a/tests/tools/test_lp_solver.py b/tests/tools/test_lp_solver.py new file mode 100644 index 00000000..33508998 --- /dev/null +++ b/tests/tools/test_lp_solver.py @@ -0,0 +1,242 @@ +"""Tests for the LP solver backend and its integration with the +CMM supply-chain optimization tool.""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from ursa.tools.cmm_supply_chain_optimization_tool import ( + OptimizationOutput, + solve_cmm_supply_chain_optimization, +) + +_CONFIGS_DIR = Path(__file__).resolve().parents[2] / "configs" +_ND_SCENARIOS_PATH = _CONFIGS_DIR / "nd_china_2025_scenarios.json" + + +def _base_input() -> dict: + return { + "commodity": "CO", + "demand": {"NA": 100, "EU": 80}, + "suppliers": [ + { + "name": "US_mine", + "capacity": 120, + "unit_cost": 8.0, + "risk_score": 0.2, + }, + { + "name": "Allied_import", + "capacity": 90, + "unit_cost": 9.2, + "risk_score": 0.1, + }, + ], + "shipping_cost": { + "US_mine": {"NA": 1.0, "EU": 2.0}, + "Allied_import": {"NA": 1.4, "EU": 1.2}, + }, + "risk_weight": 2.0, + "max_supplier_share": 0.8, + } + + +def test_lp_optimal_status(): + payload = _base_input() + payload["solver_backend"] = "lp" + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "optimal_lp" + assert result["feasible"] is True + + +def test_lp_objective_leq_greedy(): + payload = _base_input() + + payload_greedy = {**payload, "solver_backend": "greedy"} + result_greedy = solve_cmm_supply_chain_optimization(payload_greedy) + + payload_lp = {**payload, "solver_backend": "lp"} + result_lp = solve_cmm_supply_chain_optimization(payload_lp) + + assert ( + result_lp["objective_value"] <= result_greedy["objective_value"] + 1e-6 + ) + + +def test_lp_shadow_prices_present(): + payload = _base_input() + payload["solver_backend"] = "lp" + result = solve_cmm_supply_chain_optimization(payload) + + sp = result.get("shadow_prices") + assert sp is not None + assert "demand_balance" in sp + assert "supplier_capacity" in sp + assert "supplier_share_cap" in sp + assert "composition" in sp + + +def test_lp_determinism(): + payload = _base_input() + payload["solver_backend"] = "lp" + result_a = solve_cmm_supply_chain_optimization(payload) + result_b = solve_cmm_supply_chain_optimization(payload) + + assert result_a == result_b + + +def test_lp_demand_balance(): + payload = _base_input() + payload["solver_backend"] = "lp" + result = solve_cmm_supply_chain_optimization(payload) + + demand = payload["demand"] + for market, required in demand.items(): + allocated = sum( + a["amount"] for a in result["allocations"] if a["market"] == market + ) + unmet = result["unmet_demand"].get(market, 0.0) + assert abs(allocated + unmet - required) < 1e-3 + + +def test_lp_capacity_respected(): + payload = _base_input() + payload["solver_backend"] = "lp" + result = solve_cmm_supply_chain_optimization(payload) + + for sup in payload["suppliers"]: + allocated = sum( + a["amount"] + for a in result["allocations"] + if a["supplier"] == sup["name"] + ) + assert allocated <= sup["capacity"] + 1e-3 + + +def test_lp_with_composition_constraints(): + payload = { + "commodity": "ALLOY", + "demand": {"US": 100}, + "suppliers": [ + { + "name": "la_rich", + "capacity": 100, + "unit_cost": 1.0, + "risk_score": 0.1, + "composition_profile": { + "LA": 0.10, + "Y": 0.00, + }, + }, + { + "name": "y_rich", + "capacity": 100, + "unit_cost": 5.0, + "risk_score": 0.1, + "composition_profile": { + "LA": 0.00, + "Y": 0.10, + }, + }, + ], + "shipping_cost": { + "la_rich": {"US": 0.0}, + "y_rich": {"US": 0.0}, + }, + "risk_weight": 0.0, + "max_supplier_share": 1.0, + "composition_targets": {"LA": 0.05, "Y": 0.05}, + "composition_tolerance": 0.001, + "solver_backend": "lp", + } + + result = solve_cmm_supply_chain_optimization(payload) + + assert result["status"] == "optimal_lp" + assert result["feasible"] is True + comp = result["composition"] + assert comp is not None + assert abs(comp["actual"]["LA"] - 0.05) <= 0.002 + assert abs(comp["actual"]["Y"] - 0.05) <= 0.002 + + +def test_greedy_fallback_when_scipy_unavailable(): + payload = _base_input() + # Default backend (auto) should fall back to greedy + + with ( + patch("ursa.tools._lp_solver._HAS_SCIPY", False), + patch( + "ursa.tools.cmm_supply_chain_optimization_tool.scipy_available", + return_value=False, + ), + ): + result = solve_cmm_supply_chain_optimization( + payload, + ) + + assert result["status"] == "optimal_greedy" + assert result["feasible"] is True + assert result.get("shadow_prices") is None + + +def test_lp_infeasible_capacity(): + payload = _base_input() + payload["solver_backend"] = "lp" + payload["suppliers"] = [ + { + "name": "tiny", + "capacity": 10, + "unit_cost": 8.0, + "risk_score": 0.0, + }, + ] + payload["shipping_cost"] = {"tiny": {"NA": 0.0, "EU": 0.0}} + + result = solve_cmm_supply_chain_optimization(payload) + + # LP should handle this via unmet demand slack + assert result["status"] == "infeasible_unmet_demand" + assert result["feasible"] is False + assert sum(result["unmet_demand"].values()) > 0 + + +@pytest.mark.skipif( + not _ND_SCENARIOS_PATH.exists(), + reason="Nd scenario config not found", +) +def test_nd_scenarios_solve(): + with open(_ND_SCENARIOS_PATH) as fh: + scenarios = json.load(fh) + + # Pre-shock: should be feasible + pre = solve_cmm_supply_chain_optimization( + scenarios["nd_preshock_baseline"]["optimization_input"], + ) + assert pre["feasible"] is True + + # Post-December: demand exceeds constrained capacity + post_dec = solve_cmm_supply_chain_optimization( + scenarios["nd_post_december_2025"]["optimization_input"], + ) + # Should have some unmet demand due to share caps + assert post_dec["sensitivity_summary"]["unmet_demand_total"] > 0 + + +def test_output_validates_through_pydantic(): + payload = _base_input() + payload["solver_backend"] = "lp" + result = solve_cmm_supply_chain_optimization(payload) + + validated = OptimizationOutput.model_validate(result) + assert validated.commodity == "CO" + assert validated.shadow_prices is not None + assert isinstance( + validated.shadow_prices.demand_balance, + dict, + ) diff --git a/tests/workflows/test_critical_minerals_workflow.py b/tests/workflows/test_critical_minerals_workflow.py new file mode 100644 index 00000000..fffa4b49 --- /dev/null +++ b/tests/workflows/test_critical_minerals_workflow.py @@ -0,0 +1,155 @@ +from pathlib import Path + +from langchain_core.messages import AIMessage + +from ursa.workflows import CriticalMineralsWorkflow + + +class _FakePlanner: + def __init__(self): + self.prompts = [] + + def invoke(self, prompt): + self.prompts.append(prompt) + return {"plan": "1) collect sources 2) synthesize findings"} + + +class _FakeExecutor: + def __init__(self): + self.prompts = [] + + def invoke(self, prompt): + self.prompts.append(prompt) + return {"messages": [AIMessage(content="final minerals synthesis")]} + + +class _FakeAgent: + def __init__(self, summary: str): + self.summary = summary + self.calls = [] + + def invoke(self, payload): + self.calls.append(payload) + return {"final_summary": self.summary} + + +class _FakeRAG: + def __init__(self): + self.calls = [] + self.database_path = None + + def invoke(self, payload): + self.calls.append(payload) + return {"summary": "rag evidence summary"} + + +def test_critical_minerals_workflow_orchestrates_modules(tmp_path): + planner = _FakePlanner() + executor = _FakeExecutor() + osti = _FakeAgent("osti signal") + arxiv = _FakeAgent("arxiv signal") + rag = _FakeRAG() + materials = _FakeAgent("materials candidate list") + + workflow = CriticalMineralsWorkflow( + planner=planner, + executor=executor, + acquisition_agents={"osti": osti, "arxiv": arxiv}, + rag_agent=rag, + materials_agent=materials, + workspace=tmp_path, + ) + + result = workflow.invoke( + { + "task": "Assess supply risk for scandium and gallium", + "source_queries": { + "osti": "scandium gallium supply chain policy", + "arxiv": "materials substitution for gallium compounds", + }, + "materials_query": { + "elements": ["Sc", "Ga", "Al", "O"], + "band_gap_min": 1.0, + "band_gap_max": 5.0, + }, + } + ) + + assert "Assess supply risk for scandium and gallium" in planner.prompts[0] + assert osti.calls[0]["query"] == "scandium gallium supply chain policy" + assert arxiv.calls[0]["query"] == "materials substitution for gallium compounds" + assert rag.calls[0]["context"] == "Assess supply risk for scandium and gallium" + assert executor.prompts + assert "Acquisition (osti) summary" in executor.prompts[0] + assert "Materials intelligence summary" in executor.prompts[0] + assert result["final_summary"] == "final minerals synthesis" + + +def test_critical_minerals_workflow_sets_local_corpus_on_rag(tmp_path): + planner = _FakePlanner() + executor = _FakeExecutor() + rag = _FakeRAG() + workflow = CriticalMineralsWorkflow( + planner=planner, + executor=executor, + rag_agent=rag, + workspace=tmp_path, + ) + + corpus_path = "/tmp/cmm-corpus" + workflow.invoke( + { + "task": "Summarize domestic rare earth supply constraints", + "local_corpus_path": corpus_path, + } + ) + + assert rag.database_path == Path(corpus_path) + + +def test_critical_minerals_workflow_runs_optimization(tmp_path): + planner = _FakePlanner() + executor = _FakeExecutor() + workflow = CriticalMineralsWorkflow( + planner=planner, + executor=executor, + workspace=tmp_path, + ) + + result = workflow.invoke( + { + "task": "Allocate cobalt supply for North America and Europe", + "optimization_input": { + "commodity": "CO", + "demand": {"NA": 100, "EU": 80}, + "suppliers": [ + { + "name": "US_mine", + "capacity": 120, + "unit_cost": 8.0, + "risk_score": 0.2, + }, + { + "name": "Allied_import", + "capacity": 90, + "unit_cost": 9.2, + "risk_score": 0.1, + }, + ], + "shipping_cost": { + "US_mine": {"NA": 1.0, "EU": 2.0}, + "Allied_import": {"NA": 1.4, "EU": 1.2}, + }, + "risk_weight": 2.0, + "max_supplier_share": 0.8, + }, + } + ) + + optimization = result["optimization"] + assert optimization is not None + assert "objective_value" in optimization + assert "allocations" in optimization + assert "constraint_residuals" in optimization + assert "sensitivity_summary" in optimization + assert "Optimization output (deterministic JSON)" in executor.prompts[0] diff --git a/tests/workflows/test_workflows.py b/tests/workflows/test_workflows.py new file mode 100644 index 00000000..d21a2d5d --- /dev/null +++ b/tests/workflows/test_workflows.py @@ -0,0 +1,79 @@ +from types import SimpleNamespace + +from langchain_core.messages import AIMessage + +from ursa.workflows import PlanningExecutorWorkflow, SimulationUseWorkflow + + +class _FakePlanner: + def __init__(self): + self.prompts = [] + + def invoke(self, prompt): + self.prompts.append(prompt) + step = SimpleNamespace( + name="Step 1", + description="Do a thing", + requires_code=False, + expected_outputs=["output"], + success_criteria=["criterion"], + ) + return {"plan": SimpleNamespace(steps=[step])} + + +class _FakeExecutor: + def __init__(self): + self.prompts = [] + + def invoke(self, prompt): + self.prompts.append(prompt) + return {"messages": [AIMessage(content="done")]} + + +def test_planning_executor_workflow_handles_string_input( + monkeypatch, tmp_path +): + monkeypatch.setattr( + "ursa.workflows.planning_execution_workflow.render_plan_steps_rich", + lambda _: None, + ) + planner = _FakePlanner() + executor = _FakeExecutor() + workflow = PlanningExecutorWorkflow( + planner=planner, + executor=executor, + workspace=tmp_path, + ) + + result = workflow.invoke("Solve this") + + assert result == "done" + assert planner.prompts + assert "Solve this" in planner.prompts[0] + assert executor.prompts + + +def test_simulation_use_workflow_handles_string_input_and_tool_schema( + monkeypatch, tmp_path +): + monkeypatch.setattr( + "ursa.workflows.simulation_use_workflow.render_plan_steps_rich", + lambda _: None, + ) + planner = _FakePlanner() + executor = _FakeExecutor() + workflow = SimulationUseWorkflow( + planner=planner, + executor=executor, + workspace=tmp_path, + tool_description="tool desc", + tool_schema="custom schema", + ) + + result = workflow.invoke("Run simulation sweep") + + assert result == "done" + assert workflow.tool_schema == "custom schema" + assert planner.prompts + assert "Run simulation sweep" in planner.prompts[0] + assert executor.prompts diff --git a/uv.lock b/uv.lock index 9c8d6e06..2e36a4a9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -191,6 +191,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, ] +[[package]] +name = "altair" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "jsonschema" }, + { name = "narwhals" }, + { name = "packaging" }, + { name = "typing-extensions", marker = "python_full_version < '3.15'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/c0/184a89bd5feba14ff3c41cfaf1dd8a82c05f5ceedbc92145e17042eb08a4/altair-6.0.0.tar.gz", hash = "sha256:614bf5ecbe2337347b590afb111929aa9c16c9527c4887d96c9bc7f6640756b4", size = 763834, upload-time = "2025-11-12T08:59:11.519Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/33/ef2f2409450ef6daa61459d5de5c08128e7d3edb773fefd0a324d1310238/altair-6.0.0-py3-none-any.whl", hash = "sha256:09ae95b53d5fe5b16987dccc785a7af8588f2dca50de1e7a156efa8a461515f8", size = 795410, upload-time = "2025-11-12T08:59:09.804Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -530,6 +546,15 @@ css = [ { name = "tinycss2" }, ] +[[package]] +name = "blinker" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460, upload-time = "2024-11-08T17:25:47.436Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, +] + [[package]] name = "boto3" version = "1.40.47" @@ -910,6 +935,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" }, ] +[[package]] +name = "cohere" +version = "5.20.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastavro" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "pydantic-core" }, + { name = "requests" }, + { name = "tokenizers" }, + { name = "types-requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/7c/415e9b150843d879427ad4760c2331443d3f4e6860d17a3c3b3841357898/cohere-5.20.6.tar.gz", hash = "sha256:96b53fafcca97d7345646b66caafb79d6d92fa144c44b6d7fd63fbeade2a5155", size = 185110, upload-time = "2026-02-18T15:57:38.19Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/0e/175613bbd3a16465b93e429edd3a44e2f76d67367131206baa951a82925d/cohere-5.20.6-py3-none-any.whl", hash = "sha256:f67050be06c437fb3b330f1326e4fc1974cdcddbb6c25afd57c2bd94897feaa6", size = 323373, upload-time = "2026-02-18T15:57:36.761Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -1374,6 +1418,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "deprecation" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/d3/8ae2869247df154b64c1884d7346d412fed0c49df84db635aab2d1c40e62/deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff", size = 173788, upload-time = "2020-04-20T14:23:38.738Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a", size = 11178, upload-time = "2020-04-20T14:23:36.581Z" }, +] + [[package]] name = "diskcache" version = "5.6.3" @@ -1498,6 +1554,53 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fastavro" +version = "1.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/8b/fa2d3287fd2267be6261d0177c6809a7fa12c5600ddb33490c8dc29e77b2/fastavro-1.12.1.tar.gz", hash = "sha256:2f285be49e45bc047ab2f6bed040bb349da85db3f3c87880e4b92595ea093b2b", size = 1025661, upload-time = "2025-10-10T15:40:55.41Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/a0/077fd7cbfc143152cb96780cb592ed6cb6696667d8bc1b977745eb2255a8/fastavro-1.12.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:00650ca533907361edda22e6ffe8cf87ab2091c5d8aee5c8000b0f2dcdda7ed3", size = 1000335, upload-time = "2025-10-10T15:40:59.834Z" }, + { url = "https://files.pythonhosted.org/packages/a0/ae/a115e027f3a75df237609701b03ecba0b7f0aa3d77fe0161df533fde1eb7/fastavro-1.12.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac76d6d95f909c72ee70d314b460b7e711d928845771531d823eb96a10952d26", size = 3221067, upload-time = "2025-10-10T15:41:04.399Z" }, + { url = "https://files.pythonhosted.org/packages/94/4e/c4991c3eec0175af9a8a0c161b88089cb7bf7fe353b3e3be1bc4cf9036b2/fastavro-1.12.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f55eef18c41d4476bd32a82ed5dd86aabc3f614e1b66bdb09ffa291612e1670", size = 3228979, upload-time = "2025-10-10T15:41:06.738Z" }, + { url = "https://files.pythonhosted.org/packages/21/0c/f2afb8eaea38799ccb1ed07d68bf2659f2e313f1902bbd36774cf6a1bef9/fastavro-1.12.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81563e1f93570e6565487cdb01ba241a36a00e58cff9c5a0614af819d1155d8f", size = 3160740, upload-time = "2025-10-10T15:41:08.731Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1a/f4d367924b40b86857862c1fa65f2afba94ddadf298b611e610a676a29e5/fastavro-1.12.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bec207360f76f0b3de540758a297193c5390e8e081c43c3317f610b1414d8c8f", size = 3235787, upload-time = "2025-10-10T15:41:10.869Z" }, + { url = "https://files.pythonhosted.org/packages/90/ec/8db9331896e3dfe4f71b2b3c23f2e97fbbfd90129777467ca9f8bafccb74/fastavro-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:c0390bfe4a9f8056a75ac6785fbbff8f5e317f5356481d2e29ec980877d2314b", size = 449350, upload-time = "2025-10-10T15:41:12.104Z" }, + { url = "https://files.pythonhosted.org/packages/a0/e9/31c64b47cefc0951099e7c0c8c8ea1c931edd1350f34d55c27cbfbb08df1/fastavro-1.12.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6b632b713bc5d03928a87d811fa4a11d5f25cd43e79c161e291c7d3f7aa740fd", size = 1016585, upload-time = "2025-10-10T15:41:13.717Z" }, + { url = "https://files.pythonhosted.org/packages/10/76/111560775b548f5d8d828c1b5285ff90e2d2745643fb80ecbf115344eea4/fastavro-1.12.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaa7ab3769beadcebb60f0539054c7755f63bd9cf7666e2c15e615ab605f89a8", size = 3404629, upload-time = "2025-10-10T15:41:15.642Z" }, + { url = "https://files.pythonhosted.org/packages/b0/07/6bb93cb963932146c2b6c5c765903a0a547ad9f0f8b769a4a9aad8c06369/fastavro-1.12.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:123fb221df3164abd93f2d042c82f538a1d5a43ce41375f12c91ce1355a9141e", size = 3428594, upload-time = "2025-10-10T15:41:17.779Z" }, + { url = "https://files.pythonhosted.org/packages/d1/67/8115ec36b584197ea737ec79e3499e1f1b640b288d6c6ee295edd13b80f6/fastavro-1.12.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:632a4e3ff223f834ddb746baae0cc7cee1068eb12c32e4d982c2fee8a5b483d0", size = 3344145, upload-time = "2025-10-10T15:41:19.89Z" }, + { url = "https://files.pythonhosted.org/packages/9e/9e/a7cebb3af967e62539539897c10138fa0821668ec92525d1be88a9cd3ee6/fastavro-1.12.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:83e6caf4e7a8717d932a3b1ff31595ad169289bbe1128a216be070d3a8391671", size = 3431942, upload-time = "2025-10-10T15:41:22.076Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d1/7774ddfb8781c5224294c01a593ebce2ad3289b948061c9701bd1903264d/fastavro-1.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:b91a0fe5a173679a6c02d53ca22dcaad0a2c726b74507e0c1c2e71a7c3f79ef9", size = 450542, upload-time = "2025-10-10T15:41:23.333Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f0/10bd1a3d08667fa0739e2b451fe90e06df575ec8b8ba5d3135c70555c9bd/fastavro-1.12.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:509818cb24b98a804fc80be9c5fed90f660310ae3d59382fc811bfa187122167", size = 1009057, upload-time = "2025-10-10T15:41:24.556Z" }, + { url = "https://files.pythonhosted.org/packages/78/ad/0d985bc99e1fa9e74c636658000ba38a5cd7f5ab2708e9c62eaf736ecf1a/fastavro-1.12.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:089e155c0c76e0d418d7e79144ce000524dd345eab3bc1e9c5ae69d500f71b14", size = 3391866, upload-time = "2025-10-10T15:41:26.882Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9e/b4951dc84ebc34aac69afcbfbb22ea4a91080422ec2bfd2c06076ff1d419/fastavro-1.12.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44cbff7518901c91a82aab476fcab13d102e4999499df219d481b9e15f61af34", size = 3458005, upload-time = "2025-10-10T15:41:29.017Z" }, + { url = "https://files.pythonhosted.org/packages/af/f8/5a8df450a9f55ca8441f22ea0351d8c77809fc121498b6970daaaf667a21/fastavro-1.12.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a275e48df0b1701bb764b18a8a21900b24cf882263cb03d35ecdba636bbc830b", size = 3295258, upload-time = "2025-10-10T15:41:31.564Z" }, + { url = "https://files.pythonhosted.org/packages/99/b2/40f25299111d737e58b85696e91138a66c25b7334f5357e7ac2b0e8966f8/fastavro-1.12.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2de72d786eb38be6b16d556b27232b1bf1b2797ea09599507938cdb7a9fe3e7c", size = 3430328, upload-time = "2025-10-10T15:41:33.689Z" }, + { url = "https://files.pythonhosted.org/packages/e0/07/85157a7c57c5f8b95507d7829b5946561e5ee656ff80e9dd9a757f53ddaf/fastavro-1.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:9090f0dee63fe022ee9cc5147483366cc4171c821644c22da020d6b48f576b4f", size = 444140, upload-time = "2025-10-10T15:41:34.902Z" }, + { url = "https://files.pythonhosted.org/packages/bb/57/26d5efef9182392d5ac9f253953c856ccb66e4c549fd3176a1e94efb05c9/fastavro-1.12.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:78df838351e4dff9edd10a1c41d1324131ffecbadefb9c297d612ef5363c049a", size = 1000599, upload-time = "2025-10-10T15:41:36.554Z" }, + { url = "https://files.pythonhosted.org/packages/33/cb/8ab55b21d018178eb126007a56bde14fd01c0afc11d20b5f2624fe01e698/fastavro-1.12.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:780476c23175d2ae457c52f45b9ffa9d504593499a36cd3c1929662bf5b7b14b", size = 3335933, upload-time = "2025-10-10T15:41:39.07Z" }, + { url = "https://files.pythonhosted.org/packages/fe/03/9c94ec9bf873eb1ffb0aa694f4e71940154e6e9728ddfdc46046d7e8ced4/fastavro-1.12.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0714b285160fcd515eb0455540f40dd6dac93bdeacdb03f24e8eac3d8aa51f8d", size = 3402066, upload-time = "2025-10-10T15:41:41.608Z" }, + { url = "https://files.pythonhosted.org/packages/75/c8/cb472347c5a584ccb8777a649ebb28278fccea39d005fc7df19996f41df8/fastavro-1.12.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a8bc2dcec5843d499f2489bfe0747999108f78c5b29295d877379f1972a3d41a", size = 3240038, upload-time = "2025-10-10T15:41:43.743Z" }, + { url = "https://files.pythonhosted.org/packages/e1/77/569ce9474c40304b3a09e109494e020462b83e405545b78069ddba5f614e/fastavro-1.12.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3b1921ac35f3d89090a5816b626cf46e67dbecf3f054131f84d56b4e70496f45", size = 3369398, upload-time = "2025-10-10T15:41:45.719Z" }, + { url = "https://files.pythonhosted.org/packages/4a/1f/9589e35e9ea68035385db7bdbf500d36b8891db474063fb1ccc8215ee37c/fastavro-1.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:5aa777b8ee595b50aa084104cd70670bf25a7bbb9fd8bb5d07524b0785ee1699", size = 444220, upload-time = "2025-10-10T15:41:47.39Z" }, + { url = "https://files.pythonhosted.org/packages/6c/d2/78435fe737df94bd8db2234b2100f5453737cffd29adee2504a2b013de84/fastavro-1.12.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:c3d67c47f177e486640404a56f2f50b165fe892cc343ac3a34673b80cc7f1dd6", size = 1086611, upload-time = "2025-10-10T15:41:48.818Z" }, + { url = "https://files.pythonhosted.org/packages/b6/be/428f99b10157230ddac77ec8cc167005b29e2bd5cbe228345192bb645f30/fastavro-1.12.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5217f773492bac43dae15ff2931432bce2d7a80be7039685a78d3fab7df910bd", size = 3541001, upload-time = "2025-10-10T15:41:50.871Z" }, + { url = "https://files.pythonhosted.org/packages/16/08/a2eea4f20b85897740efe44887e1ac08f30dfa4bfc3de8962bdcbb21a5a1/fastavro-1.12.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:469fecb25cba07f2e1bfa4c8d008477cd6b5b34a59d48715e1b1a73f6160097d", size = 3432217, upload-time = "2025-10-10T15:41:53.149Z" }, + { url = "https://files.pythonhosted.org/packages/87/bb/b4c620b9eb6e9838c7f7e4b7be0762834443adf9daeb252a214e9ad3178c/fastavro-1.12.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:d71c8aa841ef65cfab709a22bb887955f42934bced3ddb571e98fdbdade4c609", size = 3366742, upload-time = "2025-10-10T15:41:55.237Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d1/e69534ccdd5368350646fea7d93be39e5f77c614cca825c990bd9ca58f67/fastavro-1.12.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:b81fc04e85dfccf7c028e0580c606e33aa8472370b767ef058aae2c674a90746", size = 3383743, upload-time = "2025-10-10T15:41:57.68Z" }, + { url = "https://files.pythonhosted.org/packages/58/54/b7b4a0c3fb5fcba38128542da1b26c4e6d69933c923f493548bdfd63ab6a/fastavro-1.12.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:9445da127751ba65975d8e4bdabf36bfcfdad70fc35b2d988e3950cce0ec0e7c", size = 1001377, upload-time = "2025-10-10T15:41:59.241Z" }, + { url = "https://files.pythonhosted.org/packages/1e/4f/0e589089c7df0d8f57d7e5293fdc34efec9a3b758a0d4d0c99a7937e2492/fastavro-1.12.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ed924233272719b5d5a6a0b4d80ef3345fc7e84fc7a382b6232192a9112d38a6", size = 3320401, upload-time = "2025-10-10T15:42:01.682Z" }, + { url = "https://files.pythonhosted.org/packages/f9/19/260110d56194ae29d7e423a336fccea8bcd103196d00f0b364b732bdb84e/fastavro-1.12.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3616e2f0e1c9265e92954fa099db79c6e7817356d3ff34f4bcc92699ae99697c", size = 3350894, upload-time = "2025-10-10T15:42:04.073Z" }, + { url = "https://files.pythonhosted.org/packages/d0/96/58b0411e8be9694d5972bee3167d6c1fd1fdfdf7ce253c1a19a327208f4f/fastavro-1.12.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cb0337b42fd3c047fcf0e9b7597bd6ad25868de719f29da81eabb6343f08d399", size = 3229644, upload-time = "2025-10-10T15:42:06.221Z" }, + { url = "https://files.pythonhosted.org/packages/5b/db/38660660eac82c30471d9101f45b3acfdcbadfe42d8f7cdb129459a45050/fastavro-1.12.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:64961ab15b74b7c168717bbece5660e0f3d457837c3cc9d9145181d011199fa7", size = 3329704, upload-time = "2025-10-10T15:42:08.384Z" }, + { url = "https://files.pythonhosted.org/packages/9d/a9/1672910f458ecb30b596c9e59e41b7c00309b602a0494341451e92e62747/fastavro-1.12.1-cp314-cp314-win_amd64.whl", hash = "sha256:792356d320f6e757e89f7ac9c22f481e546c886454a6709247f43c0dd7058004", size = 452911, upload-time = "2025-10-10T15:42:09.795Z" }, + { url = "https://files.pythonhosted.org/packages/dc/8d/2e15d0938ded1891b33eff252e8500605508b799c2e57188a933f0bd744c/fastavro-1.12.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:120aaf82ac19d60a1016afe410935fe94728752d9c2d684e267e5b7f0e70f6d9", size = 3541999, upload-time = "2025-10-10T15:42:11.794Z" }, + { url = "https://files.pythonhosted.org/packages/a7/1c/6dfd082a205be4510543221b734b1191299e6a1810c452b6bc76dfa6968e/fastavro-1.12.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6a3462934b20a74f9ece1daa49c2e4e749bd9a35fa2657b53bf62898fba80f5", size = 3433972, upload-time = "2025-10-10T15:42:14.485Z" }, + { url = "https://files.pythonhosted.org/packages/24/90/9de694625a1a4b727b1ad0958d220cab25a9b6cf7f16a5c7faa9ea7b2261/fastavro-1.12.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1f81011d54dd47b12437b51dd93a70a9aa17b61307abf26542fc3c13efbc6c51", size = 3368752, upload-time = "2025-10-10T15:42:16.618Z" }, + { url = "https://files.pythonhosted.org/packages/fa/93/b44f67589e4d439913dab6720f7e3507b0fa8b8e56d06f6fc875ced26afb/fastavro-1.12.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:43ded16b3f4a9f1a42f5970c2aa618acb23ea59c4fcaa06680bdf470b255e5a8", size = 3386636, upload-time = "2025-10-10T15:42:18.974Z" }, +] + [[package]] name = "fastjsonschema" version = "2.21.2" @@ -1783,6 +1886,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.46" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/b5/59d16470a1f0dfe8c793f9ef56fd3826093fc52b3bd96d6b9d6c26c7e27b/gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f", size = 215371, upload-time = "2026-01-01T15:37:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" }, +] + [[package]] name = "google-auth" version = "2.42.0" @@ -2164,7 +2291,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "1.0.1" +version = "1.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2178,9 +2305,9 @@ dependencies = [ { name = "typer-slim" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/e0/308849e8ff9590505815f4a300cb8941a21c5889fb94c955d992539b5bef/huggingface_hub-1.0.1.tar.gz", hash = "sha256:87b506d5b45f0d1af58df7cf8bab993ded25d6077c2e959af58444df8b9589f3", size = 419291, upload-time = "2025-10-28T12:48:43.526Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/fc/eb9bc06130e8bbda6a616e1b80a7aa127681c448d6b49806f61db2670b61/huggingface_hub-1.4.1.tar.gz", hash = "sha256:b41131ec35e631e7383ab26d6146b8d8972abc8b6309b963b306fbcca87f5ed5", size = 642156, upload-time = "2026-02-06T09:20:03.013Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/fb/d71f914bc69e6357cbde04db62ef15497cd27926d95f03b4930997c4c390/huggingface_hub-1.0.1-py3-none-any.whl", hash = "sha256:7e255cd9b3432287a34a86933057abb1b341d20b97fb01c40cbd4e053764ae13", size = 503841, upload-time = "2025-10-28T12:48:41.821Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ae/2f6d96b4e6c5478d87d606a1934b5d436c4a2bce6bb7c6fdece891c128e3/huggingface_hub-1.4.1-py3-none-any.whl", hash = "sha256:9931d075fb7a79af5abc487106414ec5fba2c0ae86104c0c62fd6cae38873d18", size = 553326, upload-time = "2026-02-06T09:20:00.728Z" }, ] [[package]] @@ -5216,15 +5343,15 @@ wheels = [ [[package]] name = "plotly" -version = "6.3.1" +version = "5.24.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "narwhals" }, { name = "packaging" }, + { name = "tenacity" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0c/63/961d47c9ffd592a575495891cdcf7875dc0903ebb33ac238935714213789/plotly-6.3.1.tar.gz", hash = "sha256:dd896e3d940e653a7ce0470087e82c2bd903969a55e30d1b01bb389319461bb0", size = 6956460, upload-time = "2025-10-02T16:10:34.16Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/4f/428f6d959818d7425a94c190a6b26fbc58035cbef40bf249be0b62a9aedd/plotly-5.24.1.tar.gz", hash = "sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae", size = 9479398, upload-time = "2024-09-12T15:36:31.068Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/93/023955c26b0ce614342d11cc0652f1e45e32393b6ab9d11a664a60e9b7b7/plotly-6.3.1-py3-none-any.whl", hash = "sha256:8b4420d1dcf2b040f5983eed433f95732ed24930e496d36eb70d211923532e64", size = 9833698, upload-time = "2025-10-02T16:10:22.584Z" }, + { url = "https://files.pythonhosted.org/packages/e5/ae/580600f441f6fc05218bd6c9d5794f4aef072a7d9093b291f1c50a9db8bc/plotly-5.24.1-py3-none-any.whl", hash = "sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089", size = 19054220, upload-time = "2024-09-12T15:36:24.08Z" }, ] [[package]] @@ -5538,6 +5665,63 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/ee/a878f2ad010cbccb311f947f0f2f09d38f613938ee28c34e60fceecc75a1/pyaml-25.7.0-py3-none-any.whl", hash = "sha256:ce5d7867cc2b455efdb9b0448324ff7b9f74d99f64650f12ca570102db6b985f", size = 26418, upload-time = "2025-07-10T18:44:50.679Z" }, ] +[[package]] +name = "pyarrow" +version = "23.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/22/134986a4cc224d593c1afde5494d18ff629393d74cc2eddb176669f234a4/pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019", size = 1167336, upload-time = "2026-02-16T10:14:12.39Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/a8/24e5dc6855f50a62936ceb004e6e9645e4219a8065f304145d7fb8a79d5d/pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56", size = 34307390, upload-time = "2026-02-16T10:08:08.654Z" }, + { url = "https://files.pythonhosted.org/packages/bc/8e/4be5617b4aaae0287f621ad31c6036e5f63118cfca0dc57d42121ff49b51/pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c", size = 35853761, upload-time = "2026-02-16T10:08:17.811Z" }, + { url = "https://files.pythonhosted.org/packages/2e/08/3e56a18819462210432ae37d10f5c8eed3828be1d6c751b6e6a2e93c286a/pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258", size = 44493116, upload-time = "2026-02-16T10:08:25.792Z" }, + { url = "https://files.pythonhosted.org/packages/f8/82/c40b68001dbec8a3faa4c08cd8c200798ac732d2854537c5449dc859f55a/pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2", size = 47564532, upload-time = "2026-02-16T10:08:34.27Z" }, + { url = "https://files.pythonhosted.org/packages/20/bc/73f611989116b6f53347581b02177f9f620efdf3cd3f405d0e83cdf53a83/pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5", size = 48183685, upload-time = "2026-02-16T10:08:42.889Z" }, + { url = "https://files.pythonhosted.org/packages/b0/cc/6c6b3ecdae2a8c3aced99956187e8302fc954cc2cca2a37cf2111dad16ce/pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222", size = 50605582, upload-time = "2026-02-16T10:08:51.641Z" }, + { url = "https://files.pythonhosted.org/packages/8d/94/d359e708672878d7638a04a0448edf7c707f9e5606cee11e15aaa5c7535a/pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d", size = 27521148, upload-time = "2026-02-16T10:08:58.077Z" }, + { url = "https://files.pythonhosted.org/packages/b0/41/8e6b6ef7e225d4ceead8459427a52afdc23379768f54dd3566014d7618c1/pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb", size = 34302230, upload-time = "2026-02-16T10:09:03.859Z" }, + { url = "https://files.pythonhosted.org/packages/bf/4a/1472c00392f521fea03ae93408bf445cc7bfa1ab81683faf9bc188e36629/pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350", size = 35850050, upload-time = "2026-02-16T10:09:11.877Z" }, + { url = "https://files.pythonhosted.org/packages/0c/b2/bd1f2f05ded56af7f54d702c8364c9c43cd6abb91b0e9933f3d77b4f4132/pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd", size = 44491918, upload-time = "2026-02-16T10:09:18.144Z" }, + { url = "https://files.pythonhosted.org/packages/0b/62/96459ef5b67957eac38a90f541d1c28833d1b367f014a482cb63f3b7cd2d/pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9", size = 47562811, upload-time = "2026-02-16T10:09:25.792Z" }, + { url = "https://files.pythonhosted.org/packages/7d/94/1170e235add1f5f45a954e26cd0e906e7e74e23392dcb560de471f7366ec/pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701", size = 48183766, upload-time = "2026-02-16T10:09:34.645Z" }, + { url = "https://files.pythonhosted.org/packages/0e/2d/39a42af4570377b99774cdb47f63ee6c7da7616bd55b3d5001aa18edfe4f/pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78", size = 50607669, upload-time = "2026-02-16T10:09:44.153Z" }, + { url = "https://files.pythonhosted.org/packages/00/ca/db94101c187f3df742133ac837e93b1f269ebdac49427f8310ee40b6a58f/pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919", size = 27527698, upload-time = "2026-02-16T10:09:50.263Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4b/4166bb5abbfe6f750fc60ad337c43ecf61340fa52ab386da6e8dbf9e63c4/pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f", size = 34214575, upload-time = "2026-02-16T10:09:56.225Z" }, + { url = "https://files.pythonhosted.org/packages/e1/da/3f941e3734ac8088ea588b53e860baeddac8323ea40ce22e3d0baa865cc9/pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7", size = 35832540, upload-time = "2026-02-16T10:10:03.428Z" }, + { url = "https://files.pythonhosted.org/packages/88/7c/3d841c366620e906d54430817531b877ba646310296df42ef697308c2705/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9", size = 44470940, upload-time = "2026-02-16T10:10:10.704Z" }, + { url = "https://files.pythonhosted.org/packages/2c/a5/da83046273d990f256cb79796a190bbf7ec999269705ddc609403f8c6b06/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05", size = 47586063, upload-time = "2026-02-16T10:10:17.95Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3c/b7d2ebcff47a514f47f9da1e74b7949138c58cfeb108cdd4ee62f43f0cf3/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67", size = 48173045, upload-time = "2026-02-16T10:10:25.363Z" }, + { url = "https://files.pythonhosted.org/packages/43/b2/b40961262213beaba6acfc88698eb773dfce32ecdf34d19291db94c2bd73/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730", size = 50621741, upload-time = "2026-02-16T10:10:33.477Z" }, + { url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" }, + { url = "https://files.pythonhosted.org/packages/47/10/2cbe4c6f0fb83d2de37249567373d64327a5e4d8db72f486db42875b08f6/pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8", size = 34210066, upload-time = "2026-02-16T10:10:45.487Z" }, + { url = "https://files.pythonhosted.org/packages/cb/4f/679fa7e84dadbaca7a65f7cdba8d6c83febbd93ca12fa4adf40ba3b6362b/pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f", size = 35825526, upload-time = "2026-02-16T10:10:52.266Z" }, + { url = "https://files.pythonhosted.org/packages/f9/63/d2747d930882c9d661e9398eefc54f15696547b8983aaaf11d4a2e8b5426/pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677", size = 44473279, upload-time = "2026-02-16T10:11:01.557Z" }, + { url = "https://files.pythonhosted.org/packages/b3/93/10a48b5e238de6d562a411af6467e71e7aedbc9b87f8d3a35f1560ae30fb/pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2", size = 47585798, upload-time = "2026-02-16T10:11:09.401Z" }, + { url = "https://files.pythonhosted.org/packages/5c/20/476943001c54ef078dbf9542280e22741219a184a0632862bca4feccd666/pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37", size = 48179446, upload-time = "2026-02-16T10:11:17.781Z" }, + { url = "https://files.pythonhosted.org/packages/4b/b6/5dd0c47b335fcd8edba9bfab78ad961bd0fd55ebe53468cc393f45e0be60/pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2", size = 50623972, upload-time = "2026-02-16T10:11:26.185Z" }, + { url = "https://files.pythonhosted.org/packages/d5/09/a532297c9591a727d67760e2e756b83905dd89adb365a7f6e9c72578bcc1/pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a", size = 27540749, upload-time = "2026-02-16T10:12:23.297Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8e/38749c4b1303e6ae76b3c80618f84861ae0c55dd3c2273842ea6f8258233/pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1", size = 34471544, upload-time = "2026-02-16T10:11:32.535Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/f237b2bc8c669212f842bcfd842b04fc8d936bfc9d471630569132dc920d/pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500", size = 35949911, upload-time = "2026-02-16T10:11:39.813Z" }, + { url = "https://files.pythonhosted.org/packages/0c/86/b912195eee0903b5611bf596833def7d146ab2d301afeb4b722c57ffc966/pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41", size = 44520337, upload-time = "2026-02-16T10:11:47.764Z" }, + { url = "https://files.pythonhosted.org/packages/69/c2/f2a717fb824f62d0be952ea724b4f6f9372a17eed6f704b5c9526f12f2f1/pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07", size = 47548944, upload-time = "2026-02-16T10:11:56.607Z" }, + { url = "https://files.pythonhosted.org/packages/84/a7/90007d476b9f0dc308e3bc57b832d004f848fd6c0da601375d20d92d1519/pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83", size = 48236269, upload-time = "2026-02-16T10:12:04.47Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3f/b16fab3e77709856eb6ac328ce35f57a6d4a18462c7ca5186ef31b45e0e0/pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125", size = 50604794, upload-time = "2026-02-16T10:12:11.797Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a1/22df0620a9fac31d68397a75465c344e83c3dfe521f7612aea33e27ab6c0/pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8", size = 27660642, upload-time = "2026-02-16T10:12:17.746Z" }, + { url = "https://files.pythonhosted.org/packages/8d/1b/6da9a89583ce7b23ac611f183ae4843cd3a6cf54f079549b0e8c14031e73/pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca", size = 34238755, upload-time = "2026-02-16T10:12:32.819Z" }, + { url = "https://files.pythonhosted.org/packages/ae/b5/d58a241fbe324dbaeb8df07be6af8752c846192d78d2272e551098f74e88/pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1", size = 35847826, upload-time = "2026-02-16T10:12:38.949Z" }, + { url = "https://files.pythonhosted.org/packages/54/a5/8cbc83f04aba433ca7b331b38f39e000efd9f0c7ce47128670e737542996/pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb", size = 44536859, upload-time = "2026-02-16T10:12:45.467Z" }, + { url = "https://files.pythonhosted.org/packages/36/2e/c0f017c405fcdc252dbccafbe05e36b0d0eb1ea9a958f081e01c6972927f/pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1", size = 47614443, upload-time = "2026-02-16T10:12:55.525Z" }, + { url = "https://files.pythonhosted.org/packages/af/6b/2314a78057912f5627afa13ba43809d9d653e6630859618b0fd81a4e0759/pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886", size = 48232991, upload-time = "2026-02-16T10:13:04.729Z" }, + { url = "https://files.pythonhosted.org/packages/40/f2/1bcb1d3be3460832ef3370d621142216e15a2c7c62602a4ea19ec240dd64/pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f", size = 50645077, upload-time = "2026-02-16T10:13:14.147Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3f/b1da7b61cd66566a4d4c8383d376c606d1c34a906c3f1cb35c479f59d1aa/pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5", size = 28234271, upload-time = "2026-02-16T10:14:09.397Z" }, + { url = "https://files.pythonhosted.org/packages/b5/78/07f67434e910a0f7323269be7bfbf58699bd0c1d080b18a1ab49ba943fe8/pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d", size = 34488692, upload-time = "2026-02-16T10:13:21.541Z" }, + { url = "https://files.pythonhosted.org/packages/50/76/34cf7ae93ece1f740a04910d9f7e80ba166b9b4ab9596a953e9e62b90fe1/pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f", size = 35964383, upload-time = "2026-02-16T10:13:28.63Z" }, + { url = "https://files.pythonhosted.org/packages/46/90/459b827238936d4244214be7c684e1b366a63f8c78c380807ae25ed92199/pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814", size = 44538119, upload-time = "2026-02-16T10:13:35.506Z" }, + { url = "https://files.pythonhosted.org/packages/28/a1/93a71ae5881e99d1f9de1d4554a87be37da11cd6b152239fb5bd924fdc64/pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d", size = 47571199, upload-time = "2026-02-16T10:13:42.504Z" }, + { url = "https://files.pythonhosted.org/packages/88/a3/d2c462d4ef313521eaf2eff04d204ac60775263f1fb08c374b543f79f610/pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7", size = 48259435, upload-time = "2026-02-16T10:13:49.226Z" }, + { url = "https://files.pythonhosted.org/packages/cc/f1/11a544b8c3d38a759eb3fbb022039117fd633e9a7b19e4841cc3da091915/pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690", size = 50629149, upload-time = "2026-02-16T10:13:57.238Z" }, + { url = "https://files.pythonhosted.org/packages/50/f2/c0e76a0b451ffdf0cf788932e182758eb7558953f4f27f1aff8e2518b653/pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce", size = 28365807, upload-time = "2026-02-16T10:14:03.892Z" }, +] + [[package]] name = "pyasn1" version = "0.6.1" @@ -5903,6 +6087,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/86/e74c978800131c657fc5145f2c1c63e0cea01a49b6216f729cf77a2e1edf/pydash-8.0.5-py3-none-any.whl", hash = "sha256:b2625f8981862e19911daa07f80ed47b315ce20d9b5eb57aaf97aaf570c3892f", size = 102077, upload-time = "2025-01-17T16:08:47.91Z" }, ] +[[package]] +name = "pydeck" +version = "0.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/ca/40e14e196864a0f61a92abb14d09b3d3da98f94ccb03b49cf51688140dab/pydeck-0.9.1.tar.gz", hash = "sha256:f74475ae637951d63f2ee58326757f8d4f9cd9f2a457cf42950715003e2cb605", size = 3832240, upload-time = "2024-05-10T15:36:21.153Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/4c/b888e6cf58bd9db9c93f40d1c6be8283ff49d88919231afe93a6bcf61626/pydeck-0.9.1-py2.py3-none-any.whl", hash = "sha256:b3f75ba0d273fc917094fa61224f3f6076ca8752b93d46faf3bcfd9f9d59b038", size = 6900403, upload-time = "2024-05-10T15:36:17.36Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -6480,6 +6678,19 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/e8/c2/525e9e9b458c3ca493d9bd0871f3ed9b51446d26fe82d462494de188f848/randomname-0.2.1.tar.gz", hash = "sha256:b79b98302ba4479164b0a4f87995b7bebbd1d91012aeda483341e3e58ace520e", size = 64242, upload-time = "2023-01-29T02:42:26.469Z" } +[[package]] +name = "rank-bm25" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/0a/f9579384aa017d8b4c15613f86954b92a95a93d641cc849182467cf0bb3b/rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d", size = 8347, upload-time = "2022-02-16T12:10:52.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/21/f691fb2613100a62b3fa91e9988c991e9ca5b89ea31c0d3152a3210344f9/rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae", size = 8584, upload-time = "2022-02-16T12:10:50.626Z" }, +] + [[package]] name = "referencing" version = "0.36.2" @@ -6926,6 +7137,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/f0/ae7ca09223a81a1d890b2557186ea015f6e0502e9b8cb8e1813f1d8cfa4e/s3transfer-0.14.0-py3-none-any.whl", hash = "sha256:ea3b790c7077558ed1f02a3072fb3cb992bbbd253392f4b6e9e8976941c7d456", size = 85712, upload-time = "2025-09-09T19:23:30.041Z" }, ] +[[package]] +name = "safetensors" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" }, + { url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" }, + { url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" }, + { url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" }, + { url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" }, + { url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" }, + { url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" }, + { url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" }, + { url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" }, + { url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" }, + { url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" }, + { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, + { url = "https://files.pythonhosted.org/packages/a7/6a/4d08d89a6fcbe905c5ae68b8b34f0791850882fc19782d0d02c65abbdf3b/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4729811a6640d019a4b7ba8638ee2fd21fa5ca8c7e7bdf0fed62068fcaac737", size = 492430, upload-time = "2025-11-19T15:18:11.884Z" }, + { url = "https://files.pythonhosted.org/packages/dd/29/59ed8152b30f72c42d00d241e58eaca558ae9dbfa5695206e2e0f54c7063/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12f49080303fa6bb424b362149a12949dfbbf1e06811a88f2307276b0c131afd", size = 503977, upload-time = "2025-11-19T15:18:17.523Z" }, + { url = "https://files.pythonhosted.org/packages/d3/0b/4811bfec67fa260e791369b16dab105e4bae82686120554cc484064e22b4/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0071bffba4150c2f46cae1432d31995d77acfd9f8db598b5d1a2ce67e8440ad2", size = 623890, upload-time = "2025-11-19T15:18:22.666Z" }, + { url = "https://files.pythonhosted.org/packages/58/5b/632a58724221ef03d78ab65062e82a1010e1bef8e8e0b9d7c6d7b8044841/safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:473b32699f4200e69801bf5abf93f1a4ecd432a70984df164fc22ccf39c4a6f3", size = 531885, upload-time = "2025-11-19T15:18:27.146Z" }, +] + [[package]] name = "scikit-learn" version = "1.7.2" @@ -7208,6 +7445,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/b0/4562db6223154aa4e22f939003cb92514c79f3d4dccca3444253fd17f902/Send2Trash-1.8.3-py3-none-any.whl", hash = "sha256:0c31227e0bd08961c7665474a3d1ef7193929fedda4233843689baa056be46c9", size = 18072, upload-time = "2024-04-07T00:01:07.438Z" }, ] +[[package]] +name = "sentence-transformers" +version = "5.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scikit-learn" }, + { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.16.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/30/21664028fc0776eb1ca024879480bbbab36f02923a8ff9e4cae5a150fa35/sentence_transformers-5.2.3.tar.gz", hash = "sha256:3cd3044e1f3fe859b6a1b66336aac502eaae5d3dd7d5c8fc237f37fbf58137c7", size = 381623, upload-time = "2026-02-17T14:05:20.238Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/9f/dba4b3e18ebbe1eaa29d9f1764fbc7da0cd91937b83f2b7928d15c5d2d36/sentence_transformers-5.2.3-py3-none-any.whl", hash = "sha256:6437c62d4112b615ddebda362dfc16a4308d604c5b68125ed586e3e95d5b2e30", size = 494225, upload-time = "2026-02-17T14:05:18.596Z" }, +] + [[package]] name = "sentinels" version = "1.1.1" @@ -7262,6 +7520,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/d9/460cf1d58945dd771c228c29d5664f431dfc4060d3d092fed40546b11472/smart_open-7.3.1-py3-none-any.whl", hash = "sha256:e243b2e7f69d6c0c96dd763d6fbbedbb4e0e4fc6d74aa007acc5b018d523858c", size = 61722, upload-time = "2025-09-08T10:03:52.02Z" }, ] +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -7430,6 +7697,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/72/2db2f49247d0a18b4f1bb9a5a39a0162869acf235f3a96418363947b3d46/starlette-0.48.0-py3-none-any.whl", hash = "sha256:0764ca97b097582558ecb498132ed0c7d942f233f365b86ba37770e026510659", size = 73736, upload-time = "2025-09-13T08:41:03.869Z" }, ] +[[package]] +name = "streamlit" +version = "1.54.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "altair" }, + { name = "blinker" }, + { name = "cachetools" }, + { name = "click" }, + { name = "gitpython" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pillow" }, + { name = "protobuf" }, + { name = "pyarrow" }, + { name = "pydeck" }, + { name = "requests" }, + { name = "tenacity" }, + { name = "toml" }, + { name = "tornado" }, + { name = "typing-extensions" }, + { name = "watchdog", marker = "sys_platform != 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/66/d887ee80ea85f035baee607c60af024994e17ae9b921277fca9675e76ecf/streamlit-1.54.0.tar.gz", hash = "sha256:09965e6ae7eb0357091725de1ce2a3f7e4be155c2464c505c40a3da77ab69dd8", size = 8662292, upload-time = "2026-02-04T16:37:54.734Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/1d/40de1819374b4f0507411a60f4d2de0d620a9b10c817de5925799132b6c9/streamlit-1.54.0-py3-none-any.whl", hash = "sha256:a7b67d6293a9f5f6b4d4c7acdbc4980d7d9f049e78e404125022ecb1712f79fc", size = 9119730, upload-time = "2026-02-04T16:37:52.199Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -7599,6 +7896,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/46/e33a8c93907b631a99377ef4c5f817ab453d0b34f93529421f42ff559671/tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138", size = 2674684, upload-time = "2025-09-19T09:49:24.953Z" }, ] +[[package]] +name = "toml" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253, upload-time = "2020-11-01T01:40:22.204Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -7803,6 +8109,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, ] +[[package]] +name = "transformers" +version = "5.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, + { name = "typer-slim" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bd/7e/8a0c57d562015e5b16c97c1f0b8e0e92ead2c7c20513225dc12c2043ba9f/transformers-5.2.0.tar.gz", hash = "sha256:0088b8b46ccc9eff1a1dca72b5d618a5ee3b1befc3e418c9512b35dea9f9a650", size = 8618176, upload-time = "2026-02-16T18:54:02.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/93/79754b0ca486e556c2b95d4f5afc66aaf4b260694f3d6e1b51da2d036691/transformers-5.2.0-py3-none-any.whl", hash = "sha256:9ecaf243dc45bee11a7d93f8caf03746accc0cb069181bbf4ad8566c53e854b4", size = 10403304, upload-time = "2026-02-16T18:53:59.699Z" }, +] + [[package]] name = "triton" version = "3.5.1" @@ -7854,6 +8181,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/d9/a29dfa84363e88b053bf85a8b7f212a04f0d7343a4d24933baa45c06e08b/types_python_dateutil-2.9.0.20250822-py3-none-any.whl", hash = "sha256:849d52b737e10a6dc6621d2bd7940ec7c65fcb69e6aa2882acf4e56b2b508ddc", size = 17892, upload-time = "2025-08-22T03:01:59.436Z" }, ] +[[package]] +name = "types-requests" +version = "2.32.4.20260107" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/f3/a0663907082280664d745929205a89d41dffb29e89a50f753af7d57d0a96/types_requests-2.32.4.20260107.tar.gz", hash = "sha256:018a11ac158f801bfa84857ddec1650750e393df8a004a8a9ae2a9bec6fcb24f", size = 23165, upload-time = "2026-01-07T03:20:54.091Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/12/709ea261f2bf91ef0a26a9eed20f2623227a8ed85610c1e54c5805692ecb/types_requests-2.32.4.20260107-py3-none-any.whl", hash = "sha256:b703fe72f8ce5b31ef031264fe9395cac8f46a04661a79f7ed31a80fb308730d", size = 20676, upload-time = "2026-01-07T03:20:52.929Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -7964,6 +8303,7 @@ dependencies = [ { name = "langgraph-checkpoint-sqlite" }, { name = "mcp" }, { name = "mp-api" }, + { name = "openai" }, { name = "pandas" }, { name = "pillow" }, { name = "pydantic" }, @@ -7971,6 +8311,7 @@ dependencies = [ { name = "pypdf" }, { name = "pyyaml" }, { name = "randomname" }, + { name = "rank-bm25" }, { name = "rich" }, { name = "selectolax" }, { name = "trafilatura", version = "1.6.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, @@ -7979,6 +8320,15 @@ dependencies = [ ] [package.optional-dependencies] +cmm = [ + { name = "cohere" }, + { name = "sentence-transformers" }, + { name = "weaviate-client" }, +] +dashboard = [ + { name = "plotly" }, + { name = "streamlit" }, +] fm = [ { name = "torch" }, ] @@ -8019,6 +8369,7 @@ requires-dist = [ { name = "arxiv", specifier = ">=2.2.0,<3.0" }, { name = "atomman", marker = "extra == 'lammps'", specifier = ">=1.5.2" }, { name = "beautifulsoup4", specifier = ">=4.13.4,<5.0" }, + { name = "cohere", marker = "extra == 'cmm'", specifier = ">=5.11.0" }, { name = "ddgs", specifier = ">=9.5.5" }, { name = "fastmcp", specifier = ">=2.13.3" }, { name = "jsonargparse", specifier = ">=4.45.0" }, @@ -8032,11 +8383,13 @@ requires-dist = [ { name = "langgraph-checkpoint-sqlite", specifier = ">=3.0.0" }, { name = "mcp", specifier = ">=1.20.0,<2.0" }, { name = "mp-api", specifier = ">=0.45.8,<0.45.13" }, + { name = "openai", specifier = ">=1.0.0,<3.0" }, { name = "opentelemetry-exporter-otlp", marker = "extra == 'otel'", specifier = ">=1.39.0" }, { name = "opentelemetry-sdk", marker = "extra == 'otel'", specifier = ">=1.38.0" }, { name = "ortools", marker = "extra == 'opt'", specifier = ">=9.14,<9.15" }, { name = "pandas", specifier = ">=2.3.1,<3.0" }, { name = "pillow", specifier = ">=11.3.0,<12.0" }, + { name = "plotly", marker = "extra == 'dashboard'", specifier = ">=5.18,<6.0" }, { name = "pydantic", specifier = ">=2.12.0,<3.0" }, { name = "pymupdf", specifier = ">=1.26.0,<2.0" }, { name = "pypdf", specifier = ">=5.9.0,<6.0" }, @@ -8044,13 +8397,17 @@ requires-dist = [ { name = "python-pptx", marker = "extra == 'office-readers'", specifier = ">=1.0.2" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "randomname", specifier = ">=0.2.1,<0.3" }, + { name = "rank-bm25", specifier = ">=0.2.2,<0.3" }, { name = "rich", specifier = ">=13.9.4,<14.0" }, { name = "selectolax", specifier = ">=0.4.0,<0.5" }, + { name = "sentence-transformers", marker = "extra == 'cmm'", specifier = ">=3.2.0" }, + { name = "streamlit", marker = "extra == 'dashboard'", specifier = ">=1.40,<2.0" }, { name = "torch", marker = "extra == 'fm'", specifier = ">=2.9.0" }, { name = "trafilatura", specifier = ">=1.6.1,<1.7" }, { name = "typer", specifier = ">=0.16.1" }, + { name = "weaviate-client", marker = "extra == 'cmm'", specifier = ">=4.9.0" }, ] -provides-extras = ["fm", "lammps", "office-readers", "opt", "otel"] +provides-extras = ["fm", "lammps", "office-readers", "opt", "otel", "cmm", "dashboard"] [package.metadata.requires-dev] dev = [ @@ -8138,6 +8495,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e4/16/c1fd27e9549f3c4baf1dc9c20c456cd2f822dbf8de9f463824b0c0357e06/uvloop-0.22.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6cde23eeda1a25c75b2e07d39970f3374105d5eafbaab2a4482be82f272d5a5e", size = 4296730, upload-time = "2025-10-16T22:17:00.744Z" }, ] +[[package]] +name = "validators" +version = "0.35.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/53/66/a435d9ae49850b2f071f7ebd8119dd4e84872b01630d6736761e6e7fd847/validators-0.35.0.tar.gz", hash = "sha256:992d6c48a4e77c81f1b4daba10d16c3a9bb0dbb79b3a19ea847ff0928e70497a", size = 73399, upload-time = "2025-05-01T05:42:06.7Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" }, +] + [[package]] name = "virtualenv" version = "20.34.0" @@ -8297,6 +8663,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, ] +[[package]] +name = "weaviate-client" +version = "4.19.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "authlib" }, + { name = "deprecation" }, + { name = "grpcio" }, + { name = "httpx" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "validators" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/34/c4587c49255cb6310d9ef395422f3b8eb88f54fb344e6dfebe97746b279f/weaviate_client-4.19.4.tar.gz", hash = "sha256:9b448df63a40461c6e20153eb5f2dd8f58f18bdf7c0c60d168c1e628c1088e0b", size = 788828, upload-time = "2026-02-18T15:05:31.565Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/ad/c8cf1f52686988deeb2775e3e1bf78023405769e1c2ff9abd7b39a11222b/weaviate_client-4.19.4-py3-none-any.whl", hash = "sha256:440ec600d702c88ee807c2382d4a3493c877c6db8dc0a22767b4acbc004d0e8b", size = 604242, upload-time = "2026-02-18T15:05:29.931Z" }, +] + [[package]] name = "webcolors" version = "24.11.1"